diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b0090356..649aebe80 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,21 @@ endif() # variables set by vendor.py set(DUCKDB_INCLUDE_DIRS + src/duckdb/extension/core_functions/include + src/duckdb/extension/icu/include + src/duckdb/extension/icu/third_party/icu/common + src/duckdb/extension/icu/third_party/icu/i18n + src/duckdb/extension/json/include + src/duckdb/extension/parquet/include src/duckdb/src/include + src/duckdb/third_party/brotli/common + src/duckdb/third_party/brotli/common + src/duckdb/third_party/brotli/dec + src/duckdb/third_party/brotli/dec + src/duckdb/third_party/brotli/enc + src/duckdb/third_party/brotli/enc + src/duckdb/third_party/brotli/include + src/duckdb/third_party/brotli/include src/duckdb/third_party/concurrentqueue src/duckdb/third_party/fast_float src/duckdb/third_party/fastpforlib @@ -49,45 +63,30 @@ set(DUCKDB_INCLUDE_DIRS src/duckdb/third_party/jaro_winkler src/duckdb/third_party/jaro_winkler/details src/duckdb/third_party/lz4 - src/duckdb/third_party/brotli/include - src/duckdb/third_party/brotli/common - src/duckdb/third_party/brotli/dec - src/duckdb/third_party/brotli/enc + src/duckdb/third_party/lz4 + src/duckdb/third_party/mbedtls + src/duckdb/third_party/mbedtls/include src/duckdb/third_party/mbedtls/include src/duckdb/third_party/mbedtls/library src/duckdb/third_party/miniz + src/duckdb/third_party/parquet src/duckdb/third_party/pcg src/duckdb/third_party/pdqsort src/duckdb/third_party/re2 src/duckdb/third_party/ska_sort src/duckdb/third_party/skiplist + src/duckdb/third_party/snappy src/duckdb/third_party/tdigest + src/duckdb/third_party/thrift src/duckdb/third_party/utf8proc src/duckdb/third_party/utf8proc/include src/duckdb/third_party/vergesort src/duckdb/third_party/yyjson/include src/duckdb/third_party/zstd/include - src/duckdb/extension/core_functions/include - src/duckdb/extension/parquet/include - src/duckdb/third_party/parquet - src/duckdb/third_party/thrift - src/duckdb/third_party/lz4 - src/duckdb/third_party/brotli/include - src/duckdb/third_party/brotli/common - src/duckdb/third_party/brotli/dec - src/duckdb/third_party/brotli/enc - src/duckdb/third_party/snappy - src/duckdb/third_party/mbedtls - src/duckdb/third_party/mbedtls/include - src/duckdb/third_party/zstd/include - src/duckdb/extension/icu/include - src/duckdb/extension/icu/third_party/icu/common - src/duckdb/extension/icu/third_party/icu/i18n - src/duckdb/extension/json/include) + src/duckdb/third_party/zstd/include) set(JEMALLOC_INCLUDE_DIRS - src/duckdb/extension/jemalloc/include - src/duckdb/extension/jemalloc/jemalloc/include) + src/duckdb/third_party/jemalloc/include) set(DUCKDB_DEFINITIONS -DDUCKDB_EXTENSION_CORE_FUNCTIONS_LINKED @@ -96,31 +95,186 @@ set(DUCKDB_DEFINITIONS -DDUCKDB_EXTENSION_JSON_LINKED) set(DUCKDB_SRC_FILES - src/duckdb/ub_src_catalog.cpp - src/duckdb/ub_src_catalog_catalog_entry.cpp - src/duckdb/ub_src_catalog_catalog_entry_dependency.cpp - src/duckdb/ub_src_catalog_default.cpp + src/duckdb/extension/core_functions/core_functions_extension.cpp + src/duckdb/extension/core_functions/function_list.cpp + src/duckdb/extension/core_functions/lambda_functions.cpp + src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp + src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp + src/duckdb/extension/core_functions/scalar/math/numeric.cpp + src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp + src/duckdb/extension/icu/icu-current.cpp + src/duckdb/extension/icu/icu-dateadd.cpp + src/duckdb/extension/icu/icu-datefunc.cpp + src/duckdb/extension/icu/icu-datepart.cpp + src/duckdb/extension/icu/icu-datesub.cpp + src/duckdb/extension/icu/icu-datetrunc.cpp + src/duckdb/extension/icu/icu-list-range.cpp + src/duckdb/extension/icu/icu-makedate.cpp + src/duckdb/extension/icu/icu-strptime.cpp + src/duckdb/extension/icu/icu-table-range.cpp + src/duckdb/extension/icu/icu-timebucket.cpp + src/duckdb/extension/icu/icu-timezone.cpp + src/duckdb/extension/icu/icu_extension.cpp + src/duckdb/extension/icu/third_party/icu/common/bytestriebuilder.cpp + src/duckdb/extension/icu/third_party/icu/common/loadednormalizer2impl.cpp + src/duckdb/extension/icu/third_party/icu/common/localebuilder.cpp + src/duckdb/extension/icu/third_party/icu/common/locdistance.cpp + src/duckdb/extension/icu/third_party/icu/common/loclikelysubtags.cpp + src/duckdb/extension/icu/third_party/icu/common/locmap.cpp + src/duckdb/extension/icu/third_party/icu/common/putil.cpp + src/duckdb/extension/icu/third_party/icu/common/static_unicode_sets.cpp + src/duckdb/extension/icu/third_party/icu/common/ubidi_props.cpp + src/duckdb/extension/icu/third_party/icu/common/ubiditransform.cpp + src/duckdb/extension/icu/third_party/icu/common/ucase.cpp + src/duckdb/extension/icu/third_party/icu/common/ucharstriebuilder.cpp + src/duckdb/extension/icu/third_party/icu/common/ucptrie.cpp + src/duckdb/extension/icu/third_party/icu/common/uhash.cpp + src/duckdb/extension/icu/third_party/icu/common/uloc.cpp + src/duckdb/extension/icu/third_party/icu/common/umapfile.cpp + src/duckdb/extension/icu/third_party/icu/common/umutablecptrie.cpp + src/duckdb/extension/icu/third_party/icu/common/unames.cpp + src/duckdb/extension/icu/third_party/icu/common/unifiedcache.cpp + src/duckdb/extension/icu/third_party/icu/common/unormcmp.cpp + src/duckdb/extension/icu/third_party/icu/common/uresbund.cpp + src/duckdb/extension/icu/third_party/icu/common/uresdata.cpp + src/duckdb/extension/icu/third_party/icu/common/uscript_props.cpp + src/duckdb/extension/icu/third_party/icu/common/ustrcase.cpp + src/duckdb/extension/icu/third_party/icu/common/utrie.cpp + src/duckdb/extension/icu/third_party/icu/common/utrie2.cpp + src/duckdb/extension/icu/third_party/icu/common/utrie2_builder.cpp + src/duckdb/extension/icu/third_party/icu/common/uvector.cpp + src/duckdb/extension/icu/third_party/icu/common/wintz.cpp + src/duckdb/extension/icu/third_party/icu/i18n/buddhcal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/calendar.cpp + src/duckdb/extension/icu/third_party/icu/i18n/cecal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/chnsecal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/choicfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/collationroot.cpp + src/duckdb/extension/icu/third_party/icu/i18n/compactdecimalformat.cpp + src/duckdb/extension/icu/third_party/icu/i18n/coptccal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/currpinf.cpp + src/duckdb/extension/icu/third_party/icu/i18n/dayperiodrules.cpp + src/duckdb/extension/icu/third_party/icu/i18n/decNumber.cpp + src/duckdb/extension/icu/third_party/icu/i18n/decimfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-bignum-dtoa.cpp + src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-bignum.cpp + src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-cached-powers.cpp + src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-double-to-string.cpp + src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-fast-dtoa.cpp + src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-string-to-double.cpp + src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-strtod.cpp + src/duckdb/extension/icu/third_party/icu/i18n/dtfmtsym.cpp + src/duckdb/extension/icu/third_party/icu/i18n/dtitvfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/dtptngen.cpp + src/duckdb/extension/icu/third_party/icu/i18n/ethpccal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/fmtable.cpp + src/duckdb/extension/icu/third_party/icu/i18n/formattedval_sbimpl.cpp + src/duckdb/extension/icu/third_party/icu/i18n/gregocal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/gregoimp.cpp + src/duckdb/extension/icu/third_party/icu/i18n/hebrwcal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/indiancal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/islamcal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/iso8601cal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/japancal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/measfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/measunit_extra.cpp + src/duckdb/extension/icu/third_party/icu/i18n/msgfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/nfrs.cpp + src/duckdb/extension/icu/third_party/icu/i18n/nfrule.cpp + src/duckdb/extension/icu/third_party/icu/i18n/nfsubs.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_asformat.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_capi.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_compact.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_currencysymbols.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_decimalquantity.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_fluent.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_formatimpl.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_grouping.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_integerwidth.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_longnames.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_mapper.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_multiplier.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_output.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_patternmodifier.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_patternstring.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_rounding.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_scientific.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_simple.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_skeletons.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_symbolswrapper.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_usageprefs.cpp + src/duckdb/extension/icu/third_party/icu/i18n/number_utils.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_affixes.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_compositions.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_currency.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_decimal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_impl.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_parsednumber.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_scientific.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_symbols.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numparse_validators.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numrange_capi.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numrange_fluent.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numrange_impl.cpp + src/duckdb/extension/icu/third_party/icu/i18n/numsys.cpp + src/duckdb/extension/icu/third_party/icu/i18n/persncal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/pluralranges.cpp + src/duckdb/extension/icu/third_party/icu/i18n/plurfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/plurrule.cpp + src/duckdb/extension/icu/third_party/icu/i18n/quantityformatter.cpp + src/duckdb/extension/icu/third_party/icu/i18n/rbnf.cpp + src/duckdb/extension/icu/third_party/icu/i18n/smpdtfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/string_segment.cpp + src/duckdb/extension/icu/third_party/icu/i18n/taiwncal.cpp + src/duckdb/extension/icu/third_party/icu/i18n/timezone.cpp + src/duckdb/extension/icu/third_party/icu/i18n/tmutfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/tzgnames.cpp + src/duckdb/extension/icu/third_party/icu/i18n/ucol.cpp + src/duckdb/extension/icu/third_party/icu/i18n/ucol_res.cpp + src/duckdb/extension/icu/third_party/icu/i18n/units_complexconverter.cpp + src/duckdb/extension/icu/third_party/icu/i18n/units_data.cpp + src/duckdb/extension/icu/third_party/icu/i18n/units_router.cpp + src/duckdb/extension/icu/third_party/icu/i18n/upluralrules.cpp + src/duckdb/extension/icu/third_party/icu/i18n/vtzone.cpp + src/duckdb/extension/icu/third_party/icu/i18n/windtfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/winnmfmt.cpp + src/duckdb/extension/icu/third_party/icu/i18n/wintzimpl.cpp + src/duckdb/extension/icu/third_party/icu/stubdata/stubdata.cpp + src/duckdb/extension/json/json_common.cpp + src/duckdb/extension/json/json_deserializer.cpp + src/duckdb/extension/json/json_enums.cpp + src/duckdb/extension/json/json_extension.cpp + src/duckdb/extension/json/json_functions.cpp + src/duckdb/extension/json/json_multi_file_info.cpp + src/duckdb/extension/json/json_reader.cpp + src/duckdb/extension/json/json_scan.cpp + src/duckdb/extension/json/json_serializer.cpp + src/duckdb/extension/json/serialize_json.cpp + src/duckdb/extension/parquet/column_reader.cpp + src/duckdb/extension/parquet/column_writer.cpp + src/duckdb/extension/parquet/parquet_column_schema.cpp + src/duckdb/extension/parquet/parquet_crypto.cpp + src/duckdb/extension/parquet/parquet_extension.cpp + src/duckdb/extension/parquet/parquet_field_id.cpp + src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp + src/duckdb/extension/parquet/parquet_float16.cpp + src/duckdb/extension/parquet/parquet_geometry.cpp + src/duckdb/extension/parquet/parquet_metadata.cpp + src/duckdb/extension/parquet/parquet_multi_file_info.cpp + src/duckdb/extension/parquet/parquet_prefetch_cost_model.cpp + src/duckdb/extension/parquet/parquet_reader.cpp + src/duckdb/extension/parquet/parquet_shredding.cpp + src/duckdb/extension/parquet/parquet_statistics.cpp + src/duckdb/extension/parquet/parquet_timestamp.cpp + src/duckdb/extension/parquet/parquet_writer.cpp + src/duckdb/extension/parquet/serialize_parquet.cpp + src/duckdb/extension/parquet/zstd_file_system.cpp + src/duckdb/generated_extension_loader_package_build.cpp src/duckdb/src/common/adbc/adbc.cpp - src/duckdb/ub_src_common_adbc_nanoarrow.cpp - src/duckdb/ub_src_common.cpp - src/duckdb/ub_src_common_arrow_appender.cpp - src/duckdb/ub_src_common_arrow.cpp src/duckdb/src/common/crypto/md5.cpp - src/duckdb/ub_src_common_enums.cpp - src/duckdb/ub_src_common_exception.cpp - src/duckdb/ub_src_common_multi_file.cpp - src/duckdb/ub_src_common_operator.cpp - src/duckdb/ub_src_common_progress_bar.cpp - src/duckdb/ub_src_common_row_operations.cpp - src/duckdb/ub_src_common_serializer.cpp - src/duckdb/ub_src_common_sort.cpp - src/duckdb/ub_src_common_tree_renderer.cpp - src/duckdb/ub_src_common_types.cpp - src/duckdb/ub_src_common_types_column.cpp - src/duckdb/ub_src_common_types_row.cpp - src/duckdb/ub_src_common_types_variant.cpp + src/duckdb/src/common/util/util_parsed_expression.cpp src/duckdb/src/common/value_operations/comparison_operations.cpp - src/duckdb/ub_src_common_vector.cpp src/duckdb/src/common/vector_operations/boolean_operators.cpp src/duckdb/src/common/vector_operations/comparison_operators.cpp src/duckdb/src/common/vector_operations/generators.cpp @@ -131,64 +285,21 @@ set(DUCKDB_SRC_FILES src/duckdb/src/common/vector_operations/vector_copy.cpp src/duckdb/src/common/vector_operations/vector_hash.cpp src/duckdb/src/common/vector_operations/vector_storage.cpp - src/duckdb/ub_src_execution.cpp - src/duckdb/ub_src_execution_expression_executor.cpp - src/duckdb/ub_src_execution_index_art.cpp - src/duckdb/ub_src_execution_index.cpp - src/duckdb/ub_src_execution_nested_loop_join.cpp - src/duckdb/ub_src_execution_operator_aggregate.cpp - src/duckdb/ub_src_execution_operator_csv_scanner_buffer_manager.cpp src/duckdb/src/execution/operator/csv_scanner/encode/csv_encoder.cpp - src/duckdb/ub_src_execution_operator_csv_scanner_scanner.cpp - src/duckdb/ub_src_execution_operator_csv_scanner_sniffer.cpp - src/duckdb/ub_src_execution_operator_csv_scanner_state_machine.cpp - src/duckdb/ub_src_execution_operator_csv_scanner_table_function.cpp - src/duckdb/ub_src_execution_operator_csv_scanner_util.cpp src/duckdb/src/execution/operator/filter/physical_filter.cpp - src/duckdb/ub_src_execution_operator_helper.cpp - src/duckdb/ub_src_execution_operator_join.cpp - src/duckdb/ub_src_execution_operator_order.cpp - src/duckdb/ub_src_execution_operator_persistent.cpp - src/duckdb/ub_src_execution_operator_projection.cpp - src/duckdb/ub_src_execution_operator_scan.cpp - src/duckdb/ub_src_execution_operator_schema.cpp - src/duckdb/ub_src_execution_operator_set.cpp - src/duckdb/ub_src_execution_physical_plan.cpp - src/duckdb/ub_src_execution_sample.cpp - src/duckdb/ub_src_function_aggregate_distributive.cpp src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp - src/duckdb/ub_src_function.cpp - src/duckdb/ub_src_function_cast.cpp src/duckdb/src/function/cast/union/from_struct.cpp - src/duckdb/ub_src_function_cast_variant.cpp - src/duckdb/ub_src_function_pragma.cpp - src/duckdb/ub_src_function_scalar_comparison.cpp - src/duckdb/ub_src_function_scalar_compressed_materialization.cpp - src/duckdb/ub_src_function_scalar.cpp src/duckdb/src/function/scalar/date/strftime.cpp - src/duckdb/ub_src_function_scalar_generic.cpp src/duckdb/src/function/scalar/geometry/geometry_functions.cpp - src/duckdb/ub_src_function_scalar_list.cpp src/duckdb/src/function/scalar/map/map_contains.cpp - src/duckdb/ub_src_function_scalar_operator.cpp src/duckdb/src/function/scalar/sequence/nextval.cpp - src/duckdb/ub_src_function_scalar_string.cpp - src/duckdb/ub_src_function_scalar_string_regexp.cpp - src/duckdb/ub_src_function_scalar_struct.cpp - src/duckdb/ub_src_function_scalar_system.cpp - src/duckdb/ub_src_function_scalar_variant.cpp - src/duckdb/ub_src_function_table_arrow.cpp - src/duckdb/ub_src_function_table.cpp - src/duckdb/ub_src_function_table_system.cpp src/duckdb/src/function/table/version/pragma_version.cpp src/duckdb/src/function/variant/variant_shredding.cpp - src/duckdb/ub_src_function_window.cpp - src/duckdb/ub_src_logging.cpp - src/duckdb/ub_src_main.cpp - src/duckdb/ub_src_main_buffered_data.cpp src/duckdb/src/main/capi/aggregate_function-c.cpp src/duckdb/src/main/capi/appender-c.cpp src/duckdb/src/main/capi/arrow-c.cpp + src/duckdb/src/main/capi/cast/from_decimal-c.cpp + src/duckdb/src/main/capi/cast/utils-c.cpp src/duckdb/src/main/capi/cast_function-c.cpp src/duckdb/src/main/capi/catalog-c.cpp src/duckdb/src/main/capi/config-c.cpp @@ -216,67 +327,75 @@ set(DUCKDB_SRC_FILES src/duckdb/src/main/capi/table_function-c.cpp src/duckdb/src/main/capi/threading-c.cpp src/duckdb/src/main/capi/value-c.cpp - src/duckdb/src/main/capi/cast/from_decimal-c.cpp - src/duckdb/src/main/capi/cast/utils-c.cpp - src/duckdb/ub_src_main_chunk_scan_state.cpp src/duckdb/src/main/extension/extension_alias.cpp src/duckdb/src/main/extension/extension_helper.cpp src/duckdb/src/main/extension/extension_install.cpp src/duckdb/src/main/extension/extension_load.cpp src/duckdb/src/main/extension/extension_loader.cpp src/duckdb/src/main/http/http_util.cpp - src/duckdb/ub_src_main_relation.cpp - src/duckdb/ub_src_main_secret.cpp - src/duckdb/ub_src_main_settings.cpp - src/duckdb/ub_src_optimizer.cpp - src/duckdb/ub_src_optimizer_compressed_materialization.cpp - src/duckdb/ub_src_optimizer_join_order.cpp src/duckdb/src/optimizer/matcher/expression_matcher.cpp - src/duckdb/ub_src_optimizer_pullup.cpp - src/duckdb/ub_src_optimizer_pushdown.cpp - src/duckdb/ub_src_optimizer_rule.cpp - src/duckdb/ub_src_optimizer_statistics_expression.cpp - src/duckdb/ub_src_optimizer_statistics_operator.cpp - src/duckdb/ub_src_parallel.cpp - src/duckdb/ub_src_parser.cpp - src/duckdb/ub_src_parser_constraints.cpp - src/duckdb/ub_src_parser_expression.cpp - src/duckdb/ub_src_parser_parsed_data.cpp - src/duckdb/ub_src_parser_peg.cpp - src/duckdb/ub_src_parser_peg_tokenizer.cpp - src/duckdb/ub_src_parser_peg_transformer.cpp - src/duckdb/ub_src_parser_query_node.cpp - src/duckdb/ub_src_parser_statement.cpp - src/duckdb/ub_src_parser_tableref.cpp - src/duckdb/ub_src_planner.cpp - src/duckdb/ub_src_planner_binder_expression.cpp - src/duckdb/ub_src_planner_binder_query_node.cpp - src/duckdb/ub_src_planner_binder_statement.cpp - src/duckdb/ub_src_planner_binder_tableref.cpp - src/duckdb/ub_src_planner_expression.cpp - src/duckdb/ub_src_planner_expression_binder.cpp - src/duckdb/ub_src_planner_filter.cpp - src/duckdb/ub_src_planner_operator.cpp - src/duckdb/ub_src_planner_subquery.cpp - src/duckdb/ub_src_storage.cpp - src/duckdb/ub_src_storage_buffer.cpp - src/duckdb/ub_src_storage_checkpoint.cpp - src/duckdb/ub_src_storage_compression_alp.cpp - src/duckdb/ub_src_storage_compression.cpp - src/duckdb/ub_src_storage_compression_chimp.cpp - src/duckdb/ub_src_storage_compression_dict_fsst.cpp - src/duckdb/ub_src_storage_compression_dictionary.cpp - src/duckdb/ub_src_storage_compression_roaring.cpp - src/duckdb/ub_src_storage_external_file_cache.cpp - src/duckdb/ub_src_storage_metadata.cpp - src/duckdb/ub_src_storage_serialization.cpp - src/duckdb/ub_src_storage_statistics.cpp - src/duckdb/ub_src_storage_table.cpp - src/duckdb/ub_src_storage_table_variant.cpp - src/duckdb/ub_src_transaction.cpp + src/duckdb/third_party/brotli/common/constants.cpp + src/duckdb/third_party/brotli/common/context.cpp + src/duckdb/third_party/brotli/common/dictionary.cpp + src/duckdb/third_party/brotli/common/platform.cpp + src/duckdb/third_party/brotli/common/shared_dictionary.cpp + src/duckdb/third_party/brotli/common/transform.cpp + src/duckdb/third_party/brotli/dec/bit_reader.cpp + src/duckdb/third_party/brotli/dec/decode.cpp + src/duckdb/third_party/brotli/dec/huffman.cpp + src/duckdb/third_party/brotli/dec/state.cpp + src/duckdb/third_party/brotli/enc/backward_references.cpp + src/duckdb/third_party/brotli/enc/backward_references_hq.cpp + src/duckdb/third_party/brotli/enc/bit_cost.cpp + src/duckdb/third_party/brotli/enc/block_splitter.cpp + src/duckdb/third_party/brotli/enc/brotli_bit_stream.cpp + src/duckdb/third_party/brotli/enc/cluster.cpp + src/duckdb/third_party/brotli/enc/command.cpp + src/duckdb/third_party/brotli/enc/compound_dictionary.cpp + src/duckdb/third_party/brotli/enc/compress_fragment.cpp + src/duckdb/third_party/brotli/enc/compress_fragment_two_pass.cpp + src/duckdb/third_party/brotli/enc/dictionary_hash.cpp + src/duckdb/third_party/brotli/enc/encode.cpp + src/duckdb/third_party/brotli/enc/encoder_dict.cpp + src/duckdb/third_party/brotli/enc/entropy_encode.cpp + src/duckdb/third_party/brotli/enc/fast_log.cpp + src/duckdb/third_party/brotli/enc/histogram.cpp + src/duckdb/third_party/brotli/enc/literal_cost.cpp + src/duckdb/third_party/brotli/enc/memory.cpp + src/duckdb/third_party/brotli/enc/metablock.cpp + src/duckdb/third_party/brotli/enc/static_dict.cpp + src/duckdb/third_party/brotli/enc/utf8_util.cpp + src/duckdb/third_party/fastpforlib/bitpacking.cpp src/duckdb/third_party/fmt/format.cc src/duckdb/third_party/fsst/libfsst.cpp + src/duckdb/third_party/hyperloglog/hyperloglog.cpp + src/duckdb/third_party/hyperloglog/sds.cpp + src/duckdb/third_party/lz4/lz4.cpp + src/duckdb/third_party/mbedtls/library/aes.cpp + src/duckdb/third_party/mbedtls/library/asn1parse.cpp + src/duckdb/third_party/mbedtls/library/asn1write.cpp + src/duckdb/third_party/mbedtls/library/base64.cpp + src/duckdb/third_party/mbedtls/library/bignum.cpp + src/duckdb/third_party/mbedtls/library/bignum_core.cpp + src/duckdb/third_party/mbedtls/library/cipher.cpp + src/duckdb/third_party/mbedtls/library/cipher_wrap.cpp + src/duckdb/third_party/mbedtls/library/constant_time.cpp + src/duckdb/third_party/mbedtls/library/gcm.cpp + src/duckdb/third_party/mbedtls/library/md.cpp + src/duckdb/third_party/mbedtls/library/oid.cpp + src/duckdb/third_party/mbedtls/library/pem.cpp + src/duckdb/third_party/mbedtls/library/pk.cpp + src/duckdb/third_party/mbedtls/library/pk_wrap.cpp + src/duckdb/third_party/mbedtls/library/pkparse.cpp + src/duckdb/third_party/mbedtls/library/platform.cpp + src/duckdb/third_party/mbedtls/library/platform_util.cpp + src/duckdb/third_party/mbedtls/library/rsa.cpp + src/duckdb/third_party/mbedtls/library/rsa_alt_helpers.cpp + src/duckdb/third_party/mbedtls/library/sha1.cpp + src/duckdb/third_party/mbedtls/library/sha256.cpp + src/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp src/duckdb/third_party/miniz/miniz.cpp + src/duckdb/third_party/parquet/parquet_types.cpp src/duckdb/third_party/re2/re2/bitmap256.cc src/duckdb/third_party/re2/re2/bitstate.cc src/duckdb/third_party/re2/re2/compile.cc @@ -300,35 +419,14 @@ set(DUCKDB_SRC_FILES src/duckdb/third_party/re2/re2/unicode_groups.cc src/duckdb/third_party/re2/util/rune.cc src/duckdb/third_party/re2/util/strutil.cc - src/duckdb/third_party/hyperloglog/hyperloglog.cpp - src/duckdb/third_party/hyperloglog/sds.cpp src/duckdb/third_party/skiplist/SkipList.cpp - src/duckdb/third_party/fastpforlib/bitpacking.cpp + src/duckdb/third_party/snappy/snappy-sinksource.cc + src/duckdb/third_party/snappy/snappy.cc + src/duckdb/third_party/thrift/thrift/protocol/TProtocol.cpp + src/duckdb/third_party/thrift/thrift/transport/TBufferTransports.cpp + src/duckdb/third_party/thrift/thrift/transport/TTransportException.cpp src/duckdb/third_party/utf8proc/utf8proc.cpp src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp - src/duckdb/third_party/mbedtls/library/aes.cpp - src/duckdb/third_party/mbedtls/library/asn1parse.cpp - src/duckdb/third_party/mbedtls/library/asn1write.cpp - src/duckdb/third_party/mbedtls/library/base64.cpp - src/duckdb/third_party/mbedtls/library/bignum.cpp - src/duckdb/third_party/mbedtls/library/bignum_core.cpp - src/duckdb/third_party/mbedtls/library/cipher.cpp - src/duckdb/third_party/mbedtls/library/cipher_wrap.cpp - src/duckdb/third_party/mbedtls/library/constant_time.cpp - src/duckdb/third_party/mbedtls/library/gcm.cpp - src/duckdb/third_party/mbedtls/library/md.cpp - src/duckdb/third_party/mbedtls/library/oid.cpp - src/duckdb/third_party/mbedtls/library/pem.cpp - src/duckdb/third_party/mbedtls/library/pk.cpp - src/duckdb/third_party/mbedtls/library/pk_wrap.cpp - src/duckdb/third_party/mbedtls/library/pkparse.cpp - src/duckdb/third_party/mbedtls/library/platform.cpp - src/duckdb/third_party/mbedtls/library/platform_util.cpp - src/duckdb/third_party/mbedtls/library/rsa.cpp - src/duckdb/third_party/mbedtls/library/rsa_alt_helpers.cpp - src/duckdb/third_party/mbedtls/library/sha1.cpp - src/duckdb/third_party/mbedtls/library/sha256.cpp - src/duckdb/third_party/mbedtls/mbedtls_wrapper.cpp src/duckdb/third_party/yyjson/yyjson.cpp src/duckdb/third_party/zstd/common/debug.cpp src/duckdb/third_party/zstd/common/entropy_common.cpp @@ -362,311 +460,214 @@ set(DUCKDB_SRC_FILES src/duckdb/third_party/zstd/dict/divsufsort.cpp src/duckdb/third_party/zstd/dict/fastcover.cpp src/duckdb/third_party/zstd/dict/zdict.cpp - src/duckdb/extension/core_functions/core_functions_extension.cpp - src/duckdb/extension/core_functions/lambda_functions.cpp - src/duckdb/extension/core_functions/function_list.cpp + src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp + src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp src/duckdb/ub_extension_core_functions_aggregate_nested.cpp - src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp - src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp src/duckdb/ub_extension_core_functions_aggregate_regression.cpp - src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp - src/duckdb/ub_extension_core_functions_scalar_generic.cpp - src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp - src/duckdb/ub_extension_core_functions_scalar_debug.cpp - src/duckdb/extension/core_functions/scalar/math/numeric.cpp src/duckdb/ub_extension_core_functions_scalar_array.cpp - src/duckdb/ub_extension_core_functions_scalar_union.cpp src/duckdb/ub_extension_core_functions_scalar_blob.cpp + src/duckdb/ub_extension_core_functions_scalar_date.cpp + src/duckdb/ub_extension_core_functions_scalar_debug.cpp + src/duckdb/ub_extension_core_functions_scalar_generic.cpp src/duckdb/ub_extension_core_functions_scalar_list.cpp src/duckdb/ub_extension_core_functions_scalar_map.cpp - src/duckdb/ub_extension_core_functions_scalar_date.cpp - src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp src/duckdb/ub_extension_core_functions_scalar_random.cpp src/duckdb/ub_extension_core_functions_scalar_string.cpp src/duckdb/ub_extension_core_functions_scalar_struct.cpp - src/duckdb/extension/parquet/parquet_reader.cpp - src/duckdb/extension/parquet/parquet_statistics.cpp - src/duckdb/extension/parquet/zstd_file_system.cpp - src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp - src/duckdb/extension/parquet/parquet_shredding.cpp - src/duckdb/extension/parquet/parquet_multi_file_info.cpp - src/duckdb/extension/parquet/parquet_metadata.cpp - src/duckdb/extension/parquet/parquet_timestamp.cpp - src/duckdb/extension/parquet/parquet_crypto.cpp - src/duckdb/extension/parquet/serialize_parquet.cpp - src/duckdb/extension/parquet/parquet_column_schema.cpp - src/duckdb/extension/parquet/parquet_extension.cpp - src/duckdb/extension/parquet/parquet_field_id.cpp - src/duckdb/extension/parquet/column_writer.cpp - src/duckdb/extension/parquet/column_reader.cpp - src/duckdb/extension/parquet/parquet_writer.cpp - src/duckdb/extension/parquet/parquet_float16.cpp - src/duckdb/extension/parquet/parquet_geometry.cpp + src/duckdb/ub_extension_core_functions_scalar_union.cpp + src/duckdb/ub_extension_icu_third_party_icu_common.cpp + src/duckdb/ub_extension_icu_third_party_icu_i18n.cpp + src/duckdb/ub_extension_json_json_functions.cpp + src/duckdb/ub_extension_parquet_decoder.cpp src/duckdb/ub_extension_parquet_reader.cpp src/duckdb/ub_extension_parquet_reader_variant.cpp - src/duckdb/ub_extension_parquet_decoder.cpp src/duckdb/ub_extension_parquet_writer.cpp src/duckdb/ub_extension_parquet_writer_variant.cpp - src/duckdb/third_party/parquet/parquet_types.cpp - src/duckdb/third_party/thrift/thrift/protocol/TProtocol.cpp - src/duckdb/third_party/thrift/thrift/transport/TTransportException.cpp - src/duckdb/third_party/thrift/thrift/transport/TBufferTransports.cpp - src/duckdb/third_party/snappy/snappy.cc - src/duckdb/third_party/snappy/snappy-sinksource.cc - src/duckdb/third_party/lz4/lz4.cpp - src/duckdb/third_party/brotli/common/constants.cpp - src/duckdb/third_party/brotli/common/context.cpp - src/duckdb/third_party/brotli/common/dictionary.cpp - src/duckdb/third_party/brotli/common/platform.cpp - src/duckdb/third_party/brotli/common/shared_dictionary.cpp - src/duckdb/third_party/brotli/common/transform.cpp - src/duckdb/third_party/brotli/dec/bit_reader.cpp - src/duckdb/third_party/brotli/dec/decode.cpp - src/duckdb/third_party/brotli/dec/huffman.cpp - src/duckdb/third_party/brotli/dec/state.cpp - src/duckdb/third_party/brotli/enc/backward_references.cpp - src/duckdb/third_party/brotli/enc/backward_references_hq.cpp - src/duckdb/third_party/brotli/enc/bit_cost.cpp - src/duckdb/third_party/brotli/enc/block_splitter.cpp - src/duckdb/third_party/brotli/enc/brotli_bit_stream.cpp - src/duckdb/third_party/brotli/enc/cluster.cpp - src/duckdb/third_party/brotli/enc/command.cpp - src/duckdb/third_party/brotli/enc/compound_dictionary.cpp - src/duckdb/third_party/brotli/enc/compress_fragment.cpp - src/duckdb/third_party/brotli/enc/compress_fragment_two_pass.cpp - src/duckdb/third_party/brotli/enc/dictionary_hash.cpp - src/duckdb/third_party/brotli/enc/encode.cpp - src/duckdb/third_party/brotli/enc/encoder_dict.cpp - src/duckdb/third_party/brotli/enc/entropy_encode.cpp - src/duckdb/third_party/brotli/enc/fast_log.cpp - src/duckdb/third_party/brotli/enc/histogram.cpp - src/duckdb/third_party/brotli/enc/literal_cost.cpp - src/duckdb/third_party/brotli/enc/memory.cpp - src/duckdb/third_party/brotli/enc/metablock.cpp - src/duckdb/third_party/brotli/enc/static_dict.cpp - src/duckdb/third_party/brotli/enc/utf8_util.cpp - src/duckdb/extension/icu/./icu-timebucket.cpp - src/duckdb/extension/icu/./icu-datepart.cpp - src/duckdb/extension/icu/./icu-datetrunc.cpp - src/duckdb/extension/icu/./icu-current.cpp - src/duckdb/extension/icu/./icu_extension.cpp - src/duckdb/extension/icu/./icu-table-range.cpp - src/duckdb/extension/icu/./icu-datefunc.cpp - src/duckdb/extension/icu/./icu-makedate.cpp - src/duckdb/extension/icu/./icu-dateadd.cpp - src/duckdb/extension/icu/./icu-timezone.cpp - src/duckdb/extension/icu/./icu-strptime.cpp - src/duckdb/extension/icu/./icu-datesub.cpp - src/duckdb/extension/icu/./icu-list-range.cpp - src/duckdb/extension/icu/third_party/icu/common/static_unicode_sets.cpp - src/duckdb/extension/icu/third_party/icu/common/utrie2_builder.cpp - src/duckdb/extension/icu/third_party/icu/common/umutablecptrie.cpp - src/duckdb/extension/icu/third_party/icu/common/unames.cpp - src/duckdb/extension/icu/third_party/icu/common/uscript_props.cpp - src/duckdb/extension/icu/third_party/icu/common/putil.cpp - src/duckdb/extension/icu/third_party/icu/common/ubiditransform.cpp - src/duckdb/extension/icu/third_party/icu/common/locmap.cpp - src/duckdb/extension/icu/third_party/icu/common/locdistance.cpp - src/duckdb/extension/icu/third_party/icu/common/ucptrie.cpp - src/duckdb/extension/icu/third_party/icu/common/uhash.cpp - src/duckdb/extension/icu/third_party/icu/common/localebuilder.cpp - src/duckdb/extension/icu/third_party/icu/common/uvector.cpp - src/duckdb/extension/icu/third_party/icu/common/ucase.cpp - src/duckdb/extension/icu/third_party/icu/common/utrie2.cpp - src/duckdb/extension/icu/third_party/icu/common/ubidi_props.cpp - src/duckdb/extension/icu/third_party/icu/common/ustrcase.cpp - src/duckdb/extension/icu/third_party/icu/common/uresdata.cpp - src/duckdb/extension/icu/third_party/icu/common/umapfile.cpp - src/duckdb/extension/icu/third_party/icu/common/utrie.cpp - src/duckdb/extension/icu/third_party/icu/common/loadednormalizer2impl.cpp - src/duckdb/extension/icu/third_party/icu/common/ucharstriebuilder.cpp - src/duckdb/extension/icu/third_party/icu/common/unormcmp.cpp - src/duckdb/extension/icu/third_party/icu/common/wintz.cpp - src/duckdb/extension/icu/third_party/icu/common/uresbund.cpp - src/duckdb/extension/icu/third_party/icu/common/uloc.cpp - src/duckdb/extension/icu/third_party/icu/common/loclikelysubtags.cpp - src/duckdb/extension/icu/third_party/icu/common/unifiedcache.cpp - src/duckdb/extension/icu/third_party/icu/common/bytestriebuilder.cpp - src/duckdb/ub_extension_icu_third_party_icu_common.cpp - src/duckdb/extension/icu/third_party/icu/i18n/ucol.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_grouping.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_symbolswrapper.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_capi.cpp - src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-cached-powers.cpp - src/duckdb/extension/icu/third_party/icu/i18n/fmtable.cpp - src/duckdb/extension/icu/third_party/icu/i18n/string_segment.cpp - src/duckdb/extension/icu/third_party/icu/i18n/chnsecal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/plurrule.cpp - src/duckdb/extension/icu/third_party/icu/i18n/smpdtfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_skeletons.cpp - src/duckdb/extension/icu/third_party/icu/i18n/nfsubs.cpp - src/duckdb/extension/icu/third_party/icu/i18n/coptccal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-bignum.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numrange_fluent.cpp - src/duckdb/extension/icu/third_party/icu/i18n/taiwncal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/measfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/decimfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_decimalquantity.cpp - src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-string-to-double.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_usageprefs.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_patternstring.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_formatimpl.cpp - src/duckdb/extension/icu/third_party/icu/i18n/dtfmtsym.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_integerwidth.cpp - src/duckdb/extension/icu/third_party/icu/i18n/wintzimpl.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_rounding.cpp - src/duckdb/extension/icu/third_party/icu/i18n/dtitvfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_symbols.cpp - src/duckdb/extension/icu/third_party/icu/i18n/quantityformatter.cpp - src/duckdb/extension/icu/third_party/icu/i18n/timezone.cpp - src/duckdb/extension/icu/third_party/icu/i18n/indiancal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/msgfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/persncal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/dtptngen.cpp - src/duckdb/extension/icu/third_party/icu/i18n/choicfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_affixes.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_mapper.cpp - src/duckdb/extension/icu/third_party/icu/i18n/compactdecimalformat.cpp - src/duckdb/extension/icu/third_party/icu/i18n/buddhcal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_scientific.cpp - src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-bignum-dtoa.cpp - src/duckdb/extension/icu/third_party/icu/i18n/gregocal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_decimal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/cecal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-double-to-string.cpp - src/duckdb/extension/icu/third_party/icu/i18n/hebrwcal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_longnames.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_currencysymbols.cpp - src/duckdb/extension/icu/third_party/icu/i18n/upluralrules.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_compact.cpp - src/duckdb/extension/icu/third_party/icu/i18n/nfrule.cpp - src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-strtod.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_parsednumber.cpp - src/duckdb/extension/icu/third_party/icu/i18n/vtzone.cpp - src/duckdb/extension/icu/third_party/icu/i18n/plurfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/windtfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/calendar.cpp - src/duckdb/extension/icu/third_party/icu/i18n/formattedval_sbimpl.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numrange_capi.cpp - src/duckdb/extension/icu/third_party/icu/i18n/units_complexconverter.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_utils.cpp - src/duckdb/extension/icu/third_party/icu/i18n/tzgnames.cpp - src/duckdb/extension/icu/third_party/icu/i18n/nfrs.cpp - src/duckdb/extension/icu/third_party/icu/i18n/decNumber.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numsys.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_patternmodifier.cpp - src/duckdb/extension/icu/third_party/icu/i18n/units_router.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_impl.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_scientific.cpp - src/duckdb/extension/icu/third_party/icu/i18n/islamcal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/double-conversion-fast-dtoa.cpp - src/duckdb/extension/icu/third_party/icu/i18n/units_data.cpp - src/duckdb/extension/icu/third_party/icu/i18n/collationroot.cpp - src/duckdb/extension/icu/third_party/icu/i18n/currpinf.cpp - src/duckdb/extension/icu/third_party/icu/i18n/gregoimp.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_simple.cpp - src/duckdb/extension/icu/third_party/icu/i18n/rbnf.cpp - src/duckdb/extension/icu/third_party/icu/i18n/iso8601cal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/tmutfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_asformat.cpp - src/duckdb/extension/icu/third_party/icu/i18n/pluralranges.cpp - src/duckdb/extension/icu/third_party/icu/i18n/dayperiodrules.cpp - src/duckdb/extension/icu/third_party/icu/i18n/ethpccal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_multiplier.cpp - src/duckdb/extension/icu/third_party/icu/i18n/ucol_res.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_currency.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_fluent.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numrange_impl.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_compositions.cpp - src/duckdb/extension/icu/third_party/icu/i18n/winnmfmt.cpp - src/duckdb/extension/icu/third_party/icu/i18n/number_output.cpp - src/duckdb/extension/icu/third_party/icu/i18n/japancal.cpp - src/duckdb/extension/icu/third_party/icu/i18n/numparse_validators.cpp - src/duckdb/extension/icu/third_party/icu/i18n/measunit_extra.cpp - src/duckdb/ub_extension_icu_third_party_icu_i18n.cpp - src/duckdb/extension/icu/third_party/icu/stubdata/stubdata.cpp - src/duckdb/extension/json/json_reader.cpp - src/duckdb/extension/json/json_functions.cpp - src/duckdb/extension/json/json_deserializer.cpp - src/duckdb/extension/json/json_serializer.cpp - src/duckdb/extension/json/json_scan.cpp - src/duckdb/extension/json/serialize_json.cpp - src/duckdb/extension/json/json_enums.cpp - src/duckdb/extension/json/json_extension.cpp - src/duckdb/extension/json/json_multi_file_info.cpp - src/duckdb/extension/json/json_common.cpp - src/duckdb/ub_extension_json_json_functions.cpp - src/duckdb/generated_extension_loader_package_build.cpp) + src/duckdb/ub_src_catalog.cpp + src/duckdb/ub_src_catalog_catalog_entry.cpp + src/duckdb/ub_src_catalog_catalog_entry_dependency.cpp + src/duckdb/ub_src_catalog_default.cpp + src/duckdb/ub_src_common.cpp + src/duckdb/ub_src_common_adbc_nanoarrow.cpp + src/duckdb/ub_src_common_allocator.cpp + src/duckdb/ub_src_common_arrow.cpp + src/duckdb/ub_src_common_arrow_appender.cpp + src/duckdb/ub_src_common_enums.cpp + src/duckdb/ub_src_common_exception.cpp + src/duckdb/ub_src_common_multi_file.cpp + src/duckdb/ub_src_common_operator.cpp + src/duckdb/ub_src_common_progress_bar.cpp + src/duckdb/ub_src_common_row_operations.cpp + src/duckdb/ub_src_common_serializer.cpp + src/duckdb/ub_src_common_sort.cpp + src/duckdb/ub_src_common_tree_renderer.cpp + src/duckdb/ub_src_common_types.cpp + src/duckdb/ub_src_common_types_column.cpp + src/duckdb/ub_src_common_types_row.cpp + src/duckdb/ub_src_common_types_variant.cpp + src/duckdb/ub_src_common_vector.cpp + src/duckdb/ub_src_execution.cpp + src/duckdb/ub_src_execution_expression_executor.cpp + src/duckdb/ub_src_execution_index.cpp + src/duckdb/ub_src_execution_index_art.cpp + src/duckdb/ub_src_execution_nested_loop_join.cpp + src/duckdb/ub_src_execution_operator_aggregate.cpp + src/duckdb/ub_src_execution_operator_csv_scanner_buffer_manager.cpp + src/duckdb/ub_src_execution_operator_csv_scanner_scanner.cpp + src/duckdb/ub_src_execution_operator_csv_scanner_sniffer.cpp + src/duckdb/ub_src_execution_operator_csv_scanner_state_machine.cpp + src/duckdb/ub_src_execution_operator_csv_scanner_table_function.cpp + src/duckdb/ub_src_execution_operator_csv_scanner_util.cpp + src/duckdb/ub_src_execution_operator_helper.cpp + src/duckdb/ub_src_execution_operator_join.cpp + src/duckdb/ub_src_execution_operator_order.cpp + src/duckdb/ub_src_execution_operator_persistent.cpp + src/duckdb/ub_src_execution_operator_projection.cpp + src/duckdb/ub_src_execution_operator_scan.cpp + src/duckdb/ub_src_execution_operator_schema.cpp + src/duckdb/ub_src_execution_operator_set.cpp + src/duckdb/ub_src_execution_physical_plan.cpp + src/duckdb/ub_src_execution_sample.cpp + src/duckdb/ub_src_function.cpp + src/duckdb/ub_src_function_aggregate_distributive.cpp + src/duckdb/ub_src_function_cast.cpp + src/duckdb/ub_src_function_cast_variant.cpp + src/duckdb/ub_src_function_pragma.cpp + src/duckdb/ub_src_function_scalar.cpp + src/duckdb/ub_src_function_scalar_comparison.cpp + src/duckdb/ub_src_function_scalar_compressed_materialization.cpp + src/duckdb/ub_src_function_scalar_generic.cpp + src/duckdb/ub_src_function_scalar_list.cpp + src/duckdb/ub_src_function_scalar_operator.cpp + src/duckdb/ub_src_function_scalar_string.cpp + src/duckdb/ub_src_function_scalar_string_regexp.cpp + src/duckdb/ub_src_function_scalar_struct.cpp + src/duckdb/ub_src_function_scalar_system.cpp + src/duckdb/ub_src_function_scalar_variant.cpp + src/duckdb/ub_src_function_table.cpp + src/duckdb/ub_src_function_table_arrow.cpp + src/duckdb/ub_src_function_table_system.cpp + src/duckdb/ub_src_function_window.cpp + src/duckdb/ub_src_logging.cpp + src/duckdb/ub_src_main.cpp + src/duckdb/ub_src_main_buffered_data.cpp + src/duckdb/ub_src_main_chunk_scan_state.cpp + src/duckdb/ub_src_main_relation.cpp + src/duckdb/ub_src_main_secret.cpp + src/duckdb/ub_src_main_settings.cpp + src/duckdb/ub_src_optimizer.cpp + src/duckdb/ub_src_optimizer_compressed_materialization.cpp + src/duckdb/ub_src_optimizer_join_order.cpp + src/duckdb/ub_src_optimizer_pullup.cpp + src/duckdb/ub_src_optimizer_pushdown.cpp + src/duckdb/ub_src_optimizer_rule.cpp + src/duckdb/ub_src_optimizer_statistics_expression.cpp + src/duckdb/ub_src_optimizer_statistics_operator.cpp + src/duckdb/ub_src_parallel.cpp + src/duckdb/ub_src_parser.cpp + src/duckdb/ub_src_parser_constraints.cpp + src/duckdb/ub_src_parser_expression.cpp + src/duckdb/ub_src_parser_parsed_data.cpp + src/duckdb/ub_src_parser_peg.cpp + src/duckdb/ub_src_parser_peg_tokenizer.cpp + src/duckdb/ub_src_parser_peg_transformer.cpp + src/duckdb/ub_src_parser_query_node.cpp + src/duckdb/ub_src_parser_statement.cpp + src/duckdb/ub_src_parser_tableref.cpp + src/duckdb/ub_src_planner.cpp + src/duckdb/ub_src_planner_binder_expression.cpp + src/duckdb/ub_src_planner_binder_query_node.cpp + src/duckdb/ub_src_planner_binder_statement.cpp + src/duckdb/ub_src_planner_binder_tableref.cpp + src/duckdb/ub_src_planner_expression.cpp + src/duckdb/ub_src_planner_expression_binder.cpp + src/duckdb/ub_src_planner_filter.cpp + src/duckdb/ub_src_planner_operator.cpp + src/duckdb/ub_src_planner_subquery.cpp + src/duckdb/ub_src_storage.cpp + src/duckdb/ub_src_storage_buffer.cpp + src/duckdb/ub_src_storage_checkpoint.cpp + src/duckdb/ub_src_storage_compression.cpp + src/duckdb/ub_src_storage_compression_alp.cpp + src/duckdb/ub_src_storage_compression_chimp.cpp + src/duckdb/ub_src_storage_compression_dict_fsst.cpp + src/duckdb/ub_src_storage_compression_dictionary.cpp + src/duckdb/ub_src_storage_compression_roaring.cpp + src/duckdb/ub_src_storage_external_file_cache.cpp + src/duckdb/ub_src_storage_metadata.cpp + src/duckdb/ub_src_storage_serialization.cpp + src/duckdb/ub_src_storage_statistics.cpp + src/duckdb/ub_src_storage_table.cpp + src/duckdb/ub_src_storage_table_variant.cpp + src/duckdb/ub_src_transaction.cpp) set(JEMALLOC_SRC_FILES - src/duckdb/extension/jemalloc/jemalloc_extension.cpp - src/duckdb/extension/jemalloc/jemalloc/src/jemalloc.c - src/duckdb/extension/jemalloc/jemalloc/src/arena.c - src/duckdb/extension/jemalloc/jemalloc/src/background_thread.c - src/duckdb/extension/jemalloc/jemalloc/src/base.c - src/duckdb/extension/jemalloc/jemalloc/src/batcher.c - src/duckdb/extension/jemalloc/jemalloc/src/bin.c - src/duckdb/extension/jemalloc/jemalloc/src/bin_info.c - src/duckdb/extension/jemalloc/jemalloc/src/bitmap.c - src/duckdb/extension/jemalloc/jemalloc/src/buf_writer.c - src/duckdb/extension/jemalloc/jemalloc/src/cache_bin.c - src/duckdb/extension/jemalloc/jemalloc/src/ckh.c - src/duckdb/extension/jemalloc/jemalloc/src/counter.c - src/duckdb/extension/jemalloc/jemalloc/src/ctl.c - src/duckdb/extension/jemalloc/jemalloc/src/decay.c - src/duckdb/extension/jemalloc/jemalloc/src/div.c - src/duckdb/extension/jemalloc/jemalloc/src/ecache.c - src/duckdb/extension/jemalloc/jemalloc/src/edata.c - src/duckdb/extension/jemalloc/jemalloc/src/edata_cache.c - src/duckdb/extension/jemalloc/jemalloc/src/ehooks.c - src/duckdb/extension/jemalloc/jemalloc/src/emap.c - src/duckdb/extension/jemalloc/jemalloc/src/eset.c - src/duckdb/extension/jemalloc/jemalloc/src/exp_grow.c - src/duckdb/extension/jemalloc/jemalloc/src/extent.c - src/duckdb/extension/jemalloc/jemalloc/src/extent_dss.c - src/duckdb/extension/jemalloc/jemalloc/src/extent_mmap.c - src/duckdb/extension/jemalloc/jemalloc/src/fxp.c - src/duckdb/extension/jemalloc/jemalloc/src/san.c - src/duckdb/extension/jemalloc/jemalloc/src/san_bump.c - src/duckdb/extension/jemalloc/jemalloc/src/hook.c - src/duckdb/extension/jemalloc/jemalloc/src/hpa.c - src/duckdb/extension/jemalloc/jemalloc/src/hpa_hooks.c - src/duckdb/extension/jemalloc/jemalloc/src/hpdata.c - src/duckdb/extension/jemalloc/jemalloc/src/inspect.c - src/duckdb/extension/jemalloc/jemalloc/src/large.c - src/duckdb/extension/jemalloc/jemalloc/src/log.c - src/duckdb/extension/jemalloc/jemalloc/src/malloc_io.c - src/duckdb/extension/jemalloc/jemalloc/src/mutex.c - src/duckdb/extension/jemalloc/jemalloc/src/nstime.c - src/duckdb/extension/jemalloc/jemalloc/src/pa.c - src/duckdb/extension/jemalloc/jemalloc/src/pa_extra.c - src/duckdb/extension/jemalloc/jemalloc/src/pai.c - src/duckdb/extension/jemalloc/jemalloc/src/pac.c - src/duckdb/extension/jemalloc/jemalloc/src/pages.c - src/duckdb/extension/jemalloc/jemalloc/src/peak_event.c - src/duckdb/extension/jemalloc/jemalloc/src/prof.c - src/duckdb/extension/jemalloc/jemalloc/src/prof_data.c - src/duckdb/extension/jemalloc/jemalloc/src/prof_log.c - src/duckdb/extension/jemalloc/jemalloc/src/prof_recent.c - src/duckdb/extension/jemalloc/jemalloc/src/prof_stats.c - src/duckdb/extension/jemalloc/jemalloc/src/prof_sys.c - src/duckdb/extension/jemalloc/jemalloc/src/psset.c - src/duckdb/extension/jemalloc/jemalloc/src/rtree.c - src/duckdb/extension/jemalloc/jemalloc/src/safety_check.c - src/duckdb/extension/jemalloc/jemalloc/src/sc.c - src/duckdb/extension/jemalloc/jemalloc/src/sec.c - src/duckdb/extension/jemalloc/jemalloc/src/stats.c - src/duckdb/extension/jemalloc/jemalloc/src/sz.c - src/duckdb/extension/jemalloc/jemalloc/src/tcache.c - src/duckdb/extension/jemalloc/jemalloc/src/test_hooks.c - src/duckdb/extension/jemalloc/jemalloc/src/thread_event.c - src/duckdb/extension/jemalloc/jemalloc/src/ticker.c - src/duckdb/extension/jemalloc/jemalloc/src/tsd.c - src/duckdb/extension/jemalloc/jemalloc/src/util.c - src/duckdb/extension/jemalloc/jemalloc/src/witness.c - src/duckdb/extension/jemalloc/jemalloc/src/zone.c) + src/duckdb/third_party/jemalloc/src/arena.c + src/duckdb/third_party/jemalloc/src/background_thread.c + src/duckdb/third_party/jemalloc/src/base.c + src/duckdb/third_party/jemalloc/src/batcher.c + src/duckdb/third_party/jemalloc/src/bin.c + src/duckdb/third_party/jemalloc/src/bin_info.c + src/duckdb/third_party/jemalloc/src/bitmap.c + src/duckdb/third_party/jemalloc/src/buf_writer.c + src/duckdb/third_party/jemalloc/src/cache_bin.c + src/duckdb/third_party/jemalloc/src/ckh.c + src/duckdb/third_party/jemalloc/src/counter.c + src/duckdb/third_party/jemalloc/src/ctl.c + src/duckdb/third_party/jemalloc/src/decay.c + src/duckdb/third_party/jemalloc/src/div.c + src/duckdb/third_party/jemalloc/src/ecache.c + src/duckdb/third_party/jemalloc/src/edata.c + src/duckdb/third_party/jemalloc/src/edata_cache.c + src/duckdb/third_party/jemalloc/src/ehooks.c + src/duckdb/third_party/jemalloc/src/emap.c + src/duckdb/third_party/jemalloc/src/eset.c + src/duckdb/third_party/jemalloc/src/exp_grow.c + src/duckdb/third_party/jemalloc/src/extent.c + src/duckdb/third_party/jemalloc/src/extent_dss.c + src/duckdb/third_party/jemalloc/src/extent_mmap.c + src/duckdb/third_party/jemalloc/src/fxp.c + src/duckdb/third_party/jemalloc/src/hook.c + src/duckdb/third_party/jemalloc/src/hpa.c + src/duckdb/third_party/jemalloc/src/hpa_hooks.c + src/duckdb/third_party/jemalloc/src/hpdata.c + src/duckdb/third_party/jemalloc/src/inspect.c + src/duckdb/third_party/jemalloc/src/jemalloc.c + src/duckdb/third_party/jemalloc/src/large.c + src/duckdb/third_party/jemalloc/src/log.c + src/duckdb/third_party/jemalloc/src/malloc_io.c + src/duckdb/third_party/jemalloc/src/mutex.c + src/duckdb/third_party/jemalloc/src/nstime.c + src/duckdb/third_party/jemalloc/src/pa.c + src/duckdb/third_party/jemalloc/src/pa_extra.c + src/duckdb/third_party/jemalloc/src/pac.c + src/duckdb/third_party/jemalloc/src/pages.c + src/duckdb/third_party/jemalloc/src/pai.c + src/duckdb/third_party/jemalloc/src/peak_event.c + src/duckdb/third_party/jemalloc/src/prof.c + src/duckdb/third_party/jemalloc/src/prof_data.c + src/duckdb/third_party/jemalloc/src/prof_log.c + src/duckdb/third_party/jemalloc/src/prof_recent.c + src/duckdb/third_party/jemalloc/src/prof_stats.c + src/duckdb/third_party/jemalloc/src/prof_sys.c + src/duckdb/third_party/jemalloc/src/psset.c + src/duckdb/third_party/jemalloc/src/rtree.c + src/duckdb/third_party/jemalloc/src/safety_check.c + src/duckdb/third_party/jemalloc/src/san.c + src/duckdb/third_party/jemalloc/src/san_bump.c + src/duckdb/third_party/jemalloc/src/sc.c + src/duckdb/third_party/jemalloc/src/sec.c + src/duckdb/third_party/jemalloc/src/stats.c + src/duckdb/third_party/jemalloc/src/sz.c + src/duckdb/third_party/jemalloc/src/tcache.c + src/duckdb/third_party/jemalloc/src/test_hooks.c + src/duckdb/third_party/jemalloc/src/thread_event.c + src/duckdb/third_party/jemalloc/src/ticker.c + src/duckdb/third_party/jemalloc/src/tsd.c + src/duckdb/third_party/jemalloc/src/util.c + src/duckdb/third_party/jemalloc/src/witness.c + src/duckdb/third_party/jemalloc/src/zone.c) # a few OS-specific variables @@ -785,7 +786,7 @@ target_compile_definitions(duckdb_java PRIVATE if(NOT MSVC AND NOT ZOS) target_compile_definitions(duckdb_java PRIVATE - -DDUCKDB_EXTENSION_JEMALLOC_LINKED) + -DDUCKDB_ENABLE_JEMALLOC) endif() if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp index 2b685e309..3186d2dee 100644 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp @@ -12,13 +12,12 @@ namespace { template struct AvgState { + static constexpr const char *STATE_NAMES[] = {"count", "value"}; + using STATE_TYPE = StructStateType; + uint64_t count; T value; - void Initialize() { - this->count = 0; - } - void Combine(const AvgState &other) { this->count += other.count; this->value += other.value; @@ -26,14 +25,12 @@ struct AvgState { }; struct IntervalAvgState { + static constexpr const char *STATE_NAMES[] = {"count", "value"}; + using STATE_TYPE = StructStateType; + int64_t count; interval_t value; - void Initialize() { - this->count = 0; - this->value = interval_t(); - } - void Combine(const IntervalAvgState &other) { this->count += other.count; this->value = AddOperator::Operation(this->value, other.value); @@ -41,15 +38,13 @@ struct IntervalAvgState { }; struct KahanAvgState { + static constexpr const char *STATE_NAMES[] = {"count", "value", "err"}; + using STATE_TYPE = StructStateType; + uint64_t count; double value; double err; - void Initialize() { - this->count = 0; - this->err = 0.0; - } - void Combine(const KahanAvgState &other) { this->count += other.count; KahanAddInternal(other.value, this->value, this->err); @@ -75,10 +70,6 @@ struct AverageDecimalBindData : public FunctionData { }; struct AverageSetOperation { - template - static void Initialize(STATE &state) { - state.Initialize(); - } template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { target.Combine(source); @@ -172,12 +163,6 @@ struct KahanAverageOperation : public BaseSumOperation { - // Override BaseSumOperation::Initialize because - // IntervalAvgState does not have an assignment constructor from 0 - static void Initialize(IntervalAvgState &state) { - AverageSetOperation::Initialize(state); - } - template static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) { if (state.count == 0) { @@ -239,47 +224,27 @@ struct TimeTZAverageOperation : public BaseSumOperation children; - children.emplace_back("count", LogicalType::UBIGINT); - children.emplace_back("value", function.GetArguments()[0]); - return LogicalType::STRUCT(std::move(children)); -} - -LogicalType GetKahanAvgStateType(const BoundAggregateFunction &function) { - child_list_t children; - children.emplace_back("count", LogicalType::UBIGINT); - children.emplace_back("value", LogicalType::DOUBLE); - children.emplace_back("err", LogicalType::DOUBLE); - return LogicalType::STRUCT(std::move(children)); -} - AggregateFunction GetAverageAggregate(PhysicalType type) { switch (type) { case PhysicalType::INT16: { return AggregateFunction::UnaryAggregate, int16_t, double, IntegerAverageOperation>( - LogicalType::SMALLINT, LogicalType::DOUBLE) - .SetStructStateExport(GetAvgStateType); + LogicalType::SMALLINT, LogicalType::DOUBLE); } case PhysicalType::INT32: { return AggregateFunction::UnaryAggregate, int32_t, double, IntegerAverageOperationHugeint>( - LogicalType::INTEGER, LogicalType::DOUBLE) - .SetStructStateExport(GetAvgStateType); + LogicalType::INTEGER, LogicalType::DOUBLE); } case PhysicalType::INT64: { return AggregateFunction::UnaryAggregate, int64_t, double, IntegerAverageOperationHugeint>( - LogicalType::BIGINT, LogicalType::DOUBLE) - .SetStructStateExport(GetAvgStateType); + LogicalType::BIGINT, LogicalType::DOUBLE); } case PhysicalType::INT128: { return AggregateFunction::UnaryAggregate, hugeint_t, double, HugeintAverageOperation>( - LogicalType::HUGEINT, LogicalType::DOUBLE) - .SetStructStateExport(GetAvgStateType); + LogicalType::HUGEINT, LogicalType::DOUBLE); } case PhysicalType::INTERVAL: { return AggregateFunction::UnaryAggregate( - LogicalType::INTERVAL, LogicalType::INTERVAL) - .SetStructStateExport(GetAvgStateType); + LogicalType::INTERVAL, LogicalType::INTERVAL); } default: throw InternalException("Unimplemented average aggregate"); @@ -313,31 +278,24 @@ AggregateFunctionSet AvgFun::GetFunctions() { avg.AddFunction(GetAverageAggregate(PhysicalType::INT128)); avg.AddFunction(GetAverageAggregate(PhysicalType::INTERVAL)); avg.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericAverageOperation>( - LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetAvgStateType)); + LogicalType::DOUBLE, LogicalType::DOUBLE)); avg.AddFunction(AggregateFunction::UnaryAggregate, int64_t, int64_t, DiscreteAverageOperation>( - LogicalType::TIMESTAMP, LogicalType::TIMESTAMP) - .SetStructStateExport(GetAvgStateType)); + LogicalType::TIMESTAMP, LogicalType::TIMESTAMP)); avg.AddFunction(AggregateFunction::UnaryAggregate, int64_t, int64_t, DiscreteAverageOperation>( - LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ) - .SetStructStateExport(GetAvgStateType)); + LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ)); avg.AddFunction(AggregateFunction::UnaryAggregate, int64_t, int64_t, DiscreteAverageOperation>( - LogicalType::TIME, LogicalType::TIME) - .SetStructStateExport(GetAvgStateType)); + LogicalType::TIME, LogicalType::TIME)); avg.AddFunction( AggregateFunction::UnaryAggregate, dtime_tz_t, dtime_tz_t, TimeTZAverageOperation>( - LogicalType::TIME_TZ, LogicalType::TIME_TZ) - .SetStructStateExport(GetAvgStateType)); + LogicalType::TIME_TZ, LogicalType::TIME_TZ)); return avg; } AggregateFunction FAvgFun::GetFunction() { - auto function = AggregateFunction::UnaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetKahanAvgStateType); - return function; + return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, + LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp index c3cb919bb..bf53a5ad3 100644 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp @@ -6,34 +6,8 @@ namespace duckdb { -LogicalType GetCorrStateType() { - child_list_t covar_children; - covar_children.emplace_back("count", LogicalType::UBIGINT); - covar_children.emplace_back("meanx", LogicalType::DOUBLE); - covar_children.emplace_back("meany", LogicalType::DOUBLE); - covar_children.emplace_back("co_moment", LogicalType::DOUBLE); - auto cov_pop_type = LogicalType::STRUCT(std::move(covar_children)); - - child_list_t stddev_types; - stddev_types.emplace_back("count", LogicalType::UBIGINT); - stddev_types.emplace_back("mean", LogicalType::DOUBLE); - stddev_types.emplace_back("dsquared", LogicalType::DOUBLE); - auto stddev_type = LogicalType::STRUCT(std::move(stddev_types)); - - child_list_t state_children; - state_children.emplace_back("cov_pop", std::move(cov_pop_type)); - state_children.emplace_back("dev_pop_x", stddev_type); - state_children.emplace_back("dev_pop_y", stddev_type); - return LogicalType::STRUCT(std::move(state_children)); -} - -LogicalType GetCorrExportStateType(const BoundAggregateFunction &) { - return GetCorrStateType(); -} - AggregateFunction CorrFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetCorrExportStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp index 0c5a848a8..8a166daf2 100644 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp @@ -3,29 +3,14 @@ namespace duckdb { -namespace { - -LogicalType GetCovarStateType(const BoundAggregateFunction &) { - child_list_t child_types; - child_types.emplace_back("count", LogicalType::UBIGINT); - child_types.emplace_back("meanx", LogicalType::DOUBLE); - child_types.emplace_back("meany", LogicalType::DOUBLE); - child_types.emplace_back("co_moment", LogicalType::DOUBLE); - return LogicalType::STRUCT(std::move(child_types)); -} - -} // namespace - AggregateFunction CovarPopFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetCovarStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } AggregateFunction CovarSampFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetCovarStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp index 3b252d80d..6e1971cd3 100644 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp +++ b/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp @@ -4,46 +4,29 @@ namespace duckdb { -namespace { - -LogicalType GetStddevStateType(const BoundAggregateFunction &) { - child_list_t child_types; - child_types.emplace_back("count", LogicalType::UBIGINT); - child_types.emplace_back("mean", LogicalType::DOUBLE); - child_types.emplace_back("dsquared", LogicalType::DOUBLE); - return LogicalType::STRUCT(std::move(child_types)); -} - -} // namespace - AggregateFunction StdDevSampFun::GetFunction() { return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE) - .SetStructStateExport(GetStddevStateType); + LogicalType::DOUBLE); } AggregateFunction StdDevPopFun::GetFunction() { return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE) - .SetStructStateExport(GetStddevStateType); + LogicalType::DOUBLE); } AggregateFunction VarPopFun::GetFunction() { return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE) - .SetStructStateExport(GetStddevStateType); + LogicalType::DOUBLE); } AggregateFunction VarSampFun::GetFunction() { return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE) - .SetStructStateExport(GetStddevStateType); + LogicalType::DOUBLE); } AggregateFunction StandardErrorOfTheMeanFun::GetFunction() { return AggregateFunction::UnaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetStddevStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp index 1c0eb2d6e..51af608c3 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp @@ -14,11 +14,6 @@ struct ApproxDistinctCountState { }; struct ApproxCountDistinctFunction { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { target.hll.Merge(source.hll); @@ -34,21 +29,6 @@ struct ApproxCountDistinctFunction { } }; -void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state, - idx_t count) { - D_ASSERT(input_count == 1); - auto &input = inputs[0]; - - if (count > STANDARD_VECTOR_SIZE) { - throw InternalException("ApproxCountDistinct - count must be at most vector size"); - } - Vector hash_vec(LogicalType::HASH, count); - VectorOperations::Hash(input, hash_vec, count); - - auto agg_state = reinterpret_cast(state); - agg_state->hll.Update(input, hash_vec, count); -} - void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { D_ASSERT(input_count == 1); @@ -80,8 +60,7 @@ AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) AggregateFunction::StateInitialize, ApproxCountDistinctUpdateFunction, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, - ApproxCountDistinctSimpleUpdateFunction); + AggregateFunction::StateFinalize, nullptr); fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp index 44af8f5c2..ee6210754 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp @@ -84,11 +84,6 @@ struct ArgMinMaxState : public ArgMinMaxStateBase { template struct ArgMinMaxBase { - template - static void Initialize(STATE &state) { - new (&state) STATE; - } - template static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { state.~STATE(); @@ -201,25 +196,25 @@ struct ArgMinMaxBase { }; struct SpecializedGenericArgMinMaxState { - static bool CreateExtraState(idx_t count) { + static bool CreateExtraState() { // nop extra state return false; } - static void PrepareData(Vector &by, idx_t count, bool &, UnifiedVectorFormat &result) { + static void PrepareData(const Vector &by, bool &, UnifiedVectorFormat &result) { by.ToUnifiedFormat(result); } }; template struct GenericArgMinMaxState { - static Vector CreateExtraState(idx_t count) { - return Vector(LogicalType::BLOB, count); + static Vector CreateExtraState() { + return Vector(LogicalType::BLOB); } - static void PrepareData(Vector &by, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) { + static void PrepareData(const Vector &by, Vector &extra_state, UnifiedVectorFormat &result) { OrderModifiers modifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST); - CreateSortKeyHelpers::CreateSortKeyWithValidity(by, extra_state, modifiers, count); + CreateSortKeyHelpers::CreateSortKeyWithValidity(by, extra_state, modifiers); extra_state.ToUnifiedFormat(result); } }; @@ -240,8 +235,8 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { using BY_TYPE = typename STATE::BY_TYPE; auto &by = inputs[1]; UnifiedVectorFormat bdata; - auto extra_state = UPDATE_TYPE::CreateExtraState(count); - UPDATE_TYPE::PrepareData(by, count, extra_state, bdata); + auto extra_state = UPDATE_TYPE::CreateExtraState(); + UPDATE_TYPE::PrepareData(by, extra_state, bdata); const auto bys = UnifiedVectorFormat::GetData(bdata); UnifiedVectorFormat sdata; @@ -309,7 +304,7 @@ struct VectorArgMinMaxBase : ArgMinMaxBase { // slice with a selection vector and generate sort keys SelectionVector sel(assign_sel, assign_count); Vector sliced_input(arg, sel, assign_count); - CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key); + CreateSortKeyHelpers::CreateSortKey(sliced_input, modifiers, sort_key); auto sort_key_data = FlatVector::GetData(sort_key); // now assign sort keys @@ -665,11 +660,11 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp UnifiedVectorFormat n_format; UnifiedVectorFormat state_format; - auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); - auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count); + auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(); + auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(); - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, bind_data.nulls_last); - STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format, bind_data.nulls_last); + STATE::VAL_TYPE::PrepareData(val_vector, val_extra_state, val_format, bind_data.nulls_last); + STATE::ARG_TYPE::PrepareData(arg_vector, arg_extra_state, arg_format, bind_data.nulls_last); n_vector.ToUnifiedFormat(n_format); state_vector.ToUnifiedFormat(state_format); diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp index 7c5a3845a..d84e7483b 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp @@ -1,10 +1,10 @@ #include "core_functions/aggregate/distributive_functions.hpp" #include "duckdb/common/exception.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/operator/aggregate_operators.hpp" #include "duckdb/common/vector_operations/aggregate_executor.hpp" #include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/function/aggregate/distributive_function_utils.hpp" +#include "duckdb/function/aggregate_state_layout.hpp" namespace duckdb { @@ -12,74 +12,41 @@ namespace { template struct BitState { - using TYPE = T; - bool is_set; + using value_type = T; + using STATE_TYPE = OptionalStateType; T value; + bool is_set; }; -template -LogicalType GetBitStateType(const BoundAggregateFunction &function) { - child_list_t child_types; - child_types.emplace_back("is_set", LogicalType::BOOLEAN); - - LogicalType value_type = function.GetReturnType(); - child_types.emplace_back("value", value_type); - - return LogicalType::STRUCT(std::move(child_types)); -} - -LogicalType GetBitStringStateType(const BoundAggregateFunction &function) { - child_list_t child_types; - child_types.emplace_back("is_set", LogicalType::BOOLEAN); - child_types.emplace_back("value", function.GetReturnType()); - return LogicalType::STRUCT(std::move(child_types)); -} - template AggregateFunction GetBitfieldUnaryAggregate(LogicalType type) { switch (type.id()) { case LogicalTypeId::TINYINT: - return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); case LogicalTypeId::SMALLINT: - return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type); case LogicalTypeId::INTEGER: - return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type); case LogicalTypeId::BIGINT: - return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type); case LogicalTypeId::HUGEINT: - return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type); case LogicalTypeId::UTINYINT: - return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type); case LogicalTypeId::USMALLINT: - return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type); case LogicalTypeId::UINTEGER: - return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type); case LogicalTypeId::UBIGINT: - return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type); case LogicalTypeId::UHUGEINT: - return AggregateFunction::UnaryAggregate, uhugeint_t, uhugeint_t, OP>(type, type) - .SetStructStateExport(GetBitStateType); + return AggregateFunction::UnaryAggregate, uhugeint_t, uhugeint_t, OP>(type, type); default: throw InternalException("Unimplemented bitfield type for unary aggregate"); } } struct BitwiseOperation { - template - static void Initialize(STATE &state) { - // If there are no matching rows, returns a null value. - state.is_set = false; - } - template static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { if (!state.is_set) { @@ -96,11 +63,6 @@ struct BitwiseOperation { OP::template Operation(state, input, unary_input); } - template - static void Assign(STATE &state, INPUT_TYPE input) { - state.value = typename STATE::TYPE(input); - } - template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { if (!source.is_set) { @@ -109,13 +71,18 @@ struct BitwiseOperation { } if (!target.is_set) { // target is NULL, use source value directly. - OP::template Assign(target, source.value); + OP::template Assign(target, source.value); target.is_set = true; } else { - OP::template Execute(target, source.value); + OP::template Execute(target, source.value); } } + template + static void Assign(STATE &state, INPUT_TYPE input) { + state.value = typename STATE::value_type(input); + } + template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { if (!state.is_set) { @@ -130,26 +97,55 @@ struct BitwiseOperation { } }; -struct BitAndOperation : public BitwiseOperation { +template +struct NumericBitwiseOperation : public BitwiseOperation, public ClusteredStateCopy { template - static void Execute(STATE &state, INPUT_TYPE input) { - state.value &= typename STATE::TYPE(input); - ; + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input) { + if (!local.is_set) { + Assign(local, input); + local.is_set = true; + } else { + OP::template Execute(local, input); + } + } + + template + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input, idx_t count) { + if (count != 0) { + UpdateClusteredLocal(local, input); + } } }; -struct BitOrOperation : public BitwiseOperation { +template +struct SimpleBitwiseOperation : public NumericBitwiseOperation> { template static void Execute(STATE &state, INPUT_TYPE input) { - state.value |= typename STATE::TYPE(input); - ; + state.value = + OP::template Operation(state.value, typename STATE::value_type(input)); } }; -struct BitXorOperation : public BitwiseOperation { +using BitAndOperation = SimpleBitwiseOperation; +using BitOrOperation = SimpleBitwiseOperation; + +struct BitXorOperation : public NumericBitwiseOperation { + using NumericBitwiseOperation::UpdateClusteredLocal; + template static void Execute(STATE &state, INPUT_TYPE input) { - state.value ^= typename STATE::TYPE(input); + state.value = + BitXor::template Operation(state.value, typename STATE::value_type(input)); + } + + template + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input, idx_t count) { + if ((count & 1) != 0) { + NumericBitwiseOperation::template UpdateClusteredLocal(local, input); + } else if (count != 0 && !local.is_set) { + local.value = typename STATE::value_type(0); + local.is_set = true; + } } template @@ -161,6 +157,8 @@ struct BitXorOperation : public BitwiseOperation { } }; +using BitStringState = BitState; + struct BitStringBitwiseOperation : public BitwiseOperation { template static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { @@ -229,12 +227,9 @@ AggregateFunctionSet BitAndFun::GetFunctions() { for (auto &type : LogicalType::Integral()) { bit_and.AddFunction(GetBitfieldUnaryAggregate(type)); } - - auto bit_string_fun = - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringAndOperation>( - LogicalType::BIT, LogicalType::BIT); - bit_string_fun.SetStructStateExport(GetBitStringStateType); - bit_and.AddFunction(bit_string_fun); + bit_and.AddFunction( + AggregateFunction::UnaryAggregateDestructor( + LogicalType::BIT, LogicalType::BIT)); return bit_and; } @@ -243,11 +238,9 @@ AggregateFunctionSet BitOrFun::GetFunctions() { for (auto &type : LogicalType::Integral()) { bit_or.AddFunction(GetBitfieldUnaryAggregate(type)); } - auto bit_string_fun = - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringOrOperation>( - LogicalType::BIT, LogicalType::BIT); - bit_string_fun.SetStructStateExport(GetBitStringStateType); - bit_or.AddFunction(bit_string_fun); + bit_or.AddFunction( + AggregateFunction::UnaryAggregateDestructor( + LogicalType::BIT, LogicalType::BIT)); return bit_or; } @@ -256,11 +249,9 @@ AggregateFunctionSet BitXorFun::GetFunctions() { for (auto &type : LogicalType::Integral()) { bit_xor.AddFunction(GetBitfieldUnaryAggregate(type)); } - auto bit_string_fun = - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringXorOperation>( - LogicalType::BIT, LogicalType::BIT); - bit_string_fun.SetStructStateExport(GetBitStringStateType); - bit_xor.AddFunction(bit_string_fun); + bit_xor.AddFunction( + AggregateFunction::UnaryAggregateDestructor( + LogicalType::BIT, LogicalType::BIT)); return bit_xor; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp index bcbeb4f74..2c2d4837e 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp @@ -67,11 +67,6 @@ struct BitstringAggBindData : public FunctionData { struct BitStringAggOperation { static constexpr const idx_t MAX_BIT_RANGE = 1000000000; // for now capped at 1 billion bits - template - static void Initialize(STATE &state) { - state.is_set = false; - } - template static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { auto &bind_agg_data = unary_input.input.bind_data->template Cast(); diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp index 64938c759..a39e13f2b 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp @@ -1,91 +1,50 @@ #include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/function/function_set.hpp" +#include "duckdb/common/operator/aggregate_operators.hpp" namespace duckdb { namespace { struct BoolState { - bool empty; - bool val; + using STATE_TYPE = OptionalStateType; + bool value; + bool is_set; }; -struct BoolAndFunFunction { - template - static void Initialize(STATE &state) { - state.val = true; - state.empty = true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val = target.val && source.val; - target.empty = target.empty && source.empty; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { - finalize_data.ReturnNull(); - return; - } - target = state.val; - } +template +struct BoolAggregate { + // No Initialize: StateInitialize falls back to memset(state, 0) = nullopt template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.empty = false; - state.val = input && state.val; + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + state.value = REDUCE_OP::template Operation(state.is_set ? state.value : INIT_VALUE, bool(input)); + state.is_set = true; } template static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { - for (idx_t i = 0; i < count; i++) { + if (count > 0) { Operation(state, input, unary_input); } } - static bool IgnoreNull() { - return true; - } -}; - -struct BoolOrFunFunction { - template - static void Initialize(STATE &state) { - state.val = false; - state.empty = true; - } template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val = target.val || source.val; - target.empty = target.empty && source.empty; + if (!source.is_set) { + return; + } + target.value = REDUCE_OP::template Operation(target.is_set ? target.value : INIT_VALUE, source.value); + target.is_set = true; } template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { + if (!state.is_set) { finalize_data.ReturnNull(); return; } - target = state.val; - } - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.empty = false; - state.val = input || state.val; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } + target = state.value; } static bool IgnoreNull() { @@ -93,12 +52,8 @@ struct BoolOrFunFunction { } }; -LogicalType GetBoolAndStateType(const BoundAggregateFunction &function) { - child_list_t child_types; - child_types.emplace_back("empty", LogicalType::BOOLEAN); - child_types.emplace_back("val", LogicalType::BOOLEAN); - return LogicalType::STRUCT(std::move(child_types)); -} +using BoolAndFunFunction = BoolAggregate; +using BoolOrFunFunction = BoolAggregate; } // namespace @@ -107,7 +62,7 @@ AggregateFunction BoolOrFun::GetFunction() { LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); fun.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT); - return fun.SetStructStateExport(GetBoolAndStateType); + return fun; } AggregateFunction BoolAndFun::GetFunction() { @@ -115,7 +70,7 @@ AggregateFunction BoolAndFun::GetFunction() { LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); fun.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT); - return fun.SetStructStateExport(GetBoolAndStateType); + return fun; } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp index 95eb5a2e8..6f023641f 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp @@ -9,6 +9,9 @@ namespace duckdb { namespace { struct KurtosisState { + static constexpr const char *STATE_NAMES[] = {"n", "sum", "sum_sqr", "sum_cub", "sum_four"}; + using STATE_TYPE = StructStateType; + idx_t n; double sum; double sum_sqr; @@ -22,12 +25,6 @@ struct KurtosisFlagNoBiasCorrection {}; template struct KurtosisOperation { - template - static void Initialize(STATE &state) { - state.n = 0; - state.sum = state.sum_sqr = state.sum_cub = state.sum_four = 0.0; - } - template static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { @@ -100,16 +97,6 @@ struct KurtosisOperation { } }; -LogicalType GetKurtosisStateType(const BoundAggregateFunction &function) { - child_list_t children; - children.emplace_back("n", LogicalType::UBIGINT); - children.emplace_back("sum", LogicalType::DOUBLE); - children.emplace_back("sum_sqr", LogicalType::DOUBLE); - children.emplace_back("sum_cub", LogicalType::DOUBLE); - children.emplace_back("sum_four", LogicalType::DOUBLE); - return LogicalType::STRUCT(std::move(children)); -} - } // namespace AggregateFunction KurtosisFun::GetFunction() { @@ -117,7 +104,6 @@ AggregateFunction KurtosisFun::GetFunction() { AggregateFunction::UnaryAggregate>( LogicalType::DOUBLE, LogicalType::DOUBLE); result.SetFallible(); - result.SetStructStateExport(GetKurtosisStateType); return result; } @@ -126,7 +112,6 @@ AggregateFunction KurtosisPopFun::GetFunction() { KurtosisOperation>( LogicalType::DOUBLE, LogicalType::DOUBLE); result.SetFallible(); - result.SetStructStateExport(GetKurtosisStateType); return result; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp index 3bfa40af2..424d87b4b 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp @@ -1,7 +1,9 @@ #include "core_functions/aggregate/distributive_functions.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/multiply.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/aggregate/distributive_function_utils.hpp" #include "duckdb/function/function_set.hpp" namespace duckdb { @@ -9,65 +11,36 @@ namespace duckdb { namespace { struct ProductState { + static constexpr const char *STATE_NAMES[] = {"empty", "val"}; + using STATE_TYPE = StructStateType; + bool empty; double val; }; -struct ProductFunction { - template - static void Initialize(STATE &state) { - state.val = 1; - state.empty = true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val *= source.val; - target.empty = target.empty && source.empty; +struct ProductReduce { + template + static T Operation(T left, T right) { + return MultiplyOperator::template Operation(left, right); } +}; - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { - finalize_data.ReturnNull(); - return; - } - target = state.val; - } - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (state.empty) { - state.empty = false; - } - state.val *= input; - } +struct ProductFunction : public EmptyValAggregate> { + using EmptyValAggregate>::UpdateClusteredLocal; - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { + template + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input, idx_t count) { for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); + EmptyValAggregate>::template UpdateClusteredLocal(local, input); } } - - static bool IgnoreNull() { - return true; - } }; -LogicalType GetProductStateType(const BoundAggregateFunction &function) { - child_list_t children; - children.emplace_back("empty", LogicalType::BOOLEAN); - children.emplace_back("val", LogicalType::DOUBLE); - return LogicalType::STRUCT(std::move(children)); -} - } // namespace AggregateFunction ProductFun::GetFunction() { return AggregateFunction::UnaryAggregate( - LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE) - .SetStructStateExport(GetProductStateType); + LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp index 2cca496e7..720d60dbc 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp @@ -3,25 +3,23 @@ #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/common/algorithm.hpp" +#include namespace duckdb { namespace { struct SkewState { - size_t n; + static constexpr const char *STATE_NAMES[] = {"n", "sum", "sum_sqr", "sum_cub"}; + using STATE_TYPE = StructStateType; + + idx_t n; double sum; double sum_sqr; double sum_cub; }; struct SkewnessOperation { - template - static void Initialize(STATE &state) { - state.n = 0; - state.sum = state.sum_sqr = state.sum_cub = 0; - } - template static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { @@ -58,15 +56,21 @@ struct SkewnessOperation { } double n = state.n; double temp = 1 / n; - auto p = std::pow(temp * (state.sum_sqr - state.sum * state.sum * temp), 3); - if (p < 0) { - p = 0; // Shouldn't be below 0 but floating points are weird + double raw_m2 = state.sum_sqr - state.sum * state.sum * temp; + // Only treat finite second-moment noise as zero; overflow should still surface as out-of-range. + if (Value::DoubleIsFinite(raw_m2) && Value::DoubleIsFinite(state.sum_sqr) && + // Scale the tolerance by the accumulated squared magnitude instead of using a fixed epsilon. + std::abs(raw_m2) <= std::numeric_limits::epsilon() * std::max(1.0, std::abs(state.sum_sqr))) { + finalize_data.ReturnNull(); + return; } - double div = std::sqrt(p); - if (div == 0) { - target = NAN; + double variance = temp * raw_m2; + if (variance <= 0) { + finalize_data.ReturnNull(); return; } + auto p = std::pow(variance, 3); + double div = std::sqrt(p); double temp1 = std::sqrt(n * (n - 1)) / (n - 2); target = temp1 * temp * (state.sum_cub - 3 * state.sum_sqr * state.sum * temp + 2 * pow(state.sum, 3) * temp * temp) / div; diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp index b80e1e4b0..84663c72c 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp @@ -13,9 +13,9 @@ namespace duckdb { namespace { struct StringAggState { - idx_t size; - idx_t alloc_size; - char *dataptr; + string_t value; + bool is_set; + uint32_t alloc_size; }; struct StringAggBindData : public FunctionData { @@ -34,19 +34,12 @@ struct StringAggBindData : public FunctionData { }; struct StringAggFunction { - template - static void Initialize(STATE &state) { - state.dataptr = nullptr; - state.alloc_size = 0; - state.size = 0; - } - template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.dataptr) { + if (!state.is_set) { finalize_data.ReturnNull(); } else { - target = string_t(state.dataptr, state.size); + target = state.value; } } @@ -56,31 +49,41 @@ struct StringAggFunction { static inline void PerformOperation(StringAggState &state, ArenaAllocator &allocator, const char *str, const char *sep, idx_t str_size, idx_t sep_size) { - if (!state.dataptr) { - // first iteration: allocate space for the string and copy it into the state - state.alloc_size = MaxValue(8, NextPowerOfTwo(str_size)); - state.dataptr = char_ptr_cast(allocator.Allocate(state.alloc_size)); - state.size = str_size; - memcpy(state.dataptr, str, str_size); + idx_t new_size; + auto current_size = state.value.GetSize(); + if (!state.is_set) { + new_size = str_size; } else { - // subsequent iteration: first check if we have space to place the string and separator - idx_t required_size = state.size + str_size + sep_size; - if (required_size > state.alloc_size) { - // no space! allocate extra space - const auto old_size = state.alloc_size; - while (state.alloc_size < required_size) { - state.alloc_size *= 2; - } - state.dataptr = - char_ptr_cast(allocator.Reallocate(data_ptr_cast(state.dataptr), old_size, state.alloc_size)); + new_size = current_size + sep_size + str_size; + } + char *target_data; + if (new_size > state.alloc_size && new_size > string_t::INLINE_LENGTH) { + // need to (re-)allocate space + if (new_size > NumericLimits::Maximum()) { + throw InvalidInputException("string_agg string size exceeds maximum string size"); + } + state.alloc_size = NextPowerOfTwo(new_size); + target_data = char_ptr_cast(allocator.Allocate(state.alloc_size)); + if (current_size > 0) { + // copy over the current data + memcpy(target_data, state.value.GetData(), current_size); } + state.value = string_t(target_data, new_size); + } else { + target_data = state.value.GetDataWriteable(); + } + if (!state.is_set) { + memcpy(target_data, str, str_size); + state.is_set = true; + } else { // copy the separator - memcpy(state.dataptr + state.size, sep, sep_size); - state.size += sep_size; + target_data += current_size; + memcpy(target_data, sep, sep_size); // copy the string - memcpy(state.dataptr + state.size, str, str_size); - state.size += str_size; + target_data += sep_size; + memcpy(target_data, str, str_size); } + state.value.SetSizeAndFinalize(new_size, state.alloc_size); } static inline void PerformOperation(StringAggState &state, ArenaAllocator &allocator, string_t str, @@ -104,12 +107,11 @@ struct StringAggFunction { template static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - if (!source.dataptr) { + if (!source.is_set) { // source is not set: skip combining return; } - PerformOperation(target, aggr_input_data.allocator, - string_t(source.dataptr, UnsafeNumericCast(source.size)), aggr_input_data.bind_data); + PerformOperation(target, aggr_input_data.allocator, source.value, aggr_input_data.bind_data); } }; @@ -167,7 +169,7 @@ AggregateFunctionSet StringAggFun::GetFunctions() { AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, - AggregateFunction::UnaryUpdate, StringAggBind); + FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NoClusterUpdate(), StringAggBind); string_agg_param.SetSerializeCallback(StringAggSerialize); string_agg_param.SetDeserializeCallback(StringAggDeserialize); string_agg.AddFunction(string_agg_param); diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp index b86db63c5..9fa75794e 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/bignum.hpp" #include "duckdb/common/types/decimal.hpp" +#include "duckdb/function/aggregate/distributive_function_utils.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -12,60 +13,185 @@ namespace { struct SumSetOperation { template - static void Initialize(STATE &state) { - state.Initialize(); + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + if (!source.is_set) { + return; + } + if (!target.is_set) { + target.value = source.value; + target.is_set = true; + } else { + target.value += source.value; + } } + template + static void AddValues(STATE &state, idx_t count) { + state.is_set = true; + } +}; + +struct KahanSumSetOperation { template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { target.Combine(source); } template static void AddValues(STATE &state, idx_t count) { - state.isset = true; + state.is_set = true; } }; -struct IntegerSumOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = Hugeint::Convert(state.value); +template +struct ClusteredSumStateCopy : public BASE, public ClusteredStateCopy { + template + static void FlushClusteredLocal(STATE &state, STATE &local, bool saw_value) { + if (saw_value) { + local.is_set = true; } + state = local; + } +}; + +template +struct ClusteredAddOp { + template + static void Execute(STATE &local, const INPUT_TYPE &input) { + ADD_OP::template AddNumber(local, input); + } + + template + static void Execute(STATE &local, const INPUT_TYPE &input, idx_t count) { + ADD_OP::template AddConstant(local, input, count); } }; -struct SumToHugeintOperation : public BaseSumOperation { +template +struct ClusteredSumOperation : public ClusteredSumStateCopy { + static constexpr bool ClusteredI64HugeintSum = std::is_same>::value; + + template + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input) { + LOCAL_OP::template Execute(local, input); + } + + template + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input, idx_t count) { + LOCAL_OP::template Execute(local, input, count); + } + template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { + if (!state.is_set) { finalize_data.ReturnNull(); } else { target = state.value; } } + + template + static void ExecuteFlatI64HugeintSum(const INPUT_TYPE *vals, const ClusteredAggr &clustered, + const SelectionVector *isel, const sel_t *cluster_iter) { + idx_t pos = 0; + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + const auto *run_sel = clustered.group_runs[r].sel; + const idx_t run_count = clustered.group_runs[r].count; + if (run_count == 0) { + continue; + } + state.is_set = true; + int64_t local64 = 0; + auto add_row = [&](idx_t idx) { + const int64_t v = static_cast(vals[idx]); + if (DUCKDB_UNLIKELY(!I64VectorSumSafe(v))) { + state.value = Hugeint::Add(state.value, v); + } else { + local64 += v; + } + }; + if (cluster_iter) { + for (idx_t k = 0; k < run_count; k++) { + add_row(cluster_iter[pos + k]); + } + } else if (isel && run_sel) { + for (idx_t k = 0; k < run_count; k++) { + add_row(isel->get_index(run_sel[k])); + } + } else if (isel) { + for (idx_t k = 0; k < run_count; k++) { + add_row(isel->get_index(k)); + } + } else if (run_sel) { + for (idx_t k = 0; k < run_count; k++) { + add_row(run_sel[k]); + } + } else { + for (idx_t k = 0; k < run_count; k++) { + add_row(k); + } + } + pos += run_count; + state.value = Hugeint::Add(state.value, local64); + } + } + + template + static void ExecuteDictI64HugeintSum(Vector &input, const ClusteredAggr &clustered, const int64_t *safe_vals) { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(idata); + const auto *dict_sel = idata.sel->data(); + auto &validity = idata.validity; + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + const auto *run_sel = clustered.group_runs[r].sel; + const idx_t run_count = clustered.group_runs[r].count; + int64_t local64 = 0; + if (run_sel) { + for (idx_t k = 0; k < run_count; k++) { + local64 += safe_vals[dict_sel[run_sel[k]]]; + } + } else { + for (idx_t k = 0; k < run_count; k++) { + local64 += safe_vals[dict_sel[k]]; + } + } + if (local64 != 0) { + state.is_set = true; + state.value = Hugeint::Add(state.value, local64); + } else if (!state.is_set) { // rare: we added 0 -- were all values NULL? + for (idx_t k = 0; k < run_count; k++) { + const idx_t i = dict_sel[run_sel ? run_sel[k] : k]; + if (validity.RowIsValidUnsafe(i)) { // we added non-NULL + state.is_set = true; + break; + } + } + } + } + } }; -template -struct DoubleSumOperation : public BaseSumOperation { +struct IntegerSumOperation + : public ClusteredSumOperation, ClusteredAddOp> { template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { + if (!state.is_set) { finalize_data.ReturnNull(); } else { - target = state.value; + target = Hugeint::Convert(state.value); } } }; -using NumericSumOperation = DoubleSumOperation; -using KahanSumOperation = DoubleSumOperation; +using SumToHugeintOperation = + ClusteredSumOperation, ClusteredAddOp>; +using NumericSumOperation = + ClusteredSumOperation, ClusteredAddOp>; -struct HugeintSumOperation : public BaseSumOperation { +struct KahanSumOperation : public BaseSumOperation { template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { + if (!state.is_set) { finalize_data.ReturnNull(); } else { target = state.value; @@ -73,36 +199,8 @@ struct HugeintSumOperation : public BaseSumOperation -static LogicalType GetValueLogicalType(); - -template <> -LogicalType GetValueLogicalType() { - return LogicalType::BIGINT; -} -template <> -LogicalType GetValueLogicalType() { - return LogicalType::HUGEINT; -} -template <> -LogicalType GetValueLogicalType() { - return LogicalType::DOUBLE; -} - -template -LogicalType GetSumStateType(const BoundAggregateFunction &function) { - child_list_t child_types; - child_types.emplace_back("isset", LogicalType::BOOLEAN); - - LogicalType value_type = GetValueLogicalType(); - // Use the return type when its physical representation matches the state type - if (function.GetReturnType().InternalType() == value_type.InternalType()) { - value_type = function.GetReturnType(); - } - child_types.emplace_back("value", value_type); - - return LogicalType::STRUCT(std::move(child_types)); -} +using HugeintSumOperation = + ClusteredSumOperation, ClusteredAddOp>; unique_ptr SumNoOverflowBind(BindAggregateFunctionInput &input) { throw BinderException("sum_no_overflow is for internal use only!"); @@ -123,22 +221,22 @@ AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) { case PhysicalType::INT32: { auto function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, IntegerSumOperation>( LogicalType::INTEGER, LogicalType::HUGEINT); - function.name = "sum_no_overflow"; + function.SetName("sum_no_overflow"); function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); function.SetBindCallback(SumNoOverflowBind); function.SetSerializeCallback(SumNoOverflowSerialize); function.SetDeserializeCallback(SumNoOverflowDeserialize); - return function.SetStructStateExport(GetSumStateType); + return function; } case PhysicalType::INT64: { auto function = AggregateFunction::UnaryAggregate, int64_t, hugeint_t, IntegerSumOperation>( LogicalType::BIGINT, LogicalType::HUGEINT); - function.name = "sum_no_overflow"; + function.SetName("sum_no_overflow"); function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); function.SetBindCallback(SumNoOverflowBind); function.SetSerializeCallback(SumNoOverflowSerialize); function.SetDeserializeCallback(SumNoOverflowDeserialize); - return function.SetStructStateExport(GetSumStateType); + return function; } default: throw BinderException("Unsupported internal type for sum_no_overflow"); @@ -147,7 +245,8 @@ AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) { AggregateFunction GetSumAggregateNoOverflowDecimal() { AggregateFunction aggr({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr, - nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, SumNoOverflowBind); + nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NoClusterUpdate(), + SumNoOverflowBind); aggr.SetSerializeCallback(SumNoOverflowSerialize); aggr.SetDeserializeCallback(SumNoOverflowDeserialize); return aggr; @@ -183,7 +282,7 @@ unique_ptr SumPropagateStats(ClientContext &context, BoundAggreg return nullptr; } // total sum is guaranteed to fit in a single int64: use int64 sum instead of hugeint sum - expr.function.ReplaceImplementation(GetSumAggregateNoOverflow(internal_type)); + expr.FunctionMutable().ReplaceImplementation(GetSumAggregateNoOverflow(internal_type)); } return nullptr; } @@ -194,13 +293,13 @@ AggregateFunction GetSumAggregate(PhysicalType type) { auto function = AggregateFunction::UnaryAggregate, bool, hugeint_t, IntegerSumOperation>( LogicalType::BOOLEAN, LogicalType::HUGEINT); function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); - return function.SetStructStateExport(GetSumStateType); + return function; } case PhysicalType::INT16: { auto function = AggregateFunction::UnaryAggregate, int16_t, hugeint_t, IntegerSumOperation>( LogicalType::SMALLINT, LogicalType::HUGEINT); function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); - return function.SetStructStateExport(GetSumStateType); + return function; } case PhysicalType::INT32: { @@ -209,7 +308,7 @@ AggregateFunction GetSumAggregate(PhysicalType type) { LogicalType::INTEGER, LogicalType::HUGEINT); function.SetStatisticsCallback(SumPropagateStats); function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); - return function.SetStructStateExport(GetSumStateType); + return function; } case PhysicalType::INT64: { auto function = @@ -217,14 +316,14 @@ AggregateFunction GetSumAggregate(PhysicalType type) { LogicalType::BIGINT, LogicalType::HUGEINT); function.SetStatisticsCallback(SumPropagateStats); function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); - return function.SetStructStateExport(GetSumStateType); + return function; } case PhysicalType::INT128: { auto function = AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, HugeintSumOperation>( LogicalType::HUGEINT, LogicalType::HUGEINT); function.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); - return function.SetStructStateExport(GetSumStateType); + return function; } default: throw InternalException("Unimplemented sum aggregate"); @@ -249,11 +348,6 @@ struct BignumState { }; struct BignumOperation { - template - static void Initialize(STATE &state) { - state.is_set = false; - } - template static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { @@ -312,8 +406,7 @@ AggregateFunctionSet SumFun::GetFunctions() { sum.AddFunction(GetSumAggregate(PhysicalType::INT64)); sum.AddFunction(GetSumAggregate(PhysicalType::INT128)); sum.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericSumOperation>( - LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetSumStateType)); + LogicalType::DOUBLE, LogicalType::DOUBLE)); sum.AddFunction(AggregateFunction::UnaryAggregate( LogicalType::BIGNUM, LogicalType::BIGNUM)); return sum; @@ -331,18 +424,9 @@ AggregateFunctionSet SumNoOverflowFun::GetFunctions() { return sum_no_overflow; } -LogicalType GetKahanSumStateType(const BoundAggregateFunction &function) { - child_list_t children; - children.emplace_back("isset", LogicalType::BOOLEAN); - children.emplace_back("value", LogicalType::DOUBLE); - children.emplace_back("err", LogicalType::DOUBLE); - return LogicalType::STRUCT(std::move(children)); -} - AggregateFunction KahanSumFun::GetFunction() { return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE) - .SetStructStateExport(GetKahanSumStateType); + LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp index 5f77eb9f7..ba1030fe3 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp @@ -192,14 +192,9 @@ struct ApproxTopKState { }; struct ApproxTopKOperation { - template - static void Initialize(STATE &state) { - state.state = nullptr; - } - template - static void Operation(STATE &aggr_state, const TYPE &input, AggregateInputData &aggr_input, Vector &top_k_vector, - idx_t offset, idx_t count) { + static void Operation(STATE &aggr_state, const TYPE &input, AggregateInputData &aggr_input, + const Vector &top_k_vector, idx_t offset, idx_t count) { auto &state = aggr_state.GetState(); if (state.values.empty()) { static constexpr int64_t MAX_APPROX_K = 1000000; @@ -321,11 +316,11 @@ void ApproxTopKUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp using STATE = ApproxTopKState; auto &input = inputs[0]; - auto &top_k_vector = inputs[1]; + const auto &top_k_vector = inputs[1]; - auto extra_state = OP::CreateExtraState(count); + auto extra_state = OP::CreateExtraState(); UnifiedVectorFormat input_data; - OP::PrepareData(input, count, extra_state, input_data); + OP::PrepareData(input, extra_state, input_data); auto states = state_vector.Values(); auto data = UnifiedVectorFormat::GetData(input_data); diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp index f5e7d118a..90aec94a9 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -103,12 +103,6 @@ struct ApproximateQuantileBindData : public FunctionData { struct ApproxQuantileOperation { using SAVE_TYPE = duckdb_tdigest::Value; - template - static void Initialize(STATE &state) { - state.pos = 0; - state.h = nullptr; - } - template static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { @@ -273,7 +267,7 @@ unique_ptr BindApproxQuantile(BindAggregateFunctionInput &input) { AggregateFunction ApproxQuantileDecimalFunction(const LogicalType &type) { auto function = GetApproximateQuantileDecimalAggregateFunction(type); - function.name = "approx_quantile"; + function.SetName("approx_quantile"); function.SetSerializeCallback(ApproximateQuantileBindData::Serialize); function.SetDeserializeCallback(ApproximateQuantileBindData::Deserialize); return function; @@ -337,8 +331,8 @@ AggregateFunction ApproxQuantileListAggregate(const LogicalType &input_type, con return AggregateFunction( {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); + AggregateFunction::StateFinalize, FunctionNullHandling::DEFAULT_NULL_HANDLING, + AggregateFunction::NoClusterUpdate(), AggregateFunction::NoBind(), AggregateFunction::StateDestroy); } template @@ -394,7 +388,7 @@ AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type AggregateFunction ApproxQuantileDecimalListFunction(const LogicalType &type) { auto function = GetApproxQuantileListAggregateFunction(type); - function.name = "approx_quantile"; + function.SetName("approx_quantile"); function.SetSerializeCallback(ApproximateQuantileBindData::Serialize); function.SetDeserializeCallback(ApproximateQuantileBindData::Deserialize); return function; diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp index 9079a0aa9..098f64330 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp @@ -341,7 +341,8 @@ unique_ptr BindMedianAbsoluteDeviationDecimal(BindAggregateFunctio AggregateFunctionSet MadFun::GetFunctions() { AggregateFunctionSet mad("mad"); mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal)); + nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, + AggregateFunction::NoClusterUpdate(), BindMedianAbsoluteDeviationDecimal)); const vector MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp index 798e44e84..11ecfabe6 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp @@ -452,7 +452,8 @@ AggregateFunction GetFallbackModeFunction(const LogicalType &type) { AggregateFunction aggr({type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr); + AggregateFunction::StateVoidFinalize, FunctionNullHandling::DEFAULT_NULL_HANDLING, + AggregateFunction::NoClusterUpdate()); aggr.SetStateDestructorCallback(AggregateFunction::StateDestroy); return aggr; } @@ -518,7 +519,8 @@ unique_ptr BindModeAggregate(BindAggregateFunctionInput &input) { AggregateFunctionSet ModeFun::GetFunctions() { AggregateFunctionSet mode("mode"); mode.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalTypeId::ANY, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, BindModeAggregate)); + nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, + AggregateFunction::NoClusterUpdate(), BindModeAggregate)); return mode; } @@ -574,7 +576,8 @@ AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) { AggregateFunction func({type}, LogicalType::DOUBLE, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, nullptr); + AggregateFunction::StateFinalize, + FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NoClusterUpdate()); func.SetStateDestructorCallback(AggregateFunction::StateDestroy); func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return func; @@ -620,7 +623,8 @@ unique_ptr BindEntropyAggregate(BindAggregateFunctionInput &input) AggregateFunctionSet EntropyFun::GetFunctions() { AggregateFunctionSet entropy("entropy"); entropy.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalType::DOUBLE, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, BindEntropyAggregate)); + nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, + AggregateFunction::NoClusterUpdate(), BindEntropyAggregate)); return entropy; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp index d0c655795..147ce9a11 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp @@ -440,8 +440,8 @@ static AggregateFunction QuantileListAggregate(const LogicalType &input_type, co {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); + AggregateFunction::StateFinalize, FunctionNullHandling::DEFAULT_NULL_HANDLING, + AggregateFunction::NoClusterUpdate(), AggregateFunction::NoBind(), AggregateFunction::StateDestroy); } struct ListDiscreteQuantile { @@ -672,7 +672,7 @@ static bool CanInterpolate(const LogicalType &type) { struct MedianFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = CanInterpolate(type) ? GetContinuousQuantile(type) : GetDiscreteQuantile(type); - fun.name = "median"; + fun.SetName("median"); fun.SetSerializeCallback(QuantileBindData::Serialize); fun.SetDeserializeCallback(Deserialize); return fun; @@ -697,7 +697,7 @@ struct MedianFunction { struct DiscreteQuantileListFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = GetDiscreteQuantileList(type); - fun.name = "quantile_disc"; + fun.SetName("quantile_disc"); fun.SetBindCallback(Bind); fun.SetSerializeCallback(QuantileBindData::Serialize); fun.SetDeserializeCallback(Deserialize); @@ -726,7 +726,7 @@ struct DiscreteQuantileListFunction { struct DiscreteQuantileFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = GetDiscreteQuantile(type); - fun.name = "quantile_disc"; + fun.SetName("quantile_disc"); fun.SetBindCallback(Bind); fun.SetSerializeCallback(QuantileBindData::Serialize); fun.SetDeserializeCallback(Deserialize); @@ -760,7 +760,7 @@ struct DiscreteQuantileFunction { struct ContinuousQuantileFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = GetContinuousQuantile(type); - fun.name = "quantile_cont"; + fun.SetName("quantile_cont"); fun.SetBindCallback(Bind); fun.SetSerializeCallback(QuantileBindData::Serialize); fun.SetDeserializeCallback(Deserialize); @@ -792,7 +792,7 @@ struct ContinuousQuantileFunction { struct ContinuousQuantileListFunction { static AggregateFunction GetAggregate(const LogicalType &type) { auto fun = GetContinuousQuantileList(type); - fun.name = "quantile_cont"; + fun.SetName("quantile_cont"); fun.SetBindCallback(Bind); fun.SetSerializeCallback(QuantileBindData::Serialize); fun.SetDeserializeCallback(Deserialize); diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp index c98ab3d24..95fc63b02 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp @@ -92,14 +92,6 @@ struct ReservoirQuantileBindData : public FunctionData { }; struct ReservoirQuantileOperation { - template - static void Initialize(STATE &state) { - state.v = nullptr; - state.len = 0; - state.pos = 0; - state.r_samp = nullptr; - } - template static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { @@ -251,8 +243,8 @@ AggregateFunction ReservoirQuantileListAggregate(const LogicalType &input_type, return AggregateFunction( {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); + AggregateFunction::StateFinalize, FunctionNullHandling::DEFAULT_NULL_HANDLING, + AggregateFunction::NoClusterUpdate(), AggregateFunction::NoBind(), AggregateFunction::StateDestroy); } template diff --git a/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp index 3083145e7..f6ff7a597 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp @@ -21,11 +21,6 @@ struct HistogramBinState { unsafe_vector *bin_boundaries; unsafe_vector *counts; - void Initialize() { - bin_boundaries = nullptr; - counts = nullptr; - } - void Destroy() { if (bin_boundaries) { delete bin_boundaries; @@ -42,7 +37,7 @@ struct HistogramBinState { } template - void InitializeBins(Vector &bin_vector, idx_t count, idx_t pos, AggregateInputData &aggr_input) { + void InitializeBins(Vector &bin_vector, idx_t pos, AggregateInputData &aggr_input) { bin_boundaries = new unsafe_vector(); counts = new unsafe_vector(); auto bin_counts = bin_vector.Values(); @@ -53,10 +48,9 @@ struct HistogramBinState { auto bin_list = bin_entry.GetValue(); auto &bin_child = ListVector::GetChildMutable(bin_vector); - auto bin_count = ListVector::GetListSize(bin_vector); UnifiedVectorFormat bin_child_data; - auto extra_state = OP::CreateExtraState(bin_count); - OP::PrepareData(bin_child, bin_count, extra_state, bin_child_data); + auto extra_state = OP::CreateExtraState(); + OP::PrepareData(bin_child, extra_state, bin_child_data); bin_boundaries->reserve(bin_list.length); for (idx_t i = 0; i < bin_list.length; i++) { @@ -81,11 +75,6 @@ struct HistogramBinState { }; struct HistogramBinFunction { - template - static void Initialize(STATE &state) { - state.Initialize(); - } - template static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { state.Destroy(); @@ -155,9 +144,9 @@ void HistogramBinUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, auto &input = inputs[0]; auto &bin_vector = inputs[1]; - auto extra_state = OP::CreateExtraState(count); + auto extra_state = OP::CreateExtraState(); UnifiedVectorFormat input_data; - OP::PrepareData(input, count, extra_state, input_data); + OP::PrepareData(input, extra_state, input_data); auto states = state_vector.Values *>(); auto data = UnifiedVectorFormat::GetData(input_data); @@ -168,7 +157,7 @@ void HistogramBinUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, } auto &state = *states[i].GetValue(); if (!state.IsSet()) { - state.template InitializeBins(bin_vector, count, i, aggr_input); + state.template InitializeBins(bin_vector, i, aggr_input); } auto bin_entry = HIST::template GetBin(data[idx], *state.bin_boundaries); ++(*state.counts)[bin_entry]; @@ -265,7 +254,7 @@ void IsHistogramOtherBinFunction(DataChunk &args, ExpressionState &state, Vector } auto v = OtherBucketValue(input_type); Vector ref(v, count_t(args.size())); - VectorOperations::NotDistinctFrom(args.data[0], ref, result, args.size()); + VectorOperations::NotDistinctFrom(args.data[0], ref, result); // Set NULL if input is NULL. UnifiedVectorFormat input_data; diff --git a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp index 09a3167b5..298ce0b8d 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp @@ -13,11 +13,6 @@ namespace duckdb { namespace { template struct HistogramFunction { - template - static void Initialize(STATE &state) { - state.hist = nullptr; - } - template static void Destroy(STATE &state, AggregateInputData &) { if (state.hist) { @@ -68,9 +63,9 @@ void HistogramUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, id auto &input = inputs[0]; - auto extra_state = OP::CreateExtraState(count); + auto extra_state = OP::CreateExtraState(); UnifiedVectorFormat input_data; - OP::PrepareData(input, count, extra_state, input_data); + OP::PrepareData(input, extra_state, input_data); auto states = state_vector.Values *>(); auto input_values = UnifiedVectorFormat::GetData(input_data); diff --git a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp index 64576cf33..f56481c01 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp @@ -36,12 +36,6 @@ struct ListAggState { }; struct ListFunction { - template - static void Initialize(STATE &state) { - state.linked_list.total_capacity = 0; - state.linked_list.first_segment = nullptr; - state.linked_list.last_segment = nullptr; - } static bool IgnoreNull() { return false; } diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp index 0f82395e3..2679a6fa9 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp @@ -9,17 +9,14 @@ namespace duckdb { namespace { struct RegrState { + static constexpr const char *STATE_NAMES[] = {"sum", "count"}; + using STATE_TYPE = StructStateType; + double sum; uint64_t count; }; struct RegrAvgFunction { - template - static void Initialize(STATE &state) { - state.sum = 0; - state.count = 0; - } - template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { target.sum += source.sum; @@ -54,25 +51,16 @@ struct RegrAvgYFunction : RegrAvgFunction { } }; -LogicalType GetRegrAvgStateType(const BoundAggregateFunction &) { - child_list_t child_types; - child_types.emplace_back("sum", LogicalType::DOUBLE); - child_types.emplace_back("count", LogicalType::UBIGINT); - return LogicalType::STRUCT(std::move(child_types)); -} - } // namespace AggregateFunction RegrAvgxFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetRegrAvgStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } AggregateFunction RegrAvgyFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetRegrAvgStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp index ec464e7fa..05b285fe2 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp @@ -1,28 +1,17 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "core_functions/aggregate/regression_functions.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "core_functions/aggregate/regression/regr_count.hpp" #include "duckdb/function/function_set.hpp" namespace duckdb { -namespace { - -LogicalType GetRegrCountStateType(const BoundAggregateFunction &) { - child_list_t child_types; - child_types.emplace_back("count", LogicalType::UBIGINT); - return LogicalType::STRUCT(std::move(child_types)); -} - -} // namespace - AggregateFunction RegrCountFun::GetFunction() { auto regr_count = AggregateFunction::BinaryAggregate( LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::UINTEGER); - regr_count.name = "regr_count"; + regr_count.SetName("regr_count"); regr_count.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - return regr_count.SetStructStateExport(GetRegrCountStateType); + return regr_count; } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp index d93b540ed..d115a242d 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp @@ -8,6 +8,9 @@ namespace duckdb { namespace { struct RegrInterceptState { + static constexpr const char *STATE_NAMES[] = {"count", "sum_x", "sum_y", "slope"}; + using STATE_TYPE = StructStateType; + uint64_t count; double sum_x; double sum_y; @@ -15,14 +18,6 @@ struct RegrInterceptState { }; struct RegrInterceptOperation { - template - static void Initialize(STATE &state) { - state.count = 0; - state.sum_x = 0; - state.sum_y = 0; - RegrSlopeOperation::Initialize(state.slope); - } - template static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { state.count++; @@ -60,37 +55,11 @@ struct RegrInterceptOperation { } }; -LogicalType GetRegrInterceptStateType(const BoundAggregateFunction &) { - child_list_t covpop_children; - covpop_children.emplace_back("count", LogicalType::UBIGINT); - covpop_children.emplace_back("meanx", LogicalType::DOUBLE); - covpop_children.emplace_back("meany", LogicalType::DOUBLE); - covpop_children.emplace_back("co_moment", LogicalType::DOUBLE); - auto covpop_type = LogicalType::STRUCT(std::move(covpop_children)); - - child_list_t varpop_children; - varpop_children.emplace_back("count", LogicalType::UBIGINT); - varpop_children.emplace_back("mean", LogicalType::DOUBLE); - varpop_children.emplace_back("dsquared", LogicalType::DOUBLE); - auto varpop_type = LogicalType::STRUCT(std::move(varpop_children)); - - child_list_t state_children; - state_children.emplace_back("count", LogicalType::UBIGINT); - state_children.emplace_back("sum_x", LogicalType::DOUBLE); - state_children.emplace_back("sum_y", LogicalType::DOUBLE); - child_list_t slope_children; - slope_children.emplace_back("cov_pop", std::move(covpop_type)); - slope_children.emplace_back("var_pop", std::move(varpop_type)); - state_children.emplace_back("slope", LogicalType::STRUCT(std::move(slope_children))); - return LogicalType::STRUCT(std::move(state_children)); -} - } // namespace AggregateFunction RegrInterceptFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetRegrInterceptStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp index 7f6407cad..2643a4962 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp @@ -13,19 +13,15 @@ namespace duckdb { namespace { struct RegrR2State { + static constexpr const char *STATE_NAMES[] = {"corr", "var_pop_x", "var_pop_y"}; + using STATE_TYPE = StructStateType; + CorrState corr; StddevState var_pop_x; StddevState var_pop_y; }; struct RegrR2Operation { - template - static void Initialize(STATE &state) { - CorrOperation::Initialize(state.corr); - STDDevBaseOperation::Initialize(state.var_pop_x); - STDDevBaseOperation::Initialize(state.var_pop_y); - } - template static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { CorrOperation::Operation(state.corr, y, x, idata); @@ -61,39 +57,11 @@ struct RegrR2Operation { } }; -LogicalType GetRegrR2StateType(const BoundAggregateFunction &) { - child_list_t covar_children; - covar_children.emplace_back("count", LogicalType::UBIGINT); - covar_children.emplace_back("meanx", LogicalType::DOUBLE); - covar_children.emplace_back("meany", LogicalType::DOUBLE); - covar_children.emplace_back("co_moment", LogicalType::DOUBLE); - auto cov_pop_type = LogicalType::STRUCT(std::move(covar_children)); - - child_list_t stddev_types; - stddev_types.emplace_back("count", LogicalType::UBIGINT); - stddev_types.emplace_back("mean", LogicalType::DOUBLE); - stddev_types.emplace_back("dsquared", LogicalType::DOUBLE); - auto stddev_type = LogicalType::STRUCT(std::move(stddev_types)); - - child_list_t corr_children; - corr_children.emplace_back("cov_pop", std::move(cov_pop_type)); - corr_children.emplace_back("dev_pop_x", stddev_type); - corr_children.emplace_back("dev_pop_y", stddev_type); - auto corr_state = LogicalType::STRUCT(std::move(corr_children)); - - child_list_t state_children; - state_children.emplace_back("corr", corr_state); - state_children.emplace_back("var_pop_x", stddev_type); - state_children.emplace_back("var_pop_y", stddev_type); - return LogicalType::STRUCT(std::move(state_children)); -} - } // namespace AggregateFunction RegrR2Fun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetRegrR2StateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp index 70f18cda3..c00201fd5 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp @@ -12,34 +12,9 @@ namespace duckdb { -namespace { - -LogicalType GetRegrSlopeStateType(const BoundAggregateFunction &) { - child_list_t covar_children; - covar_children.emplace_back("count", LogicalType::UBIGINT); - covar_children.emplace_back("meanx", LogicalType::DOUBLE); - covar_children.emplace_back("meany", LogicalType::DOUBLE); - covar_children.emplace_back("co_moment", LogicalType::DOUBLE); - auto cov_pop_type = LogicalType::STRUCT(std::move(covar_children)); - - child_list_t stddev_types; - stddev_types.emplace_back("count", LogicalType::UBIGINT); - stddev_types.emplace_back("mean", LogicalType::DOUBLE); - stddev_types.emplace_back("dsquared", LogicalType::DOUBLE); - auto stddev_type = LogicalType::STRUCT(std::move(stddev_types)); - - child_list_t state_children; - state_children.emplace_back("cov_pop", std::move(cov_pop_type)); - state_children.emplace_back("var_pop", std::move(stddev_type)); - return LogicalType::STRUCT(std::move(state_children)); -} - -} // namespace - AggregateFunction RegrSlopeFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetRegrSlopeStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp index ec01aa50d..3673b0dbf 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp @@ -11,17 +11,14 @@ namespace duckdb { namespace { struct RegrSState { + static constexpr const char *STATE_NAMES[] = {"count", "var_pop"}; + using STATE_TYPE = StructStateType; + uint64_t count; StddevState var_pop; }; struct RegrBaseOperation { - template - static void Initialize(STATE &state) { - RegrCountFunction::Initialize(state.count); - STDDevBaseOperation::Initialize(state.var_pop); - } - template static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { RegrCountFunction::Combine(source.count, target.count, aggr_input_data); @@ -60,31 +57,16 @@ struct RegrSYYOperation : RegrBaseOperation { } }; -LogicalType GetRegrSStateType(const BoundAggregateFunction &) { - child_list_t stddev_types; - stddev_types.emplace_back("count", LogicalType::UBIGINT); - stddev_types.emplace_back("mean", LogicalType::DOUBLE); - stddev_types.emplace_back("dsquared", LogicalType::DOUBLE); - auto stddev_type = LogicalType::STRUCT(std::move(stddev_types)); - - child_list_t state_children; - state_children.emplace_back("count", LogicalType::UBIGINT); - state_children.emplace_back("var_pop", stddev_type); - return LogicalType::STRUCT(std::move(state_children)); -} - } // namespace AggregateFunction RegrSXXFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetRegrSStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } AggregateFunction RegrSYYFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetRegrSStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp index e46d025f8..375377a6d 100644 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp +++ b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp @@ -11,17 +11,14 @@ namespace duckdb { namespace { struct RegrSXyState { + static constexpr const char *STATE_NAMES[] = {"count", "cov_pop"}; + using STATE_TYPE = StructStateType; + uint64_t count; CovarState cov_pop; }; struct RegrSXYOperation { - template - static void Initialize(STATE &state) { - RegrCountFunction::Initialize(state.count); - CovarOperation::Initialize(state.cov_pop); - } - template static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { RegrCountFunction::Operation(state.count, y, x, idata); @@ -47,27 +44,11 @@ struct RegrSXYOperation { } }; -LogicalType GetRegrSXYStateType(const BoundAggregateFunction &) { - child_list_t covar_children; - covar_children.emplace_back("count", LogicalType::UBIGINT); - covar_children.emplace_back("meanx", LogicalType::DOUBLE); - covar_children.emplace_back("meany", LogicalType::DOUBLE); - covar_children.emplace_back("co_moment", LogicalType::DOUBLE); - auto cov_pop_type = LogicalType::STRUCT(std::move(covar_children)); - - child_list_t state_children; - state_children.emplace_back("count", LogicalType::UBIGINT); - state_children.emplace_back("cov_pop", std::move(cov_pop_type)); - - return LogicalType::STRUCT(std::move(state_children)); -} - } // namespace AggregateFunction RegrSXYFun::GetFunction() { return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE) - .SetStructStateExport(GetRegrSXYStateType); + LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/function_list.cpp b/src/duckdb/extension/core_functions/function_list.cpp index e73b42d36..50337fef4 100644 --- a/src/duckdb/extension/core_functions/function_list.cpp +++ b/src/duckdb/extension/core_functions/function_list.cpp @@ -120,6 +120,7 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_AGGREGATE_FUNCTION_SET(BitXorFun), DUCKDB_SCALAR_FUNCTION_SET(BitStringFun), DUCKDB_AGGREGATE_FUNCTION_SET(BitstringAggFun), + DUCKDB_SCALAR_FUNCTION(BitStringSortKeyFun), DUCKDB_AGGREGATE_FUNCTION(BoolAndFun), DUCKDB_AGGREGATE_FUNCTION(BoolOrFun), DUCKDB_SCALAR_FUNCTION(CanCastImplicitlyFun), diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp index 2a5d066f2..4aac57eb1 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp @@ -15,6 +15,9 @@ namespace duckdb { struct CorrState { + static constexpr const char *STATE_NAMES[] = {"cov_pop", "dev_pop_x", "dev_pop_y"}; + using STATE_TYPE = StructStateType; + CovarState cov_pop; StddevState dev_pop_x; StddevState dev_pop_y; @@ -23,13 +26,6 @@ struct CorrState { // Returns the correlation coefficient for non-null pairs in a group. // CORR(y, x) = COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y)) struct CorrOperation { - template - static void Initialize(STATE &state) { - CovarOperation::Initialize(state.cov_pop); - STDDevBaseOperation::Initialize(state.dev_pop_x); - STDDevBaseOperation::Initialize(state.dev_pop_y); - } - template static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { CovarOperation::Operation(state.cov_pop, y, x, idata); diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp index 1908dfad1..719736066 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp @@ -14,6 +14,9 @@ namespace duckdb { struct CovarState { + static constexpr const char *STATE_NAMES[] = {"count", "meanx", "meany", "co_moment"}; + using STATE_TYPE = StructStateType; + uint64_t count; double meanx; double meany; @@ -21,14 +24,6 @@ struct CovarState { }; struct CovarOperation { - template - static void Initialize(STATE &state) { - state.count = 0; - state.meanx = 0; - state.meany = 0; - state.co_moment = 0; - } - template static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { // update running mean and d^2 diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp index 3a7c9adde..0150248c1 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp @@ -15,6 +15,9 @@ namespace duckdb { struct StddevState { + static constexpr const char *STATE_NAMES[] = {"count", "mean", "dsquared"}; + using STATE_TYPE = StructStateType; + uint64_t count; // n double mean; // M1 double dsquared; // M2 @@ -23,13 +26,6 @@ struct StddevState { // Streaming approximate standard deviation using Welford's // method, DOI: 10.2307/1266577 struct STDDevBaseOperation { - template - static void Initialize(STATE &state) { - state.count = 0; - state.mean = 0; - state.dsquared = 0; - } - template static void Execute(STATE &state, const INPUT_TYPE &input) { // update running mean and d^2 diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp index b0b7ad0de..690565cb7 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp @@ -19,11 +19,11 @@ struct HistogramFunctor { FlatVector::GetDataMutable(result)[offset] = value; } - static bool CreateExtraState(idx_t count) { + static bool CreateExtraState() { return false; } - static void PrepareData(Vector &input, idx_t count, bool &, UnifiedVectorFormat &result) { + static void PrepareData(const Vector &input, bool &, UnifiedVectorFormat &result) { input.ToUnifiedFormat(result); } @@ -66,11 +66,11 @@ struct HistogramStringFunctor : HistogramStringFunctorBase { FlatVector::GetDataMutable(result)[offset] = StringVector::AddStringOrBlob(result, value); } - static bool CreateExtraState(idx_t count) { + static bool CreateExtraState() { return false; } - static void PrepareData(Vector &input, idx_t count, bool &, UnifiedVectorFormat &result) { + static void PrepareData(const Vector &input, bool &, UnifiedVectorFormat &result) { input.ToUnifiedFormat(result); } }; @@ -82,13 +82,14 @@ struct HistogramGenericFunctor : HistogramStringFunctorBase { OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); } - static Vector CreateExtraState(idx_t count) { - return Vector(LogicalType::BLOB, count); + static Vector CreateExtraState() { + return Vector(LogicalType::BLOB); } - static void PrepareData(Vector &input, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) { + static void PrepareData(const Vector &input, Vector &extra_state, UnifiedVectorFormat &result) { OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - CreateSortKeyHelpers::CreateSortKey(input, count, modifiers, extra_state); + extra_state.Reserve(input.size()); + CreateSortKeyHelpers::CreateSortKey(input, modifiers, extra_state); input.Flatten(); extra_state.Flatten(); FlatVector::ValidityMutable(extra_state).Initialize(FlatVector::Validity(input)); diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp index 40366ef6a..ac9836499 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp @@ -16,11 +16,6 @@ namespace duckdb { struct RegrCountFunction { - template - static void Initialize(STATE &state) { - state = 0; - } - template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { target += source; diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp index 77aecea6c..c9d487df7 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp @@ -13,17 +13,14 @@ namespace duckdb { struct RegrSlopeState { + static constexpr const char *STATE_NAMES[] = {"cov_pop", "var_pop"}; + using STATE_TYPE = StructStateType; + CovarState cov_pop; StddevState var_pop; }; struct RegrSlopeOperation { - template - static void Initialize(STATE &state) { - CovarOperation::Initialize(state.cov_pop); - STDDevBaseOperation::Initialize(state.var_pop); - } - template static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { CovarOperation::Operation(state.cov_pop, y, x, idata); diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp index fadf94e47..4c847b46a 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/operator/add.hpp" #include "duckdb/common/operator/multiply.hpp" #include "duckdb/function/aggregate_state.hpp" +#include "duckdb/function/aggregate_state_layout.hpp" #include "duckdb/common/operator/cast_operators.hpp" namespace duckdb { @@ -25,32 +26,22 @@ static inline void KahanAddInternal(double input, double &summed, double &err) { template struct SumState { - bool isset; + using value_type = T; + using STATE_TYPE = OptionalStateType; T value; - - void Initialize() { - this->isset = false; - this->value = 0; - } - - void Combine(const SumState &other) { - this->isset = other.isset || this->isset; - this->value += other.value; - } + bool is_set; }; struct KahanSumState { - bool isset; + static constexpr const char *STATE_NAMES[] = {"value", "err"}; + using STATE_TYPE = OptionalStateType>; + double value; double err; - - void Initialize() { - this->isset = false; - this->err = 0.0; - } + bool is_set; void Combine(const KahanSumState &other) { - this->isset = other.isset || this->isset; + this->is_set = other.is_set || this->is_set; KahanAddInternal(other.value, this->value, this->err); KahanAddInternal(other.err, this->value, this->err); } @@ -161,12 +152,6 @@ struct AddToHugeint { template struct BaseSumOperation { - template - static void Initialize(STATE &state) { - state.value = 0; - STATEOP::template Initialize(state); - } - template static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { STATEOP::template Combine(source, target, aggr_input_data); diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp index 056d5f477..41cc5f565 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp @@ -55,4 +55,14 @@ struct BitStringFun { static ScalarFunctionSet GetFunctions(); }; +struct BitStringSortKeyFun { + static constexpr const char *Name = "bitstring_byte_comparable"; + static constexpr const char *Parameters = "bitstring"; + static constexpr const char *Description = "Converts a BIT to a binary comparable sort key"; + static constexpr const char *Example = "bitstring_byte_comparable('1010'::BIT)"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/lambda_functions.cpp b/src/duckdb/extension/core_functions/lambda_functions.cpp index a793e5fea..1c18aaad0 100644 --- a/src/duckdb/extension/core_functions/lambda_functions.cpp +++ b/src/duckdb/extension/core_functions/lambda_functions.cpp @@ -80,7 +80,7 @@ struct ListTransformFunctor { offset += entry.length; } //! Appends the lambda vector to the result's child vector - static void AppendResult(Vector &result, Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *, + static void AppendResult(Vector &result, const Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *, ListFilterInfo &, LambdaExecuteInfo &) { ListVector::Append(result, lambda_vector, elem_cnt); } @@ -102,8 +102,8 @@ struct ListFilterFunctor { entry_lengths.push_back(entry.length); } //! Uses the lambda vector to filter the incoming list and to append the filtered list to the result vector - static void AppendResult(Vector &result, Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *result_entries, - ListFilterInfo &info, LambdaExecuteInfo &execute_info) { + static void AppendResult(Vector &result, const Vector &lambda_vector, const idx_t elem_cnt, + list_entry_t *result_entries, ListFilterInfo &info, LambdaExecuteInfo &execute_info) { idx_t count = 0; SelectionVector sel(elem_cnt); @@ -177,9 +177,6 @@ LambdaFunctions::GetMutableColumnInfo(vector &data) static void ExecuteExpression(const idx_t elem_cnt, const LambdaFunctions::ColumnInfo &column_info, const vector &column_infos, const Vector &index_vector, LambdaExecuteInfo &info) { - info.input_chunk.SetCardinality(elem_cnt); - info.lambda_chunk.SetCardinality(elem_cnt); - // slice the child vector Vector slice(column_info.vector, column_info.sel, elem_cnt); @@ -378,7 +375,7 @@ unique_ptr LambdaFunctions::ListLambdaBind(ClientContext &context, // get the lambda expression and put it in the bind info auto &bound_lambda_expr = arguments[1]->Cast(); - auto lambda_expr = std::move(bound_lambda_expr.lambda_expr); + auto lambda_expr = std::move(bound_lambda_expr.LambdaExprMutable()); if (lambda_expr->IsVolatile()) { bound_function.SetVolatile(); } diff --git a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp index da8c616db..ce695607a 100644 --- a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp @@ -87,7 +87,7 @@ template static void ArrayFixedCombine(DataChunk &args, ExpressionState &state, Vector &result) { const auto &lstate = state.Cast(); const auto &expr = lstate.expr.Cast(); - const auto &func_name = expr.function.GetName(); + const auto &func_name = expr.Function().GetName(); const auto count = args.size(); auto &lhs_child = ArrayVector::GetChildMutable(args.data[0]); @@ -149,7 +149,7 @@ template static void ArrayGenericFold(DataChunk &args, ExpressionState &state, Vector &result) { const auto &lstate = state.Cast(); const auto &expr = lstate.expr.Cast(); - const auto &func_name = expr.function.GetName(); + const auto &func_name = expr.Function().GetName(); const auto count = args.size(); auto &lhs_child = ArrayVector::GetChildMutable(args.data[0]); diff --git a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp index 929f4787d..0d388ddfb 100644 --- a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp +++ b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp @@ -10,7 +10,7 @@ namespace duckdb { template static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &result) { BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t n) { + args.data[0], args.data[1], result, [&](string_t input, int32_t n) { if (n < 0) { throw InvalidInputException("The bitstring length cannot be negative"); } @@ -48,6 +48,34 @@ ScalarFunctionSet BitStringFun::GetFunctions() { return bitstring; } +namespace { + +// Keep the key bytes above the BLOB escape range while preserving 0 < 1. +constexpr data_t BIT_SORT_KEY_ZERO = 2; +constexpr data_t BIT_SORT_KEY_ONE = 3; + +string_t CreateBitStringSortKey(string_t input, Vector &result) { + const auto bit_length = Bit::BitLength(input); + auto target = StringVector::EmptyString(result, bit_length); + auto data = data_ptr_cast(target.GetDataWriteable()); + for (idx_t bit_idx = 0; bit_idx < bit_length; bit_idx++) { + data[bit_idx] = Bit::GetBit(input, bit_idx) ? BIT_SORT_KEY_ONE : BIT_SORT_KEY_ZERO; + } + target.Finalize(); + return target; +} + +} // namespace + +static void BitStringSortKeyFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, + [&](string_t input) { return CreateBitStringSortKey(input, result); }); +} + +ScalarFunction BitStringSortKeyFun::GetFunction() { + return ScalarFunction({LogicalType::BIT}, LogicalType::BLOB, BitStringSortKeyFunction); +} + //===--------------------------------------------------------------------===// // get_bit //===--------------------------------------------------------------------===// @@ -77,8 +105,7 @@ ScalarFunction GetBitFun::GetFunction() { //===--------------------------------------------------------------------===// static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &result) { TernaryExecutor::Execute( - args.data[0], args.data[1], args.data[2], result, args.size(), - [&](string_t input, int32_t n, int32_t new_value) { + args.data[0], args.data[1], args.data[2], result, [&](string_t input, int32_t n, int32_t new_value) { if (new_value != 0 && new_value != 1) { throw InvalidInputException("The new bit must be 1 or 0"); } diff --git a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp index bad98ad09..8b7f5526f 100644 --- a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp +++ b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp @@ -27,12 +27,12 @@ struct Base64DecodeOperator { void Base64EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString(args.data[0], result); } void Base64DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString(args.data[0], result); } } // namespace diff --git a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp index 49000a2d9..ae68c5741 100644 --- a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp +++ b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp @@ -52,14 +52,14 @@ struct UnaryBlobDecodeOperator { void UnaryDecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::Execute(args.data[0], result, args.size()); + UnaryExecutor::Execute(args.data[0], result); StringVector::AddHeapReference(result, args.data[0]); } void BinaryDecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { // decode is also a nop cast, but requires verification if the provided string is actually BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, string_t error_option) { + args.data[0], args.data[1], result, [&](string_t input, string_t error_option) { auto input_data = input.GetDataWriteable(); auto input_length = input.GetSize(); @@ -78,7 +78,6 @@ void BinaryDecodeFunction(DataChunk &args, ExpressionState &state, Vector &resul auto target = StringVector::EmptyString(result, new_str.size()); auto output = target.GetDataWriteable(); memcpy(output, new_str.data(), new_str.size()); - output[new_str.size()] = '\0'; target.Finalize(); return target; } diff --git a/src/duckdb/extension/core_functions/scalar/date/age.cpp b/src/duckdb/extension/core_functions/scalar/date/age.cpp index 1f3048e44..7e4777853 100644 --- a/src/duckdb/extension/core_functions/scalar/date/age.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/age.cpp @@ -17,7 +17,7 @@ static void AgeFunctionStandard(DataChunk &input, ExpressionState &state, Vector auto current_date = Timestamp::FromDatetime( Timestamp::GetDate(MetaTransaction::Get(state.GetContext()).start_timestamp), dtime_t(0)); - UnaryExecutor::Execute(input.data[0], result, input.size(), + UnaryExecutor::Execute(input.data[0], result, [&](timestamp_t input) -> optional { if (input.IsFinite()) { return Interval::GetAge(current_date, input); @@ -31,8 +31,7 @@ static void AgeFunction(DataChunk &input, ExpressionState &state, Vector &result D_ASSERT(input.ColumnCount() == 2); BinaryExecutor::Execute( - input.data[0], input.data[1], result, input.size(), - [&](timestamp_t input1, timestamp_t input2) -> optional { + input.data[0], input.data[1], result, [&](timestamp_t input1, timestamp_t input2) -> optional { if (input1.IsFinite() && input2.IsFinite()) { return Interval::GetAge(input1, input2); } else { diff --git a/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp b/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp index e18d90c86..9b69de281 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp @@ -16,8 +16,8 @@ namespace duckdb { namespace { struct DateDiff { template - static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - BinaryExecutor::Execute(left, right, result, count, [&](TA startdate, TB enddate) -> optional { + static inline void BinaryExecute(const Vector &left, const Vector &right, Vector &result) { + BinaryExecutor::Execute(left, right, result, [&](TA startdate, TB enddate) -> optional { if (startdate.IsFinite() && enddate.IsFinite()) { return OP::template Operation(startdate, enddate); } else { @@ -361,55 +361,55 @@ struct DateDiffTernaryOperator { }; template -void DateDiffBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { +void DateDiffBinaryExecutor(DatePartSpecifier type, const Vector &left, const Vector &right, Vector &result) { switch (type) { case DatePartSpecifier::YEAR: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::MONTH: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::DAY: case DatePartSpecifier::DOW: case DatePartSpecifier::ISODOW: case DatePartSpecifier::DOY: case DatePartSpecifier::JULIAN_DAY: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::DECADE: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::CENTURY: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::MILLENNIUM: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::QUARTER: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::WEEK: case DatePartSpecifier::YEARWEEK: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::ISOYEAR: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::MICROSECONDS: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::MILLISECONDS: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::SECOND: case DatePartSpecifier::EPOCH: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::MINUTE: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; case DatePartSpecifier::HOUR: - DateDiff::BinaryExecute(left, right, result, count); + DateDiff::BinaryExecute(left, right, result); break; default: throw NotImplementedException("Specifier type not implemented for DATEDIFF"); @@ -419,9 +419,9 @@ void DateDiffBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, template void DateDiffFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 3); - auto &part_arg = args.data[0]; - auto &start_arg = args.data[1]; - auto &end_arg = args.data[2]; + const auto &part_arg = args.data[0]; + const auto &start_arg = args.data[1]; + const auto &end_arg = args.data[2]; if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { // Common case of constant part. @@ -429,9 +429,9 @@ void DateDiffFunction(DataChunk &args, ExpressionState &state, Vector &result) { throw InternalException("DateDiff called with constant NULL part"); } const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateDiffBinaryExecutor(type, start_arg, end_arg, result, args.size()); + DateDiffBinaryExecutor(type, start_arg, end_arg, result); } else { - TernaryExecutor::Execute(part_arg, start_arg, end_arg, result, args.size(), + TernaryExecutor::Execute(part_arg, start_arg, end_arg, result, DateDiffTernaryOperator::Operation); } } diff --git a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp index 78c3cc278..e1f7f7bdd 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp @@ -161,7 +161,7 @@ struct DatePart { D_ASSERT(input.ColumnCount() >= 1); using IOP = PartOperator; std::nullptr_t no_data = nullptr; - UnaryExecutor::GenericExecute(input.data[0], result, input.size(), no_data, true); + UnaryExecutor::GenericExecute(input.data[0], result, no_data, true); } struct YearOperator { @@ -413,7 +413,7 @@ struct DatePart { static void Inverse(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 1); - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](int64_t input) { + UnaryExecutor::Execute(input.data[0], result, [&](int64_t input) { // millisecond amounts provided to epoch_ms should never be considered infinite // instead such values will just throw when converted to microseconds return Timestamp::FromEpochMsPossiblyInfinite(input); @@ -541,11 +541,11 @@ struct DatePart { template static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 2); - auto &offset = input.data[0]; - auto &timetz = input.data[1]; + const auto &offset = input.data[0]; + const auto &timetz = input.data[1]; auto func = DatePart::TimezoneOperator::Operation; - BinaryExecutor::Execute(offset, timetz, result, input.size(), func); + BinaryExecutor::Execute(offset, timetz, result, func); } template @@ -772,7 +772,7 @@ struct DatePart { template void DatePartCachedFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast>(); - UnaryExecutor::Execute(args.data[0], result, args.size(), + UnaryExecutor::Execute(args.data[0], result, [&](T input) { return lstate.cache.ExtractElement(input); }); } @@ -1745,11 +1745,11 @@ int64_t ExtractElement(DatePartSpecifier type, T element) { template void DatePartFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 2); - auto &spec_arg = args.data[0]; - auto &date_arg = args.data[1]; + const auto &spec_arg = args.data[0]; + const auto &date_arg = args.data[1]; BinaryExecutor::Execute( - spec_arg, date_arg, result, args.size(), [&](string_t specifier, T date) -> optional { + spec_arg, date_arg, result, [&](string_t specifier, T date) -> optional { if (date.IsFinite()) { return ExtractElement(GetDatePartSpecifier(specifier.GetString()), date); } else { @@ -1989,11 +1989,11 @@ struct StructDatePart { template static void Function(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); D_ASSERT(args.ColumnCount() == 1); const auto count = args.size(); - Vector &input = args.data[0]; + const Vector &input = args.data[0]; // Type counts const auto BIGINT_COUNT = size_t(DatePartSpecifier::BEGIN_DOUBLE) - size_t(DatePartSpecifier::BEGIN_BIGINT); @@ -2218,7 +2218,7 @@ static void ExecuteGetNanosFromTimestampNs(DataChunk &input, ExpressionState &st D_ASSERT(input.ColumnCount() == 1); auto func = GetEpochNanosOperator::Operation; - UnaryExecutor::Execute(input.data[0], result, input.size(), func); + UnaryExecutor::Execute(input.data[0], result, func); } ScalarFunctionSet EpochNsFun::GetFunctions() { diff --git a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp index c1304cd74..b9557e8f7 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp @@ -22,8 +22,8 @@ struct DateSub { } template - static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - BinaryExecutor::Execute(left, right, result, count, [&](TA startdate, TB enddate) -> optional { + static inline void BinaryExecute(const Vector &left, const Vector &right, Vector &result) { + BinaryExecutor::Execute(left, right, result, [&](TA startdate, TB enddate) -> optional { if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { return OP::template Operation(startdate, enddate); } else { @@ -362,53 +362,53 @@ struct DateSubTernaryOperator { }; template -void DateSubBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { +void DateSubBinaryExecutor(DatePartSpecifier type, const Vector &left, const Vector &right, Vector &result) { switch (type) { case DatePartSpecifier::YEAR: case DatePartSpecifier::ISOYEAR: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::MONTH: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::DAY: case DatePartSpecifier::DOW: case DatePartSpecifier::ISODOW: case DatePartSpecifier::DOY: case DatePartSpecifier::JULIAN_DAY: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::DECADE: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::CENTURY: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::MILLENNIUM: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::QUARTER: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::WEEK: case DatePartSpecifier::YEARWEEK: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::MICROSECONDS: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::MILLISECONDS: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::SECOND: case DatePartSpecifier::EPOCH: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::MINUTE: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; case DatePartSpecifier::HOUR: - DateSub::BinaryExecute(left, right, result, count); + DateSub::BinaryExecute(left, right, result); break; default: throw NotImplementedException("Specifier type not implemented for DATESUB"); @@ -418,9 +418,9 @@ void DateSubBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, template void DateSubFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 3); - auto &part_arg = args.data[0]; - auto &start_arg = args.data[1]; - auto &end_arg = args.data[2]; + const auto &part_arg = args.data[0]; + const auto &start_arg = args.data[1]; + const auto &end_arg = args.data[2]; if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { // Common case of constant part. @@ -428,9 +428,9 @@ void DateSubFunction(DataChunk &args, ExpressionState &state, Vector &result) { throw InternalException("DateSub called with constant NULL part"); } const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateSubBinaryExecutor(type, start_arg, end_arg, result, args.size()); + DateSubBinaryExecutor(type, start_arg, end_arg, result); } else { - TernaryExecutor::Execute(part_arg, start_arg, end_arg, result, args.size(), + TernaryExecutor::Execute(part_arg, start_arg, end_arg, result, DateSubTernaryOperator::Operation); } } diff --git a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp index 00ec4bd9b..c0f96a5c2 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp @@ -23,8 +23,8 @@ struct DateTrunc { } template - static inline void UnaryExecute(Vector &left, Vector &result, idx_t count) { - UnaryExecutor::Execute(left, result, count, UnaryFunction); + static inline void UnaryExecute(const Vector &left, Vector &result) { + UnaryExecutor::Execute(left, result, UnaryFunction); } struct MillenniumOperator { @@ -408,55 +408,55 @@ struct DateTruncBinaryOperator { }; template -void DateTruncUnaryExecutor(DatePartSpecifier type, Vector &left, Vector &result, idx_t count) { +void DateTruncUnaryExecutor(DatePartSpecifier type, const Vector &left, Vector &result) { switch (type) { case DatePartSpecifier::MILLENNIUM: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::CENTURY: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::DECADE: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::YEAR: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::QUARTER: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::MONTH: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::WEEK: case DatePartSpecifier::YEARWEEK: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::ISOYEAR: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::DAY: case DatePartSpecifier::DOW: case DatePartSpecifier::ISODOW: case DatePartSpecifier::DOY: case DatePartSpecifier::JULIAN_DAY: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::HOUR: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::MINUTE: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::SECOND: case DatePartSpecifier::EPOCH: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::MILLISECONDS: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; case DatePartSpecifier::MICROSECONDS: - DateTrunc::UnaryExecute(left, result, count); + DateTrunc::UnaryExecute(left, result); break; default: throw NotImplementedException("Specifier type not implemented for DATETRUNC"); @@ -466,8 +466,8 @@ void DateTruncUnaryExecutor(DatePartSpecifier type, Vector &left, Vector &result template void DateTruncFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 2); - auto &part_arg = args.data[0]; - auto &date_arg = args.data[1]; + const auto &part_arg = args.data[0]; + const auto &date_arg = args.data[1]; if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { // Common case of constant part. @@ -475,10 +475,9 @@ void DateTruncFunction(DataChunk &args, ExpressionState &state, Vector &result) throw InternalException("DateTrunc called with constant NULL part"); } const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateTruncUnaryExecutor(type, date_arg, result, args.size()); + DateTruncUnaryExecutor(type, date_arg, result); } else { - BinaryExecutor::ExecuteStandard(part_arg, date_arg, result, - args.size()); + BinaryExecutor::ExecuteStandard(part_arg, date_arg, result); } } diff --git a/src/duckdb/extension/core_functions/scalar/date/epoch.cpp b/src/duckdb/extension/core_functions/scalar/date/epoch.cpp index 9af33ed9f..4c3bef43d 100644 --- a/src/duckdb/extension/core_functions/scalar/date/epoch.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/epoch.cpp @@ -22,7 +22,7 @@ struct EpochSecOperator { void EpochSecFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 1); - UnaryExecutor::Execute(input.data[0], result, input.size()); + UnaryExecutor::Execute(input.data[0], result); } struct NormalizedIntervalOperator { @@ -35,7 +35,7 @@ struct NormalizedIntervalOperator { void NormalizedIntervalFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 1); - UnaryExecutor::Execute(input.data[0], result, input.size()); + UnaryExecutor::Execute(input.data[0], result); } struct TimeTZSortKeyOperator { @@ -48,7 +48,7 @@ struct TimeTZSortKeyOperator { void TimeTZSortKeyFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 1); - UnaryExecutor::Execute(input.data[0], result, input.size()); + UnaryExecutor::Execute(input.data[0], result); } } // namespace diff --git a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp index 6a69725c3..4e75692af 100644 --- a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp @@ -29,12 +29,11 @@ struct MakeDateOperator { template void ExecuteMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 3); - auto &yyyy = input.data[0]; - auto &mm = input.data[1]; - auto &dd = input.data[2]; + const auto &yyyy = input.data[0]; + const auto &mm = input.data[1]; + const auto &dd = input.data[2]; - TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), - MakeDateOperator::Operation); + TernaryExecutor::Execute(yyyy, mm, dd, result, MakeDateOperator::Operation); } template @@ -51,7 +50,7 @@ template void ExecuteStructMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { // this should be guaranteed by the binder D_ASSERT(input.ColumnCount() == 1); - auto &vec = input.data[0]; + const auto &vec = input.data[0]; const auto count = input.size(); auto iter = vec.Values>(); @@ -92,11 +91,11 @@ struct MakeTimeOperator { template void ExecuteMakeTime(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 3); - auto &yyyy = input.data[0]; - auto &mm = input.data[1]; - auto &dd = input.data[2]; + const auto &yyyy = input.data[0]; + const auto &mm = input.data[1]; + const auto &dd = input.data[2]; - TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), + TernaryExecutor::Execute(yyyy, mm, dd, result, MakeTimeOperator::Operation); } @@ -122,7 +121,7 @@ template void ExecuteMakeTimestamp(DataChunk &input, ExpressionState &state, Vector &result) { if (input.ColumnCount() == 1) { auto func = MakeTimestampOperator::Operation; - UnaryExecutor::Execute(input.data[0], result, input.size(), func); + UnaryExecutor::Execute(input.data[0], result, func); return; } @@ -137,7 +136,7 @@ void ExecuteMakeTimestampNs(DataChunk &input, ExpressionState &state, Vector &re D_ASSERT(input.ColumnCount() == 1); auto func = MakeTimestampOperator::Operation; - UnaryExecutor::Execute(input.data[0], result, input.size(), func); + UnaryExecutor::Execute(input.data[0], result, func); return; } diff --git a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp index a2f22b1ee..d8f96706e 100644 --- a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp @@ -229,8 +229,8 @@ template void TimeBucketFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 2); - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; + const auto &bucket_width_arg = args.data[0]; + const auto &ts_arg = args.data[1]; if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (ConstantVector::IsNull(bucket_width_arg)) { @@ -241,23 +241,23 @@ void TimeBucketFunction(DataChunk &args, ExpressionState &state, Vector &result) switch (bucket_width_type) { case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), + bucket_width_arg, ts_arg, result, TimeBucket::WidthConvertibleToMicrosBinaryOperator::Operation); break; case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), + bucket_width_arg, ts_arg, result, TimeBucket::WidthConvertibleToMonthsBinaryOperator::Operation); break; case TimeBucket::BucketWidthType::UNCLASSIFIED: - BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), + BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, TimeBucket::BinaryOperator::Operation); break; default: throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); } } else { - BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), + BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, TimeBucket::BinaryOperator::Operation); } } @@ -266,9 +266,9 @@ template void TimeBucketOffsetFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 3); - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &offset_arg = args.data[2]; + const auto &bucket_width_arg = args.data[0]; + const auto &ts_arg = args.data[1]; + const auto &offset_arg = args.data[2]; if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (ConstantVector::IsNull(bucket_width_arg)) { @@ -279,17 +279,17 @@ void TimeBucketOffsetFunction(DataChunk &args, ExpressionState &state, Vector &r switch (bucket_width_type) { case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, TimeBucket::OffsetWidthConvertibleToMicrosTernaryOperator::Operation); break; case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, TimeBucket::OffsetWidthConvertibleToMonthsTernaryOperator::Operation); break; case TimeBucket::BucketWidthType::UNCLASSIFIED: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, TimeBucket::OffsetTernaryOperator::Operation); break; default: @@ -297,7 +297,7 @@ void TimeBucketOffsetFunction(DataChunk &args, ExpressionState &state, Vector &r } } else { TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, TimeBucket::OffsetTernaryOperator::Operation); } } @@ -306,9 +306,9 @@ template void TimeBucketOriginFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 3); - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &origin_arg = args.data[2]; + const auto &bucket_width_arg = args.data[0]; + const auto &ts_arg = args.data[1]; + const auto &origin_arg = args.data[2]; if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR && origin_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -321,17 +321,17 @@ void TimeBucketOriginFunction(DataChunk &args, ExpressionState &state, Vector &r switch (bucket_width_type) { case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, TimeBucket::OriginWidthConvertibleToMicrosTernaryOperator::Operation); break; case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, TimeBucket::OriginWidthConvertibleToMonthsTernaryOperator::Operation); break; case TimeBucket::BucketWidthType::UNCLASSIFIED: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, TimeBucket::OriginTernaryOperator::Operation); break; default: @@ -340,7 +340,7 @@ void TimeBucketOriginFunction(DataChunk &args, ExpressionState &state, Vector &r } } else { TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, TimeBucket::OriginTernaryOperator::Operation); } } diff --git a/src/duckdb/extension/core_functions/scalar/debug/index_key.cpp b/src/duckdb/extension/core_functions/scalar/debug/index_key.cpp index 19b01017e..70eb22374 100644 --- a/src/duckdb/extension/core_functions/scalar/debug/index_key.cpp +++ b/src/duckdb/extension/core_functions/scalar/debug/index_key.cpp @@ -29,23 +29,26 @@ static TableDescription ExtractTableDescription(const child_list_t fields["table"] = ""; for (idx_t i = 0; i < field_types.size(); i++) { - auto field_name = StringUtil::Lower(field_types[i].first); + auto field_name = StringUtil::Lower(field_types[i].first.GetIdentifierName()); if (fields.find(field_name) == fields.end()) { - throw BinderException("index_key: unknown field '%s' in path", field_types[i].first); + throw BinderException("index_key: unknown field '%s' in path", field_types[i].first.GetIdentifierName()); } auto &field_value = field_values[i]; if (field_value.IsNull()) { - throw BinderException("index_key: path field '%s' cannot be NULL", field_types[i].first); + throw BinderException("index_key: path field '%s' cannot be NULL", + field_types[i].first.GetIdentifierName()); } if (field_value.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("index_key: path field '%s' must be VARCHAR", field_types[i].first); + throw BinderException("index_key: path field '%s' must be VARCHAR", + field_types[i].first.GetIdentifierName()); } auto value = StringValue::Get(field_value); if (value.empty()) { - throw BinderException("index_key: path field '%s' cannot be empty", field_types[i].first); + throw BinderException("index_key: path field '%s' cannot be empty", + field_types[i].first.GetIdentifierName()); } fields[field_name] = value; } @@ -54,7 +57,7 @@ static TableDescription ExtractTableDescription(const child_list_t throw BinderException("index_key: path must contain a 'table' field"); } - return TableDescription(fields["catalog"], fields["schema"], fields["table"]); + return TableDescription(Identifier(fields["catalog"]), Identifier(fields["schema"]), Identifier(fields["table"])); } static TableDescription EvaluateTableDescription(ClientContext &context, const Expression &expr) { @@ -95,25 +98,26 @@ static string GetStringArgument(ClientContext &context, const Expression &expr, return StringValue::Get(value); } -static BoundIndex &FindBoundIndex(TableIndexList &index_list, const string &index_name, const TableDescription &path) { +static BoundIndex &FindBoundIndex(TableIndexList &index_list, const Identifier &index_name, + const TableDescription &path) { auto found = index_list.Find(index_name); if (found) { return *found; } auto qualified_table = ParseInfo::QualifierToString(path.database, path.schema, path.table); - vector available; + vector available; for (auto &idx : index_list.Indexes()) { available.push_back(idx.GetIndexName()); } if (available.empty()) { throw CatalogException("index_key: index '%s' was not found on table %s. No indexes found on this table.", - index_name, qualified_table); + index_name.GetIdentifierName(), qualified_table); } auto available_list = StringUtil::Join(available, ", "); - throw CatalogException("index_key: index '%s' was not found on table %s. Available indexes: %s", index_name, - qualified_table, available_list); + throw CatalogException("index_key: index '%s' was not found on table %s. Available indexes: %s", + index_name.GetIdentifierName(), qualified_table, available_list); } struct IndexKeyBindData : public FunctionData { @@ -156,7 +160,7 @@ static unique_ptr IndexKeyBind(BindScalarFunctionInput &input) { data_table_info.BindIndexes(context); auto &index_list = data_table_info.GetIndexes(); - auto &bound_index = FindBoundIndex(index_list, index_name, path); + auto &bound_index = FindBoundIndex(index_list, Identifier(index_name), path); auto index_type = bound_index.GetIndexType(); if (index_type != ART::TYPE_NAME) { @@ -190,7 +194,7 @@ static unique_ptr IndexKeyBind(BindScalarFunctionInput &input) { static void IndexKeyFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &bind_data = func_expr.bind_info->Cast(); + auto &bind_data = func_expr.BindInfo()->Cast(); idx_t count = args.size(); @@ -200,7 +204,6 @@ static void IndexKeyFunction(DataChunk &args, ExpressionState &state, Vector &re for (idx_t i = 0; i < bind_data.key_types.size(); i++) { key_chunk.data[i].Reference(args.data[INDEX_KEY_FIXED_ARGS + i]); } - key_chunk.SetCardinality(count); auto &art = bind_data.art; unsafe_vector keys(count); diff --git a/src/duckdb/extension/core_functions/scalar/debug/sleep.cpp b/src/duckdb/extension/core_functions/scalar/debug/sleep.cpp index 93ec32b23..4b4310c08 100644 --- a/src/duckdb/extension/core_functions/scalar/debug/sleep.cpp +++ b/src/duckdb/extension/core_functions/scalar/debug/sleep.cpp @@ -17,8 +17,9 @@ struct NullResultType { static void SleepFunction(DataChunk &input, ExpressionState &state, Vector &result) { input.Flatten(); - GenericExecutor::ExecuteUnary, NullResultType>(input.data[0], result, input.size(), - [](PrimitiveType input) { + auto &context = state.GetContext(); + GenericExecutor::ExecuteUnary, NullResultType>(input.data[0], result, + [&context](PrimitiveType input) { // Sleep for the specified number of // milliseconds (clamp negative values to // 0) @@ -26,7 +27,7 @@ static void SleepFunction(DataChunk &input, ExpressionState &state, Vector &resu if (sleep_ms < 0) { sleep_ms = 0; } - ThreadUtil::SleepMs(sleep_ms); + ThreadUtil::SleepMs(sleep_ms, context); return NullResultType(); }); } diff --git a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp index 977f91005..0659bafc4 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp @@ -3,14 +3,37 @@ namespace duckdb { +namespace { +struct AliasBindData : public FunctionData { + explicit AliasBindData(Identifier alias_p) : alias(std::move(alias_p)) { + } + + Identifier alias; + + unique_ptr Copy() const override { + return make_uniq(alias); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return alias == other.alias; + } +}; + +unique_ptr AliasBind(BindScalarFunctionInput &input) { + return make_uniq(input.GetArguments()[0]->GetName()); +} +} // namespace + static void AliasFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - Value v(state.expr.GetAlias().empty() ? func_expr.children[0]->GetName() : state.expr.GetAlias()); + auto &bind_data = func_expr.BindInfo()->Cast(); + Value v(state.expr.GetAlias().empty() ? bind_data.alias : state.expr.GetAlias()); result.Reference(v, count_t(args.size())); } ScalarFunction AliasFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, AliasFunction); + auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, AliasFunction, AliasBind); fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp index 8610f349d..0afcad373 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp @@ -430,15 +430,15 @@ unique_ptr BindEquiWidthFunction(BindScalarFunctionInput &input) { template void EquiWidthBinFunction(DataChunk &args, ExpressionState &state, Vector &result) { static constexpr int64_t MAX_BIN_COUNT = 1000000; - auto &min_arg = args.data[0]; - auto &max_arg = args.data[1]; - auto &bin_count = args.data[2]; - auto &nice_rounding = args.data[3]; + const auto &min_arg = args.data[0]; + const auto &max_arg = args.data[1]; + const auto &bin_count = args.data[2]; + const auto &nice_rounding = args.data[3]; Vector intermediate_result(LogicalType::LIST(OP::LOGICAL_TYPE)); GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, PrimitiveType, GenericListType>>( - min_arg, max_arg, bin_count, nice_rounding, intermediate_result, args.size(), + min_arg, max_arg, bin_count, nice_rounding, intermediate_result, [&](PrimitiveType min_p, PrimitiveType max_p, PrimitiveType bins_p, PrimitiveType nice_rounding_p) { if (max_p.val < min_p.val) { diff --git a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp index 03d528bdb..9609196b0 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp @@ -29,7 +29,7 @@ struct CurrentSettingBindData : public FunctionData { void CurrentSettingFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); result.Reference(info.value, count_t(args.size())); } diff --git a/src/duckdb/extension/core_functions/scalar/generic/least.cpp b/src/duckdb/extension/core_functions/scalar/generic/least.cpp index 2a3ac3315..ee3307a08 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/least.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/least.cpp @@ -50,7 +50,7 @@ struct LeastGreatestSortKeyState : public FunctionLocalState { template unique_ptr LeastGreatestSortKeyInit(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { - return make_uniq(expr.children.size(), OP::NullOrdering()); + return make_uniq(expr.GetChildren().size(), OP::NullOrdering()); } template @@ -82,10 +82,8 @@ struct SortKeyLeastGreatest { auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); lstate.sort_keys.Reset(); for (idx_t c_idx = 0; c_idx < args.ColumnCount(); c_idx++) { - CreateSortKeyHelpers::CreateSortKey(args.data[c_idx], args.size(), lstate.modifiers, - lstate.sort_keys.data[c_idx]); + CreateSortKeyHelpers::CreateSortKey(args.data[c_idx], lstate.modifiers, lstate.sort_keys.data[c_idx]); } - lstate.sort_keys.SetCardinality(args.size()); return lstate.sort_keys; } diff --git a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp index 949b74e27..76c273b9c 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp @@ -23,7 +23,7 @@ struct StatsBindData : public FunctionData { void StatsFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); result.Reference(info.stats, count_t(args.size())); } diff --git a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp index 76c491a20..894512a5b 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp @@ -76,7 +76,7 @@ unique_ptr CurrentSchemasBind(BindScalarFunctionInput &input) { // current_schemas void CurrentSchemasFunction(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); result.Reference(info.result, count_t(input.size())); } @@ -85,8 +85,9 @@ void InSearchPathFunction(DataChunk &input, ExpressionState &state, Vector &resu auto &context = state.GetContext(); auto &search_path = ClientData::Get(context).catalog_search_path; BinaryExecutor::Execute( - input.data[0], input.data[1], result, input.size(), [&](string_t db_name, string_t schema_name) { - return search_path->SchemaInSearchPath(context, db_name.GetString(), schema_name.GetString()); + input.data[0], input.data[1], result, [&](string_t db_name, string_t schema_name) { + return search_path->SchemaInSearchPath(context, Identifier(db_name.GetString()), + Identifier(schema_name.GetString())); }); } diff --git a/src/duckdb/extension/core_functions/scalar/generic/type_functions.cpp b/src/duckdb/extension/core_functions/scalar/generic/type_functions.cpp index 97833ba42..b1e087562 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/type_functions.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/type_functions.cpp @@ -83,7 +83,7 @@ static unique_ptr BindMakeTypeFunctionExpression(FunctionBindExpress // Evaluate all arguments to constant values for (auto &child : input.children) { - string name = child->GetAlias(); + string name = child->GetAlias().GetIdentifierName(); if (!child->IsFoldable()) { throw BinderException("make_type function arguments must be constant expressions"); } @@ -103,7 +103,7 @@ static unique_ptr BindMakeTypeFunctionExpression(FunctionBindExpress for (idx_t i = 1; i < args.size(); i++) { auto &arg = args[i]; auto result = make_uniq(arg.second); - result->SetAlias(arg.first); + result->SetAlias(Identifier(arg.first)); type_args.push_back(std::move(result)); } diff --git a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp index 0d71cf5e2..dc931000b 100644 --- a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp @@ -271,7 +271,7 @@ void ArraySliceFunction(DataChunk &args, ExpressionState &state, Vector &result) } auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); auto begin_is_empty = info.begin_is_empty; auto end_is_empty = info.end_is_empty; switch (result.GetType().id()) { diff --git a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp index 65a77bc01..c760f971b 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp @@ -28,7 +28,7 @@ unique_ptr ListAggregatesInitLocalState(ExpressionState &sta FunctionData *bind_data) { return make_uniq(BufferAllocator::Get(state.GetContext())); } -// FIXME: benchmark the use of simple_update against using update (if applicable) +// FIXME: benchmark the use of cluster_update against using update (if applicable) unique_ptr ListAggregatesBindFailure(BoundScalarFunction &bound_function) { bound_function.GetArguments()[0] = LogicalType::SQLNULL; @@ -94,10 +94,10 @@ struct StateVector { ~StateVector() { // NOLINT // destroy objects within the aggregate states auto &aggr = aggr_expr->Cast(); - if (aggr.function.HasStateDestructorCallback()) { + if (aggr.Function().HasStateDestructorCallback()) { ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - aggr.function.GetStateDestructorCallback()(state_vector, aggr_input_data, count); + AggregateInputData aggr_input_data(aggr, allocator); + aggr.Function().GetStateDestructorCallback()(state_vector, aggr_input_data, count); } } @@ -130,13 +130,13 @@ struct FinalizeGenericValueFunctor { struct AggregateFunctor { template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + static void ListExecuteFunction(Vector &result, const Vector &state_vector, idx_t count) { } }; struct DistinctFunctor { template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + static void ListExecuteFunction(Vector &result, const Vector &state_vector, idx_t count) { UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(sdata); auto states = UnifiedVectorFormat::GetData *>(sdata); @@ -179,7 +179,7 @@ struct DistinctFunctor { struct UniqueFunctor { template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { + static void ListExecuteFunction(Vector &result, const Vector &state_vector, idx_t count) { UnifiedVectorFormat sdata; state_vector.ToUnifiedFormat(sdata); auto states = UnifiedVectorFormat::GetData *>(sdata); @@ -214,13 +214,13 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res // get the aggregate function auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); auto &aggr = info.aggr_expr->Cast(); auto &allocator = ExecuteFunctionState::GetFunctionState(state)->Cast().arena_allocator; allocator.Reset(); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); - D_ASSERT(aggr.function.HasStateUpdateCallback()); + D_ASSERT(aggr.Function().HasStateUpdateCallback()); auto &child_vector = ListVector::GetChildMutable(lists); child_vector.Flatten(); @@ -233,7 +233,7 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res auto list_entries = UnifiedVectorFormat::GetData(lists_data); // state_buffer holds the state for each list of this chunk - idx_t size = aggr.function.GetStateSizeCallback()(aggr.function); + idx_t size = aggr.Function().GetStateSizeCallback()(aggr.Function()); auto state_buffer = make_unsafe_uniq_array_uninitialized(size * count); // state vector for initialize and finalize @@ -252,7 +252,7 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res // initialize the state for this list auto state_ptr = state_buffer.get() + size * i; states[i] = state_ptr; - aggr.function.GetStateInitCallback()(aggr.function, states[i]); + aggr.Function().GetStateInitCallback()(aggr.Function(), states[i]); auto lists_index = lists_data.sel->get_index(i); const auto &list_entry = list_entries[lists_index]; @@ -273,7 +273,7 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res if (states_idx == STANDARD_VECTOR_SIZE) { // update the aggregate state(s) Vector slice(child_vector, sel_vector, states_idx); - aggr.function.GetStateUpdateCallback()(&slice, aggr_input_data, 1, state_vector_update, states_idx); + aggr.Function().GetStateUpdateCallback()(&slice, aggr_input_data, 1, state_vector_update, states_idx); // reset values states_idx = 0; @@ -289,17 +289,17 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res // update the remaining elements of the last list(s) if (states_idx != 0) { Vector slice(child_vector, sel_vector, states_idx); - aggr.function.GetStateUpdateCallback()(&slice, aggr_input_data, 1, state_vector_update, states_idx); + aggr.Function().GetStateUpdateCallback()(&slice, aggr_input_data, 1, state_vector_update, states_idx); } if (IS_AGGR) { // finalize all the aggregate states - aggr.function.GetStateFinalizeCallback()(state_vector.state_vector, aggr_input_data, result, count, 0); + aggr.Function().GetStateFinalizeCallback()(state_vector.state_vector, aggr_input_data, result, count, 0); } else { // finalize manually to use the map - D_ASSERT(aggr.function.GetArguments().size() == 1); - auto key_type = aggr.function.GetArguments()[0]; + D_ASSERT(aggr.Function().GetArguments().size() == 1); + auto key_type = aggr.Function().GetArguments()[0]; switch (key_type.InternalType()) { #ifndef DUCKDB_SMALLER_BINARY @@ -396,13 +396,13 @@ unique_ptr ListAggregatesBindFunction(ClientContext &context, Boun FunctionBinder function_binder(context); auto bound_aggr_function = function_binder.BindAggregateFunction(aggr_function, std::move(children)); - bound_function.GetArguments()[0] = LogicalType::LIST(bound_aggr_function->function.GetArguments()[0]); + bound_function.GetArguments()[0] = LogicalType::LIST(bound_aggr_function->Function().GetArguments()[0]); if (IS_AGGR) { - bound_function.SetReturnType(bound_aggr_function->function.GetReturnType()); + bound_function.SetReturnType(bound_aggr_function->Function().GetReturnType()); } // check if the aggregate function consumed all the extra input arguments - if (bound_aggr_function->children.size() > 1) { + if (bound_aggr_function->GetChildren().size() > 1) { throw InvalidInputException( "Aggregate function %s is not supported for list_aggr: extra arguments were not removed during bind", bound_aggr_function->ToString()); @@ -445,8 +445,8 @@ unique_ptr ListAggregatesBind(BindScalarFunctionInput &input) { } // look up the aggregate function in the catalog - auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, DEFAULT_SCHEMA, - function_name); + auto &func = Catalog::GetSystemCatalog(context).GetEntry( + context, Identifier::DefaultSchema(), Identifier(function_name)); D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); if (is_parameter) { diff --git a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp index 810728a8b..c286e3abf 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp @@ -15,9 +15,7 @@ template static void ListGenericFold(DataChunk &args, ExpressionState &state, Vector &result) { const auto &lstate = state.Cast(); const auto &expr = lstate.expr.Cast(); - const auto &func_name = expr.function.GetName(); - - auto count = args.size(); + const auto &func_name = expr.Function().GetName(); auto &lhs_vec = args.data[0]; auto &rhs_vec = args.data[1]; @@ -46,7 +44,7 @@ static void ListGenericFold(DataChunk &args, ExpressionState &state, Vector &res auto rhs_data = FlatVector::GetData(rhs_child); BinaryExecutor::Execute( - lhs_vec, rhs_vec, result, count, [&](const list_entry_t &left, const list_entry_t &right) -> optional { + lhs_vec, rhs_vec, result, [&](const list_entry_t &left, const list_entry_t &right) -> optional { if (left.length != right.length) { throw InvalidInputException( "%s: list dimensions must be equal, got left length '%d' and right length '%d'", func_name, diff --git a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp index 73914598b..07496c2d5 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp @@ -19,16 +19,16 @@ static unique_ptr ListFilterBind(BindScalarFunctionInput &input) { auto &bound_lambda_expr = arguments[1]->Cast(); // try to cast to boolean, if the return type of the lambda filter expression is not already boolean - if (bound_lambda_expr.lambda_expr->GetReturnType() != LogicalType::BOOLEAN) { - auto cast_lambda_expr = - BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.lambda_expr), LogicalType::BOOLEAN); - bound_lambda_expr.lambda_expr = std::move(cast_lambda_expr); + if (bound_lambda_expr.LambdaExpr()->GetReturnType() != LogicalType::BOOLEAN) { + auto cast_lambda_expr = BoundCastExpression::AddCastToType( + context, std::move(bound_lambda_expr.LambdaExprMutable()), LogicalType::BOOLEAN); + bound_lambda_expr.LambdaExprMutable() = std::move(cast_lambda_expr); } arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); bound_function.SetReturnType(arguments[0]->GetReturnType()); - auto has_index = bound_lambda_expr.parameter_count == 2; + auto has_index = bound_lambda_expr.ParameterCount() == 2; return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp index 9e1144bfc..5e2bf4b94 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp @@ -36,8 +36,8 @@ static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &resul const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); - CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(l_child, order_modifiers, l_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(r_child, order_modifiers, r_sortkey_vec); const auto l_sortkey_ptr = FlatVector::GetData(l_sortkey_vec); const auto r_sortkey_ptr = FlatVector::GetData(r_sortkey_vec); @@ -45,7 +45,7 @@ static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &resul string_set_t set; BinaryExecutor::Execute( - l_vec, r_vec, result, args.size(), [&](const list_entry_t &l_list, const list_entry_t &r_list) { + l_vec, r_vec, result, [&](const list_entry_t &l_list, const list_entry_t &r_list) { // Short circuit if either list is empty if (l_list.length == 0 || r_list.length == 0) { return false; @@ -95,7 +95,7 @@ static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &resul static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector &result) { const auto &func_expr = state.expr.Cast(); - const auto swap = func_expr.function.GetName() == "<@"; + const auto swap = func_expr.Function().GetName() == "<@"; auto &l_vec = args.data[swap ? 1 : 0]; auto &r_vec = args.data[swap ? 0 : 1]; @@ -126,8 +126,8 @@ static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector & const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); - CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(l_child, order_modifiers, l_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(r_child, order_modifiers, r_sortkey_vec); const auto build_data = FlatVector::GetData(l_sortkey_vec); const auto probe_data = FlatVector::GetData(r_sortkey_vec); @@ -135,7 +135,7 @@ static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector & string_set_t set; BinaryExecutor::Execute( - l_vec, r_vec, result, args.size(), [&](const list_entry_t &build_list, const list_entry_t &probe_list) { + l_vec, r_vec, result, [&](const list_entry_t &build_list, const list_entry_t &probe_list) { // Short circuit if the probe list is empty if (probe_list.length == 0) { return true; diff --git a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp index 0ede2c95b..2c58e1717 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp @@ -155,7 +155,7 @@ bool ExecuteReduce(const idx_t loops, ReduceExecuteInfo &execute_info, LambdaFun // create the input chunk DataChunk input_chunk; input_chunk.InitializeEmpty(execute_info.input_types); - input_chunk.SetCardinality(reduced_row_idx); + input_chunk.SetChildCardinality(reduced_row_idx); const idx_t element_offset = info.has_index ? 1 : 0; const idx_t accumulator_offset = element_offset + 1; @@ -190,7 +190,7 @@ bool ExecuteReduce(const idx_t loops, ReduceExecuteInfo &execute_info, LambdaFun } result_chunk.Reset(); - result_chunk.SetCardinality(reduced_row_idx); + result_chunk.SetChildCardinality(reduced_row_idx); execute_info.expr_executor->Execute(input_chunk, result_chunk); // We need to copy the result into left_slice to avoid data loss due to vector.Reference(...). @@ -232,13 +232,13 @@ unique_ptr ListReduceBind(BindScalarFunctionInput &input) { if (has_initial) { const auto &initial_value_type = arguments[2]->GetReturnType(); auto &bound_lambda_expr = arguments[1]->Cast(); - auto &lambda_return_type = bound_lambda_expr.lambda_expr->GetReturnType(); + auto &lambda_return_type = bound_lambda_expr.LambdaExpr()->GetReturnType(); accumulator_type = ResolveReduceAccumulatorType(context, initial_value_type, lambda_return_type); arguments[2] = BoundCastExpression::AddCastToType(context, std::move(arguments[2]), accumulator_type); } else { auto &bound_lambda_expr = arguments[1]->Cast(); auto list_child_type = LambdaFunctions::DetermineListChildType(arguments[0]->GetReturnType()); - auto &lambda_return_type = bound_lambda_expr.lambda_expr->GetReturnType(); + auto &lambda_return_type = bound_lambda_expr.LambdaExpr()->GetReturnType(); if (!LogicalType::TryGetMaxLogicalType(context, list_child_type, lambda_return_type, accumulator_type)) { throw BinderException("No common super type between list element type %s and lambda return type %s", list_child_type.ToString(), lambda_return_type.ToString()); @@ -246,14 +246,14 @@ unique_ptr ListReduceBind(BindScalarFunctionInput &input) { } auto &bound_lambda_expr = arguments[1]->Cast(); - if (bound_lambda_expr.parameter_count < 2 || bound_lambda_expr.parameter_count > 3) { + if (bound_lambda_expr.ParameterCount() < 2 || bound_lambda_expr.ParameterCount() > 3) { throw BinderException("list_reduce expects a function with 2 or 3 arguments"); } - auto has_index = bound_lambda_expr.parameter_count == 3; + auto has_index = bound_lambda_expr.ParameterCount() == 3; - const auto lambda_return_type = bound_lambda_expr.lambda_expr->GetReturnType(); + const auto lambda_return_type = bound_lambda_expr.LambdaExpr()->GetReturnType(); auto cast_lambda_expr = - BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.lambda_expr), accumulator_type); + BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.LambdaExprMutable()), accumulator_type); if (!cast_lambda_expr) { throw BinderException("Could not cast lambda return type %s to accumulator type %s", lambda_return_type.ToString(), accumulator_type.ToString()); diff --git a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp index 7933a3d34..bee224024 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp @@ -92,7 +92,7 @@ static void SinkDataChunk(const Sort &sort, ExecutionContext &context, OperatorS static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() >= 1 && args.ColumnCount() <= 3); auto count = args.size(); - Vector &input_lists = args.data[0]; + const Vector &input_lists = args.data[0]; result.SetVectorType(VectorType::FLAT_VECTOR); auto &result_validity = FlatVector::ValidityMutable(result); @@ -103,7 +103,7 @@ static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &re } auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); // initialize the global and local sorting state auto global_sink_state = info.sort->GetGlobalSinkState(info.context); @@ -211,7 +211,6 @@ static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &re for (;;) { DataChunk result_chunk; result_chunk.Initialize(Allocator::DefaultAllocator(), {LogicalType::UINTEGER}); - result_chunk.SetCardinality(0); info.sort->GetData(execution_context, result_chunk, source_input); if (result_chunk.size() == 0) { break; diff --git a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp index fb3ef0a66..b12c580a2 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp @@ -19,8 +19,8 @@ static unique_ptr ListTransformBind(BindScalarFunctionInput &input arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); auto &bound_lambda_expr = arguments[1]->Cast(); - bound_function.SetReturnType(LogicalType::LIST(bound_lambda_expr.lambda_expr->GetReturnType())); - auto has_index = bound_lambda_expr.parameter_count == 2; + bound_function.SetReturnType(LogicalType::LIST(bound_lambda_expr.LambdaExpr()->GetReturnType())); + auto has_index = bound_lambda_expr.ParameterCount() == 2; return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp index f01d54a70..64e35803f 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp @@ -180,8 +180,6 @@ bool StructFunction(DataChunk &args, Vector &result) { DataChunk chunk; chunk.InitializeEmpty(types); - chunk.SetCardinality(args.size()); - for (idx_t col = 0; col < column_count; col++) { auto &struct_vector = args.data[col]; if (struct_vector.GetVectorType() != VectorType::CONSTANT_VECTOR) { diff --git a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp index 724762f47..e05d04f8a 100644 --- a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp @@ -8,7 +8,7 @@ namespace duckdb { static void CardinalityFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &map = args.data[0]; + const auto &map = args.data[0]; auto entries = map.Values(); auto result_data = FlatVector::Writer(result, args.size()); diff --git a/src/duckdb/extension/core_functions/scalar/map/map.cpp b/src/duckdb/extension/core_functions/scalar/map/map.cpp index 48bbaa551..c2ee5863f 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map.cpp @@ -26,8 +26,8 @@ static bool MapIsNull(DataChunk &chunk) { return false; } D_ASSERT(chunk.data.size() == 2); - auto &keys = chunk.data[0]; - auto &values = chunk.data[1]; + const auto &keys = chunk.data[0]; + const auto &values = chunk.data[1]; if (keys.GetType().id() == LogicalTypeId::SQLNULL) { return true; diff --git a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp index caff21735..8d32b248c 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp @@ -48,7 +48,7 @@ void MapConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) auto map_count = args.ColumnCount(); vector map_formats(map_count); for (idx_t i = 0; i < map_count; i++) { - auto &map = args.data[i]; + const auto &map = args.data[i]; map.ToUnifiedFormat(map_formats[i]); } auto result_data = FlatVector::Writer(result, count); @@ -71,7 +71,7 @@ void MapConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) } all_null = false; - auto &keys = MapVector::GetKeys(args.data[map_idx]); + const auto &keys = MapVector::GetKeys(args.data[map_idx]); auto entry = UnifiedVectorFormat::GetData(map_format)[index]; // Update the list for this row @@ -104,8 +104,8 @@ void MapConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) D_ASSERT(keys_list.size() == index_to_map.size()); // Get the values from the mapping for (auto &mapping : index_to_map) { - auto &map = args.data[mapping.map_index]; - auto &values = MapVector::GetValues(map); + const auto &map = args.data[mapping.map_index]; + const auto &values = MapVector::GetValues(map); values_list.push_back(values.GetValue(mapping.key_index)); } D_ASSERT(values_list.size() == keys_list.size()); diff --git a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp index d1a7fe730..6623c2142 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp @@ -12,12 +12,12 @@ namespace duckdb { static void MapEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto count = args.size(); - auto &map = args.data[0]; + const auto &map = args.data[0]; if (map.GetType().id() == LogicalTypeId::SQLNULL) { ConstantVector::SetNull(result, count_t(count)); return; } - MapUtil::ReinterpretMap(result, map, count); + MapUtil::ReinterpretMap(result, map); } ScalarFunction MapEntriesFun::GetFunction() { diff --git a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp index 42844dff8..43b21df6b 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp @@ -10,11 +10,11 @@ namespace duckdb { static void MapExtractValueFunc(DataChunk &args, ExpressionState &state, Vector &result) { const auto count = args.size(); - auto &map_vec = args.data[0]; - auto &arg_vec = args.data[1]; + const auto &map_vec = args.data[0]; + const auto &arg_vec = args.data[1]; - auto &key_vec = MapVector::GetKeys(map_vec); - auto &val_vec = MapVector::GetValues(map_vec); + const auto &key_vec = MapVector::GetKeys(map_vec); + const auto &val_vec = MapVector::GetValues(map_vec); // Collect the matching positions Vector pos_vec(LogicalType::INTEGER, count); @@ -58,11 +58,11 @@ static void MapExtractValueFunc(DataChunk &args, ExpressionState &state, Vector static void MapExtractListFunc(DataChunk &args, ExpressionState &state, Vector &result) { const auto count = args.size(); - auto &map_vec = args.data[0]; - auto &arg_vec = args.data[1]; + const auto &map_vec = args.data[0]; + const auto &arg_vec = args.data[1]; - auto &key_vec = MapVector::GetKeys(map_vec); - auto &val_vec = MapVector::GetValues(map_vec); + const auto &key_vec = MapVector::GetKeys(map_vec); + const auto &val_vec = MapVector::GetValues(map_vec); // Collect the matching positions Vector pos_vec(LogicalType::INTEGER, count); diff --git a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp index 9f43ba408..2e1d37756 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp @@ -11,7 +11,7 @@ namespace duckdb { static void MapFromEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto count = args.size(); - MapUtil::ReinterpretMap(result, args.data[0], count); + MapUtil::ReinterpretMap(result, args.data[0]); MapVector::MapConversionVerify(result, count); } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp index f6b186d42..82009e79b 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp @@ -18,7 +18,6 @@ static void MapKeyValueFunction(DataChunk &args, ExpressionState &state, Vector ConstantVector::SetNull(result, count_t(args.size())); return; } - auto count = args.size(); map.Flatten(); D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); @@ -27,7 +26,7 @@ static void MapKeyValueFunction(DataChunk &args, ExpressionState &state, Vector auto &entries = ListVector::GetChildMutable(result); entries.Reference(child); - FlatVector::SetData(result, FlatVector::GetDataMutable(map), count_t(count)); + FlatVector::SetData(result, FlatVector::GetDataMutable(map), count_t(args.size())); FlatVector::SetValidity(result, FlatVector::ValidityMutable(map)); auto list_size = ListVector::GetListSize(map); ListVector::SetListSize(result, list_size); diff --git a/src/duckdb/extension/core_functions/scalar/map/switch.cpp b/src/duckdb/extension/core_functions/scalar/map/switch.cpp index 16273af51..955aa2b71 100644 --- a/src/duckdb/extension/core_functions/scalar/map/switch.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/switch.cpp @@ -53,7 +53,7 @@ unique_ptr SwitchBindReturnType(BindScalarFunctionInput &input) { throw BinderException("SWITCH expected a constant map for the cases"); } auto &func = cases->Cast(); - if (func.function.GetName() != "map") { + if (func.Function().GetName() != "map") { throw BinderException("SWITCH expected a constant map for the cases"); } auto map_value = ExpressionExecutor::EvaluateScalar(context, *cases); @@ -66,13 +66,13 @@ void ExtractConstantExprFromList(unique_ptr &expr, vectorCast(); - if (list_function.function.GetName() != "list_value") { + if (list_function.Function().GetName() != "list_value") { throw BinderException("Expected a list function"); } - if (list_function.children.empty()) { + if (list_function.GetChildren().empty()) { throw BinderException("No values provided for SWITCH expression"); } - for (auto &list_child : list_function.children) { + for (auto &list_child : list_function.GetChildrenMutable()) { if (list_child->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { throw NotImplementedException("Only constant expressions are supported for keys inside SWITCH"); } @@ -100,22 +100,22 @@ unique_ptr SwitchBindExpression(FunctionBindExpressionInput &input) cases = std::move(input.children[map_index]); } else if (input.children[map_index]->GetExpressionClass() == ExpressionClass::BOUND_CAST) { auto &cast_expr = input.children[map_index]->Cast(); - if (cast_expr.child->GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + if (cast_expr.Child().GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { throw BinderException("SWITCH expected a map function for the cases"); } - cases = std::move(cast_expr.child); + cases = std::move(cast_expr.ChildMutable()); } else { throw BinderException("SWITCH expected a map function for the cases"); } auto &cases_func = cases->Cast(); - D_ASSERT(cases_func.children.size() == 2); + D_ASSERT(cases_func.GetChildren().size() == 2); vector> keys_unpacked; vector> values_unpacked; - ExtractConstantExprFromList(cases_func.children[0], keys_unpacked); - ExtractConstantExprFromList(cases_func.children[1], values_unpacked); + ExtractConstantExprFromList(cases_func.GetChildrenMutable()[0], keys_unpacked); + ExtractConstantExprFromList(cases_func.GetChildrenMutable()[1], values_unpacked); - result->case_checks.reserve(keys_unpacked.size()); + result->CaseChecksMutable().reserve(keys_unpacked.size()); for (idx_t i = 0; i < keys_unpacked.size(); i++) { BoundCaseCheck case_check; if (base_expr) { @@ -136,12 +136,12 @@ unique_ptr SwitchBindExpression(FunctionBindExpressionInput &input) function_data.return_type.ToString(), then_type.ToString()); } case_check.then_expr = std::move(values_unpacked[i]); - result->case_checks.push_back(std::move(case_check)); + result->CaseChecksMutable().push_back(std::move(case_check)); } if (default_expr) { - result->else_expr = std::move(default_expr); + result->ElseMutable() = std::move(default_expr); } else { - result->else_expr = BoundCastExpression::AddCastToType( + result->ElseMutable() = BoundCastExpression::AddCastToType( input.context, make_uniq(Value()), function_data.return_type); } return std::move(result); diff --git a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp index 065a854a5..553dc6302 100644 --- a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp +++ b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp @@ -149,12 +149,13 @@ static unique_ptr PropagateAbsStats(ClientContext &context, Func max_val = MaxValue(AbsValue(current_min), current_max); } else { // if both current_min and current_max are > 0, then the abs is a no-op and can be removed entirely - *input.expr_ptr = std::move(input.expr.children[0]); + *input.expr_ptr = std::move(input.expr.GetChildrenMutable()[0]); return child_stats[0].ToUnique(); } new_min = Value::Numeric(expr.GetReturnType(), min_val); new_max = Value::Numeric(expr.GetReturnType(), max_val); - expr.function.SetFunctionCallback(ScalarFunction::GetScalarUnaryFunction(expr.GetReturnType())); + expr.FunctionMutable().SetFunctionCallback( + ScalarFunction::GetScalarUnaryFunction(expr.GetReturnType())); } auto stats = NumericStats::CreateEmpty(expr.GetReturnType()); NumericStats::SetMin(stats, new_min); @@ -351,7 +352,7 @@ struct CeilOperator { template static void GenericRoundFunctionDecimal(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - OP::template Operation(input, DecimalType::GetScale(func_expr.children[0]->GetReturnType()), + OP::template Operation(input, DecimalType::GetScale(func_expr.GetChildren()[0]->GetReturnType()), result); } @@ -391,7 +392,7 @@ struct CeilDecimalOperator { template static void Operation(DataChunk &input, uint8_t scale, Vector &result) { T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + UnaryExecutor::Execute(input.data[0], result, [&](T input) { if (input <= 0) { // below 0 we floor the number (e.g. -10.5 -> -10) return UnsafeNumericCast(input / power_of_ten); @@ -447,7 +448,7 @@ struct FloorDecimalOperator { template static void Operation(DataChunk &input, uint8_t scale, Vector &result) { T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + UnaryExecutor::Execute(input.data[0], result, [&](T input) { if (input < 0) { // below 0 we ceil the number (e.g. -10.5 -> -11) return UnsafeNumericCast(((input + 1) / power_of_ten) - 1); @@ -523,7 +524,7 @@ unique_ptr BindDecimalRoundPrecision(BindScalarFunctionInput &inpu if (arguments[1]->HasParameter()) { throw ParameterNotResolvedException(); } - auto fname = StringUtil::Upper(bound_function.GetName()); + auto fname = StringUtil::Upper(bound_function.GetName().GetIdentifierName()); if (!arguments[1]->IsFoldable()) { throw NotImplementedException("%s(DECIMAL, INTEGER) with non-constant precision is not supported", fname); } @@ -617,7 +618,7 @@ struct TruncDecimalOperator { template static void Operation(DataChunk &input, uint8_t scale, Vector &result) { T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + UnaryExecutor::Execute(input.data[0], result, [&](T input) { // Always floor return UnsafeNumericCast((input / power_of_ten)); }); @@ -628,9 +629,9 @@ struct TruncDecimalNegativePrecisionOperator { template static void Operation(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->GetReturnType()); - auto width = DecimalType::GetWidth(func_expr.children[0]->GetReturnType()); + auto &info = func_expr.BindInfo()->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.GetChildren()[0]->GetReturnType()); + auto width = DecimalType::GetWidth(func_expr.GetChildren()[0]->GetReturnType()); if (info.target_scale <= -int32_t(width - source_scale)) { // scale too big for width result.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -641,7 +642,7 @@ struct TruncDecimalNegativePrecisionOperator { UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale + source_scale]); T multiply_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]); - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + UnaryExecutor::Execute(input.data[0], result, [&](T input) { return UnsafeNumericCast(input / divide_power_of_ten * multiply_power_of_ten); }); } @@ -651,10 +652,10 @@ struct TruncDecimalPositivePrecisionOperator { template static void Operation(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->GetReturnType()); + auto &info = func_expr.BindInfo()->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.GetChildren()[0]->GetReturnType()); T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]); - UnaryExecutor::Execute(input.data[0], result, input.size(), + UnaryExecutor::Execute(input.data[0], result, [&](T input) { return UnsafeNumericCast(input / power_of_ten); }); } }; @@ -805,7 +806,7 @@ struct RoundDecimalOperator { // and then flooring the number // e.g. 10.5 + 0.5 = 11, floor(11) = 11 // 10.4 + 0.5 = 10.9, floor(10.9) = 10 - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + UnaryExecutor::Execute(input.data[0], result, [&](T input) { if (input < 0) { input -= addition; } else { @@ -850,9 +851,9 @@ struct DecimalRoundNegativePrecisionOperator { template static void Operation(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->GetReturnType()); - auto width = DecimalType::GetWidth(func_expr.children[0]->GetReturnType()); + auto &info = func_expr.BindInfo()->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.GetChildren()[0]->GetReturnType()); + auto width = DecimalType::GetWidth(func_expr.GetChildren()[0]->GetReturnType()); if (info.target_scale <= -int32_t(width - source_scale)) { // scale too big for width result.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -864,7 +865,7 @@ struct DecimalRoundNegativePrecisionOperator { T multiply_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]); T addition = divide_power_of_ten / 2; - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + UnaryExecutor::Execute(input.data[0], result, [&](T input) { if (input < 0) { input -= addition; } else { @@ -879,11 +880,11 @@ struct DecimalRoundPositivePrecisionOperator { template static void Operation(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->GetReturnType()); + auto &info = func_expr.BindInfo()->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.GetChildren()[0]->GetReturnType()); T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]); T addition = power_of_ten / 2; - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + UnaryExecutor::Execute(input.data[0], result, [&](T input) { if (input < 0) { input -= addition; } else { @@ -976,6 +977,16 @@ ScalarFunction ExpFun::GetFunction() { namespace { struct PowOperator { + template + static inline TR Operation(TA base, TB exponent) { + if (base == 0.0 && exponent < 0.0) { + throw OutOfRangeException("zero raised to a negative power is undefined"); + } + return std::pow(base, exponent); + } +}; + +struct IEEEPowOperator { template static inline TR Operation(TA base, TB exponent) { return std::pow(base, exponent); @@ -984,8 +995,10 @@ struct PowOperator { } // namespace ScalarFunction PowOperatorFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::BinaryFunction); + ScalarFunction function({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, nullptr, + BindIEEEFloatingBinary); + function.SetFallible(); + return function; } //===--------------------------------------------------------------------===// @@ -1714,6 +1727,9 @@ namespace { struct FactorialOperator { template static inline TR Operation(TA left) { + if (left < 0) { + throw OutOfRangeException("factorial of a negative number is undefined"); + } TR ret = 1; for (TA i = 2; i <= left; i++) { if (!TryMultiplyOperator::Operation(ret, TR(i), ret)) { diff --git a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp index f875213b0..bee840bdf 100644 --- a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp +++ b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp @@ -98,7 +98,7 @@ struct BitwiseANDOperator { void BitwiseANDOperation(DataChunk &args, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); - BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), + BinaryExecutor::Execute(args.data[0], args.data[1], result, [&](string_t rhs, string_t lhs) { string_t target = heap.EmptyString(rhs.GetSize()); @@ -136,7 +136,7 @@ struct BitwiseOROperator { void BitwiseOROperation(DataChunk &args, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); - BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), + BinaryExecutor::Execute(args.data[0], args.data[1], result, [&](string_t rhs, string_t lhs) { string_t target = heap.EmptyString(rhs.GetSize()); @@ -174,7 +174,7 @@ struct BitwiseXOROperator { void BitwiseXOROperation(DataChunk &args, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); - BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), + BinaryExecutor::Execute(args.data[0], args.data[1], result, [&](string_t rhs, string_t lhs) { string_t target = heap.EmptyString(rhs.GetSize()); @@ -212,7 +212,7 @@ struct BitwiseNotOperator { void BitwiseNOTOperation(DataChunk &args, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(args.data[0], result, [&](string_t input) { string_t target = heap.EmptyString(input.GetSize()); Bit::BitwiseNot(input, target); @@ -269,7 +269,7 @@ struct BitwiseShiftLeftOperator { void BitwiseShiftLeftOperation(DataChunk &args, ExpressionState &state, Vector &result) { BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { + args.data[0], args.data[1], result, [&](string_t input, int32_t shift) { auto max_shift = UnsafeNumericCast(Bit::BitLength(input)); if (shift == 0) { return input; @@ -322,7 +322,7 @@ struct BitwiseShiftRightOperator { void BitwiseShiftRightOperation(DataChunk &args, ExpressionState &state, Vector &result) { BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { + args.data[0], args.data[1], result, [&](string_t input, int32_t shift) { auto max_shift = UnsafeNumericCast(Bit::BitLength(input)); if (shift == 0) { return input; diff --git a/src/duckdb/extension/core_functions/scalar/random/random.cpp b/src/duckdb/extension/core_functions/scalar/random/random.cpp index a83eca847..d4b976f56 100644 --- a/src/duckdb/extension/core_functions/scalar/random/random.cpp +++ b/src/duckdb/extension/core_functions/scalar/random/random.cpp @@ -46,17 +46,15 @@ struct ExtractTimestampUuidOperator { template void ExtractVersionFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; - idx_t count = args.size(); - UnaryExecutor::Execute(input, result, count); + const auto &input = args.data[0]; + UnaryExecutor::Execute(input, result); } template void ExtractTimestampFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; - idx_t count = args.size(); - UnaryExecutor::Execute(input, result, count); + const auto &input = args.data[0]; + UnaryExecutor::Execute(input, result); } struct RandomLocalState : public FunctionLocalState { diff --git a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp index 54d1d5cd8..7e824d73b 100644 --- a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp +++ b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp @@ -29,7 +29,7 @@ struct SetseedBindData : public FunctionData { void SetSeedFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); auto &input = args.data[0]; input.Flatten(); diff --git a/src/duckdb/extension/core_functions/scalar/string/bar.cpp b/src/duckdb/extension/core_functions/scalar/string/bar.cpp index 968e81fa1..cb196590c 100644 --- a/src/duckdb/extension/core_functions/scalar/string/bar.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/bar.cpp @@ -64,24 +64,24 @@ static string_t BarScalarFunction(double x, double min, double max, double max_w static void BarFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); - auto &x_arg = args.data[0]; - auto &min_arg = args.data[1]; - auto &max_arg = args.data[2]; + const auto &x_arg = args.data[0]; + const auto &min_arg = args.data[1]; + const auto &max_arg = args.data[2]; string buffer; auto &heap = StringVector::GetStringHeap(result); if (args.ColumnCount() == 3) { GenericExecutor::ExecuteTernary, PrimitiveType, PrimitiveType, PrimitiveType>( - x_arg, min_arg, max_arg, result, args.size(), + x_arg, min_arg, max_arg, result, [&](PrimitiveType x, PrimitiveType min, PrimitiveType max) { return heap.AddString(BarScalarFunction(x.val, min.val, max.val, 80, buffer)); }); } else { - auto &width_arg = args.data[3]; + const auto &width_arg = args.data[3]; GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, PrimitiveType, PrimitiveType>( - x_arg, min_arg, max_arg, width_arg, result, args.size(), + x_arg, min_arg, max_arg, width_arg, result, [&](PrimitiveType x, PrimitiveType min, PrimitiveType max, PrimitiveType width) { return heap.AddString(BarScalarFunction(x.val, min.val, max.val, width.val, buffer)); diff --git a/src/duckdb/extension/core_functions/scalar/string/chr.cpp b/src/duckdb/extension/core_functions/scalar/string/chr.cpp index 20e0c5a9f..50b67b519 100644 --- a/src/duckdb/extension/core_functions/scalar/string/chr.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/chr.cpp @@ -24,12 +24,12 @@ struct ChrOperator { // the chr function depends on the data always being inlined (which is always possible, since it outputs max 4 bytes) // to enable chr when string inlining is disabled we create a special function here static void ChrFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &code_vec = args.data[0]; + const auto &code_vec = args.data[0]; char c[5] = {'\0', '\0', '\0', '\0', '\0'}; int utf8_bytes; auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(code_vec, result, args.size(), [&](int32_t input) { + UnaryExecutor::Execute(code_vec, result, [&](int32_t input) { ChrOperator::GetCodepoint(input, c, utf8_bytes); return heap.AddString(&c[0], UnsafeNumericCast(utf8_bytes)); }); diff --git a/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp b/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp index 91b0fbd33..af7cb4b00 100644 --- a/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp @@ -88,11 +88,11 @@ static int64_t DamerauLevenshteinScalarFunction(Vector &result, const string_t s } static void DamerauLevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &source_vec = args.data[0]; - auto &target_vec = args.data[1]; + const auto &source_vec = args.data[0]; + const auto &target_vec = args.data[1]; BinaryExecutor::Execute( - source_vec, target_vec, result, args.size(), + source_vec, target_vec, result, [&](string_t source, string_t target) { return DamerauLevenshteinScalarFunction(result, source, target); }); } diff --git a/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp b/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp index 87572808f..12e18f8b2 100644 --- a/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp @@ -7,7 +7,7 @@ namespace duckdb { template static void FormatBytesFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](int64_t bytes) { + UnaryExecutor::Execute(args.data[0], result, [&](int64_t bytes) { bool is_negative = bytes < 0; idx_t unsigned_bytes; if (bytes < 0) { diff --git a/src/duckdb/extension/core_functions/scalar/string/hamming.cpp b/src/duckdb/extension/core_functions/scalar/string/hamming.cpp index b32a80199..552c1ce22 100644 --- a/src/duckdb/extension/core_functions/scalar/string/hamming.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/hamming.cpp @@ -30,12 +30,12 @@ static int64_t MismatchesScalarFunction(Vector &result, const string_t str, stri } static void MismatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; + const auto &str_vec = args.data[0]; + const auto &tgt_vec = args.data[1]; - BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return MismatchesScalarFunction(result, str, tgt); }); + BinaryExecutor::Execute(str_vec, tgt_vec, result, [&](string_t str, string_t tgt) { + return MismatchesScalarFunction(result, str, tgt); + }); } ScalarFunction HammingFun::GetFunction() { diff --git a/src/duckdb/extension/core_functions/scalar/string/hex.cpp b/src/duckdb/extension/core_functions/scalar/string/hex.cpp index 694355041..cf6eeb063 100644 --- a/src/duckdb/extension/core_functions/scalar/string/hex.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/hex.cpp @@ -170,9 +170,8 @@ struct HexUhugeIntOperator { template static void ToHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; - idx_t count = args.size(); - UnaryExecutor::ExecuteString(input, result, count); + const auto &input = args.data[0]; + UnaryExecutor::ExecuteString(input, result); } struct BinaryStrOperator { @@ -361,27 +360,22 @@ struct FromBinaryOperator { template static void ToBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; - idx_t count = args.size(); - UnaryExecutor::ExecuteString(input, result, count); + const auto &input = args.data[0]; + UnaryExecutor::ExecuteString(input, result); } static void FromBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); - auto &input = args.data[0]; - idx_t count = args.size(); - - UnaryExecutor::ExecuteString(input, result, count); + const auto &input = args.data[0]; + UnaryExecutor::ExecuteString(input, result); } static void FromHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); - auto &input = args.data[0]; - idx_t count = args.size(); - - UnaryExecutor::ExecuteString(input, result, count); + const auto &input = args.data[0]; + UnaryExecutor::ExecuteString(input, result); } ScalarFunctionSet HexFun::GetFunctions() { diff --git a/src/duckdb/extension/core_functions/scalar/string/instr.cpp b/src/duckdb/extension/core_functions/scalar/string/instr.cpp index e7f819612..0d902fbab 100644 --- a/src/duckdb/extension/core_functions/scalar/string/instr.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/instr.cpp @@ -44,7 +44,7 @@ static unique_ptr InStrPropagateStats(ClientContext &context, Fu // can only propagate stats if the children have stats // for strpos, we only care if the FIRST string has unicode or not if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.SetFunctionCallback( + expr.FunctionMutable().SetFunctionCallback( ScalarFunction::BinaryFunction); } return nullptr; diff --git a/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp b/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp index eae31dc9c..6b58e65c6 100644 --- a/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp @@ -43,12 +43,11 @@ static double JaccardScalarFunction(Vector &result, const string_t str, string_t } static void JaccardFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; + const auto &str_vec = args.data[0]; + const auto &tgt_vec = args.data[1]; BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return JaccardScalarFunction(result, str, tgt); }); + str_vec, tgt_vec, result, [&](string_t str, string_t tgt) { return JaccardScalarFunction(result, str, tgt); }); } ScalarFunction JaccardFun::GetFunction() { diff --git a/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp b/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp index 2933940f7..fd82344c7 100644 --- a/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp @@ -22,10 +22,9 @@ static inline double JaroWinklerScalarFunction(const string_t &s1, const string_ template static void CachedFunction(Vector &constant, Vector &other, Vector &result, DataChunk &args) { auto val = constant.GetValue(0); - idx_t count = args.size(); if (val.IsNull()) { auto &result_validity = FlatVector::ValidityMutable(result); - result_validity.SetAllInvalid(count); + result_validity.SetAllInvalid(other.size()); return; } @@ -34,14 +33,14 @@ static void CachedFunction(Vector &constant, Vector &other, Vector &result, Data D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3); if (args.ColumnCount() == 2) { - UnaryExecutor::Execute(other, result, count, [&](const string_t &other_str) { + UnaryExecutor::Execute(other, result, [&](const string_t &other_str) { auto other_str_begin = other_str.GetData(); return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize()); }); } else { - auto &score_cutoff = args.data[2]; + const auto &score_cutoff = args.data[2]; BinaryExecutor::Execute( - other, score_cutoff, result, count, [&](const string_t &other_str, const double_t score_cutoff) { + other, score_cutoff, result, [&](const string_t &other_str, const double_t score_cutoff) { auto other_str_begin = other_str.GetData(); return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize(), score_cutoff); }); @@ -57,12 +56,12 @@ static void TemplatedJaroWinklerFunction(DataChunk &args, Vector &result, SIMILA D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3); if (args.ColumnCount() == 2) { BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), + args.data[0], args.data[1], result, [&](const string_t &s1, const string_t &s2) { return fun(s1, s2, 0.0); }); return; } else { TernaryExecutor::Execute(args.data[0], args.data[1], args.data[2], - result, args.size(), fun); + result, fun); return; } } diff --git a/src/duckdb/extension/core_functions/scalar/string/left_right.cpp b/src/duckdb/extension/core_functions/scalar/string/left_right.cpp index b13ff9560..68b0c77a1 100644 --- a/src/duckdb/extension/core_functions/scalar/string/left_right.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/left_right.cpp @@ -43,12 +43,11 @@ static string_t LeftScalarFunction(Vector &result, const string_t str, int64_t p template static void LeftFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &pos_vec = args.data[1]; + const auto &str_vec = args.data[0]; + const auto &pos_vec = args.data[1]; BinaryExecutor::Execute( - str_vec, pos_vec, result, args.size(), - [&](string_t str, int64_t pos) { return LeftScalarFunction(result, str, pos); }); + str_vec, pos_vec, result, [&](string_t str, int64_t pos) { return LeftScalarFunction(result, str, pos); }); } ScalarFunction LeftFun::GetFunction() { @@ -80,11 +79,10 @@ static string_t RightScalarFunction(Vector &result, const string_t str, int64_t template static void RightFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &pos_vec = args.data[1]; + const auto &str_vec = args.data[0]; + const auto &pos_vec = args.data[1]; BinaryExecutor::Execute( - str_vec, pos_vec, result, args.size(), - [&](string_t str, int64_t pos) { return RightScalarFunction(result, str, pos); }); + str_vec, pos_vec, result, [&](string_t str, int64_t pos) { return RightScalarFunction(result, str, pos); }); } ScalarFunction RightFun::GetFunction() { diff --git a/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp b/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp index 24e28b89c..979087b78 100644 --- a/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp @@ -69,12 +69,12 @@ static int64_t LevenshteinScalarFunction(Vector &result, const string_t str, str } static void LevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; + const auto &str_vec = args.data[0]; + const auto &tgt_vec = args.data[1]; - BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return LevenshteinScalarFunction(result, str, tgt); }); + BinaryExecutor::Execute(str_vec, tgt_vec, result, [&](string_t str, string_t tgt) { + return LevenshteinScalarFunction(result, str, tgt); + }); } ScalarFunction LevenshteinFun::GetFunction() { diff --git a/src/duckdb/extension/core_functions/scalar/string/pad.cpp b/src/duckdb/extension/core_functions/scalar/string/pad.cpp index c73cbe41c..b8c7e1003 100644 --- a/src/duckdb/extension/core_functions/scalar/string/pad.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/pad.cpp @@ -118,14 +118,14 @@ struct RightPadOperator { template static void PadFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vector = args.data[0]; - auto &len_vector = args.data[1]; - auto &pad_vector = args.data[2]; + const auto &str_vector = args.data[0]; + const auto &len_vector = args.data[1]; + const auto &pad_vector = args.data[2]; vector buffer; auto &heap = StringVector::GetStringHeap(result); TernaryExecutor::Execute( - str_vector, len_vector, pad_vector, result, args.size(), [&](string_t str, int32_t len, string_t pad) { + str_vector, len_vector, pad_vector, result, [&](string_t str, int32_t len, string_t pad) { len = MaxValue(len, 0); return heap.AddString(OP::Operation(str, len, pad, buffer)); }); diff --git a/src/duckdb/extension/core_functions/scalar/string/parse_formatted_bytes.cpp b/src/duckdb/extension/core_functions/scalar/string/parse_formatted_bytes.cpp index 2a37d4e7d..a65967ca0 100644 --- a/src/duckdb/extension/core_functions/scalar/string/parse_formatted_bytes.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/parse_formatted_bytes.cpp @@ -4,8 +4,8 @@ namespace duckdb { static void ParseFormattedBytesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &arg0 = args.data[0]; - UnaryExecutor::Execute(arg0, result, args.size(), [&](string_t str) { + const auto &arg0 = args.data[0]; + UnaryExecutor::Execute(arg0, result, [&](string_t str) { // Invalid input exceptions thrown from ParseFormattedBytes won't be handled but will be thrown as is return StringUtil::ParseFormattedBytes(str.GetString()); }); diff --git a/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp b/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp index aaa61dbe3..076333e5d 100644 --- a/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp @@ -155,14 +155,13 @@ static void ReadOptionalArgs(DataChunk &args, Vector &sep, Vector &trim, const b template static void TrimPathFunction(DataChunk &args, ExpressionState &state, Vector &result) { // set default values - Vector &path = args.data[0]; + const Vector &path = args.data[0]; Vector separator(string_t("default"), count_t(args.size())); Vector trim_extension(Value::BOOLEAN(false), count_t(args.size())); ReadOptionalArgs(args, separator, trim_extension, FRONT_TRIM); TernaryExecutor::Execute( - path, separator, trim_extension, result, args.size(), - [&](string_t &inputs, string_t input_sep, bool trim_extension) { + path, separator, trim_extension, result, [&](string_t &inputs, string_t input_sep, bool trim_extension) { auto data = inputs.GetData(); auto input_size = inputs.GetSize(); auto sep = GetSeparator(input_sep.GetString()); @@ -201,14 +200,14 @@ static void TrimPathFunction(DataChunk &args, ExpressionState &state, Vector &re static void ParseDirpathFunction(DataChunk &args, ExpressionState &state, Vector &result) { // set default values - Vector &path = args.data[0]; + const Vector &path = args.data[0]; Vector separator(string_t("default"), count_t(args.size())); Vector trim_extension(Value::BOOLEAN(false), count_t(args.size())); ReadOptionalArgs(args, separator, trim_extension, true); auto &heap = StringVector::GetStringHeap(result); BinaryExecutor::Execute( - path, separator, result, args.size(), [&](string_t input_path, string_t input_sep) { + path, separator, result, [&](string_t input_path, string_t input_sep) { auto path = input_path.GetData(); auto path_size = input_path.GetSize(); auto sep = GetSeparator(input_sep.GetString()); diff --git a/src/duckdb/extension/core_functions/scalar/string/printf.cpp b/src/duckdb/extension/core_functions/scalar/string/printf.cpp index 3956c1539..4f947e38e 100644 --- a/src/duckdb/extension/core_functions/scalar/string/printf.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/printf.cpp @@ -88,10 +88,10 @@ struct StringConstructArgument { }; template -static void ConvertArguments(const Vector &input, idx_t count, idx_t arg_idx, +static void ConvertArguments(const Vector &input, idx_t arg_idx, vector>> &result_args) { auto result = input.Values(); - for (idx_t i = 0; i < count; i++) { + for (idx_t i = 0; i < input.size(); i++) { auto &args = result_args[i]; if (args.size() != arg_idx - 1) { // this entry has a NULL as one of the parameters @@ -117,40 +117,40 @@ static void PrintfFunction(DataChunk &args, ExpressionState &state, Vector &resu auto format_data = args.data[0].Values(); for (idx_t i = 1; i < args.ColumnCount(); i++) { - auto &col = args.data[i]; + const auto &col = args.data[i]; switch (col.GetType().id()) { case LogicalTypeId::BOOLEAN: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::TINYINT: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::SMALLINT: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::INTEGER: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::BIGINT: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::UBIGINT: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::FLOAT: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::HUGEINT: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::UHUGEINT: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::DOUBLE: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; case LogicalTypeId::VARCHAR: - ConvertArguments(col, count, i, format_args); + ConvertArguments(col, i, format_args); break; default: throw InternalException("Unexpected type for printf format"); diff --git a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp index 3fe09df50..dfadf9bb0 100644 --- a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp @@ -7,11 +7,11 @@ namespace duckdb { static void RepeatFunction(DataChunk &args, ExpressionState &, Vector &result) { - auto &str_vector = args.data[0]; - auto &cnt_vector = args.data[1]; + const auto &str_vector = args.data[0]; + const auto &cnt_vector = args.data[1]; BinaryExecutor::Execute( - str_vector, cnt_vector, result, args.size(), [&](string_t str, int64_t cnt) { + str_vector, cnt_vector, result, [&](string_t str, int64_t cnt) { auto input_str = str.GetData(); auto size_str = str.GetSize(); idx_t copy_count = cnt <= 0 || size_str == 0 ? 0 : UnsafeNumericCast(cnt); diff --git a/src/duckdb/extension/core_functions/scalar/string/replace.cpp b/src/duckdb/extension/core_functions/scalar/string/replace.cpp index 0e37ef422..3ebfaa99e 100644 --- a/src/duckdb/extension/core_functions/scalar/string/replace.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/replace.cpp @@ -63,14 +63,14 @@ static string_t ReplaceScalarFunction(const string_t &haystack, const string_t & } static void ReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &haystack_vector = args.data[0]; - auto &needle_vector = args.data[1]; - auto &thread_vector = args.data[2]; + const auto &haystack_vector = args.data[0]; + const auto &needle_vector = args.data[1]; + const auto &thread_vector = args.data[2]; vector buffer; auto &heap = StringVector::GetStringHeap(result); TernaryExecutor::Execute( - haystack_vector, needle_vector, thread_vector, result, args.size(), + haystack_vector, needle_vector, thread_vector, result, [&](string_t input_string, string_t needle_string, string_t thread_string) { return heap.AddString(ReplaceScalarFunction(input_string, needle_string, thread_string, buffer)); }); diff --git a/src/duckdb/extension/core_functions/scalar/string/reverse.cpp b/src/duckdb/extension/core_functions/scalar/string/reverse.cpp index 52eeaddc2..8f43e6d6d 100644 --- a/src/duckdb/extension/core_functions/scalar/string/reverse.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/reverse.cpp @@ -45,7 +45,7 @@ struct ReverseOperator { }; static void ReverseFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString(args.data[0], result); } ScalarFunction ReverseFun::GetFunction() { diff --git a/src/duckdb/extension/core_functions/scalar/string/to_base.cpp b/src/duckdb/extension/core_functions/scalar/string/to_base.cpp index 740fbcc9f..4869e7caa 100644 --- a/src/duckdb/extension/core_functions/scalar/string/to_base.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/to_base.cpp @@ -18,14 +18,13 @@ static unique_ptr ToBaseBind(BindScalarFunctionInput &input) { } static void ToBaseFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - auto &radix = args.data[1]; - auto &min_length = args.data[2]; - auto count = args.size(); + const auto &input = args.data[0]; + const auto &radix = args.data[1]; + const auto &min_length = args.data[2]; auto &heap = StringVector::GetStringHeap(result); TernaryExecutor::Execute( - input, radix, min_length, result, count, [&](int64_t input, int32_t radix, int32_t min_length) { + input, radix, min_length, result, [&](int64_t input, int32_t radix, int32_t min_length) { if (input < 0) { throw InvalidInputException("'to_base' number must be greater than or equal to 0"); } diff --git a/src/duckdb/extension/core_functions/scalar/string/translate.cpp b/src/duckdb/extension/core_functions/scalar/string/translate.cpp index e93874b2e..ac8688d5e 100644 --- a/src/duckdb/extension/core_functions/scalar/string/translate.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/translate.cpp @@ -75,14 +75,14 @@ static string_t TranslateScalarFunction(const string_t &haystack, const string_t } static void TranslateFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &haystack_vector = args.data[0]; - auto &needle_vector = args.data[1]; - auto &thread_vector = args.data[2]; + const auto &haystack_vector = args.data[0]; + const auto &needle_vector = args.data[1]; + const auto &thread_vector = args.data[2]; vector buffer; auto &heap = StringVector::GetStringHeap(result); TernaryExecutor::Execute( - haystack_vector, needle_vector, thread_vector, result, args.size(), + haystack_vector, needle_vector, thread_vector, result, [&](string_t input_string, string_t needle_string, string_t thread_string) { return heap.AddString(TranslateScalarFunction(input_string, needle_string, thread_string, buffer)); }); diff --git a/src/duckdb/extension/core_functions/scalar/string/trim.cpp b/src/duckdb/extension/core_functions/scalar/string/trim.cpp index 853851c1f..7c6a83d01 100644 --- a/src/duckdb/extension/core_functions/scalar/string/trim.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/trim.cpp @@ -61,7 +61,7 @@ struct TrimOperator { template static void UnaryTrimFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString>(args.data[0], result); } static void GetIgnoredCodepoints(string_t ignored, unordered_set &ignored_codepoints) { @@ -79,7 +79,7 @@ static void GetIgnoredCodepoints(string_t ignored, unordered_set static void BinaryTrimFunction(DataChunk &input, ExpressionState &state, Vector &result) { BinaryExecutor::Execute( - input.data[0], input.data[1], result, input.size(), [&](string_t input, string_t ignored) { + input.data[0], input.data[1], result, [&](string_t input, string_t ignored) { auto data = input.GetData(); auto size = input.GetSize(); diff --git a/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp b/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp index 18fa9a913..8bc716bf2 100644 --- a/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp @@ -18,7 +18,7 @@ struct URLEncodeOperator { }; static void URLEncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString(args.data[0], result); } ScalarFunction UrlEncodeFun::GetFunction() { @@ -39,7 +39,7 @@ struct URLDecodeOperator { }; static void URLDecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString(args.data[0], result); } ScalarFunction UrlDecodeFun::GetFunction() { diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp index cef63e9c3..2ead1a438 100644 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp @@ -12,7 +12,7 @@ namespace duckdb { static void StructInsertFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &starting_vec = args.data[0]; + const auto &starting_vec = args.data[0]; starting_vec.Verify(); auto &starting_child_entries = StructVector::GetEntries(starting_vec); @@ -43,7 +43,7 @@ static unique_ptr StructInsertBind(BindScalarFunctionInput &input) throw InvalidInputException("Can't insert nothing into a STRUCT"); } - case_insensitive_set_t name_collision_set; + identifier_set_t name_collision_set; child_list_t new_children; auto &existing_children = StructType::GetChildTypes(arguments[0]->GetReturnType()); @@ -63,7 +63,7 @@ static unique_ptr StructInsertBind(BindScalarFunctionInput &input) throw BinderException("Duplicate struct entry name \"%s\"", child->GetAlias()); } name_collision_set.insert(child->GetAlias()); - new_children.push_back(make_pair(child->GetAlias(), arguments[i]->GetReturnType())); + new_children.emplace_back(make_pair(child->GetAlias(), arguments[i]->GetReturnType())); } bound_function.SetReturnType(LogicalType::STRUCT(new_children)); diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_keys.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_keys.cpp index b59ac4369..c0892e242 100644 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_keys.cpp +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_keys.cpp @@ -17,7 +17,7 @@ struct StructKeysBindData : public FunctionData { auto list_writer = FlatVector::Writer>(keys_vector, 2); idx_t idx = 0; for (auto &child_writer : list_writer.WriteList(count)) { - child_writer.WriteValue(string_t(child_types[idx++].first)); + child_writer.WriteValue(string_t(child_types[idx++].first.GetIdentifierName())); } list_writer.WriteNull(); } @@ -34,10 +34,10 @@ struct StructKeysBindData : public FunctionData { }; static void StructKeysFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; + const auto &input = args.data[0]; const idx_t count = args.size(); - auto &data = state.expr.Cast().bind_info->Cast(); + auto &data = state.expr.Cast().BindInfo()->Cast(); auto &keys_vector = data.keys_vector; // If the input is a constant, we must return a CONSTANT_VECTOR @@ -60,7 +60,7 @@ static void StructKeysFunction(DataChunk &args, ExpressionState &state, Vector & static unique_ptr StructKeysBind(BindScalarFunctionInput &input) { auto &arguments = input.GetArguments(); auto return_type = arguments[0]->GetReturnType(); - if (return_type.id() != LogicalTypeId::STRUCT && !return_type.IsAggregateStateStructType()) { + if (return_type.id() != LogicalTypeId::STRUCT) { throw InvalidInputException("struct_keys() expects a STRUCT argument"); } diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp index eb867000f..a65d61014 100644 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_update.cpp @@ -12,7 +12,7 @@ namespace duckdb { static void StructUpdateFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &starting_vec = args.data[0]; + const auto &starting_vec = args.data[0]; starting_vec.Verify(); auto &starting_child_entries = StructVector::GetEntries(starting_vec); @@ -20,8 +20,8 @@ static void StructUpdateFunction(DataChunk &args, ExpressionState &state, Vector auto &starting_types = StructType::GetChildTypes(starting_vec.GetType()); - auto &func_args = state.expr.Cast().children; - auto new_entries = case_insensitive_tree_t(); + auto &func_args = state.expr.Cast().GetChildren(); + auto new_entries = identifier_tree_t(); auto is_new_field = vector(args.ColumnCount(), true); for (idx_t arg_idx = 1; arg_idx < func_args.size(); arg_idx++) { @@ -69,7 +69,7 @@ static unique_ptr StructUpdateBind(BindScalarFunctionInput &input) child_list_t new_children; auto &existing_children = StructType::GetChildTypes(arguments[0]->GetReturnType()); - auto incoming_children = case_insensitive_tree_t(); + auto incoming_children = identifier_tree_t(); auto is_new_field = vector(arguments.size(), true); // Validate incoming arguments and record names @@ -93,7 +93,7 @@ static unique_ptr StructUpdateBind(BindScalarFunctionInput &input) // Update the struct with the new data of the same name auto arg_idx = update->second; auto &new_child = arguments[arg_idx]; - new_children.push_back(make_pair(new_child->GetAlias(), new_child->GetReturnType())); + new_children.emplace_back(make_pair(new_child->GetAlias(), new_child->GetReturnType())); is_new_field[arg_idx] = false; } } @@ -102,7 +102,7 @@ static unique_ptr StructUpdateBind(BindScalarFunctionInput &input) for (idx_t arg_idx = 1; arg_idx < arguments.size(); arg_idx++) { if (is_new_field[arg_idx]) { auto &child = arguments[arg_idx]; - new_children.push_back(make_pair(child->GetAlias(), child->GetReturnType())); + new_children.emplace_back(make_pair(child->GetAlias(), child->GetReturnType())); } } @@ -114,12 +114,12 @@ static unique_ptr StructUpdateStats(ClientContext &context, Func auto &child_stats = input.child_stats; auto &expr = input.expr; - auto incoming_children = case_insensitive_tree_t(); - auto is_new_field = vector(expr.children.size(), true); + auto incoming_children = identifier_tree_t(); + auto is_new_field = vector(expr.GetChildren().size(), true); auto new_stats = StructStats::CreateUnknown(expr.GetReturnType()); - for (idx_t arg_idx = 1; arg_idx < expr.children.size(); arg_idx++) { - auto &new_child = expr.children[arg_idx]; + for (idx_t arg_idx = 1; arg_idx < expr.GetChildren().size(); arg_idx++) { + auto &new_child = expr.GetChildren()[arg_idx]; incoming_children.emplace(new_child->GetAlias(), arg_idx); } @@ -138,7 +138,7 @@ static unique_ptr StructUpdateStats(ClientContext &context, Func } } - for (idx_t arg_idx = 1, field_idx = existing_count; arg_idx < expr.children.size(); arg_idx++) { + for (idx_t arg_idx = 1, field_idx = existing_count; arg_idx < expr.GetChildren().size(); arg_idx++) { if (is_new_field[arg_idx]) { StructStats::SetChildStats(new_stats, field_idx++, child_stats[arg_idx]); } diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_values.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_values.cpp index 6d7205d2c..58812fbf7 100644 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_values.cpp +++ b/src/duckdb/extension/core_functions/scalar/struct/struct_values.cpp @@ -9,7 +9,7 @@ namespace duckdb { static void StructValuesFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; + const auto &input = args.data[0]; const idx_t count = args.size(); auto &input_children = StructVector::GetEntries(input); diff --git a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp index 784de169e..952e3351b 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp @@ -10,7 +10,7 @@ namespace duckdb { namespace { struct UnionExtractBindData : public FunctionData { - UnionExtractBindData(string key, idx_t index, LogicalType type) + UnionExtractBindData(Identifier key, idx_t index, LogicalType type) : key(std::move(key)), index(index), type(std::move(type)) { } @@ -20,7 +20,7 @@ struct UnionExtractBindData : public FunctionData { public: unique_ptr Copy() const override { - return make_uniq(key, index, type); + return make_uniq(Identifier(key), index, type); } bool Equals(const FunctionData &other_p) const override { auto &other = other_p.Cast(); @@ -30,10 +30,10 @@ struct UnionExtractBindData : public FunctionData { void UnionExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); // this should be guaranteed by the binder - auto &vec = args.data[0]; + const auto &vec = args.data[0]; vec.Verify(); D_ASSERT(info.index < UnionType::GetMemberCount(vec.GetType())); @@ -73,7 +73,7 @@ unique_ptr UnionExtractBind(BindScalarFunctionInput &input) { if (key_val.IsNull() || key_str.empty()) { throw BinderException("Key name for union_extract needs to be neither NULL nor empty"); } - string key = StringUtil::Lower(key_str); + auto key = Identifier(key_str); LogicalType return_type; idx_t key_index = 0; @@ -81,7 +81,7 @@ unique_ptr UnionExtractBind(BindScalarFunctionInput &input) { for (size_t i = 0; i < union_member_count; i++) { auto &member_name = UnionType::GetMemberName(arguments[0]->GetReturnType(), i); - if (StringUtil::Lower(member_name) == key) { + if (member_name == key) { found_key = true; key_index = i; return_type = UnionType::GetMemberType(arguments[0]->GetReturnType(), i); @@ -93,11 +93,11 @@ unique_ptr UnionExtractBind(BindScalarFunctionInput &input) { vector candidates; candidates.reserve(union_member_count); for (idx_t i = 0; i < union_member_count; i++) { - candidates.push_back(UnionType::GetMemberName(arguments[0]->GetReturnType(), i)); + candidates.emplace_back(UnionType::GetMemberName(arguments[0]->GetReturnType(), i)); } auto closest_settings = StringUtil::TopNJaroWinkler(candidates, key); auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); - throw BinderException("Could not find key \"%s\" in union\n%s", key, message); + throw BinderException("Could not find key \"%s\" in union\n%s", key.GetIdentifierName(), message); } bound_function.SetReturnType(return_type); diff --git a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp index 579c7e5d8..61dc898c2 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp @@ -39,7 +39,8 @@ unique_ptr UnionTagBind(BindScalarFunctionInput &input) { auto varchar_vector = Vector(LogicalType::VARCHAR, member_count); auto result_data = FlatVector::Writer(varchar_vector, member_count); for (idx_t i = 0; i < member_count; i++) { - result_data.WriteValue(string_t(UnionType::GetMemberName(arguments[0]->GetReturnType(), i))); + result_data.WriteValue( + string_t(UnionType::GetMemberName(arguments[0]->GetReturnType(), i).GetIdentifierName())); } auto enum_type = LogicalType::ENUM(varchar_vector, member_count); bound_function.SetReturnType(enum_type); diff --git a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp index 5f9eba8d5..374931b27 100644 --- a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp @@ -47,7 +47,7 @@ unique_ptr UnionValueBind(BindScalarFunctionInput &input) { child_list_t union_members; - union_members.push_back(make_pair(child->GetAlias(), child->GetReturnType())); + union_members.emplace_back(make_pair(child->GetAlias(), child->GetReturnType())); bound_function.SetReturnType(LogicalType::UNION(std::move(union_members))); return make_uniq(bound_function.GetReturnType()); diff --git a/src/duckdb/extension/icu/icu-current.cpp b/src/duckdb/extension/icu/icu-current.cpp index bd54ea73f..7c8dd2ad3 100644 --- a/src/duckdb/extension/icu/icu-current.cpp +++ b/src/duckdb/extension/icu/icu-current.cpp @@ -54,7 +54,7 @@ void RegisterICUCurrentFunctions(ExtensionLoader &loader) { current_date.AddFunction(GetCurrentDateFun()); loader.RegisterFunction(current_date); - current_date.name = "today"; + current_date.SetName("today"); loader.RegisterFunction(current_date); } diff --git a/src/duckdb/extension/icu/icu-dateadd.cpp b/src/duckdb/extension/icu/icu-dateadd.cpp index 4c0aa0c3f..2c99ea7ce 100644 --- a/src/duckdb/extension/icu/icu-dateadd.cpp +++ b/src/duckdb/extension/icu/icu-dateadd.cpp @@ -222,13 +222,13 @@ struct ICUDateAdd : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 1); auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); TZCalendar calendar(*info.calendar, info.cal_setting); // Subtract argument from current_date (at midnight) const auto end_date = CurrentMidnight(calendar.GetICUCalendar(), state); - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](TA start_date) { + UnaryExecutor::Execute(args.data[0], result, [&](TA start_date) { return OP::template Operation(end_date, start_date, calendar); }); } @@ -244,10 +244,10 @@ struct ICUDateAdd : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 2); auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); TZCalendar calendar(*info.calendar, info.cal_setting); - BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), [&](TA left, TB right) { + BinaryExecutor::Execute(args.data[0], args.data[1], result, [&](TA left, TB right) { return OP::template Operation(left, right, calendar); }); } @@ -264,9 +264,9 @@ struct ICUDateAdd : public ICUDateFunc { return GetBinaryDateFunction(left_type, right_type, LogicalType::TIMESTAMP_TZ); } - static void AddDateAddOperators(const string &name, ExtensionLoader &loader) { + static void AddDateAddOperators(const Identifier &name, ExtensionLoader &loader) { // temporal + interval - ScalarFunctionSet set(name); + ScalarFunctionSet set {name}; set.AddFunction(GetDateAddFunction(LogicalType::TIMESTAMP_TZ, LogicalType::INTERVAL)); set.AddFunction(GetDateAddFunction(LogicalType::INTERVAL, @@ -284,9 +284,9 @@ struct ICUDateAdd : public ICUDateFunc { return GetBinaryDateFunction(left_type, right_type, LogicalType::INTERVAL); } - static void AddDateSubOperators(const string &name, ExtensionLoader &loader) { + static void AddDateSubOperators(const Identifier &name, ExtensionLoader &loader) { // temporal - interval - ScalarFunctionSet set(name); + ScalarFunctionSet set {name}; set.AddFunction(GetDateAddFunction(LogicalType::TIMESTAMP_TZ, LogicalType::INTERVAL)); @@ -296,9 +296,9 @@ struct ICUDateAdd : public ICUDateFunc { loader.RegisterFunction(set); } - static void AddDateAgeFunctions(const string &name, ExtensionLoader &loader) { + static void AddDateAgeFunctions(const Identifier &name, ExtensionLoader &loader) { // age(temporal, temporal) - ScalarFunctionSet set(name); + ScalarFunctionSet set {name}; set.AddFunction(GetBinaryAgeFunction( LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP_TZ)); set.AddFunction(GetUnaryAgeFunction(LogicalType::TIMESTAMP_TZ)); diff --git a/src/duckdb/extension/icu/icu-datepart.cpp b/src/duckdb/extension/icu/icu-datepart.cpp index 12eeff2e8..adc458909 100644 --- a/src/duckdb/extension/icu/icu-datepart.cpp +++ b/src/duckdb/extension/icu/icu-datepart.cpp @@ -278,14 +278,14 @@ struct ICUDatePart : public ICUDateFunc { static void UnaryTimestampFunction(DataChunk &args, ExpressionState &state, Vector &result) { using BIND_TYPE = BindAdapterData; D_ASSERT(args.ColumnCount() == 1); - auto &date_arg = args.data[0]; + const auto &date_arg = args.data[0]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); - UnaryExecutor::Execute(date_arg, result, args.size(), + UnaryExecutor::Execute(date_arg, result, [&](INPUT_TYPE input) -> optional { if (input.IsFinite()) { const auto micros = SetTime(calendar, input); @@ -300,17 +300,16 @@ struct ICUDatePart : public ICUDateFunc { static void BinaryTimestampFunction(DataChunk &args, ExpressionState &state, Vector &result) { using BIND_TYPE = BindAdapterData; D_ASSERT(args.ColumnCount() == 2); - auto &part_arg = args.data[0]; - auto &date_arg = args.data[1]; + const auto &part_arg = args.data[0]; + const auto &date_arg = args.data[1]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); BinaryExecutor::Execute( - part_arg, date_arg, result, args.size(), - [&](string_t specifier, INPUT_TYPE input) -> optional { + part_arg, date_arg, result, [&](string_t specifier, INPUT_TYPE input) -> optional { if (input.IsFinite()) { const auto micros = SetTime(calendar, input); auto adapter = PartCodeBigintFactory(GetDatePartSpecifier(specifier.GetString())); @@ -370,13 +369,13 @@ struct ICUDatePart : public ICUDateFunc { template static void StructFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); D_ASSERT(args.ColumnCount() == 1); const auto count = args.size(); - Vector &input = args.data[0]; + const Vector &input = args.data[0]; auto entries = input.Values(); result.SetVectorType(VectorType::FLAT_VECTOR); @@ -432,7 +431,7 @@ struct ICUDatePart : public ICUDateFunc { auto &context = input.GetClientContext(); auto &arguments = input.GetArguments(); - const auto part_code = GetDatePartSpecifier(bound_function.GetName()); + const auto part_code = GetDatePartSpecifier(bound_function.GetName().GetIdentifierName()); if (IsBigintDatepart(part_code)) { using data_t = BindAdapterData; auto adapter = PartCodeBigintFactory(part_code); @@ -468,7 +467,7 @@ struct ICUDatePart : public ICUDateFunc { arguments.erase(arguments.begin()); bound_function.GetArguments().erase(bound_function.GetArguments().begin()); - bound_function.SetName(part_name); + bound_function.SetName(Identifier(part_name)); bound_function.SetReturnType(LogicalType::DOUBLE); bound_function.SetFunctionCallback(UnaryTimestampFunction); @@ -556,10 +555,10 @@ struct ICUDatePart : public ICUDateFunc { } template - static void AddUnaryPartCodeFunctions(const string &name, ExtensionLoader &loader, + static void AddUnaryPartCodeFunctions(const Identifier &name, ExtensionLoader &loader, const LogicalType &result_type = LogicalType::BIGINT, ArgProperties unary_arg0_props = {}) { - ScalarFunctionSet set(name); + ScalarFunctionSet set {name}; set.AddFunction(GetUnaryPartCodeFunction(LogicalType::TIMESTAMP_TZ, result_type)); set.SetUnaryArgProperties(unary_arg0_props); loader.RegisterFunction(set); @@ -581,8 +580,8 @@ struct ICUDatePart : public ICUDateFunc { return result; } - static void AddDatePartFunctions(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddDatePartFunctions(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(GetBinaryPartCodeFunction(LogicalType::TIMESTAMP_TZ)); set.AddFunction(GetStructFunction(LogicalType::TIMESTAMP_TZ)); for (auto &func : set.functions) { @@ -605,8 +604,8 @@ struct ICUDatePart : public ICUDateFunc { return ScalarFunction({temporal_type}, LogicalType::DATE, UnaryTimestampFunction, BindLastDate); } - static void AddLastDayFunctions(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddLastDayFunctions(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(GetLastDayFunction(LogicalType::TIMESTAMP_TZ)); loader.RegisterFunction(set); } @@ -624,8 +623,8 @@ struct ICUDatePart : public ICUDateFunc { return ScalarFunction({temporal_type}, LogicalType::VARCHAR, UnaryTimestampFunction, BindMonthName); } - static void AddMonthNameFunctions(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddMonthNameFunctions(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(GetMonthNameFunction(LogicalType::TIMESTAMP_TZ)); loader.RegisterFunction(set); } @@ -643,8 +642,8 @@ struct ICUDatePart : public ICUDateFunc { return ScalarFunction({temporal_type}, LogicalType::VARCHAR, UnaryTimestampFunction, BindDayName); } - static void AddDayNameFunctions(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddDayNameFunctions(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(GetDayNameFunction(LogicalType::TIMESTAMP_TZ)); loader.RegisterFunction(set); } diff --git a/src/duckdb/extension/icu/icu-datesub.cpp b/src/duckdb/extension/icu/icu-datesub.cpp index a7a03beaf..98f903359 100644 --- a/src/duckdb/extension/icu/icu-datesub.cpp +++ b/src/duckdb/extension/icu/icu-datesub.cpp @@ -90,12 +90,12 @@ struct ICUCalendarSub : public ICUDateFunc { template static void ICUDateSubFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 3); - auto &part_arg = args.data[0]; - auto &startdate_arg = args.data[1]; - auto &enddate_arg = args.data[2]; + const auto &part_arg = args.data[0]; + const auto &startdate_arg = args.data[1]; + const auto &enddate_arg = args.data[2]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar(info.calendar->clone()); if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -105,7 +105,7 @@ struct ICUCalendarSub : public ICUDateFunc { } const auto specifier = ConstantVector::GetData(part_arg)->GetString(); auto part_func = SubtractFactory(GetDatePartSpecifier(specifier)); - BinaryExecutor::Execute(startdate_arg, enddate_arg, result, args.size(), + BinaryExecutor::Execute(startdate_arg, enddate_arg, result, [&](T start_date, T end_date) -> optional { if (start_date.IsFinite() && end_date.IsFinite()) { return part_func(calendar.get(), start_date, end_date); @@ -115,7 +115,7 @@ struct ICUCalendarSub : public ICUDateFunc { }); } else { TernaryExecutor::Execute( - part_arg, startdate_arg, enddate_arg, result, args.size(), + part_arg, startdate_arg, enddate_arg, result, [&](string_t specifier, T start_date, T end_date) -> optional { if (start_date.IsFinite() && end_date.IsFinite()) { auto part_func = SubtractFactory(GetDatePartSpecifier(specifier.GetString())); @@ -132,8 +132,8 @@ struct ICUCalendarSub : public ICUDateFunc { return ScalarFunction({LogicalType::VARCHAR, type, type}, LogicalType::BIGINT, ICUDateSubFunction, Bind); } - static void AddFunctions(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddFunctions(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(GetFunction(LogicalType::TIMESTAMP_TZ)); set.SetArgProperties(1, ArgProperties().NonIncreasing()); set.SetArgProperties(2, ArgProperties().NonDecreasing()); @@ -217,12 +217,12 @@ struct ICUCalendarDiff : public ICUDateFunc { template static void ICUDateDiffFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 3); - auto &part_arg = args.data[0]; - auto &startdate_arg = args.data[1]; - auto &enddate_arg = args.data[2]; + const auto &part_arg = args.data[0]; + const auto &startdate_arg = args.data[1]; + const auto &enddate_arg = args.data[2]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); @@ -235,19 +235,18 @@ struct ICUCalendarDiff : public ICUDateFunc { const auto part = GetDatePartSpecifier(specifier); auto trunc_func = DiffTruncationFactory(part); auto sub_func = SubtractFactory(part); - BinaryExecutor::Execute(startdate_arg, enddate_arg, result, args.size(), - [&](T start_date, T end_date) -> optional { - if (start_date.IsFinite() && end_date.IsFinite()) { - return DifferenceFunc(calendar, start_date, end_date, - trunc_func, sub_func); - } else { - return nullopt; - } - }); + BinaryExecutor::Execute( + startdate_arg, enddate_arg, result, [&](T start_date, T end_date) -> optional { + if (start_date.IsFinite() && end_date.IsFinite()) { + return DifferenceFunc(calendar, start_date, end_date, trunc_func, sub_func); + } else { + return nullopt; + } + }); } } else { TernaryExecutor::Execute( - part_arg, startdate_arg, enddate_arg, result, args.size(), + part_arg, startdate_arg, enddate_arg, result, [&](string_t specifier, T start_date, T end_date) -> optional { if (start_date.IsFinite() && end_date.IsFinite()) { const auto part = GetDatePartSpecifier(specifier.GetString()); @@ -266,8 +265,8 @@ struct ICUCalendarDiff : public ICUDateFunc { return ScalarFunction({LogicalType::VARCHAR, type, type}, LogicalType::BIGINT, ICUDateDiffFunction, Bind); } - static void AddFunctions(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddFunctions(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(GetFunction(LogicalType::TIMESTAMP_TZ)); set.SetArgProperties(1, ArgProperties().NonIncreasing()); set.SetArgProperties(2, ArgProperties().NonDecreasing()); diff --git a/src/duckdb/extension/icu/icu-datetrunc.cpp b/src/duckdb/extension/icu/icu-datetrunc.cpp index a5ef9bfd8..06054d770 100644 --- a/src/duckdb/extension/icu/icu-datetrunc.cpp +++ b/src/duckdb/extension/icu/icu-datetrunc.cpp @@ -126,11 +126,11 @@ struct ICUDateTrunc : public ICUDateFunc { template static void ICUDateTruncFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 2); - auto &part_arg = args.data[0]; - auto &date_arg = args.data[1]; + const auto &part_arg = args.data[0]; + const auto &date_arg = args.data[1]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar(info.calendar->clone()); if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -140,7 +140,7 @@ struct ICUDateTrunc : public ICUDateFunc { } const auto specifier = ConstantVector::GetData(part_arg)->GetString(); auto truncator = TruncationFactory(GetDatePartSpecifier(specifier)); - UnaryExecutor::Execute(date_arg, result, args.size(), [&](T input) { + UnaryExecutor::Execute(date_arg, result, [&](T input) { if (input.IsFinite()) { auto micros = SetTime(calendar.get(), input); truncator(calendar.get(), micros); @@ -150,17 +150,16 @@ struct ICUDateTrunc : public ICUDateFunc { } }); } else { - BinaryExecutor::Execute( - part_arg, date_arg, result, args.size(), [&](string_t specifier, T input) { - if (input.IsFinite()) { - auto truncator = TruncationFactory(GetDatePartSpecifier(specifier.GetString())); - auto micros = SetTime(calendar.get(), input); - truncator(calendar.get(), micros); - return GetTimeUnsafe(calendar.get(), micros); - } else { - return input; - } - }); + BinaryExecutor::Execute(part_arg, date_arg, result, [&](string_t specifier, T input) { + if (input.IsFinite()) { + auto truncator = TruncationFactory(GetDatePartSpecifier(specifier.GetString())); + auto micros = SetTime(calendar.get(), input); + truncator(calendar.get(), micros); + return GetTimeUnsafe(calendar.get(), micros); + } else { + return input; + } + }); } } @@ -169,8 +168,8 @@ struct ICUDateTrunc : public ICUDateFunc { return ScalarFunction({LogicalType::VARCHAR, type}, LogicalType::TIMESTAMP_TZ, ICUDateTruncFunction, Bind); } - static void AddBinaryTimestampFunction(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddBinaryTimestampFunction(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(GetDateTruncFunction(LogicalType::TIMESTAMP_TZ)); set.SetArgProperties(1, ArgProperties().NonDecreasing()); loader.RegisterFunction(set); diff --git a/src/duckdb/extension/icu/icu-list-range.cpp b/src/duckdb/extension/icu/icu-list-range.cpp index 5c8e25948..ba4e6ddd7 100644 --- a/src/duckdb/extension/icu/icu-list-range.cpp +++ b/src/duckdb/extension/icu/icu-list-range.cpp @@ -126,7 +126,7 @@ struct ICUListRange : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 3); auto &func_expr = state.expr.Cast(); - auto &bind_info = func_expr.bind_info->Cast(); + auto &bind_info = func_expr.BindInfo()->Cast(); TZCalendar calendar(*bind_info.calendar, bind_info.cal_setting); RangeInfoStruct info(args); diff --git a/src/duckdb/extension/icu/icu-makedate.cpp b/src/duckdb/extension/icu/icu-makedate.cpp index c0debc32f..e7f3d6a98 100644 --- a/src/duckdb/extension/icu/icu-makedate.cpp +++ b/src/duckdb/extension/icu/icu-makedate.cpp @@ -97,7 +97,7 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { template static void FromMicros(DataChunk &input, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T micros) { + UnaryExecutor::Execute(input.data[0], result, [&](T micros) { const auto result = timestamp_t(micros); if (!result.IsFinite()) { throw ConversionException("Timestamp microseconds out of range: %ld", micros); @@ -109,7 +109,7 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { template static void Execute(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); @@ -157,8 +157,8 @@ struct ICUMakeTimestampTZFunc : public ICUDateFunc { return function; } - static void AddFunction(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddFunction(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(GetSenaryFunction(LogicalType::BIGINT)); set.AddFunction(GetSeptenaryFunction(LogicalType::BIGINT)); ScalarFunction function({LogicalType::BIGINT}, LogicalType::TIMESTAMP_TZ, FromMicros); diff --git a/src/duckdb/extension/icu/icu-strptime.cpp b/src/duckdb/extension/icu/icu-strptime.cpp index 2caddd805..48cd4da90 100644 --- a/src/duckdb/extension/icu/icu-strptime.cpp +++ b/src/duckdb/extension/icu/icu-strptime.cpp @@ -178,16 +178,16 @@ struct ICUStrptime : public ICUDateFunc { template static void Parse(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 2); - auto &str_arg = args.data[0]; - auto &fmt_arg = args.data[1]; + const auto &str_arg = args.data[0]; + const auto &fmt_arg = args.data[1]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); D_ASSERT(fmt_arg.GetVectorType() == VectorType::CONSTANT_VECTOR); - UnaryExecutor::Execute(str_arg, result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(str_arg, result, [&](string_t input) { T parsed; ParseOne(calendar, input, info.formats, parsed); return parsed; @@ -235,11 +235,11 @@ struct ICUStrptime : public ICUDateFunc { template static void TryParse(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 2); - auto &str_arg = args.data[0]; - auto &fmt_arg = args.data[1]; + const auto &str_arg = args.data[0]; + const auto &fmt_arg = args.data[1]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); @@ -248,7 +248,7 @@ struct ICUStrptime : public ICUDateFunc { if (ConstantVector::IsNull(fmt_arg)) { ConstantVector::SetNull(result, count_t(args.size())); } else { - UnaryExecutor::Execute(str_arg, result, args.size(), [&](string_t input) -> optional { + UnaryExecutor::Execute(str_arg, result, [&](string_t input) -> optional { T result; if (TryParseOne(calendar, input, info.formats, result)) { return result; @@ -335,7 +335,7 @@ struct ICUStrptime : public ICUDateFunc { return bound_function.GetBindCallback()(new_input); } - static void TailPatch(const string &name, ExtensionLoader &loader, const vector &types) { + static void TailPatch(const Identifier &name, ExtensionLoader &loader, const vector &types) { // Find the old function auto &scalar_function = loader.GetFunction(name); auto &functions = scalar_function.functions.functions; @@ -368,7 +368,7 @@ struct ICUStrptime : public ICUDateFunc { bound_function.SetBindCallback(StrpTimeBindFunction); } - static void AddBinaryTimestampFunction(const string &name, ExtensionLoader &loader) { + static void AddBinaryTimestampFunction(const Identifier &name, ExtensionLoader &loader) { vector types {LogicalType::VARCHAR, LogicalType::VARCHAR}; TailPatch(name, loader, types); @@ -513,7 +513,7 @@ struct ICUStrptime : public ICUDateFunc { bind_scalar_function_t ICUStrptime::bind_strptime = nullptr; // NOLINT struct ICUStrftime : public ICUDateFunc { - static void ParseFormatSpecifier(string_t &format_str, StrfTimeFormat &format) { + static void ParseFormatSpecifier(const string_t &format_str, StrfTimeFormat &format) { const auto format_specifier = format_str.GetString(); const auto error = StrTimeFormat::ParseFormatSpecifier(format_specifier, format); if (!error.empty()) { @@ -581,11 +581,11 @@ struct ICUStrftime : public ICUDateFunc { template static void ICUStrftimeFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 2); - auto &src_arg = args.data[0]; - auto &fmt_arg = args.data[1]; + const auto &src_arg = args.data[0]; + const auto &fmt_arg = args.data[1]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar(info.calendar->clone()); const auto tz_name = info.tz_setting.c_str(); @@ -597,30 +597,30 @@ struct ICUStrftime : public ICUDateFunc { StrfTimeFormat format; ParseFormatSpecifier(*ConstantVector::GetData(fmt_arg), format); - UnaryExecutor::Execute(src_arg, result, args.size(), [&](T input) { + UnaryExecutor::Execute(src_arg, result, [&](T input) { if (input.IsFinite()) { return Operation(calendar.get(), input, tz_name, format, result); } else { - return StringVector::AddString(result, Timestamp::ToString(timestamp_t(input.value))); + return StringVector::AddString(result, Date::ToInfinity(input)); } }); } else { BinaryExecutor::Execute( - src_arg, fmt_arg, result, args.size(), [&](T input, string_t format_specifier) { + src_arg, fmt_arg, result, [&](T input, string_t format_specifier) { if (input.IsFinite()) { StrfTimeFormat format; ParseFormatSpecifier(format_specifier, format); return Operation(calendar.get(), input, tz_name, format, result); } else { - return StringVector::AddString(result, Timestamp::ToString(timestamp_t(input.value))); + return StringVector::AddString(result, Date::ToInfinity(input)); } }); } } - static void AddBinaryTimestampFunction(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddBinaryTimestampFunction(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ, LogicalType::VARCHAR}, LogicalType::VARCHAR, ICUStrftimeFunction, Bind)); set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ_NS, LogicalType::VARCHAR}, LogicalType::VARCHAR, @@ -632,7 +632,7 @@ struct ICUStrftime : public ICUDateFunc { static string_t CastOperation(icu::Calendar *calendar, T input, Vector &result) { // Infinity is always formatted the same way if (!input.IsFinite()) { - return StringVector::AddString(result, Timestamp::ToString(timestamp_t(input.value))); + return StringVector::AddString(result, Date::ToInfinity(input)); } // decompose the timestamp diff --git a/src/duckdb/extension/icu/icu-table-range.cpp b/src/duckdb/extension/icu/icu-table-range.cpp index 918c760eb..960bb4a85 100644 --- a/src/duckdb/extension/icu/icu-table-range.cpp +++ b/src/duckdb/extension/icu/icu-table-range.cpp @@ -113,8 +113,8 @@ struct ICUTableRange { input.Flatten(); for (idx_t c = 0; c < input.ColumnCount(); c++) { if (FlatVector::IsNull(input.data[c], row_id)) { - result.start = timestamp_tz_t(0); - result.end = timestamp_tz_t(0); + result.start = timestamp_tz_t::epoch(); + result.end = timestamp_tz_t::epoch(); result.increment = interval_t(); result.greater_than_check = true; result.inclusive_bound = false; @@ -201,7 +201,6 @@ struct ICUTableRange { } if (state.empty_range) { // empty range - output.SetCardinality(0); state.current_input_row++; state.initialized_row = false; return OperatorResultType::HAVE_MORE_OUTPUT; @@ -247,6 +246,7 @@ struct ICUTableRange { RangeDateTimeLocalInit); generate_series_function.in_out_function = ICUTableRangeFunction; generate_series_function.cardinality = Cardinality; + generate_series_function.return_type = TableFunctionReturnType::SET_RETURNING_FUNCTION; generate_series.AddFunction(generate_series_function); loader.RegisterFunction(generate_series); diff --git a/src/duckdb/extension/icu/icu-timebucket.cpp b/src/duckdb/extension/icu/icu-timebucket.cpp index eb13ca5cb..257eefbd6 100644 --- a/src/duckdb/extension/icu/icu-timebucket.cpp +++ b/src/duckdb/extension/icu/icu-timebucket.cpp @@ -131,10 +131,10 @@ struct ICUTimeBucket : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 2); auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar(info.calendar->clone()); - BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), [&](TA left, TB right) { + BinaryExecutor::Execute(args.data[0], args.data[1], result, [&](TA left, TB right) { return OP::template Operation(left, right, calendar); }); } @@ -144,11 +144,11 @@ struct ICUTimeBucket : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 3); auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar(info.calendar->clone()); TernaryExecutor::Execute( - args.data[0], args.data[1], args.data[2], result, args.size(), [&](TA ta, TB tb, TC tc) { + args.data[0], args.data[1], args.data[2], result, [&](TA ta, TB tb, TC tc) { return OP::template Operation(args.data[0], args.data[1], args.data[2], calendar.get()); }); } @@ -364,12 +364,12 @@ struct ICUTimeBucket : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 2); auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); TZCalendar calendar(*info.calendar, info.cal_setting); SetTimeZone(calendar.GetICUCalendar(), string_t("UTC")); - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; + const auto &bucket_width_arg = args.data[0]; + const auto &ts_arg = args.data[1]; if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (ConstantVector::IsNull(bucket_width_arg)) { @@ -380,25 +380,25 @@ struct ICUTimeBucket : public ICUDateFunc { switch (bucket_width_type) { case BucketWidthType::CONVERTIBLE_TO_MICROS: BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), [&](interval_t bucket_width, timestamp_tz_t ts) { + bucket_width_arg, ts_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts) { return WidthConvertibleToMicrosBinaryOperator::Operation(bucket_width, ts, calendar); }); break; case BucketWidthType::CONVERTIBLE_TO_DAYS: BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), [&](interval_t bucket_width, timestamp_tz_t ts) { + bucket_width_arg, ts_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts) { return WidthConvertibleToDaysBinaryOperator::Operation(bucket_width, ts, calendar); }); break; case BucketWidthType::CONVERTIBLE_TO_MONTHS: BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), [&](interval_t bucket_width, timestamp_tz_t ts) { + bucket_width_arg, ts_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts) { return WidthConvertibleToMonthsBinaryOperator::Operation(bucket_width, ts, calendar); }); break; case BucketWidthType::UNCLASSIFIED: BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), [&](interval_t bucket_width, timestamp_tz_t ts) { + bucket_width_arg, ts_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts) { return BinaryOperator::Operation(bucket_width, ts, calendar); }); break; @@ -407,7 +407,7 @@ struct ICUTimeBucket : public ICUDateFunc { } } else { BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), [&](interval_t bucket_width, timestamp_tz_t ts) { + bucket_width_arg, ts_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts) { return BinaryOperator::Operation(bucket_width, ts, calendar); }); } @@ -417,13 +417,13 @@ struct ICUTimeBucket : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 3); auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); TZCalendar calendar(*info.calendar, info.cal_setting); SetTimeZone(calendar.GetICUCalendar(), string_t("UTC")); - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &offset_arg = args.data[2]; + const auto &bucket_width_arg = args.data[0]; + const auto &ts_arg = args.data[1]; + const auto &offset_arg = args.data[2]; if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (ConstantVector::IsNull(bucket_width_arg)) { @@ -434,7 +434,7 @@ struct ICUTimeBucket : public ICUDateFunc { switch (bucket_width_type) { case BucketWidthType::CONVERTIBLE_TO_MICROS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, interval_t offset) { return OffsetWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, offset, calendar); @@ -442,7 +442,7 @@ struct ICUTimeBucket : public ICUDateFunc { break; case BucketWidthType::CONVERTIBLE_TO_DAYS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, interval_t offset) { return OffsetWidthConvertibleToDaysTernaryOperator::Operation(bucket_width, ts, offset, calendar); @@ -450,7 +450,7 @@ struct ICUTimeBucket : public ICUDateFunc { break; case BucketWidthType::CONVERTIBLE_TO_MONTHS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, interval_t offset) { return OffsetWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, offset, calendar); @@ -458,7 +458,7 @@ struct ICUTimeBucket : public ICUDateFunc { break; case BucketWidthType::UNCLASSIFIED: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, interval_t offset) { return OffsetTernaryOperator::Operation(bucket_width, ts, offset, calendar); }); @@ -468,7 +468,7 @@ struct ICUTimeBucket : public ICUDateFunc { } } else { TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), + bucket_width_arg, ts_arg, offset_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, interval_t offset) { return OffsetTernaryOperator::Operation(bucket_width, ts, offset, calendar); }); @@ -479,13 +479,13 @@ struct ICUTimeBucket : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 3); auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); TZCalendar calendar(*info.calendar, info.cal_setting); SetTimeZone(calendar.GetICUCalendar(), string_t("UTC")); - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &origin_arg = args.data[2]; + const auto &bucket_width_arg = args.data[0]; + const auto &ts_arg = args.data[1]; + const auto &origin_arg = args.data[2]; if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR && origin_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -498,7 +498,7 @@ struct ICUTimeBucket : public ICUDateFunc { switch (bucket_width_type) { case BucketWidthType::CONVERTIBLE_TO_MICROS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, timestamp_tz_t origin) { return OriginWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, origin, calendar); @@ -506,7 +506,7 @@ struct ICUTimeBucket : public ICUDateFunc { break; case BucketWidthType::CONVERTIBLE_TO_DAYS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, timestamp_tz_t origin) { return OriginWidthConvertibleToDaysTernaryOperator::Operation(bucket_width, ts, origin, calendar); @@ -514,7 +514,7 @@ struct ICUTimeBucket : public ICUDateFunc { break; case BucketWidthType::CONVERTIBLE_TO_MONTHS: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, timestamp_tz_t origin) { return OriginWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, origin, calendar); @@ -522,7 +522,7 @@ struct ICUTimeBucket : public ICUDateFunc { break; case BucketWidthType::UNCLASSIFIED: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, timestamp_tz_t origin) -> optional { return OriginTernaryOperator::Operation(bucket_width, ts, origin, calendar); @@ -534,7 +534,7 @@ struct ICUTimeBucket : public ICUDateFunc { } } else { TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), + bucket_width_arg, ts_arg, origin_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, timestamp_tz_t origin) -> optional { return OriginTernaryOperator::Operation(bucket_width, ts, origin, calendar); }); @@ -545,12 +545,12 @@ struct ICUTimeBucket : public ICUDateFunc { D_ASSERT(args.ColumnCount() == 3); auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); TZCalendar calendar(*info.calendar, info.cal_setting); - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &tz_arg = args.data[2]; + const auto &bucket_width_arg = args.data[0]; + const auto &ts_arg = args.data[1]; + const auto &tz_arg = args.data[2]; if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR && tz_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -566,7 +566,7 @@ struct ICUTimeBucket : public ICUDateFunc { origin = ICUDateFunc::FromNaive(calendar.GetICUCalendar(), Timestamp::FromEpochMicroSeconds(DEFAULT_ORIGIN_MICROS_1)); BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), [&](interval_t bucket_width, timestamp_tz_t ts) { + bucket_width_arg, ts_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts) { return TimeZoneWidthConvertibleToMicrosBinaryOperator::Operation(bucket_width, ts, origin, calendar); }); @@ -575,7 +575,7 @@ struct ICUTimeBucket : public ICUDateFunc { origin = ICUDateFunc::FromNaive(calendar.GetICUCalendar(), Timestamp::FromEpochMicroSeconds(DEFAULT_ORIGIN_MICROS_1)); BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), [&](interval_t bucket_width, timestamp_tz_t ts) { + bucket_width_arg, ts_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts) { return TimeZoneWidthConvertibleToDaysBinaryOperator::Operation(bucket_width, ts, origin, calendar); }); @@ -584,14 +584,14 @@ struct ICUTimeBucket : public ICUDateFunc { origin = ICUDateFunc::FromNaive(calendar.GetICUCalendar(), Timestamp::FromEpochMicroSeconds(DEFAULT_ORIGIN_MICROS_2)); BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), [&](interval_t bucket_width, timestamp_tz_t ts) { + bucket_width_arg, ts_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts) { return TimeZoneWidthConvertibleToMonthsBinaryOperator::Operation(bucket_width, ts, origin, calendar); }); break; case BucketWidthType::UNCLASSIFIED: TernaryExecutor::Execute( - bucket_width_arg, ts_arg, tz_arg, result, args.size(), + bucket_width_arg, ts_arg, tz_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, string_t tz) { return TimeZoneTernaryOperator::Operation(bucket_width, ts, tz, calendar); }); @@ -601,8 +601,7 @@ struct ICUTimeBucket : public ICUDateFunc { } } else { TernaryExecutor::Execute( - bucket_width_arg, ts_arg, tz_arg, result, args.size(), - [&](interval_t bucket_width, timestamp_tz_t ts, string_t tz) { + bucket_width_arg, ts_arg, tz_arg, result, [&](interval_t bucket_width, timestamp_tz_t ts, string_t tz) { return TimeZoneTernaryOperator::Operation(bucket_width, ts, tz, calendar); }); } diff --git a/src/duckdb/extension/icu/icu-timezone.cpp b/src/duckdb/extension/icu/icu-timezone.cpp index f7560b027..b068399e2 100644 --- a/src/duckdb/extension/icu/icu-timezone.cpp +++ b/src/duckdb/extension/icu/icu-timezone.cpp @@ -102,79 +102,6 @@ static void ICUTimeZoneFunction(ClientContext &context, TableFunctionInput &data is_dst.Append(Value::BOOLEAN(dst_offset_ms != 0)); ++index; } - output.SetCardinality(index); -} - -// Wrap the multiply-named and non-type-safe cast utilities. -struct ICUCast { - template - static inline DST Operation(SRC input) { - throw NotImplementedException("Naive timezone cast could not be performed!"); - } -}; - -// From naive types to TIMESTAMP_TZ -template <> -timestamp_tz_t ICUCast::Operation(timestamp_t src) { - return timestamp_tz_t(src); -} - -template <> -timestamp_tz_t ICUCast::Operation(timestamp_ms_t src) { - return CastTimestampMsToUs::Operation(src); -} - -template <> -timestamp_tz_t ICUCast::Operation(timestamp_ns_t src) { - return CastTimestampMsToUs::Operation(src); -} - -template <> -timestamp_tz_t ICUCast::Operation(timestamp_sec_t src) { - return timestamp_tz_t(CastTimestampSecToUs::Operation(src)); -} - -template <> -timestamp_tz_t ICUCast::Operation(date_t src) { - return timestamp_tz_t(Cast::Operation(src)); -} - -template <> -timestamp_tz_ns_t ICUCast::Operation(date_t src) { - return timestamp_tz_ns_t(Cast::Operation(src)); -} - -// From TIMESTAMP_TZ to naive types -template <> -timestamp_sec_t ICUCast::Operation(timestamp_tz_t src) { - return CastTimestampUsToSec::Operation(src); -} - -template <> -timestamp_ms_t ICUCast::Operation(timestamp_tz_t src) { - return CastTimestampUsToMs::Operation(src); -} - -template <> -timestamp_t ICUCast::Operation(timestamp_tz_t src) { - return timestamp_t(src); -} - -template <> -timestamp_ns_t ICUCast::Operation(timestamp_tz_t src) { - return CastTimestampUsToNs::Operation(src); -} - -// From TIMESTAMP_TZ_NS -template <> -timestamp_ns_t ICUCast::Operation(timestamp_tz_ns_t src) { - return timestamp_ns_t(src); -} - -// From TIME_TZ -template <> -dtime_tz_t ICUCast::Operation(dtime_tz_t src) { - return src; } struct ICUFromNaiveTimestamp : public ICUDateFunc { @@ -443,7 +370,7 @@ struct ICULocalTimestampFunc : public ICUDateFunc { static timestamp_t GetLocalTimestamp(ExpressionState &state) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); @@ -458,8 +385,8 @@ struct ICULocalTimestampFunc : public ICUDateFunc { rdata[0] = GetLocalTimestamp(state); } - static void AddFunction(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddFunction(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(ScalarFunction({}, LogicalType::TIMESTAMP, Execute, BindNow)); loader.RegisterFunction(set); } @@ -474,8 +401,8 @@ struct ICULocalTimeFunc : public ICUDateFunc { rdata[0] = Timestamp::GetTime(local); } - static void AddFunction(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddFunction(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(ScalarFunction({}, LogicalType::TIME, Execute, ICULocalTimestampFunc::BindNow)); loader.RegisterFunction(set); } @@ -589,36 +516,34 @@ struct ICUTimeZoneFunc : public ICUDateFunc { template static void Execute(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); CalendarPtr calendar_ptr(info.calendar->clone()); auto calendar = calendar_ptr.get(); // Two cases: constant TZ, variable TZ D_ASSERT(input.ColumnCount() == 2); - auto &tz_vec = input.data[0]; - auto &ts_vec = input.data[1]; + const auto &tz_vec = input.data[0]; + const auto &ts_vec = input.data[1]; if (tz_vec.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (ConstantVector::IsNull(tz_vec)) { throw InternalException("ICUTimeZone called with constant NULL tz"); } SetTimeZone(calendar, *ConstantVector::GetData(tz_vec)); - UnaryExecutor::Execute(ts_vec, result, input.size(), - [&](SRC ts) { return OP::Operation(calendar, ts); }); + UnaryExecutor::Execute(ts_vec, result, [&](SRC ts) { return OP::Operation(calendar, ts); }); } else { - BinaryExecutor::Execute(tz_vec, ts_vec, result, input.size(), - [&](string_t tz_id, SRC ts) { - if (ts.IsFinite()) { - SetTimeZone(calendar, tz_id); - return OP::Operation(calendar, ts); - } else { - return ICUCast::Operation(ts); - } - }); + BinaryExecutor::Execute(tz_vec, ts_vec, result, [&](string_t tz_id, SRC ts) { + if (ts.IsFinite()) { + SetTimeZone(calendar, tz_id); + return OP::Operation(calendar, ts); + } else { + return Cast::Operation(ts); + } + }); } } - static void AddFunction(const string &name, ExtensionLoader &loader) { - ScalarFunctionSet set(name); + static void AddFunction(const Identifier &name, ExtensionLoader &loader) { + ScalarFunctionSet set {name}; set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP_TZ, Execute, Bind)); set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP_TZ}, LogicalType::TIMESTAMP, diff --git a/src/duckdb/extension/icu/icu_extension.cpp b/src/duckdb/extension/icu/icu_extension.cpp index 3ec3ff696..0edeaa572 100644 --- a/src/duckdb/extension/icu/icu_extension.cpp +++ b/src/duckdb/extension/icu/icu_extension.cpp @@ -105,8 +105,8 @@ struct IcuBindData : public FunctionData { static string EncodeFunctionName(const string &collation) { return FUNCTION_PREFIX + collation; } - static string DecodeFunctionName(const string &fname) { - return fname.substr(FUNCTION_PREFIX.size()); + static string DecodeFunctionName(const Identifier &fname) { + return fname.GetIdentifierName().substr(FUNCTION_PREFIX.size()); } }; @@ -131,12 +131,12 @@ static void ICUCollateFunction(DataChunk &args, ExpressionState &state, Vector & const char HEX_TABLE[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'}; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); auto &collator = *info.collator; duckdb::unique_ptr buffer; int32_t buffer_size = 0; - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(args.data[0], result, [&](string_t input) { // create a sort key from the string const auto string_size = idx_t(ICUGetSortKey(collator, input, buffer, buffer_size)); // convert the sort key to hexadecimal @@ -200,7 +200,8 @@ static duckdb::unique_ptr ICUSortKeyBind(BindScalarFunctionInput & static ScalarFunction GetICUCollateFunction(const string &collation, const string &tag) { string fname = IcuBindData::EncodeFunctionName(collation); - ScalarFunction result(fname, {LogicalType::VARCHAR}, LogicalType::VARCHAR, ICUCollateFunction, ICUCollateBind); + ScalarFunction result(Identifier(fname), {LogicalType::VARCHAR}, LogicalType::VARCHAR, ICUCollateFunction, + ICUCollateBind); //! collation tag is added into the Function extra info result.extra_info = tag; result.SetSerializeCallback(IcuBindData::Serialize); @@ -225,7 +226,7 @@ unique_ptr GetNormalizedTimeZone(string &tz_str) { return tz; } - // Map UTC±NN00 to Etc/GMT±N + // Map UTC±NN00 to Etc/UTC±N do { if (tz_str.size() <= 4) { break; @@ -234,34 +235,63 @@ unique_ptr GetNormalizedTimeZone(string &tz_str) { break; } - // Parse the offset, allowing single digits idx_t pos = 3; - int hh, mm, ss; - if (!Timestamp::TryParseUTCOffset(tz_str.data(), pos, tz_str.size(), hh, mm, ss, false)) { + const auto utc = tz_str[pos++]; + // Invert the sign (UTC and Etc use opposite sign conventions) + // https://en.wikipedia.org/wiki/Tz_database#Area + auto sign = utc; + if (utc == '+') { + sign = '-'; + ; + } else if (utc == '-') { + sign = '+'; + } else { break; } - if (pos < tz_str.size() || mm || ss) { + + // Collect remaining characters (digits and colons) + string remainder; + for (; pos < tz_str.size(); ++pos) { + const auto ch = tz_str[pos]; + if (ch != ':' && !StringUtil::CharacterIsDigit(ch)) { + break; + } + remainder += ch; + } + if (pos < tz_str.size()) { break; } - // Invert the sign (UTC and Etc use opposite sign conventions) - // https://en.wikipedia.org/wiki/Tz_database#Area - hh = -hh; + // Step 1: Strip leading zeros + idx_t start = 0; + while (start < remainder.size() && remainder[start] == '0') { + ++start; + } + remainder = remainder.substr(start); + + // Step 2: Parse hours based on whether colon is present + string hours_str; + auto colon_idx = remainder.find(':'); + if (colon_idx != string::npos) { + // Has colon: split by colon, part before colon is hours + hours_str = remainder.substr(0, colon_idx); + } else if (remainder.size() <= 2) { + // 1-2 digits: entire string is hours + hours_str = remainder; + } else { + // No colon, 3+ digits: HHMM format, last 2 are minutes, rest are hours + hours_str = remainder.substr(0, remainder.size() - 2); + } - // Build the mapped timezone string with single digit hour offsets + // Build the mapped timezone string string mapped = "Etc/GMT"; - if (hh < 0) { - mapped += "-"; - hh = -hh; + if (hours_str.empty()) { + // Zero offset + mapped += "+0"; } else { - mapped += "+"; - } - if (hh >= 10) { - mapped += UnsafeNumericCast('0' + hh / 10); - hh %= 10; + mapped += sign; + mapped += hours_str; } - mapped += UnsafeNumericCast('0' + hh); - // Final sanity check if (tz = GetKnownTimeZone(mapped)) { tz_str = mapped; @@ -382,7 +412,6 @@ static void ICUCalendarFunction(ClientContext &context, TableFunctionInput &data ++index; } - output.SetCardinality(index); } static void SetICUCalendar(ClientContext &context, SetScope scope, Value ¶meter) { @@ -442,7 +471,7 @@ static void LoadInternal(ExtensionLoader &loader) { } collation = StringUtil::Lower(collation); - CreateCollationInfo info(collation, GetICUCollateFunction(collation, ""), false, false); + CreateCollationInfo info(Identifier(collation), GetICUCollateFunction(collation, ""), false, false); loader.RegisterCollation(info); } diff --git a/src/duckdb/extension/jemalloc/include/jemalloc_extension.hpp b/src/duckdb/extension/jemalloc/include/jemalloc_extension.hpp deleted file mode 100644 index 5092262c4..000000000 --- a/src/duckdb/extension/jemalloc/include/jemalloc_extension.hpp +++ /dev/null @@ -1,32 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// jemalloc_extension.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" - -namespace duckdb { - -class JemallocExtension : public Extension { -public: - void Load(ExtensionLoader &loader) override; - std::string Name() override; - std::string Version() const override; - - static data_ptr_t Allocate(PrivateAllocatorData *private_data, idx_t size); - static void Free(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size); - static data_ptr_t Reallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, idx_t size); - - static int64_t DecayDelay(); - static void ThreadFlush(idx_t threshold); - static void ThreadIdle(); - static void FlushAll(); - static void SetBackgroundThreads(bool enable); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/json/include/json_executors.hpp b/src/duckdb/extension/json/include/json_executors.hpp index dae5edbbd..dda7d560b 100644 --- a/src/duckdb/extension/json/include/json_executors.hpp +++ b/src/duckdb/extension/json/include/json_executors.hpp @@ -28,8 +28,8 @@ struct JSONExecutors { auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); - auto &inputs = args.data[0]; - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) -> optional { + const auto &inputs = args.data[0]; + UnaryExecutor::Execute(inputs, result, [&](string_t input) -> optional { auto doc = JSONCommon::ReadDocument(input, JSONCommon::READ_FLAG, alc); return fun(doc->root, alc, result); }); @@ -41,16 +41,16 @@ struct JSONExecutors { template static void BinaryExecute(DataChunk &args, ExpressionState &state, Vector &result, const json_function_t fun) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); - auto &inputs = args.data[0]; + const auto &inputs = args.data[0]; if (info.constant) { // Constant path const char *ptr = info.ptr; const idx_t &len = info.len; if (info.path_type == JSONCommon::JSONPathType::REGULAR) { - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) -> optional { + UnaryExecutor::Execute(inputs, result, [&](string_t input) -> optional { auto doc = JSONCommon::ReadDocument(input, JSONCommon::READ_FLAG, alc); auto val = JSONCommon::GetUnsafe(doc->root, ptr, len); if (SET_NULL_IF_NOT_FOUND && !val) { @@ -62,7 +62,7 @@ struct JSONExecutors { } else { D_ASSERT(info.path_type == JSONCommon::JSONPathType::WILDCARD); vector vals; - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(inputs, result, [&](string_t input) { vals.clear(); auto doc = JSONCommon::ReadDocument(input, JSONCommon::READ_FLAG, alc); @@ -80,7 +80,7 @@ struct JSONExecutors { for (idx_t i = 0; i < vals.size(); i++) { auto &val = vals[i]; D_ASSERT(val != nullptr); // Wildcard extract shouldn't give back nullptrs - auto fun_result = fun(val, alc, result); + auto fun_result = fun(val, alc, child_entry); if (fun_result.has_value()) { child_vals[current_size + i] = fun_result.value(); } else { @@ -120,7 +120,7 @@ struct JSONExecutors { template static void ExecuteMany(DataChunk &args, ExpressionState &state, Vector &result, const json_function_t fun) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); D_ASSERT(info.ptrs.size() == info.lens.size()); diff --git a/src/duckdb/extension/json/include/json_functions.hpp b/src/duckdb/extension/json/include/json_functions.hpp index c1cffe5a5..82842b4b1 100644 --- a/src/duckdb/extension/json/include/json_functions.hpp +++ b/src/duckdb/extension/json/include/json_functions.hpp @@ -115,7 +115,7 @@ class JSONFunctions { template static void AddAliases(const vector &names, FUNCTION_INFO fun, vector &functions) { for (auto &name : names) { - fun.name = name; + fun.name = Identifier(name); functions.push_back(fun); } } diff --git a/src/duckdb/extension/json/include/json_multi_file_info.hpp b/src/duckdb/extension/json/include/json_multi_file_info.hpp index 061a6a459..e4b28f04d 100644 --- a/src/duckdb/extension/json/include/json_multi_file_info.hpp +++ b/src/duckdb/extension/json/include/json_multi_file_info.hpp @@ -32,7 +32,7 @@ struct JSONMultiFileInfo : MultiFileReaderInterface { const vector &expected_types) override; unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, unique_ptr options) override; - void BindReader(ClientContext &context, vector &return_types, vector &names, + void BindReader(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) override; optional_idx MaxThreads(const MultiFileBindData &bind_data, const MultiFileGlobalState &global_state, FileExpandResult expand_result) override; diff --git a/src/duckdb/extension/json/include/json_reader_options.hpp b/src/duckdb/extension/json/include/json_reader_options.hpp index 7b9048cd6..67f5f5059 100644 --- a/src/duckdb/extension/json/include/json_reader_options.hpp +++ b/src/duckdb/extension/json/include/json_reader_options.hpp @@ -32,13 +32,22 @@ struct DateFormatMap { return candidate_formats.find(type)->second.back(); } + const vector &GetFormats(LogicalTypeId type) const { + D_ASSERT(candidate_formats.find(type) != candidate_formats.end()); + return candidate_formats.find(type)->second; + } + public: static void AddFormat(type_id_map_t> &candidate_formats, LogicalTypeId type, const string &format_string) { auto &formats = candidate_formats[type]; formats.emplace_back(); formats.back().format_specifier = format_string; - StrpTimeFormat::ParseFormatSpecifier(formats.back().format_specifier, formats.back()); + const auto error = StrpTimeFormat::ParseFormatSpecifier(formats.back().format_specifier, formats.back()); + if (!error.empty()) { + formats.pop_back(); + throw InvalidInputException(error); + } } static bool HasFormats(const type_id_map_t> &candidate_formats, LogicalTypeId type) { @@ -74,14 +83,6 @@ class MutableDateFormatMap { return true; } - void ShrinkFormatsToSize(LogicalTypeId type, idx_t size) { - lock_guard lock(format_lock); - auto &formats = date_format_map.candidate_formats[type]; - while (formats.size() > size) { - formats.pop_back(); - } - } - private: mutex format_lock; DateFormatMap &date_format_map; diff --git a/src/duckdb/extension/json/include/json_scan.hpp b/src/duckdb/extension/json/include/json_scan.hpp index 5a54be913..1f78942f0 100644 --- a/src/duckdb/extension/json/include/json_scan.hpp +++ b/src/duckdb/extension/json/include/json_scan.hpp @@ -136,7 +136,7 @@ struct JSONLocalTableFunctionState : public LocalTableFunctionState { struct JSONScan { public: static void AutoDetect(ClientContext &context, MultiFileBindData &bind_data, vector &return_types, - vector &names); + vector &names); static void Serialize(Serializer &serializer, const optional_ptr bind_data, const TableFunction &function); diff --git a/src/duckdb/extension/json/include/json_serializer.hpp b/src/duckdb/extension/json/include/json_serializer.hpp index d9bdd51dc..dfcaa4383 100644 --- a/src/duckdb/extension/json/include/json_serializer.hpp +++ b/src/duckdb/extension/json/include/json_serializer.hpp @@ -10,6 +10,8 @@ struct JsonSerializer : Serializer { yyjson_mut_doc *doc; yyjson_mut_val *current_tag; vector stack; + vector stack_can_be_omitted; + vector property_can_be_omitted; // Skip writing property if null bool skip_if_null = false; @@ -20,6 +22,9 @@ struct JsonSerializer : Serializer { inline yyjson_mut_val *Current() { return stack.back(); }; + inline bool CurrentPropertyCanBeOmitted() { + return !property_can_be_omitted.empty() && property_can_be_omitted.back(); + } // Either adds a value to the current object with the current tag, or appends it to the current array void PushValue(yyjson_mut_val *val); @@ -27,7 +32,8 @@ struct JsonSerializer : Serializer { public: explicit JsonSerializer(yyjson_mut_doc *doc, bool skip_if_null, bool skip_if_empty, bool skip_if_default, SerializationOptions options_p = SerializationOptions()) - : doc(doc), stack({yyjson_mut_obj(doc)}), skip_if_null(skip_if_null), skip_if_empty(skip_if_empty) { + : doc(doc), stack({yyjson_mut_obj(doc)}), stack_can_be_omitted({false}), skip_if_null(skip_if_null), + skip_if_empty(skip_if_empty) { options = std::move(options_p); options.serialize_enum_as_string = true; options.serialize_default_values = !skip_if_default; diff --git a/src/duckdb/extension/json/json_deserializer.cpp b/src/duckdb/extension/json/json_deserializer.cpp index 451e7334c..7ea58ce1c 100644 --- a/src/duckdb/extension/json/json_deserializer.cpp +++ b/src/duckdb/extension/json/json_deserializer.cpp @@ -224,15 +224,19 @@ uint64_t JsonDeserializer::ReadUnsignedInt64() { } float JsonDeserializer::ReadFloat() { - auto val = GetNextValue(); - if (!yyjson_is_real(val)) { - ThrowTypeError(val, "float"); - } - return yyjson_get_real(val); + return LossyNumericCast(ReadDouble()); } double JsonDeserializer::ReadDouble() { auto val = GetNextValue(); + if (yyjson_is_raw(val)) { + double res; + string_t input(yyjson_get_raw(val)); + if (!TryCast::Operation(input, res)) { + throw InvalidInputException("Expected a floating point number, but got the string %s", input.GetString()); + } + return res; + } if (!yyjson_is_real(val)) { ThrowTypeError(val, "double"); } diff --git a/src/duckdb/extension/json/json_extension.cpp b/src/duckdb/extension/json/json_extension.cpp index 9ff72e490..a871bbba8 100644 --- a/src/duckdb/extension/json/json_extension.cpp +++ b/src/duckdb/extension/json/json_extension.cpp @@ -57,10 +57,10 @@ static void LoadInternal(ExtensionLoader &loader) { auto copy_fun = JSONFunctions::GetJSONCopyFunction(); loader.RegisterFunction(copy_fun); copy_fun.extension = "ndjson"; - copy_fun.name = "ndjson"; + copy_fun.SetName("ndjson"); loader.RegisterFunction(copy_fun); copy_fun.extension = "jsonl"; - copy_fun.name = "jsonl"; + copy_fun.SetName("jsonl"); loader.RegisterFunction(copy_fun); // JSON macro's diff --git a/src/duckdb/extension/json/json_functions.cpp b/src/duckdb/extension/json/json_functions.cpp index 696ce646d..66dac5c31 100644 --- a/src/duckdb/extension/json/json_functions.cpp +++ b/src/duckdb/extension/json/json_functions.cpp @@ -103,9 +103,8 @@ bool JSONReadManyFunctionData::Equals(const FunctionData &other_p) const { unique_ptr JSONReadManyFunctionData::Bind(BindScalarFunctionInput &input) { auto &context = input.GetClientContext(); - auto &bound_function = input.GetBoundFunction(); auto &arguments = input.GetArguments(); - D_ASSERT(bound_function.GetArguments().size() == 2); + D_ASSERT(input.GetBoundFunction().GetArguments().size() == 2); if (arguments[1]->HasParameter()) { throw ParameterNotResolvedException(); } @@ -236,7 +235,7 @@ unique_ptr JSONFunctions::ReadJSONReplacement(ClientContext &context, if (!FileSystem::HasGlob(table_name)) { auto &fs = FileSystem::GetFileSystem(context); - table_function->alias = fs.ExtractBaseName(table_name); + table_function->alias = Identifier(fs.ExtractBaseName(table_name)); } return std::move(table_function); diff --git a/src/duckdb/extension/json/json_functions/copy_json.cpp b/src/duckdb/extension/json/json_functions/copy_json.cpp index 5fd762b6c..34c7cd6fe 100644 --- a/src/duckdb/extension/json/json_functions/copy_json.cpp +++ b/src/duckdb/extension/json/json_functions/copy_json.cpp @@ -1,4 +1,6 @@ #include "duckdb/function/copy_function.hpp" +#include "duckdb/common/bind_helpers.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/expression/operator_expression.hpp" @@ -19,37 +21,96 @@ static void ThrowJSONCopyParameterException(const string &loption) { throw BinderException("COPY (FORMAT JSON) parameter %s expects a single argument.", loption); } +static void ThrowJSONCopyNullException(const string &loption) { + throw BinderException("COPY (FORMAT JSON) parameter \"%s\" cannot be NULL.", loption); +} + +static void ThrowJSONCopyTypeException(const string &loption, const Value &value, const string &expected_type) { + throw BinderException("COPY (FORMAT JSON) parameter \"%s\" expects a %s argument, but got %s.", loption, + expected_type, value.type()); +} + +static const Value &GetSingleJSONCopyValue(const string &loption, const vector &values) { + if (values.size() != 1) { + ThrowJSONCopyParameterException(loption); + } + if (values.back().IsNull()) { + ThrowJSONCopyNullException(loption); + } + return values.back(); +} + +static string GetSingleJSONCopyString(const string &loption, const vector &values) { + auto &value = GetSingleJSONCopyValue(loption, values); + if (value.type().id() != LogicalTypeId::VARCHAR) { + ThrowJSONCopyTypeException(loption, value, "VARCHAR"); + } + return StringValue::Get(value); +} + +static bool GetJSONCopyBoolean(Binder &binder, const string &loption, const vector &values) { + if (values.size() > 1) { + throw InvalidInputException("Copy option \"%s\" did not expect a list as argument", loption); + } + if (values.empty()) { + return true; + } + if (values.back().IsNull()) { + ThrowJSONCopyNullException(loption); + } + return values.back().CastAs(binder.context, LogicalType::BOOLEAN).GetValue(); +} + +static unique_ptr PushJSONFormatProjection(unique_ptr source_ref, + const Identifier &function_name, const string &format) { + auto format_node = make_uniq(); + format_node->from_table = std::move(source_ref); + + auto columns_star = make_uniq(); + columns_star->IsColumnsMutable() = true; + vector> args; + args.push_back(std::move(columns_star)); + args.push_back(make_uniq(Value(format))); + format_node->select_list.push_back(make_uniq(function_name, std::move(args))); + + auto format_stmt = make_uniq(); + format_stmt->node = std::move(format_node); + return make_uniq(std::move(format_stmt)); +} + static BoundStatement CopyToJSONPlan(Binder &binder, CopyStatement &stmt) { static const unordered_set SUPPORTED_BASE_OPTIONS { - "compression", "encoding", "use_tmp_file", "overwrite_or_ignore", "overwrite", "append", "filename_pattern", - "file_extension", "per_thread_output", "file_size_bytes", - // "partition_by", unsupported - "return_files", "preserve_order", "return_stats", "write_partition_columns", "write_empty_file", - "hive_file_pattern"}; + "compression", "encoding", "use_tmp_file", "overwrite_or_ignore", "overwrite", + "append", "filename_pattern", "file_extension", "per_thread_output", "file_size_bytes", + "partition_by", "return_files", "preserve_order", "return_stats", "write_partition_columns", + "write_empty_file", "hive_file_pattern"}; auto ©_info = *stmt.info; // Parse the options, creating options for the CSV writer while doing so string date_format; string timestamp_format; + // Partition columns are kept as separate columns (instead of being packed into the JSON object), so that the + // COPY writer can partition on them. By default they are excluded from the written JSON, matching the behavior + // of the other formats. WRITE_PARTITION_COLUMNS keeps them inside the JSON object instead. + vector partition_columns; + bool write_partition_columns = false; + vector original_column_names; // We insert the JSON file extension here so it works properly with PER_THREAD_OUTPUT/FILE_SIZE_BYTES etc. case_insensitive_map_t> csv_copy_options {{"file_extension", {"json"}}}; for (const auto &kv : copy_info.options) { const auto &loption = StringUtil::Lower(kv.first); if (loption == "dateformat" || loption == "date_format") { - if (kv.second.size() != 1) { - ThrowJSONCopyParameterException(loption); - } - date_format = StringValue::Get(kv.second.back()); + date_format = GetSingleJSONCopyString(loption, kv.second); } else if (loption == "timestampformat" || loption == "timestamp_format") { - if (kv.second.size() != 1) { - ThrowJSONCopyParameterException(loption); - } - timestamp_format = StringValue::Get(kv.second.back()); + timestamp_format = GetSingleJSONCopyString(loption, kv.second); } else if (loption == "array") { if (kv.second.size() > 1) { ThrowJSONCopyParameterException(loption); } + if (!kv.second.empty() && kv.second.back().IsNull()) { + ThrowJSONCopyNullException(loption); + } if (kv.second.empty() || BooleanValue::Get(kv.second.back().DefaultCastAs(LogicalTypeId::BOOLEAN))) { csv_copy_options["prefix"] = {"[\n\t"}; csv_copy_options["suffix"] = {"\n]\n"}; @@ -57,14 +118,48 @@ static BoundStatement CopyToJSONPlan(Binder &binder, CopyStatement &stmt) { } } else if (loption == "file_extension") { // Since we set the file extension to "json" above, we need to override it - csv_copy_options["file_extension"] = {StringValue::Get(kv.second.back())}; + csv_copy_options["file_extension"] = {GetSingleJSONCopyString(loption, kv.second)}; + } else if (loption == "partition_by") { + for (const auto &val : kv.second) { + if (val.IsNull()) { + ThrowJSONCopyNullException(loption); + } + } + if (original_column_names.empty()) { + auto node_copy = copy_info.select_statement->Copy(); + auto child_binder = Binder::CreateBinder(binder.context, &binder); + auto bound = child_binder->Bind(*node_copy); + original_column_names = bound.names; + } + auto converted = ConvertVectorToValue(vector(kv.second)); + auto partition_indices = ParseColumnsOrdered(converted, original_column_names, loption); + for (auto &partition_index : partition_indices) { + partition_columns.emplace_back(original_column_names[partition_index]); + } + vector csv_partition_columns; + for (const auto &partition_column : partition_columns) { + csv_partition_columns.push_back(partition_column); + } + csv_copy_options["partition_by"] = std::move(csv_partition_columns); + } else if (loption == "write_partition_columns") { + // Handled below by keeping the partition columns inside the JSON object. We do not forward this to the + // CSV writer, as that would write the (separate) partition columns as their own JSON lines. + write_partition_columns = GetJSONCopyBoolean(binder, loption, kv.second); } else if (SUPPORTED_BASE_OPTIONS.find(loption) != SUPPORTED_BASE_OPTIONS.end()) { + if (!kv.second.empty() && kv.second.back().IsNull()) { + ThrowJSONCopyNullException(loption); + } // We support these base options csv_copy_options.insert(kv); } else { throw BinderException("Unknown option for COPY ... TO ... (FORMAT JSON): \"%s\".", loption); } } + if (!write_partition_columns && !partition_columns.empty() && + partition_columns.size() == original_column_names.size()) { + throw NotImplementedException("No column to write as all columns are specified as partition columns. " + "WRITE_PARTITION_COLUMNS option can be used to write partition columns."); + } // Run the following query to convert everything into a single JSON column, then invoke the CSV writer // SELECT TO_JSON(STRUCT_PACK(*COLUMNS(*))) FROM @@ -72,39 +167,13 @@ static BoundStatement CopyToJSONPlan(Binder &binder, CopyStatement &stmt) { auto inner_select_stmt = make_uniq(); inner_select_stmt->node = std::move(copy_info.select_statement); - unique_ptr source_ref; - if (!date_format.empty() || !timestamp_format.empty()) { - // if we have date_format or timestamp_format defined, we use the json_copy macros to apply them - // e.g. SELECT json_copy_strftime_if_date(COLUMNS(*), date_format) FROM - auto strftime_node = make_uniq(); - strftime_node->from_table = make_uniq(std::move(inner_select_stmt)); - - unique_ptr expr; - auto columns_star = make_uniq(); - columns_star->columns = true; - expr = std::move(columns_star); - if (!date_format.empty()) { - // apply date format - vector> args; - args.push_back(std::move(expr)); - args.push_back(make_uniq(Value(date_format))); - auto strftime_date = make_uniq("json_copy_strftime_if_date", std::move(args)); - expr = std::move(strftime_date); - } - if (!timestamp_format.empty()) { - // apply timestamp format - vector> args; - args.push_back(std::move(expr)); - args.push_back(make_uniq(Value(timestamp_format))); - auto strftime_ts = make_uniq("json_copy_strftime_if_timestamp", std::move(args)); - expr = std::move(strftime_ts); - } - strftime_node->select_list.push_back(std::move(expr)); - auto strftime_stmt = make_uniq(); - strftime_stmt->node = std::move(strftime_node); - source_ref = make_uniq(std::move(strftime_stmt)); - } else { - source_ref = make_uniq(std::move(inner_select_stmt)); + auto source_ref = make_uniq(std::move(inner_select_stmt)); + if (!date_format.empty()) { + source_ref = PushJSONFormatProjection(std::move(source_ref), "json_copy_strftime_if_date", date_format); + } + if (!timestamp_format.empty()) { + source_ref = + PushJSONFormatProjection(std::move(source_ref), "json_copy_strftime_if_timestamp", timestamp_format); } // Build outer: SELECT TO_JSON(STRUCT_PACK(*COLUMNS(*))) FROM @@ -113,9 +182,15 @@ static BoundStatement CopyToJSONPlan(Binder &binder, CopyStatement &stmt) { select_node.from_table = std::move(source_ref); auto columns_star = make_uniq(); - columns_star->columns = true; + columns_star->IsColumnsMutable() = true; + if (!write_partition_columns) { + // Exclude the partition columns from the JSON object - they are kept as separate columns below + for (const auto &partition_column : partition_columns) { + columns_star->ExcludeListMutable().insert(QualifiedColumnName(Identifier(partition_column))); + } + } auto unpack = make_uniq(ExpressionType::OPERATOR_UNPACK); - unpack->children.push_back(std::move(columns_star)); + unpack->GetChildrenMutable().push_back(std::move(columns_star)); vector> struct_pack_args; struct_pack_args.push_back(std::move(unpack)); @@ -125,6 +200,13 @@ static BoundStatement CopyToJSONPlan(Binder &binder, CopyStatement &stmt) { to_json_args.push_back(std::move(struct_pack)); select_node.select_list.push_back(make_uniq("to_json", std::move(to_json_args))); + // Keep the partition columns as separate columns so the COPY writer can partition on them. The writer routes rows + // into the right files based on these columns but does not write them to disk (WRITE_PARTITION_COLUMNS is handled + // above by keeping the columns inside the JSON object instead). + for (const auto &partition_column : partition_columns) { + select_node.select_list.push_back(make_uniq(Identifier(partition_column))); + } + // Now we can just use the CSV writer copy_info.format = "csv"; copy_info.options = std::move(csv_copy_options); diff --git a/src/duckdb/extension/json/json_functions/json_contains.cpp b/src/duckdb/extension/json/json_functions/json_contains.cpp index 67e786969..6e0744e41 100644 --- a/src/duckdb/extension/json/json_functions/json_contains.cpp +++ b/src/duckdb/extension/json/json_functions/json_contains.cpp @@ -111,8 +111,8 @@ static void JSONContainsFunction(DataChunk &args, ExpressionState &state, Vector auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); - auto &haystacks = args.data[0]; - auto &needles = args.data[1]; + const auto &haystacks = args.data[0]; + const auto &needles = args.data[1]; if (needles.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (ConstantVector::IsNull(needles)) { @@ -120,13 +120,13 @@ static void JSONContainsFunction(DataChunk &args, ExpressionState &state, Vector } auto &needle_str = *ConstantVector::GetData(needles); auto needle_doc = JSONCommon::ReadDocument(needle_str, JSONCommon::READ_FLAG, alc); - UnaryExecutor::Execute(haystacks, result, args.size(), [&](string_t haystack_str) { + UnaryExecutor::Execute(haystacks, result, [&](string_t haystack_str) { auto haystack_doc = JSONCommon::ReadDocument(haystack_str, JSONCommon::READ_FLAG, alc); return JSONContains(haystack_doc->root, needle_doc->root); }); } else { BinaryExecutor::Execute( - haystacks, needles, result, args.size(), [&](string_t haystack_str, string_t needle_str) { + haystacks, needles, result, [&](string_t haystack_str, string_t needle_str) { auto needle_doc = JSONCommon::ReadDocument(needle_str, JSONCommon::READ_FLAG, alc); auto haystack_doc = JSONCommon::ReadDocument(haystack_str, JSONCommon::READ_FLAG, alc); return JSONContains(haystack_doc->root, needle_doc->root); diff --git a/src/duckdb/extension/json/json_functions/json_create.cpp b/src/duckdb/extension/json/json_functions/json_create.cpp index 82dcf58f0..5cfefab60 100644 --- a/src/duckdb/extension/json/json_functions/json_create.cpp +++ b/src/duckdb/extension/json/json_functions/json_create.cpp @@ -78,7 +78,6 @@ static LogicalType GetJSONType(StructNames &const_struct_names, const LogicalTyp case LogicalTypeId::BIT: case LogicalTypeId::BLOB: case LogicalTypeId::VARCHAR: - case LogicalTypeId::AGGREGATE_STATE: case LogicalTypeId::ENUM: case LogicalTypeId::DATE: case LogicalTypeId::INTERVAL: @@ -94,6 +93,8 @@ static LogicalType GetJSONType(StructNames &const_struct_names, const LogicalTyp case LogicalTypeId::BIGNUM: case LogicalTypeId::DECIMAL: return type; + case LogicalTypeId::VARIANT: + return LogicalType::JSON(); case LogicalTypeId::LIST: return LogicalType::LIST(GetJSONType(const_struct_names, ListType::GetChildType(type))); case LogicalTypeId::ARRAY: @@ -103,7 +104,7 @@ static LogicalType GetJSONType(StructNames &const_struct_names, const LogicalTyp case LogicalTypeId::STRUCT: { child_list_t child_types; for (const auto &child_type : StructType::GetChildTypes(type)) { - const_struct_names.Insert(child_type.first); + const_struct_names.Insert(child_type.first.GetIdentifierName()); child_types.emplace_back(child_type.first, GetJSONType(const_struct_names, child_type.second)); } return LogicalType::STRUCT(child_types); @@ -116,7 +117,7 @@ static LogicalType GetJSONType(StructNames &const_struct_names, const LogicalTyp for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(type); member_idx++) { auto &member_name = UnionType::GetMemberName(type, member_idx); auto &member_type = UnionType::GetMemberType(type, member_idx); - const_struct_names.Insert(member_name); + const_struct_names.Insert(member_name.GetIdentifierName()); member_types.emplace_back(member_name, GetJSONType(const_struct_names, member_type)); } return LogicalType::UNION(member_types); @@ -130,6 +131,9 @@ static LogicalType GetJSONType(StructNames &const_struct_names, const LogicalTyp static unique_ptr JSONCreateBindParams(BoundScalarFunction &bound_function, vector> &arguments, bool object) { StructNames const_struct_names; + auto &bound_arguments = bound_function.GetArguments(); + bound_arguments.clear(); + bound_arguments.reserve(arguments.size()); for (idx_t i = 0; i < arguments.size(); i++) { auto &type = arguments[i]->GetReturnType(); if (arguments[i]->HasParameter()) { @@ -139,10 +143,10 @@ static unique_ptr JSONCreateBindParams(BoundScalarFunction &bound_ throw BinderException("json_object() keys must be VARCHAR, add an explicit cast to argument \"%s\"", arguments[i]->GetName()); } - bound_function.GetArguments().push_back(LogicalType::VARCHAR); + bound_arguments.push_back(LogicalType::VARCHAR); } else { // Value, cast to types that we can put in JSON - bound_function.GetArguments().push_back(GetJSONType(const_struct_names, type)); + bound_arguments.push_back(GetJSONType(const_struct_names, type)); } } return make_uniq(std::move(const_struct_names)); @@ -278,7 +282,7 @@ inline yyjson_mut_val *CreateJSONValueFromJSON(yyjson_mut_doc *doc, const string static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_mut_val *vals[], Vector &value_v, idx_t count); -static void AddKeyValuePairs(yyjson_mut_doc *doc, yyjson_mut_val *objs[], Vector &key_v, yyjson_mut_val *vals[], +static void AddKeyValuePairs(yyjson_mut_doc *doc, yyjson_mut_val *objs[], const Vector &key_v, yyjson_mut_val *vals[], idx_t count) { auto keys = key_v.Values(); for (idx_t i = 0; i < count; i++) { @@ -292,7 +296,7 @@ static void AddKeyValuePairs(yyjson_mut_doc *doc, yyjson_mut_val *objs[], Vector } static void CreateKeyValuePairs(const StructNames &names, yyjson_mut_doc *doc, yyjson_mut_val *objs[], - yyjson_mut_val *vals[], Vector &key_v, Vector &value_v, idx_t count) { + yyjson_mut_val *vals[], const Vector &key_v, Vector &value_v, idx_t count) { CreateValues(names, doc, vals, value_v, count); AddKeyValuePairs(doc, objs, key_v, vals, count); } @@ -304,7 +308,7 @@ static void CreateValuesNull(yyjson_mut_doc *doc, yyjson_mut_val *vals[], idx_t } template -static void TemplatedCreateValues(yyjson_mut_doc *doc, yyjson_mut_val *vals[], Vector &value_v, idx_t count) { +static void TemplatedCreateValues(yyjson_mut_doc *doc, yyjson_mut_val *vals[], const Vector &value_v, idx_t count) { UnifiedVectorFormat value_data; value_v.ToUnifiedFormat(value_data); auto values = UnifiedVectorFormat::GetData(value_data); @@ -323,7 +327,7 @@ static void TemplatedCreateValues(yyjson_mut_doc *doc, yyjson_mut_val *vals[], V } } -static void CreateRawValues(yyjson_mut_doc *doc, yyjson_mut_val *vals[], Vector &value_v, idx_t count) { +static void CreateRawValues(yyjson_mut_doc *doc, yyjson_mut_val *vals[], const Vector &value_v, idx_t count) { UnifiedVectorFormat value_data; value_v.ToUnifiedFormat(value_data); auto values = UnifiedVectorFormat::GetData(value_data); @@ -351,7 +355,7 @@ static void CreateValuesStruct(const StructNames &names, yyjson_mut_doc *doc, yy // Add the key/value pairs to the values auto &entries = StructVector::GetEntries(value_v); for (idx_t entry_i = 0; entry_i < entries.size(); entry_i++) { - auto &struct_key_v = names.Get(StructType::GetChildName(value_v.GetType(), entry_i), count); + auto &struct_key_v = names.Get(StructType::GetChildName(value_v.GetType(), entry_i).GetIdentifierName(), count); auto &struct_val_v = entries[entry_i]; CreateKeyValuePairs(names, doc, vals, nested_vals, struct_key_v, struct_val_v, count); } @@ -433,7 +437,8 @@ static void CreateValuesUnion(const StructNames &names, yyjson_mut_doc *doc, yyj // Add the key/value pairs to the values for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(value_v.GetType()); member_idx++) { auto &member_val_v = UnionVector::GetMember(value_v, member_idx); - auto &member_key_v = names.Get(UnionType::GetMemberName(value_v.GetType(), member_idx), count); + auto &member_key_v = + names.Get(UnionType::GetMemberName(value_v.GetType(), member_idx).GetIdentifierName(), count); // This implementation is not optimal since we convert the entire member vector, // and then skip the rows not matching the tag afterwards. @@ -643,14 +648,13 @@ static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_m case LogicalTypeId::VALIDITY: case LogicalTypeId::TABLE: case LogicalTypeId::LAMBDA: - case LogicalTypeId::AGGREGATE_STATE: throw InternalException("Unsupported type arrived at JSON create function"); } } static void ObjectFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); @@ -679,7 +683,7 @@ static void ObjectFunction(DataChunk &args, ExpressionState &state, Vector &resu static void ArrayFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); @@ -737,7 +741,7 @@ static void ToJSONFunctionInternal(const StructNames &names, Vector &input, cons static void ToJSONFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); diff --git a/src/duckdb/extension/json/json_functions/json_deep_merge.cpp b/src/duckdb/extension/json/json_functions/json_deep_merge.cpp index daa335a7d..fa76975bf 100644 --- a/src/duckdb/extension/json/json_functions/json_deep_merge.cpp +++ b/src/duckdb/extension/json/json_functions/json_deep_merge.cpp @@ -54,10 +54,10 @@ static yyjson_mut_val *DeepMerge(yyjson_mut_doc *doc, yyjson_mut_val *orig, yyjs return builder; } -static inline void DeepMergeReadObjects(yyjson_mut_doc *doc, Vector &input, yyjson_mut_val *objs[], const idx_t count) { +static inline void DeepMergeReadObjects(yyjson_mut_doc *doc, const Vector &input, yyjson_mut_val *objs[]) { + const idx_t count = input.size(); UnifiedVectorFormat input_data; - auto &input_vector = input; - input_vector.ToUnifiedFormat(input_data); + input.ToUnifiedFormat(input_data); auto inputs = UnifiedVectorFormat::GetData(input_data); for (idx_t i = 0; i < count; i++) { @@ -79,11 +79,11 @@ static void DeepMergeFunction(DataChunk &args, ExpressionState &state, Vector &r const auto count = args.size(); auto origs = JSONCommon::AllocateArray(alc, count); - DeepMergeReadObjects(doc, args.data[0], origs, count); + DeepMergeReadObjects(doc, args.data[0], origs); auto patches = JSONCommon::AllocateArray(alc, count); for (idx_t arg_idx = 1; arg_idx < args.data.size(); arg_idx++) { - DeepMergeReadObjects(doc, args.data[arg_idx], patches, count); + DeepMergeReadObjects(doc, args.data[arg_idx], patches); for (idx_t i = 0; i < count; i++) { if (patches[i] == nullptr) { origs[i] = nullptr; diff --git a/src/duckdb/extension/json/json_functions/json_merge_patch.cpp b/src/duckdb/extension/json/json_functions/json_merge_patch.cpp index 230bca7b7..2912b0ec4 100644 --- a/src/duckdb/extension/json/json_functions/json_merge_patch.cpp +++ b/src/duckdb/extension/json/json_functions/json_merge_patch.cpp @@ -14,11 +14,11 @@ static inline yyjson_mut_val *MergePatch(yyjson_mut_doc *doc, yyjson_mut_val *or return yyjson_mut_merge_patch(doc, orig, patch); } -static inline void ReadObjects(yyjson_mut_doc *doc, Vector &input, yyjson_mut_val *objs[], const idx_t count) { +static inline void ReadObjects(yyjson_mut_doc *doc, const Vector &input, yyjson_mut_val *objs[]) { auto entries = input.Values(); // Read the documents - for (idx_t i = 0; i < count; i++) { + for (idx_t i = 0; i < input.size(); i++) { auto entry = entries[i]; if (!entry.IsValid()) { objs[i] = nullptr; @@ -39,12 +39,12 @@ static void MergePatchFunction(DataChunk &args, ExpressionState &state, Vector & // Read the first json arg auto origs = JSONCommon::AllocateArray(alc, count); - ReadObjects(doc, args.data[0], origs, count); + ReadObjects(doc, args.data[0], origs); // Read the next json args one by one and merge them into the first json arg auto patches = JSONCommon::AllocateArray(alc, count); for (idx_t arg_idx = 1; arg_idx < args.data.size(); arg_idx++) { - ReadObjects(doc, args.data[arg_idx], patches, count); + ReadObjects(doc, args.data[arg_idx], patches); for (idx_t i = 0; i < count; i++) { if (patches[i] == nullptr) { // Next json arg is NULL, obj becomes NULL diff --git a/src/duckdb/extension/json/json_functions/json_serialize_plan.cpp b/src/duckdb/extension/json/json_functions/json_serialize_plan.cpp index a2d2b1418..95bdfd785 100644 --- a/src/duckdb/extension/json/json_functions/json_serialize_plan.cpp +++ b/src/duckdb/extension/json/json_functions/json_serialize_plan.cpp @@ -113,10 +113,10 @@ static bool OperatorSupportsSerialization(LogicalOperator &op, string &operator_ static void JsonSerializePlanFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &local_state = JSONFunctionLocalState::ResetAndGet(state); auto alc = local_state.json_allocator->GetYYAlc(); - auto &inputs = args.data[0]; + const auto &inputs = args.data[0]; auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); if (!state.HasContext()) { throw InvalidInputException("json_serialize_plan: No client context available"); @@ -124,7 +124,7 @@ static void JsonSerializePlanFunction(DataChunk &args, ExpressionState &state, V auto &context = state.GetContext(); auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(inputs, result, [&](string_t input) { auto doc = JSONCommon::CreateDocument(alc); auto result_obj = yyjson_mut_obj(doc); yyjson_mut_doc_set_root(doc, result_obj); @@ -201,23 +201,19 @@ static void JsonSerializePlanFunction(DataChunk &args, ExpressionState &state, V ScalarFunctionSet JSONFunctions::GetSerializePlanFunction() { ScalarFunctionSet set("json_serialize_plan"); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::JSON(), JsonSerializePlanFunction, - JsonSerializePlanBind, nullptr, JSONFunctionLocalState::Init)); + ScalarFunction func({}, LogicalType::JSON(), JsonSerializePlanFunction, JsonSerializePlanBind, nullptr, + JSONFunctionLocalState::Init); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BOOLEAN}, LogicalType::JSON(), - JsonSerializePlanFunction, JsonSerializePlanBind, nullptr, - JSONFunctionLocalState::Init)); + func.GetSignature() + .AddParameter("sql", LogicalType::VARCHAR) + .AddParameter("skip_null", LogicalType::BOOLEAN, Value::BOOLEAN(false)) + .AddParameter("skip_empty", LogicalType::BOOLEAN, Value::BOOLEAN(false)) + .AddParameter("skip_default", LogicalType::BOOLEAN, Value::BOOLEAN(false)) + .AddParameter("format", LogicalType::BOOLEAN, Value::BOOLEAN(false)) + .AddParameter("optimize", LogicalType::BOOLEAN, Value::BOOLEAN(false)); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BOOLEAN, LogicalType::BOOLEAN}, - LogicalType::JSON(), JsonSerializePlanFunction, JsonSerializePlanBind, nullptr, - JSONFunctionLocalState::Init)); + set.AddFunction(std::move(func)); - set.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN}, LogicalType::JSON(), - JsonSerializePlanFunction, JsonSerializePlanBind, nullptr, JSONFunctionLocalState::Init)); - set.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN}, - LogicalType::JSON(), JsonSerializePlanFunction, JsonSerializePlanBind, nullptr, JSONFunctionLocalState::Init)); return set; } diff --git a/src/duckdb/extension/json/json_functions/json_serialize_sql.cpp b/src/duckdb/extension/json/json_functions/json_serialize_sql.cpp index 78d8a54b4..3c3cbd923 100644 --- a/src/duckdb/extension/json/json_functions/json_serialize_sql.cpp +++ b/src/duckdb/extension/json/json_functions/json_serialize_sql.cpp @@ -6,6 +6,7 @@ #include "json_deserializer.hpp" #include "json_functions.hpp" #include "json_serializer.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" namespace duckdb { @@ -86,13 +87,13 @@ static unique_ptr JsonSerializeBind(BindScalarFunctionInput &input static void JsonSerializeFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &local_state = JSONFunctionLocalState::ResetAndGet(state); auto alc = local_state.json_allocator->GetYYAlc(); - auto &inputs = args.data[0]; + const auto &inputs = args.data[0]; auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(inputs, result, [&](string_t input) { auto doc = JSONCommon::CreateDocument(alc); auto result_obj = yyjson_mut_obj(doc); yyjson_mut_doc_set_root(doc, result_obj); @@ -110,8 +111,7 @@ static void JsonSerializeFunction(DataChunk &args, ExpressionState &state, Vecto auto &select = statement->Cast(); auto options = make_uniq(); - options->serialization_compatibility = - state.GetContext().db->config.options.serialization_compatibility; + options->storage_compatibility = state.GetContext().db->config.options.storage_compatibility; auto json = JsonSerializer::Serialize(select, doc, info.skip_if_null, info.skip_if_empty, info.skip_if_default, *options); @@ -154,23 +154,18 @@ static void JsonSerializeFunction(DataChunk &args, ExpressionState &state, Vecto ScalarFunctionSet JSONFunctions::GetSerializeSqlFunction() { ScalarFunctionSet set("json_serialize_sql"); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::JSON(), JsonSerializeFunction, - JsonSerializeBind, nullptr, JSONFunctionLocalState::Init)); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BOOLEAN}, LogicalType::JSON(), - JsonSerializeFunction, JsonSerializeBind, nullptr, JSONFunctionLocalState::Init)); + ScalarFunction func({}, LogicalType::JSON(), JsonSerializeFunction, JsonSerializeBind, nullptr, + JSONFunctionLocalState::Init); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BOOLEAN, LogicalType::BOOLEAN}, - LogicalType::JSON(), JsonSerializeFunction, JsonSerializeBind, nullptr, - JSONFunctionLocalState::Init)); + func.GetSignature() + .AddParameter("sql", LogicalType::VARCHAR) + .AddParameter("skip_null", LogicalType::BOOLEAN, Value::BOOLEAN(false)) + .AddParameter("skip_empty", LogicalType::BOOLEAN, Value::BOOLEAN(false)) + .AddParameter("skip_default", LogicalType::BOOLEAN, Value::BOOLEAN(false)) + .AddParameter("format", LogicalType::BOOLEAN, Value::BOOLEAN(false)); - set.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN}, LogicalType::JSON(), - JsonSerializeFunction, JsonSerializeBind, nullptr, JSONFunctionLocalState::Init)); - - set.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN, LogicalType::BOOLEAN}, - LogicalType::JSON(), JsonSerializeFunction, JsonSerializeBind, nullptr, JSONFunctionLocalState::Init)); + set.AddFunction(std::move(func)); return set; } @@ -215,6 +210,11 @@ static vector> DeserializeSelectStatement(string_t i if (!stmt->node) { throw ParserException("Error parsing json: no select node found in json"); } + ParsedExpressionIterator::EnumerateQueryNodeChildren(*stmt->node, [](unique_ptr &child) { + if (!child) { + throw ParserException("Error parsing json: null expression found in json"); + } + }); result.push_back(std::move(stmt)); } @@ -227,10 +227,10 @@ static vector> DeserializeSelectStatement(string_t i static void JsonDeserializeFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &local_state = JSONFunctionLocalState::ResetAndGet(state); auto alc = local_state.json_allocator->GetYYAlc(); - auto &inputs = args.data[0]; + const auto &inputs = args.data[0]; auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(inputs, result, [&](string_t input) { auto stmts = DeserializeSelectStatement(input, alc); // Combine all statements into a single semicolon separated string string str; @@ -247,8 +247,10 @@ static void JsonDeserializeFunction(DataChunk &args, ExpressionState &state, Vec ScalarFunctionSet JSONFunctions::GetDeserializeSqlFunction() { ScalarFunctionSet set("json_deserialize_sql"); - set.AddFunction(ScalarFunction({LogicalType::JSON()}, LogicalType::VARCHAR, JsonDeserializeFunction, nullptr, - nullptr, JSONFunctionLocalState::Init)); + auto function = ScalarFunction({LogicalType::JSON()}, LogicalType::VARCHAR, JsonDeserializeFunction, nullptr, + nullptr, JSONFunctionLocalState::Init); + function.SetFallible(); + set.AddFunction(std::move(function)); return set; } diff --git a/src/duckdb/extension/json/json_functions/json_strip_nulls.cpp b/src/duckdb/extension/json/json_functions/json_strip_nulls.cpp index 0403f227d..d86c49ecc 100644 --- a/src/duckdb/extension/json/json_functions/json_strip_nulls.cpp +++ b/src/duckdb/extension/json/json_functions/json_strip_nulls.cpp @@ -34,8 +34,8 @@ static void StripNullsFunction(DataChunk &args, ExpressionState &state, Vector & auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); - auto &inputs = args.data[0]; - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) { + const auto &inputs = args.data[0]; + UnaryExecutor::Execute(inputs, result, [&](string_t input) { auto doc = JSONCommon::ReadDocument(input, JSONCommon::READ_FLAG, alc); auto mut_doc = yyjson_doc_mut_copy(doc, alc); auto root = yyjson_mut_doc_get_root(mut_doc); diff --git a/src/duckdb/extension/json/json_functions/json_structure.cpp b/src/duckdb/extension/json/json_functions/json_structure.cpp index 19e382d07..b2f4600aa 100644 --- a/src/duckdb/extension/json/json_functions/json_structure.cpp +++ b/src/duckdb/extension/json/json_functions/json_structure.cpp @@ -287,7 +287,7 @@ void JSONStructureNode::EliminateCandidateTypes(const idx_t vec_count, Vector &s } template -bool TryParse(Vector &string_vector, StrpTimeFormat &format, const idx_t count) { +bool TryParse(const Vector &string_vector, StrpTimeFormat &format, const idx_t count) { const auto strings = FlatVector::GetData(string_vector); const auto &validity = FlatVector::Validity(string_vector); @@ -336,7 +336,6 @@ bool JSONStructureNode::EliminateCandidateFormats(const idx_t vec_count, Vector } if (success) { - date_format_map.ShrinkFormatsToSize(type, i); return true; } } @@ -658,7 +657,7 @@ static double CalculateTypeSimilarity(const LogicalType &merged, const LogicalTy const auto &merged_child_types = StructType::GetChildTypes(merged); const auto &type_child_types = StructType::GetChildTypes(type); - unordered_map merged_child_types_map; + identifier_map_t merged_child_types_map; for (const auto &merged_child : merged_child_types) { merged_child_types_map.emplace(merged_child.first, merged_child.second); } @@ -690,6 +689,9 @@ static double CalculateTypeSimilarity(const LogicalType &merged, const LogicalTy } case LogicalTypeId::LIST: { // Only lists can be merged into a list + if (type.id() != LogicalTypeId::LIST) { + return -1; + } D_ASSERT(type.id() == LogicalTypeId::LIST); const auto &merged_child_type = ListType::GetChildType(merged); const auto &type_child_type = ListType::GetChildType(type); diff --git a/src/duckdb/extension/json/json_functions/json_table_in_out.cpp b/src/duckdb/extension/json/json_functions/json_table_in_out.cpp index 435352ed9..b2c6364f2 100644 --- a/src/duckdb/extension/json/json_functions/json_table_in_out.cpp +++ b/src/duckdb/extension/json/json_functions/json_table_in_out.cpp @@ -250,7 +250,7 @@ static void InitializeLocalState(JSONTableInOutLocalState &lstate, DataChunk &in // Parse path, default to root if not given Value path_value("$"); if (input.data.size() > 1) { - auto &path_vector = input.data[1]; + const auto &path_vector = input.data[1]; if (ConstantVector::IsNull(path_vector)) { return; } diff --git a/src/duckdb/extension/json/json_functions/json_transform.cpp b/src/duckdb/extension/json/json_functions/json_transform.cpp index bd5f06e93..298e690e7 100644 --- a/src/duckdb/extension/json/json_functions/json_transform.cpp +++ b/src/duckdb/extension/json/json_functions/json_transform.cpp @@ -302,8 +302,26 @@ static bool TransformFromString(yyjson_val *vals[], Vector &result, const idx_t return success; } +static bool TryOneFormat(const StrpTimeFormat &fmt, LogicalTypeId type_id, const string_t &input, Vector &result, + idx_t i, string &error_message) { + if (type_id == LogicalTypeId::DATE) { + date_t val; + if (!TryParseDate::Operation(fmt, input, val, error_message)) { + return false; + } + FlatVector::GetDataMutable(result)[i] = val; + } else { + timestamp_t val; + if (!TryParseTimeStamp::Operation(fmt, input, val, error_message)) { + return false; + } + FlatVector::GetDataMutable(result)[i] = val; + } + return true; +} + template -static bool TransformStringWithFormat(Vector &string_vector, const StrpTimeFormat &format, const idx_t count, +static bool TransformStringWithFormat(const Vector &string_vector, const StrpTimeFormat &format, const idx_t count, Vector &result, JSONTransformOptions &options) { const auto source_strings = FlatVector::GetData(string_vector); const auto &source_validity = FlatVector::Validity(string_vector); @@ -334,22 +352,47 @@ static bool TransformFromStringWithFormat(yyjson_val *vals[], Vector &result, co success = false; } - const auto &result_type = result.GetType().id(); - auto &format = options.date_format_map->GetFormat(result_type); + const auto result_type_id = result.GetType().id(); + D_ASSERT(result_type_id == LogicalTypeId::DATE || result_type_id == LogicalTypeId::TIMESTAMP); + const auto &formats = options.date_format_map->GetFormats(result_type_id); + const auto source_strings = FlatVector::GetData(string_vector); + const auto &source_validity = FlatVector::Validity(string_vector); + auto &target_validity = FlatVector::ValidityMutable(result); - switch (result_type) { - case LogicalTypeId::DATE: - if (!TransformStringWithFormat(string_vector, format, count, result, options)) { - success = false; + // Probe first non-null value to find the most-specific matching format for this vector + idx_t matched_idx = DConstants::INVALID_INDEX; + for (idx_t i = 0; i < count && matched_idx == DConstants::INVALID_INDEX; i++) { + if (!source_validity.RowIsValid(i)) { + continue; } - break; - case LogicalTypeId::TIMESTAMP: - if (!TransformStringWithFormat(string_vector, format, count, result, options)) { - success = false; + string error_message; + for (idx_t j = formats.size(); j > 0; j--) { + if (TryOneFormat(formats[j - 1], result_type_id, source_strings[i], result, i, error_message)) { + matched_idx = j - 1; + break; + } + } + } + + const idx_t parse_limit = (matched_idx != DConstants::INVALID_INDEX) ? matched_idx + 1 : formats.size(); + for (idx_t i = 0; i < count; i++) { + if (!source_validity.RowIsValid(i)) { + target_validity.SetInvalid(i); + continue; + } + bool parsed = false; + string error_message; + // Reverse iter formats, beginning at matched_idx until one parses + for (idx_t j = parse_limit; j > 0 && !parsed; j--) { + parsed = TryOneFormat(formats[j - 1], result_type_id, source_strings[i], result, i, error_message); + } + if (!parsed) { + target_validity.SetInvalid(i); + if (success && options.strict_cast) { + options.object_index = i; + success = false; + } } - break; - default: - throw InternalException("No date/timestamp formats for %s", EnumUtil::ToString(result.GetType().id())); } return success; } @@ -515,7 +558,7 @@ static bool TransformObjectInternal(yyjson_val *objects[], yyjson_alc *alc, Vect const auto actual_i = column_index ? column_index->GetChildIndex(child_i).GetPrimaryIndex() : child_i; projected_indices.insert(actual_i); - child_names.push_back(StructType::GetChildName(result.GetType(), actual_i)); + child_names.emplace_back(StructType::GetChildName(result.GetType(), actual_i)); child_vectors.push_back(&child_vs[actual_i]); } @@ -796,7 +839,7 @@ static bool TransformValueIntoUnion(yyjson_val **vals, yyjson_alc *alc, Vector & auto fields = UnionType::CopyMemberTypes(type); vector names; for (const auto &field : fields) { - names.push_back(field.first); + names.emplace_back(field.first); } bool success = true; @@ -919,7 +962,6 @@ bool JSONTransform::Transform(yyjson_val *vals[], yyjson_alc *alc, Vector &resul throw InternalException("Unimplemented physical type for decimal"); } } - case LogicalTypeId::AGGREGATE_STATE: case LogicalTypeId::ENUM: case LogicalTypeId::DATE: case LogicalTypeId::INTERVAL: @@ -953,7 +995,7 @@ bool JSONTransform::Transform(yyjson_val *vals[], yyjson_alc *alc, Vector &resul } } -static bool TransformFunctionInternal(Vector &input, const idx_t count, Vector &result, yyjson_alc *alc, +static bool TransformFunctionInternal(const Vector &input, const idx_t count, Vector &result, yyjson_alc *alc, JSONTransformOptions &options) { UnifiedVectorFormat input_data; input.ToUnifiedFormat(input_data); diff --git a/src/duckdb/extension/json/json_functions/json_valid.cpp b/src/duckdb/extension/json/json_functions/json_valid.cpp index 1b4c0dbe4..73c80b2f7 100644 --- a/src/duckdb/extension/json/json_functions/json_valid.cpp +++ b/src/duckdb/extension/json/json_functions/json_valid.cpp @@ -5,8 +5,8 @@ namespace duckdb { static void ValidFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &lstate = JSONFunctionLocalState::ResetAndGet(state); auto alc = lstate.json_allocator->GetYYAlc(); - auto &inputs = args.data[0]; - UnaryExecutor::Execute(inputs, result, args.size(), [&](string_t input) { + const auto &inputs = args.data[0]; + UnaryExecutor::Execute(inputs, result, [&](string_t input) { return JSONCommon::ReadDocumentUnsafe(input, JSONCommon::READ_FLAG, alc); }); } diff --git a/src/duckdb/extension/json/json_functions/read_json.cpp b/src/duckdb/extension/json/json_functions/read_json.cpp index 65209f1f4..911143ba6 100644 --- a/src/duckdb/extension/json/json_functions/read_json.cpp +++ b/src/duckdb/extension/json/json_functions/read_json.cpp @@ -15,7 +15,7 @@ static inline LogicalType RemoveDuplicateStructKeys(const LogicalType &type, con case_insensitive_set_t child_names; child_list_t child_types; for (auto &child_type : StructType::GetChildTypes(type)) { - auto insert_success = child_names.insert(child_type.first).second; + auto insert_success = child_names.insert(child_type.first.GetIdentifierName()).second; if (!insert_success) { if (ignore_errors) { continue; @@ -147,7 +147,7 @@ class JSONSchemaTask : public BaseExecutorTask { }; void JSONScan::AutoDetect(ClientContext &context, MultiFileBindData &bind_data, vector &return_types, - vector &names) { + vector &names) { auto &json_data = bind_data.bind_data->Cast(); MutableDateFormatMap date_format_map(*json_data.date_format_map); @@ -253,7 +253,7 @@ TableFunction JSONFunctions::GetReadJSONTableFunction(shared_ptr f TableFunctionSet CreateJSONFunctionInfo(string name, shared_ptr info) { auto table_function = JSONFunctions::GetReadJSONTableFunction(std::move(info)); - table_function.name = std::move(name); + table_function.SetName(Identifier(std::move(name))); table_function.named_parameters["maximum_depth"] = LogicalType::BIGINT; table_function.named_parameters["field_appearance_threshold"] = LogicalType::DOUBLE; table_function.named_parameters["convert_strings_to_integers"] = LogicalType::BOOLEAN; diff --git a/src/duckdb/extension/json/json_functions/read_json_objects.cpp b/src/duckdb/extension/json/json_functions/read_json_objects.cpp index fd2818ffd..6189e4f08 100644 --- a/src/duckdb/extension/json/json_functions/read_json_objects.cpp +++ b/src/duckdb/extension/json/json_functions/read_json_objects.cpp @@ -6,7 +6,7 @@ namespace duckdb { -TableFunction GetReadJSONObjectsTableFunction(string name, shared_ptr function_info) { +TableFunction GetReadJSONObjectsTableFunction(Identifier name, shared_ptr function_info) { MultiFileFunction table_function(std::move(name)); JSONScan::TableFunctionDefaults(table_function); table_function.function_info = std::move(function_info); diff --git a/src/duckdb/extension/json/json_multi_file_info.cpp b/src/duckdb/extension/json/json_multi_file_info.cpp index e11362764..607bf62b9 100644 --- a/src/duckdb/extension/json/json_multi_file_info.cpp +++ b/src/duckdb/extension/json/json_multi_file_info.cpp @@ -86,7 +86,7 @@ bool JSONMultiFileInfo::ParseOption(ClientContext &context, const string &key, c if (val.IsNull()) { throw BinderException("read_json \"columns\" parameter type specification cannot be NULL."); } - options.name_list.push_back(name); + options.name_list.emplace_back(name); if (val.type().id() != LogicalTypeId::VARCHAR) { throw BinderException("read_json \"columns\" parameter type specification must be VARCHAR."); } @@ -267,12 +267,12 @@ unique_ptr JSONMultiFileInfo::InitializeBindData(MultiFileBin return std::move(json_data); } -void JSONMultiFileInfo::BindReader(ClientContext &context, vector &return_types, vector &names, +void JSONMultiFileInfo::BindReader(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) { auto &json_data = bind_data.bind_data->Cast(); auto &options = json_data.options; - names = options.name_list; + names = StringsToIdentifiers(options.name_list); return_types = options.sql_type_list; if (options.record_type == JSONRecordType::AUTO_DETECT && return_types.size() > 1) { // More than one specified column implies records @@ -303,7 +303,7 @@ void JSONMultiFileInfo::BindReader(ClientContext &context, vector & JSONScan::AutoDetect(context, bind_data, return_types, names); D_ASSERT(return_types.size() == names.size()); } - json_data.key_names = names; + json_data.key_names = IdentifiersToStrings(names); bind_data.multi_file_reader->BindOptions(bind_data.file_options, *bind_data.file_list, return_types, names, bind_data.reader_bind); @@ -320,12 +320,12 @@ void JSONMultiFileInfo::BindReader(ClientContext &context, vector & // JSON may contain columns such as "id" and "Id", which are duplicates for us due to case-insensitivity // We rename them so we can parse the file anyway. Note that we can't change json_data.key_names, // because the JSON reader gets columns by exact name, not position - case_insensitive_map_t name_collision_count; + identifier_map_t name_collision_count; for (auto &col_name : names) { // Taken from CSV header_detection.cpp while (name_collision_count.find(col_name) != name_collision_count.end()) { name_collision_count[col_name] += 1; - col_name = col_name + "_" + to_string(name_collision_count[col_name]); + col_name = Identifier(col_name + "_" + to_string(name_collision_count[col_name])); } name_collision_count[col_name] = 0; } @@ -349,7 +349,7 @@ void JSONMultiFileInfo::BindReader(ClientContext &context, vector & // re-use readers for (auto &union_reader : bind_data.union_readers) { auto &json_reader = union_reader->reader->Cast(); - union_reader->names = names; + union_reader->names = IdentifiersToStrings(names); union_reader->types = return_types; union_reader->reader->columns = MultiFileColumnDefinition::ColumnsFromNamesAndTypes(names, return_types); json_reader.Reset(); @@ -431,7 +431,8 @@ shared_ptr JSONMultiFileInfo::CreateReader(ClientContext &contex const MultiFileBindData &bind_data_p) { auto &json_data = bind_data_p.bind_data->Cast(); auto reader = make_shared_ptr(context, json_data.options, union_data.GetFileName()); - reader->columns = MultiFileColumnDefinition::ColumnsFromNamesAndTypes(union_data.names, union_data.types); + reader->columns = + MultiFileColumnDefinition::ColumnsFromNamesAndTypes(StringsToIdentifiers(union_data.names), union_data.types); return std::move(reader); } diff --git a/src/duckdb/extension/json/json_scan.cpp b/src/duckdb/extension/json/json_scan.cpp index 5e86d524a..a090f4b91 100644 --- a/src/duckdb/extension/json/json_scan.cpp +++ b/src/duckdb/extension/json/json_scan.cpp @@ -33,7 +33,7 @@ void JSONScanData::InitializeFormats(bool auto_detect_p) { {LogicalTypeId::DATE, {"%m-%d-%Y", "%m-%d-%y", "%d-%m-%Y", "%d-%m-%y", "%Y-%m-%d", "%y-%m-%d"}}, {LogicalTypeId::TIMESTAMP, {"%Y-%m-%d %H:%M:%S.%f", "%m-%d-%Y %I:%M:%S %p", "%m-%d-%y %I:%M:%S %p", "%d-%m-%Y %H:%M:%S", - "%d-%m-%y %H:%M:%S", "%Y-%m-%d %H:%M:%S", "%y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", + "%d-%m-%y %H:%M:%S", "%Y-%m-%d %H:%M:%S", "%y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%S%z", "%Y-%m-%dT%H:%M:%S.%f%z"}}, }; diff --git a/src/duckdb/extension/json/json_serializer.cpp b/src/duckdb/extension/json/json_serializer.cpp index 55ce53f3e..0b38354c3 100644 --- a/src/duckdb/extension/json/json_serializer.cpp +++ b/src/duckdb/extension/json/json_serializer.cpp @@ -21,23 +21,29 @@ void JsonSerializer::PushValue(yyjson_mut_val *val) { void JsonSerializer::OnPropertyBegin(const field_id_t, const char *tag) { current_tag = yyjson_mut_strcpy(doc, tag); + property_can_be_omitted.push_back(false); } void JsonSerializer::OnPropertyEnd() { + property_can_be_omitted.pop_back(); } -void JsonSerializer::OnOptionalPropertyBegin(const field_id_t, const char *tag, bool) { - current_tag = yyjson_mut_strcpy(doc, tag); +void JsonSerializer::OnOptionalPropertyBegin(const field_id_t, const char *tag, bool present) { + if (present) { + current_tag = yyjson_mut_strcpy(doc, tag); + } + property_can_be_omitted.push_back(present); } void JsonSerializer::OnOptionalPropertyEnd(bool) { + property_can_be_omitted.pop_back(); } //------------------------------------------------------------------------- // Nested Types //------------------------------------------------------------------------- void JsonSerializer::OnNullableBegin(bool present) { - if (!present && !skip_if_null) { + if (!present && !(skip_if_null && yyjson_mut_is_obj(Current()) && CurrentPropertyCanBeOmitted())) { WriteNull(); } } @@ -47,32 +53,39 @@ void JsonSerializer::OnNullableEnd() { void JsonSerializer::OnListBegin(idx_t count) { auto new_value = yyjson_mut_arr(doc); + auto can_be_omitted = yyjson_mut_is_obj(Current()) && CurrentPropertyCanBeOmitted(); // We always push a value to the stack, we just don't add it as a child to the current value // if skipping empty. Even though it is "unnecessary" to create an empty value just to discard it, // this allows the rest of the code to keep on like normal. - if (!(count == 0 && skip_if_empty)) { + if (!(count == 0 && skip_if_empty && can_be_omitted)) { PushValue(new_value); } stack.push_back(new_value); + stack_can_be_omitted.push_back(can_be_omitted); } void JsonSerializer::OnListEnd() { stack.pop_back(); + stack_can_be_omitted.pop_back(); } void JsonSerializer::OnObjectBegin() { auto new_value = yyjson_mut_obj(doc); + auto can_be_omitted = yyjson_mut_is_obj(Current()) && CurrentPropertyCanBeOmitted(); PushValue(new_value); stack.push_back(new_value); + stack_can_be_omitted.push_back(can_be_omitted); } void JsonSerializer::OnObjectEnd() { auto obj = Current(); auto count = yyjson_mut_obj_size(obj); + auto can_be_omitted = stack_can_be_omitted.back(); stack.pop_back(); + stack_can_be_omitted.pop_back(); - if (count == 0 && skip_if_empty && !stack.empty()) { + if (count == 0 && skip_if_empty && can_be_omitted && !stack.empty()) { // remove obj from parent since it was empty auto parent = Current(); if (yyjson_mut_is_arr(parent)) { @@ -106,7 +119,7 @@ void JsonSerializer::OnObjectEnd() { // Primitive Types //------------------------------------------------------------------------- void JsonSerializer::WriteNull() { - if (skip_if_null) { + if (skip_if_null && yyjson_mut_is_obj(Current()) && CurrentPropertyCanBeOmitted()) { return; } auto val = yyjson_mut_null(doc); @@ -171,18 +184,33 @@ void JsonSerializer::WriteValue(uhugeint_t value) { stack.pop_back(); } +yyjson_mut_val *DoubleToJSONValue(yyjson_mut_doc *doc, double value) { + if (Value::FloatIsFinite(value)) { + // simple - finite json + return yyjson_mut_real(doc, value); + } + // represent infinities as strings + const char *str; + if (!Value::IsNan(value)) { + str = value < 0 ? "-Infinity" : "Infinity"; + } else { + str = "NAN"; + } + return yyjson_mut_raw(doc, str); +} + void JsonSerializer::WriteValue(float value) { - auto val = yyjson_mut_real(doc, value); + auto val = DoubleToJSONValue(doc, value); PushValue(val); } void JsonSerializer::WriteValue(double value) { - auto val = yyjson_mut_real(doc, value); + auto val = DoubleToJSONValue(doc, value); PushValue(val); } void JsonSerializer::WriteValue(const string &value) { - if (skip_if_empty && value.empty()) { + if (skip_if_empty && value.empty() && yyjson_mut_is_obj(Current()) && CurrentPropertyCanBeOmitted()) { return; } auto val = yyjson_mut_strncpy(doc, value.c_str(), value.size()); @@ -190,7 +218,7 @@ void JsonSerializer::WriteValue(const string &value) { } void JsonSerializer::WriteValue(const string_t value) { - if (skip_if_empty && value.GetSize() == 0) { + if (skip_if_empty && value.GetSize() == 0 && yyjson_mut_is_obj(Current()) && CurrentPropertyCanBeOmitted()) { return; } auto val = yyjson_mut_strncpy(doc, value.GetData(), value.GetSize()); @@ -198,7 +226,7 @@ void JsonSerializer::WriteValue(const string_t value) { } void JsonSerializer::WriteValue(const char *value) { - if (skip_if_empty && strlen(value) == 0) { + if (skip_if_empty && strlen(value) == 0 && yyjson_mut_is_obj(Current()) && CurrentPropertyCanBeOmitted()) { return; } auto val = yyjson_mut_strcpy(doc, value); diff --git a/src/duckdb/extension/loader/dummy_static_extension_loader.cpp b/src/duckdb/extension/loader/dummy_static_extension_loader.cpp index 9275653ea..72792727a 100644 --- a/src/duckdb/extension/loader/dummy_static_extension_loader.cpp +++ b/src/duckdb/extension/loader/dummy_static_extension_loader.cpp @@ -4,6 +4,10 @@ // Link this to libduckdb_static.a to get a working system. namespace duckdb { +ExtensionLoadResult ExtensionHelper::LoadExtension(DuckDB &db, const string &extension) { + return ExtensionLoadResult::NOT_LOADED; +} + void ExtensionHelper::LoadAllExtensions(DuckDB &db) { // nop } diff --git a/src/duckdb/extension/parquet/column_reader.cpp b/src/duckdb/extension/parquet/column_reader.cpp index b659fc0d7..87abad05d 100644 --- a/src/duckdb/extension/parquet/column_reader.cpp +++ b/src/duckdb/extension/parquet/column_reader.cpp @@ -160,7 +160,7 @@ unique_ptr ColumnReader::Stats(idx_t row_group_idx_p, const vect } uint64_t ColumnReader::TotalCompressedSize() { - if (!chunk) { + if (IsSkipped()) { return 0; } @@ -171,8 +171,9 @@ uint64_t ColumnReader::TotalCompressedSize() { // apparently is not the first page of the data. Therefore we determine the address of the first page by taking the // minimum of all page offsets. idx_t ColumnReader::FileOffset() const { - if (!chunk) { - throw std::runtime_error("FileOffset called on ColumnReader with no chunk"); + if (IsSkipped()) { + //! This column reader is skipped + return 0; } auto min_offset = NumericLimits::Maximum(); if (chunk->meta_data.__isset.dictionary_page_offset) { @@ -281,7 +282,8 @@ bool ColumnReader::PageIsFilteredOut(PageHeader &page_hdr, optional_ptrCheckStatistics(*stats) == FilterPropagateResult::FILTER_ALWAYS_FALSE) { + auto &expr_filter = filter->Cast(); + if (stats && expr_filter.CheckStatistics(*stats) == FilterPropagateResult::FILTER_ALWAYS_FALSE) { page_is_filtered_out = true; } } @@ -326,7 +328,8 @@ void ColumnReader::ReadData(const data_ptr_t buffer, const uint32_t buffer_size, } } -void ColumnReader::PrepareRead(optional_ptr filter, optional_ptr filter_state) { +void ColumnReader::PrepareRead(optional_ptr filter, optional_ptr filter_state, + idx_t rows_to_skip) { encoding = ColumnEncoding::INVALID; defined_decoder.reset(); page_is_filtered_out = false; @@ -352,10 +355,21 @@ void ColumnReader::PrepareRead(optional_ptr filter, optional_ } if (PageIsFilteredOut(page_hdr, filter)) { - // this page has been filtered out so we don't need to read it return; } + if (rows_to_skip > 0 && (page_hdr.type == PageType::DATA_PAGE || page_hdr.type == PageType::DATA_PAGE_V2)) { + bool is_v1 = page_hdr.type == PageType::DATA_PAGE; + idx_t page_num_values = + NumericCast(is_v1 ? page_hdr.data_page_header.num_values : page_hdr.data_page_header_v2.num_values); + if (rows_to_skip >= page_num_values) { + trans.Skip(page_hdr.compressed_page_size); + page_is_filtered_out = true; + page_rows_available = page_num_values; + return; + } + } + switch (page_hdr.type) { case PageType::DATA_PAGE_V2: PreparePageV2(page_hdr); @@ -453,7 +467,7 @@ void ColumnReader::PreparePage(PageHeader &page_hdr) { if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) { if (compressed_page_size != NumericCast(page_hdr.uncompressed_page_size)) { - throw std::runtime_error("Page size mismatch"); + throw InternalException("Page size mismatch"); } ReadData(block->ptr, compressed_page_size, page_hdr.type); return; @@ -634,11 +648,11 @@ void ColumnReader::BeginRead(data_ptr_t define_out, data_ptr_t repeat_out) { } idx_t ColumnReader::ReadPageHeaders(idx_t max_read, optional_ptr filter, - optional_ptr filter_state) { + optional_ptr filter_state, idx_t rows_to_skip) { int8_t page_ordinal = 0; while (page_rows_available == 0) { aad_crypto_metadata.page_ordinal = page_ordinal; - PrepareRead(filter, filter_state); + PrepareRead(filter, filter_state, rows_to_skip); page_ordinal++; } return MinValue(MinValue(max_read, page_rows_available), STANDARD_VECTOR_SIZE); @@ -860,7 +874,7 @@ void ColumnReader::ApplyPendingSkips(data_ptr_t define_out, data_ptr_t repeat_ou BeginRead(nullptr, nullptr); while (to_skip > 0) { - auto skip_now = ReadPageHeaders(to_skip); + auto skip_now = ReadPageHeaders(to_skip, nullptr, nullptr, to_skip); if (page_is_filtered_out) { // the page has been filtered out entirely - skip page_rows_available -= skip_now; diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp index 0ab28f104..31ba7f0a3 100644 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ b/src/duckdb/extension/parquet/column_writer.cpp @@ -117,7 +117,8 @@ void ColumnWriterStatistics::WriteGeoStats(duckdb_parquet::GeospatialStatistics //===--------------------------------------------------------------------===// // ColumnWriter //===--------------------------------------------------------------------===// -ColumnWriter::ColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema_p, vector schema_path_p) +ColumnWriter::ColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema_p, + vector schema_path_p) : writer(writer), column_schema(std::move(column_schema_p)), schema_path(std::move(schema_path_p)) { can_have_nulls = column_schema.repetition_type == duckdb_parquet::FieldRepetitionType::OPTIONAL; } @@ -132,7 +133,7 @@ bool ColumnWriter::TryExportPreparedShreddingType(ShreddingType &result) const { if (!child_writer->TryExportPreparedShreddingType(child_shredding_type)) { continue; } - writer_shredding_type.AddChild(child_writer->Schema().name, std::move(child_shredding_type)); + writer_shredding_type.AddChild(Identifier(child_writer->Schema().name), std::move(child_shredding_type)); has_shredding = true; } if (!has_shredding) { @@ -271,13 +272,13 @@ void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterStat //===--------------------------------------------------------------------===// unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, - vector path_in_schema, const LogicalType &type, - const string &name, bool allow_geometry, + vector path_in_schema, const LogicalType &type, + const Identifier &name, bool allow_geometry, optional_ptr field_ids, optional_ptr shredding_types, idx_t max_repeat, idx_t max_define, bool can_have_nulls) { const bool parquet_write_timestamp_as_int96 = writer.WriteTimestampAsInt96(); - path_in_schema.push_back(name); + path_in_schema.emplace_back(name); if (!can_have_nulls) { max_define--; @@ -353,8 +354,7 @@ unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &cont std::move(child_writers)); } - if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION || - type.id() == LogicalTypeId::AGGREGATE_STATE) { + if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { auto struct_column = ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); if (field_id && field_id->set) { @@ -372,7 +372,7 @@ unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &cont allow_geometry, child_field_ids, shredding_type, max_repeat, max_define + 1, true)); } - return make_uniq(writer, std::move(struct_column), std::move(path_in_schema), + return make_uniq(writer, std::move(struct_column), path_in_schema, std::move(child_writers)); } @@ -392,10 +392,9 @@ unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &cont } if (is_list) { - return make_uniq(writer, std::move(list_column), std::move(path_in_schema), - std::move(child_writer)); + return make_uniq(writer, std::move(list_column), path_in_schema, std::move(child_writer)); } else { - return make_uniq(writer, std::move(list_column), std::move(path_in_schema), + return make_uniq(writer, std::move(list_column), path_in_schema, std::move(child_writer)); } } @@ -446,104 +445,99 @@ unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &cont switch (type.id()) { case LogicalTypeId::BOOLEAN: - return make_uniq(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq(writer, std::move(schema), path_in_schema); case LogicalTypeId::TINYINT: - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::SMALLINT: - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::INTEGER: case LogicalTypeId::DATE: - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::BIGINT: case LogicalTypeId::TIME: - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: if (parquet_write_timestamp_as_int96) { return make_uniq>( - writer, std::move(schema), std::move(path_in_schema)); + writer, std::move(schema), path_in_schema); } - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::TIMESTAMP_MS: if (parquet_write_timestamp_as_int96) { return make_uniq>( - writer, std::move(schema), std::move(path_in_schema)); + writer, std::move(schema), path_in_schema); } - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::TIME_TZ: return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::HUGEINT: return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::UHUGEINT: return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_TZ_NS: if (parquet_write_timestamp_as_int96) { return make_uniq>( - writer, std::move(schema), std::move(path_in_schema)); + writer, std::move(schema), path_in_schema); } return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::TIMESTAMP_SEC: if (parquet_write_timestamp_as_int96) { return make_uniq>( - writer, std::move(schema), std::move(path_in_schema)); + writer, std::move(schema), path_in_schema); } return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::UTINYINT: - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::USMALLINT: - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::UINTEGER: - return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::UBIGINT: - return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case LogicalTypeId::FLOAT: return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::DOUBLE: return make_uniq>( - writer, std::move(schema), std::move(path_in_schema)); + writer, std::move(schema), path_in_schema); case LogicalTypeId::DECIMAL: switch (type.InternalType()) { case PhysicalType::INT16: - return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case PhysicalType::INT32: - return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); case PhysicalType::INT64: - return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); default: - return make_uniq(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq(writer, std::move(schema), path_in_schema); } case LogicalTypeId::BLOB: return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::GEOMETRY: return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::VARCHAR: return make_uniq>(writer, std::move(schema), - std::move(path_in_schema)); + path_in_schema); case LogicalTypeId::UUID: return make_uniq>( - writer, std::move(schema), std::move(path_in_schema)); + writer, std::move(schema), path_in_schema); case LogicalTypeId::INTERVAL: return make_uniq>( - writer, std::move(schema), std::move(path_in_schema)); + writer, std::move(schema), path_in_schema); case LogicalTypeId::ENUM: - return make_uniq(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq(writer, std::move(schema), path_in_schema); case LogicalTypeId::SQLNULL: // All values are NULL - use INT32 as physical type (values are never read, only definition levels matter) - return make_uniq>(writer, std::move(schema), std::move(path_in_schema)); + return make_uniq>(writer, std::move(schema), path_in_schema); default: throw InternalException("Unsupported type \"%s\" in Parquet writer", type.ToString()); } diff --git a/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp b/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp index b7e7d6ce6..fc3bb7cf9 100644 --- a/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp +++ b/src/duckdb/extension/parquet/decoder/dictionary_decoder.cpp @@ -6,7 +6,6 @@ #include "column_reader.hpp" #include "parquet_reader.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" #include "duckdb/planner/table_filter_state.hpp" #include "duckdb/storage/table/column_segment.hpp" @@ -154,45 +153,10 @@ void DictionaryDecoder::Skip(uint8_t *defines, idx_t skip_count) { } bool DictionaryDecoder::DictionarySupportsFilter(const TableFilter &filter, TableFilterState &filter_state) { - switch (filter.filter_type) { - case TableFilterType::CONJUNCTION_OR: { - auto &conjunction = filter.Cast(); - auto &state = filter_state.Cast(); - for (idx_t child_idx = 0; child_idx < conjunction.child_filters.size(); child_idx++) { - auto &child_filter = *conjunction.child_filters[child_idx]; - auto &child_state = *state.child_states[child_idx]; - if (!DictionarySupportsFilter(child_filter, child_state)) { - return false; - } - } - return true; - } - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction = filter.Cast(); - auto &state = filter_state.Cast(); - for (idx_t child_idx = 0; child_idx < conjunction.child_filters.size(); child_idx++) { - auto &child_filter = *conjunction.child_filters[child_idx]; - auto &child_state = *state.child_states[child_idx]; - if (!DictionarySupportsFilter(child_filter, child_state)) { - return false; - } - } - return true; - } - case TableFilterType::EXPRESSION_FILTER: { - // expression filters can only be pushed into the dictionary if they filter out NULL values - auto &expr_filter = - ExpressionFilter::GetExpressionFilter(filter, "DictionaryDecoder::DictionarySupportsFilter"); - auto &state = filter_state.Cast(); - auto emits_nulls = expr_filter.EvaluateWithConstant(*state.executor, Value(reader.Type())); - return !emits_nulls; - } - case TableFilterType::DYNAMIC_FILTER: - case TableFilterType::OPTIONAL_FILTER: - case TableFilterType::STRUCT_EXTRACT: - default: - return false; - } + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "DictionaryDecoder::DictionarySupportsFilter"); + auto &state = filter_state.Cast(); + auto emits_nulls = expr_filter.EvaluateWithConstant(*state.executor, Value(reader.Type())); + return !emits_nulls; } bool DictionaryDecoder::CanFilter(const TableFilter &filter, TableFilterState &filter_state) { diff --git a/src/duckdb/extension/parquet/include/column_reader.hpp b/src/duckdb/extension/parquet/include/column_reader.hpp index ffece12c7..40bef960d 100644 --- a/src/duckdb/extension/parquet/include/column_reader.hpp +++ b/src/duckdb/extension/parquet/include/column_reader.hpp @@ -137,6 +137,9 @@ class ColumnReader { idx_t MaxRepeat() const { return column_schema.max_repeat; } + inline bool IsSkipped() const { + return !chunk; + } void InitializeCryptoMetadata(const duckdb_parquet::EncryptionAlgorithm &encryption_algorithm, idx_t row_group_ordinal_p) { @@ -246,7 +249,7 @@ class ColumnReader { void BeginRead(data_ptr_t define_out, data_ptr_t repeat_out); void FinishRead(idx_t read_count); idx_t ReadPageHeaders(idx_t max_read, optional_ptr filter = nullptr, - optional_ptr filter_state = nullptr); + optional_ptr filter_state = nullptr, idx_t rows_to_skip = 0); idx_t ReadInternal(ColumnReaderInput &input, Vector &result); //! Prepare a read of up to "max_read" rows and read the defines/repeats. //! Returns whether all values are valid (i.e., not NULL) @@ -353,7 +356,8 @@ class ColumnReader { private: void AllocateBlock(idx_t size); - void PrepareRead(optional_ptr filter, optional_ptr filter_state); + void PrepareRead(optional_ptr filter, optional_ptr filter_state, + idx_t rows_to_skip = 0); void PreparePage(PageHeader &page_hdr); void PrepareDataPage(PageHeader &page_hdr); void PreparePageV2(PageHeader &page_hdr); diff --git a/src/duckdb/extension/parquet/include/column_writer.hpp b/src/duckdb/extension/parquet/include/column_writer.hpp index e9d37cead..db0942e27 100644 --- a/src/duckdb/extension/parquet/include/column_writer.hpp +++ b/src/duckdb/extension/parquet/include/column_writer.hpp @@ -115,7 +115,7 @@ class ColumnWriter { static constexpr uint16_t PARQUET_DEFINE_VALID = UINT16_C(65535); public: - ColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path); + ColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path); virtual ~ColumnWriter(); public: @@ -177,8 +177,8 @@ class ColumnWriter { //! Create the column writer for a specific type recursively static unique_ptr CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, - vector path_in_schema, const LogicalType &type, - const string &name, bool allow_geometry, + vector path_in_schema, const LogicalType &type, + const Identifier &name, bool allow_geometry, optional_ptr field_ids, optional_ptr shredding_types, idx_t max_repeat = 0, idx_t max_define = 1, @@ -232,7 +232,7 @@ class ColumnWriter { //! The parent writer (if this is a nested field) optional_ptr parent; ParquetColumnSchema column_schema; - vector schema_path; + vector schema_path; bool can_have_nulls; protected: diff --git a/src/duckdb/extension/parquet/include/parquet_column_schema.hpp b/src/duckdb/extension/parquet/include/parquet_column_schema.hpp index aa2d117ba..117fc3934 100644 --- a/src/duckdb/extension/parquet/include/parquet_column_schema.hpp +++ b/src/duckdb/extension/parquet/include/parquet_column_schema.hpp @@ -31,7 +31,7 @@ using duckdb_parquet::SchemaElement; using duckdb_parquet::FileMetaData; struct ParquetOptions; -enum class ParquetColumnSchemaType { COLUMN, FILE_ROW_NUMBER, EXPRESSION, VARIANT, GEOMETRY }; +enum class ParquetColumnSchemaType { COLUMN, FILE_ROW_NUMBER, EXPRESSION, VARIANT, GEOMETRY, FILE_ROW_GROUP_NUMBER }; enum class ParquetExtraTypeInfo { NONE, @@ -54,7 +54,7 @@ struct ParquetColumnSchema { public: //! Writer constructors - static ParquetColumnSchema FromLogicalType(const string &name, const LogicalType &type, idx_t max_define, + static ParquetColumnSchema FromLogicalType(const Identifier &name, const LogicalType &type, idx_t max_define, idx_t max_repeat, idx_t column_index, duckdb_parquet::FieldRepetitionType::type repetition_type, bool allow_geometry, @@ -72,6 +72,7 @@ struct ParquetColumnSchema { vector &&children, ParquetColumnSchemaType schema_type = ParquetColumnSchemaType::COLUMN); static ParquetColumnSchema FileRowNumber(); + static ParquetColumnSchema FileRowGroupNumber(); public: unique_ptr Stats(const FileMetaData &file_meta_data, const ParquetOptions &parquet_options, diff --git a/src/duckdb/extension/parquet/include/parquet_field_id.hpp b/src/duckdb/extension/parquet/include/parquet_field_id.hpp index 7036632dc..294587dd0 100644 --- a/src/duckdb/extension/parquet/include/parquet_field_id.hpp +++ b/src/duckdb/extension/parquet/include/parquet_field_id.hpp @@ -22,7 +22,7 @@ class Serializer; struct ChildFieldIDs { ChildFieldIDs(); ChildFieldIDs Copy() const; - unique_ptr> ids; + unique_ptr> ids; void Serialize(Serializer &serializer) const; static ChildFieldIDs Deserialize(Deserializer &source); @@ -42,7 +42,7 @@ struct FieldID { static FieldID Deserialize(Deserializer &source); public: - static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, + static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, const vector &sql_types); static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, unordered_set &unique_field_ids, diff --git a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp index 00d787570..2054853bb 100644 --- a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp +++ b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp @@ -57,7 +57,6 @@ class ParquetFileMetadataCache : public ObjectCacheEntry { ParquetCacheValidity IsValid(const OpenFileInfo &info, ClientContext &context) const; private: - bool validate; timestamp_t last_modified; string version_tag; }; diff --git a/src/duckdb/extension/parquet/include/parquet_multi_file_info.hpp b/src/duckdb/extension/parquet/include/parquet_multi_file_info.hpp index 7e88d3343..b45f1835a 100644 --- a/src/duckdb/extension/parquet/include/parquet_multi_file_info.hpp +++ b/src/duckdb/extension/parquet/include/parquet_multi_file_info.hpp @@ -60,7 +60,7 @@ struct ParquetMultiFileInfo : MultiFileReaderInterface { vector &expected_types) override; bool ParseOption(ClientContext &context, const string &key, const Value &val, MultiFileOptions &file_options, BaseFileReaderOptions &options) override; - void BindReader(ClientContext &context, vector &return_types, vector &names, + void BindReader(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) override; unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, unique_ptr options) override; diff --git a/src/duckdb/extension/parquet/include/parquet_prefetch_cost_model.hpp b/src/duckdb/extension/parquet/include/parquet_prefetch_cost_model.hpp new file mode 100644 index 000000000..f00346028 --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_prefetch_cost_model.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// parquet_prefetch_cost_model.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types.hpp" + +namespace duckdb { + +struct NetworkThroughputEstimate; + +//! Cost model that decides how large a gap between two needed byte ranges may be +struct PrefetchCostModel { + //! round-trip latency + request setup + double latency_seconds; + //! single-stream throughput, in bytes per second + double bandwidth_bytes_per_s; + + //! Floor for the coalescing gap + static constexpr uint64_t GAP_MIN = 1ULL << 12; //! 4 KiB + //! Ceiling for the coalescing gap + static constexpr uint64_t GAP_MAX = 32ULL << 20; //! 32 MiB + + //! The coalescing gap implied by this model, clamped to [GAP_MIN, GAP_MAX] + uint64_t GetColumnGapSize() const; +}; + +class PrefetchCostModelState { +public: + PrefetchCostModelState() = default; + + //! Refine the model from a measured network throughput estimate. + void RefineFromEstimate(const NetworkThroughputEstimate &estimate); + + bool TryGetColumnGapSize(uint64_t &result) const; + +private: + //! Weight applied to an estimate + static constexpr double ALPHA = 0.5; + + //! Measured cost model + PrefetchCostModel model = {0, 0}; + //! Whether we have switched from the seed to measured values + bool measured_adopted = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_reader.hpp b/src/duckdb/extension/parquet/include/parquet_reader.hpp index 3bdeb454e..a684ff253 100644 --- a/src/duckdb/extension/parquet/include/parquet_reader.hpp +++ b/src/duckdb/extension/parquet/include/parquet_reader.hpp @@ -16,6 +16,7 @@ #include #include "duckdb.hpp" +#include "duckdb/common/enum_util.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/storage/external_file_cache/caching_file_system.hpp" #include "duckdb/common/common.hpp" @@ -28,6 +29,7 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/common/types/data_chunk.hpp" #include "column_reader.hpp" +#include "parquet_prefetch_cost_model.hpp" #include "parquet_file_metadata_cache.hpp" #include "parquet_rle_bp_decoder.hpp" #include "parquet_types.h" @@ -81,6 +83,7 @@ class EncryptionUtil; class PhysicalOperator; class Serializer; class TableFilter; +class ThriftFileTransport; struct CryptoMetaData; struct GlobalTableFunctionState; struct LocalTableFunctionState; @@ -88,10 +91,34 @@ struct PartitionStatistics; struct TableFilterState; struct ParquetReaderPrefetchConfig { - // Percentage of data in a row group span that should be scanned for enabling whole group prefetch + //! Percentage of data in a row group span that should be scanned for enabling whole group prefetch static constexpr double WHOLE_GROUP_PREFETCH_MINIMUM_SCAN = 0.95; + //! How many row groups need to produce at least one surviving row (from filtering) + static constexpr double PREFETCH_FILTER_MINIMUM_MATCH_RATIO = 0.9; }; +enum class ParquetPrefetchStrategy : uint8_t { + NONE, + WHOLE_GROUP, //! Prefetches the whole group + PREFETCH_FILTERS, //! Used when we have fully selective filters, so they are prefetched earlier + COLUMN_WISE_EAGER //! Used when we have projections and optional only filters +}; + +const char *ParquetPrefetchStrategyToString(ParquetPrefetchStrategy strategy); + +enum class ParquetPrefetchStrategyOption : uint8_t { + AUTO, //! Uses the runtime strategy to pick between ParquetPrefetchStrategy + WHOLE_GROUP, //! Always do the whole row group +}; + +ParquetPrefetchStrategyOption ParquetPrefetchStrategyOptionFromString(const string &value); + +template <> +const char *EnumUtil::ToChars(ParquetPrefetchStrategyOption value); + +template <> +ParquetPrefetchStrategyOption EnumUtil::FromString(const char *value); + struct ParquetScanFilter { ParquetScanFilter(ClientContext &context, ProjectionIndex filter_idx, TableFilter &filter); ~ParquetScanFilter(); @@ -102,6 +129,65 @@ struct ParquetScanFilter { unique_ptr filter_state; }; +struct ParquetPrefetchColumn { + string name; + idx_t offset; + + ParquetPrefetchColumn(string name_p, idx_t offset_p) : name(std::move(name_p)), offset(offset_p) { + } + + bool operator<(const ParquetPrefetchColumn &other) const { + return offset < other.offset; + } +}; +//! Logger-only prefetch metrics: populated only when ParquetPrefetch logging is enabled. +struct ParquetLoggerPrefetchMetrics { + //! Physical prefetch groups (column-name batches) issued for the current row group, in order + vector> prefetch_groups; + //! Tracks which scan_filters were evaluated in the current row group (indexed by scan_filter position) + vector filters_used; + //! Prefetch strategy chosen for the current row group + ParquetPrefetchStrategy strategy = ParquetPrefetchStrategy::NONE; + //! Accepted column gap (bytes) + uint64_t accepted_column_gap = 0; + + void Reset() { + prefetch_groups.clear(); + std::fill(filters_used.begin(), filters_used.end(), false); + strategy = ParquetPrefetchStrategy::NONE; + accepted_column_gap = 0; + } + + //! Build a prefetch group + void GeneratePrefetchGroup(ThriftFileTransport &trans, vector &requested_columns, + const vector &all_chunks); +}; + +struct ParquetPrefetchMetrics { + //! Whether any filter was evaluated against the current row group + bool filter_ran = false; + //! Whether the current row group produced at least one surviving row after filtering + bool had_match = false; + //! Total number of row groups for which filters were evaluated across the scan + idx_t row_groups_executed = 0; + //! Number of those row groups that produced at least one surviving row + idx_t row_groups_with_matches = 0; + + ParquetLoggerPrefetchMetrics logger; + + void FinalizeRowGroupSelectivity() { + if (filter_ran) { + row_groups_executed++; + if (had_match) { + row_groups_with_matches++; + } + } + filter_ran = false; + had_match = false; + logger.Reset(); + } +}; + struct ParquetReaderScanState { public: ColumnReader &GetColumnReader(idx_t i); @@ -111,7 +197,7 @@ struct ParquetReaderScanState { int64_t current_group; idx_t offset_in_group; idx_t group_offset; - unique_ptr file_handle; + shared_ptr file_handle; vector> column_readers; duckdb_base_std::unique_ptr thrift_file_proto; @@ -123,14 +209,35 @@ struct ParquetReaderScanState { bool prefetch_mode = false; bool current_group_prefetched = false; + //! Number of filter head counts, used for prefetching + idx_t filter_head_count = 0; + //! true once the filters ran + bool filter_done = false; + //! Surviving row count + idx_t filter_count = 0; + //! Filter columns kept across the payload BLOCKED + DataChunk filter_stash; + + ParquetPrefetchMetrics prefetch_metrics; //! Per-thread adaptive filter cache MultiFileAdaptiveFilterCache adaptive_filter_cache; //! Table filter list vector scan_filters; + //! true once the filter at this index has driven the surviving row count to zero + vector filter_eliminated_all_rows; //! (optional) pointer to the PhysicalOperator for logging optional_ptr op; + + //! Per-thread counters for row groups whose data was read / skipped, incremented as row groups are processed. + //! Read in ParquetScanGetMetrics and surfaced as the standard row_groups_scanned / total_row_groups_to_scan + //! profiling metrics (the profiler sums them across threads). + idx_t row_groups_read = 0; + idx_t row_groups_skipped = 0; + + //! Prefetch cost model + PrefetchCostModelState cost_model_state; }; struct ParquetColumnDefinition { @@ -163,6 +270,7 @@ struct ParquetOptions { vector schema; idx_t explicit_cardinality = 0; bool can_have_nan = false; // if floats or doubles can contain NaN values + ParquetPrefetchStrategyOption prefetch_strategy = ParquetPrefetchStrategyOption::AUTO; }; struct ParquetOptionsSerialization { @@ -185,7 +293,7 @@ struct ParquetUnionData : public BaseUnionData { ~ParquetUnionData() override; optional_idx TryGetCardinalityEstimate() const override; - unique_ptr GetStatistics(ClientContext &context, const string &name) override; + unique_ptr GetStatistics(ClientContext &context, const Identifier &name) override; ParquetOptions options; shared_ptr metadata; @@ -193,6 +301,10 @@ struct ParquetUnionData : public BaseUnionData { }; class ParquetReader : public BaseFileReader { +public: + //! Virtual column identifier for the "file_row_group_number" column (the file-relative row group index of each row) + static constexpr column_t COLUMN_IDENTIFIER_FILE_ROW_GROUP_NUMBER = UINT64_C(9223372036854775820); + public: ParquetReader(ClientContext &context, OpenFileInfo file, ParquetOptions parquet_options, shared_ptr metadata = nullptr); @@ -213,7 +325,7 @@ class ParquetReader : public BaseFileReader { } shared_ptr GetUnionData(idx_t file_idx) override; - unique_ptr GetStatistics(ClientContext &context, const string &name) override; + unique_ptr GetStatistics(ClientContext &context, const Identifier &name) override; bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; @@ -244,15 +356,16 @@ class ParquetReader : public BaseFileReader { uint32_t ReadDataEncrypted(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, CryptoMetaData &aad_crypto_metadata) const; - unique_ptr ReadStatistics(const string &name); + unique_ptr ReadStatistics(const Identifier &name); CachingFileHandle &GetHandle() { return *file_handle; } static unique_ptr ReadStatistics(ClientContext &context, ParquetOptions parquet_options, - shared_ptr metadata, const string &name); - static unique_ptr ReadStatistics(const ParquetUnionData &union_data, const string &name); + shared_ptr metadata, + const Identifier &name); + static unique_ptr ReadStatistics(const ParquetUnionData &union_data, const Identifier &name); LogicalType DeriveLogicalType(const SchemaElement &s_ele, ParquetColumnSchema &schema) const; static LogicalType DeriveLogicalType(const SchemaElement &s_ele, const ParquetOptions &options, @@ -283,9 +396,33 @@ class ParquetReader : public BaseFileReader { const duckdb_parquet::RowGroup &GetGroup(ParquetReaderScanState &state); uint64_t GetGroupCompressedSize(ParquetReaderScanState &state); idx_t GetGroupOffset(ParquetReaderScanState &state); - // Group span is the distance between the min page offset and the max page offset plus the max page compressed size + //! Group span is the distance between the min page offset and the max page offset plus the max page compressed size uint64_t GetGroupSpan(ParquetReaderScanState &state); void PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t out_col_idx); + //! Whole-group prefetch strategy. + ParquetPrefetchStrategy WholeGroupPrefetch(ParquetReaderScanState &state, ThriftFileTransport &trans, + const duckdb_parquet::RowGroup &group, uint64_t total_row_group_span, + bool log_prefetch); + //! Column-wise prefetch strategy. + ParquetPrefetchStrategy ColumnWisePrefetch(ParquetReaderScanState &state, ThriftFileTransport &trans, + const duckdb_parquet::RowGroup &group, bool filters_look_unselective, + bool log_prefetch) const; + //! Switch to the next row group and schedule its I/O (prepare column buffers, prefetch the bytes). + AsyncResult Schedule(ClientContext &context, ParquetReaderScanState &state, DataChunk &result, bool log_prefetch); + //! Process up to STANDARD_VECTOR_SIZE rows of the current row group into result. + AsyncResult Process(ParquetReaderScanState &state, DataChunk &result, bool log_prefetch); + //! Process filters + AsyncResult ProcessFilters(ParquetReaderScanState &state, DataChunk &result, idx_t scan_count, uint8_t *define_ptr, + uint8_t *repeat_ptr, bool log_prefetch); + //! Run the filters into state.sel; returns the surviving row count. Advances every filter column. + idx_t EvaluateFilters(ParquetReaderScanState &state, DataChunk &result, idx_t scan_count, uint8_t *define_ptr, + uint8_t *repeat_ptr, bool log_prefetch); + //! Async-fetch the surviving payload columns (stashing the filter columns); empty if no fetch is needed. + vector> ScheduleRemainingColumns(ParquetReaderScanState &state, DataChunk &result, + idx_t scan_count); + //! Read the remaining (non-filter) columns into result. + void DecodeRemainingColumns(ParquetReaderScanState &state, DataChunk &result, idx_t filter_count, + uint8_t *define_ptr, uint8_t *repeat_ptr); ParquetColumnSchema ParseColumnSchema(const SchemaElement &s_ele, idx_t max_define, idx_t max_repeat, idx_t schema_index, idx_t column_index, ParquetColumnSchemaType type = ParquetColumnSchemaType::COLUMN); diff --git a/src/duckdb/extension/parquet/include/parquet_shredding.hpp b/src/duckdb/extension/parquet/include/parquet_shredding.hpp index 410f51b86..2df78c712 100644 --- a/src/duckdb/extension/parquet/include/parquet_shredding.hpp +++ b/src/duckdb/extension/parquet/include/parquet_shredding.hpp @@ -30,7 +30,7 @@ struct ChildShreddingTypes { static ChildShreddingTypes Deserialize(Deserializer &source); public: - unique_ptr> types; + unique_ptr> types; }; struct ShreddingType { @@ -50,8 +50,8 @@ struct ShreddingType { public: static ShreddingType GetShreddingTypes(const Value &val, ClientContext &context); - void AddChild(const string &name, ShreddingType &&child); - optional_ptr GetChild(const string &name) const; + void AddChild(const Identifier &name, ShreddingType &&child); + optional_ptr GetChild(const Identifier &name) const; public: bool set = false; diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp index a64df16b2..ea297fa6a 100644 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ b/src/duckdb/extension/parquet/include/parquet_writer.hpp @@ -134,8 +134,7 @@ struct ParquetWriteLocalState : public LocalFunctionData { struct ParquetWriteGlobalState : public GlobalFunctionData { public: - ParquetWriteGlobalState() { - } + ParquetWriteGlobalState(); public: unique_ptr writer; diff --git a/src/duckdb/extension/parquet/include/reader/row_number_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/row_number_column_reader.hpp index 82f614d45..924686323 100644 --- a/src/duckdb/extension/parquet/include/reader/row_number_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/row_number_column_reader.hpp @@ -71,4 +71,36 @@ class RowNumberColumnReader : public ColumnReader { idx_t row_group_offset; }; +//! Reads the file-relative row group number as a virtual column that's not actually stored in the file +class RowGroupColumnReader : public ColumnReader { +public: + static constexpr const PhysicalType TYPE = PhysicalType::INT64; + +public: + RowGroupColumnReader(const ParquetReader &reader, const ParquetColumnSchema &schema); + +public: + idx_t Read(ColumnReaderInput &input, Vector &result) override; + + void InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) override; + + void Skip(idx_t num_values) override { + } + idx_t GroupRowsAvailable() override { + return NumericLimits::Maximum(); + }; + uint64_t TotalCompressedSize() override { + return 0; + } + idx_t FileOffset() const override { + return 0; + } + void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) override { + } + +private: + //! The index of the row group currently being read + idx_t row_group_idx = 0; +}; + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/thrift_tools.hpp b/src/duckdb/extension/parquet/include/thrift_tools.hpp index de5500680..96fe51384 100644 --- a/src/duckdb/extension/parquet/include/thrift_tools.hpp +++ b/src/duckdb/extension/parquet/include/thrift_tools.hpp @@ -8,7 +8,6 @@ #pragma once -#include #include "thrift/protocol/TCompactProtocol.h" #include "thrift/transport/TBufferTransports.h" @@ -39,27 +38,42 @@ struct ReadHead { // Materialize [buffer_ptr], should call before access. // TODO(hjiang): Currently it's only used for `Prefetch` operation, should be able to save one copy. - void Materialize() { + void Materialize(Allocator &allocator) { if (handle_group.GetHandles().size() == 1) { buffer_ptr = handle_group.Ptr(); } else { - local_buffer = Allocator::DefaultAllocator().Allocate(size); + local_buffer = allocator.Allocate(size); handle_group.CopyTo(local_buffer.get(), size); buffer_ptr = local_buffer.get(); } } + + void Fetch(CachingFileHandle &file_handle) { + if (GetEnd() > file_handle.GetFileSize()) { + throw std::runtime_error("Prefetch registered requested for bytes outside file"); + } + handle_group = file_handle.Read(size, location); + Materialize(file_handle.GetBufferAllocator()); + data_isset = true; + } }; -// Comparator for ReadHeads that are either overlapping, adjacent, or within ALLOW_GAP bytes from each other struct ReadHeadComparator { - static constexpr uint64_t ALLOW_GAP = 1 << 14; // 16 KiB + static constexpr uint64_t DEFAULT_ACCEPTED_COLUMN_GAP = 1 << 14; // 16 KiB + + ReadHeadComparator() = default; + explicit ReadHeadComparator(uint64_t accepted_column_gap_p) : accepted_column_gap(accepted_column_gap_p) { + } + + uint64_t accepted_column_gap = DEFAULT_ACCEPTED_COLUMN_GAP; + bool operator()(const ReadHead *a, const ReadHead *b) const { auto a_start = a->location; auto a_end = a->location + a->size; auto b_start = b->location; - if (a_end <= NumericLimits::Maximum() - ALLOW_GAP) { - a_end += ALLOW_GAP; + if (a_end <= NumericLimits::Maximum() - accepted_column_gap) { + a_end += accepted_column_gap; } return a_start < b_start && a_end < b_start; @@ -70,11 +84,13 @@ struct ReadHeadComparator { // 1: register all ranges that will be read, merging ranges that are consecutive // 2: prefetch all registered ranges struct ReadAheadBuffer { - explicit ReadAheadBuffer(CachingFileHandle &file_handle_p) : file_handle(file_handle_p) { + explicit ReadAheadBuffer(CachingFileHandle &file_handle_p, + uint64_t accepted_column_gap = ReadHeadComparator::DEFAULT_ACCEPTED_COLUMN_GAP) + : merge_set(ReadHeadComparator(accepted_column_gap)), file_handle(file_handle_p) { } - // The list of read heads - std::list read_heads; + // The list of read heads. + vector> read_heads; // Set for merging consecutive ranges std::set merge_set; @@ -82,6 +98,11 @@ struct ReadAheadBuffer { idx_t total_size = 0; + void SetAcceptedColumnGap(uint64_t accepted_column_gap) { + D_ASSERT(merge_set.empty()); + merge_set = std::set(ReadHeadComparator(accepted_column_gap)); + } + // Add a read head to the prefetching list void AddReadHead(idx_t pos, uint64_t len, bool merge_buffers = true) { // Attempt to merge with existing @@ -98,9 +119,9 @@ struct ReadAheadBuffer { } } - read_heads.emplace_front(ReadHead(pos, len)); + read_heads.insert(read_heads.begin(), make_shared_ptr(pos, len)); total_size += len; - auto &read_head = read_heads.front(); + auto &read_head = *read_heads.front(); if (merge_buffers) { merge_set.insert(&read_head); @@ -117,8 +138,8 @@ struct ReadAheadBuffer { // Returns the relevant read head ReadHead *GetReadHead(idx_t pos) { for (auto &read_head : read_heads) { - if (pos >= read_head.location && pos < read_head.GetEnd()) { - return &read_head; + if (pos >= read_head->location && pos < read_head->GetEnd()) { + return read_head.get(); } } return nullptr; @@ -127,12 +148,7 @@ struct ReadAheadBuffer { // Prefetch all read heads void Prefetch() { for (auto &read_head : read_heads) { - if (read_head.GetEnd() > file_handle.GetFileSize()) { - throw std::runtime_error("Prefetch registered requested for bytes outside file"); - } - read_head.handle_group = file_handle.Read(read_head.size, read_head.location); - read_head.Materialize(); - read_head.data_isset = true; + read_head->Fetch(file_handle); } } }; @@ -141,9 +157,19 @@ class ThriftFileTransport : public duckdb_apache::thrift::transport::TVirtualTra public: static constexpr uint64_t PREFETCH_FALLBACK_BUFFERSIZE = 1000000; - ThriftFileTransport(CachingFileHandle &file_handle_p, bool prefetch_mode_p) + ThriftFileTransport(CachingFileHandle &file_handle_p, bool prefetch_mode_p, + uint64_t accepted_column_gap = ReadHeadComparator::DEFAULT_ACCEPTED_COLUMN_GAP) : file_handle(file_handle_p), location(0), size(file_handle.GetFileSize()), - ra_buffer(ReadAheadBuffer(file_handle)), prefetch_mode(prefetch_mode_p) { + ra_buffer(ReadAheadBuffer(file_handle, accepted_column_gap)), prefetch_mode(prefetch_mode_p) { + } + + void SetAcceptedColumnGap(uint64_t accepted_column_gap) { + ra_buffer.SetAcceptedColumnGap(accepted_column_gap); + } + + // The accepted column gap currently used to coalesce adjacent ranges. + uint64_t GetAcceptedColumnGap() const { + return ra_buffer.merge_set.key_comp().accepted_column_gap; } uint32_t read(uint8_t *buf, uint32_t len) { @@ -152,9 +178,7 @@ class ThriftFileTransport : public duckdb_apache::thrift::transport::TVirtualTra D_ASSERT(location - prefetch_buffer->location + len <= prefetch_buffer->size); if (!prefetch_buffer->data_isset) { - prefetch_buffer->handle_group = file_handle.Read(prefetch_buffer->size, prefetch_buffer->location); - prefetch_buffer->Materialize(); - prefetch_buffer->data_isset = true; + prefetch_buffer->Fetch(file_handle); } memcpy(buf, prefetch_buffer->buffer_ptr + location - prefetch_buffer->location, len); } else if (prefetch_mode && len < PREFETCH_FALLBACK_BUFFERSIZE && len > 0) { @@ -218,6 +242,14 @@ class ThriftFileTransport : public duckdb_apache::thrift::transport::TVirtualTra return ra_buffer.GetReadHead(pos); } + vector> &GetReadHeads() { + return ra_buffer.read_heads; + } + + CachingFileHandle &GetCachingFileHandle() const { + return file_handle; + } + idx_t GetSize() const { return size; } diff --git a/src/duckdb/extension/parquet/include/writer/array_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/array_column_writer.hpp index 404430e1e..912ff62f5 100644 --- a/src/duckdb/extension/parquet/include/writer/array_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/array_column_writer.hpp @@ -14,7 +14,7 @@ namespace duckdb { class ArrayColumnWriter : public ListColumnWriter { public: - ArrayColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + ArrayColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, unique_ptr child_writer_p) : ListColumnWriter(writer, std::move(column_schema), std::move(schema_path_p), std::move(child_writer_p)) { } diff --git a/src/duckdb/extension/parquet/include/writer/boolean_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/boolean_column_writer.hpp index 70a5f0cb3..c9c446e34 100644 --- a/src/duckdb/extension/parquet/include/writer/boolean_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/boolean_column_writer.hpp @@ -26,7 +26,7 @@ struct ParquetColumnSchema; class BooleanColumnWriter : public PrimitiveColumnWriter { public: - BooleanColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); + BooleanColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); ~BooleanColumnWriter() override = default; public: diff --git a/src/duckdb/extension/parquet/include/writer/decimal_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/decimal_column_writer.hpp index 1cb5bc511..41df7530a 100644 --- a/src/duckdb/extension/parquet/include/writer/decimal_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/decimal_column_writer.hpp @@ -26,7 +26,8 @@ struct ParquetColumnSchema; class FixedDecimalColumnWriter : public PrimitiveColumnWriter { public: - FixedDecimalColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); + FixedDecimalColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, + vector schema_path_p); ~FixedDecimalColumnWriter() override = default; public: diff --git a/src/duckdb/extension/parquet/include/writer/enum_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/enum_column_writer.hpp index 5a1efd5c4..884e7b40b 100644 --- a/src/duckdb/extension/parquet/include/writer/enum_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/enum_column_writer.hpp @@ -29,12 +29,9 @@ struct ParquetColumnSchema; class EnumColumnWriter : public PrimitiveColumnWriter { public: - EnumColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); + EnumColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p); ~EnumColumnWriter() override = default; - uint32_t bit_width; - -public: unique_ptr InitializeStatsState() override; void WriteVector(WriteStream &temp_writer, ColumnWriterStatistics *stats_p, ColumnWriterPageState *page_state_p, @@ -58,6 +55,10 @@ class EnumColumnWriter : public PrimitiveColumnWriter { template void WriteEnumInternal(WriteStream &temp_writer, Vector &input_column, idx_t chunk_start, idx_t chunk_end, EnumWriterPageState &page_state); + + uint32_t bit_width; + //! Tracks which enum values have actually been written. + vector seen_enum; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp index ba841ac4e..21ea168fb 100644 --- a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp @@ -26,7 +26,7 @@ class ListColumnWriterState : public ColumnWriterState { class ListColumnWriter : public ColumnWriter { public: - ListColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + ListColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, unique_ptr child_writer_p) : ColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { child_writers.push_back(std::move(child_writer_p)); diff --git a/src/duckdb/extension/parquet/include/writer/parquet_write_operators.hpp b/src/duckdb/extension/parquet/include/writer/parquet_write_operators.hpp index 249dc84c6..ef86a5efc 100644 --- a/src/duckdb/extension/parquet/include/writer/parquet_write_operators.hpp +++ b/src/duckdb/extension/parquet/include/writer/parquet_write_operators.hpp @@ -319,6 +319,22 @@ struct ParquetTimeTZOperator : public BaseParquetOperator { static TGT Operation(SRC input) { return input.time().micros; } + + template + static unique_ptr InitializeStats() { + return make_uniq>(); + } + + template + static void HandleStats(ColumnWriterStatistics *stats, TGT target_value) { + auto &numeric_stats = stats->Cast>(); + if (LessThan::Operation(target_value, numeric_stats.min)) { + numeric_stats.min = target_value; + } + if (GreaterThan::Operation(target_value, numeric_stats.max)) { + numeric_stats.max = target_value; + } + } }; struct ParquetHugeintOperator : public BaseParquetOperator { diff --git a/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp index 4ec315567..740830ca3 100644 --- a/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp @@ -72,7 +72,7 @@ class PrimitiveColumnWriterState : public ColumnWriterState { //! Base class for writing non-compound types (ex. numerics, strings) class PrimitiveColumnWriter : public ColumnWriter { public: - PrimitiveColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path); + PrimitiveColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path); ~PrimitiveColumnWriter() override = default; //! We limit the uncompressed page size to 100MB diff --git a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp index f0dc047bd..3deb031a8 100644 --- a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp @@ -14,7 +14,7 @@ namespace duckdb { class StructColumnWriter : public ColumnWriter { public: - StructColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + StructColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, vector> child_writers_p) : ColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { child_writers = std::move(child_writers_p); diff --git a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp index f9d2bedcb..cd3477014 100644 --- a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp @@ -119,7 +119,7 @@ class StandardWriterPageState : public ColumnWriterPageState { template class StandardColumnWriter : public PrimitiveColumnWriter { public: - StandardColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p) + StandardColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p) : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } ~StandardColumnWriter() override = default; diff --git a/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp index b495afc35..16d66bc7e 100644 --- a/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp @@ -24,8 +24,8 @@ struct ArrayAnalyzeData; struct VariantAnalyzeData { public: - VariantAnalyzeData() { - } + VariantAnalyzeData(); + ~VariantAnalyzeData(); public: //! Map for every value what type it is @@ -36,8 +36,8 @@ struct VariantAnalyzeData { idx_t total_count = 0; //! Map for every decimal value what physical type it has - unique_ptr object_data = nullptr; - unique_ptr array_data = nullptr; + unique_ptr object_data; + unique_ptr array_data; }; struct ObjectAnalyzeData { @@ -71,7 +71,7 @@ struct VariantAnalyzeSchemaState : public ParquetAnalyzeSchemaState { class VariantColumnWriter : public StructColumnWriter { public: - VariantColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, + VariantColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path_p, vector> child_writers_p) : StructColumnWriter(writer, std::move(column_schema), std::move(schema_path_p), std::move(child_writers_p)) { } diff --git a/src/duckdb/extension/parquet/parquet_column_schema.cpp b/src/duckdb/extension/parquet/parquet_column_schema.cpp index dc0acc6b8..555be0eac 100644 --- a/src/duckdb/extension/parquet/parquet_column_schema.cpp +++ b/src/duckdb/extension/parquet/parquet_column_schema.cpp @@ -19,12 +19,12 @@ void ParquetColumnSchema::SetSchemaIndex(idx_t schema_idx) { //! Writer constructors -ParquetColumnSchema ParquetColumnSchema::FromLogicalType(const string &name, const LogicalType &type, idx_t max_define, - idx_t max_repeat, idx_t column_index, +ParquetColumnSchema ParquetColumnSchema::FromLogicalType(const Identifier &name, const LogicalType &type, + idx_t max_define, idx_t max_repeat, idx_t column_index, duckdb_parquet::FieldRepetitionType::type repetition_type, bool allow_geometry, ParquetColumnSchemaType schema_type) { ParquetColumnSchema res; - res.name = name; + res.name = name.GetIdentifierName(); res.max_define = max_define; res.max_repeat = max_repeat; res.column_index = column_index; @@ -95,24 +95,49 @@ ParquetColumnSchema ParquetColumnSchema::FileRowNumber() { return res; } +ParquetColumnSchema ParquetColumnSchema::FileRowGroupNumber() { + ParquetColumnSchema res; + res.name = "file_row_group_number"; + res.max_define = 0; + res.max_repeat = 0; + res.schema_index = 0; + res.column_index = 0; + res.schema_type = ParquetColumnSchemaType::FILE_ROW_GROUP_NUMBER; + res.type = LogicalType::UBIGINT; + res.repetition_type = duckdb_parquet::FieldRepetitionType::type::OPTIONAL; + return res; +} + unique_ptr ParquetColumnSchema::Stats(const FileMetaData &file_meta_data, const ParquetOptions &parquet_options, idx_t row_group_idx_p, const vector &columns) const { if (schema_type == ParquetColumnSchemaType::EXPRESSION) { return nullptr; } - if (schema_type == ParquetColumnSchemaType::FILE_ROW_NUMBER) { + if (schema_type == ParquetColumnSchemaType::FILE_ROW_GROUP_NUMBER) { + // the row group number is constant within a row group - set min and max to the row group index auto stats = NumericStats::CreateUnknown(type); + NumericStats::SetMin(stats, Value::UBIGINT(UnsafeNumericCast(row_group_idx_p))); + NumericStats::SetMax(stats, Value::UBIGINT(UnsafeNumericCast(row_group_idx_p))); + stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); + return stats.ToUnique(); + } + if (schema_type == ParquetColumnSchemaType::FILE_ROW_NUMBER) { auto &row_groups = file_meta_data.row_groups; D_ASSERT(row_group_idx_p < row_groups.size()); + if (row_groups[row_group_idx_p].num_rows == 0) { + return NumericStats::CreateEmpty(type).ToUnique(); + } + idx_t row_group_offset_min = 0; for (idx_t i = 0; i < row_group_idx_p; i++) { row_group_offset_min += row_groups[i].num_rows; } + auto stats = NumericStats::CreateUnknown(type); NumericStats::SetMin(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min))); - NumericStats::SetMax(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min + - row_groups[row_group_idx_p].num_rows))); + NumericStats::SetMax(stats, Value::BIGINT(UnsafeNumericCast( + row_group_offset_min + row_groups[row_group_idx_p].num_rows - 1))); stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); return stats.ToUnique(); } diff --git a/src/duckdb/extension/parquet/parquet_crypto.cpp b/src/duckdb/extension/parquet/parquet_crypto.cpp index 5c8c97760..db6aa3a93 100644 --- a/src/duckdb/extension/parquet/parquet_crypto.cpp +++ b/src/duckdb/extension/parquet/parquet_crypto.cpp @@ -147,7 +147,7 @@ ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context, const V const auto &keys = ParquetKeys::Get(context); for (idx_t i = 0; i < StructType::GetChildCount(arg.type()); i++) { auto &struct_key = child_types[i].first; - if (StringUtil::Lower(struct_key) == "footer_key") { + if (struct_key == "footer_key") { const auto footer_key_name = StringValue::Get(children[i].DefaultCastAs(LogicalType::VARCHAR)); if (!keys.HasKey(footer_key_name)) { throw BinderException( @@ -157,9 +157,9 @@ ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context, const V // footer key name provided - read the key from the config const auto &keys = ParquetKeys::Get(context); footer_key = keys.GetKey(footer_key_name); - } else if (StringUtil::Lower(struct_key) == "footer_key_value") { + } else if (struct_key == "footer_key_value") { footer_key = StringValue::Get(children[i].DefaultCastAs(LogicalType::BLOB)); - } else if (StringUtil::Lower(struct_key) == "column_keys") { + } else if (struct_key == "column_keys") { throw NotImplementedException("Parquet encryption_config column_keys not yet implemented"); } else { throw BinderException("Unknown key in encryption_config \"%s\"", struct_key); diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index 945aa98ea..f0821f14d 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -160,7 +160,8 @@ static void ParquetListCopyOptions(ClientContext &context, CopyOptionsInput &inp } static unique_ptr ParquetWriteBind(ClientContext &context, CopyFunctionBindInput &input, - const vector &names, const vector &sql_types) { + const vector &names, + const vector &sql_types) { D_ASSERT(names.size() == sql_types.size()); bool compression_level_set = false; auto bind_data = make_uniq(); @@ -213,7 +214,7 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun StringUtil::Lower(StringValue::Get(option.second[0])) == "auto") { throw NotImplementedException("The 'auto' option is not yet implemented for 'shredding'"); } else { - case_insensitive_set_t variant_names; + identifier_set_t variant_names; for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { if (sql_types[col_idx].id() != LogicalTypeId::VARIANT) { continue; @@ -229,8 +230,9 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun const auto &struct_children = StructValue::GetChildren(shredding_types_value); D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); for (idx_t i = 0; i < struct_children.size(); i++) { - const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); - auto it = variant_names.find(col_name); + const auto &col_name = + StringUtil::Lower(StructType::GetChildName(struct_type, i).GetIdentifierName()); + auto it = variant_names.find(Identifier(col_name)); if (it == variant_names.end()) { string names; for (const auto &entry : variant_names) { @@ -252,7 +254,7 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } } const auto &child_value = struct_children[i]; - bind_data->shredding_types.AddChild(col_name, + bind_data->shredding_types.AddChild(Identifier(col_name), ShreddingType::GetShreddingTypes(child_value, context)); } } @@ -349,7 +351,7 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } bind_data->sql_types = sql_types; - bind_data->column_names = names; + bind_data->column_names = IdentifiersToStrings(names); return std::move(bind_data); } @@ -556,6 +558,29 @@ ParquetVersion EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template <> +const char *EnumUtil::ToChars(ParquetPrefetchStrategyOption value) { + switch (value) { + case ParquetPrefetchStrategyOption::AUTO: + return "AUTO"; + case ParquetPrefetchStrategyOption::WHOLE_GROUP: + return "WHOLE_GROUP"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); + } +} + +template <> +ParquetPrefetchStrategyOption EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "AUTO")) { + return ParquetPrefetchStrategyOption::AUTO; + } + if (StringUtil::Equals(value, "WHOLE_GROUP")) { + return ParquetPrefetchStrategyOption::WHOLE_GROUP; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template <> const char *EnumUtil::ToChars(GeoParquetVersion value) { switch (value) { @@ -804,7 +829,7 @@ static unique_ptr ParquetScanReplacement(ClientContext &context, Repla if (!FileSystem::HasGlob(table_name)) { auto &fs = FileSystem::GetFileSystem(context); - table_function->alias = fs.ExtractBaseName(table_name); + table_function->alias = Identifier(fs.ExtractBaseName(table_name)); } return std::move(table_function); @@ -904,9 +929,9 @@ static void LoadInternal(ExtensionLoader &loader) { fs.RegisterSubSystem(FileCompressionType::ZSTD, make_uniq()); auto scan_fun = ParquetScanFunction::GetFunctionSet(); - scan_fun.name = "read_parquet"; + scan_fun.SetName("read_parquet"); loader.RegisterFunction(scan_fun); - scan_fun.name = "parquet_scan"; + scan_fun.SetName("parquet_scan"); loader.RegisterFunction(scan_fun); // parquet_metadata @@ -937,7 +962,6 @@ static void LoadInternal(ExtensionLoader &loader) { loader.RegisterFunction(VariantColumnWriter::GetTransformFunction()); CopyFunction function("parquet"); - function.supports_sql_null = true; function.copy_to_select = ParquetWriteSelect; function.copy_to_bind = ParquetWriteBind; function.copy_to_propagate_statistics = ParquetCopyToPropagateStatistics; diff --git a/src/duckdb/extension/parquet/parquet_field_id.cpp b/src/duckdb/extension/parquet/parquet_field_id.cpp index 49a22747b..8c354ce2f 100644 --- a/src/duckdb/extension/parquet/parquet_field_id.cpp +++ b/src/duckdb/extension/parquet/parquet_field_id.cpp @@ -17,7 +17,7 @@ namespace duckdb { constexpr const char *FieldID::DUCKDB_FIELD_ID; -ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { +ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { } ChildFieldIDs ChildFieldIDs::Copy() const { @@ -64,7 +64,7 @@ static case_insensitive_map_t GetChildNameToTypeMap(const LogicalTy return name_to_type_map; } -static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, +static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, vector &child_types) { switch (type.id()) { case LogicalTypeId::LIST: @@ -88,7 +88,7 @@ static void GetChildNamesAndTypes(const LogicalType &type, vector &child } // LCOV_EXCL_STOP } -void FieldID::GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, +void FieldID::GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, const vector &sql_types) { D_ASSERT(names.size() == sql_types.size()); for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { @@ -103,7 +103,7 @@ void FieldID::GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const } // Cannot use GetChildNameToTypeMap here because we lose order, and we want to generate depth-first - vector child_names; + vector child_names; vector child_types; GetChildNamesAndTypes(col_type, child_names, child_types); GenerateFieldIDs(inserted.first->second.child_field_ids, field_id, child_names, child_types); @@ -122,7 +122,7 @@ void FieldID::GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids const auto &struct_children = StructValue::GetChildren(field_ids_value); D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); for (idx_t i = 0; i < struct_children.size(); i++) { - const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); + const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i).GetIdentifierName()); if (col_name == FieldID::DUCKDB_FIELD_ID) { continue; } @@ -141,7 +141,8 @@ void FieldID::GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids "column is a partition column. Available column names: [%s]", col_name, names); } - D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys + D_ASSERT(field_ids.ids->find(Identifier(col_name)) == + field_ids.ids->end()); // Caught by STRUCT - deduplicates keys const auto &child_value = struct_children[i]; const auto &child_type = child_value.type(); @@ -172,7 +173,7 @@ void FieldID::GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids } field_id = FieldID(UnsafeNumericCast(field_id_int)); } - auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); + auto inserted = field_ids.ids->insert(make_pair(Identifier(col_name), std::move(field_id))); D_ASSERT(inserted.second); if (child_field_ids_value) { diff --git a/src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp b/src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp index 3cf35b007..1f48f7671 100644 --- a/src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp +++ b/src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp @@ -22,7 +22,7 @@ ParquetFileMetadataCache::ParquetFileMetadataCache(unique_ptr geo_metadata, unique_ptr crypto_metadata, idx_t footer_size) : metadata(std::move(file_metadata)), geo_metadata(std::move(geo_metadata)), - crypto_metadata(std::move(crypto_metadata)), footer_size(footer_size), validate(handle.Validate()), + crypto_metadata(std::move(crypto_metadata)), footer_size(footer_size), last_modified(handle.GetLastModifiedTime()), version_tag(handle.GetVersionTag()) { } @@ -60,7 +60,7 @@ optional_idx ParquetFileMetadataCache::GetEstimatedCacheMemory() const { } bool ParquetFileMetadataCache::IsValid(CachingFileHandle &new_handle) const { - return ExternalFileCache::IsValid(validate, version_tag, last_modified, new_handle.GetVersionTag(), + return ExternalFileCache::IsValid(new_handle.Validate(), version_tag, last_modified, new_handle.GetVersionTag(), new_handle.GetLastModifiedTime()); } @@ -69,7 +69,7 @@ ParquetCacheValidity ParquetFileMetadataCache::IsValid(const OpenFileInfo &info, if (validation_mode == CacheValidationMode::NO_VALIDATION) { return ParquetCacheValidity::VALID; } - if (validation_mode == CacheValidationMode::VALIDATE_REMOTE && FileSystem::IsRemoteFile(info.path)) { + if (validation_mode == CacheValidationMode::VALIDATE_REMOTE && !FileSystem::IsRemoteFile(info.path)) { return ParquetCacheValidity::VALID; } if (info.extended_info == nullptr) { diff --git a/src/duckdb/extension/parquet/parquet_geometry.cpp b/src/duckdb/extension/parquet/parquet_geometry.cpp index 86ec8a748..a5c7adc29 100644 --- a/src/duckdb/extension/parquet/parquet_geometry.cpp +++ b/src/duckdb/extension/parquet/parquet_geometry.cpp @@ -393,7 +393,7 @@ unique_ptr GeometryColumnReader::Create(const ParquetReader &reade // TODO: Pass the actual target type here so we get the CRS information too auto func = StGeomfromwkbFun::GetFunction(); - func.name = "ST_GeomFromWKB"; + func.SetName("ST_GeomFromWKB"); auto read_expr = func.Bind(context, std::move(args)); auto type_expr = BoundCastExpression::AddDefaultCastToType(std::move(read_expr), schema.type); return make_uniq(context, std::move(string_reader), std::move(type_expr), schema); diff --git a/src/duckdb/extension/parquet/parquet_metadata.cpp b/src/duckdb/extension/parquet/parquet_metadata.cpp index 45d5b5792..4f8f16125 100644 --- a/src/duckdb/extension/parquet/parquet_metadata.cpp +++ b/src/duckdb/extension/parquet/parquet_metadata.cpp @@ -855,7 +855,8 @@ void ParquetBloomProbeProcessor::InitializeInternal(ClientContext &context, Parq allocator = &BufferAllocator::Get(context); auto column_type = reader.GetColumns()[probe_column_idx.GetIndex()].type; auto comparison = BoundComparisonExpression::Create( - ExpressionType::COMPARE_EQUAL, make_uniq(probe_column_name, column_type, 0), + ExpressionType::COMPARE_EQUAL, + make_uniq(Identifier(probe_column_name), column_type, 0), make_uniq(probe_constant.CastAs(context, column_type))); filter = make_uniq(std::move(comparison)); } @@ -941,7 +942,7 @@ void ParquetMetaDataOperator::BindSchema(file_meta_types, file_meta_names); child_list_t file_meta_children; for (idx_t i = 0; i < file_meta_types.size(); i++) { - file_meta_children.push_back(make_pair(file_meta_names[i], file_meta_types[i])); + file_meta_children.emplace_back(make_pair(file_meta_names[i], file_meta_types[i])); } return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(file_meta_children)))); @@ -951,7 +952,7 @@ void ParquetMetaDataOperator::BindSchema(row_group_types, row_group_names); child_list_t row_group_children; for (idx_t i = 0; i < row_group_types.size(); i++) { - row_group_children.push_back(make_pair(row_group_names[i], row_group_types[i])); + row_group_children.emplace_back(make_pair(row_group_names[i], row_group_types[i])); } return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(row_group_children)))); @@ -961,7 +962,7 @@ void ParquetMetaDataOperator::BindSchema(schema_types, schema_names); child_list_t schema_children; for (idx_t i = 0; i < schema_types.size(); i++) { - schema_children.push_back(make_pair(schema_names[i], schema_types[i])); + schema_children.emplace_back(make_pair(schema_names[i], schema_types[i])); } return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(schema_children)))); @@ -971,7 +972,7 @@ void ParquetMetaDataOperator::BindSchema(kv_types, kv_names); child_list_t kv_children; for (idx_t i = 0; i < kv_types.size(); i++) { - kv_children.push_back(make_pair(kv_names[i], kv_types[i])); + kv_children.emplace_back(make_pair(kv_names[i], kv_types[i])); } return_types.emplace_back(LogicalType::LIST(LogicalType::STRUCT(std::move(kv_children)))); } @@ -1110,8 +1111,6 @@ void ParquetMetaDataOperator::Function(ClientContext &context, TableFunctionInpu rows_to_output = left_in_vector; } - output.SetCardinality(output_count + rows_to_output); - for (idx_t i = 0; i < rows_to_output; ++i) { local_state.processor->ReadRow(output_vectors, local_state.row_idx + i, *local_state.reader); } diff --git a/src/duckdb/extension/parquet/parquet_multi_file_info.cpp b/src/duckdb/extension/parquet/parquet_multi_file_info.cpp index 174c549c2..258dead1a 100644 --- a/src/duckdb/extension/parquet/parquet_multi_file_info.cpp +++ b/src/duckdb/extension/parquet/parquet_multi_file_info.cpp @@ -101,7 +101,7 @@ struct ParquetReadLocalState : public LocalTableFunctionState { }; static void ParseFileRowNumberOption(MultiFileReaderBindData &bind_data, ParquetOptions &options, - vector &return_types, vector &names) { + vector &return_types, vector &names) { if (options.file_row_number) { if (StringUtil::CIFind(names, "file_row_number") != DConstants::INVALID_INDEX) { throw BinderException( @@ -113,7 +113,7 @@ static void ParseFileRowNumberOption(MultiFileReaderBindData &bind_data, Parquet } } -static void BindSchema(ClientContext &context, vector &return_types, vector &names, +static void BindSchema(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) { auto &parquet_bind = bind_data.bind_data->Cast(); auto &options = parquet_bind.GetParquetOptions(); @@ -125,7 +125,7 @@ static void BindSchema(ClientContext &context, vector &return_types } auto &reader_bind = bind_data.reader_bind; - vector schema_col_names; + vector schema_col_names; vector schema_col_types; schema_col_names.reserve(options.schema.size()); schema_col_types.reserve(options.schema.size()); @@ -143,7 +143,7 @@ static void BindSchema(ClientContext &context, vector &return_types for (idx_t i = 0; i < options.schema.size(); i++) { const auto &column = options.schema[i]; - schema_col_names.push_back(column.name); + schema_col_names.push_back(Identifier(column.name)); schema_col_types.push_back(column.type); auto res = MultiFileColumnDefinition(column.name, column.type); @@ -163,7 +163,7 @@ static void BindSchema(ClientContext &context, vector &return_types if (options.file_row_number) { MultiFileColumnDefinition res("file_row_number", LogicalType::BIGINT); res.identifier = Value::INTEGER(MultiFileReader::ORDINAL_FIELD_ID); - schema_col_names.push_back(res.name); + schema_col_names.push_back(Identifier(res.name)); schema_col_types.push_back(res.type); reader_bind.schema.emplace_back(res); } @@ -187,8 +187,8 @@ unique_ptr ParquetMultiFileInfo::CreateInterface(Clien return make_uniq(); } -void ParquetMultiFileInfo::BindReader(ClientContext &context, vector &return_types, vector &names, - MultiFileBindData &bind_data) { +void ParquetMultiFileInfo::BindReader(ClientContext &context, vector &return_types, + vector &names, MultiFileBindData &bind_data) { auto &parquet_bind = bind_data.bind_data->Cast(); auto &options = parquet_bind.GetParquetOptions(); if (!options.schema.empty()) { @@ -248,21 +248,21 @@ static void VerifyParquetSchemaParameter(const Value &schema) { throw InvalidInputException( "'schema' expects the STRUCT to have 3 children, 'name', 'type' and 'default_value"); } - if (!StringUtil::CIEquals(children[0].first, "name")) { + if (children[0].first != "name") { throw InvalidInputException("'schema' expects the first field of the struct to be called 'name'"); } if (children[0].second.id() != LogicalTypeId::VARCHAR) { throw InvalidInputException("'schema' expects the 'name' field to be of type VARCHAR, not %s", LogicalTypeIdToString(children[0].second.id())); } - if (!StringUtil::CIEquals(children[1].first, "type")) { + if (children[1].first != "type") { throw InvalidInputException("'schema' expects the second field of the struct to be called 'type'"); } if (children[1].second.id() != LogicalTypeId::VARCHAR) { throw InvalidInputException("'schema' expects the 'type' field to be of type VARCHAR, not %s", LogicalTypeIdToString(children[1].second.id())); } - if (!StringUtil::CIEquals(children[2].first, "default_value")) { + if (children[2].first != "default_value") { throw InvalidInputException("'schema' expects the third field of the struct to be called 'default_value'"); } //! NOTE: default_value can be any type @@ -288,7 +288,8 @@ static void ParquetScanSerialize(Serializer &serializer, const optional_ptr ParquetScanDeserialize(Deserializer &deserialize file_path.emplace_back(path); } FileGlobInput input(FileGlobOptions::FALLBACK_GLOB, "parquet"); + input.allow_empty = serialization.file_options.allow_empty; auto multi_file_reader = MultiFileReader::Create(function); auto file_list = multi_file_reader->CreateFileList(context, Value::LIST(LogicalType::VARCHAR, file_path), input); @@ -326,6 +328,23 @@ static vector ParquetGetRowIdColumns(ClientContext &context, optional_ return result; } +static void ParquetScanGetMetrics(TableFunctionGetMetricsInput &input) { + // emit the shared multi-file metrics (files read, filenames, rows scanned) + MultiFileFunction::MultiFileGetMetrics(input); + // report row groups read vs. considered as the standard per-thread row-group metrics: the profiler sums + // row_groups_scanned / total_row_groups_to_scan across threads, and "skipped" = total - scanned + if (!input.local_state) { + return; + } + auto &local = input.local_state->Cast(); + if (!local.local_state) { + return; + } + auto &scan_state = local.local_state->Cast().scan_state; + input.operator_metrics.row_groups_scanned = scan_state.row_groups_read; + input.operator_metrics.total_row_groups_to_scan = scan_state.row_groups_read + scan_state.row_groups_skipped; +} + ParquetMetadataCacheEntry::ParquetMetadataCacheEntry(shared_ptr metadata_p, ParquetCacheValidity validity_p, bool has_deletes_p) : metadata(std::move(metadata_p)), validity(validity_p), has_deletes(has_deletes_p) { @@ -420,7 +439,9 @@ TableFunctionSet ParquetScanFunction::GetFunctionSet() { table_function.named_parameters["encryption_config"] = LogicalTypeId::ANY; table_function.named_parameters["parquet_version"] = LogicalType::VARCHAR; table_function.named_parameters["can_have_nan"] = LogicalType::BOOLEAN; + table_function.named_parameters["prefetch_strategy"] = LogicalType::VARCHAR; table_function.statistics_extended = MultiFileFunction::MultiFileScanStatsExtended; + table_function.get_metrics = ParquetScanGetMetrics; table_function.supports_pushdown_extract = ParquetScanSupportPushdownExtract; table_function.serialize = ParquetScanSerialize; table_function.deserialize = ParquetScanDeserialize; @@ -474,6 +495,13 @@ bool ParquetMultiFileInfo::ParseCopyOption(ClientContext &context, const string options.can_have_nan = GetBooleanArgument(key, values); return true; } + if (key == "prefetch_strategy") { + if (values.size() != 1) { + throw BinderException("Parquet prefetch_strategy cannot be empty!"); + } + options.prefetch_strategy = ParquetPrefetchStrategyOptionFromString(StringValue::Get(values[0])); + return true; + } return false; } @@ -528,6 +556,10 @@ bool ParquetMultiFileInfo::ParseOption(ClientContext &context, const string &ori options.encryption_config = ParquetEncryptionConfig::Create(context, val); return true; } + if (key == "prefetch_strategy") { + options.prefetch_strategy = ParquetPrefetchStrategyOptionFromString(StringValue::Get(val)); + return true; + } return false; } @@ -656,7 +688,7 @@ unique_ptr ParquetMultiFileInfo::GetCardinality(ClientContext &c return make_uniq(per_file_cardinality * file_count); } -unique_ptr ParquetReader::GetStatistics(ClientContext &context, const string &name) { +unique_ptr ParquetReader::GetStatistics(ClientContext &context, const Identifier &name) { return ReadStatistics(name); } @@ -668,6 +700,8 @@ double ParquetReader::GetProgressInFile(ClientContext &context) { void ParquetMultiFileInfo::GetVirtualColumns(ClientContext &, MultiFileBindData &, virtual_column_map_t &result) { result.insert(make_pair(MultiFileReader::COLUMN_IDENTIFIER_FILE_ROW_NUMBER, TableColumn("file_row_number", LogicalType::BIGINT))); + result.insert(make_pair(ParquetReader::COLUMN_IDENTIFIER_FILE_ROW_GROUP_NUMBER, + TableColumn("file_row_group_number", LogicalType::UBIGINT))); } shared_ptr ParquetMultiFileInfo::CreateReader(ClientContext &context, GlobalTableFunctionState &, @@ -696,7 +730,7 @@ shared_ptr ParquetReader::GetUnionData(idx_t file_idx) { result->names.reserve(columns.size()); result->types.reserve(columns.size()); for (auto &column : columns) { - result->names.push_back(column.name); + result->names.push_back(column.name.GetIdentifierName()); result->types.push_back(column.type); } diff --git a/src/duckdb/extension/parquet/parquet_prefetch_cost_model.cpp b/src/duckdb/extension/parquet/parquet_prefetch_cost_model.cpp new file mode 100644 index 000000000..f6b2ecf97 --- /dev/null +++ b/src/duckdb/extension/parquet/parquet_prefetch_cost_model.cpp @@ -0,0 +1,49 @@ +#include "parquet_prefetch_cost_model.hpp" + +#include "duckdb/common/helper.hpp" +#include "duckdb/common/file_system.hpp" + +namespace duckdb { + +constexpr uint64_t PrefetchCostModel::GAP_MIN; +constexpr uint64_t PrefetchCostModel::GAP_MAX; +constexpr double PrefetchCostModelState::ALPHA; + +uint64_t PrefetchCostModel::GetColumnGapSize() const { + const double column_gap_size = latency_seconds * bandwidth_bytes_per_s; + if (!(column_gap_size > 0)) { // also catches NaN + return GAP_MIN; + } + if (column_gap_size >= static_cast(GAP_MAX)) { + return GAP_MAX; + } + return MaxValue(GAP_MIN, static_cast(column_gap_size)); +} + +void PrefetchCostModelState::RefineFromEstimate(const NetworkThroughputEstimate &estimate) { + const double latency = estimate.latency_seconds; + const double bandwidth = estimate.bandwidth_bytes_per_s; + // Ignore degenerate estimates (also catches NaN). + if (!(latency > 0) || !(bandwidth > 0)) { + return; + } + + if (!measured_adopted) { + model.latency_seconds = latency; + model.bandwidth_bytes_per_s = bandwidth; + measured_adopted = true; + } else { + model.latency_seconds = ALPHA * latency + (1.0 - ALPHA) * model.latency_seconds; + model.bandwidth_bytes_per_s = ALPHA * bandwidth + (1.0 - ALPHA) * model.bandwidth_bytes_per_s; + } +} + +bool PrefetchCostModelState::TryGetColumnGapSize(uint64_t &result) const { + if (!measured_adopted) { + return false; + } + result = model.GetColumnGapSize(); + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp index a522c6fcf..ef008e162 100644 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ b/src/duckdb/extension/parquet/parquet_reader.cpp @@ -5,6 +5,7 @@ #include #include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/reference_map.hpp" #include "duckdb/function/partition_stats.hpp" #include "parquet_types.h" #include "column_reader.hpp" @@ -18,9 +19,11 @@ #include "reader/variant_column_reader.hpp" #include "reader/struct_column_reader.hpp" #include "thrift_tools.hpp" +#include "parquet_prefetch_cost_model.hpp" #include "duckdb/common/encryption_state.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/object_cache.hpp" #include "duckdb/planner/table_filter_state.hpp" @@ -51,7 +54,6 @@ #include "duckdb/main/setting_info.hpp" #include "duckdb/original/std/memory.hpp" #include "duckdb/planner/expression.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/table_filter_set.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/statistics/numeric_stats.hpp" @@ -68,6 +70,125 @@ namespace duckdb { +const char *ParquetPrefetchStrategyToString(ParquetPrefetchStrategy strategy) { + switch (strategy) { + case ParquetPrefetchStrategy::WHOLE_GROUP: + return "whole_group"; + case ParquetPrefetchStrategy::PREFETCH_FILTERS: + return "prefetch_filters"; + case ParquetPrefetchStrategy::COLUMN_WISE_EAGER: + return "column_wise_eager"; + case ParquetPrefetchStrategy::NONE: + return "none"; + default: + throw InternalException("Invalid ParquetPrefetchStrategy for ParquetPrefetchStrategyToString()"); + } +} + +ParquetPrefetchStrategyOption ParquetPrefetchStrategyOptionFromString(const string &value) { + auto lower = StringUtil::Lower(value); + if (lower == "auto") { + return ParquetPrefetchStrategyOption::AUTO; + } + if (lower == "whole_group") { + return ParquetPrefetchStrategyOption::WHOLE_GROUP; + } + throw BinderException("Unrecognized prefetch_strategy '%s' (supported: 'auto', 'whole_group')", value); +} + +static idx_t ParquetColumnChunkFileOffset(const duckdb_parquet::ColumnChunk &chunk) { + idx_t offset = NumericCast(chunk.meta_data.data_page_offset); + if (chunk.meta_data.__isset.dictionary_page_offset) { + offset = MinValue(offset, NumericCast(chunk.meta_data.dictionary_page_offset)); + } + if (chunk.meta_data.__isset.index_page_offset) { + offset = MinValue(offset, NumericCast(chunk.meta_data.index_page_offset)); + } + return offset; +} + +void ParquetLoggerPrefetchMetrics::GeneratePrefetchGroup(ThriftFileTransport &trans, + vector &requested_columns, + const vector &all_chunks) { + if (requested_columns.empty()) { + return; + } + + // Let's use ReadHead to bucket registrations so we are use we are getting whatever the ThriftFileTransport does + vector> bucket_order; + reference_map_t> buckets; + set registered_offsets; + for (auto &requested_column : requested_columns) { + registered_offsets.insert(requested_column.offset); + auto rh_ptr = trans.GetReadHead(requested_column.offset); + if (!rh_ptr) { + continue; + } + auto &rh = *rh_ptr; + auto inserted = buckets.emplace(rh, vector {}); + if (inserted.second) { + bucket_order.emplace_back(rh); + } + inserted.first->second.push_back(requested_column); + } + + for (auto &chunk : all_chunks) { + // We gotta figure out if we had any hitchhiker + idx_t chunk_start = ParquetColumnChunkFileOffset(chunk); + if (registered_offsets.count(chunk_start) > 0) { + continue; + } + idx_t chunk_end = chunk_start + NumericCast(chunk.meta_data.total_compressed_size); + // Figure out what physical I/O this file offset lives in + auto rh_ptr = trans.GetReadHead(chunk_start); + if (!rh_ptr || chunk_end > rh_ptr->GetEnd()) { + continue; + } + auto &rh = *rh_ptr; + // Hitchhiker only counts if the ReadHead was created for one of our registrations. + auto bucket_it = buckets.find(rh); + if (bucket_it == buckets.end()) { + continue; + } + auto &path = chunk.meta_data.path_in_schema; + string name = path.empty() ? string() : StringUtil::Join(path, "."); + // we tag hitchhikers with *column_name* + bucket_it->second.emplace_back("*" + std::move(name) + "*", chunk_start); + } + + // Issue one log group per ReadHead, on the order they were registered + for (auto &rh_ref : bucket_order) { + auto &bucket = buckets[rh_ref]; + std::sort(bucket.begin(), bucket.end()); + vector names; + names.reserve(bucket.size()); + for (auto &col : bucket) { + names.push_back(std::move(col.name)); + } + prefetch_groups.push_back(std::move(names)); + } + requested_columns.clear(); +} + +static void LogRowGroupPrefetch(ClientContext &context, const string &file_path, idx_t row_group_id, + ParquetReaderScanState &state) { + const bool fully_filtered = !state.prefetch_metrics.had_match; + auto &filters_used = state.prefetch_metrics.logger.filters_used; + vector minimal_filters; + minimal_filters.reserve(filters_used.size()); + for (idx_t i = 0; i < state.scan_filters.size(); i++) { + if (i < filters_used.size() && filters_used[i]) { + auto &scan_filter = state.scan_filters[i]; + MultiFileLocalIndex local_idx(scan_filter.filter_idx); + minimal_filters.push_back(state.GetColumnReader(local_idx).Schema().name); + } + } + DUCKDB_LOG(context, ParquetPrefetchLogType, file_path, row_group_id, fully_filtered, + ParquetPrefetchStrategyToString(state.prefetch_metrics.logger.strategy), + state.prefetch_metrics.logger.prefetch_groups, minimal_filters, + state.prefetch_metrics.logger.accepted_column_gap); +} + using duckdb_parquet::ColumnChunk; using duckdb_parquet::ConvertedType; using duckdb_parquet::FieldRepetitionType; @@ -79,8 +200,9 @@ using duckdb_parquet::Statistics; using duckdb_parquet::Type; static unique_ptr -CreateThriftFileProtocol(QueryContext context, CachingFileHandle &file_handle, bool prefetch_mode) { - auto transport = duckdb_base_std::make_shared(file_handle, prefetch_mode); +CreateThriftFileProtocol(QueryContext context, CachingFileHandle &file_handle, bool prefetch_mode, + uint64_t accepted_column_gap = ReadHeadComparator::DEFAULT_ACCEPTED_COLUMN_GAP) { + auto transport = duckdb_base_std::make_shared(file_handle, prefetch_mode, accepted_column_gap); return make_uniq>(std::move(transport)); } @@ -460,6 +582,8 @@ unique_ptr ParquetReader::CreateReaderRecursive(ClientContext &con switch (schema.schema_type) { case ParquetColumnSchemaType::FILE_ROW_NUMBER: return make_uniq(*this, schema); + case ParquetColumnSchemaType::FILE_ROW_GROUP_NUMBER: + return make_uniq(*this, schema); case ParquetColumnSchemaType::GEOMETRY: { return GeometryColumnReader::Create(*this, schema, context); } @@ -786,6 +910,10 @@ static ParquetColumnSchema FileRowNumberSchema() { return ParquetColumnSchema::FileRowNumber(); } +static ParquetColumnSchema FileRowGroupNumberSchema() { + return ParquetColumnSchema::FileRowGroupNumber(); +} + unique_ptr ParquetReader::ParseSchema(ClientContext &context) { auto file_meta_data = GetFileMetadata(); idx_t next_schema_idx = 0; @@ -869,6 +997,8 @@ void ParquetReader::InitializeSchema(ClientContext &context) { void ParquetReader::AddVirtualColumn(column_t virtual_column_id) { if (virtual_column_id == MultiFileReader::COLUMN_IDENTIFIER_FILE_ROW_NUMBER) { root_schema->children.push_back(FileRowNumberSchema()); + } else if (virtual_column_id == ParquetReader::COLUMN_IDENTIFIER_FILE_ROW_GROUP_NUMBER) { + root_schema->children.push_back(FileRowGroupNumberSchema()); } else { throw InternalException("Unsupported virtual column id %d for parquet reader", virtual_column_id); } @@ -914,7 +1044,6 @@ ParquetReader::ParquetReader(ClientContext &context_p, OpenFileInfo file_p, Parq "Reading parquet files from a FIFO stream is not supported and cannot be efficiently supported since " "metadata is located at the end of the file. Write the stream to disk first and read from there instead."); } - // read the extended file open info (if any) optional_idx footer_size; if (file.extended_info) { @@ -938,11 +1067,11 @@ ParquetReader::ParquetReader(ClientContext &context_p, OpenFileInfo file_p, Parq metadata = LoadMetadata(context_p, allocator, *file_handle, parquet_options.encryption_config, encryption_util, footer_size); } else { - metadata = ObjectCache::GetObjectCache(context_p).Get(file.path); + metadata = ObjectCache::GetObjectCache(context_p).GetWithTypePrefix(file.path); if (!metadata || !metadata->IsValid(*file_handle)) { metadata = LoadMetadata(context_p, allocator, *file_handle, parquet_options.encryption_config, encryption_util, footer_size); - ObjectCache::GetObjectCache(context_p).Put(file.path, metadata); + ObjectCache::GetObjectCache(context_p).PutWithTypePrefix(file.path, metadata); } } } else { @@ -962,7 +1091,7 @@ bool ParquetReader::MetadataCacheEnabled(ClientContext &context) { shared_ptr ParquetReader::GetMetadataCacheEntry(ClientContext &context, const OpenFileInfo &file) { - return ObjectCache::GetObjectCache(context).Get(file.path); + return ObjectCache::GetObjectCache(context).GetWithTypePrefix(file.path); } ParquetUnionData::~ParquetUnionData() { @@ -978,7 +1107,7 @@ optional_idx ParquetUnionData::TryGetCardinalityEstimate() const { return optional_idx(); } -unique_ptr ParquetUnionData::GetStatistics(ClientContext &context, const string &name) { +unique_ptr ParquetUnionData::GetStatistics(ClientContext &context, const Identifier &name) { if (reader) { return reader->Cast().GetStatistics(context, name); } @@ -1023,7 +1152,7 @@ static unique_ptr ReadStatisticsInternal(const FileMetaData &fil return column_stats; } -unique_ptr ParquetReader::ReadStatistics(const string &name) { +unique_ptr ParquetReader::ReadStatistics(const Identifier &name) { idx_t file_col_idx; for (file_col_idx = 0; file_col_idx < columns.size(); file_col_idx++) { if (columns[file_col_idx].name == name) { @@ -1039,12 +1168,12 @@ unique_ptr ParquetReader::ReadStatistics(const string &name) { unique_ptr ParquetReader::ReadStatistics(ClientContext &context, ParquetOptions parquet_options, shared_ptr metadata, - const string &name) { + const Identifier &name) { ParquetReader reader(context, std::move(parquet_options), std::move(metadata)); return reader.ReadStatistics(name); } -unique_ptr ParquetReader::ReadStatistics(const ParquetUnionData &union_data, const string &name) { +unique_ptr ParquetReader::ReadStatistics(const ParquetUnionData &union_data, const Identifier &name) { const auto &col_names = union_data.names; idx_t file_col_idx; @@ -1172,62 +1301,20 @@ idx_t ParquetReader::GetGroupOffset(ParquetReaderScanState &state) { return min_offset; } -static FilterPropagateResult CheckParquetStringFilter(BaseStatistics &stats, const Statistics &pq_col_stats, - const TableFilter &filter) { - switch (filter.filter_type) { - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_filter = filter.Cast(); - auto and_result = FilterPropagateResult::FILTER_ALWAYS_TRUE; - for (auto &child_filter : conjunction_filter.child_filters) { - auto child_prune_result = CheckParquetStringFilter(stats, pq_col_stats, *child_filter); - if (child_prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - if (child_prune_result != and_result) { - and_result = FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - } - return and_result; - } - case TableFilterType::EXPRESSION_FILTER: { - auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "CheckParquetStringFilter"); - auto &expr = *expr_filter.expr; - - // Handle comparison expressions (from ConstantFilter conversion) - if (BoundComparisonExpression::IsComparison(expr)) { - auto &comp = expr.Cast(); - auto &right = BoundComparisonExpression::Right(comp); - if (right.GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - auto &constant = right.Cast(); - if (constant.value.type().id() == LogicalTypeId::VARCHAR) { - auto &min_value = pq_col_stats.min_value; - auto &max_value = pq_col_stats.max_value; - return StringStats::CheckZonemap(const_data_ptr_cast(min_value.c_str()), min_value.size(), - const_data_ptr_cast(max_value.c_str()), max_value.size(), - comp.GetExpressionType(), StringValue::Get(constant.value)); - } - } - } - return filter.Cast().CheckStatistics(stats); - } - default: - return filter.CheckStatistics(stats); - } -} - static FilterPropagateResult CheckParquetFloatFilter(ColumnReader &reader, const Statistics &pq_col_stats, const TableFilter &filter) { // floating point values can have values in the [min, max] domain AND nan values // check both stats against the filter + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "CheckParquetFloatFilter"); auto &type = reader.Type(); auto nan_stats = NumericStats::CreateUnknown(type); auto nan_value = Value("nan").DefaultCastAs(type); NumericStats::SetMin(nan_stats, nan_value); NumericStats::SetMax(nan_stats, nan_value); - auto nan_prune = filter.CheckStatistics(nan_stats); + auto nan_prune = expr_filter.CheckStatistics(nan_stats); auto min_max_stats = ParquetStatisticsUtils::CreateNumericStats(reader.Type(), reader.Schema(), pq_col_stats); - auto prune = filter.CheckStatistics(*min_max_stats); + auto prune = expr_filter.CheckStatistics(*min_max_stats); // if EITHER of them cannot be pruned - we cannot prune if (prune == FilterPropagateResult::NO_PRUNING_POSSIBLE || @@ -1278,12 +1365,6 @@ void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t i if (is_expression) { // no pruning possible for expressions prune_result = FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else if (!is_generated_column && has_min_max && column_reader.Type().id() == LogicalTypeId::VARCHAR) { - // our StringStats only store the first 8 bytes of strings (even if Parquet has longer string stats) - // however, when reading remote Parquet files, skipping row groups is really important - // here, we implement a special case to check the full length for string filters - prune_result = - CheckParquetStringFilter(*stats, group.columns[schema_column_index].meta_data.statistics, filter); } else if (!is_generated_column && has_min_max && (column_reader.Type().id() == LogicalTypeId::FLOAT || column_reader.Type().id() == LogicalTypeId::DOUBLE) && @@ -1294,7 +1375,9 @@ void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t i prune_result = CheckParquetFloatFilter(column_reader, group.columns[schema_column_index].meta_data.statistics, filter); } else { - prune_result = filter.CheckStatistics(*stats); + auto &expr_filter = + ExpressionFilter::GetExpressionFilter(filter, "ParquetReader::PrepareRowGroupBuffer"); + prune_result = expr_filter.CheckStatistics(*stats); } // check the bloom filter if present if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE && !column_reader.Type().IsNested() && @@ -1312,10 +1395,7 @@ void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t i } } - for (auto &column_reader : state.column_readers) { - column_reader->InitializeRead(state.group_idx_list[state.current_group], group.columns, - *state.thrift_file_proto); - } + column_reader.InitializeRead(state.group_idx_list[state.current_group], group.columns, *state.thrift_file_proto); } idx_t ParquetReader::NumRows() const { @@ -1357,6 +1437,7 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat auto flags = FileFlags::FILE_FLAGS_READ; if (ShouldAndCanPrefetch(context, *file_handle)) { state.prefetch_mode = true; + flags |= FileFlags::FILE_FLAGS_PARALLEL_ACCESS; if (file_handle->IsRemoteFile()) { flags |= FileFlags::FILE_FLAGS_DIRECT_IO; } @@ -1374,8 +1455,20 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat state.scan_filters.emplace_back(context, entry.GetIndex(), entry.Filter()); } } + if (state.filter_eliminated_all_rows.size() != state.scan_filters.size()) { + state.filter_eliminated_all_rows.assign(state.scan_filters.size(), false); + } - state.thrift_file_proto = CreateThriftFileProtocol(context, *state.file_handle, state.prefetch_mode); + uint64_t accepted_column_gap = ReadHeadComparator::DEFAULT_ACCEPTED_COLUMN_GAP; + if (state.prefetch_mode && !state.file_handle->OnDiskFile()) { + NetworkThroughputEstimate estimate; + if (state.file_handle->TryGetNetworkThroughput(estimate)) { + state.cost_model_state.RefineFromEstimate(estimate); + } + state.cost_model_state.TryGetColumnGapSize(accepted_column_gap); + } + state.thrift_file_proto = + CreateThriftFileProtocol(context, *state.file_handle, state.prefetch_mode, accepted_column_gap); state.column_readers.resize(column_indexes.size()); for (idx_t i = 0; i < column_indexes.size(); i++) { @@ -1436,23 +1529,22 @@ struct ParquetPartitionRowGroup : public PartitionRowGroup { D_ASSERT(metadata.row_groups.size() > row_group_idx); D_ASSERT(root_schema->children.size() > primary_index); + // Special handle generated columns. + const auto &column_schema = root_schema->children[primary_index]; + if (column_schema.schema_type == ParquetColumnSchemaType::FILE_ROW_NUMBER) { + return true; + } + const auto &row_group = metadata.row_groups[row_group_idx]; const auto &column_chunk = row_group.columns[primary_index]; - if (!column_chunk.__isset.meta_data || !column_chunk.meta_data.__isset.statistics) { - return false; - } - const auto &col_stats = column_chunk.meta_data.statistics; - if (col_stats.__isset.is_min_value_exact && col_stats.__isset.is_max_value_exact) { - return col_stats.is_min_value_exact && col_stats.is_max_value_exact; + if (column_chunk.__isset.meta_data && column_chunk.meta_data.__isset.statistics && + column_chunk.meta_data.statistics.__isset.is_min_value_exact && + column_chunk.meta_data.statistics.__isset.is_max_value_exact) { + const auto &stats = column_chunk.meta_data.statistics; + return stats.is_min_value_exact && stats.is_max_value_exact; } - // Pre-PARQUET-2352 (Oct 2023) the spec required min/max, when present, to be the exact - // min/max; in practice some writers truncated long binary values, and PARQUET-2352 added - // is_*_value_exact so readers could detect that. Fixed-width physical types are not - // subject to such truncation, so when the flag is missing we mirror arrow-rs - // (parquet/src/file/statistics.rs ValueStatistics::new) and treat them as exact. - const auto t = column_chunk.meta_data.type; - return t != Type::BYTE_ARRAY && t != Type::FIXED_LEN_BYTE_ARRAY; + return false; } }; @@ -1475,108 +1567,413 @@ void ParquetReader::GetPartitionStats(const duckdb_parquet::FileMetaData &metada } } +// An I/O task that fetches the bytes of a ReadHead. +class ParquetIOAsyncTask : public AsyncTask { +public: + ParquetIOAsyncTask(shared_ptr read_head, shared_ptr file_handle) + : read_head(std::move(read_head)), file_handle(std::move(file_handle)) { + } + + void Execute() override { + read_head->Fetch(*file_handle); + } + +private: + shared_ptr read_head; + shared_ptr file_handle; +}; + +// Async I/O tasks for the read heads in the index range [from, to), skipping any that are already fetched. +static vector> +CollectIOTasks(ThriftFileTransport &trans, const shared_ptr &file_handle, idx_t from, idx_t to) { + auto &read_heads = trans.GetReadHeads(); + vector> io_tasks; + for (idx_t i = from; i < to; i++) { + auto &read_head = read_heads[i]; + if (!read_head->data_isset) { + io_tasks.push_back(make_uniq(read_head, file_handle)); + } + } + return io_tasks; +} + +ParquetPrefetchStrategy ParquetReader::WholeGroupPrefetch(ParquetReaderScanState &state, ThriftFileTransport &trans, + const duckdb_parquet::RowGroup &group, + uint64_t total_row_group_span, bool log_prefetch) { + if (!state.current_group_prefetched) { + auto total_compressed_size = GetGroupCompressedSize(state); + if (total_compressed_size > 0) { + trans.RegisterPrefetch(GetGroupOffset(state), total_row_group_span, false); + trans.FinalizeRegistration(); + } + state.current_group_prefetched = true; + } + if (log_prefetch) { + state.prefetch_metrics.logger.strategy = ParquetPrefetchStrategy::WHOLE_GROUP; + vector pending_registrations; + for (auto &column_chunk : group.columns) { + auto &path = column_chunk.meta_data.path_in_schema; + string name = path.empty() ? string() : StringUtil::Join(path, "."); + pending_registrations.push_back({name, ParquetColumnChunkFileOffset(column_chunk)}); + } + state.prefetch_metrics.logger.GeneratePrefetchGroup(trans, pending_registrations, group.columns); + } + return ParquetPrefetchStrategy::WHOLE_GROUP; +} + +ParquetPrefetchStrategy ParquetReader::ColumnWisePrefetch(ParquetReaderScanState &state, ThriftFileTransport &trans, + const duckdb_parquet::RowGroup &group, + bool filters_look_unselective, bool log_prefetch) const { + bool has_non_optional_filter = false; + if (filters) { + for (auto &entry : *filters) { + if (!ExpressionFilter::IsRootOptionalFilter(entry.Filter())) { + has_non_optional_filter = true; + } + } + } + const bool lazy_fetch = has_non_optional_filter && !filters_look_unselective; + + vector already_registered(column_ids.size(), false); + vector pending_registrations; + + if (lazy_fetch && !state.scan_filters.empty()) { + auto &adaptive_filter = state.adaptive_filter_cache.GetAdaptiveFilter(); + const auto &permutation = adaptive_filter.GetPermutation(); + for (idx_t i = 0; i < state.scan_filters.size(); i++) { + // We must do the filter grouping based on the set that excludes all rows + if (i > 0 && state.filter_eliminated_all_rows[permutation[i - 1]]) { + trans.FinalizeRegistration(); + if (log_prefetch) { + state.prefetch_metrics.logger.GeneratePrefetchGroup(trans, pending_registrations, group.columns); + } + } + auto &scan_filter = state.scan_filters[permutation[i]]; + MultiFileLocalIndex local_idx(scan_filter.filter_idx); + auto &child = state.GetColumnReader(local_idx); + child.RegisterPrefetch(trans, true); + if (log_prefetch) { + pending_registrations.emplace_back(child.Schema().name, child.FileOffset()); + } + already_registered[local_idx.GetIndex()] = true; + } + // Done with the filter columns, time to merge the non-filter columns. + trans.FinalizeRegistration(); + state.filter_head_count = trans.GetReadHeads().size(); + if (log_prefetch) { + state.prefetch_metrics.logger.GeneratePrefetchGroup(trans, pending_registrations, group.columns); + } + } + + for (idx_t i = 0; i < column_ids.size(); i++) { + // We do the remaining columns that have not yet been prefetched by the previous steps + // These can be the remaining projections, or all columns if the filters were not selective + if (already_registered[i]) { + continue; + } + auto col_idx = MultiFileLocalIndex(i); + auto &child = state.GetColumnReader(col_idx); + if (child.IsSkipped()) { + //! Column reader is skipped entirely + continue; + } + child.RegisterPrefetch(trans, true); + if (log_prefetch) { + pending_registrations.emplace_back(child.Schema().name, child.FileOffset()); + } + } + trans.FinalizeRegistration(); + const auto strategy = + lazy_fetch ? ParquetPrefetchStrategy::PREFETCH_FILTERS : ParquetPrefetchStrategy::COLUMN_WISE_EAGER; + if (log_prefetch) { + state.prefetch_metrics.logger.GeneratePrefetchGroup(trans, pending_registrations, group.columns); + state.prefetch_metrics.logger.strategy = strategy; + } + return strategy; +} + AsyncResult ParquetReader::Scan(ClientContext &context, ParquetReaderScanState &state, DataChunk &result) { result.Reset(); if (state.finished) { return SourceResultType::FINISHED; } + const bool log_prefetch = + Logger::Get(context).ShouldLog(ParquetPrefetchLogType::NAME, ParquetPrefetchLogType::LEVEL); + // see if we have to switch to the next row group in the parquet file if (state.current_group < 0 || (int64_t)state.offset_in_group >= GetGroup(state).num_rows) { - state.current_group++; - state.offset_in_group = 0; + return Schedule(context, state, result, log_prefetch); + } - auto &trans = reinterpret_cast(*state.thrift_file_proto->getTransport()); - trans.ClearPrefetch(); - state.current_group_prefetched = false; + return Process(state, result, log_prefetch); +} - if ((idx_t)state.current_group == state.group_idx_list.size()) { - state.finished = true; - return SourceResultType::FINISHED; - } +AsyncResult ParquetReader::Schedule(ClientContext &context, ParquetReaderScanState &state, DataChunk &result, + bool log_prefetch) { + state.current_group++; + state.offset_in_group = 0; - // TODO: only need this if we have a deletion vector? - state.group_offset = GetRowGroupOffset(*this, state.group_idx_list[state.current_group]); + auto &trans = reinterpret_cast(*state.thrift_file_proto->getTransport()); + trans.ClearPrefetch(); + state.current_group_prefetched = false; + state.filter_head_count = 0; - uint64_t to_scan_compressed_bytes = 0; - for (idx_t i = 0; i < column_ids.size(); i++) { - auto col_idx = MultiFileLocalIndex(i); - PrepareRowGroupBuffer(state, col_idx); - to_scan_compressed_bytes += state.GetColumnReader(i).TotalCompressedSize(); + if (state.prefetch_mode && !state.file_handle->OnDiskFile()) { + NetworkThroughputEstimate estimate; + if (state.file_handle->TryGetNetworkThroughput(estimate)) { + state.cost_model_state.RefineFromEstimate(estimate); } + uint64_t accepted_column_gap = ReadHeadComparator::DEFAULT_ACCEPTED_COLUMN_GAP; + state.cost_model_state.TryGetColumnGapSize(accepted_column_gap); + trans.SetAcceptedColumnGap(accepted_column_gap); + } - auto &group = GetGroup(state); - if (state.op) { - DUCKDB_LOG(context, PhysicalOperatorLogType, *state.op, "ParquetReader", - state.offset_in_group == (idx_t)group.num_rows ? "SkipRowGroup" : "ReadRowGroup", - {{"file", file.path}, {"row_group_id", to_string(state.group_idx_list[state.current_group])}}); - } + if (log_prefetch && state.prefetch_metrics.filter_ran && state.current_group > 0) { + LogRowGroupPrefetch(context, file.path, state.group_idx_list[state.current_group - 1], state); + } + state.prefetch_metrics.FinalizeRowGroupSelectivity(); - if (state.prefetch_mode && state.offset_in_group != (idx_t)group.num_rows) { - uint64_t total_row_group_span = GetGroupSpan(state); + if ((idx_t)state.current_group == state.group_idx_list.size()) { + state.finished = true; + return SourceResultType::FINISHED; + } - double scan_percentage = (double)(to_scan_compressed_bytes) / static_cast(total_row_group_span); + // TODO: only need this if we have a deletion vector? + state.group_offset = GetRowGroupOffset(*this, state.group_idx_list[state.current_group]); - if (to_scan_compressed_bytes > total_row_group_span) { - throw IOException( - "The parquet file '%s' seems to have incorrectly set page offsets. This interferes with DuckDB's " - "prefetching optimization. DuckDB may still be able to scan this file by manually disabling the " - "prefetching mechanism using: 'SET disable_parquet_prefetching=true'.", - GetFileName()); - } + uint64_t to_scan_compressed_bytes = 0; + for (idx_t i = 0; i < column_ids.size(); i++) { + auto col_idx = MultiFileLocalIndex(i); + PrepareRowGroupBuffer(state, col_idx); + to_scan_compressed_bytes += state.GetColumnReader(i).TotalCompressedSize(); + } - if (!filters && scan_percentage > ParquetReaderPrefetchConfig::WHOLE_GROUP_PREFETCH_MINIMUM_SCAN) { - // Prefetch the whole row group - if (!state.current_group_prefetched) { - auto total_compressed_size = GetGroupCompressedSize(state); - if (total_compressed_size > 0) { - trans.Prefetch(GetGroupOffset(state), total_row_group_span); - } - state.current_group_prefetched = true; - } - } else { - // lazy fetching is when all tuples in a column can be skipped. With lazy fetching the buffer is only - // fetched on the first read to that buffer. - bool lazy_fetch = false; - if (filters) { - // check if any filter is non-optional - bool has_non_optional_filter = false; - for (auto &entry : *filters) { - if (entry.Filter().filter_type != TableFilterType::OPTIONAL_FILTER) { - has_non_optional_filter = true; - } - } - if (has_non_optional_filter) { - lazy_fetch = true; - } - } + auto &group = GetGroup(state); + const bool row_group_skipped = state.offset_in_group == (idx_t)group.num_rows; + if (row_group_skipped) { + ++state.row_groups_skipped; + } else { + ++state.row_groups_read; + } + if (state.op) { + DUCKDB_LOG(context, PhysicalOperatorLogType, *state.op, "ParquetReader", + row_group_skipped ? "SkipRowGroup" : "ReadRowGroup", + {{"file", file.path}, {"row_group_id", to_string(state.group_idx_list[state.current_group])}}); + } - // Prefetch column-wise - for (idx_t i = 0; i < column_ids.size(); i++) { - auto col_idx = MultiFileLocalIndex(i); - bool has_filter = false; - if (filters) { - auto filter = filters->TryGetFilterByColumnIndex(col_idx); - if (filter && filter->filter_type != TableFilterType::OPTIONAL_FILTER) { - has_filter = true; - } - } - state.GetColumnReader(col_idx).RegisterPrefetch(trans, !has_filter); - } + vector> io_tasks; + if (state.prefetch_mode && state.offset_in_group != (idx_t)group.num_rows) { + uint64_t total_row_group_span = GetGroupSpan(state); - trans.FinalizeRegistration(); + double scan_percentage = (double)(to_scan_compressed_bytes) / static_cast(total_row_group_span); - if (!lazy_fetch) { - trans.PrefetchRegistered(); + if (to_scan_compressed_bytes > total_row_group_span) { + throw IOException( + "The parquet file '%s' seems to have incorrectly set page offsets. This interferes with DuckDB's " + "prefetching optimization. DuckDB may still be able to scan this file by manually disabling the " + "prefetching mechanism using: 'SET disable_parquet_prefetching=true'.", + GetFileName()); + } + // We basically do eager fetch if we do a whole-group or column-wise prefetch, if we do filters we do it lazily + ParquetPrefetchStrategy strategy = ParquetPrefetchStrategy::NONE; + if (parquet_options.prefetch_strategy == ParquetPrefetchStrategyOption::WHOLE_GROUP) { + strategy = WholeGroupPrefetch(state, trans, group, total_row_group_span, log_prefetch); + } else { + bool filters_look_unselective = false; + if (filters) { + if (state.prefetch_metrics.row_groups_executed == 0) { + filters_look_unselective = true; + } else if (static_cast(state.prefetch_metrics.row_groups_with_matches) / + static_cast(state.prefetch_metrics.row_groups_executed) > + ParquetReaderPrefetchConfig::PREFETCH_FILTER_MINIMUM_MATCH_RATIO) { + filters_look_unselective = true; } } + + if ((!filters || filters_look_unselective) && + scan_percentage > ParquetReaderPrefetchConfig::WHOLE_GROUP_PREFETCH_MINIMUM_SCAN) { + strategy = WholeGroupPrefetch(state, trans, group, total_row_group_span, log_prefetch); + } else { + strategy = ColumnWisePrefetch(state, trans, group, filters_look_unselective, log_prefetch); + } + } + auto read_head_count = trans.GetReadHeads().size(); + switch (strategy) { + case ParquetPrefetchStrategy::PREFETCH_FILTERS: + // schedule only the filter columns' I/O, they are last in our list + io_tasks = + CollectIOTasks(trans, state.file_handle, read_head_count - state.filter_head_count, read_head_count); + break; + case ParquetPrefetchStrategy::WHOLE_GROUP: + case ParquetPrefetchStrategy::COLUMN_WISE_EAGER: + // schedule the I/O for all columns up front + io_tasks = CollectIOTasks(trans, state.file_handle, 0, read_head_count); + break; + default: + throw InternalException("Unexpected parquet prefetch strategy when scheduling I/O"); } - result.Reset(); - return SourceResultType::HAVE_MORE_OUTPUT; + if (log_prefetch) { + state.prefetch_metrics.logger.accepted_column_gap = trans.GetAcceptedColumnGap(); + } + } + result.Reset(); + if (!io_tasks.empty()) { + return AsyncResult(std::move(io_tasks), TaskSchedulerType::ASYNC); } + return SourceResultType::HAVE_MORE_OUTPUT; +} +idx_t ParquetReader::EvaluateFilters(ParquetReaderScanState &state, DataChunk &result, idx_t scan_count, + uint8_t *define_ptr, uint8_t *repeat_ptr, bool log_prefetch) { + idx_t filter_count = result.size(); + D_ASSERT(filter_count == scan_count); + + state.sel.Initialize(nullptr); + D_ASSERT(!filters || state.scan_filters.size() == filters->FilterCount()); + + bool is_first_filter = true; + if (deletion_filter) { + auto row_start = UnsafeNumericCast(state.offset_in_group + state.group_offset); + filter_count = deletion_filter->Filter(row_start, scan_count, state.sel); + //! FIXME: does this need to be set? + //! As part of 'DirectFilter' we also initialize reads of the child readers + is_first_filter = false; + } + if (!filters) { + return filter_count; + } + + // first load the columns that are used in filters + auto &adaptive_filter = state.adaptive_filter_cache.GetAdaptiveFilter(); + auto filter_state = adaptive_filter.BeginFilter(); + const auto &permutation = adaptive_filter.GetPermutation(); + for (idx_t i = 0; i < state.scan_filters.size(); i++) { + auto &scan_filter = state.scan_filters[permutation[i]]; + MultiFileLocalIndex local_idx(scan_filter.filter_idx); + auto &child_reader = state.GetColumnReader(local_idx); + if (filter_count == 0) { + child_reader.Skip(scan_count); + continue; + } + + auto &result_vector = result.data[local_idx.GetIndex()]; + ColumnReaderInput reader_input(scan_count, define_ptr, repeat_ptr); + child_reader.Filter(reader_input, result_vector, scan_filter.filter, *scan_filter.filter_state, state.sel, + filter_count, is_first_filter); + if (log_prefetch) { + auto &filters_used = state.prefetch_metrics.logger.filters_used; + if (filters_used.size() != state.scan_filters.size()) { + filters_used.assign(state.scan_filters.size(), false); + } + filters_used[permutation[i]] = true; + } + is_first_filter = false; + if (filter_count == 0) { + state.filter_eliminated_all_rows[permutation[i]] = true; + } + } + adaptive_filter.EndFilter(filter_state); + state.prefetch_metrics.filter_ran = true; + if (filter_count > 0) { + state.prefetch_metrics.had_match = true; + } + return filter_count; +} + +// Whether column position `i` is used by a filter. +static bool IsFilterColumn(const vector &scan_filters, idx_t column_index) { + for (auto &scan_filter : scan_filters) { + if (MultiFileLocalIndex(scan_filter.filter_idx).GetIndex() == column_index) { + return true; + } + } + return false; +} + +void ParquetReader::DecodeRemainingColumns(ParquetReaderScanState &state, DataChunk &result, idx_t filter_count, + uint8_t *define_ptr, uint8_t *repeat_ptr) { + for (idx_t i = 0; i < column_ids.size(); i++) { + MultiFileLocalIndex col_idx(i); + if (IsFilterColumn(state.scan_filters, i)) { + // already decoded (or skipped) by EvaluateFilters + continue; + } + if (filter_count == 0) { + state.GetColumnReader(col_idx).Skip(result.size()); + continue; + } + auto &result_vector = result.data[i]; + auto &child_reader = state.GetColumnReader(col_idx); + if (metadata->crypto_metadata->encryption_algorithm.__isset.AES_GCM_V1) { + child_reader.InitializeCryptoMetadata(metadata->crypto_metadata->encryption_algorithm, + GetGroup(state).ordinal); + } + ColumnReaderInput reader_input(result.size(), define_ptr, repeat_ptr); + child_reader.Select(reader_input, result_vector, state.sel, filter_count); + } +} + +// When doing ParquetPrefetchStrategy::PREFETCH_FILTERS, we gotta block mid processing to get the other columns. +// on the re-entry the chunk is reset in the multifilescan which borkes our datachunk, hence we need to copy for +// ownership. +// FIXME: maybe we can change this to be able to move? or not have the multifile reset? bigger refactor though. +static void CopyFilterColumns(const vector &scan_filters, DataChunk &src, DataChunk &dst, + idx_t count) { + for (auto &scan_filter : scan_filters) { + auto idx = MultiFileLocalIndex(scan_filter.filter_idx).GetIndex(); + VectorOperations::Copy(src.data[idx], dst.data[idx], count, 0, 0); + } +} + +vector> ParquetReader::ScheduleRemainingColumns(ParquetReaderScanState &state, DataChunk &result, + idx_t scan_count) { + if (state.filter_count == 0 || state.filter_head_count == 0) { + return {}; + } + auto &trans = reinterpret_cast(*state.thrift_file_proto->getTransport()); + auto payload_head_count = trans.GetReadHeads().size() - state.filter_head_count; + // payload columns are the leading read heads + auto io_tasks = CollectIOTasks(trans, state.file_handle, 0, payload_head_count); + if (io_tasks.empty()) { + // no payload to do + return {}; + } + // stash the decoded filter columns and empty the chunk for the BLOCKED return + if (state.filter_stash.data.empty()) { + state.filter_stash.Initialize(allocator, result.GetTypes()); + } + state.filter_stash.Reset(); + CopyFilterColumns(state.scan_filters, result, state.filter_stash, scan_count); + result.Reset(); + state.filter_done = true; + return io_tasks; +} + +AsyncResult ParquetReader::ProcessFilters(ParquetReaderScanState &state, DataChunk &result, idx_t scan_count, + uint8_t *define_ptr, uint8_t *repeat_ptr, bool log_prefetch) { + if (state.filter_done) { + // the filters already ran, we restore the stashed columns + state.filter_done = false; + CopyFilterColumns(state.scan_filters, state.filter_stash, result, scan_count); + } else { + // we need to evaluate the filters, then async-fetch the payload and resume into the decode + state.filter_count = EvaluateFilters(state, result, scan_count, define_ptr, repeat_ptr, log_prefetch); + auto io_tasks = ScheduleRemainingColumns(state, result, scan_count); + if (!io_tasks.empty()) { + return AsyncResult(std::move(io_tasks), TaskSchedulerType::ASYNC); + } + } + DecodeRemainingColumns(state, result, state.filter_count, define_ptr, repeat_ptr); + if (scan_count != state.filter_count) { + result.Slice(state.sel, state.filter_count); + } + return SourceResultType::HAVE_MORE_OUTPUT; +} + +AsyncResult ParquetReader::Process(ParquetReaderScanState &state, DataChunk &result, bool log_prefetch) { auto scan_count = MinValue(STANDARD_VECTOR_SIZE, GetGroup(state).num_rows - state.offset_in_group); - result.SetCardinality(scan_count); + result.SetChildCardinality(scan_count); if (scan_count == 0) { state.finished = true; @@ -1593,67 +1990,10 @@ AsyncResult ParquetReader::Scan(ClientContext &context, ParquetReaderScanState & auto repeat_ptr = (uint8_t *)state.repeat_buf.ptr; if (filters || deletion_filter) { - idx_t filter_count = result.size(); - D_ASSERT(filter_count == scan_count); - vector need_to_read(column_ids.size(), true); - - state.sel.Initialize(nullptr); - D_ASSERT(!filters || state.scan_filters.size() == filters->FilterCount()); - - bool is_first_filter = true; - if (deletion_filter) { - auto row_start = UnsafeNumericCast(state.offset_in_group + state.group_offset); - filter_count = deletion_filter->Filter(row_start, scan_count, state.sel); - //! FIXME: does this need to be set? - //! As part of 'DirectFilter' we also initialize reads of the child readers - is_first_filter = false; - } - - if (filters) { - // first load the columns that are used in filters - auto &adaptive_filter = state.adaptive_filter_cache.GetAdaptiveFilter(); - auto filter_state = adaptive_filter.BeginFilter(); - const auto &permutation = adaptive_filter.GetPermutation(); - for (idx_t i = 0; i < state.scan_filters.size(); i++) { - if (filter_count == 0) { - // if no rows are left we can stop checking filters - break; - } - auto &scan_filter = state.scan_filters[permutation[i]]; - MultiFileLocalIndex local_idx(scan_filter.filter_idx); - - auto &result_vector = result.data[local_idx.GetIndex()]; - auto &child_reader = state.GetColumnReader(local_idx); - ColumnReaderInput reader_input(scan_count, define_ptr, repeat_ptr); - child_reader.Filter(reader_input, result_vector, scan_filter.filter, *scan_filter.filter_state, - state.sel, filter_count, is_first_filter); - need_to_read[local_idx.GetIndex()] = false; - is_first_filter = false; - } - adaptive_filter.EndFilter(filter_state); - } - - // we still may have to read some cols - for (idx_t i = 0; i < column_ids.size(); i++) { - MultiFileLocalIndex col_idx(i); - if (!need_to_read[col_idx]) { - continue; - } - if (filter_count == 0) { - state.GetColumnReader(col_idx).Skip(result.size()); - continue; - } - auto &result_vector = result.data[i]; - auto &child_reader = state.GetColumnReader(col_idx); - if (metadata->crypto_metadata->encryption_algorithm.__isset.AES_GCM_V1) { - child_reader.InitializeCryptoMetadata(metadata->crypto_metadata->encryption_algorithm, - GetGroup(state).ordinal); - } - ColumnReaderInput reader_input(result.size(), define_ptr, repeat_ptr); - child_reader.Select(reader_input, result_vector, state.sel, filter_count); - } - if (scan_count != filter_count) { - result.Slice(state.sel, filter_count); + auto res = ProcessFilters(state, result, scan_count, define_ptr, repeat_ptr, log_prefetch); + if (res.GetResultType() == AsyncResultType::BLOCKED) { + // we must wait on the payload I/O + return res; } } else { for (idx_t i = 0; i < column_ids.size(); i++) { diff --git a/src/duckdb/extension/parquet/parquet_shredding.cpp b/src/duckdb/extension/parquet/parquet_shredding.cpp index 02bea0eb9..825ce055f 100644 --- a/src/duckdb/extension/parquet/parquet_shredding.cpp +++ b/src/duckdb/extension/parquet/parquet_shredding.cpp @@ -10,7 +10,7 @@ namespace duckdb { class ClientContext; -ChildShreddingTypes::ChildShreddingTypes() : types(make_uniq>()) { +ChildShreddingTypes::ChildShreddingTypes() : types(make_uniq>()) { } ChildShreddingTypes ChildShreddingTypes::Copy() const { @@ -99,11 +99,11 @@ static ShreddingType ConvertShreddingTypeRecursive(const LogicalType &type) { throw BinderException("VARIANT can only be shredded on LIST/STRUCT/ANY/non-nested type, not %s", type.ToString()); } -void ShreddingType::AddChild(const string &name, ShreddingType &&child) { +void ShreddingType::AddChild(const Identifier &name, ShreddingType &&child) { children.types->emplace(name, std::move(child)); } -optional_ptr ShreddingType::GetChild(const string &name) const { +optional_ptr ShreddingType::GetChild(const Identifier &name) const { auto it = children.types->find(name); if (it == children.types->end()) { return nullptr; diff --git a/src/duckdb/extension/parquet/parquet_statistics.cpp b/src/duckdb/extension/parquet/parquet_statistics.cpp index ac7898627..e789f657c 100644 --- a/src/duckdb/extension/parquet/parquet_statistics.cpp +++ b/src/duckdb/extension/parquet/parquet_statistics.cpp @@ -18,7 +18,6 @@ #include "duckdb/storage/statistics/struct_stats.hpp" #include "duckdb/storage/statistics/list_stats.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" @@ -40,7 +39,6 @@ #include "duckdb/common/types/geometry.hpp" #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/timestamp.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/statistics/geometry_stats.hpp" #include "duckdb/storage/statistics/numeric_stats.hpp" @@ -254,7 +252,7 @@ Value ParquetStatisticsUtils::ConvertValueInternal(const LogicalType &type, cons } else if (stats.size() == sizeof(int64_t)) { val = Load(stats_data); } else { - throw InvalidInputException("Incorrect stats size for type TIME"); + throw InvalidInputException("Incorrect stats size for type TIME_NS"); } switch (schema_ele.type_info) { case ParquetExtraTypeInfo::UNIT_MS: @@ -359,6 +357,10 @@ Value ParquetStatisticsUtils::ConvertValueInternal(const LogicalType &type, cons } } +bool IsVariantNull(const string &str) { + return str.size() == 1 && str[0] == '\0'; +} + static bool ConvertUnshreddedStats(BaseStatistics &result, optional_ptr input_p) { D_ASSERT(result.GetType().id() == LogicalTypeId::UINTEGER); @@ -369,14 +371,16 @@ static bool ConvertUnshreddedStats(BaseStatistics &result, optional_ptr(result, 0); NumericStats::SetMax(result, 0); @@ -436,6 +440,17 @@ static bool ConvertShreddedStats(BaseStatistics &result, optional_ptr ParquetStatisticsUtils::TransformParquetStatistics(const LogicalType &type, const ParquetColumnSchema &schema, const duckdb_parquet::Statistics &parquet_stats, bool can_have_nan, @@ -479,15 +494,23 @@ ParquetStatisticsUtils::TransformParquetStatistics(const LogicalType &type, cons case LogicalTypeId::VARCHAR: { auto string_stats = StringStats::CreateUnknown(type); const bool is_varchar = type.id() == LogicalTypeId::VARCHAR; - if (parquet_stats.__isset.min_value && StringColumnReader::IsValid(parquet_stats.min_value, is_varchar)) { - StringStats::SetMin(string_stats, parquet_stats.min_value); - } else if (parquet_stats.__isset.min && StringColumnReader::IsValid(parquet_stats.min, is_varchar)) { - StringStats::SetMin(string_stats, parquet_stats.min); - } - if (parquet_stats.__isset.max_value && StringColumnReader::IsValid(parquet_stats.max_value, is_varchar)) { - StringStats::SetMax(string_stats, parquet_stats.max_value); - } else if (parquet_stats.__isset.max && StringColumnReader::IsValid(parquet_stats.max, is_varchar)) { - StringStats::SetMax(string_stats, parquet_stats.max); + auto min_stats_type = parquet_stats.__isset.is_min_value_exact && parquet_stats.is_min_value_exact + ? StringStatsType::EXACT_STATS + : StringStatsType::TRUNCATED_STATS; + auto max_stats_type = parquet_stats.__isset.is_max_value_exact && parquet_stats.is_max_value_exact + ? StringStatsType::EXACT_STATS + : StringStatsType::TRUNCATED_STATS; + if (parquet_stats.__isset.min_value && + StringStatsAreValid(parquet_stats.min_value, is_varchar, min_stats_type)) { + StringStats::SetMin(string_stats, parquet_stats.min_value, min_stats_type); + } else if (parquet_stats.__isset.min && StringStatsAreValid(parquet_stats.min, is_varchar, min_stats_type)) { + StringStats::SetMin(string_stats, parquet_stats.min, min_stats_type); + } + if (parquet_stats.__isset.max_value && + StringStatsAreValid(parquet_stats.max_value, is_varchar, max_stats_type)) { + StringStats::SetMax(string_stats, parquet_stats.max_value, max_stats_type); + } else if (parquet_stats.__isset.max && StringStatsAreValid(parquet_stats.max, is_varchar, max_stats_type)) { + StringStats::SetMax(string_stats, parquet_stats.max, max_stats_type); } return string_stats.ToUnique(); } @@ -656,7 +679,7 @@ static bool HasFilterConstants(const Expression &expr) { return false; } auto &constant = right.Cast(); - return !constant.value.IsNull(); + return !constant.GetValue().IsNull(); } if (expr.GetExpressionClass() != ExpressionClass::BOUND_CONJUNCTION) { return false; @@ -671,31 +694,8 @@ static bool HasFilterConstants(const Expression &expr) { } static bool HasFilterConstants(const TableFilter &duckdb_filter) { - switch (duckdb_filter.filter_type) { - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and_filter = duckdb_filter.Cast(); - bool child_has_constant = false; - for (auto &child_filter : conjunction_and_filter.child_filters) { - child_has_constant |= HasFilterConstants(*child_filter); - } - return child_has_constant; - } - case TableFilterType::CONJUNCTION_OR: { - auto &conjunction_or_filter = duckdb_filter.Cast(); - bool child_has_constant = false; - for (auto &child_filter : conjunction_or_filter.child_filters) { - child_has_constant |= HasFilterConstants(*child_filter); - } - return child_has_constant; - } - case TableFilterType::EXPRESSION_FILTER: { - auto &expr_filter = - ExpressionFilter::GetExpressionFilter(duckdb_filter, "ParquetStatistics::HasFilterConstants"); - return HasFilterConstants(*expr_filter.expr); - } - default: - return false; - } + auto &expr_filter = ExpressionFilter::GetExpressionFilter(duckdb_filter, "ParquetStatistics::HasFilterConstants"); + return HasFilterConstants(*expr_filter.expr); } template @@ -749,8 +749,8 @@ static bool ApplyBloomFilter(const Expression &expr, ParquetBloomFilter &bloom_f return false; } auto &constant = right.Cast(); - D_ASSERT(!constant.value.IsNull()); - auto hash = ValueXXH64(constant.value); + D_ASSERT(!constant.GetValue().IsNull()); + auto hash = ValueXXH64(constant.GetValue()); return hash > 0 && !bloom_filter.FilterCheck(hash); } if (expr.GetExpressionClass() != ExpressionClass::BOUND_CONJUNCTION) { @@ -775,30 +775,8 @@ static bool ApplyBloomFilter(const Expression &expr, ParquetBloomFilter &bloom_f } static bool ApplyBloomFilter(const TableFilter &duckdb_filter, ParquetBloomFilter &bloom_filter) { - switch (duckdb_filter.filter_type) { - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and_filter = duckdb_filter.Cast(); - bool any_children_true = false; - for (auto &child_filter : conjunction_and_filter.child_filters) { - any_children_true |= ApplyBloomFilter(*child_filter, bloom_filter); - } - return any_children_true; - } - case TableFilterType::CONJUNCTION_OR: { - auto &conjunction_or_filter = duckdb_filter.Cast(); - bool all_children_true = true; - for (auto &child_filter : conjunction_or_filter.child_filters) { - all_children_true &= ApplyBloomFilter(*child_filter, bloom_filter); - } - return all_children_true; - } - case TableFilterType::EXPRESSION_FILTER: { - auto &expr_filter = ExpressionFilter::GetExpressionFilter(duckdb_filter, "ParquetStatistics::ApplyBloomFilter"); - return ApplyBloomFilter(*expr_filter.expr, bloom_filter); - } - default: - return false; - } + auto &expr_filter = ExpressionFilter::GetExpressionFilter(duckdb_filter, "ParquetStatistics::ApplyBloomFilter"); + return ApplyBloomFilter(*expr_filter.expr, bloom_filter); } bool ParquetStatisticsUtils::BloomFilterSupported(const LogicalTypeId &type_id) { diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index 802b7b615..535d1d050 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -486,10 +486,10 @@ void ParquetWriter::InitializeColumnWriters() { column_writers.clear(); D_ASSERT(options.sql_types.size() == unique_names.size()); for (idx_t i = 0; i < options.sql_types.size(); i++) { - vector path_in_schema; + vector path_in_schema; const bool can_have_nulls = options.not_null_columns.empty() || !options.not_null_columns[i]; column_writers.push_back(ColumnWriter::CreateWriterRecursive( - context, *this, path_in_schema, types[i], unique_names[i], allow_geometry, &options.field_ids, + context, *this, path_in_schema, types[i], Identifier(unique_names[i]), allow_geometry, &options.field_ids, &options.shredding_types, 0, 1, can_have_nulls)); } } @@ -550,7 +550,7 @@ PreparedParquetLayout ParquetWriter::ExportPreparedLayout() const { if (!column_writer->TryExportPreparedShreddingType(column_shredding_type)) { continue; } - result.shredding_types.AddChild(column_writer->Schema().name, std::move(column_shredding_type)); + result.shredding_types.AddChild(Identifier(column_writer->Schema().name), std::move(column_shredding_type)); } return result; } @@ -1337,4 +1337,6 @@ void ParquetWriter::SetWrittenStatistics(CopyFunctionFileStatistics &written_sta //! NOTE: the actual accumulators for the writers are created after FinalizeSchema() is called } +ParquetWriteGlobalState::ParquetWriteGlobalState() = default; + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/reader/list_column_reader.cpp b/src/duckdb/extension/parquet/reader/list_column_reader.cpp index 3ace32542..8bedf4219 100644 --- a/src/duckdb/extension/parquet/reader/list_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/list_column_reader.cpp @@ -59,7 +59,7 @@ struct TemplatedListReader { data.result_ptr[offset].length = 0; } - static void AppendVector(optional_ptr result_out, Vector &read_vector, idx_t child_idx) { + static void AppendVector(optional_ptr result_out, const Vector &read_vector, idx_t child_idx) { ListVector::Append(*result_out, read_vector, child_idx); } }; diff --git a/src/duckdb/extension/parquet/reader/row_number_column_reader.cpp b/src/duckdb/extension/parquet/reader/row_number_column_reader.cpp index 157d2ee8d..cbe2679ce 100644 --- a/src/duckdb/extension/parquet/reader/row_number_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/row_number_column_reader.cpp @@ -64,4 +64,25 @@ idx_t RowNumberColumnReader::Read(ColumnReaderInput &input, Vector &result) { return num_values; } +//===--------------------------------------------------------------------===// +// Row Group Column Reader +//===--------------------------------------------------------------------===// +RowGroupColumnReader::RowGroupColumnReader(const ParquetReader &reader, const ParquetColumnSchema &schema) + : ColumnReader(reader, schema) { +} + +void RowGroupColumnReader::InitializeRead(idx_t row_group_idx_p, const vector &columns, + TProtocol &protocol_p) { + row_group_idx = row_group_idx_p; +} + +idx_t RowGroupColumnReader::Read(ColumnReaderInput &input, Vector &result) { + auto &num_values = input.num_values; + + // the row group number is constant for all rows within a row group - emit a constant vector + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(result)[0] = UnsafeNumericCast(row_group_idx); + return num_values; +} + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp index c7f38a764..c6aff53e4 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp @@ -336,7 +336,7 @@ namespace { struct ShreddedVariantField { public: - explicit ShreddedVariantField(const string &field_name) : field_name(field_name) { + explicit ShreddedVariantField(const Identifier &field_name) : field_name(field_name.GetIdentifierName()) { } public: diff --git a/src/duckdb/extension/parquet/serialize_parquet.cpp b/src/duckdb/extension/parquet/serialize_parquet.cpp index 76f864368..bb53ce2f2 100644 --- a/src/duckdb/extension/parquet/serialize_parquet.cpp +++ b/src/duckdb/extension/parquet/serialize_parquet.cpp @@ -13,22 +13,22 @@ namespace duckdb { void ChildFieldIDs::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "ids", ids.operator*()); + serializer.WritePropertyWithDefault>(100, "ids", ids.operator*()); } ChildFieldIDs ChildFieldIDs::Deserialize(Deserializer &deserializer) { ChildFieldIDs result; - deserializer.ReadPropertyWithDefault>(100, "ids", result.ids.operator*()); + deserializer.ReadPropertyWithDefault>(100, "ids", result.ids.operator*()); return result; } void ChildShreddingTypes::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "types", types.operator*()); + serializer.WritePropertyWithDefault>(100, "types", types.operator*()); } ChildShreddingTypes ChildShreddingTypes::Deserialize(Deserializer &deserializer) { ChildShreddingTypes result; - deserializer.ReadPropertyWithDefault>(100, "types", result.types.operator*()); + deserializer.ReadPropertyWithDefault>(100, "types", result.types.operator*()); return result; } @@ -85,6 +85,7 @@ void ParquetOptionsSerialization::Serialize(Serializer &serializer) const { /* [Deleted] (bool) "parquet_options.debug_use_openssl" */ serializer.WritePropertyWithDefault(106, "explicit_cardinality", parquet_options.explicit_cardinality, 0); serializer.WritePropertyWithDefault(107, "can_have_nan", parquet_options.can_have_nan, false); + serializer.WritePropertyWithDefault(108, "prefetch_strategy", parquet_options.prefetch_strategy, ParquetPrefetchStrategyOption::AUTO); } ParquetOptionsSerialization ParquetOptionsSerialization::Deserialize(Deserializer &deserializer) { @@ -97,6 +98,7 @@ ParquetOptionsSerialization ParquetOptionsSerialization::Deserialize(Deserialize deserializer.ReadDeletedProperty(105, "debug_use_openssl"); deserializer.ReadPropertyWithExplicitDefault(106, "explicit_cardinality", result.parquet_options.explicit_cardinality, 0); deserializer.ReadPropertyWithExplicitDefault(107, "can_have_nan", result.parquet_options.can_have_nan, false); + deserializer.ReadPropertyWithExplicitDefault(108, "prefetch_strategy", result.parquet_options.prefetch_strategy, ParquetPrefetchStrategyOption::AUTO); return result; } diff --git a/src/duckdb/extension/parquet/writer/boolean_column_writer.cpp b/src/duckdb/extension/parquet/writer/boolean_column_writer.cpp index 30e3f997a..3929b2981 100644 --- a/src/duckdb/extension/parquet/writer/boolean_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/boolean_column_writer.cpp @@ -47,7 +47,7 @@ class BooleanWriterPageState : public ColumnWriterPageState { }; BooleanColumnWriter::BooleanColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, - vector schema_path_p) + vector schema_path_p) : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } diff --git a/src/duckdb/extension/parquet/writer/decimal_column_writer.cpp b/src/duckdb/extension/parquet/writer/decimal_column_writer.cpp index 0038179f4..f6ff05eb1 100644 --- a/src/duckdb/extension/parquet/writer/decimal_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/decimal_column_writer.cpp @@ -82,7 +82,7 @@ class FixedDecimalStatistics : public ColumnWriterStatistics { }; FixedDecimalColumnWriter::FixedDecimalColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, - vector schema_path_p) + vector schema_path_p) : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { } diff --git a/src/duckdb/extension/parquet/writer/enum_column_writer.cpp b/src/duckdb/extension/parquet/writer/enum_column_writer.cpp index c9044c408..b030cc307 100644 --- a/src/duckdb/extension/parquet/writer/enum_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/enum_column_writer.cpp @@ -22,6 +22,18 @@ class Vector; using duckdb_parquet::Encoding; +namespace { +class EnumStatisticsState : public StringStatisticsState { +public: + bool MinIsExact() override { + return false; + } + bool MaxIsExact() override { + return false; + } +}; +} // namespace + class EnumWriterPageState : public ColumnWriterPageState { public: explicit EnumWriterPageState(uint32_t bit_width) : encoder(bit_width), written_value(false) { @@ -32,13 +44,14 @@ class EnumWriterPageState : public ColumnWriterPageState { }; EnumColumnWriter::EnumColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, - vector schema_path_p) + vector schema_path_p) : PrimitiveColumnWriter(writer, std::move(column_schema), std::move(schema_path_p)) { bit_width = RleBpDecoder::ComputeBitWidthFromValueCount(EnumType::GetSize(Type())); + seen_enum.resize(EnumType::GetSize(Type()), false); } unique_ptr EnumColumnWriter::InitializeStatsState() { - return make_uniq(); + return make_uniq(); } template @@ -55,6 +68,7 @@ void EnumColumnWriter::WriteEnumInternal(WriteStream &temp_writer, Vector &input page_state.written_value = true; } page_state.encoder.WriteValue(temp_writer, ptr[r]); + seen_enum[ptr[r]] = true; } } } @@ -108,7 +122,6 @@ idx_t EnumColumnWriter::DictionarySize(PrimitiveColumnWriterState &state_p) { void EnumColumnWriter::FlushDictionary(PrimitiveColumnWriterState &state, ColumnWriterStatistics *stats_p) { auto &stats = stats_p->Cast(); - // write the enum values to a dictionary page auto &enum_values = EnumType::GetValuesInsertOrder(Type()); auto enum_count = EnumType::GetSize(Type()); auto string_values = FlatVector::GetData(enum_values); @@ -116,11 +129,13 @@ void EnumColumnWriter::FlushDictionary(PrimitiveColumnWriterState &state, Column auto temp_writer = make_uniq(BufferAllocator::Get(writer.GetContext())); for (idx_t r = 0; r < enum_count; r++) { D_ASSERT(!FlatVector::IsNull(enum_values, r)); - // update the statistics - stats.Update(string_values[r]); - // write this string value to the dictionary - temp_writer->Write(string_values[r].GetSize()); - temp_writer->WriteData(const_data_ptr_cast(string_values[r].GetData()), string_values[r].GetSize()); + if (seen_enum[r]) { + stats.Update(string_values[r]); + temp_writer->Write(string_values[r].GetSize()); + temp_writer->WriteData(const_data_ptr_cast(string_values[r].GetData()), string_values[r].GetSize()); + } else { + temp_writer->Write(0); + } } // flush the dictionary page and add it to the to-be-written pages WriteDictionary(state, std::move(temp_writer), enum_count); diff --git a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp index dd3b696e0..b4e52d028 100644 --- a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp @@ -31,7 +31,7 @@ constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_PAGE_SIZE; constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_DICT_PAGE_SIZE; PrimitiveColumnWriter::PrimitiveColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, - vector schema_path) + vector schema_path) : ColumnWriter(writer, std::move(column_schema), std::move(schema_path)) { } @@ -45,7 +45,7 @@ void PrimitiveColumnWriter::RegisterToRowGroup(duckdb_parquet::RowGroup &row_gro duckdb_parquet::ColumnChunk column_chunk; column_chunk.__isset.meta_data = true; column_chunk.meta_data.codec = writer.GetCodec(); - column_chunk.meta_data.path_in_schema = schema_path; + column_chunk.meta_data.path_in_schema = IdentifiersToStrings(schema_path); column_chunk.meta_data.num_values = 0; column_chunk.meta_data.type = writer.GetType(SchemaIndex()); row_group.columns.push_back(std::move(column_chunk)); @@ -287,6 +287,29 @@ void PrimitiveColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, id void PrimitiveColumnWriter::SetParquetStatistics(PrimitiveColumnWriterState &state, duckdb_parquet::ColumnChunk &column_chunk) { + auto add_encoding = [&](duckdb_parquet::Encoding::type encoding) { + for (const auto &existing_encoding : column_chunk.meta_data.encodings) { + if (existing_encoding == encoding) { + return; + } + } + column_chunk.meta_data.encodings.push_back(encoding); + }; + + for (const auto &write_info : state.write_info) { + // only care about data page encodings, data_page_header.encoding is meaningless for dict + switch (write_info.page_header.type) { + case PageType::DATA_PAGE: + add_encoding(write_info.page_header.data_page_header.encoding); + break; + case PageType::DATA_PAGE_V2: + add_encoding(write_info.page_header.data_page_header_v2.encoding); + break; + default: + break; + } + } + if (!state.stats_state) { return; } @@ -351,15 +374,6 @@ void PrimitiveColumnWriter::SetParquetStatistics(PrimitiveColumnWriterState &sta *state.stats_state->GetGeoStats(), gpq_version); } } - - for (const auto &write_info : state.write_info) { - // only care about data page encodings, data_page_header.encoding is meaningless for dict - if (write_info.page_header.type != PageType::DATA_PAGE && - write_info.page_header.type != PageType::DATA_PAGE_V2) { - continue; - } - column_chunk.meta_data.encodings.push_back(write_info.page_header.data_page_header.encoding); - } } void PrimitiveColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { diff --git a/src/duckdb/extension/parquet/writer/variant/analyze_variant.cpp b/src/duckdb/extension/parquet/writer/variant/analyze_variant.cpp index 04acaec50..325057ae3 100644 --- a/src/duckdb/extension/parquet/writer/variant/analyze_variant.cpp +++ b/src/duckdb/extension/parquet/writer/variant/analyze_variant.cpp @@ -212,6 +212,9 @@ void VariantColumnWriter::AnalyzeSchemaFinalize(const ParquetAnalyzeSchemaState LogicalType shredded_type; if (!ConstructShreddedType(state.analyze_data, shredded_type)) { //! Can't shred, keep the original children + //! Mark as analyzed to prevent re-analysis from modifying child_writers + //! after InitializeSchemaElements has already locked in the schema + is_analyzed = true; return; } is_analyzed = true; diff --git a/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp index 7d43331cc..7ca8444fe 100644 --- a/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp +++ b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp @@ -565,9 +565,18 @@ static void WritePrimitiveValueData(const UnifiedVariantVectorData &variant, idx case VariantLogicalType::DECIMAL: { auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); - if (decimal_data.width <= 4 || decimal_data.width > 38) { + if (decimal_data.width > 38) { throw InvalidInputException("Can't convert VARIANT DECIMAL(%d, %d) to Parquet VARIANT", decimal_data.width, decimal_data.scale); + } else if (decimal_data.width <= 4) { + // DuckDB uses INT16 to store small decimals, but parquet only supports DECIMAL4 at minimum, so here we + // promote to INT32. + WritePrimitiveTypeHeader(value_data); + Store(NumericCast(decimal_data.scale), value_data); + value_data++; + const int32_t promoted = Load(decimal_data.value_ptr); + Store(promoted, value_data); + value_data += sizeof(int32_t); } else if (decimal_data.width <= 9) { WritePrimitiveTypeHeader(value_data); Store(NumericCast(decimal_data.scale), value_data); @@ -900,7 +909,7 @@ static void ToParquetVariant(DataChunk &input, ExpressionState &state, Vector &r // - metadata = BLOB // - value = BLOB - auto &variant_vec = input.data[0]; + const auto &variant_vec = input.data[0]; auto count = input.size(); RecursiveUnifiedVectorFormat recursive_format; diff --git a/src/duckdb/extension/parquet/writer/variant_column_writer.cpp b/src/duckdb/extension/parquet/writer/variant_column_writer.cpp new file mode 100644 index 000000000..020ffb6b9 --- /dev/null +++ b/src/duckdb/extension/parquet/writer/variant_column_writer.cpp @@ -0,0 +1,8 @@ +#include "writer/variant_column_writer.hpp" + +namespace duckdb { + +VariantAnalyzeData::VariantAnalyzeData() = default; +VariantAnalyzeData::~VariantAnalyzeData() = default; + +} // namespace duckdb diff --git a/src/duckdb/generated_extension_loader_package_build.cpp b/src/duckdb/generated_extension_loader_package_build.cpp index 95c8dad03..8f6906398 100644 --- a/src/duckdb/generated_extension_loader_package_build.cpp +++ b/src/duckdb/generated_extension_loader_package_build.cpp @@ -1,8 +1,31 @@ +#ifndef DUCKDB_EXTENSION_CORE_FUNCTIONS_LINKED +#define DUCKDB_EXTENSION_CORE_FUNCTIONS_LINKED 1 +#endif + +#ifndef DUCKDB_EXTENSION_PARQUET_LINKED +#define DUCKDB_EXTENSION_PARQUET_LINKED 1 +#endif + +#ifndef DUCKDB_EXTENSION_ICU_LINKED +#define DUCKDB_EXTENSION_ICU_LINKED 1 +#endif + +#ifndef DUCKDB_EXTENSION_JSON_LINKED +#define DUCKDB_EXTENSION_JSON_LINKED 1 +#endif + +#if DUCKDB_EXTENSION_CORE_FUNCTIONS_LINKED #include "core_functions_extension.hpp" +#endif +#if DUCKDB_EXTENSION_PARQUET_LINKED #include "parquet_extension.hpp" +#endif +#if DUCKDB_EXTENSION_ICU_LINKED #include "icu_extension.hpp" +#endif +#if DUCKDB_EXTENSION_JSON_LINKED #include "json_extension.hpp" -#include "jemalloc_extension.hpp" +#endif #include "duckdb/main/extension/generated_extension_loader.hpp" #include "duckdb/main/extension_helper.hpp" @@ -10,37 +33,48 @@ namespace duckdb { //! Looks through the package_build.py-generated list of extensions that are linked into DuckDB currently to try load ExtensionLoadResult ExtensionHelper::LoadExtension(DuckDB &db, const std::string &extension) { - +#if DUCKDB_EXTENSION_CORE_FUNCTIONS_LINKED if (extension=="core_functions") { db.LoadStaticExtension(); return ExtensionLoadResult::LOADED_EXTENSION; } - +#endif +#if DUCKDB_EXTENSION_PARQUET_LINKED if (extension=="parquet") { db.LoadStaticExtension(); return ExtensionLoadResult::LOADED_EXTENSION; } - +#endif +#if DUCKDB_EXTENSION_ICU_LINKED if (extension=="icu") { db.LoadStaticExtension(); return ExtensionLoadResult::LOADED_EXTENSION; } - +#endif +#if DUCKDB_EXTENSION_JSON_LINKED if (extension=="json") { db.LoadStaticExtension(); return ExtensionLoadResult::LOADED_EXTENSION; } - - if (extension=="jemalloc") { - db.LoadStaticExtension(); - return ExtensionLoadResult::LOADED_EXTENSION; - } - +#endif + return ExtensionLoadResult::NOT_LOADED; } vector LinkedExtensions(){ - vector VEC = {"core_functions", "parquet", "icu", "json", "jemalloc" + vector VEC = { +#if DUCKDB_EXTENSION_CORE_FUNCTIONS_LINKED + "core_functions", +#endif +#if DUCKDB_EXTENSION_PARQUET_LINKED + "parquet", +#endif +#if DUCKDB_EXTENSION_ICU_LINKED + "icu", +#endif +#if DUCKDB_EXTENSION_JSON_LINKED + "json", +#endif }; return VEC; } diff --git a/src/duckdb/src/catalog/catalog.cpp b/src/duckdb/src/catalog/catalog.cpp index b66f0452d..50f5920ff 100644 --- a/src/duckdb/src/catalog/catalog.cpp +++ b/src/duckdb/src/catalog/catalog.cpp @@ -63,7 +63,7 @@ const AttachedDatabase &Catalog::GetAttached() const { return db; } -const string &Catalog::GetName() const { +const Identifier &Catalog::GetName() const { return GetAttached().GetName(); } @@ -75,11 +75,11 @@ Catalog &Catalog::GetSystemCatalog(ClientContext &context) { return Catalog::GetSystemCatalog(*context.db); } -const string &GetDefaultCatalog(CatalogEntryRetriever &retriever) { +Identifier GetDefaultCatalog(CatalogEntryRetriever &retriever) { return DatabaseManager::GetDefaultDatabase(retriever.GetContext()); } -optional_ptr Catalog::GetCatalogEntry(CatalogEntryRetriever &retriever, const string &catalog_name) { +optional_ptr Catalog::GetCatalogEntry(CatalogEntryRetriever &retriever, const Identifier &catalog_name) { auto &context = retriever.GetContext(); auto &db_manager = DatabaseManager::Get(context); if (catalog_name == TEMP_CATALOG) { @@ -96,20 +96,20 @@ optional_ptr Catalog::GetCatalogEntry(CatalogEntryRetriever &retriever, return &entry->GetCatalog(); } -optional_ptr Catalog::GetCatalogEntry(ClientContext &context, const string &catalog_name) { +optional_ptr Catalog::GetCatalogEntry(ClientContext &context, const Identifier &catalog_name) { CatalogEntryRetriever entry_retriever(context); return GetCatalogEntry(entry_retriever, catalog_name); } -Catalog &Catalog::GetCatalog(CatalogEntryRetriever &retriever, const string &catalog_name) { +Catalog &Catalog::GetCatalog(CatalogEntryRetriever &retriever, const Identifier &catalog_name) { auto catalog = Catalog::GetCatalogEntry(retriever, catalog_name); if (!catalog) { - throw BinderException("Catalog \"%s\" does not exist!", catalog_name); + throw BinderException("Catalog \"%s\" does not exist!", catalog_name.GetIdentifierName()); } return *catalog; } -Catalog &Catalog::GetCatalog(ClientContext &context, const string &catalog_name) { +Catalog &Catalog::GetCatalog(ClientContext &context, const Identifier &catalog_name) { CatalogEntryRetriever entry_retriever(context); return GetCatalog(entry_retriever, catalog_name); } @@ -345,20 +345,44 @@ unique_ptr Catalog::BindAlterAddIndex(Binder &binder, TableCata throw NotImplementedException("BindAlterAddIndex not supported by this catalog"); } +unique_ptr Catalog::RemoteExecute(ClientContext &context, unique_ptr node) { + throw NotImplementedException("RemoteExecute(QueryNode) not supported by this catalog"); +} + +unique_ptr Catalog::RemoteExecute(ClientContext &context, const string &sql) { + throw NotImplementedException("RemoteExecute(string) not supported by this catalog"); +} + +bool Catalog::SupportsPushdown(const ParsedExpression &expression) { + return true; +} + +bool Catalog::SupportsPushdown(const TableRef &ref) { + return true; +} + +bool Catalog::SupportsPushdown(const QueryNode &node) { + return true; +} + +string Catalog::GetConnectDisplay() { + return GetAttached().GetName().GetIdentifierName(); +} + //===--------------------------------------------------------------------===// // Lookup Structures //===--------------------------------------------------------------------===// struct CatalogLookup { - CatalogLookup(Catalog &catalog, CatalogType catalog_type, string schema_p, string name_p) + CatalogLookup(Catalog &catalog, CatalogType catalog_type, Identifier schema_p, Identifier name_p) : catalog(catalog), schema(std::move(schema_p)), name(std::move(name_p)), lookup_info(catalog_type, name) { } - CatalogLookup(Catalog &catalog, string schema_p, const EntryLookupInfo &lookup_p) + CatalogLookup(Catalog &catalog, Identifier schema_p, const EntryLookupInfo &lookup_p) : catalog(catalog), schema(std::move(schema_p)), name(lookup_p.GetEntryName()), lookup_info(lookup_p, name) { } Catalog &catalog; - string schema; - string name; + Identifier schema; + Identifier name; EntryLookupInfo lookup_info; }; @@ -374,7 +398,7 @@ void Catalog::DropEntry(ClientContext &context, DropInfo &info) { CatalogEntryRetriever retriever(context); EntryLookupInfo lookup_info(info.type, info.name); - auto lookup = LookupEntry(retriever, info.schema, lookup_info, info.if_not_found); + auto lookup = LookupEntry(retriever, info.schema.GetIdentifierName(), lookup_info, info.if_not_found); if (!lookup.Found()) { return; } @@ -382,11 +406,11 @@ void Catalog::DropEntry(ClientContext &context, DropInfo &info) { lookup.schema->DropEntry(context, info); } -SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &schema) { +SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const Identifier &schema) { return GetSchema(GetCatalogTransaction(context), schema); } -SchemaCatalogEntry &Catalog::GetSchema(CatalogTransaction transaction, const string &schema) { +SchemaCatalogEntry &Catalog::GetSchema(CatalogTransaction transaction, const Identifier &schema) { EntryLookupInfo schema_lookup(CatalogType::SCHEMA_ENTRY, schema); return GetSchema(transaction, schema_lookup); } @@ -395,16 +419,17 @@ SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const EntryLookup return *Catalog::GetSchema(context, schema_lookup, OnEntryNotFound::THROW_EXCEPTION); } -SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &catalog_name, const string &schema) { +SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const Identifier &catalog_name, + const Identifier &schema) { return *GetSchema(context, catalog_name, schema, OnEntryNotFound::THROW_EXCEPTION); } -optional_ptr Catalog::GetSchema(ClientContext &context, const string &schema, +optional_ptr Catalog::GetSchema(ClientContext &context, const Identifier &schema, OnEntryNotFound if_not_found) { return GetSchema(GetCatalogTransaction(context), schema, if_not_found); } -optional_ptr Catalog::GetSchema(CatalogTransaction transaction, const string &schema, +optional_ptr Catalog::GetSchema(CatalogTransaction transaction, const Identifier &schema, OnEntryNotFound if_not_found) { EntryLookupInfo schema_lookup(CatalogType::SCHEMA_ENTRY, schema); return LookupSchema(transaction, schema_lookup, if_not_found); @@ -415,13 +440,13 @@ optional_ptr Catalog::GetSchema(ClientContext &context, cons return LookupSchema(GetCatalogTransaction(context), schema_lookup, if_not_found); } -optional_ptr Catalog::GetSchema(ClientContext &context, const string &catalog_name, - const string &schema, OnEntryNotFound if_not_found) { +optional_ptr Catalog::GetSchema(ClientContext &context, const Identifier &catalog_name, + const Identifier &schema, OnEntryNotFound if_not_found) { EntryLookupInfo schema_lookup(CatalogType::SCHEMA_ENTRY, schema); return GetSchema(context, catalog_name, schema_lookup, if_not_found); } -SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &catalog_name, +SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const Identifier &catalog_name, const EntryLookupInfo &schema_lookup) { return *Catalog::GetSchema(context, catalog_name, schema_lookup, OnEntryNotFound::THROW_EXCEPTION); } @@ -430,7 +455,11 @@ SchemaCatalogEntry &Catalog::GetSchema(CatalogTransaction transaction, const Ent return *LookupSchema(transaction, schema_lookup, OnEntryNotFound::THROW_EXCEPTION); } -bool Catalog::CheckAmbiguousCatalogOrSchema(ClientContext &context, const string &schema) { +bool Catalog::CheckAmbiguousCatalogOrSchema(ClientContext &context, const Identifier &schema) { + if (Supports(RemoteCapability::IS_REMOTE)) { + // skip this check for remote catalogs + return false; + } EntryLookupInfo schema_lookup(CatalogType::SCHEMA_ENTRY, schema); return !!GetSchema(context, schema_lookup, OnEntryNotFound::RETURN_NULL); } @@ -461,8 +490,8 @@ vector Catalog::SimilarEntriesInSchemas(ClientContext &cont return results; } -vector GetCatalogEntries(CatalogEntryRetriever &retriever, const string &catalog, - const string &schema) { +vector GetCatalogEntries(CatalogEntryRetriever &retriever, const Identifier &catalog, + const Identifier &schema) { auto &context = retriever.GetContext(); vector entries; auto &search_path = retriever.GetSearchPath(); @@ -490,7 +519,7 @@ vector GetCatalogEntries(CatalogEntryRetriever &retriever, c if (entries.empty()) { auto catalog_entry = Catalog::GetCatalogEntry(context, catalog); if (catalog_entry) { - entries.emplace_back(catalog, catalog_entry->GetDefaultSchema()); + entries.emplace_back(catalog, Identifier(catalog_entry->GetDefaultSchema())); } else { entries.emplace_back(catalog, DEFAULT_SCHEMA); } @@ -502,11 +531,11 @@ vector GetCatalogEntries(CatalogEntryRetriever &retriever, c return entries; } -void FindMinimalQualification(CatalogEntryRetriever &retriever, const string &catalog_name, const string &schema_name, - bool &qualify_database, bool &qualify_schema) { +void FindMinimalQualification(CatalogEntryRetriever &retriever, const Identifier &catalog_name, + const Identifier &schema_name, bool &qualify_database, bool &qualify_schema) { // check if we can we qualify ONLY the schema bool found = false; - auto entries = GetCatalogEntries(retriever, INVALID_CATALOG, schema_name); + auto entries = GetCatalogEntries(retriever, Identifier::InvalidCatalog(), schema_name); for (auto &entry : entries) { if (entry.catalog == catalog_name && entry.schema == schema_name) { found = true; @@ -520,7 +549,7 @@ void FindMinimalQualification(CatalogEntryRetriever &retriever, const string &ca } // check if we can qualify ONLY the catalog found = false; - entries = GetCatalogEntries(retriever, catalog_name, INVALID_SCHEMA); + entries = GetCatalogEntries(retriever, catalog_name, Identifier::InvalidSchema()); for (auto &entry : entries) { if (entry.catalog == catalog_name && entry.schema == schema_name) { found = true; @@ -663,7 +692,7 @@ CatalogException Catalog::UnrecognizedConfigurationError(ClientContext &context, for (auto &entry : DBConfig::GetConfig(context).GetExtensionSettings()) { potential_names.push_back(entry.first); } - throw CatalogException::MissingEntry("configuration parameter", name, potential_names); + throw CatalogException::MissingEntry("configuration parameter", Identifier(name), potential_names); } CatalogException Catalog::CreateMissingEntryException(CatalogEntryRetriever &retriever, @@ -770,7 +799,7 @@ CatalogException Catalog::CreateMissingEntryException(CatalogEntryRetriever &ret } } else if (!entries.empty()) { for (auto &entry : entries) { - suggestions.insert(entry.name); + suggestions.insert(entry.name.GetIdentifierName()); } } @@ -791,7 +820,7 @@ CatalogEntryLookup Catalog::TryLookupEntryInternal(CatalogTransaction transactio if (lookup_info.GetAtClause() && !SupportsTimeTravel()) { return {nullptr, nullptr, ErrorData(BinderException("Catalog type does not support time travel"))}; } - auto schema_lookup = EntryLookupInfo::SchemaLookup(lookup_info, schema); + auto schema_lookup = EntryLookupInfo::SchemaLookup(lookup_info, Identifier(schema)); auto schema_entry = LookupSchema(transaction, schema_lookup, OnEntryNotFound::RETURN_NULL); if (!schema_entry) { return {nullptr, nullptr, ErrorData()}; @@ -809,11 +838,11 @@ CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, con reference_set_t schemas; if (IsInvalidSchema(schema)) { // try all schemas for this catalog - auto entries = GetCatalogEntries(retriever, GetName(), INVALID_SCHEMA); + auto entries = GetCatalogEntries(retriever, GetName(), Identifier::InvalidSchema()); for (auto &entry : entries) { auto &candidate_schema = entry.schema; auto transaction = GetCatalogTransaction(context); - auto result = TryLookupEntryInternal(transaction, candidate_schema, lookup_info); + auto result = TryLookupEntryInternal(transaction, candidate_schema.GetIdentifierName(), lookup_info); if (result.Found()) { return result; } @@ -837,7 +866,7 @@ CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, con } // Check if the default database is actually attached. CreateMissingEntryException will throw binder exception // otherwise. - if (!GetCatalogEntry(context, GetDefaultCatalog(retriever))) { + if (!GetCatalogEntry(context, Identifier(GetDefaultCatalog(retriever)))) { auto except = CatalogException("%s with name %s does not exist!", CatalogTypeToString(lookup_info.GetCatalogType()), lookup_info.GetEntryName()); return {nullptr, nullptr, ErrorData(except)}; @@ -886,7 +915,8 @@ CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, con CatalogEntryLookup result; for (auto &lookup : lookups) { auto transaction = lookup.catalog.GetCatalogTransaction(context); - result = lookup.catalog.TryLookupEntryInternal(transaction, lookup.schema, lookup.lookup_info); + result = + lookup.catalog.TryLookupEntryInternal(transaction, lookup.schema.GetIdentifierName(), lookup.lookup_info); if (result.Found()) { break; } @@ -933,7 +963,7 @@ CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, con // If we have a specific schema name and no schemas were found, the schema doesn't exist. // Throw an error about the schema instead of the table if (schemas.empty() && !lookups.empty() && lookup_info.GetCatalogType() == CatalogType::TABLE_ENTRY) { - string schema_name = lookups[0].schema; + auto &schema_name = lookups[0].schema; if (!IsInvalidSchema(schema_name)) { EntryLookupInfo schema_lookup(CatalogType::SCHEMA_ENTRY, schema_name, lookup_info.GetErrorContext()); string relation_name = schema_name + "." + lookup_info.GetEntryName(); @@ -947,7 +977,7 @@ CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, con // Check if the default database is actually attached. CreateMissingEntryException will throw binder exception // otherwise. - if (!GetCatalogEntry(context, GetDefaultCatalog(retriever))) { + if (!GetCatalogEntry(context, Identifier(GetDefaultCatalog(retriever)))) { auto except = CatalogException("%s with name %s does not exist!", CatalogTypeToString(lookup_info.GetCatalogType()), lookup_info.GetEntryName()); return {nullptr, nullptr, ErrorData(except)}; @@ -959,7 +989,7 @@ CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, con CatalogEntryLookup Catalog::TryLookupDefaultTable(CatalogEntryRetriever &retriever, const EntryLookupInfo &lookup_info, bool allow_ignore_at_clause) { - auto catalog_by_name = GetCatalogEntry(retriever, lookup_info.GetEntryName()); + auto catalog_by_name = GetCatalogEntry(retriever, lookup_info.GetEntryIdentifier()); if (catalog_by_name && catalog_by_name->HasDefaultTable()) { auto transaction = catalog_by_name->GetCatalogTransaction(retriever.GetContext()); @@ -975,15 +1005,15 @@ CatalogEntryLookup Catalog::TryLookupDefaultTable(CatalogEntryRetriever &retriev at_clause = lookup_info.GetAtClause(); } - EntryLookupInfo info = EntryLookupInfo(CatalogType::TABLE_ENTRY, table_name, at_clause, context); + EntryLookupInfo info = EntryLookupInfo(CatalogType::TABLE_ENTRY, Identifier(table_name), at_clause, context); return catalog_by_name->TryLookupEntryInternal(transaction, table_schema, info); } return {nullptr, nullptr, ErrorData()}; } -CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, const string &catalog, - const string &schema, const EntryLookupInfo &lookup_info, +CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, const Identifier &catalog, + const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found) { auto entries = GetCatalogEntries(retriever, catalog, schema); vector lookups; @@ -1017,33 +1047,33 @@ CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, con return TryLookupEntry(retriever, lookups, lookup_info, if_not_found, allow_default_table_lookup); } -CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType catalog_type, const string &catalog_name, - const string &schema_name, const string &name) { +CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType catalog_type, const Identifier &catalog_name, + const Identifier &schema_name, const Identifier &name) { EntryLookupInfo lookup_info(catalog_type, name); return GetEntry(context, catalog_name, schema_name, lookup_info); } -optional_ptr Catalog::GetEntry(ClientContext &context, CatalogType catalog_type, const string &schema, - const string &name, OnEntryNotFound if_not_found) { +optional_ptr Catalog::GetEntry(ClientContext &context, CatalogType catalog_type, const Identifier &schema, + const Identifier &name, OnEntryNotFound if_not_found) { EntryLookupInfo lookup_info(catalog_type, name); return GetEntry(context, schema, lookup_info, if_not_found); } -CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType catalog_type, const string &schema_name, - const string &name) { +CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType catalog_type, const Identifier &schema_name, + const Identifier &name) { EntryLookupInfo lookup_info(catalog_type, name); return GetEntry(context, schema_name, lookup_info); } -optional_ptr Catalog::GetEntry(CatalogEntryRetriever &retriever, const string &schema_name, +optional_ptr Catalog::GetEntry(CatalogEntryRetriever &retriever, const Identifier &schema_name, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found) { - auto lookup_entry = TryLookupEntry(retriever, schema_name, lookup_info, if_not_found); + auto lookup_entry = TryLookupEntry(retriever, schema_name.GetIdentifierName(), lookup_info, if_not_found); // Try autoloading extension to resolve lookup if (!lookup_entry.Found()) { if (AutoLoadExtensionByCatalogEntry(*retriever.GetContext().db, lookup_info.GetCatalogType(), lookup_info.GetEntryName())) { - lookup_entry = TryLookupEntry(retriever, schema_name, lookup_info, if_not_found); + lookup_entry = TryLookupEntry(retriever, schema_name.GetIdentifierName(), lookup_info, if_not_found); } } @@ -1054,18 +1084,18 @@ optional_ptr Catalog::GetEntry(CatalogEntryRetriever &retriever, c return lookup_entry.entry.get(); } -optional_ptr Catalog::GetEntry(ClientContext &context, const string &schema_name, +optional_ptr Catalog::GetEntry(ClientContext &context, const Identifier &schema_name, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found) { CatalogEntryRetriever retriever(context); return GetEntry(retriever, schema_name, lookup_info, if_not_found); } -CatalogEntry &Catalog::GetEntry(ClientContext &context, const string &schema, const EntryLookupInfo &lookup_info) { +CatalogEntry &Catalog::GetEntry(ClientContext &context, const Identifier &schema, const EntryLookupInfo &lookup_info) { return *Catalog::GetEntry(context, schema, lookup_info, OnEntryNotFound::THROW_EXCEPTION); } -optional_ptr Catalog::GetEntry(CatalogEntryRetriever &retriever, const string &catalog, - const string &schema, const EntryLookupInfo &lookup_info, +optional_ptr Catalog::GetEntry(CatalogEntryRetriever &retriever, const Identifier &catalog, + const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found) { auto result = TryLookupEntry(retriever, catalog, schema, lookup_info, if_not_found); @@ -1087,21 +1117,22 @@ optional_ptr Catalog::GetEntry(CatalogEntryRetriever &retriever, c } return result.entry.get(); } -optional_ptr Catalog::GetEntry(ClientContext &context, const string &catalog, const string &schema, - const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found) { +optional_ptr Catalog::GetEntry(ClientContext &context, const Identifier &catalog, + const Identifier &schema, const EntryLookupInfo &lookup_info, + OnEntryNotFound if_not_found) { CatalogEntryRetriever retriever(context); return GetEntry(retriever, catalog, schema, lookup_info, if_not_found); } -CatalogEntry &Catalog::GetEntry(ClientContext &context, const string &catalog, const string &schema, +CatalogEntry &Catalog::GetEntry(ClientContext &context, const Identifier &catalog, const Identifier &schema, const EntryLookupInfo &lookup_info) { return *Catalog::GetEntry(context, catalog, schema, lookup_info, OnEntryNotFound::THROW_EXCEPTION); } -optional_ptr Catalog::GetSchema(CatalogEntryRetriever &retriever, const string &catalog_name, +optional_ptr Catalog::GetSchema(CatalogEntryRetriever &retriever, const Identifier &catalog_name, const EntryLookupInfo &schema_lookup, OnEntryNotFound if_not_found) { - auto entries = GetCatalogEntries(retriever, catalog_name, schema_lookup.GetEntryName()); + auto entries = GetCatalogEntries(retriever, catalog_name, schema_lookup.GetEntryIdentifier()); for (idx_t i = 0; i < entries.size(); i++) { auto catalog = Catalog::GetCatalogEntry(retriever, entries[i].catalog); if (!catalog) { @@ -1116,12 +1147,13 @@ optional_ptr Catalog::GetSchema(CatalogEntryRetriever &retri } // Catalog has not been found. if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw CatalogException(schema_lookup.GetErrorContext(), "Catalog with name %s does not exist!", catalog_name); + throw CatalogException(schema_lookup.GetErrorContext(), "Catalog with name %s does not exist!", + catalog_name.GetIdentifierName()); } return nullptr; } -optional_ptr Catalog::GetSchema(ClientContext &context, const string &catalog_name, +optional_ptr Catalog::GetSchema(ClientContext &context, const Identifier &catalog_name, const EntryLookupInfo &schema_lookup, OnEntryNotFound if_not_found) { CatalogEntryRetriever retriever(context); @@ -1150,7 +1182,7 @@ vector> Catalog::GetSchemas(CatalogEntryRetriever catalogs.push_back(catalog); } } else { - catalogs.push_back(Catalog::GetCatalog(retriever, catalog_name)); + catalogs.push_back(Catalog::GetCatalog(retriever, Identifier(catalog_name))); } vector> result; for (auto catalog : catalogs) { @@ -1208,7 +1240,7 @@ void Catalog::Alter(CatalogTransaction transaction, AlterInfo &info) { if (transaction.HasContext()) { CatalogEntryRetriever retriever(transaction.GetContext()); EntryLookupInfo lookup_info(info.GetCatalogType(), info.name); - auto lookup = LookupEntry(retriever, info.schema, lookup_info, info.if_not_found); + auto lookup = LookupEntry(retriever, info.schema.GetIdentifierName(), lookup_info, info.if_not_found); if (!lookup.Found()) { return; } @@ -1259,9 +1291,9 @@ bool Catalog::HasDefaultTable() const { return !default_table.empty(); } -void Catalog::SetDefaultTable(const string &schema, const string &name) { - default_table = name; - default_table_schema = schema; +void Catalog::SetDefaultTable(const Identifier &schema, const Identifier &name) { + default_table = name.GetIdentifierName(); + default_table_schema = schema.GetIdentifierName(); } string Catalog::GetDefaultTable() const { @@ -1295,7 +1327,10 @@ void Catalog::OnDetach(ClientContext &context) { bool Catalog::HasConflictingAttachOptions(const string &path, const AttachOptions &options) { auto const db_type = options.db_type.empty() ? "duckdb" : options.db_type; - return GetDBPath() != path || GetCatalogType() != db_type; + // Normalize through the extension alias table so that equivalent forms + auto canonical_actual = ExtensionHelper::ApplyExtensionAlias(GetCatalogType()); + auto canonical_requested = ExtensionHelper::ApplyExtensionAlias(db_type); + return GetDBPath() != path || !StringUtil::CIEquals(canonical_actual, canonical_requested); } } // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry.cpp index 8fca4a954..d54d28510 100644 --- a/src/duckdb/src/catalog/catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry.cpp @@ -9,12 +9,12 @@ namespace duckdb { -CatalogEntry::CatalogEntry(CatalogType type, string name_p, idx_t oid) +CatalogEntry::CatalogEntry(CatalogType type, Identifier name_p, idx_t oid) : oid(oid), type(type), set(nullptr), name(std::move(name_p)), deleted(false), temporary(false), internal(false), parent(nullptr) { } -CatalogEntry::CatalogEntry(CatalogType type, Catalog &catalog, string name_p) +CatalogEntry::CatalogEntry(CatalogType type, Catalog &catalog, Identifier name_p) : CatalogEntry(type, std::move(name_p), catalog.GetDatabase().GetDatabaseManager().NextOid()) { } @@ -54,13 +54,13 @@ string CatalogEntry::ToSQL() const { void CatalogEntry::SetChild(unique_ptr child_p) { child = std::move(child_p); if (child) { - child->parent = this; + child->parent.store(this); } } unique_ptr CatalogEntry::TakeChild() { if (child) { - child->parent = nullptr; + child->parent.store(nullptr); } return std::move(child); } @@ -69,7 +69,7 @@ bool CatalogEntry::HasChild() const { return child != nullptr; } bool CatalogEntry::HasParent() const { - return parent != nullptr; + return parent.load() != nullptr; } CatalogEntry &CatalogEntry::Child() { @@ -77,7 +77,11 @@ CatalogEntry &CatalogEntry::Child() { } CatalogEntry &CatalogEntry::Parent() { - return *parent; + return *parent.load(); +} + +const CatalogEntry &CatalogEntry::Parent() const { + return *parent.load(); } Catalog &CatalogEntry::ParentCatalog() { @@ -115,7 +119,7 @@ void CatalogEntry::Rollback(CatalogEntry &prev_entry) { void CatalogEntry::OnDrop() { } -InCatalogEntry::InCatalogEntry(CatalogType type, Catalog &catalog, string name) +InCatalogEntry::InCatalogEntry(CatalogType type, Catalog &catalog, Identifier name) : CatalogEntry(type, catalog, std::move(name)), catalog(catalog) { } diff --git a/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp b/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp index b8b2f0e87..c1741e9f7 100644 --- a/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp +++ b/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp @@ -18,10 +18,10 @@ void ColumnDependencyManager::AddGeneratedColumn(const ColumnDefinition &column, column.GetListOfDependencies(referenced_columns); vector indices; for (auto &col : referenced_columns) { - if (!list.ColumnExists(col)) { + if (!list.ColumnExists(Identifier(col))) { throw BinderException("Column \"%s\" referenced by generated column does not exist", col); } - auto &entry = list.GetColumn(col); + auto &entry = list.GetColumn(Identifier(col)); indices.push_back(entry.Logical()); } return AddGeneratedColumn(column.Logical(), indices); diff --git a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp index f91d6569f..5f81bab17 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp @@ -7,7 +7,7 @@ namespace duckdb { -IndexDataTableInfo::IndexDataTableInfo(shared_ptr info_p, const string &index_name_p) +IndexDataTableInfo::IndexDataTableInfo(shared_ptr info_p, const Identifier &index_name_p) : info(std::move(info_p)), index_name(index_name_p) { } @@ -44,11 +44,11 @@ unique_ptr DuckIndexEntry::Copy(ClientContext &context) const { return std::move(result); } -string DuckIndexEntry::GetSchemaName() const { +Identifier DuckIndexEntry::GetSchemaName() const { return GetDataTableInfo().GetSchemaName(); } -string DuckIndexEntry::GetTableName() const { +Identifier DuckIndexEntry::GetTableName() const { return GetDataTableInfo().GetTableName(); } diff --git a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp index 5609da548..cc0fe3afb 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp @@ -255,7 +255,7 @@ optional_ptr DuckSchemaEntry::CreateIndex(CatalogTransaction trans // currently, we can not alter PK/FK/UNIQUE constraints // concurrency-safe name checks against other INDEX catalog entries happens in the catalog if (info.on_conflict != OnCreateConflict::IGNORE_ON_CONFLICT && - !table.GetStorage().IndexNameIsUnique(info.index_name)) { + !table.GetStorage().IndexNameIsUnique(info.index_name.GetIdentifierName())) { throw CatalogException("An index with the name " + info.index_name + " already exists!"); } @@ -307,7 +307,7 @@ void DuckSchemaEntry::Alter(CatalogTransaction transaction, AlterInfo &info) { throw CatalogException("Couldn't change ownership!"); } } else { - string name = info.name; + auto &name = info.name; if (!set.AlterEntry(transaction, name, info)) { throw CatalogException::MissingEntry(type, name, string()); } @@ -376,17 +376,17 @@ void DuckSchemaEntry::OnDropEntry(CatalogTransaction transaction, CatalogEntry & optional_ptr DuckSchemaEntry::LookupEntry(CatalogTransaction transaction, const EntryLookupInfo &lookup_info) { - return GetCatalogSet(lookup_info.GetCatalogType()).GetEntry(transaction, lookup_info.GetEntryName()); + return GetCatalogSet(lookup_info.GetCatalogType()).GetEntry(transaction, lookup_info.GetEntryIdentifier()); } CatalogSet::EntryLookup DuckSchemaEntry::LookupEntryDetailed(CatalogTransaction transaction, const EntryLookupInfo &lookup_info) { - return GetCatalogSet(lookup_info.GetCatalogType()).GetEntryDetailed(transaction, lookup_info.GetEntryName()); + return GetCatalogSet(lookup_info.GetCatalogType()).GetEntryDetailed(transaction, lookup_info.GetEntryIdentifier()); } SimilarCatalogEntry DuckSchemaEntry::GetSimilarEntry(CatalogTransaction transaction, const EntryLookupInfo &lookup_info) { - return GetCatalogSet(lookup_info.GetCatalogType()).SimilarEntry(transaction, lookup_info.GetEntryName()); + return GetCatalogSet(lookup_info.GetCatalogType()).SimilarEntry(transaction, lookup_info.GetEntryIdentifier()); } CatalogSet &DuckSchemaEntry::GetCatalogSet(CatalogType type) { diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp index c77083781..94ba907e3 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp @@ -36,7 +36,7 @@ IndexStorageInfo GetIndexInfo(const IndexConstraintType type, const bool v1_0_0_ auto &table_info = info->Cast(); auto constraint_name = EnumUtil::ToString(type) + "_"; auto name = constraint_name + table_info.table + "_" + to_string(id); - IndexStorageInfo index_info(name); + IndexStorageInfo index_info {Identifier(name)}; if (!v1_0_0_storage) { index_info.options.emplace("v1_0_0_storage", v1_0_0_storage); } @@ -45,6 +45,19 @@ IndexStorageInfo GetIndexInfo(const IndexConstraintType type, const bool v1_0_0_ static void CheckTypeIsSupported(const LogicalType &logical_type, AttachedDatabase &db) { TypeVisitor::Contains(logical_type, [&](const LogicalType &type) { + if (type.IsAggregateState()) { + const auto storage_version = db.GetStorageManager().GetStorageVersion(); + + if (storage_version < StorageVersion::V2_0_0) { + auto required = GetStorageVersionName(StorageVersion::V2_0_0, false); + auto current = GetStorageVersionName(storage_version, false); + + throw InvalidInputException("Aggregate state columns are not supported in storage versions prior to %s " + "(database \"%s\" is using storage version %s)", + required, db.GetName(), current); + } + return false; + } switch (type.id()) { case LogicalTypeId::TYPE: { throw InvalidInputException("A table cannot be created with a 'TYPE' column"); @@ -113,7 +126,8 @@ DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, Bou column_defs.push_back(col_def.Copy()); } storage = make_shared_ptr(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), - schema.name, name, std::move(column_defs), std::move(info.data)); + schema.name.GetIdentifierName(), name.GetIdentifierName(), + std::move(column_defs), std::move(info.data)); // Create the unique indexes for the UNIQUE, PRIMARY KEY, and FOREIGN KEY constraints. idx_t indexes_idx = 0; @@ -332,8 +346,8 @@ void DuckTableEntry::UndoAlter(ClientContext &context, AlterInfo &info) { static void RenameExpression(ParsedExpression &root_expr, RenameColumnInfo &info) { ParsedExpressionIterator::VisitExpressionMutable(root_expr, [&](ColumnRefExpression &colref) { - if (colref.column_names.back() == info.old_name) { - colref.column_names.back() = info.new_name; + if (colref.ColumnNames().back() == info.old_name) { + colref.ColumnNamesMutable().back() = info.new_name; } }); } @@ -382,7 +396,7 @@ unique_ptr DuckTableEntry::RenameColumn(ClientContext &context, Re case ConstraintType::FOREIGN_KEY: { // FOREIGN KEY constraint: possibly need to rename columns auto &fk = copy->Cast(); - vector columns = fk.pk_columns; + vector columns = fk.pk_columns; if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { columns = fk.fk_columns; } else if (fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { @@ -449,7 +463,8 @@ unique_ptr DuckTableEntry::AddColumn(ClientContext &context, AddCo // When replaying WAL, only Bind DEFAULT for added Column auto &catalog_name = schema.ParentCatalog().GetName(); auto &schema_name = schema.name; - binder->BindDefaultValue(info.new_column, bound_defaults, catalog_name, schema_name); + binder->BindDefaultValue(info.new_column, bound_defaults, catalog_name.GetIdentifierName(), + schema_name.GetIdentifierName()); } auto new_storage = make_shared_ptr(context, *storage, info.new_column, *bound_defaults.back()); return make_uniq(catalog, schema, *bound_create_info, new_storage); @@ -461,12 +476,10 @@ struct StructMappingInfo { ErrorData error; }; -unique_ptr PackExpression(unique_ptr expr, string name) { - expr->SetAlias(std::move(name)); - vector> children; - children.push_back(std::move(expr)); - auto res = make_uniq("struct_pack", std::move(children)); - return std::move(res); +unique_ptr PackExpression(unique_ptr expr, Identifier name) { + vector children; + children.emplace_back(std::move(name), std::move(expr)); + return make_uniq("struct_pack", std::move(children)); } static child_list_t GetChildList(const LogicalType &type) { @@ -509,7 +522,7 @@ static LogicalType ConstructNewType(const LogicalType &original_type, child_list } } -Value ConstructMapping(const string &name, const LogicalType &type) { +Value ConstructMapping(const Identifier &name, const LogicalType &type) { if (!type.IsNested()) { return Value(name); } @@ -528,7 +541,7 @@ Value ConstructMapping(const string &name, const LogicalType &type) { return Value::STRUCT(std::move(child_mapping)); } -StructMappingInfo AddFieldToStruct(const LogicalType &type, const vector &column_path, +StructMappingInfo AddFieldToStruct(const LogicalType &type, const vector &column_path, const ColumnDefinition &new_field, idx_t depth = 0) { if (!type.IsNested()) { throw BinderException("Column '%s' is not a nested type, ADD COLUMN can only be used on nested types", @@ -548,7 +561,7 @@ StructMappingInfo AddFieldToStruct(const LogicalType &type, const vector // root path - we are adding at this level // check if a field with this name already exists for (auto &entry : child_list) { - if (StringUtil::CIEquals(entry.first, new_field.Name())) { + if (entry.first == new_field.Name()) { // already exists! result.error = ErrorData(CatalogException("Duplicate field \"%s\" - field already exists in struct %s", new_field.Name(), current_component)); @@ -573,8 +586,7 @@ StructMappingInfo AddFieldToStruct(const LogicalType &type, const vector auto &next_component = column_path[depth + 1]; bool found = false; for (auto &entry : child_list) { - if ((type.id() == LogicalTypeId::LIST && StringUtil::CIEquals(next_component, "element")) || - StringUtil::CIEquals(entry.first, next_component)) { + if ((type.id() == LogicalTypeId::LIST && next_component == "element") || entry.first == next_component) { // found the entry - recurse auto child_res = AddFieldToStruct(entry.second, column_path, new_field, depth + 1); if (child_res.error.HasError()) { @@ -701,7 +713,7 @@ void DuckTableEntry::UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_i case ConstraintType::FOREIGN_KEY: { auto copy = constraint->Copy(); auto &fk = copy->Cast(); - vector columns = fk.pk_columns; + vector columns = fk.pk_columns; if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { columns = fk.fk_columns; } else if (fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { @@ -783,7 +795,7 @@ struct DroppedFieldMapping { ErrorData error; }; -DroppedFieldMapping DropFieldFromStruct(const LogicalType &type, const vector &column_path, idx_t depth) { +DroppedFieldMapping DropFieldFromStruct(const LogicalType &type, const vector &column_path, idx_t depth) { if (!type.IsNested()) { throw CatalogException("Cannot drop field from column \"%s\" - not a nested type", column_path[0]); } @@ -798,19 +810,20 @@ DroppedFieldMapping DropFieldFromStruct(const LogicalType &type, const vector DuckTableEntry::RemoveField(ClientContext &context, Rem return ChangeColumnType(context, change_column_type); } -DroppedFieldMapping RenameFieldFromStruct(const LogicalType &type, const vector &column_path, +DroppedFieldMapping RenameFieldFromStruct(const LogicalType &type, const vector &column_path, const string &new_name, idx_t depth) { if (!type.IsNested()) { throw CatalogException("Cannot rename field from column \"%s\" - not a nested type", column_path[0]); @@ -895,25 +908,24 @@ DroppedFieldMapping RenameFieldFromStruct(const LogicalType &type, const vector< auto field_name = entry.first; Value mapping_value; LogicalType type_value; - if ((type.id() == LogicalTypeId::LIST && StringUtil::CIEquals(rename_entry, "element")) || - StringUtil::CIEquals(field_name, rename_entry)) { + if ((type.id() == LogicalTypeId::LIST && rename_entry == "element") || field_name == rename_entry) { // this is the entry we are renaming found = true; if (last_entry) { if (type.id() != LogicalTypeId::STRUCT) { throw CatalogException( "Cannot rename field '%s' from column '%s' - can only rename fields inside a struct", - column_path.back(), column_path.front()); + column_path.back().GetIdentifierName(), column_path.front().GetIdentifierName()); } // we are renaming this entry for (auto &sub_entry : child_types) { - if (StringUtil::CIEquals(new_name, sub_entry.first)) { + if (sub_entry.first == new_name) { throw CatalogException( "Cannot rename field %s from column %s to %s - a field with this name already exists", column_path.back(), column_path.front(), new_name); } } - field_name = new_name; + field_name = Identifier(new_name); mapping_value = ConstructMapping(entry.first, entry.second); type_value = entry.second; } else { @@ -956,7 +968,7 @@ unique_ptr DuckTableEntry::RenameField(ClientContext &context, Ren // follow the path auto &col = GetColumn(info.column_path[0]); - auto res = RenameFieldFromStruct(col.Type(), info.column_path, info.new_name, 1); + auto res = RenameFieldFromStruct(col.Type(), info.column_path, info.new_name.GetIdentifierName(), 1); if (res.error.HasError()) { res.error.Throw(); } @@ -1245,7 +1257,7 @@ void DuckTableEntry::Rollback(CatalogEntry &prev_entry) { // Find all index-based constraints that exist in rollback_table, but not in table. // Then, remove them. - unordered_set names; + identifier_set_t names; for (const auto &constraint : prev_table.GetConstraints()) { if (constraint->type != ConstraintType::UNIQUE) { continue; @@ -1326,7 +1338,7 @@ void DuckTableEntry::CommitAlter(string &column_name, CommitDropState &drop_stat auto &root_column_name = column_path[0]; idx_t column_position = 0; for (auto &col : columns.Logical()) { - if (StringUtil::CIEquals(col.Name(), root_column_name)) { + if (col.Name() == root_column_name) { // No need to alter storage, removed column is generated column if (col.Generated()) { return; @@ -1358,8 +1370,18 @@ TableFunction DuckTableEntry::GetScanFunction(ClientContext &context, unique_ptr return TableScanFunction::GetFunction(); } -vector DuckTableEntry::GetColumnSegmentInfo(const QueryContext &context) { - return storage->GetColumnSegmentInfo(context); +vector DuckTableEntry::GetColumnSegmentInfo(const QueryContext &context, + const ColumnSegmentInfoScanOptions &options) { + return storage->GetColumnSegmentInfo(context, options); +} + +void DuckTableEntry::InitializeColumnSegmentInfoScan(ColumnSegmentInfoScanState &state) { + storage->InitializeColumnSegmentInfoScan(state); +} + +bool DuckTableEntry::ScanColumnSegmentInfo(const QueryContext &context, ColumnSegmentInfoScanState &state, + vector &result) { + return storage->ScanColumnSegmentInfo(context, state, result); } TableStorageInfo DuckTableEntry::GetStorageInfo(ClientContext &context) { @@ -1375,6 +1397,11 @@ optional_ptr DuckTableEntry::CreateTrigger(CatalogTransaction tran if (old_entry) { return nullptr; } + } else if (info.on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT) { + auto old_entry = triggers->GetEntry(transaction, entry_name); + if (old_entry) { + triggers->DropEntry(transaction, entry_name, false); + } } if (!triggers->CreateEntry(transaction, entry_name, std::move(trigger), dependencies)) { throw CatalogException::EntryAlreadyExists(CatalogType::TRIGGER_ENTRY, entry_name); @@ -1391,7 +1418,7 @@ void DuckTableEntry::ScanTriggersNonTransactional(const std::functionScan(callback); } -bool DuckTableEntry::DropTrigger(CatalogTransaction transaction, const string &name, bool cascade) { +bool DuckTableEntry::DropTrigger(CatalogTransaction transaction, const Identifier &name, bool cascade) { return triggers->DropEntry(transaction, name, cascade); } diff --git a/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp index 700d8213a..a5c29faeb 100644 --- a/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp @@ -30,17 +30,17 @@ SimilarCatalogEntry SchemaCatalogEntry::GetSimilarEntry(CatalogTransaction trans const EntryLookupInfo &lookup_info) { SimilarCatalogEntry result; Scan(transaction.GetContext(), lookup_info.GetCatalogType(), [&](CatalogEntry &entry) { - auto entry_score = StringUtil::SimilarityRating(entry.name, lookup_info.GetEntryName()); + auto entry_score = StringUtil::SimilarityRating(entry.name.GetIdentifierName(), lookup_info.GetEntryName()); if (entry_score > result.score) { result.score = entry_score; - result.name = entry.name; + result.name = Identifier(entry.name.GetIdentifierName()); } }); return result; } optional_ptr SchemaCatalogEntry::GetEntry(CatalogTransaction transaction, CatalogType type, - const string &name) { + const Identifier &name) { EntryLookupInfo lookup_info(type, name); return LookupEntry(transaction, lookup_info); } diff --git a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp index e2014ba7a..43b771f43 100644 --- a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp @@ -114,7 +114,7 @@ string SequenceCatalogEntry::ToSQL() const { duckdb::stringstream ss; ss << "CREATE SEQUENCE "; - ss << name; + ss << name.GetIdentifierName(); ss << " INCREMENT BY " << seq_data.increment; ss << " MINVALUE " << seq_data.min_value; ss << " MAXVALUE " << seq_data.max_value; diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index 2fe6a4a94..9fcec52c0 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -53,7 +53,7 @@ StorageIndex TableCatalogEntry::GetStorageIndex(const ColumnIndex &column_id) co return result; } -LogicalIndex TableCatalogEntry::GetColumnIndex(string &column_name, bool if_exists) const { +LogicalIndex TableCatalogEntry::GetColumnIndex(Identifier &column_name, bool if_exists) const { auto entry = columns.GetColumnIndex(column_name); if (!entry.IsValid()) { if (if_exists) { @@ -61,11 +61,12 @@ LogicalIndex TableCatalogEntry::GetColumnIndex(string &column_name, bool if_exis } vector column_names; for (auto &col : columns.Logical()) { - column_names.push_back(col.Name()); + column_names.emplace_back(col.Name()); } - auto candidates = StringUtil::CandidatesErrorMessage(column_names, column_name, "Did you mean"); - throw BinderException("Table \"%s\" does not have a column with name \"%s\"\n%s", name, column_name, - candidates); + auto candidates = + StringUtil::CandidatesErrorMessage(column_names, column_name.GetIdentifierName(), "Did you mean"); + throw BinderException("Table \"%s\" does not have a column with name \"%s\"\n%s", name.GetIdentifierName(), + column_name, candidates); } return entry; } @@ -74,11 +75,11 @@ unique_ptr TableCatalogEntry::GetSample() { return nullptr; } -bool TableCatalogEntry::ColumnExists(const string &name) const { +bool TableCatalogEntry::ColumnExists(const Identifier &name) const { return columns.ColumnExists(name); } -const ColumnDefinition &TableCatalogEntry::GetColumn(const string &name) const { +const ColumnDefinition &TableCatalogEntry::GetColumn(const Identifier &name) const { return columns.GetColumn(name); } @@ -116,7 +117,7 @@ string TableCatalogEntry::ColumnsToSQL(const ColumnList &columns, const vector multi_key_pks; + identifier_set_t multi_key_pks; vector extra_constraints; for (auto &constraint : constraints) { if (constraint->type == ConstraintType::NOT_NULL) { @@ -159,7 +160,7 @@ string TableCatalogEntry::ColumnsToSQL(const ColumnList &columns, const vector TableCatalogEntry::GetColumnSegmentInfo(const QueryContext &context) { +vector TableCatalogEntry::GetColumnSegmentInfo(const QueryContext &context, + const ColumnSegmentInfoScanOptions &options) { return {}; } +void TableCatalogEntry::InitializeColumnSegmentInfoScan(ColumnSegmentInfoScanState &state) { +} + +bool TableCatalogEntry::ScanColumnSegmentInfo(const QueryContext &context, ColumnSegmentInfoScanState &state, + vector &result) { + return false; +} + void TableCatalogEntry::BindUpdateConstraints(Binder &binder, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, ClientContext &context) { // check the constraints and indexes of the table to see if we need to project any additional columns @@ -367,6 +377,8 @@ vector> TableCatalogEntry::GetTriggersForEv TriggerTiming timing, TriggerEventType event_type) const { vector> result; + // CatalogSet is backed by case_insensitive_tree_t (a map with case-insensitive comparator), + // so ScanTriggers yields entries in alphabetical order by name ScanTriggers(transaction, [&](CatalogEntry &entry) { auto &trigger = entry.Cast(); if (trigger.timing == timing && trigger.event_type == event_type) { diff --git a/src/duckdb/src/catalog/catalog_entry/trigger_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/trigger_catalog_entry.cpp index 6d0ff987d..8ac215d86 100644 --- a/src/duckdb/src/catalog/catalog_entry/trigger_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/trigger_catalog_entry.cpp @@ -12,6 +12,7 @@ TriggerCatalogEntry::TriggerCatalogEntry(Catalog &catalog, SchemaCatalogEntry &s : StandardEntry(CatalogType::TRIGGER_ENTRY, schema, catalog, info.trigger_name), base_table(unique_ptr_cast(info.base_table->Copy())), timing(info.timing), event_type(info.event_type), columns(info.columns), for_each(info.for_each), + referencing_new_table(info.referencing_new_table), referencing_old_table(info.referencing_old_table), trigger_action(info.trigger_action->Copy()) { this->temporary = info.temporary; this->comment = info.comment; @@ -34,6 +35,8 @@ unique_ptr TriggerCatalogEntry::GetInfo() const { result->event_type = event_type; result->columns = columns; result->for_each = for_each; + result->referencing_new_table = referencing_new_table; + result->referencing_old_table = referencing_old_table; result->trigger_action = trigger_action->Copy(); result->dependencies = dependencies; result->comment = comment; @@ -60,6 +63,15 @@ string TriggerCatalogEntry::ToSQL() const { } ss << " ON "; ss << ParseInfo::QualifierToString(base_table->catalog_name, base_table->schema_name, base_table->table_name); + if (!referencing_new_table.empty() || !referencing_old_table.empty()) { + ss << " REFERENCING"; + if (!referencing_new_table.empty()) { + ss << " NEW TABLE AS " << SQLIdentifier(referencing_new_table); + } + if (!referencing_old_table.empty()) { + ss << " OLD TABLE AS " << SQLIdentifier(referencing_old_table); + } + } ss << " FOR EACH " << EnumUtil::ToString(for_each); ss << " " << trigger_action->ToString(); ss << ";"; diff --git a/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp index 8c0da475e..c0df1cbf9 100644 --- a/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp @@ -74,19 +74,30 @@ unique_ptr ViewCatalogEntry::AlterEntry(ClientContext &context, Al auto &comment_on_column_info = info.Cast(); auto copied_view = Copy(context); + Identifier resolved_column_name = comment_on_column_info.column_name; auto view_columns = GetColumnInfo(); if (view_columns) { // if the view is bound - verify the name we are commenting on exists auto &names = view_columns->names; - auto entry = std::find(names.begin(), names.end(), comment_on_column_info.column_name); + auto match_name = [&](const auto &n) { + return resolved_column_name == n; + }; + auto entry = std::find_if(names.begin(), names.end(), match_name); if (entry == names.end()) { - throw BinderException("View \"%s\" does not have a column with name \"%s\"", name, - comment_on_column_info.column_name); + // the column name might be a view alias - check those as well + auto alias_entry = std::find_if(aliases.begin(), aliases.end(), match_name); + if (alias_entry == aliases.end()) { + throw BinderException("View \"%s\" does not have a column with name \"%s\"", + name.GetIdentifierName(), resolved_column_name.GetIdentifierName()); + } + auto alias_index = NumericCast(std::distance(aliases.begin(), alias_entry)); + D_ASSERT(alias_index < names.size()); + resolved_column_name = names[alias_index]; } } // apply the comment to the view auto &copied_view_entry = copied_view->Cast(); - copied_view_entry.column_comments[comment_on_column_info.column_name] = comment_on_column_info.comment_value; + copied_view_entry.column_comments[resolved_column_name] = comment_on_column_info.comment_value; return copied_view; } @@ -167,7 +178,7 @@ void ViewCatalogEntry::BindView(ClientContext &context, BindViewAction action) { bind_thread = thread_id {}; } -void ViewCatalogEntry::UpdateBinding(const vector &types_p, const vector &names_p) { +void ViewCatalogEntry::UpdateBinding(const vector &types_p, const vector &names_p) { auto columns = view_columns.atomic_load(); if (columns && columns->types == types_p && columns->names == names_p) { // already bound with the current info diff --git a/src/duckdb/src/catalog/catalog_entry_retriever.cpp b/src/duckdb/src/catalog/catalog_entry_retriever.cpp index ece24b1e0..ea2c6baf7 100644 --- a/src/duckdb/src/catalog/catalog_entry_retriever.cpp +++ b/src/duckdb/src/catalog/catalog_entry_retriever.cpp @@ -11,7 +11,7 @@ namespace duckdb { -LogicalType CatalogEntryRetriever::GetType(Catalog &catalog, const string &schema, const string &name, +LogicalType CatalogEntryRetriever::GetType(Catalog &catalog, const Identifier &schema, const Identifier &name, OnEntryNotFound on_entry_not_found) { EntryLookupInfo lookup_info(CatalogType::TYPE_ENTRY, name); auto result = GetEntry(catalog, schema, lookup_info, on_entry_not_found); @@ -22,7 +22,7 @@ LogicalType CatalogEntryRetriever::GetType(Catalog &catalog, const string &schem return type_entry.user_type; } -LogicalType CatalogEntryRetriever::GetType(const string &catalog, const string &schema, const string &name, +LogicalType CatalogEntryRetriever::GetType(const Identifier &catalog, const Identifier &schema, const Identifier &name, OnEntryNotFound on_entry_not_found) { EntryLookupInfo lookup_info(CatalogType::TYPE_ENTRY, name); auto result = GetEntry(catalog, schema, lookup_info, on_entry_not_found); @@ -33,13 +33,13 @@ LogicalType CatalogEntryRetriever::GetType(const string &catalog, const string & return type_entry.user_type; } -optional_ptr CatalogEntryRetriever::GetEntry(const string &catalog, const string &schema, +optional_ptr CatalogEntryRetriever::GetEntry(const Identifier &catalog, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found) { return ReturnAndCallback(Catalog::GetEntry(*this, catalog, schema, lookup_info, on_entry_not_found)); } -optional_ptr CatalogEntryRetriever::GetSchema(const string &catalog, +optional_ptr CatalogEntryRetriever::GetSchema(const Identifier &catalog, const EntryLookupInfo &schema_lookup_p, OnEntryNotFound on_entry_not_found) { EntryLookupInfo schema_lookup(schema_lookup_p, at_clause); @@ -54,7 +54,7 @@ optional_ptr CatalogEntryRetriever::GetSchema(const string & return result; } -optional_ptr CatalogEntryRetriever::GetEntry(Catalog &catalog, const string &schema, +optional_ptr CatalogEntryRetriever::GetEntry(Catalog &catalog, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found) { return ReturnAndCallback(catalog.GetEntry(*this, schema, lookup_info, on_entry_not_found)); diff --git a/src/duckdb/src/catalog/catalog_search_path.cpp b/src/duckdb/src/catalog/catalog_search_path.cpp index 540109496..3a7ce8f30 100644 --- a/src/duckdb/src/catalog/catalog_search_path.cpp +++ b/src/duckdb/src/catalog/catalog_search_path.cpp @@ -7,12 +7,13 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/database_manager.hpp" +#include "duckdb/main/extension_callback_manager.hpp" #include "duckdb/common/exception/parser_exception.hpp" namespace duckdb { -CatalogSearchEntry::CatalogSearchEntry(string catalog_p, string schema_p) +CatalogSearchEntry::CatalogSearchEntry(Identifier catalog_p, Identifier schema_p) : catalog(std::move(catalog_p)), schema(std::move(schema_p)) { } @@ -24,13 +25,14 @@ string CatalogSearchEntry::ToString() const { } } -string CatalogSearchEntry::WriteOptionallyQuoted(const string &input) { - for (idx_t i = 0; i < input.size(); i++) { - if (input[i] == '.' || input[i] == ',' || input[i] == '"') { - return "\"" + StringUtil::Replace(input, "\"", "\"\"") + "\""; +string CatalogSearchEntry::WriteOptionallyQuoted(const Identifier &input) { + auto &name = input.GetIdentifierName(); + for (idx_t i = 0; i < name.size(); i++) { + if (name[i] == '.' || name[i] == ',' || name[i] == '"') { + return "\"" + StringUtil::Replace(name, "\"", "\"\"") + "\""; } } - return input; + return name; } string CatalogSearchEntry::ListToString(const vector &input) { @@ -104,7 +106,7 @@ CatalogSearchEntry CatalogSearchEntry::ParseInternal(const string &input, idx_t if (schema.empty()) { throw ParserException("Unexpected end of entry - empty CatalogSearchEntry"); } - return CatalogSearchEntry(std::move(catalog), std::move(schema)); + return CatalogSearchEntry(Identifier(std::move(catalog)), Identifier(std::move(schema))); } CatalogSearchEntry CatalogSearchEntry::Parse(const string &input) { @@ -139,6 +141,10 @@ void CatalogSearchPath::Reset() { SetPathsInternal(empty); } +void CatalogSearchPath::RefreshSetPaths() { + SetPathsInternal(set_paths); +} + string CatalogSearchPath::GetSetName(CatalogSetPathType set_type) { switch (set_type) { case CatalogSetPathType::SET_SCHEMA: @@ -173,7 +179,8 @@ void CatalogSearchPath::Set(vector new_paths, CatalogSetPath if (path.catalog.empty()) { auto catalog = Catalog::GetCatalogEntry(context, path.schema); if (catalog) { - auto schema = catalog->GetSchema(context, catalog->GetDefaultSchema(), OnEntryNotFound::RETURN_NULL); + auto schema = + catalog->GetSchema(context, Identifier(catalog->GetDefaultSchema()), OnEntryNotFound::RETURN_NULL); if (schema) { path.catalog = std::move(path.schema); path.schema = schema->name; @@ -186,7 +193,7 @@ void CatalogSearchPath::Set(vector new_paths, CatalogSetPath if (set_type == CatalogSetPathType::SET_SCHEMA) { if (new_paths[0].catalog == TEMP_CATALOG || new_paths[0].catalog == SYSTEM_CATALOG) { throw CatalogException("%s cannot be set to internal schema \"%s\"", GetSetName(set_type), - new_paths[0].catalog); + new_paths[0].catalog.GetIdentifierName()); } } SetPathsInternal(std::move(new_paths)); @@ -208,68 +215,68 @@ vector CatalogSearchPath::Get() const { return res; } -string CatalogSearchPath::GetDefaultSchema(const string &catalog) const { +Identifier CatalogSearchPath::GetDefaultSchema(const Identifier &catalog) const { for (auto &path : paths) { if (path.catalog == TEMP_CATALOG) { continue; } - if (StringUtil::CIEquals(path.catalog, catalog)) { + if (path.catalog == catalog) { return path.schema; } } return DEFAULT_SCHEMA; } -string CatalogSearchPath::GetDefaultSchema(ClientContext &context, const string &catalog) const { +Identifier CatalogSearchPath::GetDefaultSchema(ClientContext &context, const Identifier &catalog) const { for (auto &path : paths) { if (path.catalog == TEMP_CATALOG) { continue; } - if (StringUtil::CIEquals(path.catalog, catalog)) { + if (path.catalog == catalog) { return path.schema; } } auto catalog_entry = Catalog::GetCatalogEntry(context, catalog); if (catalog_entry) { - return catalog_entry->GetDefaultSchema(); + return Identifier(catalog_entry->GetDefaultSchema()); } return DEFAULT_SCHEMA; } -string CatalogSearchPath::GetDefaultCatalog(const string &schema) const { +Identifier CatalogSearchPath::GetDefaultCatalog(const Identifier &schema) const { if (DefaultSchemaGenerator::IsDefaultSchema(schema)) { - return SYSTEM_CATALOG; + return Identifier::SystemCatalog(); } for (auto &path : paths) { if (path.catalog == TEMP_CATALOG) { continue; } - if (StringUtil::CIEquals(path.schema, schema)) { + if (path.schema == schema) { return path.catalog; } } - return INVALID_CATALOG; + return Identifier::InvalidCatalog(); } -vector CatalogSearchPath::GetCatalogsForSchema(const string &schema) const { - vector catalogs; +vector CatalogSearchPath::GetCatalogsForSchema(const Identifier &schema) const { + vector catalogs; if (DefaultSchemaGenerator::IsDefaultSchema(schema)) { - catalogs.push_back(SYSTEM_CATALOG); + catalogs.push_back(Identifier::SystemCatalog()); } else { for (auto &path : paths) { - if (StringUtil::CIEquals(path.schema, schema) || path.schema.empty()) { - catalogs.push_back(path.catalog); + if (path.schema == schema || path.schema.empty()) { + catalogs.emplace_back(path.catalog); } } } return catalogs; } -vector CatalogSearchPath::GetSchemasForCatalog(const string &catalog) const { - vector schemas; +vector CatalogSearchPath::GetSchemasForCatalog(const Identifier &catalog) const { + vector schemas; for (auto &path : paths) { - if (!path.schema.empty() && StringUtil::CIEquals(path.catalog, catalog)) { - schemas.push_back(path.schema); + if (!path.schema.empty() && path.catalog == catalog) { + schemas.emplace_back(path.schema); } } return schemas; @@ -293,19 +300,22 @@ void CatalogSearchPath::SetPathsInternal(vector new_paths) { paths.emplace_back(INVALID_CATALOG, DEFAULT_SCHEMA); paths.emplace_back(SYSTEM_CATALOG, DEFAULT_SCHEMA); paths.emplace_back(SYSTEM_CATALOG, "pg_catalog"); + // set extension schemas on the search path, if any + for (auto &schema : ExtensionCallbackManager::Get(context).GetExtensionSchemas()) { + paths.emplace_back(Identifier(SYSTEM_CATALOG), Identifier(schema)); + } } -bool CatalogSearchPath::SchemaInSearchPath(ClientContext &context, const string &catalog_name, - const string &schema_name) const { +bool CatalogSearchPath::SchemaInSearchPath(ClientContext &context, const Identifier &catalog_name, + const Identifier &schema_name) const { for (auto &path : paths) { - if (!StringUtil::CIEquals(path.schema, schema_name)) { + if (path.schema != schema_name) { continue; } - if (StringUtil::CIEquals(path.catalog, catalog_name)) { + if (path.catalog == catalog_name) { return true; } - if (IsInvalidCatalog(path.catalog) && - StringUtil::CIEquals(catalog_name, DatabaseManager::GetDefaultDatabase(context))) { + if (IsInvalidCatalog(path.catalog) && catalog_name == DatabaseManager::GetDefaultDatabase(context)) { return true; } } diff --git a/src/duckdb/src/catalog/catalog_set.cpp b/src/duckdb/src/catalog/catalog_set.cpp index 1a12f1517..f96b398ab 100644 --- a/src/duckdb/src/catalog/catalog_set.cpp +++ b/src/duckdb/src/catalog/catalog_set.cpp @@ -45,7 +45,7 @@ void CatalogEntryMap::UpdateEntry(unique_ptr catalog_entry) { entry->second->SetChild(std::move(existing)); } -case_insensitive_tree_t> &CatalogEntryMap::Entries() { +identifier_tree_t> &CatalogEntryMap::Entries() { return entries; } @@ -77,7 +77,7 @@ void CatalogEntryMap::DropEntry(CatalogEntry &entry) { } } -optional_ptr CatalogEntryMap::GetEntry(const string &name) { +optional_ptr CatalogEntryMap::GetEntry(const Identifier &name) { auto entry = entries.find(name); if (entry == entries.end()) { return nullptr; @@ -92,7 +92,7 @@ CatalogSet::CatalogSet(Catalog &catalog_p, unique_ptr defaults CatalogSet::~CatalogSet() { } -bool CatalogSet::StartChain(CatalogTransaction transaction, const string &name, unique_lock &read_lock) { +bool CatalogSet::StartChain(CatalogTransaction transaction, const Identifier &name, unique_lock &read_lock) { D_ASSERT(!map.GetEntry(name)); // check if there is a default entry @@ -130,7 +130,7 @@ static bool IsDependencyEntry(CatalogEntry &entry) { return entry.type == CatalogType::DEPENDENCY_ENTRY; } -void CatalogSet::CheckCatalogEntryInvariants(CatalogEntry &value, const string &name) { +void CatalogSet::CheckCatalogEntryInvariants(CatalogEntry &value, const Identifier &name) { if (value.internal && !catalog.IsSystemCatalog() && name != DEFAULT_SCHEMA) { throw InternalException("Attempting to create internal entry \"%s\" in non-system catalog - internal entries " "can only be created in the system catalog", @@ -169,8 +169,9 @@ optional_ptr CatalogSet::CreateCommittedEntry(unique_ptr value, - unique_lock &read_lock, bool should_be_empty) { +bool CatalogSet::CreateEntryInternal(CatalogTransaction transaction, const Identifier &name, + unique_ptr value, unique_lock &read_lock, + bool should_be_empty) { auto entry_value = map.GetEntry(name); if (!entry_value) { // Add a dummy node to start the chain @@ -195,7 +196,7 @@ bool CatalogSet::CreateEntryInternal(CatalogTransaction transaction, const strin return true; } -bool CatalogSet::CreateEntry(CatalogTransaction transaction, const string &name, unique_ptr value, +bool CatalogSet::CreateEntry(CatalogTransaction transaction, const Identifier &name, unique_ptr value, const LogicalDependencyList &dependencies) { CheckCatalogEntryInvariants(*value, name); @@ -212,13 +213,13 @@ bool CatalogSet::CreateEntry(CatalogTransaction transaction, const string &name, return CreateEntryInternal(transaction, name, std::move(value), read_lock); } -bool CatalogSet::CreateEntry(ClientContext &context, const string &name, unique_ptr value, +bool CatalogSet::CreateEntry(ClientContext &context, const Identifier &name, unique_ptr value, const LogicalDependencyList &dependencies) { return CreateEntry(catalog.GetCatalogTransaction(context), name, std::move(value), dependencies); } //! This method is used to retrieve an entry for the purpose of making a new version, through an alter/drop/create -optional_ptr CatalogSet::GetEntryInternal(CatalogTransaction transaction, const string &name) { +optional_ptr CatalogSet::GetEntryInternal(CatalogTransaction transaction, const Identifier &name) { auto entry_value = map.GetEntry(name); if (!entry_value) { return nullptr; @@ -260,14 +261,15 @@ bool CatalogSet::AlterOwnership(CatalogTransaction transaction, ChangeOwnershipI } } if (!owner_entry) { - throw CatalogException("CatalogElement \"%s.%s\" does not exist!", info.owner_schema, info.owner_name); + throw CatalogException("CatalogElement \"%s.%s\" does not exist!", info.owner_schema.GetIdentifierName(), + info.owner_name.GetIdentifierName()); } write_lock.unlock(); catalog.GetDependencyManager()->AddOwnership(transaction, *owner_entry, *entry); return true; } -bool CatalogSet::RenameEntryInternal(CatalogTransaction transaction, CatalogEntry &old, const string &new_name, +bool CatalogSet::RenameEntryInternal(CatalogTransaction transaction, CatalogEntry &old, const Identifier &new_name, AlterInfo &alter_info, unique_lock &read_lock) { auto &original_name = old.name; @@ -306,7 +308,7 @@ bool CatalogSet::RenameEntryInternal(CatalogTransaction transaction, CatalogEntr return CreateEntryInternal(transaction, new_name, std::move(renamed_node), read_lock); } -bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, AlterInfo &alter_info) { +bool CatalogSet::AlterEntry(CatalogTransaction transaction, const Identifier &name, AlterInfo &alter_info) { // If the entry does not exist, we error auto entry = GetEntry(transaction, name); if (!entry) { @@ -360,7 +362,7 @@ bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, value->timestamp = transaction.transaction_id; value->set = this; - if (!StringUtil::CIEquals(value->name, entry->name)) { + if (!(value->name == entry->name)) { if (!RenameEntryInternal(transaction, *entry, value->name, alter_info, read_lock)) { return false; } @@ -395,7 +397,7 @@ bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, return true; } -bool CatalogSet::DropDependencies(CatalogTransaction transaction, const string &name, bool cascade, +bool CatalogSet::DropDependencies(CatalogTransaction transaction, const Identifier &name, bool cascade, bool allow_drop_internal) { auto entry = GetEntry(transaction, name); if (!entry) { @@ -411,7 +413,7 @@ bool CatalogSet::DropDependencies(CatalogTransaction transaction, const string & return true; } -bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const string &name, bool allow_drop_internal) { +bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const Identifier &name, bool allow_drop_internal) { // lock the catalog for writing // we can only delete an entry that exists auto entry = GetEntryInternal(transaction, name); @@ -440,7 +442,8 @@ bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const string return true; } -bool CatalogSet::DropEntry(CatalogTransaction transaction, const string &name, bool cascade, bool allow_drop_internal) { +bool CatalogSet::DropEntry(CatalogTransaction transaction, const Identifier &name, bool cascade, + bool allow_drop_internal) { if (!DropDependencies(transaction, name, cascade, allow_drop_internal)) { return false; } @@ -449,7 +452,7 @@ bool CatalogSet::DropEntry(CatalogTransaction transaction, const string &name, b return DropEntryInternal(transaction, name, allow_drop_internal); } -bool CatalogSet::DropEntry(ClientContext &context, const string &name, bool cascade, bool allow_drop_internal) { +bool CatalogSet::DropEntry(ClientContext &context, const Identifier &name, bool cascade, bool allow_drop_internal) { return DropEntry(catalog.GetCatalogTransaction(context), name, cascade, allow_drop_internal); } @@ -563,22 +566,22 @@ CatalogEntry &CatalogSet::GetCommittedEntry(CatalogEntry ¤t) { return entry.get(); } -SimilarCatalogEntry CatalogSet::SimilarEntry(CatalogTransaction transaction, const string &name) { +SimilarCatalogEntry CatalogSet::SimilarEntry(CatalogTransaction transaction, const Identifier &name) { unique_lock lock(catalog_lock); CreateDefaultEntries(transaction, lock); SimilarCatalogEntry result; for (auto &kv : map.Entries()) { - auto entry_score = StringUtil::SimilarityRating(kv.first, name); + auto entry_score = StringUtil::SimilarityRating(kv.first.GetIdentifierName(), name.GetIdentifierName()); if (entry_score > result.score) { result.score = entry_score; - result.name = kv.first; + result.name = Identifier(kv.first.GetIdentifierName()); } } return result; } -optional_ptr CatalogSet::CreateDefaultEntry(CatalogTransaction transaction, const string &name, +optional_ptr CatalogSet::CreateDefaultEntry(CatalogTransaction transaction, const Identifier &name, unique_lock &read_lock) { // no entry found with this name, check for defaults if (!defaults || defaults->created_all_entries) { @@ -614,7 +617,7 @@ optional_ptr CatalogSet::CreateDefaultEntry(CatalogTransaction tra return GetEntry(transaction, name); } -CatalogSet::EntryLookup CatalogSet::GetEntryDetailed(CatalogTransaction transaction, const string &name) { +CatalogSet::EntryLookup CatalogSet::GetEntryDetailed(CatalogTransaction transaction, const Identifier &name) { unique_lock read_lock(catalog_lock); auto entry_value = map.GetEntry(name); if (entry_value) { @@ -631,7 +634,7 @@ CatalogSet::EntryLookup CatalogSet::GetEntryDetailed(CatalogTransaction transact return EntryLookup {nullptr, EntryLookup::FailureReason::DELETED}; } } - D_ASSERT(StringUtil::CIEquals(name, current.name)); + D_ASSERT(current.name == name); return EntryLookup {¤t, EntryLookup::FailureReason::SUCCESS}; } auto default_entry = CreateDefaultEntry(transaction, name, read_lock); @@ -641,12 +644,12 @@ CatalogSet::EntryLookup CatalogSet::GetEntryDetailed(CatalogTransaction transact return EntryLookup {default_entry, EntryLookup::FailureReason::SUCCESS}; } -optional_ptr CatalogSet::GetEntry(CatalogTransaction transaction, const string &name) { +optional_ptr CatalogSet::GetEntry(CatalogTransaction transaction, const Identifier &name) { auto lookup = GetEntryDetailed(transaction, name); return lookup.result; } -optional_ptr CatalogSet::GetEntry(ClientContext &context, const string &name) { +optional_ptr CatalogSet::GetEntry(ClientContext &context, const Identifier &name) { return GetEntry(catalog.GetCatalogTransaction(context), name); } @@ -665,7 +668,7 @@ void CatalogSet::Undo(CatalogEntry &entry) { auto &to_be_removed_node = entry.Parent(); to_be_removed_node.Rollback(entry); - D_ASSERT(StringUtil::CIEquals(entry.name, to_be_removed_node.name)); + D_ASSERT(entry.name == to_be_removed_node.name); if (!to_be_removed_node.HasParent()) { to_be_removed_node.Child().SetAsRoot(); } @@ -745,14 +748,14 @@ void CatalogSet::ScanWithReturn(ClientContext &context, const std::function &callback, - const string &prefix) { + const Identifier &prefix) { // lock the catalog set unique_lock lock(catalog_lock); CreateDefaultEntries(transaction, lock); auto &entries = map.Entries(); auto it = entries.lower_bound(prefix); - auto end = entries.upper_bound(prefix + char(255)); + auto end = entries.upper_bound(Identifier(prefix.GetIdentifierName() + char(255))); for (; it != end; it++) { auto &entry = *it->second; auto &entry_for_transaction = GetEntryForTransaction(transaction, entry); diff --git a/src/duckdb/src/catalog/default/default_coordinate_systems.cpp b/src/duckdb/src/catalog/default/default_coordinate_systems.cpp index 759c24e5f..6dac4991e 100644 --- a/src/duckdb/src/catalog/default/default_coordinate_systems.cpp +++ b/src/duckdb/src/catalog/default/default_coordinate_systems.cpp @@ -44,7 +44,7 @@ DefaultCoordinateSystemGenerator::DefaultCoordinateSystemGenerator(Catalog &cata } unique_ptr DefaultCoordinateSystemGenerator::CreateDefaultEntry(ClientContext &context, - const string &entry_name) { + const Identifier &entry_name) { if (schema.name != DEFAULT_SCHEMA) { return nullptr; } @@ -62,14 +62,14 @@ unique_ptr DefaultCoordinateSystemGenerator::CreateDefaultEntry(Cl return nullptr; } -vector DefaultCoordinateSystemGenerator::GetDefaultEntries() { +vector DefaultCoordinateSystemGenerator::GetDefaultEntries() { if (schema.name != DEFAULT_SCHEMA) { return {}; } - vector entries; + vector entries; for (const auto &crs_definition : DEFAULT_CRS_DEFINITIONS) { - entries.push_back(crs_definition.name); + entries.emplace_back(crs_definition.name); } return entries; } diff --git a/src/duckdb/src/catalog/default/default_functions.cpp b/src/duckdb/src/catalog/default/default_functions.cpp index df651376c..afe503e16 100644 --- a/src/duckdb/src/catalog/default/default_functions.cpp +++ b/src/duckdb/src/catalog/default/default_functions.cpp @@ -8,162 +8,218 @@ namespace duckdb { static const DefaultMacro internal_macros[] = { - {DEFAULT_SCHEMA, "current_role", "() AS 'duckdb'"}, // user name of current execution context - {DEFAULT_SCHEMA, "current_user", "() AS 'duckdb'"}, // user name of current execution context - {DEFAULT_SCHEMA, "current_catalog", "() AS main.current_database()"}, // name of current database (called "catalog" in the SQL standard) - {DEFAULT_SCHEMA, "user", "() AS current_user"}, // equivalent to current_user - {DEFAULT_SCHEMA, "session_user", "() AS 'duckdb'"}, // session user name - {"pg_catalog", "inet_client_addr", "() AS NULL"}, // address of the remote connection - {"pg_catalog", "inet_client_port", "() AS NULL"}, // port of the remote connection - {"pg_catalog", "inet_server_addr", "() AS NULL"}, // address of the local connection - {"pg_catalog", "inet_server_port", "() AS NULL"}, // port of the local connection - {"pg_catalog", "pg_my_temp_schema", "() AS 0"}, // OID of session's temporary schema, or 0 if none - {"pg_catalog", "pg_is_other_temp_schema", "(schema_id) AS false"}, // is schema another session's temporary schema? - - {"pg_catalog", "pg_conf_load_time", "() AS current_timestamp"}, // configuration load time - {"pg_catalog", "pg_postmaster_start_time", "() AS current_timestamp"}, // server start time - - {"pg_catalog", "pg_typeof", "(expression) AS lower(typeof(expression))"}, // get the data type of any value - - {"pg_catalog", "current_database", "() AS system.main.current_database()"}, // name of current database (called "catalog" in the SQL standard) - {"pg_catalog", "current_query", "() AS system.main.current_query()"}, // the currently executing query (NULL if not inside a plpgsql function) - {"pg_catalog", "current_schema", "() AS system.main.current_schema()"}, // name of current schema - {"pg_catalog", "current_schemas", "(include_implicit) AS system.main.current_schemas(include_implicit)"}, // names of schemas in search path - - // privilege functions - {"pg_catalog", "has_any_column_privilege", "(\"table\", privilege) AS true, (user, \"table\", privilege) AS true"}, //boolean //does current/named user have privilege for any column of table - {"pg_catalog", "has_column_privilege", "(\"table\", \"column\", privilege) AS true, (user, \"table\", \"column\", privilege) AS true"}, //boolean //does current/named user have privilege for column - {"pg_catalog", "has_database_privilege", "(database, privilege) AS true, (user, database, privilege) AS true"}, //boolean //does current/named user have privilege for database - {"pg_catalog", "has_foreign_data_wrapper_privilege", "(fdw, privilege) AS true, (user, fdw, privilege) AS true"}, //boolean //does current/named user have privilege for foreign-data wrapper - {"pg_catalog", "has_function_privilege", "(function, privilege) AS true, (user, function, privilege) AS true"}, //boolean //does current/named user have privilege for function - {"pg_catalog", "has_language_privilege", "(language, privilege) AS true, (user, language, privilege) AS true"}, //boolean //does current/named user have privilege for language - {"pg_catalog", "has_schema_privilege", "(schema, privilege) AS true, (user, schema, privilege) AS true"}, //boolean //does current/named user have privilege for schema - {"pg_catalog", "has_sequence_privilege", "(sequence, privilege) AS true, (user, sequence, privilege) AS true"}, //boolean //does current/named user have privilege for sequence - {"pg_catalog", "has_server_privilege", "(server, privilege) AS true, (user, server, privilege) AS true"}, //boolean //does current/named user have privilege for foreign server - {"pg_catalog", "has_table_privilege", "(\"table\", privilege) AS true, (user, \"table\", privilege) AS true"}, //boolean //does current/named user have privilege for table - {"pg_catalog", "has_tablespace_privilege", "(tablespace, privilege) AS true, (user, tablespace, privilege) AS true"}, //boolean //does current/named user have privilege for tablespace - - // various postgres system functions - {"pg_catalog", "pg_get_viewdef", "(oid) AS (select sql from duckdb_views() v where v.view_oid=oid)"}, - {"pg_catalog", "pg_get_constraintdef", "(constraint_oid) AS (select constraint_text from duckdb_constraints() d_constraint where d_constraint.table_oid=constraint_oid//1000000 and d_constraint.constraint_index=constraint_oid%1000000), (constraint_oid, pretty_bool) AS pg_get_constraintdef(constraint_oid)"}, - {"pg_catalog", "pg_get_expr", "(pg_node_tree, relation_oid) AS pg_node_tree"}, - {"pg_catalog", "format_pg_type", "(logical_type, type_name) AS case upper(logical_type) when 'FLOAT' then 'float4' when 'DOUBLE' then 'float8' when 'DECIMAL' then 'numeric' when 'ENUM' then lower(type_name) when 'VARCHAR' then 'varchar' when 'BLOB' then 'bytea' when 'TIMESTAMP' then 'timestamp' when 'TIME' then 'time' when 'TIMESTAMP WITH TIME ZONE' then 'timestamptz' when 'TIME WITH TIME ZONE' then 'timetz' when 'SMALLINT' then 'int2' when 'INTEGER' then 'int4' when 'BIGINT' then 'int8' when 'BOOLEAN' then 'bool' else lower(logical_type) end"}, - {"pg_catalog", "format_type", "(type_oid, typemod) AS (select format_pg_type(logical_type, type_name) from duckdb_types() t where t.type_oid=type_oid) || case when typemod>0 then concat('(', typemod//1000, ',', typemod%1000, ')') else '' end"}, - {"pg_catalog", "map_to_pg_oid", "(type_name) AS case type_name when 'bool' then 16 when 'int16' then 21 when 'int' then 23 when 'bigint' then 20 when 'date' then 1082 when 'time' then 1083 when 'datetime' then 1114 when 'dec' then 1700 when 'float' then 700 when 'double' then 701 when 'bpchar' then 1043 when 'binary' then 17 when 'interval' then 1186 when 'timestamptz' then 1184 when 'timestamp with time zone' then 1184 when 'timetz' then 1266 when 'time with time zone' then 1266 when 'bit' then 1560 when 'guid' then 2950 else null end"}, // map duckdb_oid to pg_oid. If no corresponding type, return null - - {"pg_catalog", "pg_has_role", "(user, role, privilege) AS true, (role, privilege) AS true"}, //boolean //does current/named user have privilege for role - - {"pg_catalog", "col_description", "(table_oid, column_number) AS NULL"}, // get comment for a table column - {"pg_catalog", "obj_description", "(object_oid, catalog_name) AS NULL"}, // get comment for a database object - {"pg_catalog", "shobj_description", "(object_oid, catalog_name) AS NULL"}, // get comment for a shared database object - - // visibility functions - {"pg_catalog", "pg_collation_is_visible", "(collation_oid) AS true"}, - {"pg_catalog", "pg_conversion_is_visible", "(conversion_oid) AS true"}, - {"pg_catalog", "pg_function_is_visible", "(function_oid) AS true"}, - {"pg_catalog", "pg_opclass_is_visible", "(opclass_oid) AS true"}, - {"pg_catalog", "pg_operator_is_visible", "(operator_oid) AS true"}, - {"pg_catalog", "pg_opfamily_is_visible", "(opclass_oid) AS true"}, - {"pg_catalog", "pg_table_is_visible", "(table_oid) AS true"}, - {"pg_catalog", "pg_ts_config_is_visible", "(config_oid) AS true"}, - {"pg_catalog", "pg_ts_dict_is_visible", "(dict_oid) AS true"}, - {"pg_catalog", "pg_ts_parser_is_visible", "(parser_oid) AS true"}, - {"pg_catalog", "pg_ts_template_is_visible", "(template_oid) AS true"}, - {"pg_catalog", "pg_type_is_visible", "(type_oid) AS true"}, - - {"pg_catalog", "pg_size_pretty", "(bytes) AS format_bytes(bytes)"}, - {"pg_catalog", "pg_sleep", "(seconds) AS sleep_ms(CAST(seconds * 1000 AS BIGINT))"}, - - {DEFAULT_SCHEMA, "round_even", "(x, n) AS CASE ((abs(x) * power(10, n+1)) % 10) WHEN 5 THEN round(x/2, n) * 2 ELSE round(x, n) END"}, - {DEFAULT_SCHEMA, "roundbankers", "(x, n) AS round_even(x, n)"}, - {DEFAULT_SCHEMA, "nullif", "(a, b) AS CASE WHEN a=b THEN NULL ELSE a END"}, - {DEFAULT_SCHEMA, "list_append", "(l, e) AS list_concat(l, list_value(e))"}, - {DEFAULT_SCHEMA, "array_append", "(arr, el) AS list_append(arr, el)"}, - {DEFAULT_SCHEMA, "list_prepend", "(e, l) AS list_concat(list_value(e), l)"}, - {DEFAULT_SCHEMA, "array_prepend", "(el, arr) AS list_prepend(el, arr)"}, - {DEFAULT_SCHEMA, "array_pop_back", "(arr) AS arr[:LEN(arr)-1]"}, - {DEFAULT_SCHEMA, "array_pop_front", "(arr) AS arr[2:]"}, - {DEFAULT_SCHEMA, "array_push_back", "(arr, e) AS list_concat(arr, list_value(e))"}, - {DEFAULT_SCHEMA, "array_push_front", "(arr, e) AS list_concat(list_value(e), arr)"}, - {DEFAULT_SCHEMA, "array_to_string", "(arr, sep) AS case len(arr::varchar[]) when 0 then '' else list_aggr(arr::varchar[], 'string_agg', sep) end"}, - // Test default parameters - {DEFAULT_SCHEMA, "array_to_string_comma_default", "(arr, sep := ',') AS case len(arr::varchar[]) when 0 then '' else list_aggr(arr::varchar[], 'string_agg', sep) end"}, - - {DEFAULT_SCHEMA, "generate_subscripts", "(arr, dim) AS unnest(generate_series(1, array_length(arr, dim)))"}, - {DEFAULT_SCHEMA, "fdiv", "(x, y) AS floor(x/y)"}, - {DEFAULT_SCHEMA, "fmod", "(x, y) AS (x-y*floor(x/y))"}, - {DEFAULT_SCHEMA, "split_part", "(string, delimiter, \"position\") AS if(string IS NOT NULL AND delimiter IS NOT NULL AND position IS NOT NULL, coalesce(string_split(string, delimiter)[position],''), NULL)"}, - {DEFAULT_SCHEMA, "geomean", "(x) AS exp(avg(ln(x)))"}, - {DEFAULT_SCHEMA, "geometric_mean", "(x) AS geomean(x)"}, - - {DEFAULT_SCHEMA, "weighted_avg", "(value, weight) AS SUM(value * weight) / SUM(CASE WHEN value IS NOT NULL THEN weight ELSE 0 END)"}, - {DEFAULT_SCHEMA, "wavg", "(value, weight) AS weighted_avg(value, weight)"}, - - {DEFAULT_SCHEMA, "list_reverse", "(l) AS l[:-:-1]"}, - {DEFAULT_SCHEMA, "array_reverse", "(l) AS list_reverse(l)"}, - - // algebraic list aggregates - {DEFAULT_SCHEMA, "list_avg", "(l) AS list_aggr(l, 'avg')"}, - {DEFAULT_SCHEMA, "list_var_samp", "(l) AS list_aggr(l, 'var_samp')"}, - {DEFAULT_SCHEMA, "list_var_pop", "(l) AS list_aggr(l, 'var_pop')"}, - {DEFAULT_SCHEMA, "list_stddev_pop", "(l) AS list_aggr(l, 'stddev_pop')"}, - {DEFAULT_SCHEMA, "list_stddev_samp", "(l) AS list_aggr(l, 'stddev_samp')"}, - {DEFAULT_SCHEMA, "list_sem", "(l) AS list_aggr(l, 'sem')"}, - - // distributive list aggregates - {DEFAULT_SCHEMA, "list_approx_count_distinct", "(l) AS list_aggr(l, 'approx_count_distinct')"}, - {DEFAULT_SCHEMA, "list_bit_xor", "(l) AS list_aggr(l, 'bit_xor')"}, - {DEFAULT_SCHEMA, "list_bit_or", "(l) AS list_aggr(l, 'bit_or')"}, - {DEFAULT_SCHEMA, "list_bit_and", "(l) AS list_aggr(l, 'bit_and')"}, - {DEFAULT_SCHEMA, "list_bool_and", "(l) AS list_aggr(l, 'bool_and')"}, - {DEFAULT_SCHEMA, "list_bool_or", "(l) AS list_aggr(l, 'bool_or')"}, - {DEFAULT_SCHEMA, "list_count", "(l) AS list_aggr(l, 'count')"}, - {DEFAULT_SCHEMA, "list_entropy", "(l) AS list_aggr(l, 'entropy')"}, - {DEFAULT_SCHEMA, "list_last", "(l) AS list_aggr(l, 'last')"}, - {DEFAULT_SCHEMA, "list_first", "(l) AS list_aggr(l, 'first')"}, - {DEFAULT_SCHEMA, "list_any_value", "(l) AS list_aggr(l, 'any_value')"}, - {DEFAULT_SCHEMA, "list_kurtosis", "(l) AS list_aggr(l, 'kurtosis')"}, - {DEFAULT_SCHEMA, "list_kurtosis_pop", "(l) AS list_aggr(l, 'kurtosis_pop')"}, - {DEFAULT_SCHEMA, "list_min", "(l) AS list_aggr(l, 'min')"}, - {DEFAULT_SCHEMA, "list_max", "(l) AS list_aggr(l, 'max')"}, - {DEFAULT_SCHEMA, "list_product", "(l) AS list_aggr(l, 'product')"}, - {DEFAULT_SCHEMA, "list_skewness", "(l) AS list_aggr(l, 'skewness')"}, - {DEFAULT_SCHEMA, "list_sum", "(l) AS list_aggr(l, 'sum')"}, - {DEFAULT_SCHEMA, "list_string_agg", "(l) AS list_aggr(l, 'string_agg')"}, - - // holistic list aggregates - {DEFAULT_SCHEMA, "list_mode", "(l) AS list_aggr(l, 'mode')"}, - {DEFAULT_SCHEMA, "list_median", "(l) AS list_aggr(l, 'median')"}, - {DEFAULT_SCHEMA, "list_mad", "(l) AS list_aggr(l, 'mad')"}, - - // nested list aggregates - {DEFAULT_SCHEMA, "list_histogram", "(l) AS list_aggr(l, 'histogram')"}, - - // map functions - {DEFAULT_SCHEMA, "map_contains_entry", "(map, key, value) AS contains(map_entries(map), {'key': key, 'value': value})"}, - {DEFAULT_SCHEMA, "map_contains_value", "(map, value) AS contains(map_values(map), value)"}, - - // date functions - {DEFAULT_SCHEMA, "date_add", "(date, \"interval\") AS date + interval"}, - - // date functions - convenience macro for getting days in month - {DEFAULT_SCHEMA, "days_in_month", "(date) AS day(last_day(date))"}, - - // timestamptz functions - {DEFAULT_SCHEMA, "ago", "(i) AS current_timestamp - i::interval"}, - - // regexp functions - {DEFAULT_SCHEMA, "regexp_split_to_table", "(text, pattern) AS unnest(string_split_regex(text, pattern))"}, - - // storage helper functions - {DEFAULT_SCHEMA, "get_block_size", "(db_name) AS (SELECT block_size FROM pragma_database_size() WHERE database_name = db_name)"}, - - // string functions - {DEFAULT_SCHEMA, "md5_number_upper", "(param) AS ((md5_number(param)::bit::varchar)[65:])::bit::uint64"}, - {DEFAULT_SCHEMA, "md5_number_lower", "(param) AS ((md5_number(param)::bit::varchar)[:64])::bit::uint64"}, - - {nullptr, nullptr, nullptr} - }; + {DEFAULT_SCHEMA, "current_role", "() AS 'duckdb'"}, // user name of current execution context + {DEFAULT_SCHEMA, "current_user", "() AS 'duckdb'"}, // user name of current execution context + {DEFAULT_SCHEMA, "current_catalog", + "() AS main.current_database()"}, // name of current database (called "catalog" in the SQL standard) + {DEFAULT_SCHEMA, "user", "() AS current_user"}, // equivalent to current_user + {DEFAULT_SCHEMA, "session_user", "() AS 'duckdb'"}, // session user name + {"pg_catalog", "inet_client_addr", "() AS NULL"}, // address of the remote connection + {"pg_catalog", "inet_client_port", "() AS NULL"}, // port of the remote connection + {"pg_catalog", "inet_server_addr", "() AS NULL"}, // address of the local connection + {"pg_catalog", "inet_server_port", "() AS NULL"}, // port of the local connection + {"pg_catalog", "pg_my_temp_schema", "() AS 0"}, // OID of session's temporary schema, or 0 if none + {"pg_catalog", "pg_is_other_temp_schema", "(schema_id) AS false"}, // is schema another session's temporary schema? + + {"pg_catalog", "pg_conf_load_time", "() AS current_timestamp"}, // configuration load time + {"pg_catalog", "pg_postmaster_start_time", "() AS current_timestamp"}, // server start time + + {"pg_catalog", "pg_typeof", "(expression) AS lower(typeof(expression))"}, // get the data type of any value + + {"pg_catalog", "current_database", + "() AS system.main.current_database()"}, // name of current database (called "catalog" in the SQL standard) + {"pg_catalog", "current_query", + "() AS system.main.current_query()"}, // the currently executing query (NULL if not inside a plpgsql function) + {"pg_catalog", "current_schema", "() AS system.main.current_schema()"}, // name of current schema + {"pg_catalog", "current_schemas", + "(include_implicit) AS system.main.current_schemas(include_implicit)"}, // names of schemas in search path + + // privilege functions + {"pg_catalog", "has_any_column_privilege", + "(\"table\", privilege) AS true, (user, \"table\", privilege) AS true"}, // boolean //does current/named user have + // privilege for any column of table + {"pg_catalog", "has_column_privilege", + "(\"table\", \"column\", privilege) AS true, (user, \"table\", \"column\", privilege) AS true"}, // boolean //does + // current/named + // user have + // privilege for + // column + {"pg_catalog", "has_database_privilege", + "(database, privilege) AS true, (user, database, privilege) AS true"}, // boolean //does current/named user have + // privilege for database + {"pg_catalog", "has_foreign_data_wrapper_privilege", + "(fdw, privilege) AS true, (user, fdw, privilege) AS true"}, // boolean //does current/named user have privilege + // for foreign-data wrapper + {"pg_catalog", "has_function_privilege", + "(function, privilege) AS true, (user, function, privilege) AS true"}, // boolean //does current/named user have + // privilege for function + {"pg_catalog", "has_language_privilege", + "(language, privilege) AS true, (user, language, privilege) AS true"}, // boolean //does current/named user have + // privilege for language + {"pg_catalog", "has_schema_privilege", + "(schema, privilege) AS true, (user, schema, privilege) AS true"}, // boolean //does current/named user have + // privilege for schema + {"pg_catalog", "has_sequence_privilege", + "(sequence, privilege) AS true, (user, sequence, privilege) AS true"}, // boolean //does current/named user have + // privilege for sequence + {"pg_catalog", "has_server_privilege", + "(server, privilege) AS true, (user, server, privilege) AS true"}, // boolean //does current/named user have + // privilege for foreign server + {"pg_catalog", "has_table_privilege", + "(\"table\", privilege) AS true, (user, \"table\", privilege) AS true"}, // boolean //does current/named user have + // privilege for table + {"pg_catalog", "has_tablespace_privilege", + "(tablespace, privilege) AS true, (user, tablespace, privilege) AS true"}, // boolean //does current/named user + // have privilege for tablespace + + // various postgres system functions + {"pg_catalog", "pg_get_viewdef", "(oid) AS (select sql from duckdb_views() v where v.view_oid=oid)"}, + {"pg_catalog", "pg_get_constraintdef", + "(constraint_oid) AS (select constraint_text from duckdb_constraints() d_constraint where " + "d_constraint.table_oid=constraint_oid//1000000 and d_constraint.constraint_index=constraint_oid%1000000), " + "(constraint_oid, pretty_bool) AS pg_get_constraintdef(constraint_oid)"}, + {"pg_catalog", "pg_get_expr", "(pg_node_tree, relation_oid) AS pg_node_tree"}, + {"pg_catalog", "format_pg_type", + "(logical_type, type_name) AS case upper(logical_type) when 'FLOAT' then 'float4' when 'DOUBLE' then 'float8' " + "when 'DECIMAL' then 'numeric' when 'ENUM' then lower(type_name) when 'VARCHAR' then 'varchar' when 'BLOB' then " + "'bytea' when 'TIMESTAMP' then 'timestamp' when 'TIME' then 'time' when 'TIMESTAMP WITH TIME ZONE' then " + "'timestamptz' when 'TIME WITH TIME ZONE' then 'timetz' when 'SMALLINT' then 'int2' when 'INTEGER' then 'int4' " + "when 'BIGINT' then 'int8' when 'BOOLEAN' then 'bool' else lower(logical_type) end"}, + {"pg_catalog", "format_type", + "(type_oid, typemod) AS (select format_pg_type(logical_type, type_name) from duckdb_types() t where " + "t.type_oid=type_oid) || case when typemod>0 then concat('(', typemod//1000, ',', typemod%1000, ')') else '' end"}, + {"pg_catalog", "map_to_pg_oid", + "(type_name) AS case type_name when 'bool' then 16 when 'int16' then 21 when 'int' then 23 when 'bigint' then 20 " + "when 'date' then 1082 when 'time' then 1083 when 'datetime' then 1114 when 'dec' then 1700 when 'float' then 700 " + "when 'double' then 701 when 'bpchar' then 1043 when 'binary' then 17 when 'interval' then 1186 when " + "'timestamptz' then 1184 when 'timestamp with time zone' then 1184 when 'timetz' then 1266 when 'time with time " + "zone' then 1266 when 'bit' then 1560 when 'guid' then 2950 else null end"}, // map duckdb_oid to pg_oid. If no + // corresponding type, return null + + {"pg_catalog", "pg_has_role", + "(user, role, privilege) AS true, (role, privilege) AS true"}, // boolean //does current/named user have privilege + // for role + + {"pg_catalog", "col_description", "(table_oid, column_number) AS NULL"}, // get comment for a table column + {"pg_catalog", "obj_description", "(object_oid, catalog_name) AS NULL"}, // get comment for a database object + {"pg_catalog", "shobj_description", + "(object_oid, catalog_name) AS NULL"}, // get comment for a shared database object + + // visibility functions + {"pg_catalog", "pg_collation_is_visible", "(collation_oid) AS true"}, + {"pg_catalog", "pg_conversion_is_visible", "(conversion_oid) AS true"}, + {"pg_catalog", "pg_function_is_visible", "(function_oid) AS true"}, + {"pg_catalog", "pg_opclass_is_visible", "(opclass_oid) AS true"}, + {"pg_catalog", "pg_operator_is_visible", "(operator_oid) AS true"}, + {"pg_catalog", "pg_opfamily_is_visible", "(opclass_oid) AS true"}, + {"pg_catalog", "pg_table_is_visible", "(table_oid) AS true"}, + {"pg_catalog", "pg_ts_config_is_visible", "(config_oid) AS true"}, + {"pg_catalog", "pg_ts_dict_is_visible", "(dict_oid) AS true"}, + {"pg_catalog", "pg_ts_parser_is_visible", "(parser_oid) AS true"}, + {"pg_catalog", "pg_ts_template_is_visible", "(template_oid) AS true"}, + {"pg_catalog", "pg_type_is_visible", "(type_oid) AS true"}, + + {"pg_catalog", "pg_size_pretty", "(bytes) AS format_bytes(bytes)"}, + {"pg_catalog", "pg_sleep", "(seconds) AS sleep_ms(CAST(seconds * 1000 AS BIGINT))"}, + + {DEFAULT_SCHEMA, "round_even", + "(x, n) AS CASE ((abs(x) * power(10, n+1)) % 10) WHEN 5 THEN round(x/2, n) * 2 ELSE round(x, n) END"}, + {DEFAULT_SCHEMA, "roundbankers", "(x, n) AS round_even(x, n)"}, + {DEFAULT_SCHEMA, "nullif", "(a, b) AS CASE WHEN a=b THEN NULL ELSE a END"}, + {DEFAULT_SCHEMA, "list_append", "(l, e) AS list_concat(l, list_value(e))"}, + {DEFAULT_SCHEMA, "array_append", "(arr, el) AS list_append(arr, el)"}, + {DEFAULT_SCHEMA, "list_prepend", "(e, l) AS list_concat(list_value(e), l)"}, + {DEFAULT_SCHEMA, "array_prepend", "(el, arr) AS list_prepend(el, arr)"}, + {DEFAULT_SCHEMA, "array_pop_back", "(arr) AS arr[:LEN(arr)-1]"}, + {DEFAULT_SCHEMA, "array_pop_front", "(arr) AS arr[2:]"}, + {DEFAULT_SCHEMA, "array_push_back", "(arr, e) AS list_concat(arr, list_value(e))"}, + {DEFAULT_SCHEMA, "array_push_front", "(arr, e) AS list_concat(list_value(e), arr)"}, + {DEFAULT_SCHEMA, "array_to_string", + "(arr, sep) AS case len(arr::varchar[]) when 0 then '' else list_aggr(arr::varchar[], 'string_agg', sep) end"}, + // Test default parameters + {DEFAULT_SCHEMA, "array_to_string_comma_default", + "(arr, sep := ',') AS case len(arr::varchar[]) when 0 then '' else list_aggr(arr::varchar[], 'string_agg', sep) " + "end"}, + + {DEFAULT_SCHEMA, "generate_subscripts", "(arr, dim) AS unnest(generate_series(1, array_length(arr, dim)))"}, + {DEFAULT_SCHEMA, "fdiv", "(x, y) AS floor(x/y)"}, + {DEFAULT_SCHEMA, "fmod", "(x, y) AS (x-y*floor(x/y))"}, + {DEFAULT_SCHEMA, "split_part", + "(string, delimiter, \"position\") AS if(string IS NOT NULL AND delimiter IS NOT NULL AND position IS NOT NULL, " + "coalesce(string_split(string, delimiter)[position],''), NULL)"}, + {DEFAULT_SCHEMA, "geomean", "(x) AS exp(avg(ln(x)))"}, + {DEFAULT_SCHEMA, "geometric_mean", "(x) AS geomean(x)"}, + + {DEFAULT_SCHEMA, "weighted_avg", + "(value, weight) AS SUM(value * weight) / SUM(CASE WHEN value IS NOT NULL THEN weight ELSE 0 END)"}, + {DEFAULT_SCHEMA, "wavg", "(value, weight) AS weighted_avg(value, weight)"}, + + {DEFAULT_SCHEMA, "list_reverse", "(l) AS l[:-:-1]"}, + {DEFAULT_SCHEMA, "array_reverse", "(l) AS list_reverse(l)"}, + + // algebraic list aggregates + {DEFAULT_SCHEMA, "list_avg", "(l) AS list_aggr(l, 'avg')"}, + {DEFAULT_SCHEMA, "list_var_samp", "(l) AS list_aggr(l, 'var_samp')"}, + {DEFAULT_SCHEMA, "list_var_pop", "(l) AS list_aggr(l, 'var_pop')"}, + {DEFAULT_SCHEMA, "list_stddev_pop", "(l) AS list_aggr(l, 'stddev_pop')"}, + {DEFAULT_SCHEMA, "list_stddev_samp", "(l) AS list_aggr(l, 'stddev_samp')"}, + {DEFAULT_SCHEMA, "list_sem", "(l) AS list_aggr(l, 'sem')"}, + + // distributive list aggregates + {DEFAULT_SCHEMA, "list_approx_count_distinct", "(l) AS list_aggr(l, 'approx_count_distinct')"}, + {DEFAULT_SCHEMA, "list_bit_xor", "(l) AS list_aggr(l, 'bit_xor')"}, + {DEFAULT_SCHEMA, "list_bit_or", "(l) AS list_aggr(l, 'bit_or')"}, + {DEFAULT_SCHEMA, "list_bit_and", "(l) AS list_aggr(l, 'bit_and')"}, + {DEFAULT_SCHEMA, "list_bool_and", "(l) AS list_aggr(l, 'bool_and')"}, + {DEFAULT_SCHEMA, "list_bool_or", "(l) AS list_aggr(l, 'bool_or')"}, + {DEFAULT_SCHEMA, "list_count", "(l) AS list_aggr(l, 'count')"}, + {DEFAULT_SCHEMA, "list_entropy", "(l) AS list_aggr(l, 'entropy')"}, + {DEFAULT_SCHEMA, "list_last", "(l) AS list_aggr(l, 'last')"}, + {DEFAULT_SCHEMA, "list_first", "(l) AS list_aggr(l, 'first')"}, + {DEFAULT_SCHEMA, "list_any_value", "(l) AS list_aggr(l, 'any_value')"}, + {DEFAULT_SCHEMA, "list_kurtosis", "(l) AS list_aggr(l, 'kurtosis')"}, + {DEFAULT_SCHEMA, "list_kurtosis_pop", "(l) AS list_aggr(l, 'kurtosis_pop')"}, + {DEFAULT_SCHEMA, "list_min", "(l) AS list_aggr(l, 'min')"}, + {DEFAULT_SCHEMA, "list_max", "(l) AS list_aggr(l, 'max')"}, + {DEFAULT_SCHEMA, "list_product", "(l) AS list_aggr(l, 'product')"}, + {DEFAULT_SCHEMA, "list_skewness", "(l) AS list_aggr(l, 'skewness')"}, + {DEFAULT_SCHEMA, "list_sum", "(l) AS list_aggr(l, 'sum')"}, + {DEFAULT_SCHEMA, "list_string_agg", "(l) AS list_aggr(l, 'string_agg')"}, + + // holistic list aggregates + {DEFAULT_SCHEMA, "list_mode", "(l) AS list_aggr(l, 'mode')"}, + {DEFAULT_SCHEMA, "list_median", "(l) AS list_aggr(l, 'median')"}, + {DEFAULT_SCHEMA, "list_mad", "(l) AS list_aggr(l, 'mad')"}, + + // nested list aggregates + {DEFAULT_SCHEMA, "list_histogram", "(l) AS list_aggr(l, 'histogram')"}, + + // map functions + {DEFAULT_SCHEMA, "map_contains_entry", + "(map, key, value) AS contains(map_entries(map), {'key': key, 'value': value})"}, + {DEFAULT_SCHEMA, "map_contains_value", "(map, value) AS contains(map_values(map), value)"}, + + // date functions + {DEFAULT_SCHEMA, "date_add", "(date, \"interval\") AS date + interval"}, + + // date functions - convenience macro for getting days in month + {DEFAULT_SCHEMA, "days_in_month", "(date) AS day(last_day(date))"}, + + // timestamptz functions + {DEFAULT_SCHEMA, "ago", "(i) AS current_timestamp - i::interval"}, + + // regexp functions + {DEFAULT_SCHEMA, "regexp_split_to_table", "(text, pattern) AS unnest(string_split_regex(text, pattern))"}, + + // storage helper functions + {DEFAULT_SCHEMA, "get_block_size", + "(db_name) AS (SELECT block_size FROM pragma_database_size() WHERE database_name = db_name)"}, + + // string functions + {DEFAULT_SCHEMA, "md5_number_upper", "(param) AS ((md5_number(param)::bit::varchar)[65:])::bit::uint64"}, + {DEFAULT_SCHEMA, "md5_number_lower", "(param) AS ((md5_number(param)::bit::varchar)[:64])::bit::uint64"}, + + {nullptr, nullptr, nullptr}}; unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(const DefaultMacro &default_macro) { auto bind_info = make_uniq(CatalogType::MACRO_ENTRY); @@ -187,20 +243,20 @@ unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(co } } bind_info->macros = std::move(macro_info.macros); - bind_info->schema = default_macro.schema; - bind_info->name = default_macro.name; + bind_info->schema = Identifier(default_macro.schema); + bind_info->name = Identifier(default_macro.name); bind_info->temporary = true; bind_info->internal = true; return bind_info; } -static bool DefaultFunctionMatches(const DefaultMacro ¯o, const string &schema, const string &name) { +static bool DefaultFunctionMatches(const DefaultMacro ¯o, const Identifier &schema, const Identifier &name) { return macro.schema == schema && macro.name == name; } -static unique_ptr GetDefaultFunction(const string &input_schema, const string &input_name) { - auto schema = StringUtil::Lower(input_schema); - auto name = StringUtil::Lower(input_name); +static unique_ptr GetDefaultFunction(const Identifier &input_schema, const Identifier &input_name) { + auto &schema = input_schema; + auto &name = input_name; for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { if (DefaultFunctionMatches(internal_macros[index], schema, name)) { return DefaultFunctionGenerator::CreateInternalMacroInfo(internal_macros[index]); @@ -214,7 +270,7 @@ DefaultFunctionGenerator::DefaultFunctionGenerator(Catalog &catalog, SchemaCatal } unique_ptr DefaultFunctionGenerator::CreateDefaultEntry(ClientContext &context, - const string &entry_name) { + const Identifier &entry_name) { auto info = GetDefaultFunction(schema.name, entry_name); if (info) { return make_uniq_base(catalog, schema, info->Cast()); @@ -222,8 +278,8 @@ unique_ptr DefaultFunctionGenerator::CreateDefaultEntry(ClientCont return nullptr; } -vector DefaultFunctionGenerator::GetDefaultEntries() { - vector result; +vector DefaultFunctionGenerator::GetDefaultEntries() { + vector result; for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { if (StringUtil::Lower(internal_macros[index].name) != internal_macros[index].name) { throw InternalException("Default macro name %s should be lowercase", internal_macros[index].name); diff --git a/src/duckdb/src/catalog/default/default_generator.cpp b/src/duckdb/src/catalog/default/default_generator.cpp index 2fbb2b646..265ac34c0 100644 --- a/src/duckdb/src/catalog/default/default_generator.cpp +++ b/src/duckdb/src/catalog/default/default_generator.cpp @@ -8,12 +8,12 @@ DefaultGenerator::DefaultGenerator(Catalog &catalog) : catalog(catalog), created DefaultGenerator::~DefaultGenerator() { } -unique_ptr DefaultGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { +unique_ptr DefaultGenerator::CreateDefaultEntry(ClientContext &context, const Identifier &entry_name) { throw InternalException("CreateDefaultEntry with ClientContext called but not supported in this generator"); } unique_ptr DefaultGenerator::CreateDefaultEntry(CatalogTransaction transaction, - const string &entry_name) { + const Identifier &entry_name) { if (!transaction.context) { // no context - cannot create default entry return nullptr; diff --git a/src/duckdb/src/catalog/default/default_schemas.cpp b/src/duckdb/src/catalog/default/default_schemas.cpp index 64aaf56d2..b1f66d9a8 100644 --- a/src/duckdb/src/catalog/default/default_schemas.cpp +++ b/src/duckdb/src/catalog/default/default_schemas.cpp @@ -11,10 +11,9 @@ struct DefaultSchema { static const DefaultSchema internal_schemas[] = {{"information_schema"}, {"pg_catalog"}, {nullptr}}; -bool DefaultSchemaGenerator::IsDefaultSchema(const string &input_schema) { - auto schema = StringUtil::Lower(input_schema); +bool DefaultSchemaGenerator::IsDefaultSchema(const Identifier &input_schema) { for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { - if (internal_schemas[index].name == schema) { + if (internal_schemas[index].name == input_schema) { return true; } } @@ -25,18 +24,18 @@ DefaultSchemaGenerator::DefaultSchemaGenerator(Catalog &catalog) : DefaultGenera } unique_ptr DefaultSchemaGenerator::CreateDefaultEntry(CatalogTransaction transaction, - const string &entry_name) { + const Identifier &entry_name) { if (IsDefaultSchema(entry_name)) { CreateSchemaInfo info; - info.schema = StringUtil::Lower(entry_name); + info.schema = Identifier(StringUtil::Lower(entry_name.GetIdentifierName())); info.internal = true; return make_uniq_base(catalog, info); } return nullptr; } -vector DefaultSchemaGenerator::GetDefaultEntries() { - vector result; +vector DefaultSchemaGenerator::GetDefaultEntries() { + vector result; for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { result.emplace_back(internal_schemas[index].name); } diff --git a/src/duckdb/src/catalog/default/default_table_functions.cpp b/src/duckdb/src/catalog/default/default_table_functions.cpp index 0351e8ce3..e2bc4f584 100644 --- a/src/duckdb/src/catalog/default/default_table_functions.cpp +++ b/src/duckdb/src/catalog/default/default_table_functions.cpp @@ -79,7 +79,7 @@ SELECT * EXCLUDE(input_type, scope, aliases) 'profiling_coverage', 'profiling_output', 'profiling_mode', - 'custom_profiling_settings' + 'tracked_metrics' ); )"}, {nullptr, nullptr, {nullptr}, {{nullptr, nullptr}}, nullptr} @@ -103,13 +103,13 @@ DefaultTableFunctionGenerator::CreateInternalTableMacroInfo(const DefaultTableMa throw InternalException("Expected a single expression"); } function->parameters.push_back(make_uniq(named_param.name)); - function->default_parameters.insert(make_pair(named_param.name, std::move(expr_list[0]))); + function->default_parameters.insert(Identifier(named_param.name), std::move(expr_list[0])); } auto type = CatalogType::TABLE_MACRO_ENTRY; auto bind_info = make_uniq(type); - bind_info->schema = default_macro.schema; - bind_info->name = default_macro.name; + bind_info->schema = Identifier(default_macro.schema); + bind_info->name = Identifier(default_macro.name); bind_info->temporary = true; bind_info->internal = true; bind_info->macros.push_back(std::move(function)); @@ -129,11 +129,10 @@ DefaultTableFunctionGenerator::CreateTableMacroInfo(const DefaultTableMacro &def return CreateInternalTableMacroInfo(default_macro, std::move(result)); } -static unique_ptr GetDefaultTableFunction(const string &input_schema, const string &input_name) { - auto schema = StringUtil::Lower(input_schema); - auto name = StringUtil::Lower(input_name); +static unique_ptr GetDefaultTableFunction(const Identifier &input_schema, + const Identifier &input_name) { for (idx_t index = 0; internal_table_macros[index].name != nullptr; index++) { - if (internal_table_macros[index].schema == schema && internal_table_macros[index].name == name) { + if (internal_table_macros[index].schema == input_schema && internal_table_macros[index].name == input_name) { return DefaultTableFunctionGenerator::CreateTableMacroInfo(internal_table_macros[index]); } } @@ -141,7 +140,7 @@ static unique_ptr GetDefaultTableFunction(const string &inpu } unique_ptr DefaultTableFunctionGenerator::CreateDefaultEntry(ClientContext &context, - const string &entry_name) { + const Identifier &entry_name) { auto info = GetDefaultTableFunction(schema.name, entry_name); if (info) { return make_uniq_base(catalog, schema, info->Cast()); @@ -149,8 +148,8 @@ unique_ptr DefaultTableFunctionGenerator::CreateDefaultEntry(Clien return nullptr; } -vector DefaultTableFunctionGenerator::GetDefaultEntries() { - vector result; +vector DefaultTableFunctionGenerator::GetDefaultEntries() { + vector result; for (idx_t index = 0; internal_table_macros[index].name != nullptr; index++) { if (StringUtil::Lower(internal_table_macros[index].name) != internal_table_macros[index].name) { throw InternalException("Default macro name %s should be lowercase", internal_table_macros[index].name); diff --git a/src/duckdb/src/catalog/default/default_types.cpp b/src/duckdb/src/catalog/default/default_types.cpp index a59be63de..16b60a0f9 100644 --- a/src/duckdb/src/catalog/default/default_types.cpp +++ b/src/duckdb/src/catalog/default/default_types.cpp @@ -47,7 +47,7 @@ LogicalType BindDecimalType(BindLogicalTypeInput &input) { if (scale_value.DefaultTryCastAs(LogicalTypeId::UTINYINT)) { scale = scale_value.GetValueUnsafe(); } else { - throw BinderException("DECIMAL type scale must be between 0 and %d", Decimal::MAX_WIDTH_DECIMAL - 1); + throw BinderException("DECIMAL type scale must be between 0 and %d", Decimal::MAX_WIDTH_DECIMAL); } } @@ -301,14 +301,14 @@ LogicalType BindStructType(BindLogicalTypeInput &input) { // Named struct case D_ASSERT(all_name); child_list_t children; - case_insensitive_set_t name_collision_set; + identifier_set_t name_collision_set; for (auto &arg : arguments) { auto &child_name = arg.GetName(); - if (name_collision_set.find(child_name) != name_collision_set.end()) { + if (name_collision_set.find(Identifier(child_name)) != name_collision_set.end()) { throw BinderException("Duplicate STRUCT type argument name \"%s\"", child_name); } - name_collision_set.insert(child_name); + name_collision_set.insert(Identifier(child_name)); children.emplace_back(child_name, TypeValue::GetType(arg.GetValue())); } @@ -359,7 +359,7 @@ LogicalType BindUnionType(BindLogicalTypeInput &input) { } child_list_t children; - case_insensitive_set_t name_collision_set; + identifier_set_t name_collision_set; for (auto &arg : arguments) { if (!arg.HasName()) { @@ -375,11 +375,11 @@ LogicalType BindUnionType(BindLogicalTypeInput &input) { auto &entry_name = arg.GetName(); auto entry_type = TypeValue::GetType(arg.GetValue()); - if (name_collision_set.find(entry_name) != name_collision_set.end()) { + if (name_collision_set.find(Identifier(entry_name)) != name_collision_set.end()) { throw BinderException("Duplicate UNION type member name \"%s\"", entry_name); } - name_collision_set.insert(entry_name); + name_collision_set.insert(Identifier(entry_name)); children.emplace_back(entry_name, entry_type); } @@ -546,10 +546,10 @@ const builtin_type_array BUILTIN_TYPES = {{{"decimal", LogicalTypeId::DECIMAL, B {"geometry", LogicalTypeId::GEOMETRY, BindGeometryType}, {"type", LogicalTypeId::TYPE, nullptr}}}; -optional_ptr TryGetDefaultTypeEntry(const string &name) { +optional_ptr TryGetDefaultTypeEntry(const Identifier &name) { auto &internal_types = BUILTIN_TYPES; for (auto &type : internal_types) { - if (StringUtil::CIEquals(name, type.name)) { + if (name == type.name) { return &type; } } @@ -561,10 +561,10 @@ optional_ptr TryGetDefaultTypeEntry(const string &name) { //---------------------------------------------------------------------------------------------------------------------- // Default Type Generator //---------------------------------------------------------------------------------------------------------------------- -LogicalTypeId DefaultTypeGenerator::GetDefaultType(const string &name) { +LogicalTypeId DefaultTypeGenerator::GetDefaultType(const Identifier &name) { auto &internal_types = BUILTIN_TYPES; for (auto &type : internal_types) { - if (StringUtil::CIEquals(name, type.name)) { + if (name == type.name) { return type.type; } } @@ -572,7 +572,7 @@ LogicalTypeId DefaultTypeGenerator::GetDefaultType(const string &name) { } LogicalType DefaultTypeGenerator::TryDefaultBind(const string &name, const vector> ¶ms) { - auto entry = TryGetDefaultTypeEntry(name); + auto entry = TryGetDefaultTypeEntry(Identifier(name)); if (!entry) { return LogicalTypeId::INVALID; } @@ -598,7 +598,8 @@ DefaultTypeGenerator::DefaultTypeGenerator(Catalog &catalog, SchemaCatalogEntry : DefaultGenerator(catalog), schema(schema) { } -unique_ptr DefaultTypeGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { +unique_ptr DefaultTypeGenerator::CreateDefaultEntry(ClientContext &context, + const Identifier &entry_name) { if (schema.name != DEFAULT_SCHEMA) { return nullptr; } @@ -615,8 +616,8 @@ unique_ptr DefaultTypeGenerator::CreateDefaultEntry(ClientContext return make_uniq_base(catalog, schema, info); } -vector DefaultTypeGenerator::GetDefaultEntries() { - vector result; +vector DefaultTypeGenerator::GetDefaultEntries() { + vector result; if (schema.name != DEFAULT_SCHEMA) { return result; } diff --git a/src/duckdb/src/catalog/default/default_views.cpp b/src/duckdb/src/catalog/default/default_views.cpp index ff22352b8..6323beb10 100644 --- a/src/duckdb/src/catalog/default/default_views.cpp +++ b/src/duckdb/src/catalog/default/default_views.cpp @@ -13,8 +13,13 @@ struct DefaultView { }; static const DefaultView internal_views[] = { - {DEFAULT_SCHEMA, "pragma_database_list", "SELECT database_oid AS seq, database_name AS name, path AS file FROM duckdb_databases() WHERE NOT internal ORDER BY 1"}, - {DEFAULT_SCHEMA, "sqlite_master", "select 'table' \"type\", table_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_tables union all select 'view' \"type\", view_name \"name\", view_name \"tbl_name\", 0 rootpage, sql from duckdb_views union all select 'index' \"type\", index_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_indexes;"}, + {DEFAULT_SCHEMA, "pragma_database_list", + "SELECT database_oid AS seq, database_name AS name, path AS file FROM duckdb_databases() WHERE NOT internal ORDER " + "BY 1"}, + {DEFAULT_SCHEMA, "sqlite_master", + "select 'table' \"type\", table_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_tables union " + "all select 'view' \"type\", view_name \"name\", view_name \"tbl_name\", 0 rootpage, sql from duckdb_views union " + "all select 'index' \"type\", index_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_indexes;"}, {DEFAULT_SCHEMA, "sqlite_schema", "SELECT * FROM sqlite_master"}, {DEFAULT_SCHEMA, "sqlite_temp_master", "SELECT * FROM sqlite_master"}, {DEFAULT_SCHEMA, "sqlite_temp_schema", "SELECT * FROM sqlite_master"}, @@ -26,50 +31,204 @@ static const DefaultView internal_views[] = { {DEFAULT_SCHEMA, "duckdb_tables", "SELECT * FROM duckdb_tables() WHERE NOT internal"}, {DEFAULT_SCHEMA, "duckdb_types", "SELECT * FROM duckdb_types()"}, {DEFAULT_SCHEMA, "duckdb_views", "SELECT * FROM duckdb_views() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_logs", "SELECT * FROM duckdb_logs(denormalized_table=true)"}, + {DEFAULT_SCHEMA, "duckdb_logs", "SELECT * FROM duckdb_logs(denormalized_table=true)"}, {"pg_catalog", "pg_am", "SELECT 0 oid, 'art' amname, NULL amhandler, 'i' amtype"}, - {"pg_catalog", "pg_prepared_statements", "SELECT name, statement, NULL prepare_time, parameter_types, result_types, NULL from_sql, NULL generic_plans, NULL custom_plans from duckdb_prepared_statements()"}, - {"pg_catalog", "pg_attribute", "SELECT table_oid attrelid, column_name attname, data_type_id atttypid, 0 attstattarget, NULL attlen, column_index attnum, 0 attndims, -1 attcacheoff, case when data_type ilike '%decimal%' then numeric_precision*1000+numeric_scale else -1 end atttypmod, false attbyval, NULL attstorage, NULL attalign, NOT is_nullable attnotnull, column_default IS NOT NULL atthasdef, false atthasmissing, '' attidentity, '' attgenerated, false attisdropped, true attislocal, 0 attinhcount, 0 attcollation, NULL attcompression, NULL attacl, NULL attoptions, NULL attfdwoptions, NULL attmissingval FROM duckdb_columns()"}, - {"pg_catalog", "pg_attrdef", "SELECT column_index oid, table_oid adrelid, column_index adnum, column_default adbin from duckdb_columns() where column_default is not null;"}, - {"pg_catalog", "pg_class", "SELECT table_oid oid, table_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, estimated_size::real reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, index_count > 0 relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'r' relkind, column_count relnatts, check_constraint_count relchecks, false relhasoids, has_primary_key relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_tables() UNION ALL SELECT view_oid oid, view_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'v' relkind, column_count relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_views() UNION ALL SELECT sequence_oid oid, sequence_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'S' relkind, 0 relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_sequences() UNION ALL SELECT index_oid oid, index_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, 't' relpersistence, 'i' relkind, NULL relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_indexes()"}, - {"pg_catalog", "pg_constraint", "SELECT table_oid*1000000+constraint_index oid, constraint_text conname, schema_oid connamespace, CASE constraint_type WHEN 'CHECK' then 'c' WHEN 'UNIQUE' then 'u' WHEN 'PRIMARY KEY' THEN 'p' WHEN 'FOREIGN KEY' THEN 'f' ELSE 'x' END contype, false condeferrable, false condeferred, true convalidated, table_oid conrelid, 0 contypid, 0 conindid, 0 conparentid, 0 confrelid, NULL confupdtype, NULL confdeltype, NULL confmatchtype, true conislocal, 0 coninhcount, false connoinherit, constraint_column_indexes conkey, NULL confkey, NULL conpfeqop, NULL conppeqop, NULL conffeqop, NULL conexclop, expression conbin FROM duckdb_constraints()"}, + {"pg_catalog", "pg_prepared_statements", + "SELECT name, statement, NULL prepare_time, parameter_types, result_types, NULL from_sql, NULL generic_plans, " + "NULL custom_plans from duckdb_prepared_statements()"}, + {"pg_catalog", "pg_attribute", + "SELECT table_oid attrelid, column_name attname, data_type_id atttypid, 0 attstattarget, NULL attlen, " + "column_index attnum, 0 attndims, -1 attcacheoff, case when data_type ilike '%decimal%' then " + "numeric_precision*1000+numeric_scale else -1 end atttypmod, false attbyval, NULL attstorage, NULL attalign, NOT " + "is_nullable attnotnull, column_default IS NOT NULL atthasdef, false atthasmissing, '' attidentity, '' " + "attgenerated, false attisdropped, true attislocal, 0 attinhcount, 0 attcollation, NULL attcompression, NULL " + "attacl, NULL attoptions, NULL attfdwoptions, NULL attmissingval FROM duckdb_columns()"}, + {"pg_catalog", "pg_attrdef", + "SELECT column_index oid, table_oid adrelid, column_index adnum, column_default adbin from duckdb_columns() where " + "column_default is not null;"}, + {"pg_catalog", "pg_class", + "SELECT table_oid oid, table_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, " + "0 relfilenode, 0 reltablespace, 0 relpages, estimated_size::real reltuples, 0 relallvisible, 0 reltoastrelid, 0 " + "reltoastidxid, index_count > 0 relhasindex, false relisshared, case when temporary then 't' else 'p' end " + "relpersistence, 'r' relkind, column_count relnatts, check_constraint_count relchecks, false relhasoids, " + "has_primary_key relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, " + "true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, " + "NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_tables() UNION ALL SELECT view_oid oid, view_name " + "relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 " + "relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, " + "case when temporary then 't' else 'p' end relpersistence, 'v' relkind, column_count relnatts, 0 relchecks, false " + "relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false " + "relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL " + "relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_views() UNION ALL SELECT sequence_oid " + "oid, sequence_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, " + "0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, " + "false relisshared, case when temporary then 't' else 'p' end relpersistence, 'S' relkind, 0 relnatts, 0 " + "relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, " + "false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 " + "relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_sequences() UNION ALL " + "SELECT index_oid oid, index_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, " + "0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, " + "false relhasindex, false relisshared, 't' relpersistence, 'i' relkind, NULL relnatts, 0 relchecks, false " + "relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false " + "relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL " + "relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_indexes()"}, + {"pg_catalog", "pg_constraint", + "SELECT table_oid*1000000+constraint_index oid, constraint_text conname, schema_oid connamespace, CASE " + "constraint_type WHEN 'CHECK' then 'c' WHEN 'UNIQUE' then 'u' WHEN 'PRIMARY KEY' THEN 'p' WHEN 'FOREIGN KEY' THEN " + "'f' ELSE 'x' END contype, false condeferrable, false condeferred, true convalidated, table_oid conrelid, 0 " + "contypid, 0 conindid, 0 conparentid, 0 confrelid, NULL confupdtype, NULL confdeltype, NULL confmatchtype, true " + "conislocal, 0 coninhcount, false connoinherit, constraint_column_indexes conkey, NULL confkey, NULL conpfeqop, " + "NULL conppeqop, NULL conffeqop, NULL conexclop, expression conbin FROM duckdb_constraints()"}, {"pg_catalog", "pg_collation", "SELECT NULL::OID oid, NULL::VARCHAR collname WHERE FALSE"}, - {"pg_catalog", "pg_database", "SELECT database_oid oid, database_name datname FROM duckdb_databases()"}, + {"pg_catalog", "pg_database", + "SELECT database_oid oid, database_name datname, true::boolean datallowconn, false::boolean AS datistemplate FROM " + "duckdb_databases()"}, {"pg_catalog", "pg_depend", "SELECT * FROM duckdb_dependencies()"}, - {"pg_catalog", "pg_description", "SELECT table_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_tables() WHERE NOT internal UNION ALL SELECT table_oid AS objoid, database_oid AS classoid, column_index AS objsubid, comment AS description FROM duckdb_columns() WHERE NOT internal UNION ALL SELECT view_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_views() WHERE NOT internal UNION ALL SELECT index_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_indexes UNION ALL SELECT sequence_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_sequences() UNION ALL SELECT type_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_types() WHERE NOT internal UNION ALL SELECT function_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_functions() WHERE NOT internal;"}, - {"pg_catalog", "pg_enum", "SELECT NULL oid, a.type_oid enumtypid, list_position(b.labels, a.elabel) enumsortorder, a.elabel enumlabel FROM (SELECT UNNEST(labels) elabel, type_oid FROM duckdb_types() WHERE logical_type='ENUM') a JOIN duckdb_types() b ON a.type_oid=b.type_oid;"}, - {"pg_catalog", "pg_index", "SELECT index_oid indexrelid, table_oid indrelid, 0 indnatts, 0 indnkeyatts, is_unique indisunique, is_primary indisprimary, false indisexclusion, true indimmediate, false indisclustered, true indisvalid, false indcheckxmin, true indisready, true indislive, false indisreplident, NULL::INT[] indkey, NULL::OID[] indcollation, NULL::OID[] indclass, NULL::INT[] indoption, expressions indexprs, NULL indpred FROM duckdb_indexes()"}, - {"pg_catalog", "pg_indexes", "SELECT schema_name schemaname, table_name tablename, index_name indexname, NULL \"tablespace\", sql indexdef FROM duckdb_indexes()"}, - {"pg_catalog", "pg_namespace", "SELECT oid, schema_name nspname, 0 nspowner, NULL nspacl FROM duckdb_schemas() where database_name=current_database()"}, - {"pg_catalog", "pg_proc", "SELECT f.function_oid oid, function_name proname, s.oid pronamespace, NULL proowner, NULL prolang, 0 procost, 0 prorows, varargs provariadic, 0 prosupport, CASE function_type WHEN 'aggregate' THEN 'a' ELSE 'f' END prokind, false prosecdef, false proleakproof, false proisstrict, function_type = 'table' proretset, case (stability) when 'CONSISTENT' then 'i' when 'CONSISTENT_WITHIN_QUERY' then 's' when 'VOLATILE' then 'v' end provolatile, 'u' proparallel, length(parameters) pronargs, 0 pronargdefaults, return_type prorettype, parameter_types proargtypes, NULL proallargtypes, NULL proargmodes, parameters proargnames, NULL proargdefaults, NULL protrftypes, NULL prosrc, NULL probin, macro_definition prosqlbody, NULL proconfig, NULL proacl, function_type = 'aggregate' proisagg, FROM duckdb_functions() f LEFT JOIN duckdb_schemas() s USING (database_name, schema_name)"}, - {"pg_catalog", "pg_sequence", "SELECT sequence_oid seqrelid, 0 seqtypid, start_value seqstart, increment_by seqincrement, max_value seqmax, min_value seqmin, 0 seqcache, cycle seqcycle FROM duckdb_sequences()"}, - {"pg_catalog", "pg_sequences", "SELECT schema_name schemaname, sequence_name sequencename, 'duckdb' sequenceowner, 0 data_type, start_value, min_value, max_value, increment_by, cycle, 0 cache_size, last_value FROM duckdb_sequences()"}, - {"pg_catalog", "pg_settings", "SELECT name, value setting, description short_desc, CASE WHEN input_type = 'VARCHAR' THEN 'string' WHEN input_type = 'BOOLEAN' THEN 'bool' WHEN input_type IN ('BIGINT', 'UBIGINT') THEN 'integer' ELSE input_type END vartype FROM duckdb_settings()"}, - {"pg_catalog", "pg_tables", "SELECT schema_name schemaname, table_name tablename, 'duckdb' tableowner, NULL \"tablespace\", index_count > 0 hasindexes, false hasrules, false hastriggers FROM duckdb_tables()"}, + {"pg_catalog", "pg_description", + "SELECT table_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_tables() " + "WHERE NOT internal UNION ALL SELECT table_oid AS objoid, database_oid AS classoid, column_index AS objsubid, " + "comment AS description FROM duckdb_columns() WHERE NOT internal UNION ALL SELECT view_oid AS objoid, " + "database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_views() WHERE NOT internal UNION ALL " + "SELECT index_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_indexes " + "UNION ALL SELECT sequence_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM " + "duckdb_sequences() UNION ALL SELECT type_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS " + "description FROM duckdb_types() WHERE NOT internal UNION ALL SELECT function_oid AS objoid, database_oid AS " + "classoid, 0 AS objsubid, comment AS description FROM duckdb_functions() WHERE NOT internal;"}, + {"pg_catalog", "pg_enum", + "SELECT NULL oid, a.type_oid enumtypid, list_position(b.labels, a.elabel) enumsortorder, a.elabel enumlabel FROM " + "(SELECT UNNEST(labels) elabel, type_oid FROM duckdb_types() WHERE logical_type='ENUM') a JOIN duckdb_types() b " + "ON a.type_oid=b.type_oid;"}, + {"pg_catalog", "pg_index", + "SELECT index_oid indexrelid, table_oid indrelid, 0 indnatts, 0 indnkeyatts, is_unique indisunique, is_primary " + "indisprimary, false indisexclusion, true indimmediate, false indisclustered, true indisvalid, false " + "indcheckxmin, true indisready, true indislive, false indisreplident, NULL::INT[] indkey, NULL::OID[] " + "indcollation, NULL::OID[] indclass, NULL::INT[] indoption, expressions indexprs, NULL indpred FROM " + "duckdb_indexes()"}, + {"pg_catalog", "pg_indexes", + "SELECT schema_name schemaname, table_name tablename, index_name indexname, NULL \"tablespace\", sql indexdef " + "FROM duckdb_indexes()"}, + {"pg_catalog", "pg_namespace", + "SELECT oid, schema_name nspname, 0 nspowner, NULL nspacl FROM duckdb_schemas() where " + "database_name=current_database()"}, + {"pg_catalog", "pg_proc", + "SELECT f.function_oid oid, function_name proname, s.oid pronamespace, NULL proowner, NULL prolang, 0 procost, 0 " + "prorows, varargs provariadic, 0 prosupport, CASE function_type WHEN 'aggregate' THEN 'a' ELSE 'f' END prokind, " + "false prosecdef, false proleakproof, false proisstrict, function_type = 'table' proretset, case (stability) " + "when 'CONSISTENT' then 'i' when 'CONSISTENT_WITHIN_QUERY' then 's' when 'VOLATILE' then 'v' end provolatile, 'u' " + "proparallel, length(parameters) pronargs, 0 pronargdefaults, return_type prorettype, parameter_types " + "proargtypes, NULL proallargtypes, NULL proargmodes, parameters proargnames, NULL proargdefaults, NULL " + "protrftypes, NULL prosrc, NULL probin, macro_definition prosqlbody, NULL proconfig, NULL proacl, function_type = " + "'aggregate' proisagg, FROM duckdb_functions() f LEFT JOIN duckdb_schemas() s USING (database_name, " + "schema_name)"}, + {"pg_catalog", "pg_sequence", + "SELECT sequence_oid seqrelid, 0 seqtypid, start_value seqstart, increment_by seqincrement, max_value seqmax, " + "min_value seqmin, 0 seqcache, cycle seqcycle FROM duckdb_sequences()"}, + {"pg_catalog", "pg_sequences", + "SELECT schema_name schemaname, sequence_name sequencename, 'duckdb' sequenceowner, 0 data_type, start_value, " + "min_value, max_value, increment_by, cycle, 0 cache_size, last_value FROM duckdb_sequences()"}, + {"pg_catalog", "pg_settings", + "SELECT name, value setting, description short_desc, CASE WHEN input_type = 'VARCHAR' THEN 'string' WHEN " + "input_type = 'BOOLEAN' THEN 'bool' WHEN input_type IN ('BIGINT', 'UBIGINT') THEN 'integer' ELSE input_type END " + "vartype FROM duckdb_settings()"}, + {"pg_catalog", "pg_tables", + "SELECT schema_name schemaname, table_name tablename, 'duckdb' tableowner, NULL \"tablespace\", index_count > 0 " + "hasindexes, false hasrules, false hastriggers FROM duckdb_tables()"}, {"pg_catalog", "pg_tablespace", "SELECT 0 oid, 'pg_default' spcname, 0 spcowner, NULL spcacl, NULL spcoptions"}, - {"pg_catalog", "pg_type", "SELECT CASE WHEN type_oid IS NULL THEN NULL WHEN logical_type = 'ENUM' AND type_name <> 'enum' THEN type_oid ELSE map_to_pg_oid(type_name) END oid, format_pg_type(logical_type, type_name) typname, schema_oid typnamespace, 0 typowner, type_size typlen, false typbyval, CASE WHEN logical_type='ENUM' THEN 'e' else 'b' end typtype, CASE WHEN type_category='NUMERIC' THEN 'N' WHEN type_category='STRING' THEN 'S' WHEN type_category='DATETIME' THEN 'D' WHEN type_category='BOOLEAN' THEN 'B' WHEN type_category='COMPOSITE' THEN 'C' WHEN type_category='USER' THEN 'U' ELSE 'X' END typcategory, false typispreferred, true typisdefined, NULL typdelim, NULL typrelid, NULL typsubscript, NULL typelem, NULL typarray, NULL typinput, NULL typoutput, NULL typreceive, NULL typsend, NULL typmodin, NULL typmodout, NULL typanalyze, 'd' typalign, 'p' typstorage, NULL typnotnull, NULL typbasetype, NULL typtypmod, NULL typndims, NULL typcollation, NULL typdefaultbin, NULL typdefault, NULL typacl FROM duckdb_types() WHERE type_oid IS NOT NULL;"}, - {"pg_catalog", "pg_views", "SELECT schema_name schemaname, view_name viewname, 'duckdb' viewowner, sql definition FROM duckdb_views()"}, - {"information_schema", "columns", "SELECT database_name table_catalog, schema_name table_schema, table_name, column_name, column_index ordinal_position, column_default, CASE WHEN is_nullable THEN 'YES' ELSE 'NO' END is_nullable, data_type, character_maximum_length, NULL::INT character_octet_length, numeric_precision, numeric_precision_radix, numeric_scale, NULL::INT datetime_precision, NULL::VARCHAR interval_type, NULL::INT interval_precision, NULL::VARCHAR character_set_catalog, NULL::VARCHAR character_set_schema, NULL::VARCHAR character_set_name, NULL::VARCHAR collation_catalog, NULL::VARCHAR collation_schema, NULL::VARCHAR collation_name, NULL::VARCHAR domain_catalog, NULL::VARCHAR domain_schema, NULL::VARCHAR domain_name, NULL::VARCHAR udt_catalog, NULL::VARCHAR udt_schema, NULL::VARCHAR udt_name, NULL::VARCHAR scope_catalog, NULL::VARCHAR scope_schema, NULL::VARCHAR scope_name, NULL::BIGINT maximum_cardinality, NULL::VARCHAR dtd_identifier, NULL::BOOL is_self_referencing, NULL::BOOL is_identity, NULL::VARCHAR identity_generation, NULL::VARCHAR identity_start, NULL::VARCHAR identity_increment, NULL::VARCHAR identity_maximum, NULL::VARCHAR identity_minimum, NULL::BOOL identity_cycle, NULL::VARCHAR is_generated, NULL::VARCHAR generation_expression, NULL::BOOL is_updatable, comment AS COLUMN_COMMENT FROM duckdb_columns;"}, - {"information_schema", "schemata", "SELECT database_name catalog_name, schema_name, 'duckdb' schema_owner, NULL::VARCHAR default_character_set_catalog, NULL::VARCHAR default_character_set_schema, NULL::VARCHAR default_character_set_name, sql sql_path FROM duckdb_schemas()"}, - {"information_schema", "tables", "SELECT database_name table_catalog, schema_name table_schema, table_name, CASE WHEN temporary THEN 'LOCAL TEMPORARY' ELSE 'BASE TABLE' END table_type, NULL::VARCHAR self_referencing_column_name, NULL::VARCHAR reference_generation, NULL::VARCHAR user_defined_type_catalog, NULL::VARCHAR user_defined_type_schema, NULL::VARCHAR user_defined_type_name, 'YES' is_insertable_into, 'NO' is_typed, CASE WHEN temporary THEN 'PRESERVE' ELSE NULL END commit_action, comment AS TABLE_COMMENT FROM duckdb_tables() UNION ALL SELECT database_name table_catalog, schema_name table_schema, view_name table_name, 'VIEW' table_type, NULL self_referencing_column_name, NULL reference_generation, NULL user_defined_type_catalog, NULL user_defined_type_schema, NULL user_defined_type_name, 'NO' is_insertable_into, 'NO' is_typed, NULL commit_action, comment AS TABLE_COMMENT FROM duckdb_views;"}, - {"information_schema", "character_sets", "SELECT NULL::VARCHAR character_set_catalog, NULL::VARCHAR character_set_schema, 'UTF8' character_set_name, 'UCS' character_repertoire, 'UTF8' form_of_use, current_database() default_collate_catalog, 'pg_catalog' default_collate_schema, 'ucs_basic' default_collate_name;"}, - {"information_schema", "referential_constraints", "SELECT f.database_name constraint_catalog, f.schema_name constraint_schema, f.constraint_name constraint_name, c.database_name unique_constraint_catalog, c.schema_name unique_constraint_schema, c.constraint_name unique_constraint_name, 'NONE' match_option, 'NO ACTION' update_rule, 'NO ACTION' delete_rule FROM duckdb_constraints() c, duckdb_constraints() f WHERE f.constraint_type = 'FOREIGN KEY' AND (c.constraint_type = 'UNIQUE' OR c.constraint_type = 'PRIMARY KEY') AND f.database_oid = c.database_oid AND f.schema_oid = c.schema_oid AND lower(f.referenced_table) = lower(c.table_name) AND [lower(x) for x in f.referenced_column_names] = [lower(x) for x in c.constraint_column_names]"}, - {"information_schema", "key_column_usage", "SELECT database_name constraint_catalog, schema_name constraint_schema, constraint_name, database_name table_catalog, schema_name table_schema, table_name, UNNEST(constraint_column_names) column_name, UNNEST(generate_series(1, len(constraint_column_names))) ordinal_position, CASE constraint_type WHEN 'FOREIGN KEY' THEN UNNEST (generate_series(1, len(constraint_column_names))) ELSE NULL END position_in_unique_constraint FROM duckdb_constraints() WHERE constraint_type = 'FOREIGN KEY' OR constraint_type = 'PRIMARY KEY' OR constraint_type = 'UNIQUE';"}, - {"information_schema", "table_constraints", "SELECT database_name constraint_catalog, schema_name constraint_schema, constraint_name, database_name table_catalog, schema_name table_schema, table_name, CASE constraint_type WHEN 'NOT NULL' THEN 'CHECK' ELSE constraint_type END constraint_type, 'NO' is_deferrable, 'NO' initially_deferred, 'YES' enforced, 'YES' nulls_distinct FROM duckdb_constraints() WHERE constraint_type = 'PRIMARY KEY' OR constraint_type = 'FOREIGN KEY' OR constraint_type = 'UNIQUE' OR constraint_type = 'CHECK' OR constraint_type = 'NOT NULL';"}, - {"information_schema", "constraint_column_usage", "SELECT database_name AS table_catalog, schema_name AS table_schema, table_name, column_name, database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, constraint_type, constraint_text FROM (SELECT dc.*, UNNEST(dc.constraint_column_names) AS column_name FROM duckdb_constraints() AS dc WHERE constraint_type NOT IN ('NOT NULL') );"}, - {"information_schema", "constraint_table_usage", "SELECT database_name AS table_catalog, schema_name AS table_schema, table_name, database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, constraint_type FROM duckdb_constraints() WHERE constraint_type NOT IN ('NOT NULL');"}, - {"information_schema", "check_constraints", "SELECT database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, CASE constraint_type WHEN 'NOT NULL' THEN column_name || ' IS NOT NULL' ELSE constraint_text END AS check_clause FROM (SELECT dc.*, UNNEST(dc.constraint_column_names) AS column_name FROM duckdb_constraints() AS dc WHERE constraint_type IN ('CHECK', 'NOT NULL'));"}, - {"information_schema", "views", "SELECT database_name AS table_catalog, schema_name AS table_schema, view_name AS table_name, sql AS view_definition, 'NONE' AS check_option, 'NO' AS is_updatable, 'NO' AS is_insertable_into, 'NO' AS is_trigger_updatable, 'NO' AS is_trigger_deletable, 'NO' AS is_trigger_insertable_into FROM duckdb_views();"}, + {"pg_catalog", "pg_type", + "SELECT CASE WHEN type_oid IS NULL THEN NULL WHEN logical_type = 'ENUM' AND type_name <> 'enum' THEN type_oid " + "ELSE map_to_pg_oid(type_name) END oid, format_pg_type(logical_type, type_name) typname, schema_oid typnamespace, " + "0 typowner, type_size typlen, false typbyval, CASE WHEN logical_type='ENUM' THEN 'e' else 'b' end typtype, CASE " + "WHEN type_category='NUMERIC' THEN 'N' WHEN type_category='STRING' THEN 'S' WHEN type_category='DATETIME' THEN " + "'D' WHEN type_category='BOOLEAN' THEN 'B' WHEN type_category='COMPOSITE' THEN 'C' WHEN type_category='USER' THEN " + "'U' ELSE 'X' END typcategory, false typispreferred, true typisdefined, NULL typdelim, NULL typrelid, NULL " + "typsubscript, NULL typelem, NULL typarray, NULL typinput, NULL typoutput, NULL typreceive, NULL typsend, NULL " + "typmodin, NULL typmodout, NULL typanalyze, 'd' typalign, 'p' typstorage, NULL typnotnull, NULL typbasetype, NULL " + "typtypmod, NULL typndims, NULL typcollation, NULL typdefaultbin, NULL typdefault, NULL typacl FROM " + "duckdb_types() WHERE type_oid IS NOT NULL;"}, + {"pg_catalog", "pg_views", + "SELECT schema_name schemaname, view_name viewname, 'duckdb' viewowner, sql definition FROM duckdb_views()"}, + {"information_schema", "columns", + "SELECT database_name table_catalog, schema_name table_schema, table_name, column_name, column_index " + "ordinal_position, column_default, CASE WHEN is_nullable THEN 'YES' ELSE 'NO' END is_nullable, data_type, " + "character_maximum_length, NULL::INT character_octet_length, numeric_precision, numeric_precision_radix, " + "numeric_scale, NULL::INT datetime_precision, NULL::VARCHAR interval_type, NULL::INT interval_precision, " + "NULL::VARCHAR character_set_catalog, NULL::VARCHAR character_set_schema, NULL::VARCHAR character_set_name, " + "NULL::VARCHAR collation_catalog, NULL::VARCHAR collation_schema, NULL::VARCHAR collation_name, NULL::VARCHAR " + "domain_catalog, NULL::VARCHAR domain_schema, NULL::VARCHAR domain_name, NULL::VARCHAR udt_catalog, NULL::VARCHAR " + "udt_schema, NULL::VARCHAR udt_name, NULL::VARCHAR scope_catalog, NULL::VARCHAR scope_schema, NULL::VARCHAR " + "scope_name, NULL::BIGINT maximum_cardinality, NULL::VARCHAR dtd_identifier, NULL::BOOL is_self_referencing, " + "NULL::BOOL is_identity, NULL::VARCHAR identity_generation, NULL::VARCHAR identity_start, NULL::VARCHAR " + "identity_increment, NULL::VARCHAR identity_maximum, NULL::VARCHAR identity_minimum, NULL::BOOL identity_cycle, " + "CASE WHEN is_generated THEN 'ALWAYS' ELSE 'NEVER' END is_generated, generation_expression, NULL::BOOL " + "is_updatable, comment AS COLUMN_COMMENT FROM duckdb_columns;"}, + {"information_schema", "schemata", + "SELECT database_name catalog_name, schema_name, 'duckdb' schema_owner, NULL::VARCHAR " + "default_character_set_catalog, NULL::VARCHAR default_character_set_schema, NULL::VARCHAR " + "default_character_set_name, sql sql_path FROM duckdb_schemas()"}, + {"information_schema", "tables", + "SELECT database_name table_catalog, schema_name table_schema, table_name, CASE WHEN temporary THEN 'LOCAL " + "TEMPORARY' ELSE 'BASE TABLE' END table_type, NULL::VARCHAR self_referencing_column_name, NULL::VARCHAR " + "reference_generation, NULL::VARCHAR user_defined_type_catalog, NULL::VARCHAR user_defined_type_schema, " + "NULL::VARCHAR user_defined_type_name, 'YES' is_insertable_into, 'NO' is_typed, CASE WHEN temporary THEN " + "'PRESERVE' ELSE NULL END commit_action, comment AS TABLE_COMMENT FROM duckdb_tables() UNION ALL SELECT " + "database_name table_catalog, schema_name table_schema, view_name table_name, 'VIEW' table_type, NULL " + "self_referencing_column_name, NULL reference_generation, NULL user_defined_type_catalog, NULL " + "user_defined_type_schema, NULL user_defined_type_name, 'NO' is_insertable_into, 'NO' is_typed, NULL " + "commit_action, comment AS TABLE_COMMENT FROM duckdb_views;"}, + {"information_schema", "character_sets", + "SELECT NULL::VARCHAR character_set_catalog, NULL::VARCHAR character_set_schema, 'UTF8' character_set_name, 'UCS' " + "character_repertoire, 'UTF8' form_of_use, current_database() default_collate_catalog, 'pg_catalog' " + "default_collate_schema, 'ucs_basic' default_collate_name;"}, + {"information_schema", "referential_constraints", + "SELECT f.database_name constraint_catalog, f.schema_name constraint_schema, f.constraint_name constraint_name, " + "c.database_name unique_constraint_catalog, c.schema_name unique_constraint_schema, c.constraint_name " + "unique_constraint_name, 'NONE' match_option, 'NO ACTION' update_rule, 'NO ACTION' delete_rule FROM " + "duckdb_constraints() c, duckdb_constraints() f WHERE f.constraint_type = 'FOREIGN KEY' AND (c.constraint_type = " + "'UNIQUE' OR c.constraint_type = 'PRIMARY KEY') AND f.database_oid = c.database_oid AND f.schema_oid = " + "c.schema_oid AND lower(f.referenced_table) = lower(c.table_name) AND [lower(x) for x in " + "f.referenced_column_names] = [lower(x) for x in c.constraint_column_names]"}, + {"information_schema", "key_column_usage", + "SELECT database_name constraint_catalog, schema_name constraint_schema, constraint_name, database_name " + "table_catalog, schema_name table_schema, table_name, UNNEST(constraint_column_names) column_name, " + "UNNEST(generate_series(1, len(constraint_column_names))) ordinal_position, CASE constraint_type WHEN 'FOREIGN " + "KEY' THEN UNNEST (generate_series(1, len(constraint_column_names))) ELSE NULL END position_in_unique_constraint " + "FROM duckdb_constraints() WHERE constraint_type = 'FOREIGN KEY' OR constraint_type = 'PRIMARY KEY' OR " + "constraint_type = 'UNIQUE';"}, + {"information_schema", "table_constraints", + "SELECT database_name constraint_catalog, schema_name constraint_schema, constraint_name, database_name " + "table_catalog, schema_name table_schema, table_name, CASE constraint_type WHEN 'NOT NULL' THEN 'CHECK' ELSE " + "constraint_type END constraint_type, 'NO' is_deferrable, 'NO' initially_deferred, 'YES' enforced, 'YES' " + "nulls_distinct FROM duckdb_constraints() WHERE constraint_type = 'PRIMARY KEY' OR constraint_type = 'FOREIGN " + "KEY' OR constraint_type = 'UNIQUE' OR constraint_type = 'CHECK' OR constraint_type = 'NOT NULL';"}, + {"information_schema", "constraint_column_usage", + "SELECT database_name AS table_catalog, schema_name AS table_schema, table_name, column_name, database_name AS " + "constraint_catalog, schema_name AS constraint_schema, constraint_name, constraint_type, constraint_text FROM " + "(SELECT dc.*, UNNEST(dc.constraint_column_names) AS column_name FROM duckdb_constraints() AS dc WHERE " + "constraint_type NOT IN ('NOT NULL') );"}, + {"information_schema", "constraint_table_usage", + "SELECT database_name AS table_catalog, schema_name AS table_schema, table_name, database_name AS " + "constraint_catalog, schema_name AS constraint_schema, constraint_name, constraint_type FROM duckdb_constraints() " + "WHERE constraint_type NOT IN ('NOT NULL');"}, + {"information_schema", "check_constraints", + "SELECT database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, CASE " + "constraint_type WHEN 'NOT NULL' THEN column_name || ' IS NOT NULL' ELSE constraint_text END AS check_clause FROM " + "(SELECT dc.*, UNNEST(dc.constraint_column_names) AS column_name FROM duckdb_constraints() AS dc WHERE " + "constraint_type IN ('CHECK', 'NOT NULL'));"}, + {"information_schema", "views", + "SELECT database_name AS table_catalog, schema_name AS table_schema, view_name AS table_name, sql AS " + "view_definition, 'NONE' AS check_option, 'NO' AS is_updatable, 'NO' AS is_insertable_into, 'NO' AS " + "is_trigger_updatable, 'NO' AS is_trigger_deletable, 'NO' AS is_trigger_insertable_into FROM duckdb_views();"}, {nullptr, nullptr, nullptr}}; -static unique_ptr GetDefaultView(ClientContext &context, const string &input_schema, const string &input_name) { - auto schema = StringUtil::Lower(input_schema); - auto name = StringUtil::Lower(input_name); +static unique_ptr GetDefaultView(ClientContext &context, const Identifier &input_schema, + const Identifier &input_name) { + auto schema = StringUtil::Lower(input_schema.GetIdentifierName()); + auto name = StringUtil::Lower(input_name.GetIdentifierName()); for (idx_t index = 0; internal_views[index].name != nullptr; index++) { if (internal_views[index].schema == schema && internal_views[index].name == name) { auto result = make_uniq(); - result->schema = schema; - result->view_name = name; + result->schema = Identifier(schema); + result->view_name = Identifier(name); result->sql = internal_views[index].sql; result->temporary = true; result->internal = true; @@ -84,7 +243,8 @@ DefaultViewGenerator::DefaultViewGenerator(Catalog &catalog, SchemaCatalogEntry : DefaultGenerator(catalog), schema(schema) { } -unique_ptr DefaultViewGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { +unique_ptr DefaultViewGenerator::CreateDefaultEntry(ClientContext &context, + const Identifier &entry_name) { auto info = GetDefaultView(context, schema.name, entry_name); if (info) { return make_uniq_base(catalog, schema, *info); @@ -92,8 +252,8 @@ unique_ptr DefaultViewGenerator::CreateDefaultEntry(ClientContext return nullptr; } -vector DefaultViewGenerator::GetDefaultEntries() { - vector result; +vector DefaultViewGenerator::GetDefaultEntries() { + vector result; for (idx_t index = 0; internal_views[index].name != nullptr; index++) { if (internal_views[index].schema == schema.name) { result.emplace_back(internal_views[index].name); diff --git a/src/duckdb/src/catalog/dependency_catalog_set.cpp b/src/duckdb/src/catalog/dependency_catalog_set.cpp index bfb3862a4..8bdc16e2b 100644 --- a/src/duckdb/src/catalog/dependency_catalog_set.cpp +++ b/src/duckdb/src/catalog/dependency_catalog_set.cpp @@ -33,7 +33,7 @@ void DependencyCatalogSet::Scan(CatalogTransaction transaction, const std::funct [&](CatalogEntry &entry) { auto &dep = entry.Cast(); auto &from = dep.SourceMangledName(); - if (!StringUtil::CIEquals(from.name, mangled_name.name)) { + if (from.name != mangled_name.name) { return; } callback(entry); diff --git a/src/duckdb/src/catalog/dependency_list.cpp b/src/duckdb/src/catalog/dependency_list.cpp index 2c2bc90d3..59d088d91 100644 --- a/src/duckdb/src/catalog/dependency_list.cpp +++ b/src/duckdb/src/catalog/dependency_list.cpp @@ -42,26 +42,26 @@ LogicalDependency::LogicalDependency() : entry(), catalog() { static string GetSchema(CatalogEntry &entry) { if (entry.type == CatalogType::SCHEMA_ENTRY) { - return entry.name; + return entry.name.GetIdentifierName(); } - return entry.ParentSchema().name; + return entry.ParentSchema().name.GetIdentifierName(); } LogicalDependency::LogicalDependency(CatalogEntry &entry) { - catalog = INVALID_CATALOG; + catalog = Identifier::InvalidCatalog(); if (entry.type == CatalogType::DEPENDENCY_ENTRY) { auto &dependency_entry = entry.Cast(); this->entry = dependency_entry.EntryInfo(); } else { - this->entry.schema = GetSchema(entry); + this->entry.schema = Identifier(GetSchema(entry)); this->entry.name = entry.name; this->entry.type = entry.type; catalog = entry.ParentCatalog().GetName(); } } -LogicalDependency::LogicalDependency(optional_ptr catalog_p, CatalogEntryInfo entry_p, string catalog_str) +LogicalDependency::LogicalDependency(optional_ptr catalog_p, CatalogEntryInfo entry_p, Identifier catalog_str) : entry(std::move(entry_p)), catalog(std::move(catalog_str)) { if (catalog_p) { catalog = catalog_p->GetName(); @@ -86,7 +86,7 @@ bool LogicalDependencyList::Contains(CatalogEntry &entry_p) { return set.count(logical_entry); } -void LogicalDependencyList::VerifyDependencies(Catalog &catalog, const string &name) { +void LogicalDependencyList::VerifyDependencies(Catalog &catalog, const Identifier &name) { for (auto &dep : set) { if (dep.catalog != catalog.GetName()) { throw DependencyException( diff --git a/src/duckdb/src/catalog/dependency_manager.cpp b/src/duckdb/src/catalog/dependency_manager.cpp index 6c7c669b3..540cb11ef 100644 --- a/src/duckdb/src/catalog/dependency_manager.cpp +++ b/src/duckdb/src/catalog/dependency_manager.cpp @@ -35,19 +35,19 @@ MangledEntryName::MangledEntryName(const CatalogEntryInfo &info) { auto &schema = info.schema; auto &name = info.name; - this->name = CatalogTypeToString(type) + '\0' + schema + '\0' + name; - AssertMangledName(this->name, 2); + this->name = Identifier(CatalogTypeToString(type) + '\0' + schema + '\0' + name); + AssertMangledName(this->name.GetIdentifierName(), 2); } MangledDependencyName::MangledDependencyName(const MangledEntryName &from, const MangledEntryName &to) { - this->name = from.name + '\0' + to.name; - AssertMangledName(this->name, 5); + this->name = Identifier(from.name + '\0' + to.name); + AssertMangledName(this->name.GetIdentifierName(), 5); } DependencyManager::DependencyManager(DuckCatalog &catalog) : catalog(catalog), subjects(catalog), dependents(catalog) { } -string DependencyManager::GetSchema(const CatalogEntry &entry) { +Identifier DependencyManager::GetSchema(const CatalogEntry &entry) { if (entry.type == CatalogType::SCHEMA_ENTRY) { return entry.name; } @@ -66,7 +66,7 @@ MangledEntryName DependencyManager::MangleName(const CatalogEntry &entry) { auto type = entry.type; auto schema = GetSchema(entry); auto name = entry.name; - CatalogEntryInfo info {type, schema, name}; + CatalogEntryInfo info {type, Identifier(schema), name}; return MangleName(info); } @@ -311,7 +311,7 @@ CatalogEntryInfo DependencyManager::GetLookupProperties(const CatalogEntry &entr auto schema = DependencyManager::GetSchema(entry); auto &name = entry.name; auto &type = entry.type; - return CatalogEntryInfo {type, schema, name}; + return CatalogEntryInfo {type, Identifier(schema), name}; } } @@ -697,7 +697,7 @@ void DependencyManager::AlterObject(CatalogTransaction transaction, CatalogEntry dependencies.emplace_back(dep_info); }); - if (has_new_dependencies || !StringUtil::CIEquals(old_obj.name, new_obj.name)) { + if (has_new_dependencies || !(old_obj.name == new_obj.name)) { // The dependencies have changed (e.g. SET DEFAULT) or the name has changed // We need to recreate the dependency links CleanupDependencies(transaction, old_obj); @@ -796,7 +796,7 @@ void DependencyManager::AddOwnership(CatalogTransaction transaction, CatalogEntr } static string FormatString(const MangledEntryName &mangled) { - auto input = mangled.name; + auto input = mangled.name.GetIdentifierName(); for (size_t i = 0; i < input.size(); i++) { if (input[i] == '\0') { input[i] = '_'; diff --git a/src/duckdb/src/catalog/duck_catalog.cpp b/src/duckdb/src/catalog/duck_catalog.cpp index d5247c1e3..44bb3054a 100644 --- a/src/duckdb/src/catalog/duck_catalog.cpp +++ b/src/duckdb/src/catalog/duck_catalog.cpp @@ -29,7 +29,7 @@ void DuckCatalog::Initialize(bool load_builtin) { // create the default schema CreateSchemaInfo info; - info.schema = DEFAULT_SCHEMA; + info.schema = Identifier::DefaultSchema(); info.internal = true; info.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; CreateSchema(data, info); @@ -134,7 +134,7 @@ optional_ptr DuckCatalog::LookupSchema(CatalogTransaction tr OnEntryNotFound if_not_found) { auto &schema_name = schema_lookup.GetEntryName(); D_ASSERT(!schema_name.empty()); - auto entry = schemas->GetEntry(transaction, schema_name); + auto entry = schemas->GetEntry(transaction, Identifier(schema_name)); if (!entry) { if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { throw CatalogException(schema_lookup.GetErrorContext(), "Schema with name %s does not exist!", schema_name); diff --git a/src/duckdb/src/catalog/entry_lookup_info.cpp b/src/duckdb/src/catalog/entry_lookup_info.cpp index 742fa3486..122365c35 100644 --- a/src/duckdb/src/catalog/entry_lookup_info.cpp +++ b/src/duckdb/src/catalog/entry_lookup_info.cpp @@ -2,17 +2,17 @@ namespace duckdb { -EntryLookupInfo::EntryLookupInfo(CatalogType catalog_type_p, const string &name_p, QueryErrorContext error_context_p) - : catalog_type(catalog_type_p), name(name_p), error_context(error_context_p) { +EntryLookupInfo::EntryLookupInfo(CatalogType catalog_type_p, Identifier name_p, QueryErrorContext error_context_p) + : catalog_type(catalog_type_p), name(std::move(name_p)), error_context(error_context_p) { } -EntryLookupInfo::EntryLookupInfo(CatalogType catalog_type_p, const string &name_p, - optional_ptr at_clause_p, QueryErrorContext error_context_p) - : catalog_type(catalog_type_p), name(name_p), at_clause(at_clause_p), error_context(error_context_p) { +EntryLookupInfo::EntryLookupInfo(CatalogType catalog_type_p, Identifier name_p, optional_ptr at_clause_p, + QueryErrorContext error_context_p) + : catalog_type(catalog_type_p), name(std::move(name_p)), at_clause(at_clause_p), error_context(error_context_p) { } -EntryLookupInfo::EntryLookupInfo(const EntryLookupInfo &parent, const string &name_p) - : catalog_type(parent.catalog_type), name(name_p), at_clause(parent.at_clause), +EntryLookupInfo::EntryLookupInfo(const EntryLookupInfo &parent, Identifier name_p) + : catalog_type(parent.catalog_type), name(std::move(name_p)), at_clause(parent.at_clause), error_context(parent.error_context) { } @@ -21,18 +21,22 @@ EntryLookupInfo::EntryLookupInfo(const EntryLookupInfo &parent, optional_ptrcatalog.GetName(); + result += schema->catalog.GetName().GetIdentifierName(); } if (qualify_schema) { if (!result.empty()) { result += "."; } - result += schema->name; + result += schema->name.GetIdentifierName(); } if (!result.empty()) { result += "."; diff --git a/src/duckdb/src/common/adbc/adbc.cpp b/src/duckdb/src/common/adbc/adbc.cpp index 278d502a9..36843340a 100644 --- a/src/duckdb/src/common/adbc/adbc.cpp +++ b/src/duckdb/src/common/adbc/adbc.cpp @@ -92,7 +92,7 @@ AdbcStatusCode duckdb_adbc_init(int version, void *driver, struct AdbcError *err adbc_driver->ConnectionSetOptionDouble = duckdb_adbc::ConnectionSetOptionDouble; adbc_driver->StatementCancel = duckdb_adbc::StatementCancel; - adbc_driver->StatementExecuteSchema = nullptr; + adbc_driver->StatementExecuteSchema = duckdb_adbc::StatementExecuteSchema; adbc_driver->StatementGetOption = duckdb_adbc::StatementGetOption; adbc_driver->StatementGetOptionBytes = duckdb_adbc::StatementGetOptionBytes; adbc_driver->StatementGetOptionDouble = duckdb_adbc::StatementGetOptionDouble; @@ -1493,6 +1493,62 @@ AdbcStatusCode StatementCancel(struct AdbcStatement *statement, struct AdbcError return ADBC_STATUS_OK; } +AdbcStatusCode StatementExecuteSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcError *error) { + if (!statement) { + SetError(error, "Missing statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!statement->private_data) { + SetError(error, "Invalid statement object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!schema) { + SetError(error, "Missing schema object"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto wrapper = static_cast(statement->private_data); + if (!wrapper->statement) { + SetError(error, "Must call StatementSetSqlQuery before StatementExecuteSchema"); + return ADBC_STATUS_INVALID_STATE; + } + + if (wrapper->conn_wrapper) { + wrapper->conn_wrapper->MaterializeStreams(); + } + + auto count = duckdb_prepared_statement_column_count(wrapper->statement); + std::vector types(count); + std::vector owned_names; + owned_names.reserve(count); + duckdb::vector names(count); + + for (idx_t i = 0; i < count; i++) { + types[i] = duckdb_prepared_statement_column_logical_type(wrapper->statement, i); + auto column_name = duckdb_prepared_statement_column_name(wrapper->statement, i); + owned_names.emplace_back(column_name ? column_name : ""); + names[i] = owned_names.back().c_str(); + duckdb_free(const_cast(column_name)); + } + + duckdb_arrow_options arrow_options; + duckdb_connection_get_arrow_options(wrapper->connection, &arrow_options); + + auto res = duckdb_to_arrow_schema(arrow_options, types.data(), names.data(), count, schema); + + for (auto &type : types) { + duckdb_destroy_logical_type(&type); + } + duckdb_destroy_arrow_options(&arrow_options); + + if (res) { + SetError(error, duckdb_error_data_message(res)); + duckdb_destroy_error_data(&res); + return ADBC_STATUS_INVALID_ARGUMENT; + } + return ADBC_STATUS_OK; +} + AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, struct AdbcError *error) { if (!statement) { diff --git a/src/duckdb/src/common/allocator.cpp b/src/duckdb/src/common/allocator/allocator.cpp similarity index 71% rename from src/duckdb/src/common/allocator.cpp rename to src/duckdb/src/common/allocator/allocator.cpp index 4ee952fc4..25eae3c17 100644 --- a/src/duckdb/src/common/allocator.cpp +++ b/src/duckdb/src/common/allocator/allocator.cpp @@ -9,6 +9,7 @@ #include "duckdb/common/types/timestamp.hpp" #include + #ifdef DUCKDB_DEBUG_ALLOCATION #include "duckdb/common/mutex.hpp" #include "duckdb/common/pair.hpp" @@ -17,21 +18,6 @@ #include #endif -#ifndef USE_JEMALLOC -#if defined(DUCKDB_EXTENSION_JEMALLOC_LINKED) && DUCKDB_EXTENSION_JEMALLOC_LINKED && !defined(WIN32) && \ - INTPTR_MAX == INT64_MAX -#define USE_JEMALLOC -#endif -#endif - -#ifdef USE_JEMALLOC -#include "jemalloc_extension.hpp" -#endif - -#ifdef __GLIBC__ -#include -#endif - namespace duckdb { constexpr const idx_t Allocator::MAXIMUM_ALLOC_SIZE; @@ -180,36 +166,6 @@ data_ptr_t Allocator::ReallocateData(data_ptr_t pointer, idx_t old_size, idx_t s } return new_pointer; } - -data_ptr_t Allocator::DefaultAllocate(PrivateAllocatorData *private_data, idx_t size) { -#ifdef USE_JEMALLOC - return JemallocExtension::Allocate(private_data, size); -#else - auto default_allocate_result = malloc(size); - if (!default_allocate_result) { - throw std::bad_alloc(); - } - return data_ptr_cast(default_allocate_result); -#endif -} - -void Allocator::DefaultFree(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size) { -#ifdef USE_JEMALLOC - JemallocExtension::Free(private_data, pointer, size); -#else - free(pointer); -#endif -} - -data_ptr_t Allocator::DefaultReallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, - idx_t size) { -#ifdef USE_JEMALLOC - return JemallocExtension::Reallocate(private_data, pointer, old_size, size); -#else - return data_ptr_cast(realloc(pointer, size)); -#endif -} - shared_ptr &Allocator::DefaultAllocatorReference() { static shared_ptr DEFAULT_ALLOCATOR = make_shared_ptr(); return DEFAULT_ALLOCATOR; @@ -219,72 +175,6 @@ Allocator &Allocator::DefaultAllocator() { return *DefaultAllocatorReference(); } -optional_idx Allocator::DecayDelay() { -#ifdef USE_JEMALLOC - return NumericCast(JemallocExtension::DecayDelay()); -#else - return optional_idx(); -#endif -} - -bool Allocator::SupportsFlush() { -#if defined(USE_JEMALLOC) || defined(__GLIBC__) - return true; -#else - return false; -#endif -} - -static void MallocTrim(idx_t pad) { -#ifdef __GLIBC__ - static constexpr int64_t TRIM_INTERVAL_MS = 100; - static atomic LAST_TRIM_TIMESTAMP_MS {0}; - - int64_t last_trim_timestamp_ms = LAST_TRIM_TIMESTAMP_MS.load(); - auto current_ts = Timestamp::GetCurrentTimestamp(); - auto current_timestamp_ms = Cast::Operation(current_ts).value; - - if (current_timestamp_ms - last_trim_timestamp_ms < TRIM_INTERVAL_MS) { - return; // We trimmed less than TRIM_INTERVAL_MS ago - } - if (!LAST_TRIM_TIMESTAMP_MS.compare_exchange_strong(last_trim_timestamp_ms, current_timestamp_ms, - std::memory_order_acquire, std::memory_order_relaxed)) { - return; // Another thread has updated LAST_TRIM_TIMESTAMP_MS since we loaded it - } - - // We successfully updated LAST_TRIM_TIMESTAMP_MS, we can trim - malloc_trim(pad); -#endif -} - -void Allocator::ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) { -#ifdef USE_JEMALLOC - if (!allocator_background_threads) { - JemallocExtension::ThreadFlush(threshold); - } -#endif - MallocTrim(thread_count * threshold); -} - -void Allocator::ThreadIdle() { -#ifdef USE_JEMALLOC - JemallocExtension::ThreadIdle(); -#endif -} - -void Allocator::FlushAll() { -#ifdef USE_JEMALLOC - JemallocExtension::FlushAll(); -#endif - MallocTrim(0); -} - -void Allocator::SetBackgroundThreads(bool enable) { -#ifdef USE_JEMALLOC - JemallocExtension::SetBackgroundThreads(enable); -#endif -} - //===--------------------------------------------------------------------===// // Debug Info (extended) //===--------------------------------------------------------------------===// diff --git a/src/duckdb/extension/jemalloc/jemalloc_extension.cpp b/src/duckdb/src/common/allocator/allocator_jemalloc.cpp similarity index 51% rename from src/duckdb/extension/jemalloc/jemalloc_extension.cpp rename to src/duckdb/src/common/allocator/allocator_jemalloc.cpp index 5911d097c..84735b16e 100644 --- a/src/duckdb/extension/jemalloc/jemalloc_extension.cpp +++ b/src/duckdb/src/common/allocator/allocator_jemalloc.cpp @@ -1,41 +1,41 @@ -#include "jemalloc_extension.hpp" - #include "duckdb/common/allocator.hpp" -#include "jemalloc/jemalloc.h" -#include "malloc_ncpus.h" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/string_util.hpp" #include +#include -namespace duckdb { +#if defined(DUCKDB_ENABLE_JEMALLOC) +#if !defined(WIN32) && INTPTR_MAX == INT64_MAX -static void LoadInternal(ExtensionLoader &) { - // NOP: This extension can only be loaded statically -} -void JemallocExtension::Load(ExtensionLoader &loader) { - LoadInternal(loader); -} +#include "jemalloc/jemalloc.h" +#include "duckdb/malloc_ncpus.h" -std::string JemallocExtension::Name() { - return "jemalloc"; -} +#else +#error "jemalloc support is only available on 64-bit Linux with DUCKDB_ENABLE_JEMALLOC enabled" +#endif -data_ptr_t JemallocExtension::Allocate(PrivateAllocatorData *private_data, idx_t size) { - return data_ptr_cast(duckdb_je_malloc(size)); -} +extern "C" { -void JemallocExtension::Free(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size) { - duckdb_je_free(pointer); +unsigned duckdb_malloc_ncpus() { +#ifdef DUCKDB_NO_THREADS + return 1; +#else + unsigned concurrency = duckdb::NumericCast(std::thread::hardware_concurrency()); + return std::max(concurrency, 1u); +#endif +} } -data_ptr_t JemallocExtension::Reallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, - idx_t size) { - return data_ptr_cast(duckdb_je_realloc(pointer, size)); +namespace duckdb { + +static string PurgeArenaString(idx_t arena_idx) { + return StringUtil::Format("arena.%llu.purge", arena_idx); } static void JemallocCTL(const char *name, void *old_ptr, size_t *old_len, void *new_ptr, size_t new_len) { if (duckdb_je_mallctl(name, old_ptr, old_len, new_ptr, new_len) != 0) { #ifdef DEBUG - // We only want to throw an exception here when debugging throw InternalException("je_mallctl failed for setting \"%s\"", name); #endif } @@ -58,32 +58,43 @@ static T GetJemallocCTL(const char *name) { return result; } -static inline string PurgeArenaString(idx_t arena_idx) { - return StringUtil::Format("arena.%llu.purge", arena_idx); +data_ptr_t Allocator::DefaultAllocate(PrivateAllocatorData *private_data, idx_t size) { + return data_ptr_cast(duckdb_je_malloc(size)); } -int64_t JemallocExtension::DecayDelay() { - return DUCKDB_JEMALLOC_DECAY; +void Allocator::DefaultFree(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size) { + duckdb_je_free(pointer); } -void JemallocExtension::ThreadFlush(idx_t threshold) { - // We flush after exceeding the threshold - if (GetJemallocCTL("thread.peak.read") > threshold) { - return; - } +data_ptr_t Allocator::DefaultReallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, + idx_t size) { + return data_ptr_cast(duckdb_je_realloc(pointer, size)); +} - // Flush thread-local cache - SetJemallocCTL("thread.tcache.flush"); +optional_idx Allocator::DecayDelay() { + return NumericCast(DUCKDB_JEMALLOC_DECAY); +} - // Flush this thread's arena - const auto purge_arena = PurgeArenaString(idx_t(GetJemallocCTL("thread.arena"))); - SetJemallocCTL(purge_arena.c_str()); +bool Allocator::SupportsFlush() { + return true; +} - // Reset the peak after resetting - SetJemallocCTL("thread.peak.reset"); +void Allocator::ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) { + if (!allocator_background_threads) { + // We flush after exceeding the threshold + if (GetJemallocCTL("thread.peak.read") <= threshold) { + return; + } + + // Flush thread-local cache + SetJemallocCTL("thread.tcache.flush"); + + // Reset the peak after flushing + SetJemallocCTL("thread.peak.reset"); + } } -void JemallocExtension::ThreadIdle() { +void Allocator::ThreadIdle() { // Indicate that this thread is idle SetJemallocCTL("thread.idle"); @@ -91,7 +102,7 @@ void JemallocExtension::ThreadIdle() { SetJemallocCTL("thread.peak.reset"); } -void JemallocExtension::FlushAll() { +void Allocator::FlushAll() { // Flush thread-local cache SetJemallocCTL("thread.tcache.flush"); @@ -103,34 +114,12 @@ void JemallocExtension::FlushAll() { SetJemallocCTL("thread.peak.reset"); } -void JemallocExtension::SetBackgroundThreads(bool enable) { +void Allocator::SetBackgroundThreads(bool enable) { #ifndef __APPLE__ SetJemallocCTL("background_thread", enable); #endif } -std::string JemallocExtension::Version() const { -#ifdef EXT_VERSION_JEMALLOC - return EXT_VERSION_JEMALLOC; -#else - return ""; -#endif -} - } // namespace duckdb -extern "C" { - -unsigned duckdb_malloc_ncpus() { -#ifdef DUCKDB_NO_THREADS - return 1 -#else - unsigned concurrency = duckdb::NumericCast(std::thread::hardware_concurrency()); - return std::max(concurrency, 1u); #endif -} - -DUCKDB_CPP_EXTENSION_ENTRY(jemalloc, loader) { - duckdb::LoadInternal(loader); -} -} diff --git a/src/duckdb/src/common/allocator/allocator_standard.cpp b/src/duckdb/src/common/allocator/allocator_standard.cpp new file mode 100644 index 000000000..45e03d46b --- /dev/null +++ b/src/duckdb/src/common/allocator/allocator_standard.cpp @@ -0,0 +1,95 @@ +#include "duckdb/common/allocator.hpp" + +#ifndef DUCKDB_ENABLE_JEMALLOC + +#include "duckdb/common/assert.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/timestamp.hpp" + +#ifdef __GLIBC__ +#include +#endif + +#ifdef DUCKDB_DEBUG_ALLOCATION +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/unordered_map.hpp" + +#include +#endif + +namespace duckdb { + +data_ptr_t Allocator::DefaultAllocate(PrivateAllocatorData *private_data, idx_t size) { + auto default_allocate_result = malloc(size); + if (!default_allocate_result) { + throw std::bad_alloc(); + } + return data_ptr_cast(default_allocate_result); +} + +void Allocator::DefaultFree(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size) { + free(pointer); +} + +data_ptr_t Allocator::DefaultReallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, + idx_t size) { + return data_ptr_cast(realloc(pointer, size)); +} + +optional_idx Allocator::DecayDelay() { + return optional_idx(); +} + +bool Allocator::SupportsFlush() { +#if defined(__GLIBC__) + return true; +#else + return false; +#endif +} + +static void MallocTrim(idx_t pad) { +#ifdef __GLIBC__ + static constexpr int64_t TRIM_INTERVAL_MS = 100; + static atomic LAST_TRIM_TIMESTAMP_MS {0}; + + int64_t last_trim_timestamp_ms = LAST_TRIM_TIMESTAMP_MS.load(); + auto current_ts = Timestamp::GetCurrentTimestamp(); + auto current_timestamp_ms = Cast::Operation(current_ts).value; + + if (current_timestamp_ms - last_trim_timestamp_ms < TRIM_INTERVAL_MS) { + return; // We trimmed less than TRIM_INTERVAL_MS ago + } + if (!LAST_TRIM_TIMESTAMP_MS.compare_exchange_strong(last_trim_timestamp_ms, current_timestamp_ms, + std::memory_order_acquire, std::memory_order_relaxed)) { + return; // Another thread has updated LAST_TRIM_TIMESTAMP_MS since we loaded it + } + + // We successfully updated LAST_TRIM_TIMESTAMP_MS, we can trim + malloc_trim(pad); +#endif +} + +void Allocator::ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) { + MallocTrim(thread_count * threshold); +} + +void Allocator::ThreadIdle() { + /* no-op */ +} + +void Allocator::FlushAll() { + MallocTrim(0); +} + +void Allocator::SetBackgroundThreads(bool enable) { + /* no-op */ +} + +} // namespace duckdb + +#endif diff --git a/src/duckdb/src/common/arrow/appender/bool_data.cpp b/src/duckdb/src/common/arrow/appender/bool_data.cpp index 1ee6fbd87..645758046 100644 --- a/src/duckdb/src/common/arrow/appender/bool_data.cpp +++ b/src/duckdb/src/common/arrow/appender/bool_data.cpp @@ -8,7 +8,7 @@ void ArrowBoolData::Initialize(ArrowAppendData &result, const LogicalType &type, result.GetMainBuffer().reserve(byte_count); } -void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { +void ArrowBoolData::Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { idx_t size = to - from; auto &main_buffer = append_data.GetMainBuffer(); auto &validity_buffer = append_data.GetValidityBuffer(); diff --git a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp index 4188a7157..9d5ff74b3 100644 --- a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp +++ b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp @@ -14,7 +14,7 @@ void ArrowFixedSizeListData::Initialize(ArrowAppendData &result, const LogicalTy result.child_data.push_back(std::move(child_buffer)); } -void ArrowFixedSizeListData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, +void ArrowFixedSizeListData::Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { UnifiedVectorFormat format; input.ToUnifiedFormat(format); @@ -22,7 +22,7 @@ void ArrowFixedSizeListData::Append(ArrowAppendData &append_data, Vector &input, append_data.AppendValidity(format, from, to); input.Flatten(); auto array_size = ArrayType::GetSize(input.GetType()); - auto &child_vector = ArrayVector::GetChildMutable(input); + auto &child_vector = ArrayVector::GetChild(input); auto &child_data = *append_data.child_data[0]; child_data.append_vector(child_data, child_vector, from * array_size, to * array_size, size * array_size); append_data.row_count += size; diff --git a/src/duckdb/src/common/arrow/appender/null_data.cpp b/src/duckdb/src/common/arrow/appender/null_data.cpp index 5f8806ed6..de7702b07 100644 --- a/src/duckdb/src/common/arrow/appender/null_data.cpp +++ b/src/duckdb/src/common/arrow/appender/null_data.cpp @@ -7,7 +7,7 @@ void ArrowNullData::Initialize(ArrowAppendData &result, const LogicalType &type, // nop } -void ArrowNullData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { +void ArrowNullData::Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { idx_t size = to - from; append_data.row_count += size; } diff --git a/src/duckdb/src/common/arrow/appender/struct_data.cpp b/src/duckdb/src/common/arrow/appender/struct_data.cpp index d6667efa9..018e82013 100644 --- a/src/duckdb/src/common/arrow/appender/struct_data.cpp +++ b/src/duckdb/src/common/arrow/appender/struct_data.cpp @@ -16,7 +16,8 @@ void ArrowStructData::Initialize(ArrowAppendData &result, const LogicalType &typ } } -void ArrowStructData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { +void ArrowStructData::Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, + idx_t input_size) { UnifiedVectorFormat format; input.ToUnifiedFormat(format); idx_t size = to - from; diff --git a/src/duckdb/src/common/arrow/appender/union_data.cpp b/src/duckdb/src/common/arrow/appender/union_data.cpp index d8951379e..382af4b9b 100644 --- a/src/duckdb/src/common/arrow/appender/union_data.cpp +++ b/src/duckdb/src/common/arrow/appender/union_data.cpp @@ -16,7 +16,7 @@ void ArrowUnionData::Initialize(ArrowAppendData &result, const LogicalType &type } } -void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { +void ArrowUnionData::Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { UnifiedVectorFormat format; input.ToUnifiedFormat(format); idx_t size = to - from; diff --git a/src/duckdb/src/common/arrow/arrow_converter.cpp b/src/duckdb/src/common/arrow/arrow_converter.cpp index 46f84ae7b..044c57e12 100644 --- a/src/duckdb/src/common/arrow/arrow_converter.cpp +++ b/src/duckdb/src/common/arrow/arrow_converter.cpp @@ -78,7 +78,7 @@ void SetArrowStructFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &chi child.children = &root_holder.nested_children_ptr.back()[0]; for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { InitializeChild(*child.children[type_idx], root_holder); - root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); + root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first.GetIdentifierName())); child.children[type_idx]->name = root_holder.owned_type_names.back().get(); SetArrowFormat(root_holder, *child.children[type_idx], child_types[type_idx].second, options, context); } @@ -370,7 +370,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { InitializeChild(*child.children[type_idx], root_holder); - root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); + root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first.GetIdentifierName())); child.children[type_idx]->name = root_holder.owned_type_names.back().get(); SetArrowFormat(root_holder, *child.children[type_idx], child_types[type_idx].second, options, context); diff --git a/src/duckdb/src/common/arrow/arrow_type_extension.cpp b/src/duckdb/src/common/arrow/arrow_type_extension.cpp index f9cf046de..24cbba223 100644 --- a/src/duckdb/src/common/arrow/arrow_type_extension.cpp +++ b/src/duckdb/src/common/arrow/arrow_type_extension.cpp @@ -474,7 +474,6 @@ struct ArrowGeometry { duckdb_yyjson::yyjson_doc_free(projjson_doc); } else { - duckdb_yyjson::yyjson_mut_doc_free(doc); throw SerializationException("Could not parse PROJJSON CRS for GeoArrow metadata"); } } break; @@ -523,6 +522,7 @@ struct ArrowGeometry { duckdb_yyjson::yyjson_mut_doc_free(doc); free(json_text); } else { + duckdb_yyjson::yyjson_mut_doc_free(doc); schema_metadata.AddOption(ArrowSchemaMetadata::ARROW_METADATA_KEY, "{}"); } @@ -542,7 +542,7 @@ struct ArrowGeometry { } static void DuckToArrow(ClientContext &context, Vector &source, Vector &result, idx_t count) { - Geometry::ToBinary(source, result, count); + Geometry::ToBinary(source, result); } }; diff --git a/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp b/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp index 11406c540..98b253825 100644 --- a/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp +++ b/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp @@ -18,14 +18,14 @@ SinkFinalizeType PhysicalArrowBatchCollector::Finalize(Pipeline &pipeline, Event auto total_tuple_count = gstate.data.Count(); if (total_tuple_count == 0) { // Create the result containing a single empty result conversion - gstate.result = make_uniq(statement_type, properties, names, types, + gstate.result = make_uniq(statement_type, properties, IdentifiersToStrings(names), types, context.GetClientProperties(), record_batch_size); return SinkFinalizeType::READY; } // Already create the final query result - gstate.result = make_uniq(statement_type, properties, names, types, context.GetClientProperties(), - record_batch_size); + gstate.result = make_uniq(statement_type, properties, IdentifiersToStrings(names), types, + context.GetClientProperties(), record_batch_size); // Spawn an event that will populate the conversion result auto &arrow_result = gstate.result->Cast(); auto new_event = make_shared_ptr(arrow_result, gstate.data, pipeline); diff --git a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp index 7934a0438..ebb4f9de7 100644 --- a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp +++ b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp @@ -111,13 +111,13 @@ SinkFinalizeType PhysicalArrowCollector::Finalize(Pipeline &pipeline, Event &eve "PhysicalArrowCollector Finalize contains no chunks, but tuple_count is non-zero (%d)", gstate.tuple_count); } - gstate.result = make_uniq(statement_type, properties, names, types, + gstate.result = make_uniq(statement_type, properties, IdentifiersToStrings(names), types, context.GetClientProperties(), record_batch_size); return SinkFinalizeType::READY; } - gstate.result = make_uniq(statement_type, properties, names, types, context.GetClientProperties(), - record_batch_size); + gstate.result = make_uniq(statement_type, properties, IdentifiersToStrings(names), types, + context.GetClientProperties(), record_batch_size); auto &arrow_result = gstate.result->Cast(); arrow_result.SetArrowData(std::move(gstate.chunks)); diff --git a/src/duckdb/src/common/bind_helpers.cpp b/src/duckdb/src/common/bind_helpers.cpp index 6a3c02291..cf7cf5cfe 100644 --- a/src/duckdb/src/common/bind_helpers.cpp +++ b/src/duckdb/src/common/bind_helpers.cpp @@ -5,6 +5,13 @@ #include "duckdb/common/types/value.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/exception/binder_exception.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/main/config.hpp" + #include namespace duckdb { @@ -69,7 +76,7 @@ vector ParseColumnList(const Value &value, vector &names, const st return ParseColumnList(children, names, loption); } -vector ParseColumnsOrdered(const vector &set, vector &names, const string &loption) { +vector ParseColumnsOrdered(const vector &set, const vector &names, const string &loption) { vector result; if (set.empty()) { @@ -77,9 +84,12 @@ vector ParseColumnsOrdered(const vector &set, vector &name } // Maps option to bool indicating if its found and the index in the original set - case_insensitive_map_t> option_map; + identifier_map_t> option_map; for (idx_t i = 0; i < set.size(); i++) { - option_map[set[i].ToString()] = {false, i}; + const auto [it, inserted] = option_map.emplace(make_pair(set[i].ToString(), make_pair(false, i))); + if (!inserted) { + throw BinderException("\"%s\" does now allow duplicate columns (found: %s)", loption, set[i].ToString()); + } } result.resize(option_map.size()); @@ -99,7 +109,7 @@ vector ParseColumnsOrdered(const vector &set, vector &name return result; } -vector ParseColumnsOrdered(const Value &value, vector &names, const string &loption) { +vector ParseColumnsOrdered(const Value &value, const vector &names, const string &loption) { vector result; // Only accept a list of arguments @@ -123,4 +133,57 @@ vector ParseColumnsOrdered(const Value &value, vector &names, con return ParseColumnsOrdered(children, names, loption); } +vector ParseOrderByColumns(Binder &binder, const vector &set, + const BoundStatement &bound_statement, const string &loption) { + // Parse + vector order_by_strings; + for (auto &value : set) { + order_by_strings.push_back(value.ToString()); + } + const auto order_by_clause = StringUtil::Join(order_by_strings, ", "); + auto parsed_orders = Parser::ParseOrderList(order_by_clause); + + // Bind + auto &config = DBConfig::GetConfig(binder.context); + auto child_binder = Binder::CreateBinder(binder.context, &binder); + auto table_index = binder.GenerateTableIndex(); + child_binder->bind_context.AddGenericBinding(table_index, "__copy_input", bound_statement.names, + bound_statement.types); + ExpressionBinder expr_binder(*child_binder, binder.context); + vector bound_orders; + for (auto &parsed_order : parsed_orders) { + const auto order_type = config.ResolveOrder(binder.context, parsed_order.type); + const auto null_order = config.ResolveNullOrder(binder.context, order_type, parsed_order.null_order); + bound_orders.emplace_back(order_type, null_order, expr_binder.Bind(parsed_order.expression)); + } + + // Convert BoundColumnRefExpression to BoundReferenceExpression + vector name_set; + case_insensitive_map_t>>>> name_map; + for (auto &bound_order : bound_orders) { + ExpressionIterator::VisitExpressionClassMutable( + bound_order.expression, ExpressionClass::BOUND_COLUMN_REF, [&](unique_ptr &child) { + auto [it, inserted] = name_map.emplace(make_pair( + child->ToString(), make_pair(name_set.size(), vector>>()))); + it->second.second.push_back(child); + if (inserted) { + // Ensure we only add unique values + name_set.emplace_back(child->ToString()); + } + }); + } + const auto indices = ParseColumnsOrdered(name_set, bound_statement.names, loption); + D_ASSERT(name_set.size() == indices.size()); + for (const auto &name_value : name_set) { + auto name = name_value.ToString(); + const auto &[idx, expressions] = name_map[name]; + for (auto &expr : expressions) { + expr.get() = + make_uniq(Identifier(name), expr.get()->GetReturnType(), indices[idx]); + } + } + + return bound_orders; +} + } // namespace duckdb diff --git a/src/duckdb/src/common/box_renderer.cpp b/src/duckdb/src/common/box_renderer.cpp index a62492be6..b2b839a7b 100644 --- a/src/duckdb/src/common/box_renderer.cpp +++ b/src/duckdb/src/common/box_renderer.cpp @@ -682,7 +682,7 @@ void BoxRendererImplementation::FetchTopCollection(RenderDataCollection &top_col ConvertRenderVector(target_vector, render_lengths, insert_count, source_vector.GetType(), null_render_length); } - insert_result.SetCardinality(insert_count); + insert_result.SetChildCardinality(insert_count); // construct the render collection top_collection.render_values->Append(insert_result); @@ -718,7 +718,7 @@ void BoxRendererImplementation::FetchTopCollection(RenderDataCollection &top_col values.SetValue(0, Value(readable_numbers[c])); render_widths.SetValue(0, Value::UBIGINT(Utf8Proc::RenderWidth(readable_numbers[c]))); } - insert_result.SetCardinality(1); + insert_result.SetChildCardinality(1); top_collection.render_values->Append(insert_result); } else { config.large_number_rendering = LargeNumberRendering::NONE; @@ -801,7 +801,7 @@ void BoxRendererImplementation::FetchBottomCollection(RenderDataCollection &bott ConvertRenderVector(target_vector, render_lengths, insert_count, source_vector.GetType(), null_render_length); } - insert_result.SetCardinality(insert_count); + insert_result.SetChildCardinality(insert_count); // construct the render collection bottom_collection.render_values->Append(insert_result); } @@ -881,7 +881,7 @@ vector BoxRendererImplementation::PivotCollections(vector< } } } - row_chunk.SetCardinality(row_chunk.size() + 1); + // the appends above add exactly one row (per column) to the child vectors, growing row_chunk.size() if (row_chunk.size() == STANDARD_VECTOR_SIZE || c + 1 == column_names.size()) { res_coll.Append(append_state, row_chunk); row_chunk.Reset(); diff --git a/src/duckdb/src/common/box_renderer_context.cpp b/src/duckdb/src/common/box_renderer_context.cpp index d1e14d30b..6820e5273 100644 --- a/src/duckdb/src/common/box_renderer_context.cpp +++ b/src/duckdb/src/common/box_renderer_context.cpp @@ -5,11 +5,11 @@ namespace duckdb { -void BoxRendererContext::CastToVarchar(Vector &source, Vector &result, idx_t count) { +void BoxRendererContext::CastToVarchar(const Vector &source, Vector &result, idx_t count) { DataChunk source_chunk; source_chunk.InitializeEmpty({source.GetType()}); source_chunk.data[0].Reference(source); - source_chunk.SetCardinality(count); + source_chunk.SetChildCardinality(count); DataChunk result_chunk; result_chunk.Initialize(GetAllocator(), {LogicalType::VARCHAR}); diff --git a/src/duckdb/src/common/cgroups.cpp b/src/duckdb/src/common/cgroups.cpp index cb3a5e264..bfb39610c 100644 --- a/src/duckdb/src/common/cgroups.cpp +++ b/src/duckdb/src/common/cgroups.cpp @@ -251,7 +251,7 @@ idx_t CGroups::GetCPULimit(FileSystem &fs, idx_t physical_cores) { for (auto &controller : controller_list) { if (controller == "cpu") { cpu_entry = i; - continue; + break; } } } diff --git a/src/duckdb/src/common/checksum.cpp b/src/duckdb/src/common/checksum.cpp index 5448d239a..3bf01e1be 100644 --- a/src/duckdb/src/common/checksum.cpp +++ b/src/duckdb/src/common/checksum.cpp @@ -72,7 +72,7 @@ uint64_t Checksum(const uint8_t *buffer, size_t size) { for (i = 0; i < size / 8; i++) { result ^= Checksum(ptr[i]); } - if (size - i * 8 > 0) { + if (size > i * 8) { // the remaining 0-7 bytes we hash using a string hash result ^= ChecksumRemainder(buffer + i * 8, size - i * 8); } diff --git a/src/duckdb/src/common/clustered_aggregate.cpp b/src/duckdb/src/common/clustered_aggregate.cpp new file mode 100644 index 000000000..c9c063118 --- /dev/null +++ b/src/duckdb/src/common/clustered_aggregate.cpp @@ -0,0 +1,371 @@ +#include "duckdb/common/clustered_aggregate.hpp" + +#include "duckdb/common/common.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/vector/flat_vector.hpp" +#include "duckdb/common/vector/dictionary_vector.hpp" + +#include + +namespace duckdb { + +// Why does clustering help performance so much in TPC-H q01? Yes, "in-register" sum (for example) is +// faster than load-add-store and avoids setting "bool is_set" every time - but this is the sub-plot only. +// The main reason is cache-consistency latency because multiple load-add-store patterns hit the +// same aggregate state in short succession. To illustrate that with TPC-H Q01, change SELECT +// l_returnflag,l_shipmode into SELECT CAST(l_quantity * 1.0 AS INT) .. GROUP BY ALL ORDER BY ALL. +// With the 1.0 constant there are 50 groups, but by decreasing it, one can scale flexibly down to +// fewer groups. Negative performance effects start around 25 and grow strongest at <= 5 groups. +// +// Is this only affecting aggregations with <25 groups? No, there are two more cases: +// - repeating keys: rolling up an ordered dataframe, or aggregating back to the PK of a FK-PK join. +// e.g. change SELECT l_returnflag, l_shipmode into SELECT l_orderkey>>2, .. GROUP BY ALL +// (orderkey repeats 4 times avg -- this is not enough, I put the threshold at 6, hence >>2) +// - (very) heavy hittng keys in a GROUP BY with potentially many groups, e.g. when 75% is 0: +// e.g. change SELECT l_returnflag, l_shipmode into SELECT ((l_orderkey>>1)&l_orderkey&1)*l_orderkey, +// +// We also can deploy clustering on top of the dictionary-lookup optimization (mostly the few keys case) +// e.g. change SELECT l_returnflag, l_shipmode into SELECT l_shipinstruct, .. GROUP BY ALL +// +// Strategy: +// - A short dual-cursor sample inserts keys via hash_store. miss counters per cursor say whether +// the input looks (near-)sequential. +// - Sequential -> process the rest (most) using a merge_loop (no hash lookups). +// - Otherwise -> continue dual hash_store until the total hot-key budget is exhausted, then +// mixed_loop per cursor (hot via HT, cold via seq region). +// - At the end, memmove st2's grouop_run slice down next to st1's so consumers see one contiguous range. +// (we construct separate group_run[] arrays for the two cursors but concatenate in them eventually) +// +// Eventually we declare success iff >75% of the inputs ends up in a large run (>= 6). + +// Slot bitfield widths come from ClusteredAggr (header). Pull them into file scope for brevity. +constexpr idx_t SLOT_GRP_BITS = ClusteredAggr::SLOT_GRP_BITS; +constexpr idx_t SLOT_CURSOR_BITS = ClusteredAggr::SLOT_CURSOR_BITS; +constexpr idx_t SLOT_GID_BITS = ClusteredAggr::SLOT_GID_BITS; +constexpr uint64_t GID_MASK = ClusteredAggr::MAX_GID_COUNT - 1; + +static_assert((1 << SLOT_GRP_BITS) >= STANDARD_VECTOR_SIZE, "Group IDs must be able to address vectorsize"); + +struct Slot { // really just a 64-bits integer + union { + struct { +#if !defined(__BYTE_ORDER__) || __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + uint64_t cursor : SLOT_CURSOR_BITS, idx : SLOT_GRP_BITS, gid : SLOT_GID_BITS; +#else // reverse order + uint64_t gid : SLOT_GID_BITS, idx : SLOT_GRP_BITS, cursor : SLOT_CURSOR_BITS; +#endif + } bitfields; + uint64_t i64; // allows to increment cursor without bit extraction logic + } val; + Slot() { + val.i64 = ClusteredAggr::FREE_SLOT; + } + Slot(uint64_t c, uint64_t i, uint64_t g) { + val.bitfields.cursor = c; + val.bitfields.idx = i; + val.bitfields.gid = g; + } + //! Sentinel value meaning "no slot allocated yet" (all bitfields are 1s). + static const Slot UNINITIALIZED; +}; +const Slot Slot::UNINITIALIZED {}; + +// Each cursor uses an independent slice of the shared arena and its own slice of group_runs[]. +// Each slice has a hot region first (MAX_HOT_PER_CURSOR lists of HALF_VEC slots) followed by a +// HALF_VEC-sized seq bump region. We keep a non-restricted slice base for offset math, plus +// restricted pointers to the two disjoint write regions. +struct SlotTab { + Slot *const tab; // [HASHTAB_SZ] hash table + sel_t *const slice; // this cursor's full arena slice base + sel_t *__restrict const arena; // hot region base inside this slice + sel_t *__restrict seq; // bump pointer in the disjoint sequential region + const idx_t group_base; // base index into group_runs[] for this cursor (0 or HALF_VEC) + idx_t n_runs; // total runs created so far in this cursor's slice + idx_t hot_n_runs; // frozen at the end of the hot-insertion phase +}; + +static inline uint64_t slot_hash(uint64_t gid) { + return ((gid * 3217161767) ^ (gid >> ClusteredAggr::HASHTAB_LOG2)) & (ClusteredAggr::HASHTAB_SZ - 1); +} + +bool ClusteredAggr::TryClustered(const uint64_t *group_ids, sel_t count, sel_t *arena, uint64_t *slots) { + constexpr sel_t HALF_VEC = STANDARD_VECTOR_SIZE / 2; + constexpr idx_t MAX_HOT_PER_CURSOR = MAX_HOTKEYS / 2; + idx_t tuples_in_large = 0; + + // Per-cursor arena slices: each has MAX_HOT_PER_CURSOR hot regions followed by 1 seq bump + // region, all HALF_VEC slots wide. Total per slice = (MAX_HOT_PER_CURSOR + 1) * HALF_VEC. + constexpr idx_t SLICE_HALF_VECS = MAX_HOT_PER_CURSOR + 1; + sel_t *const slice1 = arena; + sel_t *const slice2 = arena + SLICE_HALF_VECS * HALF_VEC; + SlotTab st1 {reinterpret_cast(slots), slice1, slice1, slice1 + MAX_HOT_PER_CURSOR * HALF_VEC, 0, 0, 0}; + SlotTab st2 {reinterpret_cast(slots + HASHTAB_SZ), + slice2, + slice2, + slice2 + MAX_HOT_PER_CURSOR * HALF_VEC, + HALF_VEC, + 0, + 0}; + + auto finish_run = [&](SlotTab &st, const Slot s) -> uint64_t { + const idx_t cnt = static_cast((st.slice + s.val.bitfields.cursor) - group_runs[s.val.bitfields.idx].sel); + group_runs[s.val.bitfields.idx].count = cnt; + if (s.val.bitfields.idx >= st.group_base + st.hot_n_runs) { + st.seq += cnt; // sequential runs sequentially allocate from this cursor's seq region + } + if (cnt >= RUNLENGTH_THRESHOLD) { + tuples_in_large += cnt; + } + return s.val.bitfields.cursor; + }; + + auto free_slot = [&](SlotTab &st, uint64_t hash) { + if (st.tab[hash].val.i64 != FREE_SLOT) { + finish_run(st, st.tab[hash]); + st.tab[hash].val.i64 = FREE_SLOT; + } + }; + + auto new_run = [&](SlotTab &st, uint64_t gid, uint64_t list_start) -> Slot { + uint64_t idx = st.group_base + st.n_runs++; + group_runs[idx].sel = st.slice + list_start; + group_runs[idx].gid = gid; + return Slot(list_start, idx, gid); + }; + + auto flush = [&](SlotTab &st, Slot s) { + if (s.val.bitfields.idx < st.group_base + st.hot_n_runs) { // small groupnrs are hot keys + st.tab[slot_hash(s.val.bitfields.gid)] = s; // store back hot slot in HT + } else { + finish_run(st, s); // finish non-hot (sequential) run + } + }; + + auto hash_store = [&](SlotTab &st, uint64_t gid, sel_t pos) -> Slot { + uint64_t hash = slot_hash(gid) & GID_MASK; + Slot s = st.tab[hash]; + if (DUCKDB_UNLIKELY(s.val.i64 == FREE_SLOT)) { // first insert + s = new_run(st, gid, st.n_runs * HALF_VEC); + } else if (DUCKDB_UNLIKELY(s.val.bitfields.gid != gid)) { // collision + s = new_run(st, gid, finish_run(st, s)); // close old, new run continues in same slot + } + st.arena[s.val.bitfields.cursor] = pos; + s.val.i64++; // val.bitfields.cursor++ without bit-extraction overhead + return st.tab[hash] = s; + }; + + auto mixed_loop = [&](SlotTab &st, Slot cur, sel_t lo, sel_t hi) { + for (sel_t pos = lo; pos < hi; pos++) { + uint64_t gid = group_ids[pos] & GID_MASK; + if (DUCKDB_UNLIKELY(cur.val.bitfields.gid != gid)) { // no sequential match? + flush(st, cur); // flush previous list + Slot s = st.tab[slot_hash(gid)]; // hot key lookup + cur = (s.val.i64 != FREE_SLOT && s.val.bitfields.gid == gid) + ? s // hot key found + : new_run(st, gid, static_cast(st.seq - st.slice)); + } + st.slice[cur.val.bitfields.cursor] = pos; + cur.val.i64++; + } + if (lo < hi) { + flush(st, cur); + } + }; + + // merge_loop: seeded from sample's last cur per cursor; no HT use. Each gid change creates a + // new sequential run. We avoid an unconditional initial new_run so that the worst-case run + // count fits in the per-cursor HALF_VEC budget even when every position is a transition. + auto merge_loop = [&](SlotTab &st, Slot cur, sel_t lo, sel_t hi) { + for (sel_t pos = lo; pos < hi; pos++) { + uint64_t gid = group_ids[pos] & GID_MASK; + if (DUCKDB_UNLIKELY(cur.val.bitfields.gid != gid)) { // just check sequentially + flush(st, cur); // flush previous list + cur = new_run(st, gid, static_cast(st.seq - st.slice)); + } + st.slice[cur.val.bitfields.cursor] = pos; + cur.val.i64++; + } + if (lo < hi) { + flush(st, cur); + } + }; + + auto free_table = [&](SlotTab &st) { + for (idx_t i = 0; i < st.hot_n_runs; i++) { + free_slot(st, slot_hash(group_runs[st.group_base + i].gid)); + } + }; + + const sel_t sample = SAMPLE_SIZE / 2; // dual cursor: half as many iterations as positions + const sel_t half = count / 2; + Slot cur1 = Slot::UNINITIALIZED; + Slot cur2 = Slot::UNINITIALIZED; + sel_t miss1 = 0, miss2 = 0, pos; + for (pos = 0; pos < sample && st1.n_runs < MAX_HOT_PER_CURSOR && st2.n_runs < MAX_HOT_PER_CURSOR; pos++) { + uint64_t gid1 = group_ids[pos] & GID_MASK; + uint64_t gid2 = group_ids[pos + half] & GID_MASK; + miss1 += (cur1.val.bitfields.gid != gid1); + miss2 += (cur2.val.bitfields.gid != gid2); + cur1 = hash_store(st1, gid1, pos); + cur2 = hash_store(st2, gid2, pos + half); + } + bool fully_clustered = (miss1 + miss2 <= (st1.n_runs + st2.n_runs) + 2); // +2 because at pos==0 both miss + + if (!fully_clustered) { // use hash strategy until we find too many distinct hot keys + for (; pos < half && st1.n_runs < MAX_HOT_PER_CURSOR && st2.n_runs < MAX_HOT_PER_CURSOR; pos++) { + cur1 = hash_store(st1, group_ids[pos] & GID_MASK, pos); + cur2 = hash_store(st2, group_ids[pos + half] & GID_MASK, pos + half); + } + } + // All key insertions are done; new runs created beyond this point are sequential. + st1.hot_n_runs = st1.n_runs; + st2.hot_n_runs = st2.n_runs; + + if (fully_clustered) { // Merge strategy: No HT lookups; just detect transitions. + merge_loop(st1, cur1, pos, half); + merge_loop(st2, cur2, pos + half, count); + } else { // Mixed strategy: do HT lookup to find hot keys, cold gids are just matched sequentially + mixed_loop(st1, cur1, pos, half); + mixed_loop(st2, cur2, pos + half, count); + } + free_table(st1); + free_table(st2); + + if (st2.n_runs > 0 && st1.n_runs < HALF_VEC) { // make st2 groups consecutive with st1 + std::memmove(&group_runs[st1.n_runs], &group_runs[HALF_VEC], st2.n_runs * sizeof(GroupRun)); + } + n_group_runs = st1.n_runs + st2.n_runs; + cached_dict_sel = nullptr; + return (5 * tuples_in_large >= 4 * static_cast(count)); // require >=80% of tuples in long-enough runs +} + +void ClusteredAggr::SetSingleRun(data_ptr_t state, idx_t count) { + n_group_runs = 1; + group_runs[0].state = state; + group_runs[0].sel = nullptr; + group_runs[0].gid = 0; + group_runs[0].count = count; + cached_dict_sel = nullptr; +} + +void ClusteredAggr::AdvanceStates(idx_t payload_size) { + for (idx_t r = 0; r < n_group_runs; r++) { + group_runs[r].state += payload_size; + } +} + +const sel_t *ClusteredAggr::ClusterIter(const Vector &input, idx_t count) const { + if (input.GetVectorType() != VectorType::DICTIONARY_VECTOR) { + return nullptr; + } + auto &child = DictionaryVector::Child(input); + if (child.GetVectorType() != VectorType::FLAT_VECTOR) { + return nullptr; + } + auto &dict_sel = DictionaryVector::SelVector(input); + const sel_t *dict_data = dict_sel.data(); + if (dict_data == nullptr) { + return nullptr; + } + if (cached_dict_sel == dict_data) { + return composed_sel_data; + } + idx_t pos = 0; + for (idx_t r = 0; r < n_group_runs; r++) { + const auto *run_sel = group_runs[r].sel; + const auto run_count = group_runs[r].count; + for (idx_t k = 0; k < run_count; k++) { + auto idx = run_sel ? run_sel[k] : k; + composed_sel_data[pos + k] = dict_data[idx]; + } + pos += run_count; + } + cached_dict_sel = dict_data; + return composed_sel_data; +} + +void ClusteredAggrState::Initialize() { + // Two per-cursor slices: each has MAX_HOTKEYS/2 hot regions + 1 seq region of HALF_VEC each. + // Total = 2 * (MAX_HOTKEYS/2 + 1) * HALF_VEC = (MAX_HOTKEYS/2 + 1) * STANDARD_VECTOR_SIZE. + arena = make_unsafe_uniq_array_uninitialized((ClusteredAggr::MAX_HOTKEYS / 2 + 1) * STANDARD_VECTOR_SIZE); + slots = make_unsafe_uniq_array_uninitialized(2 * ClusteredAggr::HASHTAB_SZ); + std::fill_n(slots.get(), 2 * ClusteredAggr::HASHTAB_SZ, ClusteredAggr::FREE_SLOT); + skipped_opportunities = 0; + retry_backoff = 1; +} + +bool ClusteredAggrState::TryBuild(ClusteredAggr &clustered, const uint64_t *group_ids, idx_t count) { + if (!arena) { + return false; + } + if (skipped_opportunities > 0) { + skipped_opportunities--; + return false; + } + if (count >= ClusteredAggr::SAMPLE_SIZE && + clustered.TryClustered(group_ids, static_cast(count), arena.get(), slots.get())) { + clustered.state = this; + skipped_opportunities = 0; + retry_backoff = 1; + return true; + } + skipped_opportunities = retry_backoff; + retry_backoff = MinValue(NumericLimits::Maximum() / 2, retry_backoff) * 2; + return false; +} + +optional_ptr ClusteredAggrState::GetDictProps(const Vector &input) const { + if (input.GetVectorType() != VectorType::DICTIONARY_VECTOR) { + return nullptr; + } + const auto &dict_id = DictionaryVector::DictionaryId(input); + if (dict_id.empty()) { + return nullptr; + } + auto it = dict_props.find(dict_id); + if (it != dict_props.end()) { + return &it->second; + } + + auto &slot = dict_props[dict_id]; + const auto dict_size_opt = DictionaryVector::DictionarySize(input); + if (!dict_size_opt.IsValid()) { + return &slot; + } + const idx_t dict_size = dict_size_opt.GetIndex(); + auto &child = DictionaryVector::Child(input); + if (child.GetVectorType() != VectorType::FLAT_VECTOR) { + return &slot; + } + + auto &mask = FlatVector::Validity(child); + auto buf = make_unsafe_uniq_array_uninitialized(dict_size); + auto try_copy = [&](auto src) { + for (idx_t i = 0; i < dict_size; i++) { + if (!mask.RowIsValid(i)) { + buf[i] = 0; + continue; + } + const int64_t v = static_cast(src[i]); + if (!I64VectorSumSafe(v)) { + return false; + } + buf[i] = v; + } + return true; + }; + auto type = child.GetType().InternalType(); + if (type == PhysicalType::UINT32 || type == PhysicalType::INT32) { + if (try_copy(FlatVector::GetData(child))) { + slot = std::move(buf); + } + } else if (type == PhysicalType::UINT64 || type == PhysicalType::INT64) { + if (try_copy(FlatVector::GetData(child))) { + slot = std::move(buf); + } + } + return &slot; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/csv_writer.cpp b/src/duckdb/src/common/csv_writer.cpp index 8920aecc0..c9df0dd82 100644 --- a/src/duckdb/src/common/csv_writer.cpp +++ b/src/duckdb/src/common/csv_writer.cpp @@ -318,7 +318,7 @@ void CSVWriter::WriteQuotedString(WriteStream &writer, const char *str, idx_t le void CSVWriter::WriteChunk(DataChunk &input, MemoryStream &writer, CSVReaderOptions &options, bool &written_anything, CSVWriterOptions &writer_options) { vector> input_iterators; - for (auto &col : input.data) { + for (const auto &col : input.data) { input_iterators.emplace_back(col.Values()); } // now loop over the vectors and output the values diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index f6a69d7dc..187894805 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -10,6 +10,7 @@ #include "duckdb/common/enum_util.hpp" +#include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/catalog_entry/dependency/dependency_entry.hpp" #include "duckdb/catalog/catalog_entry/table_column_type.hpp" #include "duckdb/common/box_renderer.hpp" @@ -47,7 +48,6 @@ #include "duckdb/common/enums/logical_operator_type.hpp" #include "duckdb/common/enums/memory_tag.hpp" #include "duckdb/common/enums/merge_action_type.hpp" -#include "duckdb/common/enums/metric_type.hpp" #include "duckdb/common/enums/on_create_conflict.hpp" #include "duckdb/common/enums/on_entry_not_found.hpp" #include "duckdb/common/enums/operator_result_type.hpp" @@ -72,6 +72,7 @@ #include "duckdb/common/enums/stream_execution_result.hpp" #include "duckdb/common/enums/subquery_type.hpp" #include "duckdb/common/enums/tableref_type.hpp" +#include "duckdb/common/enums/task_scheduler_type.hpp" #include "duckdb/common/enums/thread_pin_mode.hpp" #include "duckdb/common/enums/trigger_type.hpp" #include "duckdb/common/enums/tuple_data_layout_enums.hpp" @@ -118,6 +119,8 @@ #include "duckdb/execution/index/unbound_index.hpp" #include "duckdb/execution/operator/csv_scanner/csv_option.hpp" #include "duckdb/execution/operator/csv_scanner/csv_state.hpp" +#include "duckdb/execution/operator/join/join_filter_pushdown.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte_state.hpp" #include "duckdb/execution/physical_operator.hpp" #include "duckdb/execution/physical_table_scan_enum.hpp" #include "duckdb/execution/reservoir_sample.hpp" @@ -146,12 +149,15 @@ #include "duckdb/main/extension.hpp" #include "duckdb/main/extension_helper.hpp" #include "duckdb/main/extension_install_info.hpp" -#include "duckdb/main/profiling_info.hpp" +#include "duckdb/main/gathered_metrics.hpp" #include "duckdb/main/query_parameters.hpp" #include "duckdb/main/query_profiler.hpp" #include "duckdb/main/query_result.hpp" #include "duckdb/main/secret/secret.hpp" #include "duckdb/main/setting_info.hpp" +#include "duckdb/optimizer/build_probe_side_optimizer.hpp" +#include "duckdb/optimizer/compressed_materialization.hpp" +#include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" #include "duckdb/optimizer/remove_unused_columns.hpp" #include "duckdb/parallel/async_result.hpp" #include "duckdb/parallel/interrupt.hpp" @@ -188,7 +194,7 @@ #include "duckdb/parser/tableref/showref.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/bound_result_modifier.hpp" -#include "duckdb/planner/filter/selectivity_optional_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/buffer/buffer_pool_reservation.hpp" #include "duckdb/storage/caching_mode.hpp" @@ -198,10 +204,13 @@ #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/statistics/variant_stats.hpp" #include "duckdb/storage/storage_index.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/storage/storage_options.hpp" #include "duckdb/storage/table/chunk_info.hpp" #include "duckdb/storage/table/column_data.hpp" #include "duckdb/storage/table/column_segment.hpp" #include "duckdb/storage/table/row_group_reorderer.hpp" +#include "duckdb/storage/table/segment_tree.hpp" #include "duckdb/storage/table/table_index_list.hpp" #include "duckdb/storage/temporary_file_manager.hpp" @@ -393,6 +402,24 @@ AggregateOrderDependent EnumUtil::FromString(const char return static_cast(StringUtil::StringToEnum(GetAggregateOrderDependentValues(), 2, "AggregateOrderDependent", value)); } +const StringUtil::EnumStringLiteral *GetAggregateStateExportModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(AggregateStateExportMode::NONE), "NONE" }, + { static_cast(AggregateStateExportMode::STATE_EXPORT), "STATE_EXPORT" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(AggregateStateExportMode value) { + return StringUtil::EnumToString(GetAggregateStateExportModeValues(), 2, "AggregateStateExportMode", static_cast(value)); +} + +template<> +AggregateStateExportMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetAggregateStateExportModeValues(), 2, "AggregateStateExportMode", value)); +} + const StringUtil::EnumStringLiteral *GetAggregateTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(AggregateType::NON_DISTINCT), "NON_DISTINCT" }, @@ -1260,6 +1287,25 @@ CompressedMaterializationDirection EnumUtil::FromString(StringUtil::StringToEnum(GetCompressedMaterializationDirectionValues(), 3, "CompressedMaterializationDirection", value)); } +const StringUtil::EnumStringLiteral *GetCompressedMaterializationTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(CompressedMaterializationType::INVALID), "INVALID" }, + { static_cast(CompressedMaterializationType::FUNCTION), "FUNCTION" }, + { static_cast(CompressedMaterializationType::CAST), "CAST" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(CompressedMaterializationType value) { + return StringUtil::EnumToString(GetCompressedMaterializationTypeValues(), 3, "CompressedMaterializationType", static_cast(value)); +} + +template<> +CompressedMaterializationType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetCompressedMaterializationTypeValues(), 3, "CompressedMaterializationType", value)); +} + const StringUtil::EnumStringLiteral *GetCompressionTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(CompressionType::COMPRESSION_AUTO), "AUTO" }, @@ -1597,19 +1643,20 @@ const StringUtil::EnumStringLiteral *GetDebugVerificationModeValues() { { static_cast(DebugVerificationMode::DEFAULT), "DEFAULT" }, { static_cast(DebugVerificationMode::NONE), "NONE" }, { static_cast(DebugVerificationMode::VERIFY_VECTORS), "VERIFY_VECTORS" }, - { static_cast(DebugVerificationMode::VERIFY_SERIALIZATION), "VERIFY_SERIALIZATION" } + { static_cast(DebugVerificationMode::VERIFY_SERIALIZATION), "VERIFY_SERIALIZATION" }, + { static_cast(DebugVerificationMode::VERIFY_FUNCTIONS), "VERIFY_FUNCTIONS" } }; return values; } template<> const char* EnumUtil::ToChars(DebugVerificationMode value) { - return StringUtil::EnumToString(GetDebugVerificationModeValues(), 4, "DebugVerificationMode", static_cast(value)); + return StringUtil::EnumToString(GetDebugVerificationModeValues(), 5, "DebugVerificationMode", static_cast(value)); } template<> DebugVerificationMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDebugVerificationModeValues(), 4, "DebugVerificationMode", value)); + return static_cast(StringUtil::StringToEnum(GetDebugVerificationModeValues(), 5, "DebugVerificationMode", value)); } const StringUtil::EnumStringLiteral *GetDecimalBitWidthValues() { @@ -1727,6 +1774,26 @@ DestroyBufferUpon EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetDestroyBufferUponValues(), 3, "DestroyBufferUpon", value)); } +const StringUtil::EnumStringLiteral *GetDistinctCountSourceValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(DistinctCountSource::CARDINALITY), "CARDINALITY" }, + { static_cast(DistinctCountSource::MIN_MAX), "MIN_MAX" }, + { static_cast(DistinctCountSource::HLL), "HLL" }, + { static_cast(DistinctCountSource::EXACT), "EXACT" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(DistinctCountSource value) { + return StringUtil::EnumToString(GetDistinctCountSourceValues(), 4, "DistinctCountSource", static_cast(value)); +} + +template<> +DistinctCountSource EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetDistinctCountSourceValues(), 4, "DistinctCountSource", value)); +} + const StringUtil::EnumStringLiteral *GetDistinctTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(DistinctType::DISTINCT), "DISTINCT" }, @@ -2207,20 +2274,19 @@ const StringUtil::EnumStringLiteral *GetExtraTypeInfoTypeValues() { { static_cast(ExtraTypeInfoType::ANY_TYPE_INFO), "ANY_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), "INTEGER_LITERAL_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), "TEMPLATE_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::GEO_TYPE_INFO), "GEO_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO), "AGGREGATE_STATE_TYPE_INFO" } + { static_cast(ExtraTypeInfoType::GEO_TYPE_INFO), "GEO_TYPE_INFO" } }; return values; } template<> const char* EnumUtil::ToChars(ExtraTypeInfoType value) { - return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 15, "ExtraTypeInfoType", static_cast(value)); + return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 14, "ExtraTypeInfoType", static_cast(value)); } template<> ExtraTypeInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 15, "ExtraTypeInfoType", value)); + return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 14, "ExtraTypeInfoType", value)); } const StringUtil::EnumStringLiteral *GetFileBufferTypeValues() { @@ -2301,6 +2367,25 @@ FileGlobOptions EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetFileGlobOptionsValues(), 3, "FileGlobOptions", value)); } +const StringUtil::EnumStringLiteral *GetFileIOModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(FileIOMode::BUFFERED_IO), "BUFFERED_IO" }, + { static_cast(FileIOMode::MMAP), "MMAP" }, + { static_cast(FileIOMode::DIRECT_IO), "DIRECT_IO" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(FileIOMode value) { + return StringUtil::EnumToString(GetFileIOModeValues(), 3, "FileIOMode", static_cast(value)); +} + +template<> +FileIOMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetFileIOModeValues(), 3, "FileIOMode", value)); +} + const StringUtil::EnumStringLiteral *GetFileLockTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(FileLockType::NO_LOCK), "NO_LOCK" }, @@ -2731,6 +2816,24 @@ InterruptMode EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetInterruptModeValues(), 3, "InterruptMode", value)); } +const StringUtil::EnumStringLiteral *GetJoinFilterPushdownModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION), "RECONSTRUCT_EXPRESSION" }, + { static_cast(JoinFilterPushdownMode::STORAGE_ONLY), "STORAGE_ONLY" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(JoinFilterPushdownMode value) { + return StringUtil::EnumToString(GetJoinFilterPushdownModeValues(), 2, "JoinFilterPushdownMode", static_cast(value)); +} + +template<> +JoinFilterPushdownMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetJoinFilterPushdownModeValues(), 2, "JoinFilterPushdownMode", value)); +} + const StringUtil::EnumStringLiteral *GetJoinRefTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(JoinRefType::REGULAR), "REGULAR" }, @@ -2879,23 +2982,42 @@ LimitNodeType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetLimitNodeTypeValues(), 5, "LimitNodeType", value)); } +const StringUtil::EnumStringLiteral *GetLimitValueTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(LimitValueType::ROW_COUNT), "ROW_COUNT" }, + { static_cast(LimitValueType::PERCENTAGE), "PERCENTAGE" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(LimitValueType value) { + return StringUtil::EnumToString(GetLimitValueTypeValues(), 2, "LimitValueType", static_cast(value)); +} + +template<> +LimitValueType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetLimitValueTypeValues(), 2, "LimitValueType", value)); +} + const StringUtil::EnumStringLiteral *GetLoadTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(LoadType::LOAD), "LOAD" }, { static_cast(LoadType::INSTALL), "INSTALL" }, - { static_cast(LoadType::FORCE_INSTALL), "FORCE_INSTALL" } + { static_cast(LoadType::FORCE_INSTALL), "FORCE_INSTALL" }, + { static_cast(LoadType::LOAD_AS), "LOAD_AS" } }; return values; } template<> const char* EnumUtil::ToChars(LoadType value) { - return StringUtil::EnumToString(GetLoadTypeValues(), 3, "LoadType", static_cast(value)); + return StringUtil::EnumToString(GetLoadTypeValues(), 4, "LoadType", static_cast(value)); } template<> LoadType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLoadTypeValues(), 3, "LoadType", value)); + return static_cast(StringUtil::StringToEnum(GetLoadTypeValues(), 4, "LoadType", value)); } const StringUtil::EnumStringLiteral *GetLogContextScopeValues() { @@ -3040,6 +3162,8 @@ const StringUtil::EnumStringLiteral *GetLogicalOperatorTypeValues() { { static_cast(LogicalOperatorType::LOGICAL_LOAD), "LOGICAL_LOAD" }, { static_cast(LogicalOperatorType::LOGICAL_RESET), "LOGICAL_RESET" }, { static_cast(LogicalOperatorType::LOGICAL_UPDATE_EXTENSIONS), "LOGICAL_UPDATE_EXTENSIONS" }, + { static_cast(LogicalOperatorType::LOGICAL_CONNECT), "LOGICAL_CONNECT" }, + { static_cast(LogicalOperatorType::LOGICAL_DISCONNECT), "LOGICAL_DISCONNECT" }, { static_cast(LogicalOperatorType::LOGICAL_CREATE_SECRET), "LOGICAL_CREATE_SECRET" }, { static_cast(LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR), "LOGICAL_EXTENSION_OPERATOR" } }; @@ -3048,12 +3172,12 @@ const StringUtil::EnumStringLiteral *GetLogicalOperatorTypeValues() { template<> const char* EnumUtil::ToChars(LogicalOperatorType value) { - return StringUtil::EnumToString(GetLogicalOperatorTypeValues(), 63, "LogicalOperatorType", static_cast(value)); + return StringUtil::EnumToString(GetLogicalOperatorTypeValues(), 65, "LogicalOperatorType", static_cast(value)); } template<> LogicalOperatorType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalOperatorTypeValues(), 63, "LogicalOperatorType", value)); + return static_cast(StringUtil::StringToEnum(GetLogicalOperatorTypeValues(), 65, "LogicalOperatorType", value)); } const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { @@ -3110,20 +3234,19 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { { static_cast(LogicalTypeId::LAMBDA), "LAMBDA" }, { static_cast(LogicalTypeId::UNION), "UNION" }, { static_cast(LogicalTypeId::ARRAY), "ARRAY" }, - { static_cast(LogicalTypeId::VARIANT), "VARIANT" }, - { static_cast(LogicalTypeId::AGGREGATE_STATE), "AGGREGATE_STATE" } + { static_cast(LogicalTypeId::VARIANT), "VARIANT" } }; return values; } template<> const char* EnumUtil::ToChars(LogicalTypeId value) { - return StringUtil::EnumToString(GetLogicalTypeIdValues(), 54, "LogicalTypeId", static_cast(value)); + return StringUtil::EnumToString(GetLogicalTypeIdValues(), 53, "LogicalTypeId", static_cast(value)); } template<> LogicalTypeId EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 54, "LogicalTypeId", value)); + return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 53, "LogicalTypeId", value)); } const StringUtil::EnumStringLiteral *GetLookupResultTypeValues() { @@ -3276,119 +3399,6 @@ MetaPipelineType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetMetaPipelineTypeValues(), 2, "MetaPipelineType", value)); } -const StringUtil::EnumStringLiteral *GetMetricGroupValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MetricGroup::ALL), "ALL" }, - { static_cast(MetricGroup::CORE), "CORE" }, - { static_cast(MetricGroup::DEFAULT), "DEFAULT" }, - { static_cast(MetricGroup::EXECUTION), "EXECUTION" }, - { static_cast(MetricGroup::FILE), "FILE" }, - { static_cast(MetricGroup::OPERATOR), "OPERATOR" }, - { static_cast(MetricGroup::OPTIMIZER), "OPTIMIZER" }, - { static_cast(MetricGroup::PHASE_TIMING), "PHASE_TIMING" }, - { static_cast(MetricGroup::INVALID), "INVALID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(MetricGroup value) { - return StringUtil::EnumToString(GetMetricGroupValues(), 9, "MetricGroup", static_cast(value)); -} - -template<> -MetricGroup EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetricGroupValues(), 9, "MetricGroup", value)); -} - -const StringUtil::EnumStringLiteral *GetMetricTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MetricType::CPU_TIME), "CPU_TIME" }, - { static_cast(MetricType::CUMULATIVE_CARDINALITY), "CUMULATIVE_CARDINALITY" }, - { static_cast(MetricType::CUMULATIVE_ROWS_SCANNED), "CUMULATIVE_ROWS_SCANNED" }, - { static_cast(MetricType::EXTRA_INFO), "EXTRA_INFO" }, - { static_cast(MetricType::LATENCY), "LATENCY" }, - { static_cast(MetricType::QUERY_NAME), "QUERY_NAME" }, - { static_cast(MetricType::RESULT_SET_SIZE), "RESULT_SET_SIZE" }, - { static_cast(MetricType::ROWS_RETURNED), "ROWS_RETURNED" }, - { static_cast(MetricType::BLOCKED_THREAD_TIME), "BLOCKED_THREAD_TIME" }, - { static_cast(MetricType::SYSTEM_PEAK_BUFFER_MEMORY), "SYSTEM_PEAK_BUFFER_MEMORY" }, - { static_cast(MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE), "SYSTEM_PEAK_TEMP_DIR_SIZE" }, - { static_cast(MetricType::TOTAL_MEMORY_ALLOCATED), "TOTAL_MEMORY_ALLOCATED" }, - { static_cast(MetricType::ATTACH_LOAD_STORAGE_LATENCY), "ATTACH_LOAD_STORAGE_LATENCY" }, - { static_cast(MetricType::ATTACH_REPLAY_WAL_LATENCY), "ATTACH_REPLAY_WAL_LATENCY" }, - { static_cast(MetricType::CHECKPOINT_LATENCY), "CHECKPOINT_LATENCY" }, - { static_cast(MetricType::COMMIT_LOCAL_STORAGE_LATENCY), "COMMIT_LOCAL_STORAGE_LATENCY" }, - { static_cast(MetricType::TOTAL_BYTES_READ), "TOTAL_BYTES_READ" }, - { static_cast(MetricType::TOTAL_BYTES_WRITTEN), "TOTAL_BYTES_WRITTEN" }, - { static_cast(MetricType::WAITING_TO_ATTACH_LATENCY), "WAITING_TO_ATTACH_LATENCY" }, - { static_cast(MetricType::WAL_REPLAY_ENTRY_COUNT), "WAL_REPLAY_ENTRY_COUNT" }, - { static_cast(MetricType::WRITE_TO_WAL_LATENCY), "WRITE_TO_WAL_LATENCY" }, - { static_cast(MetricType::OPERATOR_CARDINALITY), "OPERATOR_CARDINALITY" }, - { static_cast(MetricType::OPERATOR_NAME), "OPERATOR_NAME" }, - { static_cast(MetricType::OPERATOR_ROWS_SCANNED), "OPERATOR_ROWS_SCANNED" }, - { static_cast(MetricType::OPERATOR_TIMING), "OPERATOR_TIMING" }, - { static_cast(MetricType::OPERATOR_TYPE), "OPERATOR_TYPE" }, - { static_cast(MetricType::OPTIMIZER_EXPRESSION_REWRITER), "OPTIMIZER_EXPRESSION_REWRITER" }, - { static_cast(MetricType::OPTIMIZER_FILTER_PULLUP), "OPTIMIZER_FILTER_PULLUP" }, - { static_cast(MetricType::OPTIMIZER_FILTER_PUSHDOWN), "OPTIMIZER_FILTER_PUSHDOWN" }, - { static_cast(MetricType::OPTIMIZER_EMPTY_RESULT_PULLUP), "OPTIMIZER_EMPTY_RESULT_PULLUP" }, - { static_cast(MetricType::OPTIMIZER_CTE_FILTER_PUSHER), "OPTIMIZER_CTE_FILTER_PUSHER" }, - { static_cast(MetricType::OPTIMIZER_REGEX_RANGE), "OPTIMIZER_REGEX_RANGE" }, - { static_cast(MetricType::OPTIMIZER_IN_CLAUSE), "OPTIMIZER_IN_CLAUSE" }, - { static_cast(MetricType::OPTIMIZER_JOIN_ORDER), "OPTIMIZER_JOIN_ORDER" }, - { static_cast(MetricType::OPTIMIZER_DELIMINATOR), "OPTIMIZER_DELIMINATOR" }, - { static_cast(MetricType::OPTIMIZER_UNNEST_REWRITER), "OPTIMIZER_UNNEST_REWRITER" }, - { static_cast(MetricType::OPTIMIZER_UNUSED_COLUMNS), "OPTIMIZER_UNUSED_COLUMNS" }, - { static_cast(MetricType::OPTIMIZER_STATISTICS_PROPAGATION), "OPTIMIZER_STATISTICS_PROPAGATION" }, - { static_cast(MetricType::OPTIMIZER_COMMON_SUBEXPRESSIONS), "OPTIMIZER_COMMON_SUBEXPRESSIONS" }, - { static_cast(MetricType::OPTIMIZER_COMMON_AGGREGATE), "OPTIMIZER_COMMON_AGGREGATE" }, - { static_cast(MetricType::OPTIMIZER_COLUMN_LIFETIME), "OPTIMIZER_COLUMN_LIFETIME" }, - { static_cast(MetricType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE), "OPTIMIZER_BUILD_SIDE_PROBE_SIDE" }, - { static_cast(MetricType::OPTIMIZER_LIMIT_PUSHDOWN), "OPTIMIZER_LIMIT_PUSHDOWN" }, - { static_cast(MetricType::OPTIMIZER_TOP_N), "OPTIMIZER_TOP_N" }, - { static_cast(MetricType::OPTIMIZER_COMPRESSED_MATERIALIZATION), "OPTIMIZER_COMPRESSED_MATERIALIZATION" }, - { static_cast(MetricType::OPTIMIZER_DUPLICATE_GROUPS), "OPTIMIZER_DUPLICATE_GROUPS" }, - { static_cast(MetricType::OPTIMIZER_REORDER_FILTER), "OPTIMIZER_REORDER_FILTER" }, - { static_cast(MetricType::OPTIMIZER_SAMPLING_PUSHDOWN), "OPTIMIZER_SAMPLING_PUSHDOWN" }, - { static_cast(MetricType::OPTIMIZER_JOIN_FILTER_PUSHDOWN), "OPTIMIZER_JOIN_FILTER_PUSHDOWN" }, - { static_cast(MetricType::OPTIMIZER_EXTENSION), "OPTIMIZER_EXTENSION" }, - { static_cast(MetricType::OPTIMIZER_MATERIALIZED_CTE), "OPTIMIZER_MATERIALIZED_CTE" }, - { static_cast(MetricType::OPTIMIZER_AGGREGATE_FUNCTION_REWRITER), "OPTIMIZER_AGGREGATE_FUNCTION_REWRITER" }, - { static_cast(MetricType::OPTIMIZER_LATE_MATERIALIZATION), "OPTIMIZER_LATE_MATERIALIZATION" }, - { static_cast(MetricType::OPTIMIZER_CTE_INLINING), "OPTIMIZER_CTE_INLINING" }, - { static_cast(MetricType::OPTIMIZER_ROW_GROUP_PRUNER), "OPTIMIZER_ROW_GROUP_PRUNER" }, - { static_cast(MetricType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION), "OPTIMIZER_TOP_N_WINDOW_ELIMINATION" }, - { static_cast(MetricType::OPTIMIZER_COMMON_SUBPLAN), "OPTIMIZER_COMMON_SUBPLAN" }, - { static_cast(MetricType::OPTIMIZER_JOIN_ELIMINATION), "OPTIMIZER_JOIN_ELIMINATION" }, - { static_cast(MetricType::OPTIMIZER_WINDOW_SELF_JOIN), "OPTIMIZER_WINDOW_SELF_JOIN" }, - { static_cast(MetricType::OPTIMIZER_PROJECTION_PULLUP), "OPTIMIZER_PROJECTION_PULLUP" }, - { static_cast(MetricType::OPTIMIZER_OUTER_JOIN_SIMPLIFICATION), "OPTIMIZER_OUTER_JOIN_SIMPLIFICATION" }, - { static_cast(MetricType::OPTIMIZER_ROW_NUMBER_REWRITER), "OPTIMIZER_ROW_NUMBER_REWRITER" }, - { static_cast(MetricType::OPTIMIZER_PARTITIONED_EXECUTION), "OPTIMIZER_PARTITIONED_EXECUTION" }, - { static_cast(MetricType::ALL_OPTIMIZERS), "ALL_OPTIMIZERS" }, - { static_cast(MetricType::CUMULATIVE_OPTIMIZER_TIMING), "CUMULATIVE_OPTIMIZER_TIMING" }, - { static_cast(MetricType::PARSER), "PARSER" }, - { static_cast(MetricType::PHYSICAL_PLANNER), "PHYSICAL_PLANNER" }, - { static_cast(MetricType::PHYSICAL_PLANNER_COLUMN_BINDING), "PHYSICAL_PLANNER_COLUMN_BINDING" }, - { static_cast(MetricType::PHYSICAL_PLANNER_CREATE_PLAN), "PHYSICAL_PLANNER_CREATE_PLAN" }, - { static_cast(MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES), "PHYSICAL_PLANNER_RESOLVE_TYPES" }, - { static_cast(MetricType::PLANNER), "PLANNER" }, - { static_cast(MetricType::PLANNER_BINDING), "PLANNER_BINDING" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(MetricType value) { - return StringUtil::EnumToString(GetMetricTypeValues(), 72, "MetricType", static_cast(value)); -} - -template<> -MetricType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetricTypeValues(), 72, "MetricType", value)); -} - const StringUtil::EnumStringLiteral *GetMonotonicityValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(Monotonicity::UNKNOWN), "UNKNOWN" }, @@ -3669,19 +3679,21 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { { static_cast(OptimizerType::PROJECTION_PULLUP), "PROJECTION_PULLUP" }, { static_cast(OptimizerType::OUTER_JOIN_SIMPLIFICATION), "OUTER_JOIN_SIMPLIFICATION" }, { static_cast(OptimizerType::ROW_NUMBER_REWRITER), "ROW_NUMBER_REWRITER" }, - { static_cast(OptimizerType::PARTITIONED_EXECUTION), "PARTITIONED_EXECUTION" } + { static_cast(OptimizerType::PARTITIONED_EXECUTION), "PARTITIONED_EXECUTION" }, + { static_cast(OptimizerType::PARTIAL_AGGREGATE_PUSHDOWN), "PARTIAL_AGGREGATE_PUSHDOWN" }, + { static_cast(OptimizerType::REMOTE_PUSHDOWN), "REMOTE_PUSHDOWN" } }; return values; } template<> const char* EnumUtil::ToChars(OptimizerType value) { - return StringUtil::EnumToString(GetOptimizerTypeValues(), 38, "OptimizerType", static_cast(value)); + return StringUtil::EnumToString(GetOptimizerTypeValues(), 40, "OptimizerType", static_cast(value)); } template<> OptimizerType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 38, "OptimizerType", value)); + return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 40, "OptimizerType", value)); } const StringUtil::EnumStringLiteral *GetOrderByColumnTypeValues() { @@ -3861,19 +3873,21 @@ const StringUtil::EnumStringLiteral *GetParseInfoTypeValues() { { static_cast(ParseInfoType::COMMENT_ON_INFO), "COMMENT_ON_INFO" }, { static_cast(ParseInfoType::COMMENT_ON_COLUMN_INFO), "COMMENT_ON_COLUMN_INFO" }, { static_cast(ParseInfoType::COPY_DATABASE_INFO), "COPY_DATABASE_INFO" }, - { static_cast(ParseInfoType::UPDATE_EXTENSIONS_INFO), "UPDATE_EXTENSIONS_INFO" } + { static_cast(ParseInfoType::UPDATE_EXTENSIONS_INFO), "UPDATE_EXTENSIONS_INFO" }, + { static_cast(ParseInfoType::CONNECT_INFO), "CONNECT_INFO" }, + { static_cast(ParseInfoType::DISCONNECT_INFO), "DISCONNECT_INFO" } }; return values; } template<> const char* EnumUtil::ToChars(ParseInfoType value) { - return StringUtil::EnumToString(GetParseInfoTypeValues(), 17, "ParseInfoType", static_cast(value)); + return StringUtil::EnumToString(GetParseInfoTypeValues(), 19, "ParseInfoType", static_cast(value)); } template<> ParseInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetParseInfoTypeValues(), 17, "ParseInfoType", value)); + return static_cast(StringUtil::StringToEnum(GetParseInfoTypeValues(), 19, "ParseInfoType", value)); } const StringUtil::EnumStringLiteral *GetParseResultTypeValues() { @@ -3890,6 +3904,7 @@ const StringUtil::EnumStringLiteral *GetParseResultTypeValues() { { static_cast(ParseResultType::EXTENSION), "EXTENSION" }, { static_cast(ParseResultType::NUMBER), "NUMBER" }, { static_cast(ParseResultType::STRING), "STRING" }, + { static_cast(ParseResultType::END_OF_INPUT), "END_OF_INPUT" }, { static_cast(ParseResultType::INVALID), "INVALID" } }; return values; @@ -3897,12 +3912,12 @@ const StringUtil::EnumStringLiteral *GetParseResultTypeValues() { template<> const char* EnumUtil::ToChars(ParseResultType value) { - return StringUtil::EnumToString(GetParseResultTypeValues(), 13, "ParseResultType", static_cast(value)); + return StringUtil::EnumToString(GetParseResultTypeValues(), 14, "ParseResultType", static_cast(value)); } template<> ParseResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetParseResultTypeValues(), 13, "ParseResultType", value)); + return static_cast(StringUtil::StringToEnum(GetParseResultTypeValues(), 14, "ParseResultType", value)); } const StringUtil::EnumStringLiteral *GetParserExtensionResultTypeValues() { @@ -4068,6 +4083,8 @@ const StringUtil::EnumStringLiteral *GetPhysicalOperatorTypeValues() { { static_cast(PhysicalOperatorType::EXTENSION), "EXTENSION" }, { static_cast(PhysicalOperatorType::VERIFY_VECTOR), "VERIFY_VECTOR" }, { static_cast(PhysicalOperatorType::UPDATE_EXTENSIONS), "UPDATE_EXTENSIONS" }, + { static_cast(PhysicalOperatorType::CONNECT), "CONNECT" }, + { static_cast(PhysicalOperatorType::DISCONNECT), "DISCONNECT" }, { static_cast(PhysicalOperatorType::CREATE_SECRET), "CREATE_SECRET" } }; return values; @@ -4075,12 +4092,12 @@ const StringUtil::EnumStringLiteral *GetPhysicalOperatorTypeValues() { template<> const char* EnumUtil::ToChars(PhysicalOperatorType value) { - return StringUtil::EnumToString(GetPhysicalOperatorTypeValues(), 84, "PhysicalOperatorType", static_cast(value)); + return StringUtil::EnumToString(GetPhysicalOperatorTypeValues(), 86, "PhysicalOperatorType", static_cast(value)); } template<> PhysicalOperatorType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPhysicalOperatorTypeValues(), 84, "PhysicalOperatorType", value)); + return static_cast(StringUtil::StringToEnum(GetPhysicalOperatorTypeValues(), 86, "PhysicalOperatorType", value)); } const StringUtil::EnumStringLiteral *GetPhysicalTableScanExecutionStrategyValues() { @@ -4327,19 +4344,20 @@ const StringUtil::EnumStringLiteral *GetQueryNodeTypeValues() { { static_cast(QueryNodeType::STATEMENT_NODE), "STATEMENT_NODE" }, { static_cast(QueryNodeType::UPDATE_QUERY_NODE), "UPDATE_QUERY_NODE" }, { static_cast(QueryNodeType::DELETE_QUERY_NODE), "DELETE_QUERY_NODE" }, - { static_cast(QueryNodeType::INSERT_QUERY_NODE), "INSERT_QUERY_NODE" } + { static_cast(QueryNodeType::INSERT_QUERY_NODE), "INSERT_QUERY_NODE" }, + { static_cast(QueryNodeType::MERGE_QUERY_NODE), "MERGE_QUERY_NODE" } }; return values; } template<> const char* EnumUtil::ToChars(QueryNodeType value) { - return StringUtil::EnumToString(GetQueryNodeTypeValues(), 9, "QueryNodeType", static_cast(value)); + return StringUtil::EnumToString(GetQueryNodeTypeValues(), 10, "QueryNodeType", static_cast(value)); } template<> QueryNodeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 9, "QueryNodeType", value)); + return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 10, "QueryNodeType", value)); } const StringUtil::EnumStringLiteral *GetQueryResultMemoryTypeValues() { @@ -4416,6 +4434,44 @@ RecoveryMode EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetRecoveryModeValues(), 2, "RecoveryMode", value)); } +const StringUtil::EnumStringLiteral *GetRecursiveCTEInlineStageTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(RecursiveCTEInlineStageType::EXECUTE), "EXECUTE" }, + { static_cast(RecursiveCTEInlineStageType::PREPARE_FINISH), "PREPARE_FINISH" }, + { static_cast(RecursiveCTEInlineStageType::FINISH), "FINISH" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(RecursiveCTEInlineStageType value) { + return StringUtil::EnumToString(GetRecursiveCTEInlineStageTypeValues(), 3, "RecursiveCTEInlineStageType", static_cast(value)); +} + +template<> +RecursiveCTEInlineStageType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetRecursiveCTEInlineStageTypeValues(), 3, "RecursiveCTEInlineStageType", value)); +} + +const StringUtil::EnumStringLiteral *GetRecursiveProbeSidePreferenceValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(RecursiveProbeSidePreference::NONE), "NONE" }, + { static_cast(RecursiveProbeSidePreference::KEEP), "KEEP" }, + { static_cast(RecursiveProbeSidePreference::SWAP), "SWAP" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(RecursiveProbeSidePreference value) { + return StringUtil::EnumToString(GetRecursiveProbeSidePreferenceValues(), 3, "RecursiveProbeSidePreference", static_cast(value)); +} + +template<> +RecursiveProbeSidePreference EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetRecursiveProbeSidePreferenceValues(), 3, "RecursiveProbeSidePreference", value)); +} + const StringUtil::EnumStringLiteral *GetRelationTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(RelationType::INVALID_RELATION), "INVALID_RELATION" }, @@ -4461,6 +4517,25 @@ RelationType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetRelationTypeValues(), 29, "RelationType", value)); } +const StringUtil::EnumStringLiteral *GetRemoteCapabilityValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(RemoteCapability::IS_REMOTE), "IS_REMOTE" }, + { static_cast(RemoteCapability::EXECUTE_QUERY_NODE), "EXECUTE_QUERY_NODE" }, + { static_cast(RemoteCapability::CONNECT), "CONNECT" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(RemoteCapability value) { + return StringUtil::EnumToString(GetRemoteCapabilityValues(), 3, "RemoteCapability", static_cast(value)); +} + +template<> +RemoteCapability EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetRemoteCapabilityValues(), 3, "RemoteCapability", value)); +} + const StringUtil::EnumStringLiteral *GetRenderModeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(RenderMode::ROWS), "ROWS" }, @@ -4506,7 +4581,7 @@ const StringUtil::EnumStringLiteral *GetResultModifierTypeValues() { { static_cast(ResultModifierType::LIMIT_MODIFIER), "LIMIT_MODIFIER" }, { static_cast(ResultModifierType::ORDER_MODIFIER), "ORDER_MODIFIER" }, { static_cast(ResultModifierType::DISTINCT_MODIFIER), "DISTINCT_MODIFIER" }, - { static_cast(ResultModifierType::LIMIT_PERCENT_MODIFIER), "LIMIT_PERCENT_MODIFIER" } + { static_cast(ResultModifierType::LEGACY_LIMIT_PERCENT_MODIFIER), "LEGACY_LIMIT_PERCENT_MODIFIER" } }; return values; } @@ -4671,6 +4746,24 @@ SecretSerializationType EnumUtil::FromString(const char return static_cast(StringUtil::StringToEnum(GetSecretSerializationTypeValues(), 2, "SecretSerializationType", value)); } +const StringUtil::EnumStringLiteral *GetSegmentTreeVerifyModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(SegmentTreeVerifyMode::CONTIGUOUS), "CONTIGUOUS" }, + { static_cast(SegmentTreeVerifyMode::NON_OVERLAPPING), "NON_OVERLAPPING" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(SegmentTreeVerifyMode value) { + return StringUtil::EnumToString(GetSegmentTreeVerifyModeValues(), 2, "SegmentTreeVerifyMode", static_cast(value)); +} + +template<> +SegmentTreeVerifyMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetSegmentTreeVerifyModeValues(), 2, "SegmentTreeVerifyMode", value)); +} + const StringUtil::EnumStringLiteral *GetSelectivityOptionalFilterTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(SelectivityOptionalFilterType::MIN_MAX), "MIN_MAX" }, @@ -4713,6 +4806,49 @@ SequenceInfo EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetSequenceInfoValues(), 6, "SequenceInfo", value)); } +const StringUtil::EnumStringLiteral *GetSerializationVersionDeprecatedValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(SerializationVersionDeprecated::V0_10_0), "V0_10_0" }, + { static_cast(SerializationVersionDeprecated::V0_10_1), "V0_10_1" }, + { static_cast(SerializationVersionDeprecated::V0_10_2), "V0_10_2" }, + { static_cast(SerializationVersionDeprecated::V0_10_3), "V0_10_3" }, + { static_cast(SerializationVersionDeprecated::V1_0_0), "V1_0_0" }, + { static_cast(SerializationVersionDeprecated::V1_1_0), "V1_1_0" }, + { static_cast(SerializationVersionDeprecated::V1_1_1), "V1_1_1" }, + { static_cast(SerializationVersionDeprecated::V1_1_2), "V1_1_2" }, + { static_cast(SerializationVersionDeprecated::V1_1_3), "V1_1_3" }, + { static_cast(SerializationVersionDeprecated::V1_2_0), "V1_2_0" }, + { static_cast(SerializationVersionDeprecated::V1_2_1), "V1_2_1" }, + { static_cast(SerializationVersionDeprecated::V1_2_2), "V1_2_2" }, + { static_cast(SerializationVersionDeprecated::V1_3_0), "V1_3_0" }, + { static_cast(SerializationVersionDeprecated::V1_3_1), "V1_3_1" }, + { static_cast(SerializationVersionDeprecated::V1_3_2), "V1_3_2" }, + { static_cast(SerializationVersionDeprecated::V1_4_0), "V1_4_0" }, + { static_cast(SerializationVersionDeprecated::V1_4_1), "V1_4_1" }, + { static_cast(SerializationVersionDeprecated::V1_4_2), "V1_4_2" }, + { static_cast(SerializationVersionDeprecated::V1_4_3), "V1_4_3" }, + { static_cast(SerializationVersionDeprecated::V1_4_4), "V1_4_4" }, + { static_cast(SerializationVersionDeprecated::V1_5_0), "V1_5_0" }, + { static_cast(SerializationVersionDeprecated::V1_5_1), "V1_5_1" }, + { static_cast(SerializationVersionDeprecated::V1_5_2), "V1_5_2" }, + { static_cast(SerializationVersionDeprecated::V1_5_3), "V1_5_3" }, + { static_cast(SerializationVersionDeprecated::V2_0_0), "V2_0_0" }, + { static_cast(SerializationVersionDeprecated::LATEST), "LATEST" }, + { static_cast(SerializationVersionDeprecated::INVALID), "INVALID" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(SerializationVersionDeprecated value) { + return StringUtil::EnumToString(GetSerializationVersionDeprecatedValues(), 27, "SerializationVersionDeprecated", static_cast(value)); +} + +template<> +SerializationVersionDeprecated EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetSerializationVersionDeprecatedValues(), 27, "SerializationVersionDeprecated", value)); +} + const StringUtil::EnumStringLiteral *GetSetOperationTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(SetOperationType::NONE), "NONE" }, @@ -5028,19 +5164,21 @@ const StringUtil::EnumStringLiteral *GetStatementTypeValues() { { static_cast(StatementType::MULTI_STATEMENT), "MULTI_STATEMENT" }, { static_cast(StatementType::COPY_DATABASE_STATEMENT), "COPY_DATABASE_STATEMENT" }, { static_cast(StatementType::UPDATE_EXTENSIONS_STATEMENT), "UPDATE_EXTENSIONS_STATEMENT" }, - { static_cast(StatementType::MERGE_INTO_STATEMENT), "MERGE_INTO_STATEMENT" } + { static_cast(StatementType::MERGE_INTO_STATEMENT), "MERGE_INTO_STATEMENT" }, + { static_cast(StatementType::CONNECT_STATEMENT), "CONNECT_STATEMENT" }, + { static_cast(StatementType::DISCONNECT_STATEMENT), "DISCONNECT_STATEMENT" } }; return values; } template<> const char* EnumUtil::ToChars(StatementType value) { - return StringUtil::EnumToString(GetStatementTypeValues(), 31, "StatementType", static_cast(value)); + return StringUtil::EnumToString(GetStatementTypeValues(), 33, "StatementType", static_cast(value)); } template<> StatementType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStatementTypeValues(), 31, "StatementType", value)); + return static_cast(StringUtil::StringToEnum(GetStatementTypeValues(), 33, "StatementType", value)); } const StringUtil::EnumStringLiteral *GetStatisticsTypeValues() { @@ -5126,6 +5264,89 @@ StorageIndexType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetStorageIndexTypeValues(), 2, "StorageIndexType", value)); } +const StringUtil::EnumStringLiteral *GetStorageVersionValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(StorageVersion::V0_0_4), "V0_0_4" }, + { static_cast(StorageVersion::V0_1_0), "V0_1_0" }, + { static_cast(StorageVersion::V0_1_1), "V0_1_1" }, + { static_cast(StorageVersion::V0_1_2), "V0_1_2" }, + { static_cast(StorageVersion::V0_1_3), "V0_1_3" }, + { static_cast(StorageVersion::V0_1_4), "V0_1_4" }, + { static_cast(StorageVersion::V0_1_5), "V0_1_5" }, + { static_cast(StorageVersion::V0_1_6), "V0_1_6" }, + { static_cast(StorageVersion::V0_1_7), "V0_1_7" }, + { static_cast(StorageVersion::V0_1_8), "V0_1_8" }, + { static_cast(StorageVersion::V0_1_9), "V0_1_9" }, + { static_cast(StorageVersion::V0_2_0), "V0_2_0" }, + { static_cast(StorageVersion::V0_2_1), "V0_2_1" }, + { static_cast(StorageVersion::V0_2_2), "V0_2_2" }, + { static_cast(StorageVersion::V0_2_3), "V0_2_3" }, + { static_cast(StorageVersion::V0_2_4), "V0_2_4" }, + { static_cast(StorageVersion::V0_2_5), "V0_2_5" }, + { static_cast(StorageVersion::V0_2_6), "V0_2_6" }, + { static_cast(StorageVersion::V0_2_7), "V0_2_7" }, + { static_cast(StorageVersion::V0_2_8), "V0_2_8" }, + { static_cast(StorageVersion::V0_2_9), "V0_2_9" }, + { static_cast(StorageVersion::V0_3_0), "V0_3_0" }, + { static_cast(StorageVersion::V0_3_1), "V0_3_1" }, + { static_cast(StorageVersion::V0_3_2), "V0_3_2" }, + { static_cast(StorageVersion::V0_3_3), "V0_3_3" }, + { static_cast(StorageVersion::V0_3_4), "V0_3_4" }, + { static_cast(StorageVersion::V0_3_5), "V0_3_5" }, + { static_cast(StorageVersion::V0_4_0), "V0_4_0" }, + { static_cast(StorageVersion::V0_5_0), "V0_5_0" }, + { static_cast(StorageVersion::V0_5_1), "V0_5_1" }, + { static_cast(StorageVersion::V0_6_0), "V0_6_0" }, + { static_cast(StorageVersion::V0_6_1), "V0_6_1" }, + { static_cast(StorageVersion::V0_7_0), "V0_7_0" }, + { static_cast(StorageVersion::V0_7_1), "V0_7_1" }, + { static_cast(StorageVersion::V0_8_0), "V0_8_0" }, + { static_cast(StorageVersion::V0_8_1), "V0_8_1" }, + { static_cast(StorageVersion::V0_9_0), "V0_9_0" }, + { static_cast(StorageVersion::V0_9_1), "V0_9_1" }, + { static_cast(StorageVersion::V0_9_2), "V0_9_2" }, + { static_cast(StorageVersion::V0_10_0), "V0_10_0" }, + { static_cast(StorageVersion::V0_10_1), "V0_10_1" }, + { static_cast(StorageVersion::V0_10_2), "V0_10_2" }, + { static_cast(StorageVersion::V0_10_3), "V0_10_3" }, + { static_cast(StorageVersion::V1_0_0), "V1_0_0" }, + { static_cast(StorageVersion::V1_1_0), "V1_1_0" }, + { static_cast(StorageVersion::V1_1_1), "V1_1_1" }, + { static_cast(StorageVersion::V1_1_2), "V1_1_2" }, + { static_cast(StorageVersion::V1_1_3), "V1_1_3" }, + { static_cast(StorageVersion::V1_2_0), "V1_2_0" }, + { static_cast(StorageVersion::V1_2_1), "V1_2_1" }, + { static_cast(StorageVersion::V1_2_2), "V1_2_2" }, + { static_cast(StorageVersion::V1_3_0), "V1_3_0" }, + { static_cast(StorageVersion::V1_3_1), "V1_3_1" }, + { static_cast(StorageVersion::V1_3_2), "V1_3_2" }, + { static_cast(StorageVersion::V1_4_0), "V1_4_0" }, + { static_cast(StorageVersion::V1_4_1), "V1_4_1" }, + { static_cast(StorageVersion::V1_4_2), "V1_4_2" }, + { static_cast(StorageVersion::V1_4_3), "V1_4_3" }, + { static_cast(StorageVersion::V1_4_4), "V1_4_4" }, + { static_cast(StorageVersion::V1_5_0), "V1_5_0" }, + { static_cast(StorageVersion::V1_5_1), "V1_5_1" }, + { static_cast(StorageVersion::V1_5_2), "V1_5_2" }, + { static_cast(StorageVersion::V1_5_3), "V1_5_3" }, + { static_cast(StorageVersion::V2_0_0), "V2_0_0" }, + { static_cast(StorageVersion::LATEST), "LATEST" }, + { static_cast(StorageVersion::DEPRECATED), "DEPRECATED" }, + { static_cast(StorageVersion::INVALID), "INVALID" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(StorageVersion value) { + return StringUtil::EnumToString(GetStorageVersionValues(), 67, "StorageVersion", static_cast(value)); +} + +template<> +StorageVersion EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetStorageVersionValues(), 67, "StorageVersion", value)); +} + const StringUtil::EnumStringLiteral *GetStrTimeSpecifierValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME), "ABBREVIATED_WEEKDAY_NAME" }, @@ -5272,19 +5493,19 @@ TableColumnType EnumUtil::FromString(const char *value) { const StringUtil::EnumStringLiteral *GetTableFilterTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TableFilterType::CONSTANT_COMPARISON), "CONSTANT_COMPARISON" }, - { static_cast(TableFilterType::IS_NULL), "IS_NULL" }, - { static_cast(TableFilterType::IS_NOT_NULL), "IS_NOT_NULL" }, - { static_cast(TableFilterType::CONJUNCTION_OR), "CONJUNCTION_OR" }, - { static_cast(TableFilterType::CONJUNCTION_AND), "CONJUNCTION_AND" }, - { static_cast(TableFilterType::STRUCT_EXTRACT), "STRUCT_EXTRACT" }, - { static_cast(TableFilterType::OPTIONAL_FILTER), "OPTIONAL_FILTER" }, - { static_cast(TableFilterType::IN_FILTER), "IN_FILTER" }, - { static_cast(TableFilterType::DYNAMIC_FILTER), "DYNAMIC_FILTER" }, + { static_cast(TableFilterType::LEGACY_CONSTANT_COMPARISON), "LEGACY_CONSTANT_COMPARISON" }, + { static_cast(TableFilterType::LEGACY_IS_NULL), "LEGACY_IS_NULL" }, + { static_cast(TableFilterType::LEGACY_IS_NOT_NULL), "LEGACY_IS_NOT_NULL" }, + { static_cast(TableFilterType::LEGACY_CONJUNCTION_OR), "LEGACY_CONJUNCTION_OR" }, + { static_cast(TableFilterType::LEGACY_CONJUNCTION_AND), "LEGACY_CONJUNCTION_AND" }, + { static_cast(TableFilterType::LEGACY_STRUCT_EXTRACT), "LEGACY_STRUCT_EXTRACT" }, + { static_cast(TableFilterType::LEGACY_OPTIONAL_FILTER), "LEGACY_OPTIONAL_FILTER" }, + { static_cast(TableFilterType::LEGACY_IN_FILTER), "LEGACY_IN_FILTER" }, + { static_cast(TableFilterType::LEGACY_DYNAMIC_FILTER), "LEGACY_DYNAMIC_FILTER" }, { static_cast(TableFilterType::EXPRESSION_FILTER), "EXPRESSION_FILTER" }, - { static_cast(TableFilterType::BLOOM_FILTER), "BLOOM_FILTER" }, - { static_cast(TableFilterType::PERFECT_HASH_JOIN_FILTER), "PERFECT_HASH_JOIN_FILTER" }, - { static_cast(TableFilterType::PREFIX_RANGE_FILTER), "PREFIX_RANGE_FILTER" } + { static_cast(TableFilterType::LEGACY_BLOOM_FILTER), "LEGACY_BLOOM_FILTER" }, + { static_cast(TableFilterType::LEGACY_PERFECT_HASH_JOIN_FILTER), "LEGACY_PERFECT_HASH_JOIN_FILTER" }, + { static_cast(TableFilterType::LEGACY_PREFIX_RANGE_FILTER), "LEGACY_PREFIX_RANGE_FILTER" } }; return values; } @@ -5405,6 +5626,25 @@ TaskExecutionResult EnumUtil::FromString(const char *value) return static_cast(StringUtil::StringToEnum(GetTaskExecutionResultValues(), 4, "TaskExecutionResult", value)); } +const StringUtil::EnumStringLiteral *GetTaskSchedulerTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(TaskSchedulerType::REGULAR), "REGULAR" }, + { static_cast(TaskSchedulerType::ASYNC), "ASYNC" }, + { static_cast(TaskSchedulerType::ENUM_SIZE), "ENUM_SIZE" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(TaskSchedulerType value) { + return StringUtil::EnumToString(GetTaskSchedulerTypeValues(), 3, "TaskSchedulerType", static_cast(value)); +} + +template<> +TaskSchedulerType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetTaskSchedulerTypeValues(), 3, "TaskSchedulerType", value)); +} + const StringUtil::EnumStringLiteral *GetTemporaryBufferSizeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(TemporaryBufferSize::INVALID), "INVALID" }, @@ -5496,6 +5736,7 @@ TimestampCastResult EnumUtil::FromString(const char *value) const StringUtil::EnumStringLiteral *GetTransactionInvalidationPolicyValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(TransactionInvalidationPolicy::STANDARD_POLICY), "STANDARD_POLICY" }, + { static_cast(TransactionInvalidationPolicy::SYNTACTIC_ERRORS_DO_NOT_INVALIDATE), "SYNTACTIC_ERRORS_DO_NOT_INVALIDATE" }, { static_cast(TransactionInvalidationPolicy::ALL_ERRORS_INVALIDATE_TRANSACTION), "ALL_ERRORS_INVALIDATE_TRANSACTION" } }; return values; @@ -5503,12 +5744,12 @@ const StringUtil::EnumStringLiteral *GetTransactionInvalidationPolicyValues() { template<> const char* EnumUtil::ToChars(TransactionInvalidationPolicy value) { - return StringUtil::EnumToString(GetTransactionInvalidationPolicyValues(), 2, "TransactionInvalidationPolicy", static_cast(value)); + return StringUtil::EnumToString(GetTransactionInvalidationPolicyValues(), 3, "TransactionInvalidationPolicy", static_cast(value)); } template<> TransactionInvalidationPolicy EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTransactionInvalidationPolicyValues(), 2, "TransactionInvalidationPolicy", value)); + return static_cast(StringUtil::StringToEnum(GetTransactionInvalidationPolicyValues(), 3, "TransactionInvalidationPolicy", value)); } const StringUtil::EnumStringLiteral *GetTransactionModifierTypeValues() { diff --git a/src/duckdb/src/common/enums/column_segment_info_scan_type.cpp b/src/duckdb/src/common/enums/column_segment_info_scan_type.cpp new file mode 100644 index 000000000..8a4aa6406 --- /dev/null +++ b/src/duckdb/src/common/enums/column_segment_info_scan_type.cpp @@ -0,0 +1,7 @@ +#include "duckdb/common/enums/column_segment_info_scan_type.hpp" + +namespace duckdb { + +// ColumnSegmentInfoScanOptions is a plain struct; nothing to define here yet. + +} // namespace duckdb diff --git a/src/duckdb/src/common/enums/compression_type.cpp b/src/duckdb/src/common/enums/compression_type.cpp index 15cfdac70..3d9dd02a4 100644 --- a/src/duckdb/src/common/enums/compression_type.cpp +++ b/src/duckdb/src/common/enums/compression_type.cpp @@ -1,6 +1,7 @@ #include "duckdb/common/enums/compression_type.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/storage/storage_info.hpp" #include "duckdb/storage/storage_manager.hpp" namespace duckdb { @@ -20,21 +21,29 @@ vector ListCompressionTypes(void) { namespace { struct CompressionMethodRequirements { CompressionType type; - optional_idx minimum_storage_version; - optional_idx maximum_storage_version; + StorageVersion minimum_storage_version; + StorageVersion maximum_storage_version; }; } // namespace CompressionAvailabilityResult CompressionTypeIsAvailable(CompressionType compression_type, optional_ptr storage_manager) { //! Max storage compatibility - vector candidates({{CompressionType::COMPRESSION_PATAS, optional_idx(), 0}, - {CompressionType::COMPRESSION_CHIMP, optional_idx(), 0}, - {CompressionType::COMPRESSION_DICTIONARY, 0, 4}, - {CompressionType::COMPRESSION_FSST, 0, 4}, - {CompressionType::COMPRESSION_DICT_FSST, 5, optional_idx()}}); + // if minimum version is StorageVersion::INVALID, this is an indication that the + // compression method is deprecated (e.g. chimp and patas) + vector candidates( + {{CompressionType::COMPRESSION_PATAS, StorageVersion::INVALID, StorageVersion::V0_10_2}, // phased out + {CompressionType::COMPRESSION_CHIMP, StorageVersion::INVALID, StorageVersion::V0_10_2}, // phased out + {CompressionType::COMPRESSION_DICTIONARY, StorageVersion::V0_10_2, StorageVersion::V1_2_0}, + {CompressionType::COMPRESSION_FSST, StorageVersion::V0_10_2, StorageVersion::V1_2_0}, + {CompressionType::COMPRESSION_ROARING, StorageVersion::V1_2_0, StorageVersion::LATEST}, + {CompressionType::COMPRESSION_ZSTD, StorageVersion::V1_2_0, StorageVersion::LATEST}, + {CompressionType::COMPRESSION_DICT_FSST, StorageVersion::V1_3_0, StorageVersion::LATEST}, + // Not implemented yet + {CompressionType::COMPRESSION_PFOR_DELTA, (StorageVersion)((int)StorageVersion::LATEST + 1), + StorageVersion::INVALID}}); - optional_idx current_storage_version; + StorageVersion current_storage_version = StorageVersion::INVALID; if (storage_manager && storage_manager->HasStorageVersion()) { current_storage_version = storage_manager->GetStorageVersion(); } @@ -46,24 +55,24 @@ CompressionAvailabilityResult CompressionTypeIsAvailable(CompressionType compres auto &min = candidate.minimum_storage_version; auto &max = candidate.maximum_storage_version; - if (!min.IsValid()) { + if (min == StorageVersion::INVALID) { //! Used to signal: always deprecated return CompressionAvailabilityResult::Deprecated(); } - if (!current_storage_version.IsValid()) { + if (current_storage_version == StorageVersion::INVALID) { //! Can't determine in this call whether it's available or not, default to available return CompressionAvailabilityResult(); } - auto current_version = current_storage_version.GetIndex(); - D_ASSERT(min.IsValid()); - if (min.GetIndex() > current_version) { + auto current_version = current_storage_version; + D_ASSERT(min != StorageVersion::INVALID); + if (min > current_version) { //! Minimum required storage version is higher than the current storage version, this method isn't available //! yet return CompressionAvailabilityResult::NotAvailableYet(); } - if (max.IsValid() && max.GetIndex() < current_version) { + if (max != StorageVersion::INVALID && max < current_version) { //! Maximum supported storage version is lower than the current storage version, this method is no longer //! available return CompressionAvailabilityResult::Deprecated(); diff --git a/src/duckdb/src/common/enums/logical_operator_type.cpp b/src/duckdb/src/common/enums/logical_operator_type.cpp index f3c5d3784..772e5f502 100644 --- a/src/duckdb/src/common/enums/logical_operator_type.cpp +++ b/src/duckdb/src/common/enums/logical_operator_type.cpp @@ -134,6 +134,10 @@ string LogicalOperatorToString(LogicalOperatorType type) { return "PIVOT"; case LogicalOperatorType::LOGICAL_UPDATE_EXTENSIONS: return "UPDATE_EXTENSIONS"; + case LogicalOperatorType::LOGICAL_CONNECT: + return "CONNECT"; + case LogicalOperatorType::LOGICAL_DISCONNECT: + return "DISCONNECT"; } return "INVALID"; } diff --git a/src/duckdb/src/common/enums/metric_type.cpp b/src/duckdb/src/common/enums/metric_type.cpp deleted file mode 100644 index 493ecf3fd..000000000 --- a/src/duckdb/src/common/enums/metric_type.cpp +++ /dev/null @@ -1,479 +0,0 @@ -// This file is automatically generated by scripts/generate_metric_enums.py -// Do not edit this file manually, your changes will be overwritten - -#include "duckdb/common/enums/metric_type.hpp" -#include "duckdb/common/enum_util.hpp" - -namespace duckdb { - -profiler_settings_t MetricsUtils::GetAllMetrics() { - return { - MetricType::ALL_OPTIMIZERS, - MetricType::ATTACH_LOAD_STORAGE_LATENCY, - MetricType::ATTACH_REPLAY_WAL_LATENCY, - MetricType::BLOCKED_THREAD_TIME, - MetricType::CHECKPOINT_LATENCY, - MetricType::COMMIT_LOCAL_STORAGE_LATENCY, - MetricType::CPU_TIME, - MetricType::CUMULATIVE_CARDINALITY, - MetricType::CUMULATIVE_OPTIMIZER_TIMING, - MetricType::CUMULATIVE_ROWS_SCANNED, - MetricType::EXTRA_INFO, - MetricType::LATENCY, - MetricType::OPERATOR_CARDINALITY, - MetricType::OPERATOR_NAME, - MetricType::OPERATOR_ROWS_SCANNED, - MetricType::OPERATOR_TIMING, - MetricType::OPERATOR_TYPE, - MetricType::OPTIMIZER_AGGREGATE_FUNCTION_REWRITER, - MetricType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE, - MetricType::OPTIMIZER_COLUMN_LIFETIME, - MetricType::OPTIMIZER_COMMON_AGGREGATE, - MetricType::OPTIMIZER_COMMON_SUBEXPRESSIONS, - MetricType::OPTIMIZER_COMMON_SUBPLAN, - MetricType::OPTIMIZER_COMPRESSED_MATERIALIZATION, - MetricType::OPTIMIZER_CTE_FILTER_PUSHER, - MetricType::OPTIMIZER_CTE_INLINING, - MetricType::OPTIMIZER_DELIMINATOR, - MetricType::OPTIMIZER_DUPLICATE_GROUPS, - MetricType::OPTIMIZER_EMPTY_RESULT_PULLUP, - MetricType::OPTIMIZER_EXPRESSION_REWRITER, - MetricType::OPTIMIZER_EXTENSION, - MetricType::OPTIMIZER_FILTER_PULLUP, - MetricType::OPTIMIZER_FILTER_PUSHDOWN, - MetricType::OPTIMIZER_IN_CLAUSE, - MetricType::OPTIMIZER_JOIN_ELIMINATION, - MetricType::OPTIMIZER_JOIN_FILTER_PUSHDOWN, - MetricType::OPTIMIZER_JOIN_ORDER, - MetricType::OPTIMIZER_LATE_MATERIALIZATION, - MetricType::OPTIMIZER_LIMIT_PUSHDOWN, - MetricType::OPTIMIZER_MATERIALIZED_CTE, - MetricType::OPTIMIZER_OUTER_JOIN_SIMPLIFICATION, - MetricType::OPTIMIZER_PARTITIONED_EXECUTION, - MetricType::OPTIMIZER_PROJECTION_PULLUP, - MetricType::OPTIMIZER_REGEX_RANGE, - MetricType::OPTIMIZER_REORDER_FILTER, - MetricType::OPTIMIZER_ROW_GROUP_PRUNER, - MetricType::OPTIMIZER_ROW_NUMBER_REWRITER, - MetricType::OPTIMIZER_SAMPLING_PUSHDOWN, - MetricType::OPTIMIZER_STATISTICS_PROPAGATION, - MetricType::OPTIMIZER_TOP_N, - MetricType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION, - MetricType::OPTIMIZER_UNNEST_REWRITER, - MetricType::OPTIMIZER_UNUSED_COLUMNS, - MetricType::OPTIMIZER_WINDOW_SELF_JOIN, - MetricType::PARSER, - MetricType::PHYSICAL_PLANNER, - MetricType::PHYSICAL_PLANNER_COLUMN_BINDING, - MetricType::PHYSICAL_PLANNER_CREATE_PLAN, - MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES, - MetricType::PLANNER, - MetricType::PLANNER_BINDING, - MetricType::QUERY_NAME, - MetricType::RESULT_SET_SIZE, - MetricType::ROWS_RETURNED, - MetricType::SYSTEM_PEAK_BUFFER_MEMORY, - MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE, - MetricType::TOTAL_BYTES_READ, - MetricType::TOTAL_BYTES_WRITTEN, - MetricType::TOTAL_MEMORY_ALLOCATED, - MetricType::WAITING_TO_ATTACH_LATENCY, - MetricType::WAL_REPLAY_ENTRY_COUNT, - MetricType::WRITE_TO_WAL_LATENCY, - }; -} - -profiler_settings_t MetricsUtils::GetMetricsByGroupType(MetricGroup type) { - switch(type) { - case MetricGroup::ALL: - return GetAllMetrics(); - case MetricGroup::CORE: - return GetCoreMetrics(); - case MetricGroup::DEFAULT: - return GetDefaultMetrics(); - case MetricGroup::EXECUTION: - return GetExecutionMetrics(); - case MetricGroup::FILE: - return GetFileMetrics(); - case MetricGroup::OPERATOR: - return GetOperatorMetrics(); - case MetricGroup::OPTIMIZER: - return GetOptimizerMetrics(); - case MetricGroup::PHASE_TIMING: - return GetPhaseTimingMetrics(); - default: - throw InternalException("The MetricGroup passed is invalid"); - } -} - -profiler_settings_t MetricsUtils::GetCoreMetrics() { - return { - MetricType::CPU_TIME, - MetricType::CUMULATIVE_CARDINALITY, - MetricType::CUMULATIVE_ROWS_SCANNED, - MetricType::EXTRA_INFO, - MetricType::LATENCY, - MetricType::QUERY_NAME, - MetricType::RESULT_SET_SIZE, - MetricType::ROWS_RETURNED, - }; -} - -bool MetricsUtils::IsCoreMetric(MetricType type) { - switch(type) { - case MetricType::CPU_TIME: - case MetricType::CUMULATIVE_CARDINALITY: - case MetricType::CUMULATIVE_ROWS_SCANNED: - case MetricType::EXTRA_INFO: - case MetricType::LATENCY: - case MetricType::QUERY_NAME: - case MetricType::RESULT_SET_SIZE: - case MetricType::ROWS_RETURNED: - return true; - default: - return false; - } -} - -profiler_settings_t MetricsUtils::GetDefaultMetrics() { - return { - MetricType::ATTACH_LOAD_STORAGE_LATENCY, - MetricType::ATTACH_REPLAY_WAL_LATENCY, - MetricType::BLOCKED_THREAD_TIME, - MetricType::CHECKPOINT_LATENCY, - MetricType::COMMIT_LOCAL_STORAGE_LATENCY, - MetricType::CPU_TIME, - MetricType::CUMULATIVE_CARDINALITY, - MetricType::CUMULATIVE_ROWS_SCANNED, - MetricType::EXTRA_INFO, - MetricType::LATENCY, - MetricType::OPERATOR_CARDINALITY, - MetricType::OPERATOR_NAME, - MetricType::OPERATOR_ROWS_SCANNED, - MetricType::OPERATOR_TIMING, - MetricType::OPERATOR_TYPE, - MetricType::QUERY_NAME, - MetricType::RESULT_SET_SIZE, - MetricType::ROWS_RETURNED, - MetricType::SYSTEM_PEAK_BUFFER_MEMORY, - MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE, - MetricType::TOTAL_BYTES_READ, - MetricType::TOTAL_BYTES_WRITTEN, - MetricType::TOTAL_MEMORY_ALLOCATED, - MetricType::WAITING_TO_ATTACH_LATENCY, - MetricType::WAL_REPLAY_ENTRY_COUNT, - MetricType::WRITE_TO_WAL_LATENCY, - }; -} - -bool MetricsUtils::IsDefaultMetric(MetricType type) { - switch(type) { - case MetricType::ATTACH_LOAD_STORAGE_LATENCY: - case MetricType::ATTACH_REPLAY_WAL_LATENCY: - case MetricType::BLOCKED_THREAD_TIME: - case MetricType::CHECKPOINT_LATENCY: - case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: - case MetricType::CPU_TIME: - case MetricType::CUMULATIVE_CARDINALITY: - case MetricType::CUMULATIVE_ROWS_SCANNED: - case MetricType::EXTRA_INFO: - case MetricType::LATENCY: - case MetricType::OPERATOR_CARDINALITY: - case MetricType::OPERATOR_NAME: - case MetricType::OPERATOR_ROWS_SCANNED: - case MetricType::OPERATOR_TIMING: - case MetricType::OPERATOR_TYPE: - case MetricType::QUERY_NAME: - case MetricType::RESULT_SET_SIZE: - case MetricType::ROWS_RETURNED: - case MetricType::SYSTEM_PEAK_BUFFER_MEMORY: - case MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE: - case MetricType::TOTAL_BYTES_READ: - case MetricType::TOTAL_BYTES_WRITTEN: - case MetricType::TOTAL_MEMORY_ALLOCATED: - case MetricType::WAITING_TO_ATTACH_LATENCY: - case MetricType::WAL_REPLAY_ENTRY_COUNT: - case MetricType::WRITE_TO_WAL_LATENCY: - return true; - default: - return false; - } -} - -profiler_settings_t MetricsUtils::GetExecutionMetrics() { - return { - MetricType::BLOCKED_THREAD_TIME, - MetricType::SYSTEM_PEAK_BUFFER_MEMORY, - MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE, - MetricType::TOTAL_MEMORY_ALLOCATED, - }; -} - -bool MetricsUtils::IsExecutionMetric(MetricType type) { - switch(type) { - case MetricType::BLOCKED_THREAD_TIME: - case MetricType::SYSTEM_PEAK_BUFFER_MEMORY: - case MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE: - case MetricType::TOTAL_MEMORY_ALLOCATED: - return true; - default: - return false; - } -} - -profiler_settings_t MetricsUtils::GetFileMetrics() { - return { - MetricType::ATTACH_LOAD_STORAGE_LATENCY, - MetricType::ATTACH_REPLAY_WAL_LATENCY, - MetricType::CHECKPOINT_LATENCY, - MetricType::COMMIT_LOCAL_STORAGE_LATENCY, - MetricType::TOTAL_BYTES_READ, - MetricType::TOTAL_BYTES_WRITTEN, - MetricType::WAITING_TO_ATTACH_LATENCY, - MetricType::WAL_REPLAY_ENTRY_COUNT, - MetricType::WRITE_TO_WAL_LATENCY, - }; -} - -bool MetricsUtils::IsFileMetric(MetricType type) { - switch(type) { - case MetricType::ATTACH_LOAD_STORAGE_LATENCY: - case MetricType::ATTACH_REPLAY_WAL_LATENCY: - case MetricType::CHECKPOINT_LATENCY: - case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: - case MetricType::TOTAL_BYTES_READ: - case MetricType::TOTAL_BYTES_WRITTEN: - case MetricType::WAITING_TO_ATTACH_LATENCY: - case MetricType::WAL_REPLAY_ENTRY_COUNT: - case MetricType::WRITE_TO_WAL_LATENCY: - return true; - default: - return false; - } -} - -profiler_settings_t MetricsUtils::GetOperatorMetrics() { - return { - MetricType::OPERATOR_CARDINALITY, - MetricType::OPERATOR_NAME, - MetricType::OPERATOR_ROWS_SCANNED, - MetricType::OPERATOR_TIMING, - MetricType::OPERATOR_TYPE, - }; -} - -bool MetricsUtils::IsOperatorMetric(MetricType type) { - switch(type) { - case MetricType::OPERATOR_CARDINALITY: - case MetricType::OPERATOR_NAME: - case MetricType::OPERATOR_ROWS_SCANNED: - case MetricType::OPERATOR_TIMING: - case MetricType::OPERATOR_TYPE: - return true; - default: - return false; - } -} - -profiler_settings_t MetricsUtils::GetOptimizerMetrics() { - profiler_settings_t result; - for (auto metric = START_OPTIMIZER; metric <= END_OPTIMIZER; metric++) { - result.insert(static_cast(metric)); - } - return result; -} - -bool MetricsUtils::IsOptimizerMetric(MetricType type) { - return static_cast(type) >= START_OPTIMIZER && static_cast(type) <= END_OPTIMIZER; -} - - -MetricType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { - if (type == OptimizerType::INVALID) { - throw InternalException("Invalid OptimizerType: INVALID"); - } - - const auto base_opt = static_cast(OptimizerType::EXPRESSION_REWRITER); - const auto idx = static_cast(type) - base_opt; - - const auto metric_u8 = static_cast(START_OPTIMIZER + idx); - if (metric_u8 < START_OPTIMIZER || metric_u8 > END_OPTIMIZER) { - throw InternalException("OptimizerType out of MetricType optimizer range"); - } - return static_cast(metric_u8); -} - -OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricType type) { - const auto metric_u8 = static_cast(type); - if (!IsOptimizerMetric(type)) { - throw InternalException("MetricType is not an optimizer metric"); - } - - const auto idx = static_cast(metric_u8 - START_OPTIMIZER); - const auto result = static_cast(OptimizerType::EXPRESSION_REWRITER) + idx; - return static_cast(result); -} - -profiler_settings_t MetricsUtils::GetPhaseTimingMetrics() { - return { - MetricType::ALL_OPTIMIZERS, - MetricType::CUMULATIVE_OPTIMIZER_TIMING, - MetricType::PARSER, - MetricType::PHYSICAL_PLANNER, - MetricType::PHYSICAL_PLANNER_COLUMN_BINDING, - MetricType::PHYSICAL_PLANNER_CREATE_PLAN, - MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES, - MetricType::PLANNER, - MetricType::PLANNER_BINDING, - }; -} - -bool MetricsUtils::IsPhaseTimingMetric(MetricType type) { - switch(type) { - case MetricType::ALL_OPTIMIZERS: - case MetricType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricType::PARSER: - case MetricType::PHYSICAL_PLANNER: - case MetricType::PHYSICAL_PLANNER_COLUMN_BINDING: - case MetricType::PHYSICAL_PLANNER_CREATE_PLAN: - case MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES: - case MetricType::PLANNER: - case MetricType::PLANNER_BINDING: - return true; - default: - return false; - } -} - -profiler_settings_t MetricsUtils::GetRootScopeMetrics() { - return { - MetricType::ALL_OPTIMIZERS, - MetricType::ATTACH_LOAD_STORAGE_LATENCY, - MetricType::ATTACH_REPLAY_WAL_LATENCY, - MetricType::BLOCKED_THREAD_TIME, - MetricType::CHECKPOINT_LATENCY, - MetricType::COMMIT_LOCAL_STORAGE_LATENCY, - MetricType::CUMULATIVE_OPTIMIZER_TIMING, - MetricType::LATENCY, - MetricType::OPTIMIZER_AGGREGATE_FUNCTION_REWRITER, - MetricType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE, - MetricType::OPTIMIZER_COLUMN_LIFETIME, - MetricType::OPTIMIZER_COMMON_AGGREGATE, - MetricType::OPTIMIZER_COMMON_SUBEXPRESSIONS, - MetricType::OPTIMIZER_COMMON_SUBPLAN, - MetricType::OPTIMIZER_COMPRESSED_MATERIALIZATION, - MetricType::OPTIMIZER_CTE_FILTER_PUSHER, - MetricType::OPTIMIZER_CTE_INLINING, - MetricType::OPTIMIZER_DELIMINATOR, - MetricType::OPTIMIZER_DUPLICATE_GROUPS, - MetricType::OPTIMIZER_EMPTY_RESULT_PULLUP, - MetricType::OPTIMIZER_EXPRESSION_REWRITER, - MetricType::OPTIMIZER_EXTENSION, - MetricType::OPTIMIZER_FILTER_PULLUP, - MetricType::OPTIMIZER_FILTER_PUSHDOWN, - MetricType::OPTIMIZER_IN_CLAUSE, - MetricType::OPTIMIZER_JOIN_ELIMINATION, - MetricType::OPTIMIZER_JOIN_FILTER_PUSHDOWN, - MetricType::OPTIMIZER_JOIN_ORDER, - MetricType::OPTIMIZER_LATE_MATERIALIZATION, - MetricType::OPTIMIZER_LIMIT_PUSHDOWN, - MetricType::OPTIMIZER_MATERIALIZED_CTE, - MetricType::OPTIMIZER_OUTER_JOIN_SIMPLIFICATION, - MetricType::OPTIMIZER_PARTITIONED_EXECUTION, - MetricType::OPTIMIZER_PROJECTION_PULLUP, - MetricType::OPTIMIZER_REGEX_RANGE, - MetricType::OPTIMIZER_REORDER_FILTER, - MetricType::OPTIMIZER_ROW_GROUP_PRUNER, - MetricType::OPTIMIZER_ROW_NUMBER_REWRITER, - MetricType::OPTIMIZER_SAMPLING_PUSHDOWN, - MetricType::OPTIMIZER_STATISTICS_PROPAGATION, - MetricType::OPTIMIZER_TOP_N, - MetricType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION, - MetricType::OPTIMIZER_UNNEST_REWRITER, - MetricType::OPTIMIZER_UNUSED_COLUMNS, - MetricType::OPTIMIZER_WINDOW_SELF_JOIN, - MetricType::PHYSICAL_PLANNER, - MetricType::PHYSICAL_PLANNER_COLUMN_BINDING, - MetricType::PHYSICAL_PLANNER_CREATE_PLAN, - MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES, - MetricType::PLANNER, - MetricType::PLANNER_BINDING, - MetricType::QUERY_NAME, - MetricType::ROWS_RETURNED, - MetricType::TOTAL_BYTES_READ, - MetricType::TOTAL_BYTES_WRITTEN, - MetricType::TOTAL_MEMORY_ALLOCATED, - MetricType::WAITING_TO_ATTACH_LATENCY, - MetricType::WAL_REPLAY_ENTRY_COUNT, - MetricType::WRITE_TO_WAL_LATENCY, - }; -} - -bool MetricsUtils::IsRootScopeMetric(MetricType type) { - switch(type) { - case MetricType::ALL_OPTIMIZERS: - case MetricType::ATTACH_LOAD_STORAGE_LATENCY: - case MetricType::ATTACH_REPLAY_WAL_LATENCY: - case MetricType::BLOCKED_THREAD_TIME: - case MetricType::CHECKPOINT_LATENCY: - case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: - case MetricType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricType::LATENCY: - case MetricType::OPTIMIZER_AGGREGATE_FUNCTION_REWRITER: - case MetricType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: - case MetricType::OPTIMIZER_COLUMN_LIFETIME: - case MetricType::OPTIMIZER_COMMON_AGGREGATE: - case MetricType::OPTIMIZER_COMMON_SUBEXPRESSIONS: - case MetricType::OPTIMIZER_COMMON_SUBPLAN: - case MetricType::OPTIMIZER_COMPRESSED_MATERIALIZATION: - case MetricType::OPTIMIZER_CTE_FILTER_PUSHER: - case MetricType::OPTIMIZER_CTE_INLINING: - case MetricType::OPTIMIZER_DELIMINATOR: - case MetricType::OPTIMIZER_DUPLICATE_GROUPS: - case MetricType::OPTIMIZER_EMPTY_RESULT_PULLUP: - case MetricType::OPTIMIZER_EXPRESSION_REWRITER: - case MetricType::OPTIMIZER_EXTENSION: - case MetricType::OPTIMIZER_FILTER_PULLUP: - case MetricType::OPTIMIZER_FILTER_PUSHDOWN: - case MetricType::OPTIMIZER_IN_CLAUSE: - case MetricType::OPTIMIZER_JOIN_ELIMINATION: - case MetricType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: - case MetricType::OPTIMIZER_JOIN_ORDER: - case MetricType::OPTIMIZER_LATE_MATERIALIZATION: - case MetricType::OPTIMIZER_LIMIT_PUSHDOWN: - case MetricType::OPTIMIZER_MATERIALIZED_CTE: - case MetricType::OPTIMIZER_OUTER_JOIN_SIMPLIFICATION: - case MetricType::OPTIMIZER_PARTITIONED_EXECUTION: - case MetricType::OPTIMIZER_PROJECTION_PULLUP: - case MetricType::OPTIMIZER_REGEX_RANGE: - case MetricType::OPTIMIZER_REORDER_FILTER: - case MetricType::OPTIMIZER_ROW_GROUP_PRUNER: - case MetricType::OPTIMIZER_ROW_NUMBER_REWRITER: - case MetricType::OPTIMIZER_SAMPLING_PUSHDOWN: - case MetricType::OPTIMIZER_STATISTICS_PROPAGATION: - case MetricType::OPTIMIZER_TOP_N: - case MetricType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION: - case MetricType::OPTIMIZER_UNNEST_REWRITER: - case MetricType::OPTIMIZER_UNUSED_COLUMNS: - case MetricType::OPTIMIZER_WINDOW_SELF_JOIN: - case MetricType::PHYSICAL_PLANNER: - case MetricType::PHYSICAL_PLANNER_COLUMN_BINDING: - case MetricType::PHYSICAL_PLANNER_CREATE_PLAN: - case MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES: - case MetricType::PLANNER: - case MetricType::PLANNER_BINDING: - case MetricType::QUERY_NAME: - case MetricType::ROWS_RETURNED: - case MetricType::TOTAL_BYTES_READ: - case MetricType::TOTAL_BYTES_WRITTEN: - case MetricType::TOTAL_MEMORY_ALLOCATED: - case MetricType::WAITING_TO_ATTACH_LATENCY: - case MetricType::WAL_REPLAY_ENTRY_COUNT: - case MetricType::WRITE_TO_WAL_LATENCY: - return true; - default: - return false; - } -} - -} diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp index 1493bd80b..adbee2538 100644 --- a/src/duckdb/src/common/enums/optimizer_type.cpp +++ b/src/duckdb/src/common/enums/optimizer_type.cpp @@ -50,6 +50,8 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"outer_join_simplification", OptimizerType::OUTER_JOIN_SIMPLIFICATION}, {"window_rewriter", OptimizerType::ROW_NUMBER_REWRITER}, {"partitioned_execution", OptimizerType::PARTITIONED_EXECUTION}, + {"partial_aggregate_pushdown", OptimizerType::PARTIAL_AGGREGATE_PUSHDOWN}, + {"remote_pushdown", OptimizerType::REMOTE_PUSHDOWN}, {nullptr, OptimizerType::INVALID}}; string OptimizerTypeToString(OptimizerType type) { diff --git a/src/duckdb/src/common/enums/physical_operator_type.cpp b/src/duckdb/src/common/enums/physical_operator_type.cpp index a4bd016c5..12f488a40 100644 --- a/src/duckdb/src/common/enums/physical_operator_type.cpp +++ b/src/duckdb/src/common/enums/physical_operator_type.cpp @@ -171,6 +171,10 @@ string PhysicalOperatorToString(PhysicalOperatorType type) { return "VERIFY_VECTOR"; case PhysicalOperatorType::UPDATE_EXTENSIONS: return "UPDATE_EXTENSIONS"; + case PhysicalOperatorType::CONNECT: + return "CONNECT"; + case PhysicalOperatorType::DISCONNECT: + return "DISCONNECT"; case PhysicalOperatorType::INVALID: break; } diff --git a/src/duckdb/src/common/enums/statement_type.cpp b/src/duckdb/src/common/enums/statement_type.cpp index 643df0076..5ceefeacb 100644 --- a/src/duckdb/src/common/enums/statement_type.cpp +++ b/src/duckdb/src/common/enums/statement_type.cpp @@ -67,6 +67,10 @@ string StatementTypeToString(StatementType type) { return "UPDATE_EXTENSIONS"; case StatementType::MERGE_INTO_STATEMENT: return "MERGE_INTO"; + case StatementType::CONNECT_STATEMENT: + return "CONNECT"; + case StatementType::DISCONNECT_STATEMENT: + return "DISCONNECT"; case StatementType::INVALID_STATEMENT: break; } diff --git a/src/duckdb/src/common/exception.cpp b/src/duckdb/src/common/exception.cpp index 82dfd2195..34f179df8 100644 --- a/src/duckdb/src/common/exception.cpp +++ b/src/duckdb/src/common/exception.cpp @@ -334,6 +334,9 @@ SequenceException::SequenceException(const string &msg) : Exception(ExceptionTyp InterruptException::InterruptException() : Exception(ExceptionType::INTERRUPT, INTERRUPT_MESSAGE) { } +InterruptException::InterruptException(const string &message) : Exception(ExceptionType::INTERRUPT, message) { +} + FatalException::FatalException(ExceptionType type, const string &msg) : Exception(type, msg) { // FIXME: Make any log context available to add error logging. } diff --git a/src/duckdb/src/common/exception/binder_exception.cpp b/src/duckdb/src/common/exception/binder_exception.cpp index aa9a9459e..bbf88e475 100644 --- a/src/duckdb/src/common/exception/binder_exception.cpp +++ b/src/duckdb/src/common/exception/binder_exception.cpp @@ -11,38 +11,39 @@ BinderException::BinderException(const unordered_map &extra_info : Exception(extra_info, ExceptionType::BINDER, msg) { } -BinderException BinderException::ColumnNotFound(const string &name, const vector &similar_bindings, +BinderException BinderException::ColumnNotFound(const Identifier &name, const vector &similar_bindings, QueryErrorContext context) { auto extra_info = Exception::InitializeExtraInfo("COLUMN_NOT_FOUND", context.query_location); - string candidate_str = StringUtil::CandidatesMessage(similar_bindings, "Candidate bindings"); - extra_info["name"] = name; + string candidate_str = StringUtil::CandidatesMessage(IdentifiersToStrings(similar_bindings), "Candidate bindings"); + extra_info["name"] = name.GetIdentifierName(); if (!similar_bindings.empty()) { extra_info["candidates"] = StringUtil::Join(similar_bindings, ","); return BinderException(extra_info, StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", - name, candidate_str)); + name.GetIdentifierName(), candidate_str)); } else { return BinderException( - extra_info, - StringUtil::Format("Referenced column \"%s\" was not found because the FROM clause is missing", name)); + extra_info, StringUtil::Format("Referenced column \"%s\" was not found because the FROM clause is missing", + name.GetIdentifierName())); } } -BinderException BinderException::NoMatchingFunction(const string &catalog_name, const string &schema_name, - const string &name, const vector &arguments, +BinderException BinderException::NoMatchingFunction(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &name, const vector &arguments, + const vector> &named_arguments, const vector &candidates) { auto extra_info = Exception::InitializeExtraInfo("NO_MATCHING_FUNCTION", optional_idx()); // no matching function was found, throw an error - string call_str = Function::CallToString(catalog_name, schema_name, name, arguments); + string call_str = Function::CallToString(catalog_name, schema_name, name, arguments, named_arguments); string candidate_str; for (auto &candidate : candidates) { candidate_str += "\t" + candidate + "\n"; } - extra_info["name"] = name; + extra_info["name"] = name.GetIdentifierName(); if (!catalog_name.empty()) { - extra_info["catalog"] = catalog_name; + extra_info["catalog"] = catalog_name.GetIdentifierName(); } if (!schema_name.empty()) { - extra_info["schema"] = schema_name; + extra_info["schema"] = schema_name.GetIdentifierName(); } extra_info["call"] = call_str; if (!candidates.empty()) { diff --git a/src/duckdb/src/common/exception/catalog_exception.cpp b/src/duckdb/src/common/exception/catalog_exception.cpp index b1cd4caf7..f073239bc 100644 --- a/src/duckdb/src/common/exception/catalog_exception.cpp +++ b/src/duckdb/src/common/exception/catalog_exception.cpp @@ -40,32 +40,33 @@ CatalogException CatalogException::MissingEntry(const EntryLookupInfo &lookup_in version_info, did_you_mean)); } -CatalogException CatalogException::MissingEntry(CatalogType type, const string &name, const string &suggestion, +CatalogException CatalogException::MissingEntry(CatalogType type, const Identifier &name, const string &suggestion, QueryErrorContext context) { EntryLookupInfo lookup_info(type, name, context); return MissingEntry(lookup_info, suggestion); } -CatalogException CatalogException::MissingEntry(const string &type, const string &name, +CatalogException CatalogException::MissingEntry(const string &type, const Identifier &name, const vector &suggestions, QueryErrorContext context) { auto extra_info = Exception::InitializeExtraInfo("MISSING_ENTRY", context.query_location); extra_info["error_subtype"] = "MISSING_ENTRY"; - extra_info["name"] = name; + extra_info["name"] = name.GetIdentifierName(); extra_info["type"] = type; if (!suggestions.empty()) { extra_info["candidates"] = StringUtil::Join(suggestions, ", "); } - return CatalogException(extra_info, - StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, - StringUtil::CandidatesErrorMessage(suggestions, name, "Did you mean"))); + return CatalogException(extra_info, StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, + StringUtil::CandidatesErrorMessage( + suggestions, name.GetIdentifierName(), "Did you mean"))); } -CatalogException CatalogException::EntryAlreadyExists(CatalogType type, const string &name, QueryErrorContext context) { +CatalogException CatalogException::EntryAlreadyExists(CatalogType type, const Identifier &name, + QueryErrorContext context) { auto extra_info = Exception::InitializeExtraInfo("ENTRY_ALREADY_EXISTS", optional_idx()); - extra_info["name"] = name; + extra_info["name"] = name.GetIdentifierName(); extra_info["type"] = CatalogTypeToString(type); - return CatalogException(extra_info, - StringUtil::Format("%s with name \"%s\" already exists!", CatalogTypeToString(type), name)); + return CatalogException(extra_info, StringUtil::Format("%s with name \"%s\" already exists!", + CatalogTypeToString(type), name.GetIdentifierName())); } } // namespace duckdb diff --git a/src/duckdb/src/common/exception_format_value.cpp b/src/duckdb/src/common/exception_format_value.cpp index c581f6e66..53665c14d 100644 --- a/src/duckdb/src/common/exception_format_value.cpp +++ b/src/duckdb/src/common/exception_format_value.cpp @@ -7,6 +7,7 @@ #include "duckdb/common/types/uhugeint.hpp" #include "duckdb/parser/keyword_helper.hpp" #include "duckdb/common/types/string.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { @@ -56,6 +57,12 @@ ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const String &value return ExceptionFormatValue(value); } template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const Identifier &value) { + // An Identifier is a SQL identifier, so it formats with SQL identifier rules (quoted only when required) - + // callers should never need to wrap an Identifier in SQLIdentifier when printing. + return SQLIdentifier::ToString(value.GetIdentifierName()); +} +template <> ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const SQLString &value) { return SQLString::ToString(value.raw_string); } diff --git a/src/duckdb/src/common/extra_type_info.cpp b/src/duckdb/src/common/extra_type_info.cpp index 3e9f813c7..74ebe5fb1 100644 --- a/src/duckdb/src/common/extra_type_info.cpp +++ b/src/duckdb/src/common/extra_type_info.cpp @@ -243,86 +243,13 @@ shared_ptr StructTypeInfo::DeepCopy() const { return make_shared_ptr(std::move(copied_child_types)); } -//===--------------------------------------------------------------------===// -// Legacy Aggregate State Type Info -//===--------------------------------------------------------------------===// -LegacyAggregateStateTypeInfo::LegacyAggregateStateTypeInfo() - : ExtraTypeInfo(ExtraTypeInfoType::LEGACY_AGGREGATE_STATE_TYPE_INFO) { -} - -LegacyAggregateStateTypeInfo::LegacyAggregateStateTypeInfo(aggregate_state_t state_type_p) - : ExtraTypeInfo(ExtraTypeInfoType::LEGACY_AGGREGATE_STATE_TYPE_INFO), state_type(std::move(state_type_p)) { -} - -shared_ptr LegacyAggregateStateTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -bool LegacyAggregateStateTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return state_type.function_name == other.state_type.function_name && - state_type.return_type == other.state_type.return_type && - state_type.bound_argument_types == other.state_type.bound_argument_types; -} - -//===--------------------------------------------------------------------===// -// Aggregate State Type Info -//===--------------------------------------------------------------------===// -/* - * NOTE: In types.json, AggregateStateTypeInfo inherits directly from ExtraTypeInfo - * instead of StructTypeInfo. This is intentional because of a bug in the generation script logic: - * the generation script produces invalid C++ when handling - * multi-level inheritance for these types (specifically, trying to access - * non-static members in static Deserialize methods). Flattening the JSON - * ensures the dispatch logic remains in ExtraTypeInfo::Deserialize where - * the 'type' property is readily available. - */ - -AggregateStateTypeInfo::AggregateStateTypeInfo() : StructTypeInfo(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO, {}) { -} - -AggregateStateTypeInfo::AggregateStateTypeInfo(aggregate_state_t state_type_p, child_list_t child_types_p) - : StructTypeInfo(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO, std::move(child_types_p)), - state_type(std::move(state_type_p)) { -} - -bool AggregateStateTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return state_type.function_name == other.state_type.function_name && - state_type.return_type == other.state_type.return_type && - state_type.bound_argument_types == other.state_type.bound_argument_types && child_types == other.child_types; -} - -shared_ptr AggregateStateTypeInfo::Copy() const { - auto result = make_shared_ptr(state_type, child_types); - result->alias = alias; - return std::move(result); -} - -shared_ptr AggregateStateTypeInfo::DeepCopy() const { - child_list_t copied_child_types; - for (const auto &child_type : child_types) { - copied_child_types.emplace_back(child_type.first, child_type.second.DeepCopy()); - } - - vector copied_bound_arguments; - for (const auto &arg : state_type.bound_argument_types) { - copied_bound_arguments.push_back(arg.DeepCopy()); - } - aggregate_state_t copied_state_type(state_type.function_name, state_type.return_type.DeepCopy(), - std::move(copied_bound_arguments)); - auto result = make_shared_ptr(copied_state_type, copied_child_types); - result->alias = alias; - return std::move(result); -} - //===--------------------------------------------------------------------===// // User Type Info //===--------------------------------------------------------------------===// void UnboundTypeInfo::Serialize(Serializer &serializer) const { ExtraTypeInfo::Serialize(serializer); - if (serializer.ShouldSerialize(7)) { + if (serializer.ShouldSerialize(StorageVersion::V1_5_0)) { serializer.WritePropertyWithDefault>(204, "expr", expr); return; } @@ -331,13 +258,13 @@ void UnboundTypeInfo::Serialize(Serializer &serializer) const { if (expr->GetExpressionType() != ExpressionType::TYPE) { throw SerializationException( "Cannot serialize non-type type expression when targeting database storage version '%s'", - serializer.GetOptions().serialization_compatibility.duckdb_version); + serializer.GetOptions().storage_compatibility.duckdb_version); } auto &type_expr = expr->Cast(); - serializer.WritePropertyWithDefault(200, "name", type_expr.GetTypeName()); - serializer.WritePropertyWithDefault(201, "catalog", type_expr.GetCatalog()); - serializer.WritePropertyWithDefault(202, "schema", type_expr.GetSchema()); + serializer.WritePropertyWithDefault(200, "name", type_expr.GetTypeName().GetIdentifierName()); + serializer.WritePropertyWithDefault(201, "catalog", type_expr.GetCatalog().GetIdentifierName()); + serializer.WritePropertyWithDefault(202, "schema", type_expr.GetSchema().GetIdentifierName()); // Try to write the user type mods too vector user_type_mods; @@ -345,7 +272,7 @@ void UnboundTypeInfo::Serialize(Serializer &serializer) const { if (param->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { throw SerializationException( "Cannot serialize non-constant type parameter when targeting serialization version %s", - serializer.GetOptions().serialization_compatibility.duckdb_version); + serializer.GetOptions().storage_compatibility.duckdb_version); } auto &const_expr = param->Cast(); @@ -375,12 +302,29 @@ shared_ptr UnboundTypeInfo::Deserialize(Deserializer &deserialize user_type_mods.push_back(make_uniq_base(mod)); } - result->expr = make_uniq(catalog, schema, name, std::move(user_type_mods)); + result->expr = make_uniq(Identifier(catalog), Identifier(schema), Identifier(name), + std::move(user_type_mods)); } return std::move(result); } +//===--------------------------------------------------------------------===// +// Legacy Aggregate State Type Info +//===--------------------------------------------------------------------===// +LegacyAggregateStateTypeInfo::LegacyAggregateStateTypeInfo() + : ExtraTypeInfo(ExtraTypeInfoType::LEGACY_AGGREGATE_STATE_TYPE_INFO) { + throw InternalException("LegacyAggregateStateTypeInfo should no longer be getting constructed"); +} + +bool LegacyAggregateStateTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + throw InternalException("LegacyAggregateStateTypeInfo should no longer be getting constructed"); +} + +shared_ptr LegacyAggregateStateTypeInfo::LegacyDeserialize() { + return make_shared_ptr(ExtraTypeInfoType::GENERIC_TYPE_INFO); +} + //===--------------------------------------------------------------------===// // Enum Type Info //===--------------------------------------------------------------------===// @@ -396,7 +340,7 @@ PhysicalType EnumTypeInfo::DictType(idx_t size) { } } -EnumTypeInfo::EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p) +EnumTypeInfo::EnumTypeInfo(const Vector &values_insert_order_p, idx_t dict_size_p) : ExtraTypeInfo(ExtraTypeInfoType::ENUM_TYPE_INFO), values_insert_order(Vector::Ref(values_insert_order_p)), dict_type(EnumDictType::VECTOR_DICT), dict_size(dict_size_p) { } @@ -413,7 +357,7 @@ const idx_t &EnumTypeInfo::GetDictSize() const { return dict_size; } -LogicalType EnumTypeInfo::CreateType(Vector &ordered_data, idx_t size) { +LogicalType EnumTypeInfo::CreateType(const Vector &ordered_data, idx_t size) { // Generate EnumTypeInfo shared_ptr info; auto enum_internal_type = EnumTypeInfo::DictType(size); diff --git a/src/duckdb/src/common/file_buffer.cpp b/src/duckdb/src/common/file_buffer.cpp index babda0c6e..ea609cf67 100644 --- a/src/duckdb/src/common/file_buffer.cpp +++ b/src/duckdb/src/common/file_buffer.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/memory_mapped_file.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/storage/block_manager.hpp" #include "duckdb/storage/storage_info.hpp" @@ -31,6 +32,7 @@ void FileBuffer::Init() { size = 0; internal_buffer = nullptr; internal_size = 0; + owns_internal_buffer = true; } FileBuffer::FileBuffer(FileBuffer &source, FileBufferType type_p, idx_t block_header_size) @@ -40,18 +42,23 @@ FileBuffer::FileBuffer(FileBuffer &source, FileBufferType type_p, idx_t block_he size = source.internal_size - block_header_size; internal_buffer = source.internal_buffer; internal_size = source.internal_size; + owns_internal_buffer = source.owns_internal_buffer; source.Init(); } FileBuffer::~FileBuffer() { - if (!internal_buffer) { + if (!internal_buffer || !owns_internal_buffer) { return; } allocator.FreeData(internal_buffer, internal_size); } void FileBuffer::ReallocBuffer(idx_t new_size) { + if (internal_buffer && !owns_internal_buffer) { + throw InternalException( + "Cannot resize a FileBuffer that does not own its internal buffer (it points into a memory-mapped region)"); + } data_ptr_t new_buffer; if (internal_buffer) { new_buffer = allocator.ReallocateData(internal_buffer, internal_size, new_size); @@ -66,6 +73,7 @@ void FileBuffer::ReallocBuffer(idx_t new_size) { internal_buffer = new_buffer; internal_size = new_size; + owns_internal_buffer = true; // The caller must update these. buffer = nullptr; @@ -108,11 +116,30 @@ void FileBuffer::Read(QueryContext context, FileHandle &handle, uint64_t locatio handle.Read(context, internal_buffer, internal_size, location); } +void FileBuffer::Read(QueryContext context, MemoryMappedFile &handle, uint64_t location) { + D_ASSERT(type != FileBufferType::TINY_BUFFER); + // Zero-copy: adopt a pointer into the mapping. + auto src = handle.GetDataMutable(location, internal_size); + const idx_t header_size = internal_size - size; + if (internal_buffer && owns_internal_buffer) { + allocator.FreeData(internal_buffer, internal_size); + } + internal_buffer = src; + buffer = internal_buffer + header_size; + owns_internal_buffer = false; +} + void FileBuffer::Write(QueryContext context, FileHandle &handle, const uint64_t location) { D_ASSERT(type != FileBufferType::TINY_BUFFER); handle.Write(context, internal_buffer, internal_size, location); } +void FileBuffer::Write(QueryContext context, MemoryMappedFile &handle, uint64_t location) { + D_ASSERT(type != FileBufferType::TINY_BUFFER); + auto dst = handle.GetDataMutable(location, internal_size); + memcpy(dst, internal_buffer, internal_size); +} + void FileBuffer::Clear() { memset(internal_buffer, 0, internal_size); } diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp index 1a8d50e71..94e4c3137 100644 --- a/src/duckdb/src/common/file_system.cpp +++ b/src/duckdb/src/common/file_system.cpp @@ -697,7 +697,7 @@ unique_ptr FileSystem::GlobFileList(const string &pattern, const return result; } } - if (input.behavior == FileGlobOptions::FALLBACK_GLOB || input.behavior == FileGlobOptions::DISALLOW_EMPTY) { + if (!input.AllowsEmpty()) { throw IOException("No files found that match the pattern \"%s\"", pattern); } } @@ -733,11 +733,19 @@ unique_ptr FileSystem::OpenCompressedFile(QueryContext context, uniq throw NotImplementedException("%s: OpenCompressedFile is not implemented!", GetName()); } +bool FileSystem::IsLocalFileSystem() const { + return false; +} + bool FileSystem::OnDiskFile(FileHandle &handle) { throw NotImplementedException("%s: OnDiskFile is not implemented!", GetName()); } // LCOV_EXCL_STOP +bool FileSystem::TryGetNetworkThroughput(FileHandle &handle, NetworkThroughputEstimate &result) { + return false; +} + FileHandle::FileHandle(FileSystem &file_system, string path_p, FileOpenFlags flags) : file_system(file_system), path(std::move(path_p)), flags(flags) { } @@ -751,7 +759,7 @@ int64_t FileHandle::Read(void *buffer, idx_t nr_bytes) { int64_t FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddToCounter(MetricType::TOTAL_BYTES_READ, nr_bytes); + QueryProfiler::Get(*context.GetClientContext()).TrackBytesRead(nr_bytes); } return file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes)); @@ -767,7 +775,7 @@ int64_t FileHandle::Write(void *buffer, idx_t nr_bytes) { int64_t FileHandle::Write(QueryContext context, void *buffer, idx_t nr_bytes) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddToCounter(MetricType::TOTAL_BYTES_WRITTEN, nr_bytes); + QueryProfiler::Get(*context.GetClientContext()).TrackBytesWritten(nr_bytes); } return file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes)); @@ -779,7 +787,7 @@ void FileHandle::Read(void *buffer, idx_t nr_bytes, idx_t location) { void FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddToCounter(MetricType::TOTAL_BYTES_READ, nr_bytes); + QueryProfiler::Get(*context.GetClientContext()).TrackBytesRead(nr_bytes); } file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes), location); @@ -787,7 +795,7 @@ void FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes, idx_t void FileHandle::Write(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location) { if (context.GetClientContext() != nullptr) { - context.GetClientContext()->client_data->profiler->AddToCounter(MetricType::TOTAL_BYTES_WRITTEN, nr_bytes); + QueryProfiler::Get(*context.GetClientContext()).TrackBytesWritten(nr_bytes); } file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes), location); @@ -849,6 +857,10 @@ bool FileHandle::OnDiskFile() { return file_system.OnDiskFile(*this); } +bool FileHandle::TryGetNetworkThroughput(NetworkThroughputEstimate &result) { + return file_system.TryGetNetworkThroughput(*this, result); +} + idx_t FileHandle::GetFileSize() { return NumericCast(file_system.GetFileSize(*this)); } diff --git a/src/duckdb/src/common/hive_partitioning.cpp b/src/duckdb/src/common/hive_partitioning.cpp index d7820153e..fbda2c8d2 100644 --- a/src/duckdb/src/common/hive_partitioning.cpp +++ b/src/duckdb/src/common/hive_partitioning.cpp @@ -56,11 +56,11 @@ ConvertKnownColRefToConstants(ClientContext &context, unique_ptr &ex auto &bound_colref = expr->Cast(); // This bound column ref is for another table - if (table_index != bound_colref.binding.table_index) { + if (table_index != bound_colref.Binding().table_index) { return; } - auto lookup = known_column_values.find(bound_colref.binding.column_index); + auto lookup = known_column_values.find(bound_colref.Binding().column_index); if (lookup != known_column_values.end()) { auto &partition_val = lookup->second; Value result_val; @@ -238,12 +238,11 @@ static inline Value GetHiveKeyNullValue(const LogicalType &type) { } template -static void TemplatedGetHivePartitionValues(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { +static void TemplatedGetHivePartitionValues(const Vector &input, vector &keys, const idx_t col_idx) { auto entries = input.Values(); const auto &type = input.GetType(); - for (idx_t i = 0; i < count; i++) { + for (idx_t i = 0; i < input.size(); i++) { auto &key = keys[i]; auto entry = entries[i]; if (entry.IsValid()) { @@ -254,66 +253,64 @@ static void TemplatedGetHivePartitionValues(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { - for (idx_t i = 0; i < count; i++) { +static void GetNestedHivePartitionValues(const Vector &input, vector &keys, const idx_t col_idx) { + for (idx_t i = 0; i < input.size(); i++) { auto &key = keys[i]; key.values[col_idx] = input.GetValue(i); } } -static void GetHivePartitionValuesTypeSwitch(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { +static void GetHivePartitionValuesTypeSwitch(const Vector &input, vector &keys, const idx_t col_idx) { const auto &type = input.GetType(); switch (type.InternalType()) { case PhysicalType::BOOL: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::INT8: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::INT16: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::INT32: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::INT64: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::INT128: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::UINT8: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::UINT16: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::UINT32: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::UINT64: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::UINT128: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::FLOAT: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::DOUBLE: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::INTERVAL: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::VARCHAR: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); + TemplatedGetHivePartitionValues(input, keys, col_idx); break; case PhysicalType::STRUCT: case PhysicalType::LIST: - GetNestedHivePartitionValues(input, keys, col_idx, count); + GetNestedHivePartitionValues(input, keys, col_idx); break; default: throw InternalException("Unsupported type for HivePartitionedColumnData::ComputePartitionIndices"); @@ -327,8 +324,8 @@ void HivePartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataApp hashes_v.Flatten(); for (idx_t col_idx = 0; col_idx < group_by_columns.size(); col_idx++) { - auto &group_by_col = input.data[group_by_columns[col_idx]]; - GetHivePartitionValuesTypeSwitch(group_by_col, keys, col_idx, count); + const auto &group_by_col = input.data[group_by_columns[col_idx]]; + GetHivePartitionValuesTypeSwitch(group_by_col, keys, col_idx); } const auto hashes = FlatVector::GetData(hashes_v); diff --git a/src/duckdb/src/common/identifier.cpp b/src/duckdb/src/common/identifier.cpp new file mode 100644 index 000000000..84be4fad6 --- /dev/null +++ b/src/duckdb/src/common/identifier.cpp @@ -0,0 +1,37 @@ +#include "duckdb/common/identifier.hpp" + +#include "duckdb/common/string_util.hpp" + +#include + +namespace duckdb { + +hash_t Identifier::Hash() const { + return StringUtil::CIHash(value); +} + +bool operator==(const Identifier &a, const Identifier &b) { + return StringUtil::CIEquals(a.GetIdentifierName(), b.GetIdentifierName()); +} +bool operator==(const Identifier &a, const string &b) { + return StringUtil::CIEquals(a.GetIdentifierName(), b); +} +bool operator==(const string &a, const Identifier &b) { + return StringUtil::CIEquals(a, b.GetIdentifierName()); +} +bool operator==(const Identifier &a, const char *b) { + return StringUtil::CIEquals(a.GetIdentifierName(), string(b)); +} +bool operator==(const char *a, const Identifier &b) { + return StringUtil::CIEquals(string(a), b.GetIdentifierName()); +} + +bool operator<(const Identifier &a, const Identifier &b) { + return StringUtil::CILessThan(a.GetIdentifierName(), b.GetIdentifierName()); +} + +std::ostream &operator<<(std::ostream &os, const Identifier &id) { + return os << id.GetIdentifierName(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp index 63d79a280..c5452a069 100644 --- a/src/duckdb/src/common/local_file_system.cpp +++ b/src/duckdb/src/common/local_file_system.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/file_opener.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/memory_mapped_file.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/windows.hpp" #include "duckdb/function/scalar/string_common.hpp" @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #else @@ -96,14 +98,70 @@ bool LocalFileSystem::IsPipe(const string &filename, optional_ptr op } #else + +// Maximum length of a path that does not require a long path prefix or +// LongPathsEnabled setting along with longPathAware manifest. +// For files the limit is 260 - 1 characters. +// Directories additionally must be able to create 8+3 files inside them. +// The limit applies to fully resolved absolute paths. +static const size_t WINDOWS_MAX_SHORT_PATH = 247; +static const size_t WINDOWS_MAX_LONG_PATH = 32000; +static const std::wstring WINDOWS_LOCAL_LONG_PATH_PREFIX = L"\\\\?\\"; +static const std::wstring WINDOWS_UNC_LONG_PATH_PREFIX = L"\\\\?\\UNC\\"; + static std::wstring ConvertPathToUnicode(const string &path) { return WindowsUtil::UTF8ToUnicode(path.c_str()); } +static std::wstring ConvertPathToNormalizedAbsolute(const std::wstring &path) { + // len includes NULL-terminator + DWORD len = GetFullPathNameW(path.c_str(), 0, nullptr, nullptr); + if (len == 0) { + string error = LocalFileSystem::GetLastErrorAsString(); + string utf8_path = WindowsUtil::UnicodeToUTF8(path.c_str()); + throw IOException("Failed to get length for normalized path \"%s\": %s", utf8_path, error); + } + + std::vector buf; + buf.resize(len); + + // written does NOT include NULL-terminator + DWORD written = GetFullPathNameW(path.c_str(), len, buf.data(), nullptr); + if (written == 0) { + string error = LocalFileSystem::GetLastErrorAsString(); + string utf8_path = WindowsUtil::UnicodeToUTF8(path.c_str()); + throw IOException("Failed to normalize path \"%s\", length: %lu: %s", utf8_path, len, error); + } + + return std::wstring(buf.data(), written); +} + static std::wstring NormalizePathAndConvertToUnicode(FileSystem &fs, const string &path, optional_ptr opener) { - auto normalized_path = fs.ExpandPath(path, opener); - return ConvertPathToUnicode(normalized_path); + string normalized_path = fs.ExpandPath(path, opener); + std::wstring unicode_path = ConvertPathToUnicode(normalized_path); + + // We need to get absolute path to check the length. Normalizing it (removing "." and "..", + // flipping forward slashes) is only required if the path is long. + // We are doing it for all paths to not perform current working dir resolving twice. + std::wstring abs_path = ConvertPathToNormalizedAbsolute(unicode_path); + + if (abs_path.length() <= WINDOWS_MAX_SHORT_PATH || abs_path.find(WINDOWS_LOCAL_LONG_PATH_PREFIX) == 0 || + abs_path.find(WINDOWS_UNC_LONG_PATH_PREFIX) == 0) { + return abs_path; + } + + if (abs_path.length() > WINDOWS_MAX_LONG_PATH) { + string resolved_path = WindowsUtil::UnicodeToUTF8(abs_path.c_str()); + throw IOException("Path is longer than Windows MAX_PATH, length: %zu, beginning: %s[...]", abs_path.length(), + resolved_path.substr(0, 100)); + } + + if (abs_path.find(L"\\\\") == 0) { + return WINDOWS_UNC_LONG_PATH_PREFIX + abs_path; + } + + return WINDOWS_LOCAL_LONG_PATH_PREFIX + abs_path; } bool LocalFileSystem::FileExists(const string &filename, optional_ptr opener) { @@ -167,6 +225,13 @@ struct UnixFileHandle : public FileHandle { }; }; +static void CloseFileAndAppendError(int fd, string &extended_error) { + if (close(fd) == -1) { + extended_error += ". Also, failed closing file: "; + extended_error += strerror(errno); + } +} + static FileMetadata StatsFromStruct(struct stat s) { FileMetadata file_metadata; file_metadata.file_size = s.st_size; @@ -300,6 +365,64 @@ static string AdditionalProcessInfo(FileSystem &fs, pid_t pid) { } #endif +// Apply a fcntl advisory lock per flags.Lock(); throws (and closes fd) on failure. Shared +// by OpenFile and MemoryMapFile. +static void TryAcquireFileLock(FileSystem &fs, int fd, const string &path, FileOpenFlags flags) { + if (flags.Lock() == FileLockType::NO_LOCK) { + return; + } + // only attempt locking on regular files (not FIFOs/sockets) + auto file_type = StatsInternal(fd, path).file_type; + if (file_type == FileType::FILE_TYPE_FIFO || file_type == FileType::FILE_TYPE_SOCKET) { + return; + } + struct flock fl; + memset(&fl, 0, sizeof fl); + fl.l_type = flags.Lock() == FileLockType::READ_LOCK ? F_RDLCK : F_WRLCK; + fl.l_whence = SEEK_SET; + fl.l_start = 0; + fl.l_len = 0; + int rc = fcntl(fd, F_SETLK, &fl); + int retained_errno = errno; + bool has_error = rc == -1; + string extended_error; + if (has_error) { + if (retained_errno == ENOTSUP) { + if (flags.Lock() == FileLockType::READ_LOCK) { + // file lock not supported here; ignore for read-only + return; + } + extended_error = "File locks are not supported for this file system, cannot open the file in " + "read-write mode. Try opening the file in read-only mode"; + } + } + if (!has_error) { + return; + } + if (extended_error.empty()) { + // who is holding the lock? + rc = fcntl(fd, F_GETLK, &fl); + if (rc == -1) { + extended_error = strerror(errno); + } else { + extended_error = AdditionalProcessInfo(fs, fl.l_pid); + } + if (flags.Lock() == FileLockType::WRITE_LOCK) { + // could we get a read lock? + fl.l_type = F_RDLCK; + rc = fcntl(fd, F_SETLK, &fl); + if (rc != -1) { + extended_error += ". However, you would be able to open this database in read-only mode, e.g. by " + "using the -readonly parameter in the CLI"; + } + } + } + CloseFileAndAppendError(fd, extended_error); + extended_error += ". See also https://duckdb.org/docs/current/connect/concurrency"; + throw IOException({{"errno", std::to_string(retained_errno)}}, "Could not set lock on file \"%s\": %s", path, + extended_error); +} + bool LocalFileSystem::IsPrivateFile(const string &path_p, FileOpener *opener) { auto path = FileSystem::ExpandPath(path_p, opener); @@ -329,7 +452,6 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF flags.Verify(); int open_flags = 0; - int rc; bool open_read = flags.OpenForReading(); bool open_write = flags.OpenForWriting(); if (open_read && open_write) { @@ -394,72 +516,18 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF #if defined(__DARWIN__) || defined(__APPLE__) if (flags.DirectIO()) { // OSX requires fcntl for Direct IO - rc = fcntl(fd, F_NOCACHE, 1); + int rc = fcntl(fd, F_NOCACHE, 1); if (rc == -1) { - throw IOException("Could not enable direct IO for file \"%s\": %s", path, strerror(errno)); + int retained_errno = errno; + string extended_error = strerror(retained_errno); + CloseFileAndAppendError(fd, extended_error); + throw IOException({{"errno", std::to_string(retained_errno)}}, + "Could not enable direct IO for file \"%s\": %s", path, extended_error); } } #endif - if (flags.Lock() != FileLockType::NO_LOCK) { - // set lock on file - // but only if it is not an input/output stream - auto file_type = StatsInternal(fd, path_p).file_type; - if (file_type != FileType::FILE_TYPE_FIFO && file_type != FileType::FILE_TYPE_SOCKET) { - struct flock fl; - memset(&fl, 0, sizeof fl); - fl.l_type = flags.Lock() == FileLockType::READ_LOCK ? F_RDLCK : F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - fl.l_len = 0; - rc = fcntl(fd, F_SETLK, &fl); - // Retain the original error. - int retained_errno = errno; - bool has_error = rc == -1; - string extended_error; - if (has_error) { - if (retained_errno == ENOTSUP) { - // file lock not supported for this file system - if (flags.Lock() == FileLockType::READ_LOCK) { - // for read-only, we ignore not-supported errors - has_error = false; - errno = 0; - } else { - extended_error = "File locks are not supported for this file system, cannot open the file in " - "read-write mode. Try opening the file in read-only mode"; - } - } - } - if (has_error) { - if (extended_error.empty()) { - // try to find out who is holding the lock using F_GETLK - rc = fcntl(fd, F_GETLK, &fl); - if (rc == -1) { // fnctl does not want to help us - extended_error = strerror(errno); - } else { - extended_error = AdditionalProcessInfo(*this, fl.l_pid); - } - if (flags.Lock() == FileLockType::WRITE_LOCK) { - // maybe we can get a read lock instead and tell this to the user. - fl.l_type = F_RDLCK; - rc = fcntl(fd, F_SETLK, &fl); - if (rc != -1) { // success! - extended_error += - ". However, you would be able to open this database in read-only mode, e.g. by " - "using the -readonly parameter in the CLI"; - } - } - } - rc = close(fd); - if (rc == -1) { - extended_error += ". Also, failed closing file"; - } - extended_error += ". See also https://duckdb.org/docs/current/connect/concurrency"; - throw IOException({{"errno", std::to_string(retained_errno)}}, "Could not set lock on file \"%s\": %s", - path, extended_error); - } - } - } + TryAcquireFileLock(*this, fd, path, flags); auto file_handle = make_uniq(*this, path, fd, flags); if (opener) { @@ -825,12 +893,146 @@ string LocalFileSystem::MakePathAbsolute(const string &path_p, optional_ptr(offset), + NumericCast(length)); + return rc == 0; #else + (void)offset; + (void)length; + return false; +#endif +} -constexpr char PIPE_PREFIX[] = "\\\\.\\pipe\\"; +void UnixMemoryMappedFile::Close() { + if (data != nullptr) { + munmap(data, size); + data = nullptr; + size = 0; + } + if (fd != -1) { + close(fd); + fd = -1; + } +} -static const std::wstring WINDOWS_LOCAL_LONG_PATH_PREFIX = L"\\\\?\\"; -static const std::wstring WINDOWS_UNC_LONG_PATH_PREFIX = L"\\\\?\\UNC\\"; +unique_ptr LocalFileSystem::MemoryMapFile(const OpenFileInfo &path_info, FileOpenFlags flags, + const MMapOptions &options, + optional_ptr opener) { + auto path = ExpandPath(path_info.path, opener); + flags.Verify(); + + bool open_read = flags.OpenForReading(); + bool open_write = flags.OpenForWriting(); + int open_flags; + if (open_read && open_write) { + open_flags = O_RDWR; + } else if (open_read) { + open_flags = O_RDONLY; + } else { + throw InternalException("READ or READ+WRITE must be specified when memory-mapping a file"); + } + if (open_write && flags.CreateFileIfNotExists()) { + open_flags |= O_CREAT; + } + + int fd = open(path.c_str(), open_flags, 0666); + if (fd == -1) { + if (flags.ReturnNullIfNotExists() && errno == ENOENT) { + return nullptr; + } + throw IOException({{"errno", std::to_string(errno)}}, "Cannot open file \"%s\" for memory mapping: %s", path, + strerror(errno)); + } + + TryAcquireFileLock(*this, fd, path, flags); + + struct stat st; + if (fstat(fd, &st) != 0) { + int saved_errno = errno; + close(fd); + throw IOException({{"errno", std::to_string(saved_errno)}}, "Could not stat file \"%s\": %s", path, + strerror(saved_errno)); + } + auto initial_file_size = NumericCast(st.st_size); + + // Writable mappings sparsely extend the file to the reserve size so the region never + // needs to be remapped; read-only mappings just span the existing file. + idx_t mapping_size; + if (open_write) { + mapping_size = MaxValue(initial_file_size, options.reserve_size); + if (initial_file_size < mapping_size) { + if (ftruncate(fd, NumericCast(mapping_size)) != 0) { + int saved_errno = errno; + close(fd); + throw IOException({{"errno", std::to_string(saved_errno)}}, + "Could not ftruncate file \"%s\" to reserve size %llu: %s", path, mapping_size, + strerror(saved_errno)); + } + } + } else { + mapping_size = initial_file_size; + } + + data_ptr_t data = nullptr; + if (mapping_size > 0) { + int prot = PROT_READ | (open_write ? PROT_WRITE : 0); + void *mapped = mmap(nullptr, mapping_size, prot, MAP_SHARED, fd, 0); + if (mapped == MAP_FAILED) { + int saved_errno = errno; + close(fd); + throw IOException({{"errno", std::to_string(saved_errno)}}, "Could not mmap file \"%s\" (size=%llu): %s", + path, mapping_size, strerror(saved_errno)); + } + data = data_ptr_cast(mapped); + } + + return make_uniq(path, flags, fd, data, mapping_size); +} + +#else + +constexpr char PIPE_PREFIX[] = "\\\\.\\pipe\\"; // Returns the last Win32 error, in string format. Returns an empty string if there is no error. std::string LocalFileSystem::GetLastErrorAsString() { @@ -1096,7 +1298,8 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF if (!extended_error.empty()) { extended_error = "\n" + extended_error; } - throw IOException("Cannot open file \"%s\": %s%s", path.c_str(), error, extended_error); + auto abs_path = WindowsUtil::UnicodeToUTF8(unicode_path.c_str()); + throw IOException("Cannot open file \"%s\": %s%s", abs_path, error, extended_error); } auto handle = make_uniq(*this, path.c_str(), hFile, flags); if (flags.OpenForAppending()) { @@ -1259,7 +1462,8 @@ void LocalFileSystem::CreateDirectory(const string &directory, optional_ptr LocalFileSystem::MemoryMapFile(const OpenFileInfo &path_info, FileOpenFlags flags, + const MMapOptions &options, + optional_ptr opener) { + auto path = ExpandPath(path_info.path, opener); + flags.Verify(); + + bool open_read = flags.OpenForReading(); + bool open_write = flags.OpenForWriting(); + DWORD desired_access; + DWORD share_mode; + DWORD creation_disposition = OPEN_EXISTING; + + if (open_read && open_write) { + desired_access = GENERIC_READ | GENERIC_WRITE; + } else if (open_read) { + desired_access = GENERIC_READ; + } else { + throw InternalException("READ or READ+WRITE must be specified when memory-mapping a file"); + } + // share_mode mirrors OpenFile's handling: write lock excludes, read lock blocks writers. + switch (flags.Lock()) { + case FileLockType::NO_LOCK: + share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE; + break; + case FileLockType::READ_LOCK: + share_mode = FILE_SHARE_READ; + break; + case FileLockType::WRITE_LOCK: + share_mode = 0; + break; + default: + throw InternalException("Unknown FileLockType"); + } + share_mode |= FILE_SHARE_DELETE; + if (open_write && flags.CreateFileIfNotExists()) { + creation_disposition = OPEN_ALWAYS; + } + + auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); + HANDLE file_handle = CreateFileW(unicode_path.c_str(), desired_access, share_mode, NULL, creation_disposition, + FILE_ATTRIBUTE_NORMAL, NULL); + if (file_handle == INVALID_HANDLE_VALUE) { + if (flags.ReturnNullIfNotExists() && GetLastError() == ERROR_FILE_NOT_FOUND) { + return nullptr; + } + throw IOException("Cannot open file \"%s\" for memory mapping: %s", path, + LocalFileSystem::GetLastErrorAsString()); + } + + LARGE_INTEGER size_li; + if (!GetFileSizeEx(file_handle, &size_li)) { + auto err = LocalFileSystem::GetLastErrorAsString(); + CloseHandle(file_handle); + throw IOException("Could not GetFileSizeEx for \"%s\": %s", path, err); + } + auto initial_file_size = NumericCast(size_li.QuadPart); + + // Writable mappings sparsely extend the file to the reserve size so the region never + // needs to be remapped; read-only mappings just span the existing file. + idx_t mapping_size; + if (open_write) { + mapping_size = MaxValue(initial_file_size, options.reserve_size); + if (initial_file_size < mapping_size) { + // Sparse so the SetEndOfFile below doesn't allocate the untouched range. + DWORD bytes_returned = 0; + DeviceIoControl(file_handle, FSCTL_SET_SPARSE, NULL, 0, NULL, 0, &bytes_returned, NULL); + LARGE_INTEGER li; + li.QuadPart = static_cast(mapping_size); + if (!SetFilePointerEx(file_handle, li, NULL, FILE_BEGIN)) { + auto err = LocalFileSystem::GetLastErrorAsString(); + CloseHandle(file_handle); + throw IOException("Could not SetFilePointerEx for \"%s\": %s", path, err); + } + if (!SetEndOfFile(file_handle)) { + auto err = LocalFileSystem::GetLastErrorAsString(); + CloseHandle(file_handle); + throw IOException("Could not SetEndOfFile for \"%s\" (reserve_size=%llu): %s", path, mapping_size, err); + } + } + } else { + mapping_size = initial_file_size; + } + + HANDLE mapping_handle = nullptr; + data_ptr_t data = nullptr; + if (mapping_size > 0) { + DWORD protect = open_write ? PAGE_READWRITE : PAGE_READONLY; + DWORD desired_map_access = open_write ? FILE_MAP_WRITE : FILE_MAP_READ; + DWORD high = static_cast((static_cast(mapping_size) >> 32) & 0xFFFFFFFFULL); + DWORD low = static_cast(static_cast(mapping_size) & 0xFFFFFFFFULL); + mapping_handle = CreateFileMappingW(file_handle, NULL, protect, high, low, NULL); + if (mapping_handle == NULL) { + auto err = LocalFileSystem::GetLastErrorAsString(); + CloseHandle(file_handle); + throw IOException("Could not CreateFileMapping for \"%s\": %s", path, err); + } + void *mapped = MapViewOfFile(mapping_handle, desired_map_access, 0, 0, mapping_size); + if (mapped == NULL) { + auto err = LocalFileSystem::GetLastErrorAsString(); + CloseHandle(mapping_handle); + CloseHandle(file_handle); + throw IOException("Could not MapViewOfFile for \"%s\": %s", path, err); + } + data = data_ptr_cast(mapped); + } + + return make_uniq(path, flags, file_handle, mapping_handle, data, mapping_size); +} + #endif bool LocalFileSystem::CanSeek() { diff --git a/src/duckdb/src/common/memory_mapped_file.cpp b/src/duckdb/src/common/memory_mapped_file.cpp new file mode 100644 index 000000000..8c4824172 --- /dev/null +++ b/src/duckdb/src/common/memory_mapped_file.cpp @@ -0,0 +1,41 @@ +#include "duckdb/common/memory_mapped_file.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_system.hpp" + +namespace duckdb { + +// Platform-specific MemoryMappedFile subclasses and the LocalFileSystem::MemoryMapFile +// factory live in src/common/local_file_system.cpp. +MemoryMappedFile::MemoryMappedFile(string path_p, FileOpenFlags flags_p, data_ptr_t data_p, idx_t size_p) + : path(std::move(path_p)), flags(flags_p), data(data_p), size(size_p) { +} + +MemoryMappedFile::~MemoryMappedFile() { +} + +const_data_ptr_t MemoryMappedFile::GetData(idx_t location, idx_t nr_bytes) const { + if (location + nr_bytes > size) { + throw IOException("MemoryMappedFile::GetData: out-of-bounds access in file \"%s\" " + "(location=%llu, nr_bytes=%llu, size=%llu)", + path, location, nr_bytes, size); + } + return data + location; +} + +data_ptr_t MemoryMappedFile::GetDataMutable(idx_t location, idx_t nr_bytes) { + if (location + nr_bytes > size) { + throw IOException("MemoryMappedFile::GetDataMutable: out-of-bounds access in file \"%s\" " + "(location=%llu, nr_bytes=%llu, size=%llu)", + path, location, nr_bytes, size); + } + return data + location; +} + +unique_ptr FileSystem::MemoryMapFile(const OpenFileInfo &path, FileOpenFlags flags, + const MMapOptions &options, optional_ptr opener) { + throw NotImplementedException("Memory-mapped file access is not supported by this file system (%s) for path \"%s\"", + GetName(), path.path); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/multi_file/base_file_reader.cpp b/src/duckdb/src/common/multi_file/base_file_reader.cpp index 203e93e45..caf569afa 100644 --- a/src/duckdb/src/common/multi_file/base_file_reader.cpp +++ b/src/duckdb/src/common/multi_file/base_file_reader.cpp @@ -3,7 +3,7 @@ namespace duckdb { -unique_ptr BaseFileReader::GetStatistics(ClientContext &context, const string &name) { +unique_ptr BaseFileReader::GetStatistics(ClientContext &context, const Identifier &name) { return nullptr; } @@ -24,7 +24,7 @@ double BaseFileReader::GetProgressInFile(ClientContext &context) { return 0; } -unique_ptr BaseUnionData::GetStatistics(ClientContext &context, const string &name) { +unique_ptr BaseUnionData::GetStatistics(ClientContext &context, const Identifier &name) { return nullptr; } diff --git a/src/duckdb/src/common/multi_file/multi_file_adaptive_filter_cache.cpp b/src/duckdb/src/common/multi_file/multi_file_adaptive_filter_cache.cpp index 174634619..5496e744e 100644 --- a/src/duckdb/src/common/multi_file/multi_file_adaptive_filter_cache.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_adaptive_filter_cache.cpp @@ -20,4 +20,12 @@ void MultiFileAdaptiveFilterCache::InitializeAdaptiveFilter(const TableFilterSet filter->SetLogger(std::move(logger), file_path, source, identities); } +AdaptiveFilter &MultiFileAdaptiveFilterCache::GetAdaptiveFilter() const { + if (!filter) { + throw InternalException( + "Filter from MultiFileAdaptiveFilterCache must be initialized by 'InitializeAdaptiveFilter' first."); + } + return *filter; +} + } // namespace duckdb diff --git a/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp b/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp index d517d24dc..269af0035 100644 --- a/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp @@ -14,9 +14,8 @@ #include "duckdb/function/scalar/struct_utils.hpp" #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/perfect_hash_join_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/planner/expression_iterator.hpp" -#include "duckdb/planner/filter/prefix_range_filter.hpp" namespace duckdb { @@ -498,7 +497,7 @@ static ColumnMapResult MapColumnStruct(ClientContext &context, const MultiFileCo //! FIXME: the 'default_value' should only be used if the STRUCT's default value is not NULL if (child_map.default_value) { // found a default value for this child - emplace it - child_map.default_value->SetAlias(global_child.name); + child_map.default_value->SetAlias(Identifier(global_child.name)); default_expressions.push_back(std::move(child_map.default_value)); } } @@ -679,7 +678,7 @@ ResultColumnMapping MultiFileColumnMapper::CreateColumnMappingByMapper(const Col auto local_index = global_id.RemapRootIndex(local_id.GetId()); // add the virtual column to the reader - reader.columns.emplace_back(virtual_entry->second.name, virtual_column_type); + reader.columns.emplace_back(virtual_entry->second.name.GetIdentifierName(), virtual_column_type); reader.AddVirtualColumn(global_column_id); // set it as being projected in this spot @@ -700,7 +699,7 @@ ResultColumnMapping MultiFileColumnMapper::CreateColumnMappingByMapper(const Col // reader is responsible for converting types - perform a top-level match only auto entry = mapper.Find(global_column); if (!entry.IsValid()) { - ThrowColumnNotFoundError(global_column.name); + ThrowColumnNotFoundError(global_column.name.GetIdentifierName()); } MultiFileLocalColumnId local_id(entry.GetIndex()); auto local_index = global_id.RemapRootIndex(local_id.GetId()); @@ -776,74 +775,11 @@ ResultColumnMapping MultiFileColumnMapper::CreateColumnMapping(MultiFileColumnMa } } +static unique_ptr GetFilterExpression(const TableFilter &filter); + bool EvaluateTableFilterAgainstConstant(ClientContext &context, const TableFilter &filter, const Value &constant) { - if (filter.filter_type == TableFilterType::EXPRESSION_FILTER) { - auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "EvaluateTableFilterAgainstConstant"); - return expr_filter.EvaluateWithConstant(context, constant); - } - const auto type = filter.filter_type; - - switch (type) { - case TableFilterType::CONJUNCTION_OR: { - auto &or_filter = filter.Cast(); - for (auto &it : or_filter.child_filters) { - if (EvaluateTableFilterAgainstConstant(context, *it, constant)) { - return true; - } - } - return false; - } - case TableFilterType::CONJUNCTION_AND: { - auto &and_filter = filter.Cast(); - auto res = make_uniq(); - for (auto &it : and_filter.child_filters) { - if (!EvaluateTableFilterAgainstConstant(context, *it, constant)) { - return false; - } - } - return true; - } - case TableFilterType::OPTIONAL_FILTER: { - auto &optional_filter = filter.Cast(); - if (optional_filter.child_filter) { - return EvaluateTableFilterAgainstConstant(context, *optional_filter.child_filter, constant); - } - return true; - } - case TableFilterType::DYNAMIC_FILTER: { - auto &dynamic_filter = filter.Cast(); - if (!dynamic_filter.filter_data) { - //! No filter_data assigned (does this mean the DynamicFilter is broken??) - return true; - } - lock_guard lock(dynamic_filter.filter_data->lock); - if (!dynamic_filter.filter_data->initialized) { - //! Not initialized - return true; - } - if (constant.IsNull()) { - return false; - } - auto column = make_uniq(constant.type(), 0ULL); - auto expression = dynamic_filter.filter_data->ToExpression(*column); - return ExpressionFilter(std::move(expression)).EvaluateWithConstant(context, constant); - } - case TableFilterType::BLOOM_FILTER: { - auto &bloom_filter = filter.Cast(); - return bloom_filter.FilterValue(constant); - } - case TableFilterType::PERFECT_HASH_JOIN_FILTER: { - auto &perfect_hash_join_filter = filter.Cast(); - return perfect_hash_join_filter.FilterValue(constant); - } - case TableFilterType::PREFIX_RANGE_FILTER: { - auto &prefix_range_filter = filter.Cast(); - return prefix_range_filter.FilterValue(constant); - } - default: - throw NotImplementedException("Can't evaluate TableFilterType (%s) against a constant", - EnumUtil::ToString(type)); - } + auto expression = GetFilterExpression(filter); + return ExpressionFilter(std::move(expression)).EvaluateWithConstant(context, constant); } bool MultiFileColumnMapper::EvaluateFilterAgainstConstant(const TableFilter &filter, const Value &constant) { @@ -854,7 +790,7 @@ Value MultiFileColumnMapper::GetConstantValue(MultiFileGlobalIndex global_index) auto global_column_id = global_column_ids[global_index].GetPrimaryIndex(); auto &expr = reader_data.expressions[global_index]; if (expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - return expr->Cast().value; + return expr->Cast().GetValue(); } for (idx_t i = 0; i < reader_data.constant_map.size(); i++) { auto &constant_map_entry = reader_data.constant_map[MultiFileConstantMapIndex(i)]; @@ -898,6 +834,11 @@ static unique_ptr CreateReferenceExpression(const LogicalType &type) return make_uniq(type, 0ULL); } +static unique_ptr GetFilterExpression(const TableFilter &filter) { + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "GetFilterExpression"); + return expr_filter.expr->Copy(); +} + static unique_ptr CreateStructExtractExpression(unique_ptr source_expr, const LogicalType &source_type, idx_t child_idx) { vector> arguments; @@ -940,7 +881,7 @@ static RewrittenMappedExpression RewriteMappedValueExpression(const Expression & if (!TryGetStructExtractChildIndex(func, child_idx)) { return result; } - auto child_result = RewriteMappedValueExpression(*func.children[0], mapping, target_type); + auto child_result = RewriteMappedValueExpression(*func.GetChildren()[0], mapping, target_type); if (!child_result.expr || !child_result.mapping || !child_result.type || child_result.type->id() != LogicalTypeId::STRUCT) { return result; @@ -961,10 +902,59 @@ static RewrittenMappedExpression RewriteMappedValueExpression(const Expression & } } +static unique_ptr RewriteDynamicFilterExpression(const shared_ptr &filter_data, + const LogicalType &target_type) { + if (!filter_data || !filter_data->initialized) { + return nullptr; + } + lock_guard lock(filter_data->lock); + auto new_constant = filter_data->constant; + if (!TryCastConstant(new_constant, target_type)) { + return nullptr; + } + return BoundComparisonExpression::Create(filter_data->comparison_type, CreateReferenceExpression(target_type), + make_uniq(std::move(new_constant))); +} + static unique_ptr TryCastFilterExpression(const Expression &expr, const MultiFileIndexMapping &mapping, const LogicalType &target_type) { switch (expr.GetExpressionClass()) { case ExpressionClass::BOUND_FUNCTION: { + auto &func = expr.Cast(); + if (func.Function().GetName() == OptionalFilterScalarFun::NAME) { + if (!func.BindInfo()) { + return CreateOptionalFilterExpression(nullptr, target_type); + } + auto &data = func.BindInfo()->Cast(); + auto child_expr = data.child_filter_expr + ? TryCastFilterExpression(*data.child_filter_expr, mapping, target_type) + : nullptr; + if (data.child_filter_expr && !child_expr) { + return nullptr; + } + return CreateOptionalFilterExpression(std::move(child_expr), target_type); + } + if (func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME) { + if (!func.BindInfo()) { + return CreateSelectivityOptionalFilterExpression(nullptr, target_type, 0.5f, idx_t(6)); + } + auto &data = func.BindInfo()->Cast(); + auto child_expr = data.child_filter_expr + ? TryCastFilterExpression(*data.child_filter_expr, mapping, target_type) + : nullptr; + if (data.child_filter_expr && !child_expr) { + return nullptr; + } + return CreateSelectivityOptionalFilterExpression(std::move(child_expr), target_type, + data.selectivity_threshold, data.n_vectors_to_check); + } + if (func.Function().GetName() == DynamicFilterScalarFun::NAME) { + if (!func.BindInfo()) { + return nullptr; + } + auto &data = func.BindInfo()->Cast(); + return RewriteDynamicFilterExpression(data.filter_data, target_type); + } if (!BoundComparisonExpression::IsComparison(expr)) { return nullptr; } @@ -978,7 +968,7 @@ static unique_ptr TryCastFilterExpression(const Expression &expr, co if (!lhs.expr || !lhs.type) { return nullptr; } - auto constant = right.Cast().value; + auto constant = right.Cast().GetValue(); if (!TryCastConstant(constant, *lhs.type)) { return nullptr; } @@ -988,12 +978,12 @@ static unique_ptr TryCastFilterExpression(const Expression &expr, co case ExpressionClass::BOUND_CONJUNCTION: { auto &conjunction = expr.Cast(); auto result = make_uniq(conjunction.GetExpressionType()); - for (auto &child : conjunction.children) { + for (auto &child : conjunction.GetChildren()) { auto rewritten_child = TryCastFilterExpression(*child, mapping, target_type); if (!rewritten_child) { return nullptr; } - result->children.push_back(std::move(rewritten_child)); + result->GetChildrenMutable().push_back(std::move(rewritten_child)); } return std::move(result); } @@ -1002,36 +992,36 @@ static unique_ptr TryCastFilterExpression(const Expression &expr, co switch (op.GetExpressionType()) { case ExpressionType::OPERATOR_IS_NULL: case ExpressionType::OPERATOR_IS_NOT_NULL: { - if (op.children.size() != 1) { + if (op.GetChildren().size() != 1) { return nullptr; } - auto child = RewriteMappedValueExpression(*op.children[0], mapping, target_type); + auto child = RewriteMappedValueExpression(*op.GetChildren()[0], mapping, target_type); if (!child.expr) { return nullptr; } auto result = make_uniq(op.GetExpressionType(), op.GetReturnType()); - result->children.push_back(std::move(child.expr)); + result->GetChildrenMutable().push_back(std::move(child.expr)); return std::move(result); } case ExpressionType::COMPARE_IN: { - if (op.children.empty()) { + if (op.GetChildren().empty()) { return nullptr; } - auto lhs = RewriteMappedValueExpression(*op.children[0], mapping, target_type); + auto lhs = RewriteMappedValueExpression(*op.GetChildren()[0], mapping, target_type); if (!lhs.expr || !lhs.type) { return nullptr; } auto result = make_uniq(op.GetExpressionType(), op.GetReturnType()); - result->children.push_back(std::move(lhs.expr)); - for (idx_t i = 1; i < op.children.size(); i++) { - if (op.children[i]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { + result->GetChildrenMutable().push_back(std::move(lhs.expr)); + for (idx_t i = 1; i < op.GetChildren().size(); i++) { + if (op.GetChildren()[i]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { return nullptr; } - auto constant = op.children[i]->Cast().value; + auto constant = op.GetChildren()[i]->Cast().GetValue(); if (!TryCastConstant(constant, *lhs.type)) { return nullptr; } - result->children.push_back(make_uniq(std::move(constant))); + result->GetChildrenMutable().push_back(make_uniq(std::move(constant))); } return std::move(result); } @@ -1048,78 +1038,12 @@ static unique_ptr TryCastFilterExpression(const Expression &expr, co static unique_ptr TryCastTableFilter(const TableFilter &global_filter, MultiFileIndexMapping &mapping, const LogicalType &target_type) { - if (global_filter.filter_type == TableFilterType::EXPRESSION_FILTER) { - auto &expr_filter = ExpressionFilter::GetExpressionFilter(global_filter, "TryCastTableFilter"); - auto rewritten_expr = TryCastFilterExpression(*expr_filter.expr, mapping, target_type); - if (!rewritten_expr) { - return nullptr; - } - return make_uniq(std::move(rewritten_expr)); - } - auto type = global_filter.filter_type; - - switch (type) { - case TableFilterType::CONJUNCTION_OR: { - auto &or_filter = global_filter.Cast(); - auto res = make_uniq(); - for (auto &it : or_filter.child_filters) { - auto child_filter = TryCastTableFilter(*it, mapping, target_type); - if (!child_filter) { - return nullptr; - } - res->child_filters.push_back(std::move(child_filter)); - } - return std::move(res); - } - case TableFilterType::CONJUNCTION_AND: { - auto &and_filter = global_filter.Cast(); - auto res = make_uniq(); - for (auto &it : and_filter.child_filters) { - auto child_filter = TryCastTableFilter(*it, mapping, target_type); - if (!child_filter) { - return nullptr; - } - res->child_filters.push_back(std::move(child_filter)); - } - return std::move(res); - } - case TableFilterType::OPTIONAL_FILTER: { - auto &optional_filter = global_filter.Cast(); - auto child_result = TryCastTableFilter(*optional_filter.child_filter, mapping, target_type); - if (!child_result) { - return nullptr; - } - return make_uniq(std::move(child_result)); - } - case TableFilterType::DYNAMIC_FILTER: { - // we can't transfer dynamic filters over casts directly - // BUT we can copy the current state of the filter and push that - // FIXME: we could solve this in a different manner as well by pushing the dynamic filter directly - auto &dynamic_filter = global_filter.Cast(); - if (!dynamic_filter.filter_data) { - return nullptr; - } - if (!dynamic_filter.filter_data->initialized) { - return nullptr; - } - lock_guard lock(dynamic_filter.filter_data->lock); - auto new_constant = dynamic_filter.filter_data->constant; - if (!StatisticsPropagator::CanPropagateCast(new_constant.type(), target_type)) { - // type cannot be converted - abort - return nullptr; - } - if (!new_constant.DefaultTryCastAs(target_type)) { - return nullptr; - } - auto lhs = make_uniq(target_type, 0ULL); - auto rhs = make_uniq(std::move(new_constant)); - return make_uniq(BoundComparisonExpression::Create( - dynamic_filter.filter_data->comparison_type, std::move(lhs), std::move(rhs))); - } - default: - throw NotImplementedException("Can't convert TableFilterType (%s) from global to local indexes", - EnumUtil::ToString(type)); + auto filter_expression = GetFilterExpression(global_filter); + auto rewritten_expr = TryCastFilterExpression(*filter_expression, mapping, target_type); + if (!rewritten_expr) { + return nullptr; } + return make_uniq(std::move(rewritten_expr)); } static void SetIndexToZero(unique_ptr &root_expr) { @@ -1127,15 +1051,15 @@ static void SetIndexToZero(unique_ptr &root_expr) { optional_idx index; ExpressionIterator::VisitExpressionMutable(root_expr, [&](BoundReferenceExpression &ref, unique_ptr &expr) { - if (index.IsValid() && index.GetIndex() != ref.index) { + if (index.IsValid() && index.GetIndex() != ref.Index()) { throw InternalException("Expected an expression that only references a single column, but found multiple!"); } - index = ref.index; - ref.index = 0; + index = ref.Index(); + ref.IndexMutable() = 0; }); #else ExpressionIterator::VisitExpressionMutable( - root_expr, [&](BoundReferenceExpression &ref, unique_ptr &expr) { ref.index = 0; }); + root_expr, [&](BoundReferenceExpression &ref, unique_ptr &expr) { ref.IndexMutable() = 0; }); #endif } @@ -1151,7 +1075,7 @@ MultiFileColumnMapper::CreateFilters(map local_to_global; for (auto &it : filters) { auto &global_index = it.first; - auto &global_filter = it.second.get(); + auto &global_filter = it.second.get().Cast(); auto local_it = global_to_local.find(global_index); if (local_it == global_to_local.end()) { diff --git a/src/duckdb/src/common/multi_file/multi_file_list.cpp b/src/duckdb/src/common/multi_file/multi_file_list.cpp index 8d62b54c1..f29dd3e53 100644 --- a/src/duckdb/src/common/multi_file/multi_file_list.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_list.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/types.hpp" #include "duckdb/function/function_set.hpp" #include "duckdb/main/config.hpp" +#include "duckdb/planner/filter/expression_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" @@ -20,7 +21,7 @@ MultiFilePushdownInfo::MultiFilePushdownInfo(LogicalGet &get) } } -MultiFilePushdownInfo::MultiFilePushdownInfo(TableIndex table_index, const vector &column_names, +MultiFilePushdownInfo::MultiFilePushdownInfo(TableIndex table_index, const vector &column_names, const vector &column_ids, ExtraOperatorInfo &extra_info) : table_index(table_index), column_names(column_names), column_ids(column_ids), extra_info(extra_info) { } @@ -33,7 +34,7 @@ bool PushdownInternal(ClientContext &context, const MultiFileOptions &options, M if (IsVirtualColumn(info.column_ids[i])) { continue; } - filter_info.column_map.insert({info.column_names[info.column_ids[i]], i}); + filter_info.column_map.insert({info.column_names[info.column_ids[i]].GetIdentifierName(), i}); } filter_info.hive_enabled = options.hive_partitioning; filter_info.filename_enabled = options.filename; @@ -48,7 +49,7 @@ bool PushdownInternal(ClientContext &context, const MultiFileOptions &options, M return false; } -bool PushdownInternal(ClientContext &context, const MultiFileOptions &options, const vector &names, +bool PushdownInternal(ClientContext &context, const MultiFileOptions &options, const vector &names, const vector &types, const vector &column_ids, const TableFilterSet &filters, vector &expanded_files) { TableIndex table_index(0); @@ -67,7 +68,8 @@ bool PushdownInternal(ClientContext &context, const MultiFileOptions &options, c } auto column_ref = make_uniq(types[column_idx], ColumnBinding(table_index, entry.GetIndex())); - auto filter_expr = entry.Filter().ToExpression(*column_ref); + auto &expr_filter = ExpressionFilter::GetExpressionFilter(entry.Filter(), "MultiFilePushdownInfo::Pushdown"); + auto filter_expr = expr_filter.ToExpression(*column_ref); filter_expressions.push_back(std::move(filter_expr)); } @@ -200,7 +202,7 @@ unique_ptr MultiFileList::ComplexFilterPushdown(ClientContext &co } unique_ptr MultiFileList::DynamicFilterPushdown(ClientContext &context, const MultiFileOptions &options, - const vector &names, + const vector &names, const vector &types, const vector &column_ids, TableFilterSet &filters) const { diff --git a/src/duckdb/src/common/multi_file/multi_file_reader.cpp b/src/duckdb/src/common/multi_file/multi_file_reader.cpp index a68b284d7..b2175e373 100644 --- a/src/duckdb/src/common/multi_file/multi_file_reader.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_reader.cpp @@ -46,7 +46,7 @@ unique_ptr MultiFileReader::Create(const TableFunction &table_f } unique_ptr MultiFileReader::Copy() const { - return CreateDefault(function_name); + return CreateDefault(function_name.GetIdentifierName()); } MultiFileBindData::~MultiFileBindData() { @@ -72,7 +72,7 @@ unique_ptr MultiFileBindData::Copy() const { unique_ptr MultiFileReader::CreateDefault(const string &function_name) { auto res = make_uniq(); - res->function_name = function_name; + res->function_name = Identifier(function_name); return res; } @@ -90,6 +90,7 @@ void MultiFileReader::AddParameters(TableFunction &table_function) { table_function.named_parameters["union_by_name"] = LogicalType::BOOLEAN; table_function.named_parameters["hive_types"] = LogicalType::ANY; table_function.named_parameters["hive_types_autocast"] = LogicalType::BOOLEAN; + table_function.named_parameters["allow_empty"] = LogicalType::BOOLEAN; } vector MultiFileReader::ParsePaths(const Value &input) { @@ -119,7 +120,7 @@ vector MultiFileReader::ParsePaths(const Value &input) { shared_ptr MultiFileReader::CreateFileList(ClientContext &context, const vector &paths, const FileGlobInput &glob_input) { auto res = make_uniq(context, paths, glob_input); - if (res->GetExpandResult() == FileExpandResult::NO_FILES && glob_input.behavior != FileGlobOptions::ALLOW_EMPTY) { + if (res->GetExpandResult() == FileExpandResult::NO_FILES && !glob_input.AllowsEmpty()) { throw IOException("%s needs at least one file to read", function_name); } return std::move(res); @@ -161,6 +162,11 @@ bool MultiFileReader::ParseOption(const string &key, const Value &val, MultiFile throw InvalidInputException("Cannot use NULL as argument for \"%s\"", key); } options.union_by_name = BooleanValue::Get(val); + } else if (loption == "allow_empty") { + if (val.IsNull()) { + throw InvalidInputException("Cannot use NULL as argument for \"%s\"", key); + } + options.allow_empty = BooleanValue::Get(val); } else if (loption == "hive_types_autocast" || loption == "hive_type_autocast") { if (val.IsNull()) { throw InvalidInputException("Cannot use NULL as argument for \"%s\"", key); @@ -185,7 +191,7 @@ bool MultiFileReader::ParseOption(const string &key, const Value &val, MultiFile } // for every child of the struct, get the logical type LogicalType transformed_type = TransformStringToLogicalType(child.ToString(), context); - const string &name = StructType::GetChildName(val.type(), i); + const string &name = StructType::GetChildName(val.type(), i).GetIdentifierName(); options.hive_types_schema[name] = transformed_type; } D_ASSERT(!options.hive_types_schema.empty()); @@ -202,20 +208,23 @@ unique_ptr MultiFileReader::ComplexFilterPushdown(ClientContext & return files.ComplexFilterPushdown(context, options, info, filters); } -unique_ptr MultiFileReader::DynamicFilterPushdown( - ClientContext &context, const MultiFileList &files, const MultiFileOptions &options, const vector &names, - const vector &types, const vector &column_ids, TableFilterSet &filters) { +unique_ptr MultiFileReader::DynamicFilterPushdown(ClientContext &context, const MultiFileList &files, + const MultiFileOptions &options, + const vector &names, + const vector &types, + const vector &column_ids, + TableFilterSet &filters) { return files.DynamicFilterPushdown(context, options, names, types, column_ids, filters); } bool MultiFileReader::Bind(MultiFileOptions &options, MultiFileList &files, vector &return_types, - vector &names, MultiFileReaderBindData &bind_data) { + vector &names, MultiFileReaderBindData &bind_data) { // The Default MultiFileReader can not perform any binding as it uses MultiFileLists with no schema information. return false; } void MultiFileReader::BindOptions(MultiFileOptions &options, MultiFileList &files, vector &return_types, - vector &names, MultiFileReaderBindData &bind_data) { + vector &names, MultiFileReaderBindData &bind_data) { // Add generated constant column for filename if (options.filename) { if (std::find(names.begin(), names.end(), options.filename_column) != names.end()) { @@ -262,9 +271,8 @@ void MultiFileReader::BindOptions(MultiFileOptions &options, MultiFileList &file for (auto &part : partitions) { idx_t hive_partitioning_index; - auto lookup = std::find_if(names.begin(), names.end(), [&](const string &col_name) { - return StringUtil::CIEquals(col_name, part.first); - }); + auto lookup = std::find_if(names.begin(), names.end(), + [&](const Identifier &col_name) { return col_name == part.first; }); if (lookup != names.end()) { // hive partitioning column also exists in file - override auto idx = NumericCast(lookup - names.begin()); @@ -336,7 +344,7 @@ void MultiFileReader::FinalizeBind(MultiFileReaderData &reader_data, const Multi if (file_options.union_by_name) { for (idx_t col_idx = 0; col_idx < local_columns.size(); col_idx++) { auto &column = local_columns[col_idx]; - name_map[column.name] = col_idx; + name_map[column.name.GetIdentifierName()] = col_idx; } } std::map hive_partitions; @@ -360,7 +368,7 @@ void MultiFileReader::FinalizeBind(MultiFileReaderData &reader_data, const Multi auto &name = column.name; auto &type = column.type; - auto entry = name_map.find(name); + auto entry = name_map.find(name.GetIdentifierName()); bool not_present_in_file = entry == name_map.end(); if (not_present_in_file) { // we need to project a column with name \"global_name\" - but it does not exist in the current file @@ -407,14 +415,14 @@ static string GetExtendedMultiFileError(const MultiFileBindData &bind_data, cons return string(); } auto &cast_expr = expr.Cast(); - if (cast_expr.child->GetExpressionType() != ExpressionType::BOUND_REF) { + if (cast_expr.Child().GetExpressionType() != ExpressionType::BOUND_REF) { return string(); } - auto &ref = cast_expr.child->Cast(); + auto &ref = cast_expr.Child().Cast(); auto &source_type = ref.GetReturnType(); auto &target_type = cast_expr.GetReturnType(); auto &columns = reader.GetColumns(); - auto local_col_id = reader.column_indexes[ref.index].GetPrimaryIndex(); + auto local_col_id = reader.column_indexes[ref.Index()].GetPrimaryIndex(); auto &local_col = columns[local_col_id]; auto reader_type = reader.GetReaderType(); @@ -471,7 +479,6 @@ void MultiFileReader::FinalizeChunk(ClientContext &context, const MultiFileBindD first_message, original_error, extended_error); } } - output_chunk.SetCardinality(input_chunk.size()); } void MultiFileReader::GetPartitionData(ClientContext &context, const MultiFileReaderBindData &bind_data, @@ -518,7 +525,7 @@ TablePartitionInfo MultiFileReader::GetPartitionInfo(ClientContext &context, con } TableFunctionSet MultiFileReader::CreateFunctionSet(TableFunction table_function) { - TableFunctionSet function_set(table_function.name); + TableFunctionSet function_set {table_function.name}; function_set.AddFunction(table_function); D_ASSERT(!table_function.GetArguments().empty() && table_function.GetArguments()[0] == LogicalType::VARCHAR); table_function.GetArguments()[0] = LogicalType::LIST(LogicalType::VARCHAR); @@ -543,11 +550,11 @@ unique_ptr MultiFileReader::GetVirtualColumnExpression(ClientContext } MultiFileReaderBindData MultiFileReader::BindUnionReader(ClientContext &context, vector &return_types, - vector &names, MultiFileList &files, + vector &names, MultiFileList &files, MultiFileBindData &result, BaseFileReaderOptions &options, MultiFileOptions &file_options) { D_ASSERT(file_options.union_by_name); - vector union_col_names; + vector union_col_names; vector union_col_types; // obtain the set of union column names + types by unifying the types of all of the files @@ -569,7 +576,7 @@ MultiFileReaderBindData MultiFileReader::BindUnionReader(ClientContext &context, } MultiFileReaderBindData MultiFileReader::BindReader(ClientContext &context, vector &return_types, - vector &names, MultiFileList &files, + vector &names, MultiFileList &files, MultiFileBindData &result, BaseFileReaderOptions &options, MultiFileOptions &file_options) { if (file_options.union_by_name) { @@ -696,24 +703,26 @@ void MultiFileOptions::AddBatchInfo(BindInfo &bind_info) const { bind_info.InsertOption("auto_detect_hive_partitioning", Value::BOOLEAN(auto_detect_hive_partitioning)); bind_info.InsertOption("union_by_name", Value::BOOLEAN(union_by_name)); bind_info.InsertOption("hive_types_autocast", Value::BOOLEAN(hive_types_autocast)); + bind_info.InsertOption("allow_empty", Value::BOOLEAN(allow_empty)); } void UnionByName::CombineUnionTypes(const vector &col_names, const vector &sql_types, - vector &union_col_types, vector &union_col_names, - case_insensitive_map_t &union_names_map) { + vector &union_col_types, vector &union_col_names, + identifier_map_t &union_names_map) { D_ASSERT(col_names.size() == sql_types.size()); for (idx_t col = 0; col < col_names.size(); ++col) { - auto union_find = union_names_map.find(col_names[col]); + Identifier col_name(col_names[col]); + auto union_find = union_names_map.find(col_name); if (union_find != union_names_map.end()) { // given same name , union_col's type must compatible with col's type auto ¤t_type = union_col_types[union_find->second]; - auto compatible_type = LogicalType::ForceMaxLogicalType(current_type, sql_types[col]); + auto compatible_type = LogicalType::DefaultForceMaxLogicalType(current_type, sql_types[col]); union_col_types[union_find->second] = compatible_type; } else { - union_names_map[col_names[col]] = union_col_names.size(); - union_col_names.emplace_back(col_names[col]); + union_names_map[col_name] = union_col_names.size(); + union_col_names.emplace_back(col_name); union_col_types.emplace_back(sql_types[col]); } } diff --git a/src/duckdb/src/common/multi_file/union_by_name.cpp b/src/duckdb/src/common/multi_file/union_by_name.cpp index 2aae4acb6..2888345a8 100644 --- a/src/duckdb/src/common/multi_file/union_by_name.cpp +++ b/src/duckdb/src/common/multi_file/union_by_name.cpp @@ -33,10 +33,12 @@ class UnionByReaderTask : public BaseExecutorTask { MultiFileReaderInterface &interface; }; -vector> -UnionByName::UnionCols(ClientContext &context, const vector &files, vector &union_col_types, - vector &union_col_names, BaseFileReaderOptions &options, MultiFileOptions &file_options, - MultiFileReader &multi_file_reader, MultiFileReaderInterface &interface) { +vector> UnionByName::UnionCols(ClientContext &context, const vector &files, + vector &union_col_types, + vector &union_col_names, + BaseFileReaderOptions &options, MultiFileOptions &file_options, + MultiFileReader &multi_file_reader, + MultiFileReaderInterface &interface) { vector> union_readers; union_readers.resize(files.size()); @@ -51,7 +53,7 @@ UnionByName::UnionCols(ClientContext &context, const vector &files executor.WorkOnTasks(); // now combine the result schemas - case_insensitive_map_t union_names_map; + identifier_map_t union_names_map; for (auto &reader : union_readers) { auto &col_names = reader->names; auto &sql_types = reader->types; diff --git a/src/duckdb/src/common/opener_file_system.cpp b/src/duckdb/src/common/opener_file_system.cpp index 8f55d6898..f713cff06 100644 --- a/src/duckdb/src/common/opener_file_system.cpp +++ b/src/duckdb/src/common/opener_file_system.cpp @@ -1,10 +1,22 @@ #include "duckdb/common/opener_file_system.hpp" #include "duckdb/common/file_opener.hpp" +#include "duckdb/common/memory_mapped_file.hpp" #include "duckdb/main/database.hpp" #include "duckdb/main/config.hpp" namespace duckdb { +unique_ptr OpenerFileSystem::MemoryMapFile(const OpenFileInfo &path, FileOpenFlags flags, + const MMapOptions &options, + optional_ptr opener) { + VerifyNoOpener(opener); + VerifyCanAccessFile(path.path); + if (IsDuckDBExtensionName(path.path)) { + VerifyCanAccessExtension(path.path, flags); + } + return GetFileSystem().MemoryMapFile(path, flags, options, GetOpener()); +} + void OpenerFileSystem::VerifyNoOpener(optional_ptr opener) { if (opener) { throw InternalException("OpenerFileSystem cannot take an opener - the opener is pushed automatically"); diff --git a/src/duckdb/src/common/operator/cast_operators.cpp b/src/duckdb/src/common/operator/cast_operators.cpp index 2e2c224a1..3151f8c02 100644 --- a/src/duckdb/src/common/operator/cast_operators.cpp +++ b/src/duckdb/src/common/operator/cast_operators.cpp @@ -1175,6 +1175,11 @@ bool TryCast::Operation(timestamp_tz_t input, timestamp_tz_t &result, bool stric return TryCastTimebase(input, result, strict); } +template <> +bool TryCast::Operation(timestamp_tz_t input, timestamp_t &result, bool strict) { + return TryCastTimebase(input, result, strict); +} + template <> bool TryCast::Operation(timestamp_tz_ns_t input, timestamp_tz_ns_t &result, bool strict) { return TryCastTimebase(input, result, strict); @@ -1185,6 +1190,11 @@ bool TryCast::Operation(timestamp_ns_t input, timestamp_tz_ns_t &result, bool st return TryCastTimebase(input, result, strict); } +template <> +bool TryCast::Operation(timestamp_tz_ns_t input, timestamp_ns_t &result, bool strict) { + return TryCastTimebase(input, result, strict); +} + template <> bool TryCast::Operation(timestamp_t input, timestamp_tz_t &result, bool strict) { return TryCastTimebase(input, result, strict); @@ -1217,46 +1227,45 @@ duckdb::string_t CastFromTimestampNS::Operation(duckdb::timestamp_ns_t input, St } template <> duckdb::string_t CastFromTimestampMS::Operation(duckdb::timestamp_ms_t input, StringHeap &heap) { - return StringCast::Operation(CastTimestampMsToUs::Operation(input), heap); + return StringCast::Operation(Cast::Operation(input), heap); } template <> duckdb::string_t CastFromTimestampSec::Operation(duckdb::timestamp_sec_t input, StringHeap &heap) { - return StringCast::Operation(CastTimestampSecToUs::Operation(input), - heap); + return StringCast::Operation(Cast::Operation(input), heap); } template <> -timestamp_ms_t CastTimestampUsToMs::Operation(timestamp_t input) { +timestamp_ms_t Cast::Operation(timestamp_t input) { return CastTimebase(input); } template <> -timestamp_ns_t CastTimestampUsToNs::Operation(timestamp_t input) { +timestamp_ns_t Cast::Operation(timestamp_t input) { return CastTimebase(input); } template <> -timestamp_sec_t CastTimestampUsToSec::Operation(timestamp_t input) { +timestamp_sec_t Cast::Operation(timestamp_t input) { return CastTimebase(input); } template <> -timestamp_t CastTimestampMsToUs::Operation(timestamp_ms_t input) { +timestamp_t Cast::Operation(timestamp_ms_t input) { return CastTimebase(input); } template <> -date_t CastTimestampMsToDate::Operation(timestamp_ms_t input) { +date_t Cast::Operation(timestamp_ms_t input) { return Timestamp::GetDate(CastTimebase(input)); } template <> -dtime_t CastTimestampMsToTime::Operation(timestamp_ms_t input) { +dtime_t Cast::Operation(timestamp_ms_t input) { return Timestamp::GetTime(CastTimebase(input)); } template <> -timestamp_ns_t CastTimestampMsToNs::Operation(timestamp_ms_t input) { +timestamp_ns_t Cast::Operation(timestamp_ms_t input) { return CastTimebase(input); } @@ -1286,47 +1295,47 @@ bool TryCast::Operation(timestamp_sec_t input, timestamp_t &result, bool strict) } template <> -timestamp_t CastTimestampNsToUs::Operation(timestamp_ns_t input) { +timestamp_t Cast::Operation(timestamp_ns_t input) { return CastTimebase(input); } template <> -timestamp_t CastTimestampSecToUs::Operation(timestamp_sec_t input) { +timestamp_t Cast::Operation(timestamp_sec_t input) { return CastTimebase(input); } template <> -date_t CastTimestampNsToDate::Operation(timestamp_ns_t input) { +date_t Cast::Operation(timestamp_ns_t input) { return Timestamp::GetDate(CastTimebase(input)); } template <> -dtime_t CastTimestampNsToTime::Operation(timestamp_ns_t input) { +dtime_t Cast::Operation(timestamp_ns_t input) { return Timestamp::GetTime(CastTimebase(input)); } template <> -dtime_ns_t CastTimestampNsToTimeNs::Operation(timestamp_ns_t input) { +dtime_ns_t Cast::Operation(timestamp_ns_t input) { return Timestamp::GetTimeNs(input); } template <> -timestamp_ms_t CastTimestampSecToMs::Operation(timestamp_sec_t input) { +timestamp_ms_t Cast::Operation(timestamp_sec_t input) { return CastTimebase(input); } template <> -timestamp_ns_t CastTimestampSecToNs::Operation(timestamp_sec_t input) { +timestamp_ns_t Cast::Operation(timestamp_sec_t input) { return CastTimebase(input); } template <> -date_t CastTimestampSecToDate::Operation(timestamp_sec_t input) { +date_t Cast::Operation(timestamp_sec_t input) { return Timestamp::GetDate(CastTimebase(input)); } template <> -dtime_t CastTimestampSecToTime::Operation(timestamp_sec_t input) { +dtime_t Cast::Operation(timestamp_sec_t input) { return Timestamp::GetTime(CastTimebase(input)); } @@ -1334,12 +1343,7 @@ dtime_t CastTimestampSecToTime::Operation(timestamp_sec_t input) { // Cast To Timestamp //===--------------------------------------------------------------------===// template <> -bool TryCastToTimestampNS::Operation(string_t input, timestamp_ns_t &result, bool strict) { - return TryCast::Operation(input, result, strict); -} - -template <> -bool TryCastToTimestampMS::Operation(string_t input, timestamp_ms_t &result, bool strict) { +bool TryCast::Operation(string_t input, timestamp_ms_t &result, bool strict) { timestamp_t us; if (!TryCast::Operation(input, us, strict)) { return false; @@ -1348,7 +1352,7 @@ bool TryCastToTimestampMS::Operation(string_t input, timestamp_ms_t &result, boo } template <> -bool TryCastToTimestampSec::Operation(string_t input, timestamp_sec_t &result, bool strict) { +bool TryCast::Operation(string_t input, timestamp_sec_t &result, bool strict) { timestamp_t us; if (!TryCast::Operation(input, us, strict)) { return false; @@ -1357,7 +1361,7 @@ bool TryCastToTimestampSec::Operation(string_t input, timestamp_sec_t &result, b } template <> -bool TryCastToTimestampNS::Operation(date_t input, timestamp_ns_t &result, bool strict) { +bool TryCast::Operation(date_t input, timestamp_ns_t &result, bool strict) { timestamp_t us; if (!TryCast::Operation(input, us, strict)) { return false; @@ -1366,7 +1370,7 @@ bool TryCastToTimestampNS::Operation(date_t input, timestamp_ns_t &result, bool } template <> -bool TryCastToTimestampMS::Operation(date_t input, timestamp_ms_t &result, bool strict) { +bool TryCast::Operation(date_t input, timestamp_ms_t &result, bool strict) { timestamp_t us; if (!TryCast::Operation(input, us, strict)) { return false; @@ -1375,7 +1379,7 @@ bool TryCastToTimestampMS::Operation(date_t input, timestamp_ms_t &result, bool } template <> -bool TryCastToTimestampSec::Operation(date_t input, timestamp_sec_t &result, bool strict) { +bool TryCast::Operation(date_t input, timestamp_sec_t &result, bool strict) { timestamp_t us; if (!TryCast::Operation(input, us, strict)) { return false; @@ -2950,42 +2954,77 @@ bool IsRepresentableExactly(hugeint_t input, double dst) { return (input <= MAX_INT_REPRESENTABLE_IN_DOUBLE && input >= -MAX_INT_REPRESENTABLE_IN_DOUBLE); } -template -static SRC GetPowerOfTen(SRC input, uint8_t scale) { - return static_cast(NumericHelper::POWERS_OF_TEN[scale]); -} - -template <> -hugeint_t GetPowerOfTen(hugeint_t input, uint8_t scale) { - return Hugeint::POWERS_OF_TEN[scale]; +template +static bool CanUseDecimalFloatingPointFastPath(SRC input, uint8_t scale) { + return duckdb_fast_float::binary_format::min_exponent_fast_path() <= -scale && + -scale <= duckdb_fast_float::binary_format::max_exponent_fast_path() && + IsRepresentableExactly(input, DST(0.0)); +} + +template +static void FillDecimalDigits(UNSIGNED input, duckdb_fast_float::decimal &decimal) { + uint8_t digits[DecimalWidth::max]; + while (input > 0) { + digits[decimal.num_digits++] = UnsafeNumericCast(input % 10); + input /= 10; + } + for (uint32_t i = 0; i < decimal.num_digits; i++) { + decimal.digits[i] = digits[decimal.num_digits - i - 1]; + } } template -static void GetDivMod(SRC lhs, SRC rhs, SRC &div, SRC &mod) { - div = lhs / rhs; - mod = lhs % rhs; +static void FillDecimalDigits(SRC input, duckdb_fast_float::decimal &decimal, bool &negative) { + using UNSIGNED = typename MakeUnsigned::type; + if (input < 0) { + negative = true; + auto unsigned_input = UnsafeNumericCast(-(input + 1)); + FillDecimalDigits(unsigned_input + 1, decimal); + } else { + negative = false; + FillDecimalDigits(UnsafeNumericCast(input), decimal); + } } -template <> -void GetDivMod(hugeint_t lhs, hugeint_t rhs, hugeint_t &div, hugeint_t &mod) { - div = Hugeint::DivMod(lhs, rhs, mod); +static void FillDecimalDigits(hugeint_t input, duckdb_fast_float::decimal &decimal, bool &negative) { + if (input < 0) { + negative = true; + Hugeint::NegateInPlace(input); + } else { + negative = false; + } + uint8_t digits[DecimalWidth::max]; + while (input > 0) { + uint64_t remainder; + input = Hugeint::DivModPositive(input, 10, remainder); + digits[decimal.num_digits++] = UnsafeNumericCast(remainder); + } + for (uint32_t i = 0; i < decimal.num_digits; i++) { + decimal.digits[i] = digits[decimal.num_digits - i - 1]; + } } template -bool TryCastDecimalToFloatingPoint(SRC input, DST &result, uint8_t scale) { - if (IsRepresentableExactly(input, DST(0.0)) || scale == 0) { - // Fast path, integer is representable exactly as a float/double +bool TryCastDecimalToFloatingPoint(SRC input, DST &result, uint8_t width, uint8_t scale) { + if (scale == 0 || CanUseDecimalFloatingPointFastPath(input, scale)) { + // Fast path, integer and decimal exponent are representable exactly as a float/double result = Cast::Operation(input) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); return true; } - auto power_of_ten = GetPowerOfTen(input, scale); - SRC div = 0; - SRC mod = 0; - GetDivMod(input, power_of_ten, div, mod); + duckdb_fast_float::decimal decimal; + bool negative; + FillDecimalDigits(input, decimal, negative); + decimal.decimal_point = UnsafeNumericCast(decimal.num_digits) - UnsafeNumericCast(scale); + while (decimal.num_digits > 0 && decimal.digits[decimal.num_digits - 1] == 0) { + decimal.num_digits--; + } + for (uint32_t i = decimal.num_digits; i < duckdb_fast_float::max_digit_without_overflow; i++) { + decimal.digits[i] = 0; + } - result = Cast::Operation(div) + - Cast::Operation(mod) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); + auto adjusted_mantissa = duckdb_fast_float::compute_float>(decimal); + duckdb_fast_float::detail::to_float(negative, adjusted_mantissa, result); return true; } @@ -2993,50 +3032,50 @@ bool TryCastDecimalToFloatingPoint(SRC input, DST &result, uint8_t scale) { template <> bool TryCastFromDecimal::Operation(int16_t input, float &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); + return TryCastDecimalToFloatingPoint(input, result, width, scale); } template <> bool TryCastFromDecimal::Operation(int32_t input, float &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); + return TryCastDecimalToFloatingPoint(input, result, width, scale); } template <> bool TryCastFromDecimal::Operation(int64_t input, float &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); + return TryCastDecimalToFloatingPoint(input, result, width, scale); } template <> bool TryCastFromDecimal::Operation(hugeint_t input, float &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); + return TryCastDecimalToFloatingPoint(input, result, width, scale); } // DECIMAL -> DOUBLE template <> bool TryCastFromDecimal::Operation(int16_t input, double &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); + return TryCastDecimalToFloatingPoint(input, result, width, scale); } template <> bool TryCastFromDecimal::Operation(int32_t input, double &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); + return TryCastDecimalToFloatingPoint(input, result, width, scale); } template <> bool TryCastFromDecimal::Operation(int64_t input, double &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); + return TryCastDecimalToFloatingPoint(input, result, width, scale); } template <> bool TryCastFromDecimal::Operation(hugeint_t input, double &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); + return TryCastDecimalToFloatingPoint(input, result, width, scale); } } // namespace duckdb diff --git a/src/duckdb/src/common/operator/string_cast.cpp b/src/duckdb/src/common/operator/string_cast.cpp index 714e0c77b..e4b29f5f2 100644 --- a/src/duckdb/src/common/operator/string_cast.cpp +++ b/src/duckdb/src/common/operator/string_cast.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/operator/string_cast.hpp" #include "duckdb/common/types/string_heap.hpp" #include "duckdb/common/types/date.hpp" @@ -85,10 +86,8 @@ duckdb::string_t StringCast::Operation(uhugeint_t input, StringHeap &heap) { template <> duckdb::string_t StringCast::Operation(date_t input, StringHeap &heap) { - if (input == date_t::infinity()) { - return heap.AddString(Date::PINF.str); - } else if (input == date_t::ninfinity()) { - return heap.AddString(Date::NINF.str); + if (!input.IsFinite()) { + return heap.AddString(Date::ToInfinity(input)); } int32_t date[3]; Date::Convert(input, date[0], date[1], date[2]); @@ -154,26 +153,23 @@ duckdb::string_t StringCast::Operation(dtime_ns_t input, StringHeap &heap) { return StringAsTime(input, heap); } -template -duckdb::string_t StringFromTimestamp(timestamp_t input, StringHeap &heap) { - if (input == timestamp_t::infinity()) { - return heap.AddString(Date::PINF.str); - } - if (input == timestamp_t::ninfinity()) { - return heap.AddString(Date::NINF.str); +template +duckdb::string_t StringFromTimestamp(const T &input, StringHeap &heap) { + if (!input.IsFinite()) { + return heap.AddString(Date::ToInfinity(input)); } date_t date_entry; dtime_t time_entry; int32_t picos = 0; - if (HAS_NANOS) { - timestamp_ns_t ns; - ns.value = input.value; + if (T::PRECISION == timestamp_ns_t::PRECISION) { + auto ns = Cast::Operation(input); Timestamp::Convert(ns, date_entry, time_entry, picos); // Use picoseconds so we have 6 digits picos *= 1000; } else { - Timestamp::Convert(input, date_entry, time_entry); + auto us = Cast::Operation(input); + Timestamp::Convert(us, date_entry, time_entry); } int32_t date[3], time[4]; @@ -197,7 +193,8 @@ duckdb::string_t StringFromTimestamp(timestamp_t input, StringHeap &heap) { nano_length = 6; nano_length -= NumericCast(TimeToStringCast::FormatMicros(picos, nano_buffer)); } - const idx_t length = date_length + 1 + time_length + nano_length; + const idx_t tz_length = T::TIMEZONED ? 3 : 0; + const idx_t length = date_length + 1 + time_length + nano_length + tz_length; string_t result = heap.EmptyString(length); auto data = result.GetDataWriteable(); @@ -209,6 +206,13 @@ duckdb::string_t StringFromTimestamp(timestamp_t input, StringHeap &heap) { data += time_length; memcpy(data, nano_buffer, nano_length); D_ASSERT(data + nano_length <= result.GetDataWriteable() + length); + data += nano_length; + + if (T::TIMEZONED) { + *data++ = '+'; + *data++ = '0'; + *data++ = '0'; + } result.Finalize(); return result; @@ -216,14 +220,12 @@ duckdb::string_t StringFromTimestamp(timestamp_t input, StringHeap &heap) { template <> duckdb::string_t StringCast::Operation(timestamp_t input, StringHeap &heap) { - return StringFromTimestamp(input, heap); + return StringFromTimestamp(input, heap); } template <> duckdb::string_t StringCast::Operation(timestamp_ns_t input, StringHeap &heap) { - // FIXME: This is just horrible... - timestamp_t us(input.value); - return StringFromTimestamp(us, heap); + return StringFromTimestamp(input, heap); } template <> @@ -232,7 +234,7 @@ duckdb::string_t StringCast::Operation(duckdb::string_t input, StringHeap &heap) } template <> -string_t StringCastTZ::Operation(dtime_tz_t input, StringHeap &heap) { +string_t StringCast::Operation(dtime_tz_t input, StringHeap &heap) { int32_t time[4]; Time::Convert(input.time(), time[0], time[1], time[2], time[3]); @@ -293,88 +295,13 @@ string_t StringCastTZ::Operation(dtime_tz_t input, StringHeap &heap) { } template <> -string_t StringCastTZ::Operation(timestamp_t input, StringHeap &heap) { - if (input == timestamp_t::infinity()) { - return heap.AddString(Date::PINF.str); - } - if (input == timestamp_t::ninfinity()) { - return heap.AddString(Date::NINF.str); - } - - date_t date_entry; - dtime_t time_entry; - Timestamp::Convert(input, date_entry, time_entry); - - int32_t date[3], time[4]; - Date::Convert(date_entry, date[0], date[1], date[2]); - Time::Convert(time_entry, time[0], time[1], time[2], time[3]); - - // format for timestamptz is DATE TIME+00 (separated by space) - idx_t year_length; - bool add_bc; - char micro_buffer[6] = {}; - const idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); - const idx_t time_length = TimeToStringCast::Length(time, micro_buffer); - const idx_t length = date_length + 1 + time_length + 3; - - string_t result = heap.EmptyString(length); - auto data = result.GetDataWriteable(); - - idx_t pos = 0; - DateToStringCast::Format(data + pos, date, year_length, add_bc); - pos += date_length; - data[pos++] = ' '; - TimeToStringCast::Format(data + pos, time_length, time, micro_buffer); - pos += time_length; - data[pos++] = '+'; - data[pos++] = '0'; - data[pos++] = '0'; - - result.Finalize(); - return result; +string_t StringCast::Operation(timestamp_tz_t input, StringHeap &heap) { + return StringFromTimestamp(input, heap); } template <> -string_t StringCastTZ::Operation(timestamp_ns_t input, StringHeap &heap) { - if (input == input.infinity()) { - return heap.AddString(Date::PINF.str); - } - if (input == input.ninfinity()) { - return heap.AddString(Date::NINF.str); - } - - date_t date_entry; - dtime_t time_entry; - int32_t nanos; - Timestamp::Convert(input, date_entry, time_entry, nanos); - - int32_t date[3], time[4]; - Date::Convert(date_entry, date[0], date[1], date[2]); - Time::Convert(time_entry, time[0], time[1], time[2], time[3]); - - // format for timestamptzns is DATE TIME+00 (separated by space) - idx_t year_length; - bool add_bc; - char nano_buffer[9] = {}; - const idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); - const idx_t time_length = TimeToStringCast::Length(time, nano_buffer, nanos); - const idx_t length = date_length + 1 + time_length + 3; - - string_t result = heap.EmptyString(length); - auto data = result.GetDataWriteable(); - - idx_t pos = 0; - DateToStringCast::Format(data + pos, date, year_length, add_bc); - pos += date_length; - data[pos++] = ' '; - TimeToStringCast::Format(data + pos, time_length, time, nano_buffer); - pos += time_length; - data[pos++] = '+'; - data[pos++] = '0'; - data[pos++] = '0'; - - result.Finalize(); - return result; +string_t StringCast::Operation(timestamp_tz_ns_t input, StringHeap &heap) { + return StringFromTimestamp(input, heap); } } // namespace duckdb diff --git a/src/duckdb/src/common/printer.cpp b/src/duckdb/src/common/printer.cpp index 46187ac24..115c1cae0 100644 --- a/src/duckdb/src/common/printer.cpp +++ b/src/duckdb/src/common/printer.cpp @@ -73,8 +73,11 @@ idx_t Printer::TerminalWidth() { rows = csbi.srWindow.Right - csbi.srWindow.Left + 1; return rows; #else - struct winsize w; + struct winsize w = {}; ioctl(0, TIOCGWINSZ, &w); + if (w.ws_col == 0) { + return 120; + } return w.ws_col; #endif #else diff --git a/src/duckdb/src/common/radix_partitioning.cpp b/src/duckdb/src/common/radix_partitioning.cpp index 6ccf4bd1d..695b36aa1 100644 --- a/src/duckdb/src/common/radix_partitioning.cpp +++ b/src/duckdb/src/common/radix_partitioning.cpp @@ -62,7 +62,7 @@ RETURN_TYPE RadixBitsSwitch(const idx_t radix_bits, ARGS &&... args) { struct SelectFunctor { template - static idx_t Operation(Vector &hashes, const SelectionVector *sel, const idx_t count, + static idx_t Operation(const Vector &hashes, const SelectionVector *sel, const idx_t count, const ValidityMask &partition_mask, SelectionVector *true_sel, SelectionVector *false_sel) { using CONSTANTS = RadixPartitioningConstants; return UnaryExecutor::Select( @@ -75,15 +75,15 @@ struct SelectFunctor { } }; -idx_t RadixPartitioning::Select(Vector &hashes, const SelectionVector *sel, const idx_t count, const idx_t radix_bits, - const ValidityMask &partition_mask, SelectionVector *true_sel, +idx_t RadixPartitioning::Select(const Vector &hashes, const SelectionVector *sel, const idx_t count, + const idx_t radix_bits, const ValidityMask &partition_mask, SelectionVector *true_sel, SelectionVector *false_sel) { return RadixBitsSwitch(radix_bits, hashes, sel, count, partition_mask, true_sel, false_sel); } struct ComputePartitionIndicesFunctor { template - static void Operation(Vector &hashes, Vector &partition_indices, const idx_t original_count, + static void Operation(const Vector &hashes, Vector &partition_indices, const idx_t original_count, const SelectionVector &append_sel, const idx_t append_count) { using CONSTANTS = RadixPartitioningConstants; if (!append_sel.IsSet() || hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -184,6 +184,48 @@ void RadixPartitionedTupleData::Initialize() { } } +void RadixPartitionedTupleData::ResetAppendState(PartitionedTupleDataAppendState &state, + const TupleDataPinProperties properties) const { + if (!state.partition_sel.data()) { + state.partition_sel.Initialize(); + } + if (!state.reverse_partition_sel.data()) { + state.reverse_partition_sel.Initialize(); + } + + const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); + if (state.partition_pin_states.size() != num_partitions) { + state.partition_pin_states.clear(); + state.partition_pin_states.reserve(num_partitions); + for (idx_t i = 0; i < num_partitions; i++) { + state.partition_pin_states.emplace_back(); + } + } + for (idx_t i = 0; i < num_partitions; i++) { + auto &pin_state = state.partition_pin_states[i]; + pin_state.row_handles.clear(); + pin_state.heap_handles.clear(); + partitions[i]->InitializeAppend(pin_state, properties); + } + + if (state.chunk_state.column_ids.empty()) { + auto column_count = layout.ColumnCount(); + vector column_ids; + column_ids.reserve(column_count); + for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { + column_ids.emplace_back(col_idx); + } + partitions[0]->InitializeChunkState(state.chunk_state, std::move(column_ids)); + } + state.chunk_state.chunk_lock = nullptr; + state.chunk_state.chunk_parts.clear(); + state.chunk_state.chunk_part_indices.clear(); + + if (state.fixed_partition_entries.size() != num_partitions) { + state.fixed_partition_entries.resize(num_partitions); + } +} + void RadixPartitionedTupleData::InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, const TupleDataPinProperties properties) const { // Init pin state per partition diff --git a/src/duckdb/src/common/render_tree.cpp b/src/duckdb/src/common/render_tree.cpp index f3bb9d54a..82da9454d 100644 --- a/src/duckdb/src/common/render_tree.cpp +++ b/src/duckdb/src/common/render_tree.cpp @@ -100,27 +100,13 @@ static unique_ptr CreateNode(const PipelineRenderNode &op) { } static unique_ptr CreateNode(const ProfilingNode &op) { - auto &info = op.GetProfilingInfo(); - InsertionOrderPreservingMap extra_info; - if (info.Enabled(info.settings, MetricType::EXTRA_INFO)) { - extra_info = op.GetProfilingInfo().GetMetricValue>(MetricType::EXTRA_INFO); - } - - string node_name = "QUERY"; - if (op.depth > 0) { - node_name = info.GetMetricAsString(MetricType::OPERATOR_TYPE); - } + auto &info = op.GetOperatorMetrics(); - auto result = make_uniq(node_name, extra_info); - if (info.Enabled(info.settings, MetricType::OPERATOR_CARDINALITY)) { - auto cardinality = info.GetMetricAsString(MetricType::OPERATOR_CARDINALITY); - result->extra_text[RenderTreeNode::CARDINALITY] = cardinality; - } - if (info.Enabled(info.settings, MetricType::OPERATOR_TIMING)) { - auto value = info.metrics.at(MetricType::OPERATOR_TIMING).GetValue(); - string timing = StringUtil::Format("%.2f", value); - result->extra_text[RenderTreeNode::TIMING] = timing + "s"; - } + auto &node_name = info.name; + auto result = make_uniq(node_name, info.GetExtraInfo()); + result->extra_text[RenderTreeNode::CARDINALITY] = to_string(info.elements_returned); + string timing = StringUtil::Format("%.2f", info.time); + result->extra_text[RenderTreeNode::TIMING] = timing + "s"; return result; } diff --git a/src/duckdb/src/common/row_operations/row_aggregate.cpp b/src/duckdb/src/common/row_operations/row_aggregate.cpp index 340b84545..06e94422d 100644 --- a/src/duckdb/src/common/row_operations/row_aggregate.cpp +++ b/src/duckdb/src/common/row_operations/row_aggregate.cpp @@ -30,40 +30,35 @@ void RowOperations::InitializeStates(TupleDataLayout &layout, Vector &addresses, } } -void RowOperations::DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, idx_t count) { +void RowOperations::DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses) { + auto count = addresses.size(); if (count == 0) { return; } // Move to the first aggregate state - VectorOperations::AddInPlace(addresses, UnsafeNumericCast(layout.GetAggrOffset()), count); + VectorOperations::AddInPlace(addresses, UnsafeNumericCast(layout.GetAggrOffset())); for (const auto &aggr : layout.GetAggregates()) { if (aggr.function.HasStateDestructorCallback()) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); + AggregateInputData aggr_input_data(aggr, state.allocator); aggr.function.GetStateDestructorCallback()(addresses, aggr_input_data, count); } // Move to the next aggregate state - VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggr.payload_size), count); + VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggr.payload_size)); } } void RowOperations::UpdateStates(RowOperationsState &state, AggregateObject &aggr, Vector &addresses, - DataChunk &payload, idx_t arg_idx, idx_t count) { - if (addresses.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (aggr.function.HasStateSimpleUpdateCallback()) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.GetStateSimpleUpdateCallback()(aggr.child_count == 0 ? nullptr : &payload.data[arg_idx], - aggr_input_data, aggr.child_count, - FlatVector::GetData(addresses)[0], payload.size()); - return; - } - // FIXME: these are expected to be flat in at least a few places: - // "Test Aggregate Functions C API" - // "Test String Aggregate Function" - addresses.Flatten(); + DataChunk &payload, idx_t arg_idx, optional_ptr clustered) { + auto count = addresses.size(); + AggregateInputData aggr_input_data(aggr, state.allocator); + auto cluster_update = aggr.function.GetStateClusterUpdateCallback(); + aggr_input_data.clustered = cluster_update ? clustered : nullptr; + auto inputs = aggr.child_count ? payload.data.data() + arg_idx : nullptr; + if (clustered && cluster_update) { + cluster_update(inputs, aggr_input_data, aggr.child_count, *clustered, count); + return; } - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.GetStateUpdateCallback()(aggr.child_count == 0 ? nullptr : &payload.data[arg_idx], aggr_input_data, - aggr.child_count, addresses, count); + aggr.function.GetStateUpdateCallback()(inputs, aggr_input_data, aggr.child_count, addresses, count); } void RowOperations::UpdateFilteredStates(RowOperationsState &state, AggregateFilterData &filter_data, @@ -76,39 +71,80 @@ void RowOperations::UpdateFilteredStates(RowOperationsState &state, AggregateFil Vector filtered_addresses(addresses, filter_data.true_sel, count); filtered_addresses.Flatten(); - UpdateStates(state, aggr, filtered_addresses, filter_data.filtered_payload, arg_idx, count); + UpdateStates(state, aggr, filtered_addresses, filter_data.filtered_payload, arg_idx); +} + +void RowOperations::UpdateStatesClustered(RowOperationsState &state, vector &aggregates, + AggregateFilterDataSet *filter_set, const unsafe_vector *filter, + Vector &addresses, DataChunk &payload, ClusteredAggr &clustered, + bool skip_addresses) { + idx_t filter_idx = 0; + idx_t payload_idx = 0; + for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { + auto &aggr = aggregates[aggr_idx]; + if (filter && (filter_idx >= filter->size() || aggr_idx < (*filter)[filter_idx])) { + // Skip all the aggregates that are not in the filter + payload_idx += aggr.child_count; + if (!skip_addresses) { + VectorOperations::AddInPlace(addresses, NumericCast(aggr.payload_size)); + } + clustered.AdvanceStates(aggr.payload_size); + continue; + } + if (filter) { + D_ASSERT(aggr_idx == (*filter)[filter_idx]); + } + + if (aggr.aggr_type != AggregateType::DISTINCT && aggr.filter) { + D_ASSERT(filter_set); + RowOperations::UpdateFilteredStates(state, filter_set->GetFilterData(aggr_idx), aggr, addresses, payload, + payload_idx); + } else { + UpdateStates(state, aggr, addresses, payload, payload_idx, clustered); + } + + // Move to the next aggregate + payload_idx += aggr.child_count; + if (!skip_addresses) { + VectorOperations::AddInPlace(addresses, NumericCast(aggr.payload_size)); + } + clustered.AdvanceStates(aggr.payload_size); + if (filter) { + filter_idx++; + } + } } -void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, Vector &targets, - idx_t count) { +void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, + Vector &targets) { + auto count = sources.size(); if (count == 0) { return; } // Move to the first aggregate states - VectorOperations::AddInPlace(sources, UnsafeNumericCast(layout.GetAggrOffset()), count); - VectorOperations::AddInPlace(targets, UnsafeNumericCast(layout.GetAggrOffset()), count); + VectorOperations::AddInPlace(sources, UnsafeNumericCast(layout.GetAggrOffset())); + VectorOperations::AddInPlace(targets, UnsafeNumericCast(layout.GetAggrOffset())); // Keep track of the offset idx_t offset = layout.GetAggrOffset(); for (auto &aggr : layout.GetAggregates()) { D_ASSERT(aggr.function.HasStateCombineCallback()); - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator, - AggregateCombineType::ALLOW_DESTRUCTIVE); + AggregateInputData aggr_input_data(aggr, state.allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); aggr.function.GetStateCombineCallback()(sources, targets, aggr_input_data, count); // Move to the next aggregate states - VectorOperations::AddInPlace(sources, UnsafeNumericCast(aggr.payload_size), count); - VectorOperations::AddInPlace(targets, UnsafeNumericCast(aggr.payload_size), count); + VectorOperations::AddInPlace(sources, UnsafeNumericCast(aggr.payload_size)); + VectorOperations::AddInPlace(targets, UnsafeNumericCast(aggr.payload_size)); // Increment the offset offset += aggr.payload_size; } // Now subtract the offset to get back to the original position - VectorOperations::AddInPlace(sources, -UnsafeNumericCast(offset), count); - VectorOperations::AddInPlace(targets, -UnsafeNumericCast(offset), count); + VectorOperations::AddInPlace(sources, -UnsafeNumericCast(offset)); + VectorOperations::AddInPlace(targets, -UnsafeNumericCast(offset)); } void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, @@ -119,20 +155,21 @@ void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &l } auto &addresses_copy = *state.addresses; VectorOperations::Copy(addresses, addresses_copy, result.size(), 0, 0); + FlatVector::SetSize(addresses_copy, count_t(result.size())); // Move to the first aggregate state - VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(layout.GetAggrOffset()), result.size()); + VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(layout.GetAggrOffset())); auto &aggregates = layout.GetAggregates(); for (idx_t i = 0; i < aggregates.size(); i++) { auto &target = result.data[aggr_idx + i]; auto &aggr = aggregates[i]; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); + AggregateInputData aggr_input_data(aggr, state.allocator); aggr.function.GetStateFinalizeCallback()(addresses_copy, aggr_input_data, target, result.size(), 0); FlatVector::SetSize(target, count_t(result.size())); // Move to the next aggregate state - VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(aggr.payload_size), result.size()); + VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(aggr.payload_size)); } } diff --git a/src/duckdb/src/common/row_operations/row_matcher.cpp b/src/duckdb/src/common/row_operations/row_matcher.cpp index 09083a73e..6869fe97e 100644 --- a/src/duckdb/src/common/row_operations/row_matcher.cpp +++ b/src/duckdb/src/common/row_operations/row_matcher.cpp @@ -156,55 +156,55 @@ static idx_t StructMatchEquality(Vector &lhs_vector, const TupleDataVectorFormat } template -static idx_t SelectComparison(Vector &, Vector &, const SelectionVector &, idx_t, SelectionVector *, +static idx_t SelectComparison(const Vector &, const Vector &, const SelectionVector &, idx_t, SelectionVector *, SelectionVector *) { throw NotImplementedException("Unsupported list comparison operand for RowMatcher::GetMatchFunction"); } template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, +idx_t SelectComparison(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { return VectorOperations::NestedEquals(left, right, &sel, count, true_sel, false_sel); } template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, +idx_t SelectComparison(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { return VectorOperations::NestedNotEquals(left, right, &sel, count, true_sel, false_sel); } template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, +idx_t SelectComparison(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { return VectorOperations::DistinctFrom(left, right, &sel, count, true_sel, false_sel); } template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { +idx_t SelectComparison(const Vector &left, const Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { return VectorOperations::NotDistinctFrom(left, right, &sel, count, true_sel, false_sel); } template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, +idx_t SelectComparison(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel); } template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { +idx_t SelectComparison(const Vector &left, const Vector &right, const SelectionVector &sel, + idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel); } template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, +idx_t SelectComparison(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { return VectorOperations::DistinctLessThan(left, right, &sel, count, true_sel, false_sel); } template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, +idx_t SelectComparison(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { return VectorOperations::DistinctLessThanEquals(left, right, &sel, count, true_sel, false_sel); } diff --git a/src/duckdb/src/common/serialization_compatibility.cpp b/src/duckdb/src/common/serialization_compatibility.cpp deleted file mode 100644 index c22644043..000000000 --- a/src/duckdb/src/common/serialization_compatibility.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "duckdb/common/serialization_compatibility.hpp" -#include "duckdb/storage/storage_info.hpp" -#include "duckdb/common/optional_idx.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/storage/storage_manager.hpp" - -namespace duckdb { - -SerializationCompatibility SerializationCompatibility::FromDatabase(AttachedDatabase &db) { - return FromIndex(db.GetStorageManager().GetStorageVersion()); -} - -SerializationCompatibility SerializationCompatibility::FromIndex(const idx_t version) { - SerializationCompatibility result; - result.duckdb_version = ""; - result.serialization_version = version; - result.manually_set = false; - return result; -} - -SerializationCompatibility SerializationCompatibility::FromString(const string &input) { - if (input.empty()) { - throw InvalidInputException("Version string can not be empty"); - } - - auto serialization_version = GetSerializationVersion(input.c_str()); - if (!serialization_version.IsValid()) { - auto candidates = GetSerializationCandidates(); - throw InvalidInputException("The version string '%s' is not a known DuckDB version, valid options are: %s", - input, StringUtil::Join(candidates, ", ")); - } - SerializationCompatibility result; - result.duckdb_version = input; - result.serialization_version = serialization_version.GetIndex(); - result.manually_set = true; - return result; -} - -SerializationCompatibility SerializationCompatibility::Default() { -#ifdef DUCKDB_ALTERNATIVE_VERIFY - auto res = FromString("latest"); - res.manually_set = false; - return res; -#else -#ifdef DUCKDB_LATEST_STORAGE - auto res = FromString("latest"); - res.manually_set = false; - return res; -#else - auto res = FromString("v0.10.2"); - res.manually_set = false; - return res; -#endif -#endif -} - -SerializationCompatibility SerializationCompatibility::Latest() { - auto res = FromString("latest"); - res.manually_set = false; - return res; -} - -bool SerializationCompatibility::Compare(idx_t property_version) const { - return property_version <= serialization_version; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/full_sort.cpp b/src/duckdb/src/common/sort/full_sort.cpp index e9455eb64..21c0e9ada 100644 --- a/src/duckdb/src/common/sort/full_sort.cpp +++ b/src/duckdb/src/common/sort/full_sort.cpp @@ -176,8 +176,6 @@ SinkResultType FullSort::Sink(ExecutionContext &context, DataChunk &input_chunk, ConstantVector::SetNull(vec, count_t(input_chunk.size())); } - payload_chunk.SetCardinality(input_chunk); - // OVER(ORDER BY...) auto &sort_local = lstate.sort_local; D_ASSERT(sort_local); @@ -250,7 +248,7 @@ FullSort::FullSort(ClientContext &client, const vector &order_ auto &expr = *order.expression; if (expr.GetExpressionClass() == ExpressionClass::BOUND_REF) { auto &ref = expr.Cast(); - sort_ids.emplace_back(ref.index); + sort_ids.emplace_back(ref.Index()); continue; } diff --git a/src/duckdb/src/common/sort/hashed_sort.cpp b/src/duckdb/src/common/sort/hashed_sort.cpp index 97da8714d..ef1e75f16 100644 --- a/src/duckdb/src/common/sort/hashed_sort.cpp +++ b/src/duckdb/src/common/sort/hashed_sort.cpp @@ -346,12 +346,12 @@ void HashedSort::Synchronize(const GlobalSinkState &source, GlobalSinkState &tar } void HashedSortLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { - const auto count = input_chunk.size(); D_ASSERT(group_chunk.ColumnCount() > 0); // OVER(PARTITION BY...) (hash grouping) group_chunk.Reset(); hash_exec.Execute(input_chunk, group_chunk); + const idx_t count = group_chunk.size(); VectorOperations::Hash(group_chunk.data[0], hash_vector, count); for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { VectorOperations::CombineHash(hash_vector, group_chunk.data[prt_idx], count); @@ -388,7 +388,7 @@ SinkResultType HashedSort::Sink(ExecutionContext &context, DataChunk &input_chun ConstantVector::SetNull(vec, count_t(input_chunk.size())); } - payload_chunk.SetCardinality(input_chunk); + payload_chunk.SetChildCardinality(input_chunk.size()); // OVER(PARTITION BY...) auto &hash_vector = payload_chunk.data.back(); @@ -542,7 +542,7 @@ HashedSort::HashedSort(ClientContext &client, const vector(); - sort_ids.emplace_back(ref.index); + sort_ids.emplace_back(ref.Index()); continue; } diff --git a/src/duckdb/src/common/sort/sort.cpp b/src/duckdb/src/common/sort/sort.cpp index 76d356a10..9ca8fa0fc 100644 --- a/src/duckdb/src/common/sort/sort.cpp +++ b/src/duckdb/src/common/sort/sort.cpp @@ -52,7 +52,8 @@ Sort::Sort(ClientContext &client_context_p, const vector &orde } ErrorData error; - create_sort_key = binder.BindScalarFunction(DEFAULT_SCHEMA, "create_sort_key", std::move(create_children), error); + create_sort_key = + binder.BindScalarFunction(Identifier::DefaultSchema(), "create_sort_key", std::move(create_children), error); if (!create_sort_key) { throw InternalException("Unable to bind create_sort_key in Sort::Sort"); } @@ -92,7 +93,7 @@ Sort::Sort(ClientContext &client_context_p, const vector &orde for (idx_t key_idx = 0; key_idx < orders.size(); key_idx++) { const auto &key_order_expr = *orders[key_idx].expression; if (key_order_expr.GetExpressionClass() == ExpressionClass::BOUND_REF) { - input_column_to_key.emplace(key_order_expr.Cast().index, key_idx); + input_column_to_key.emplace(key_order_expr.Cast().Index(), key_idx); } } diff --git a/src/duckdb/src/common/sort/sorted_run.cpp b/src/duckdb/src/common/sort/sorted_run.cpp index c777d4560..0da36e0df 100644 --- a/src/duckdb/src/common/sort/sorted_run.cpp +++ b/src/duckdb/src/common/sort/sorted_run.cpp @@ -22,28 +22,27 @@ SortedRunScanState::SortedRunScanState(ClientContext &context, const Sort &sort_ decoded_key.Initialize(context, {sort.decode_sort_key->GetReturnType()}); } -void SortedRunScanState::Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, - DataChunk &chunk) { +void SortedRunScanState::Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, DataChunk &chunk) { const auto sort_key_type = sort.key_layout->GetSortKeyType(); switch (sort_key_type) { case SortKeyType::NO_PAYLOAD_FIXED_8: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); case SortKeyType::NO_PAYLOAD_FIXED_16: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); case SortKeyType::NO_PAYLOAD_FIXED_24: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); case SortKeyType::NO_PAYLOAD_FIXED_32: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); case SortKeyType::NO_PAYLOAD_VARIABLE_32: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); case SortKeyType::PAYLOAD_FIXED_16: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); case SortKeyType::PAYLOAD_FIXED_24: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); case SortKeyType::PAYLOAD_FIXED_32: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); case SortKeyType::PAYLOAD_VARIABLE_32: - return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + return TemplatedScan(sorted_run, sort_key_pointers, chunk); default: throw NotImplementedException("SortedRunMergerLocalState::ScanPartition for %s", EnumUtil::ToString(sort_key_type)); @@ -85,9 +84,9 @@ static void GetKeyAndPayload(SORT_KEY *const *const sort_keys, SORT_KEY *temp_ke } template -void SortedRunScanState::TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, - DataChunk &chunk) { +void SortedRunScanState::TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, DataChunk &chunk) { using SORT_KEY = SortKey; + const idx_t count = sort_key_pointers.size(); const auto &output_projection_columns = sort.output_projection_columns; idx_t opc_idx = 0; @@ -176,28 +175,29 @@ SortedRun::~SortedRun() { } template -static void TemplatedSetPayloadPointer(Vector &key_locations, Vector &payload_locations, const idx_t count) { +static void TemplatedSetPayloadPointer(const Vector &key_locations, const Vector &payload_locations) { using SORT_KEY = SortKey; + D_ASSERT(key_locations.size() == payload_locations.size()); - const auto key_locations_ptr = FlatVector::GetDataMutable(key_locations); + const auto key_locations_ptr = FlatVector::GetData(key_locations); const auto payload_locations_ptr = FlatVector::GetData(payload_locations); - for (idx_t i = 0; i < count; i++) { + for (idx_t i = 0; i < key_locations.size(); i++) { key_locations_ptr[i]->SetPayload(payload_locations_ptr[i]); } } -static void SetPayloadPointer(Vector &key_locations, Vector &payload_locations, const idx_t count, +static void SetPayloadPointer(const Vector &key_locations, const Vector &payload_locations, const SortKeyType &sort_key_type) { switch (sort_key_type) { case SortKeyType::PAYLOAD_FIXED_16: - return TemplatedSetPayloadPointer(key_locations, payload_locations, count); + return TemplatedSetPayloadPointer(key_locations, payload_locations); case SortKeyType::PAYLOAD_FIXED_24: - return TemplatedSetPayloadPointer(key_locations, payload_locations, count); + return TemplatedSetPayloadPointer(key_locations, payload_locations); case SortKeyType::PAYLOAD_FIXED_32: - return TemplatedSetPayloadPointer(key_locations, payload_locations, count); + return TemplatedSetPayloadPointer(key_locations, payload_locations); case SortKeyType::PAYLOAD_VARIABLE_32: - return TemplatedSetPayloadPointer(key_locations, payload_locations, count); + return TemplatedSetPayloadPointer(key_locations, payload_locations); default: throw NotImplementedException("SetPayloadPointer for %s", EnumUtil::ToString(sort_key_type)); } @@ -210,7 +210,7 @@ void SortedRun::Sink(DataChunk &key, DataChunk &payload) { D_ASSERT(key.size() == payload.size()); payload_data->Append(payload_append_state, payload); SetPayloadPointer(key_append_state.chunk_state.row_locations, payload_append_state.chunk_state.row_locations, - key.size(), key_data->GetLayout().GetSortKeyType()); + key_data->GetLayout().GetSortKeyType()); } } diff --git a/src/duckdb/src/common/sort/sorted_run_merger.cpp b/src/duckdb/src/common/sort/sorted_run_merger.cpp index b3b681a2d..e66e4b576 100644 --- a/src/duckdb/src/common/sort/sorted_run_merger.cpp +++ b/src/duckdb/src/common/sort/sorted_run_merger.cpp @@ -708,14 +708,16 @@ void SortedRunMergerLocalState::TemplatedScanPartition(SortedRunMergerGlobalStat // Grab pointers to sort keys const auto merged_partition_keys = reinterpret_cast(merged_partition.get()) + merged_partition_index; - const auto sort_keys = FlatVector::GetDataMutable(sort_key_pointers); - for (idx_t i = 0; i < count; i++) { - sort_keys[i] = &merged_partition_keys[i]; + { + auto writer = FlatVector::Writer(sort_key_pointers, count); + for (idx_t i = 0; i < count; i++) { + writer.WriteValue(&merged_partition_keys[i]); + } } merged_partition_index += count; // Scan - sorted_run_scan_state.Scan(*gstate.merger.sorted_runs[0], sort_key_pointers, count, chunk); + sorted_run_scan_state.Scan(*gstate.merger.sorted_runs[0], sort_key_pointers, chunk); } void SortedRunMergerLocalState::MaterializePartition(SortedRunMergerGlobalState &gstate) { diff --git a/src/duckdb/src/common/storage_compatibility.cpp b/src/duckdb/src/common/storage_compatibility.cpp new file mode 100644 index 000000000..b63bad87d --- /dev/null +++ b/src/duckdb/src/common/storage_compatibility.cpp @@ -0,0 +1,80 @@ +#include "duckdb/common/storage_compatibility.hpp" +#include "duckdb/storage/storage_info.hpp" +#include "duckdb/common/optional_idx.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/storage/storage_manager.hpp" + +namespace duckdb { + +StorageCompatibility StorageCompatibility::FromDatabase(AttachedDatabase &db) { + return FromIndex(db.GetStorageManager().GetStorageVersion()); +} + +StorageCompatibility StorageCompatibility::FromIndex(StorageVersion storage_version_p) { + StorageCompatibility result; + + result.storage_version = storage_version_p; + result.duckdb_version = StorageVersionInfo::GetStorageVersionString(storage_version_p); + result.manually_set = false; + return result; +} + +StorageCompatibility StorageCompatibility::FromString(const string &input) { + if (input.empty()) { + throw InvalidInputException("Version string can not be empty"); + } + + auto storage_version = GetStorageVersion(input.c_str()); + if (storage_version == StorageVersion::INVALID) { + auto candidates = GetStorageCandidates(); + throw InvalidInputException("The version string '%s' is not a known DuckDB version, valid options are: %s", + input, StringUtil::Join(candidates, ", ")); + } + StorageCompatibility result; + result.duckdb_version = input; + result.storage_version = storage_version; + result.manually_set = true; + return result; +} + +StorageCompatibility StorageCompatibility::Default() { +#ifdef DUCKDB_ALTERNATIVE_VERIFY + auto res = FromString("latest"); + res.duckdb_version = "latest"; + res.manually_set = false; + return res; +#else +#ifdef DUCKDB_LATEST_STORAGE + auto res = FromString("latest"); + res.manually_set = false; + return res; +#else + auto res = FromString("v0.10.2"); + res.duckdb_version = "latest"; + res.manually_set = false; + return res; +#endif +#endif +} + +StorageCompatibility StorageCompatibility::Latest() { + auto res = FromString("latest"); + res.manually_set = false; + return res; +} + +bool StorageCompatibility::Compare(StorageVersion property_version) const { + return property_version <= storage_version; +} + +bool StorageCompatibility::CompareVersionString(const string &property_version) const { + auto property_version_val = GetSerializationVersionDeprecated(property_version.c_str()); + auto deprecated_serialization_version = GetSerializationVersionDeprecated(duckdb_version.c_str()); + return property_version_val <= deprecated_serialization_version; +} + +StorageVersion StorageCompatibility::GetStorageVersionCompatibility() const { + return storage_version; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/string_util.cpp b/src/duckdb/src/common/string_util.cpp index bcf1d0725..f6a29d76f 100644 --- a/src/duckdb/src/common/string_util.cpp +++ b/src/duckdb/src/common/string_util.cpp @@ -258,6 +258,11 @@ vector StringUtil::SplitWithParentheses(const string &str, char delimite return result; } +string StringUtil::Join(const vector &input, const string &separator) { + return StringUtil::Join(input, input.size(), separator, + [](const Identifier &id) { return id.GetIdentifierName(); }); +} + string StringUtil::Join(const vector &input, const string &separator) { return StringUtil::Join(input, input.size(), separator, [](const string &s) { return s; }); } @@ -485,6 +490,15 @@ bool StringUtil::CILessThan(const string &s1, const string &s2) { return (charmap[u1] - charmap[u2]) < 0; } +idx_t StringUtil::CIFind(const vector &vector, const Identifier &search_string) { + for (idx_t i = 0; i < vector.size(); i++) { + if (vector[i] == search_string) { + return i; + } + } + return DConstants::INVALID_INDEX; +} + idx_t StringUtil::CIFind(vector &vector, const string &search_string) { for (idx_t i = 0; i < vector.size(); i++) { const auto &string = vector[i]; @@ -644,6 +658,10 @@ idx_t StringUtil::SimilarityScore(const string &s1, const string &s2) { return LevenshteinDistance(s1, s2, 3); } +double StringUtil::SimilarityRating(const Identifier &s1, const Identifier &s2) { + return SimilarityRating(s1.GetIdentifierName(), s2.GetIdentifierName()); +} + double StringUtil::SimilarityRating(const string &s1, const string &s2) { return duckdb_jaro_winkler::jaro_winkler_similarity(s1.data(), s1.data() + s1.size(), s2.data(), s2.data() + s2.size()); @@ -663,6 +681,11 @@ vector StringUtil::TopNLevenshtein(const vector &strings, const return TopNStrings(scores, n, threshold); } +vector StringUtil::TopNJaroWinkler(const vector &strings, const Identifier &target, idx_t n, + double threshold) { + return TopNJaroWinkler(strings, StringUtil::Lower(target.GetIdentifierName()), n, threshold); +} + vector StringUtil::TopNJaroWinkler(const vector &strings, const string &target, idx_t n, double threshold) { vector> scores; @@ -1074,7 +1097,7 @@ uint32_t StringUtil::StringToEnum(const EnumStringLiteral enum_list[], idx_t enu for (idx_t i = 0; i < enum_count; i++) { candidates.push_back(enum_list[i].string); } - auto closest_values = TopNJaroWinkler(candidates, str_value); + auto closest_values = TopNJaroWinkler(candidates, string(str_value)); auto message = CandidatesMessage(closest_values, "Candidates"); throw NotImplementedException("Enum value: unrecognized value \"%s\" for enum \"%s\"\n%s", str_value, enum_name, message); diff --git a/src/duckdb/src/common/thread_util.cpp b/src/duckdb/src/common/thread_util.cpp index 96a576e21..bc96a2022 100644 --- a/src/duckdb/src/common/thread_util.cpp +++ b/src/duckdb/src/common/thread_util.cpp @@ -1,16 +1,34 @@ #include "duckdb/common/thread.hpp" #include "duckdb/common/chrono.hpp" #include "duckdb/original/std/sstream.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/main/client_context.hpp" namespace duckdb { #ifndef DUCKDB_NO_THREADS -void ThreadUtil::SleepMs(idx_t sleep_ms) { - std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); +void ThreadUtil::SleepMs(idx_t sleep_ms, optional_ptr context) { + auto target_time = Timestamp::GetCurrentTimestamp(); + target_time.value += static_cast(sleep_ms) * Interval::MICROS_PER_MSEC; + static constexpr idx_t DEFAULT_SLEEP_INTERVAL_MS = 100; + + while (true) { + auto current_time = Timestamp::GetCurrentTimestamp(); + if (context && context->IsInterrupted()) { + throw InterruptException(); + } + if (current_time >= target_time) { + break; + } + auto remaining_ms = static_cast(target_time.value - current_time.value) / Interval::MICROS_PER_MSEC; + std::this_thread::sleep_for(milliseconds(MinValue(remaining_ms, DEFAULT_SLEEP_INTERVAL_MS))); + } } void ThreadUtil::SleepMicroSeconds(idx_t micros) { - std::this_thread::sleep_for(std::chrono::microseconds(micros)); + std::this_thread::sleep_for(microseconds(micros)); } thread_id ThreadUtil::GetThreadId() { @@ -25,7 +43,7 @@ string ThreadUtil::GetThreadIdString() { #else -void ThreadUtil::SleepMs(idx_t sleep_ms) { +void ThreadUtil::SleepMs(idx_t sleep_ms, optional_ptr) { throw InvalidInputException("ThreadUtil::SleepMs requires DuckDB to be compiled with thread support"); } diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp index 4ed07c859..d0b0867f3 100644 --- a/src/duckdb/src/common/types.cpp +++ b/src/duckdb/src/common/types.cpp @@ -17,6 +17,7 @@ #include "duckdb/common/types/vector.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/function/cast_rules.hpp" +#include "duckdb/function/cast/cast_function_set.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/common/types/type_manager.hpp" @@ -163,12 +164,6 @@ PhysicalType LogicalType::GetInternalType() { return PhysicalType::INVALID; case LogicalTypeId::UNBOUND: return PhysicalType::UNKNOWN; - case LogicalTypeId::LEGACY_AGGREGATE_STATE: { - // Legacy aggregate state - opaque BLOB, - return PhysicalType::VARCHAR; - } - case LogicalTypeId::AGGREGATE_STATE: - return PhysicalType::STRUCT; case LogicalTypeId::GEOMETRY: return PhysicalType::VARCHAR; default: @@ -508,12 +503,6 @@ string LogicalType::ToString() const { return expr->ToString(); } } - case LogicalTypeId::LEGACY_AGGREGATE_STATE: { - return LegacyAggregateStateType::GetTypeName(*this); - } - case LogicalTypeId::AGGREGATE_STATE: { - return AggregateStateType::GetTypeName(*this); - } case LogicalTypeId::SQLNULL: { return "\"NULL\""; } @@ -538,7 +527,7 @@ string LogicalType::ToString() const { // LCOV_EXCL_STOP LogicalTypeId TransformStringToLogicalTypeId(const string &str) { - auto type = DefaultTypeGenerator::GetDefaultType(str); + auto type = DefaultTypeGenerator::GetDefaultType(Identifier(str)); if (type == LogicalTypeId::INVALID) { // This is a User Type, at this point we don't know if its one of the User Defined Types or an error // It is checked in the binder @@ -891,50 +880,120 @@ LogicalType LogicalType::NormalizeType(const LogicalType &type) { } } -template -static bool CombineUnequalTypes(const LogicalType &left, const LogicalType &right, LogicalType &result) { - D_ASSERT(right.id() != left.id()); - if (right.id() == LogicalTypeId::VARIANT) { - result = right; - return true; +class LogicalTypeResolver { +public: + explicit LogicalTypeResolver(optional_ptr context_p) : context(context_p) { } - if (left.id() == LogicalTypeId::VARIANT) { - result = left; + virtual ~LogicalTypeResolver() = default; + + virtual bool Operation(const LogicalType &left, const LogicalType &right, LogicalType &result) = 0; + +public: + optional_ptr context; +}; + +static bool TryGetMaxLogicalTypeInternal(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result); + +class ForceLogicalTypeResolver : public LogicalTypeResolver { +public: + using LogicalTypeResolver::LogicalTypeResolver; + + bool Operation(const LogicalType &left, const LogicalType &right, LogicalType &result) override { + result = context ? LogicalType::ForceMaxLogicalType(*context, left, right) + : LogicalType::DefaultForceMaxLogicalType(left, right); return true; } +}; + +class TryGetLogicalTypeResolver : public LogicalTypeResolver { +public: + using LogicalTypeResolver::LogicalTypeResolver; + + bool Operation(const LogicalType &left, const LogicalType &right, LogicalType &result) override { + return TryGetMaxLogicalTypeInternal(*this, left, right, result); + } +}; - // left and right are not equal +static bool CombineNullOrUnknown(LogicalTypeResolver &, const LogicalType &left, const LogicalType &right, + LogicalType &result) { // NULL/unknown (parameter) types always take the other type - LogicalTypeId other_types[] = {LogicalTypeId::SQLNULL, LogicalTypeId::UNKNOWN}; - for (auto &other_type : other_types) { - if (left.id() == other_type) { - result = LogicalType::NormalizeType(right); - return true; - } else if (right.id() == other_type) { - result = LogicalType::NormalizeType(left); - return true; - } + if (left.id() == LogicalTypeId::SQLNULL || right.id() == LogicalTypeId::SQLNULL) { + result = LogicalType::NormalizeType(left.id() == LogicalTypeId::SQLNULL ? right : left); + } else { + result = LogicalType::NormalizeType(left.id() == LogicalTypeId::UNKNOWN ? right : left); } + return true; +} +static bool CombineEnum(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, const LogicalType &right, + LogicalType &result) { // for enums, match the varchar rules if (left.id() == LogicalTypeId::ENUM) { - return OP::Operation(LogicalType::VARCHAR, right, result); - } else if (right.id() == LogicalTypeId::ENUM) { - return OP::Operation(left, LogicalType::VARCHAR, result); + return logical_type_resolver.Operation(LogicalType::VARCHAR, right, result); } + return logical_type_resolver.Operation(left, LogicalType::VARCHAR, result); +} - // for everything but enums - string literals also take the other type - if (left.id() == LogicalTypeId::STRING_LITERAL) { - result = LogicalType::NormalizeType(right); - return true; - } else if (right.id() == LogicalTypeId::STRING_LITERAL) { - result = LogicalType::NormalizeType(left); - return true; - } +static bool CombineVariant(LogicalTypeResolver &, const LogicalType &left, const LogicalType &right, + LogicalType &result) { + result = right.id() == LogicalTypeId::VARIANT ? right : left; + return true; +} + +// for everything but enums - string literals also take the other type +static bool CombineStringLiteral(LogicalTypeResolver &, const LogicalType &left, const LogicalType &right, + LogicalType &result) { + result = LogicalType::NormalizeType(left.id() == LogicalTypeId::STRING_LITERAL ? right : left); + return true; +} + +using CombineTypesRuleFunction = bool (*)(LogicalTypeResolver &resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result); +struct CombineUnequalTypesRule { + bool (*matches)(const LogicalType &left, const LogicalType &right); // order-insensitive + CombineTypesRuleFunction function; +}; + +static bool MatchesVariant(const LogicalType &left, const LogicalType &right) { + return left.id() == LogicalTypeId::VARIANT || right.id() == LogicalTypeId::VARIANT; +} + +static bool MatchesNullOrUnknown(const LogicalType &left, const LogicalType &right) { + auto null_or_unknown = [](const LogicalType &type) { + return type.id() == LogicalTypeId::SQLNULL || type.id() == LogicalTypeId::UNKNOWN; + }; + return null_or_unknown(left) || null_or_unknown(right); +} + +static bool MatchesEnum(const LogicalType &left, const LogicalType &right) { + return left.id() == LogicalTypeId::ENUM || right.id() == LogicalTypeId::ENUM; +} - // for other types - use implicit cast rules to check if we can combine the types - auto left_to_right_cost = CastRules::ImplicitCast(left, right); - auto right_to_left_cost = CastRules::ImplicitCast(right, left); +static bool MatchesStringLiteral(const LogicalType &left, const LogicalType &right) { + return left.id() == LogicalTypeId::STRING_LITERAL || right.id() == LogicalTypeId::STRING_LITERAL; +} + +// Rules are matched in order, so the first matching rule wins (VARIANT and NULL/UNKNOWN are matched +// before ENUM, ENUM before STRING_LITERAL). Each `matches` predicate is order-insensitive in (left, right). +static const CombineUnequalTypesRule COMBINE_UNEQUAL_TYPES_RULES[] = { + {MatchesVariant, CombineVariant}, + {MatchesNullOrUnknown, CombineNullOrUnknown}, + {MatchesEnum, CombineEnum}, + {MatchesStringLiteral, CombineStringLiteral}, +}; + +// Generic combine engine: try the implicit cast rules in both directions and keep the cheaper one. With a +// context we go through CastFunctionSet so that any registered casts are honored; without a context we +// consult the built-in CastRules table directly. Returns false when neither direction can be cast. +static bool TryCombineViaImplicitCast(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + auto left_to_right_cost = logical_type_resolver.context + ? CastFunctionSet::ImplicitCastCost(*logical_type_resolver.context, left, right) + : CastRules::ImplicitCast(left, right); + auto right_to_left_cost = logical_type_resolver.context + ? CastFunctionSet::ImplicitCastCost(*logical_type_resolver.context, right, left) + : CastRules::ImplicitCast(right, left); if (left_to_right_cost >= 0 && (left_to_right_cost < right_to_left_cost || right_to_left_cost < 0)) { // we can implicitly cast left to right, return right //! Depending on the type, we might need to grow the `width` of the DECIMAL type @@ -965,12 +1024,28 @@ static bool CombineUnequalTypes(const LogicalType &left, const LogicalType &righ } return true; } + return false; +} + +static bool CombineUnequalTypes(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + D_ASSERT(right.id() != left.id()); + for (auto &combine_rule : COMBINE_UNEQUAL_TYPES_RULES) { + if (combine_rule.matches(left, right)) { + return combine_rule.function(logical_type_resolver, left, right, result); + } + } + + // for other types - try to combine through the generic implicit cast rules + if (TryCombineViaImplicitCast(logical_type_resolver, left, right, result)) { + return true; + } // for integer literals - rerun the operation with the underlying type if (left.id() == LogicalTypeId::INTEGER_LITERAL) { - return OP::Operation(IntegerLiteral::GetType(left), right, result); + return logical_type_resolver.Operation(IntegerLiteral::GetType(left), right, result); } if (right.id() == LogicalTypeId::INTEGER_LITERAL) { - return OP::Operation(left, IntegerLiteral::GetType(right), result); + return logical_type_resolver.Operation(left, IntegerLiteral::GetType(right), result); } // for unsigned/signed comparisons we have a few fallbacks if (left.IsNumeric() && right.IsNumeric()) { @@ -988,8 +1063,8 @@ static bool CombineUnequalTypes(const LogicalType &left, const LogicalType &righ return false; } -template -static bool CombineStructTypes(const LogicalType &left, const LogicalType &right, LogicalType &result) { +static bool CombineStructTypes(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { auto &left_children = StructType::GetChildTypes(left); auto &right_children = StructType::GetChildTypes(right); @@ -1006,7 +1081,7 @@ static bool CombineStructTypes(const LogicalType &left, const LogicalType &right for (idx_t i = 0; i < left_children.size(); i++) { LogicalType child_type; - if (!OP::Operation(left_children[i].second, right_children[i].second, child_type)) { + if (!logical_type_resolver.Operation(left_children[i].second, right_children[i].second, child_type)) { return false; } auto &child_name = left_unnamed ? right_children[i].first : left_children[i].first; @@ -1018,7 +1093,7 @@ static bool CombineStructTypes(const LogicalType &left, const LogicalType &right // Create a super-set of the STRUCT fields. // First, create a name->index map of the right children. - InsertionOrderPreservingMap right_children_map; + InsertionOrderPreservingMap> right_children_map; for (idx_t i = 0; i < right_children.size(); i++) { auto &name = right_children[i].first; right_children_map[name] = i; @@ -1037,7 +1112,7 @@ static bool CombineStructTypes(const LogicalType &left, const LogicalType &right // We need to recurse to ensure the children have a maximum logical type. LogicalType child_type; auto &right_child = right_children[right_child_it->second]; - if (!OP::Operation(left_child.second, right_child.second, child_type)) { + if (!logical_type_resolver.Operation(left_child.second, right_child.second, child_type)) { return false; } child_types.emplace_back(left_child.first, std::move(child_type)); @@ -1054,99 +1129,128 @@ static bool CombineStructTypes(const LogicalType &left, const LogicalType &right return true; } -template -static bool CombineEqualTypes(const LogicalType &left, const LogicalType &right, LogicalType &result) { - // Since both left and right are equal we get the left type as our type_id for checks - auto type_id = left.id(); - switch (type_id) { - case LogicalTypeId::STRING_LITERAL: - // two string literals convert to varchar - result = LogicalType::VARCHAR; - return true; - case LogicalTypeId::INTEGER_LITERAL: - // for two integer literals we unify the underlying types - return OP::Operation(IntegerLiteral::GetType(left), IntegerLiteral::GetType(right), result); - case LogicalTypeId::ENUM: - // If both types are different ENUMs we do a string comparison. - result = left == right ? left : LogicalType::VARCHAR; - return true; - case LogicalTypeId::VARCHAR: - // varchar: use type that has collation (if any) - if (StringType::GetCollation(right).empty()) { - result = left; - } else { - result = right; - } - return true; - case LogicalTypeId::DECIMAL: { - // unify the width/scale so that the resulting decimal always fits - // "width - scale" gives us the number of digits on the left side of the decimal point - // "scale" gives us the number of digits allowed on the right of the decimal point - // using the max of these of the two types gives us the new decimal size - auto extra_width_left = DecimalType::GetWidth(left) - DecimalType::GetScale(left); - auto extra_width_right = DecimalType::GetWidth(right) - DecimalType::GetScale(right); - auto extra_width = - MaxValue(NumericCast(extra_width_left), NumericCast(extra_width_right)); - auto scale = MaxValue(DecimalType::GetScale(left), DecimalType::GetScale(right)); - auto width = NumericCast(extra_width + scale); - if (width > DecimalType::MaxWidth()) { - // if the resulting decimal does not fit, we truncate the scale - width = DecimalType::MaxWidth(); - scale = NumericCast(width - extra_width); - } - result = LogicalType::DECIMAL(width, scale); - return true; - } - case LogicalTypeId::LIST: { - // list: perform max recursively on child type - LogicalType new_child; - if (!OP::Operation(ListType::GetChildType(left), ListType::GetChildType(right), new_child)) { - return false; - } - result = LogicalType::LIST(new_child); - return true; +struct CombineEqualTypesRule { + LogicalTypeId type; + CombineTypesRuleFunction function; +}; + +static bool CombineEqualStringLiteral(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + result = LogicalType::VARCHAR; + return true; +} +static bool CombineEqualIntegerLiteral(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + // for two integer literals we unify the underlying types + return logical_type_resolver.Operation(IntegerLiteral::GetType(left), IntegerLiteral::GetType(right), result); +} +static bool CombineEqualEnum(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + // If both types are different ENUMs we do a string comparison. + result = left == right ? left : LogicalType::VARCHAR; + return true; +} +static bool CombineEqualVarchar(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + // varchar: use type that has collation (if any) + if (StringType::GetCollation(right).empty()) { + result = left; + } else { + result = right; } - case LogicalTypeId::ARRAY: { - LogicalType new_child; - if (!OP::Operation(ArrayType::GetChildType(left), ArrayType::GetChildType(right), new_child)) { - return false; - } - auto new_size = MaxValue(ArrayType::GetSize(left), ArrayType::GetSize(right)); - result = LogicalType::ARRAY(new_child, new_size); - return true; + return true; +} +static bool CombineEqualDecimal(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + // unify the width/scale so that the resulting decimal always fits + // "width - scale" gives us the number of digits on the left side of the decimal point + // "scale" gives us the number of digits allowed on the right of the decimal point + // using the max of these of the two types gives us the new decimal size + auto extra_width_left = DecimalType::GetWidth(left) - DecimalType::GetScale(left); + auto extra_width_right = DecimalType::GetWidth(right) - DecimalType::GetScale(right); + auto extra_width = + MaxValue(NumericCast(extra_width_left), NumericCast(extra_width_right)); + auto scale = MaxValue(DecimalType::GetScale(left), DecimalType::GetScale(right)); + auto width = NumericCast(extra_width + scale); + if (width > DecimalType::MaxWidth()) { + // if the resulting decimal does not fit, we truncate the scale + width = DecimalType::MaxWidth(); + scale = NumericCast(width - extra_width); + } + result = LogicalType::DECIMAL(width, scale); + return true; +} +static bool CombineEqualList(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + // list: perform max recursively on child type + LogicalType new_child; + if (!logical_type_resolver.Operation(ListType::GetChildType(left), ListType::GetChildType(right), new_child)) { + return false; } - case LogicalTypeId::MAP: { - // map: perform max recursively on child type - LogicalType new_child; - if (!OP::Operation(ListType::GetChildType(left), ListType::GetChildType(right), new_child)) { - return false; - } - result = LogicalType::MAP(new_child); - return true; + result = LogicalType::LIST(new_child); + return true; +} +static bool CombineEqualArray(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + LogicalType new_child; + if (!logical_type_resolver.Operation(ArrayType::GetChildType(left), ArrayType::GetChildType(right), new_child)) { + return false; } - case LogicalTypeId::STRUCT: { - return CombineStructTypes(left, right, result); + auto new_size = MaxValue(ArrayType::GetSize(left), ArrayType::GetSize(right)); + result = LogicalType::ARRAY(new_child, new_size); + return true; +} +static bool CombineEqualMap(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + // map: perform max recursively on child type + LogicalType new_child; + if (!logical_type_resolver.Operation(ListType::GetChildType(left), ListType::GetChildType(right), new_child)) { + return false; } - case LogicalTypeId::UNION: { - auto left_member_count = UnionType::GetMemberCount(left); - auto right_member_count = UnionType::GetMemberCount(right); - if (left_member_count != right_member_count) { - // return the "larger" type, with the most members - result = left_member_count > right_member_count ? left : right; - return true; - } - // otherwise, keep left, don't try to meld the two together. - result = left; + result = LogicalType::MAP(new_child); + return true; +} +static bool CombineEqualUnion(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + auto left_member_count = UnionType::GetMemberCount(left); + auto right_member_count = UnionType::GetMemberCount(right); + if (left_member_count != right_member_count) { + // return the "larger" type, with the most members + result = left_member_count > right_member_count ? left : right; return true; } - default: - result = left; - return true; + // otherwise, keep left, don't try to meld the two together. + result = left; + return true; +} + +static const CombineEqualTypesRule COMBINE_EQUAL_TYPES_RULES[] = { + {LogicalTypeId::STRING_LITERAL, CombineEqualStringLiteral}, + {LogicalTypeId::INTEGER_LITERAL, CombineEqualIntegerLiteral}, + {LogicalTypeId::ENUM, CombineEqualEnum}, + {LogicalTypeId::VARCHAR, CombineEqualVarchar}, + {LogicalTypeId::DECIMAL, CombineEqualDecimal}, + {LogicalTypeId::LIST, CombineEqualList}, + {LogicalTypeId::ARRAY, CombineEqualArray}, + {LogicalTypeId::MAP, CombineEqualMap}, + {LogicalTypeId::STRUCT, CombineStructTypes}, + {LogicalTypeId::UNION, CombineEqualUnion}, +}; + +static bool CombineEqualTypes(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + auto type_id = left.id(); + for (auto &combine_rule : COMBINE_EQUAL_TYPES_RULES) { + if (type_id == combine_rule.type) { + return combine_rule.function(logical_type_resolver, left, right, result); + } } + result = left; + return true; } -template -static bool TryGetMaxLogicalTypeInternal(const LogicalType &left, const LogicalType &right, LogicalType &result) { +static bool TryGetMaxLogicalTypeInternal(LogicalTypeResolver &logical_type_resolver, const LogicalType &left, + const LogicalType &right, LogicalType &result) { // we always prefer aliased types if (!left.GetAlias().empty()) { result = left; @@ -1157,37 +1261,31 @@ static bool TryGetMaxLogicalTypeInternal(const LogicalType &left, const LogicalT return true; } if (left.id() != right.id()) { - return CombineUnequalTypes(left, right, result); + return CombineUnequalTypes(logical_type_resolver, left, right, result); } else { - return CombineEqualTypes(left, right, result); + return CombineEqualTypes(logical_type_resolver, left, right, result); } } -struct TryGetTypeOperation { - static bool Operation(const LogicalType &left, const LogicalType &right, LogicalType &result) { - return TryGetMaxLogicalTypeInternal(left, right, result); - } -}; - -struct ForceGetTypeOperation { - static bool Operation(const LogicalType &left, const LogicalType &right, LogicalType &result) { - result = LogicalType::ForceMaxLogicalType(left, right); - return true; - } -}; - bool LogicalType::TryGetMaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right, LogicalType &result) { if (Settings::Get(context)) { - result = LogicalType::ForceMaxLogicalType(left, right); + result = LogicalType::ForceMaxLogicalType(context, left, right); return true; } - return TryGetMaxLogicalTypeUnchecked(left, right, result); + return TryGetMaxLogicalTypeUnchecked(context, left, right, result); +} + +bool LogicalType::TryGetMaxLogicalTypeUnchecked(ClientContext &context, const LogicalType &left, + const LogicalType &right, LogicalType &result) { + TryGetLogicalTypeResolver logical_type_resolver(&context); + return TryGetMaxLogicalTypeInternal(logical_type_resolver, left, right, result); } -bool LogicalType::TryGetMaxLogicalTypeUnchecked(const LogicalType &left, const LogicalType &right, - LogicalType &result) { - return TryGetMaxLogicalTypeInternal(left, right, result); +bool LogicalType::DefaultTryGetMaxLogicalTypeUnchecked(const LogicalType &left, const LogicalType &right, + LogicalType &result) { + TryGetLogicalTypeResolver logical_type_resolver(nullptr); + return TryGetMaxLogicalTypeInternal(logical_type_resolver, left, right, result); } static idx_t GetLogicalTypeScore(const LogicalType &type) { @@ -1281,20 +1379,21 @@ static idx_t GetLogicalTypeScore(const LogicalType &type) { return 150; // weirdo types case LogicalTypeId::LAMBDA: - case LogicalTypeId::LEGACY_AGGREGATE_STATE: - case LogicalTypeId::AGGREGATE_STATE: case LogicalTypeId::POINTER: case LogicalTypeId::VALIDITY: case LogicalTypeId::UNBOUND: case LogicalTypeId::TYPE: + case LogicalTypeId::LEGACY_AGGREGATE_STATE: break; } return 1000; } -LogicalType LogicalType::ForceMaxLogicalType(const LogicalType &left, const LogicalType &right) { +static LogicalType ForceMaxLogicalTypeInternal(optional_ptr context, const LogicalType &left, + const LogicalType &right) { LogicalType result; - if (TryGetMaxLogicalTypeInternal(left, right, result)) { + ForceLogicalTypeResolver logical_type_resolver(context); + if (TryGetMaxLogicalTypeInternal(logical_type_resolver, left, right, result)) { return result; } // we prefer the type with the highest score @@ -1307,6 +1406,15 @@ LogicalType LogicalType::ForceMaxLogicalType(const LogicalType &left, const Logi } } +LogicalType LogicalType::ForceMaxLogicalType(ClientContext &context, const LogicalType &left, + const LogicalType &right) { + return ForceMaxLogicalTypeInternal(&context, left, right); +} + +LogicalType LogicalType::DefaultForceMaxLogicalType(const LogicalType &left, const LogicalType &right) { + return ForceMaxLogicalTypeInternal(nullptr, left, right); +} + LogicalType LogicalType::MaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right) { LogicalType result; if (!TryGetMaxLogicalType(context, left, right, result)) { @@ -1325,7 +1433,7 @@ void LogicalType::Verify() const { break; case LogicalTypeId::STRUCT: { // verify child types - case_insensitive_set_t child_names; + identifier_set_t child_names; bool all_empty = true; for (auto &entry : StructType::GetChildTypes(*this)) { if (entry.first.empty()) { @@ -1379,7 +1487,7 @@ bool ApproxEqual(double ldecimal, double rdecimal) { void LogicalType::Serialize(Serializer &serializer) const { // Serialize geometry as old extension geometry type if required - if (id_ == LogicalTypeId::GEOMETRY && !serializer.ShouldSerialize(7)) { + if (id_ == LogicalTypeId::GEOMETRY && !serializer.ShouldSerialize(StorageVersion::V1_5_0)) { // This will drop the CRS information, but that's better than throwing an error. auto legacy_geom = Geometry::GetSpatialGeometryType(); legacy_geom.Serialize(serializer); @@ -1390,7 +1498,7 @@ void LogicalType::Serialize(Serializer &serializer) const { // 1. try to default-bind into a concrete logical type, and serialize that // 2. if that fails, serialize normally, in which case the UNBOUND_TYPE_INFO will try to // write itself as an old-style USER type. - if (id_ == LogicalTypeId::UNBOUND && !serializer.ShouldSerialize(7)) { + if (id_ == LogicalTypeId::UNBOUND && !serializer.ShouldSerialize(StorageVersion::V1_5_0)) { try { auto bound_type = UnboundType::TryDefaultBind(*this); if (bound_type.id() != LogicalTypeId::INVALID && bound_type.id() != LogicalTypeId::UNBOUND) { @@ -1408,6 +1516,10 @@ void LogicalType::Serialize(Serializer &serializer) const { LogicalType LogicalType::Deserialize(Deserializer &deserializer) { auto id = deserializer.ReadProperty(100, "id"); auto type_info = deserializer.ReadPropertyWithDefault>(101, "type_info"); + if (id == LogicalTypeId::LEGACY_AGGREGATE_STATE) { + // convert legacy aggregate state to blob on deserialize + return LogicalType::BLOB; + } LogicalType result(id, std::move(type_info)); @@ -1554,67 +1666,12 @@ LogicalType LogicalType::LIST(const LogicalType &child) { return LogicalType(LogicalTypeId::LIST, std::move(info)); } -//===--------------------------------------------------------------------===// -// Legacy Aggregate State Type -//===--------------------------------------------------------------------===// -const aggregate_state_t &LegacyAggregateStateType::GetStateType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::LEGACY_AGGREGATE_STATE); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().state_type; -} - -const string LegacyAggregateStateType::GetTypeName(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::LEGACY_AGGREGATE_STATE); - auto info = type.AuxInfo(); - if (!info) { - return "LEGACY_AGGREGATE_STATE"; - } - auto aggr_state = info->Cast().state_type; - return "LEGACY_AGGREGATE_STATE<" + aggr_state.function_name + "(" + - StringUtil::Join(aggr_state.bound_argument_types, aggr_state.bound_argument_types.size(), ", ", - [](const LogicalType &arg_type) { return arg_type.ToString(); }) + - ")" + "::" + aggr_state.return_type.ToString() + ">"; -} - -//===--------------------------------------------------------------------===// -// Aggregate State Type (Struct Based) -//===--------------------------------------------------------------------===// - -const aggregate_state_t &AggregateStateType::GetStateType(const LogicalType &type) { - if (type.id() == LogicalTypeId::LEGACY_AGGREGATE_STATE) { - return LegacyAggregateStateType::GetStateType(type); - } - D_ASSERT(type.IsAggregateStateStructType()); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().state_type; -} - -const string AggregateStateType::GetTypeName(const LogicalType &type) { - D_ASSERT(type.IsAggregateStateStructType()); - auto info = type.AuxInfo(); - if (!info) { - return "AGGREGATE_STATE"; - } - auto aggr_state = info->Cast().state_type; - auto struct_type = LogicalType::STRUCT(GetChildTypes(type)); - return "AGGREGATE_STATE<" + aggr_state.function_name + "(" + - StringUtil::Join(aggr_state.bound_argument_types, aggr_state.bound_argument_types.size(), ", ", - [](const LogicalType &arg_type) { return arg_type.ToString(); }) + - ")" + "::" + aggr_state.return_type.ToString() + ", " + struct_type.ToString() + ">"; -} - -bool LogicalType::IsAggregateStateStructType() const { - return id() == LogicalTypeId::AGGREGATE_STATE && InternalType() == PhysicalType::STRUCT; -} - //===--------------------------------------------------------------------===// // Struct Type //===--------------------------------------------------------------------===// const child_list_t &StructType::GetChildTypes(const LogicalType &type) { D_ASSERT(type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION || - type.id() == LogicalTypeId::VARIANT || type.IsAggregateStateStructType()); + type.id() == LogicalTypeId::VARIANT); auto info = type.AuxInfo(); D_ASSERT(info); @@ -1627,7 +1684,7 @@ const LogicalType &StructType::GetChildType(const LogicalType &type, idx_t index return child_types[index].second; } -const string &StructType::GetChildName(const LogicalType &type, idx_t index) { +const Identifier &StructType::GetChildName(const LogicalType &type, idx_t index) { auto &child_types = StructType::GetChildTypes(type); D_ASSERT(index < child_types.size()); return child_types[index].first; @@ -1636,7 +1693,7 @@ const string &StructType::GetChildName(const LogicalType &type, idx_t index) { idx_t StructType::GetChildIndexUnsafe(const LogicalType &type, const string &name) { auto &child_types = StructType::GetChildTypes(type); for (idx_t i = 0; i < child_types.size(); i++) { - if (StringUtil::CIEquals(child_types[i].first, name)) { + if (child_types[i].first == name) { return i; } } @@ -1659,17 +1716,6 @@ LogicalType LogicalType::STRUCT(child_list_t children) { return LogicalType(LogicalTypeId::STRUCT, std::move(info)); } -LogicalType LogicalType::AGGREGATE_STATE(aggregate_state_t state_type, child_list_t struct_child_types) { - auto info = make_shared_ptr(std::move(state_type), std::move(struct_child_types)); - return LogicalType(LogicalTypeId::AGGREGATE_STATE, std::move(info)); -} - -LogicalType LogicalType::LEGACY_AGGREGATE_STATE(aggregate_state_t state_type) { - // Legacy BLOB aggregate state - auto info = make_shared_ptr(std::move(state_type)); - return LogicalType(LogicalTypeId::LEGACY_AGGREGATE_STATE, std::move(info)); -} - //===--------------------------------------------------------------------===// // Map Type //===--------------------------------------------------------------------===// @@ -1731,7 +1777,7 @@ const LogicalType &UnionType::GetMemberType(const LogicalType &type, idx_t index return child_types[index + 1].second; } -const string &UnionType::GetMemberName(const LogicalType &type, idx_t index) { +const Identifier &UnionType::GetMemberName(const LogicalType &type, idx_t index) { auto &child_types = StructType::GetChildTypes(type); D_ASSERT(index < child_types.size()); // skip the "tag" field @@ -1766,11 +1812,11 @@ LogicalType LogicalType::TYPE() { //===--------------------------------------------------------------------===// // Enum Type //===--------------------------------------------------------------------===// -LogicalType LogicalType::ENUM(Vector &ordered_data, idx_t size) { +LogicalType LogicalType::ENUM(const Vector &ordered_data, idx_t size) { return EnumTypeInfo::CreateType(ordered_data, size); } -LogicalType LogicalType::ENUM(const string &enum_name, Vector &ordered_data, idx_t size) { +LogicalType LogicalType::ENUM(const string &enum_name, const Vector &ordered_data, idx_t size) { return LogicalType::ENUM(ordered_data, size); } @@ -1813,6 +1859,9 @@ bool LogicalType::IsJSONType() const { return id() == LogicalTypeId::VARCHAR && HasAlias() && GetAlias() == JSON_TYPE_NAME; } +bool LogicalType::IsAggregateState() const { + return HasAlias() && GetAlias() == "AGGREGATE_STATE"; +} //===--------------------------------------------------------------------===// // Array Type //===--------------------------------------------------------------------===// @@ -2079,7 +2128,7 @@ static LogicalType TryDefaultBindTypeExpression(const ParsedExpression &expr) { } // Try to bind as far as we can - auto result = DefaultTypeGenerator::TryDefaultBind(name, bound_args); + auto result = DefaultTypeGenerator::TryDefaultBind(name.GetIdentifierName(), bound_args); if (result.id() != LogicalTypeId::INVALID) { return result; } diff --git a/src/duckdb/src/common/types/bignum.cpp b/src/duckdb/src/common/types/bignum.cpp index 8ea59791c..e5e54096f 100644 --- a/src/duckdb/src/common/types/bignum.cpp +++ b/src/duckdb/src/common/types/bignum.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/types/bignum.hpp" #include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/string_util.hpp" #include "duckdb/common/typedefs.hpp" #include @@ -136,7 +137,7 @@ bool Bignum::VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &e } idx_t cur_pos = start_pos; // Verify all is numeric - while (cur_pos < end_pos && std::isdigit(int_value_char[cur_pos])) { + while (cur_pos < end_pos && StringUtil::CharacterIsDigit(int_value_char[cur_pos])) { cur_pos++; } if (cur_pos < end_pos) { @@ -149,7 +150,7 @@ bool Bignum::VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &e } while (cur_pos < end_pos) { - if (std::isdigit(int_value_char[cur_pos])) { + if (StringUtil::CharacterIsDigit(int_value_char[cur_pos])) { cur_pos++; } else { // By now we can only have numbers, otherwise this is invalid. diff --git a/src/duckdb/src/common/types/bit.cpp b/src/duckdb/src/common/types/bit.cpp index dc70992e8..7fd24b723 100644 --- a/src/duckdb/src/common/types/bit.cpp +++ b/src/duckdb/src/common/types/bit.cpp @@ -130,8 +130,13 @@ string Bit::ToString(bitstring_t str) { } bool Bit::TryGetBitStringSize(string_t str, idx_t &str_len, string *error_message) { - auto data = const_data_ptr_cast(str.GetData()); auto len = str.GetSize(); + if (len == 0) { + str_len = ComputeBitstringLen(1); + return true; + } + + auto data = const_data_ptr_cast(str.GetData()); idx_t bit_count = 0; if (data[0] == 'x') { @@ -152,11 +157,15 @@ bool Bit::TryGetBitStringSize(string_t str, idx_t &str_len, string *error_messag bool Bit::ToBit(string_t str, bitstring_t &output_str, string *error_message) { auto data = str.GetData(); + auto len = str.GetSize(); + if (len == 0) { + Bit::SetEmptyBitString(output_str, 1); + return true; + } if (data[0] == 'x') { return HexToBit(str, output_str, error_message); } - auto len = str.GetSize(); auto output = output_str.GetDataWriteable(); char byte = 0; diff --git a/src/duckdb/src/common/types/column/column_data_allocator.cpp b/src/duckdb/src/common/types/column/column_data_allocator.cpp index 2be91054a..2bd3d01b2 100644 --- a/src/duckdb/src/common/types/column/column_data_allocator.cpp +++ b/src/duckdb/src/common/types/column/column_data_allocator.cpp @@ -194,6 +194,40 @@ void ColumnDataAllocator::AllocateData(idx_t size, uint32_t &block_id, uint32_t } } +void ColumnDataAllocator::Reset() { + for (auto &block : blocks) { + block.size = 0; + } +} + +void ColumnDataAllocator::ResetPreserveLastBlock() { + if (blocks.empty()) { + return; + } + if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { + D_ASSERT(blocks.size() == allocated_data.size()); + if (blocks.size() > 1) { + auto last_block = std::move(blocks.back()); + auto last_allocated = std::move(allocated_data.back()); + blocks.clear(); + allocated_data.clear(); + blocks.push_back(std::move(last_block)); + allocated_data.push_back(std::move(last_allocated)); + } + } else if (blocks.size() > 1) { + auto last_block = std::move(blocks.back()); + blocks.clear(); + blocks.push_back(std::move(last_block)); + } + for (auto &block : blocks) { + block.size = 0; + } + allocated_size = 0; + for (auto &block : blocks) { + allocated_size += block.capacity; + } +} + void ColumnDataAllocator::Initialize(ColumnDataAllocator &other) { D_ASSERT(other.HasBlocks()); blocks.push_back(other.blocks.back()); diff --git a/src/duckdb/src/common/types/column/column_data_collection.cpp b/src/duckdb/src/common/types/column/column_data_collection.cpp index d97acf49e..99aeee9e1 100644 --- a/src/duckdb/src/common/types/column/column_data_collection.cpp +++ b/src/duckdb/src/common/types/column/column_data_collection.cpp @@ -1234,6 +1234,19 @@ void ColumnDataCollection::Combine(ColumnDataCollection &other) { Verify(); } +void ColumnDataCollection::Swap(ColumnDataCollection &other) { + if (types != other.types) { + throw InternalException("Attempting to swap ColumnDataCollections with mismatching types"); + } + std::swap(allocator, other.allocator); + std::swap(types, other.types); + std::swap(count, other.count); + std::swap(segments, other.segments); + std::swap(copy_functions, other.copy_functions); + std::swap(finished_append, other.finished_append); + std::swap(partition_index, other.partition_index); +} + //===--------------------------------------------------------------------===// // Fetch //===--------------------------------------------------------------------===// @@ -1309,6 +1322,24 @@ void ColumnDataCollection::Reset() { allocator = make_shared_ptr(*allocator); } +void ColumnDataCollection::ResetForReuse() { + count = 0; + finished_append = false; + if (segments.empty()) { + allocator->Reset(); + return; + } + + for (auto &segment : segments) { + segment->Reset(); + } + + if (segments.size() > 1) { + segments.resize(1); + } + allocator->ResetPreserveLastBlock(); +} + struct ValueResultEquals { bool operator()(const Value &a, const Value &b) const { return Value::DefaultValuesAreEqual(a, b); @@ -1426,7 +1457,7 @@ unique_ptr ColumnDataCollection::Deserialize(Deserializer for (idx_t c = 0; c < types.size(); c++) { chunk.data[c].Append(values[c][r]); } - chunk.SetCardinality(chunk.size() + 1); + // the appends above already grow the child vectors, so chunk.size() reflects the new cardinality if (chunk.size() == STANDARD_VECTOR_SIZE) { collection->Append(chunk); chunk.Reset(); diff --git a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp index 0f3f3aee0..3417f0e91 100644 --- a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp +++ b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp @@ -15,6 +15,16 @@ ColumnDataCollectionSegment::ColumnDataCollectionSegment(shared_ptr(allocator->GetAllocator())) { } +void ColumnDataCollectionSegment::Reset() { + count = 0; + total_allocated = 0; + last_chunk_total_allocated = 0; + chunk_data.clear(); + vector_data.clear(); + child_indices.clear(); + heap->GetAllocator().Reset(); +} + idx_t ColumnDataCollectionSegment::GetDataSize(idx_t type_size) { return AlignValue(type_size * STANDARD_VECTOR_SIZE); } @@ -287,7 +297,7 @@ void ColumnDataCollectionSegment::ReadChunk(idx_t chunk_index, ChunkManagementSt D_ASSERT(vector_idx < chunk_meta.vector_data.size()); ReadVector(state, chunk_meta.vector_data[vector_idx], chunk.data[i]); } - chunk.SetCardinality(chunk_meta.count); + chunk.SetChildCardinality(chunk_meta.count); } idx_t ColumnDataCollectionSegment::ChunkCount() const { diff --git a/src/duckdb/src/common/types/data_chunk.cpp b/src/duckdb/src/common/types/data_chunk.cpp index c9a8d4040..82b269bdb 100644 --- a/src/duckdb/src/common/types/data_chunk.cpp +++ b/src/duckdb/src/common/types/data_chunk.cpp @@ -23,7 +23,7 @@ namespace duckdb { -DataChunk::DataChunk() : count(0) { +DataChunk::DataChunk() { } DataChunk::~DataChunk() { @@ -91,8 +91,9 @@ idx_t DataChunk::GetAllocationSize() const { } void DataChunk::Reset() { - SetCardinality(0); + count = optional_idx(); if (data.empty() || vector_caches.empty()) { + count = 0; return; } if (vector_caches.size() != data.size()) { @@ -106,7 +107,7 @@ void DataChunk::Reset() { void DataChunk::Destroy() { data.clear(); vector_caches.clear(); - SetCardinality(0); + count = optional_idx(); } Value DataChunk::GetValue(idx_t col_idx, idx_t index) const { @@ -127,15 +128,32 @@ bool DataChunk::AllConstant() const { return true; } -void DataChunk::SetCardinality(idx_t count_p) { - this->count = count_p; +void DataChunk::CheckCardinality(idx_t count_p) { + for (auto &v : data) { + if (v.size() != count_p) { + throw InternalException("DataChunk::CheckCardinality - vector has size %d but expected size %d", v.size(), + count_p); + } + } + this->count = optional_idx(); } void DataChunk::SetChildCardinality(idx_t count_p) { - this->count = count_p; for (auto &v : data) { - FlatVector::SetSize(v, count_p); + // null-buffer placeholders (InitializeEmpty) and non-flat/constant vectors cannot be resized + if (!v.GetBufferRef()) { + continue; + } + auto vtype = v.GetVectorType(); + if (vtype == VectorType::FLAT_VECTOR || vtype == VectorType::CONSTANT_VECTOR) { + FlatVector::SetSize(v, count_p); + } else if (v.size() != count_p) { + throw InternalException("DataChunk::SetChildCardinality - vector has size %d but expected size %d - and " + "cannot change vector size because it is not a flat or constant vector", + v.size(), count_p); + } } + this->count = count_p; } void DataChunk::Reference(DataChunk &chunk) { @@ -143,14 +161,13 @@ void DataChunk::Reference(DataChunk &chunk) { for (idx_t i = 0; i < chunk.ColumnCount(); i++) { data[i].Reference(chunk.data[i]); } - SetCardinality(chunk); + count = chunk.count; } void DataChunk::Move(DataChunk &chunk) { - SetCardinality(chunk); data = std::move(chunk.data); vector_caches = std::move(chunk.vector_caches); - + count = chunk.count; chunk.Destroy(); } @@ -162,8 +179,9 @@ void DataChunk::Copy(DataChunk &other, idx_t offset) const { for (idx_t i = 0; i < ColumnCount(); i++) { D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); VectorOperations::Copy(data[i], other.data[i], size(), offset, 0); + FlatVector::SetSize(other.data[i], target_count); } - other.SetChildCardinality(target_count); + other.count = optional_idx(); } void DataChunk::Copy(DataChunk &other, const SelectionVector &sel, const idx_t source_count, const idx_t offset) const { @@ -175,8 +193,9 @@ void DataChunk::Copy(DataChunk &other, const SelectionVector &sel, const idx_t s for (idx_t i = 0; i < ColumnCount(); i++) { D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); VectorOperations::Copy(data[i], other.data[i], sel, source_count, offset, 0); + FlatVector::SetSize(other.data[i], target_count); } - other.SetChildCardinality(target_count); + other.count = optional_idx(); } void DataChunk::Split(DataChunk &other, idx_t split_idx) { @@ -192,7 +211,7 @@ void DataChunk::Split(DataChunk &other, idx_t split_idx) { data.pop_back(); vector_caches.pop_back(); } - other.SetCardinality(*this); + count = other.count; } void DataChunk::Fuse(DataChunk &other) { @@ -202,6 +221,7 @@ void DataChunk::Fuse(DataChunk &other) { data.emplace_back(std::move(other.data[col_idx])); vector_caches.emplace_back(std::move(other.vector_caches[col_idx])); } + count = other.count; other.Destroy(); } @@ -209,12 +229,12 @@ void DataChunk::ReferenceColumns(DataChunk &other, const vector &colum D_ASSERT(ColumnCount() == column_ids.size()); Reset(); for (idx_t col_idx = 0; col_idx < ColumnCount(); col_idx++) { - auto &other_col = other.data[column_ids[col_idx]]; + const auto &other_col = other.data[column_ids[col_idx]]; auto &this_col = data[col_idx]; D_ASSERT(other_col.GetType() == this_col.GetType()); this_col.Reference(other_col); } - SetCardinality(other.size()); + count = other.count; } void DataChunk::Append(const DataChunk &other, VectorAppendMode append_mode) { @@ -226,20 +246,29 @@ void DataChunk::Append(const DataChunk &other, const SelectionVector &sel, idx_t if (sel_count == 0) { return; } - idx_t new_size = size() + sel_count; if (ColumnCount() != other.ColumnCount()) { throw InternalException("Column counts of appending chunk doesn't match!"); } + idx_t current_size = size(); for (idx_t i = 0; i < ColumnCount(); i++) { // ensure data[i] has the chunk's current size so the append computes new_size = current + append_size - FlatVector::SetSize(data[i], size()); + if (data[i].size() != current_size) { + throw InternalException("DataChunk::Append size mismatch - child has count %d but chunk has size %d", + data[i].size(), current_size); + } if (sel.IsSet()) { data[i].Append(other.data[i], sel, sel_count, append_mode); } else { data[i].Append(other.data[i], other.size(), append_mode); } } - SetCardinality(new_size); + CheckCardinality(current_size + sel_count); + if (ColumnCount() == 0) { + // a column-less chunk has no child vectors to carry the cardinality - record it explicitly + count = current_size + sel_count; + } else { + count = optional_idx(); + } } void DataChunk::Flatten() { @@ -283,7 +312,7 @@ void DataChunk::Serialize(Serializer &serializer, bool compressed_serialization) // Reference the vector to avoid potentially mutating it during serialization Vector serialized_vector(data[i].GetType()); serialized_vector.Reference(data[i]); - serialized_vector.Serialize(object, row_count, compressed_serialization); + serialized_vector.Serialize(object, compressed_serialization); }); }); } @@ -311,11 +340,11 @@ void DataChunk::Deserialize(Deserializer &deserializer) { } void DataChunk::Slice(const SelectionVector &sel_vector, idx_t count_p) { - this->count = count_p; SelCache merge_cache; for (idx_t c = 0; c < ColumnCount(); c++) { data[c].Slice(sel_vector, count_p, merge_cache); } + count = count_p; } void DataChunk::Slice(const DataChunk &other, idx_t offset, idx_t end) { @@ -325,12 +354,11 @@ void DataChunk::Slice(const DataChunk &other, idx_t offset, idx_t end) { for (idx_t c = 0; c < other.ColumnCount(); c++) { data[c].Slice(other.data[c], offset, end); } - SetCardinality(end - offset); + count = end - offset; } void DataChunk::Slice(const DataChunk &other, const SelectionVector &sel, idx_t count_p, idx_t col_offset) { D_ASSERT(other.ColumnCount() <= col_offset + ColumnCount()); - this->count = count_p; SelCache merge_cache; for (idx_t c = 0; c < other.ColumnCount(); c++) { if (other.data[c].GetVectorType() == VectorType::DICTIONARY_VECTOR) { @@ -341,6 +369,7 @@ void DataChunk::Slice(const DataChunk &other, const SelectionVector &sel, idx_t data[col_offset + c].Slice(other.data[c], sel, count_p); } } + count = count_p; } void DataChunk::Slice(idx_t offset, idx_t slice_count) { @@ -365,19 +394,20 @@ unsafe_unique_array DataChunk::ToUnifiedFormat() { void DataChunk::Hash(Vector &result) { D_ASSERT(result.GetType().id() == LogicalType::HASH); - VectorOperations::Hash(data[0], result, size()); + auto hash_count = size(); + VectorOperations::Hash(data[0], result, hash_count); for (idx_t i = 1; i < ColumnCount(); i++) { - VectorOperations::CombineHash(result, data[i], size()); + VectorOperations::CombineHash(result, data[i], hash_count); } } void DataChunk::Hash(vector &column_ids, Vector &result) { D_ASSERT(result.GetType().id() == LogicalType::HASH); D_ASSERT(!column_ids.empty()); - - VectorOperations::Hash(data[column_ids[0]], result, size()); + auto hash_count = size(); + VectorOperations::Hash(data[column_ids[0]], result, hash_count); for (idx_t i = 1; i < column_ids.size(); i++) { - VectorOperations::CombineHash(result, data[column_ids[i]], size()); + VectorOperations::CombineHash(result, data[column_ids[i]], hash_count); } } @@ -449,12 +479,12 @@ void DataChunk::VerifyInternal(DebugVerificationMode mode, optional_ptr(source, result, count, [&](const string_t &wkb) { @@ -1091,7 +1092,7 @@ bool Geometry::FromBinary(Vector &source, Vector &result, idx_t count, bool stri return all_ok; } -void Geometry::ToBinary(Vector &source, Vector &result, idx_t count) { +void Geometry::ToBinary(const Vector &source, Vector &result) { // We are currently using WKB internally, so just copy as-is! result.Reinterpret(source); } @@ -1276,7 +1277,7 @@ uint32_t Geometry::GetExtent(const string_t &wkb, GeometryExtent &extent, bool & //---------------------------------------------------------------------------------------------------------------------- template -static void ToPoints(Vector &source_vec, Vector &target_vec, idx_t row_count) { +static void ToPoints(const Vector &source_vec, Vector &target_vec, idx_t row_count) { const auto geom_data = source_vec.Values(); auto vert_writer = FlatVector::Writer(target_vec, row_count); @@ -1301,7 +1302,7 @@ static void ToPoints(Vector &source_vec, Vector &target_vec, idx_t row_count) { } template -static void FromPoints(Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { +static void FromPoints(const Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { auto vert_iter = source_vec.Values(); auto result_data = FlatVector::Writer(target_vec, row_count, result_offset); for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { @@ -1331,7 +1332,7 @@ static void FromPoints(Vector &source_vec, Vector &target_vec, idx_t row_count, } template -static void ToLineStrings(Vector &source_vec, Vector &target_vec, idx_t row_count) { +static void ToLineStrings(const Vector &source_vec, Vector &target_vec, idx_t row_count) { auto geom_data = source_vec.Values(); auto list_writer = FlatVector::Writer>(target_vec, row_count); @@ -1356,7 +1357,7 @@ static void ToLineStrings(Vector &source_vec, Vector &target_vec, idx_t row_coun } template -static void FromLineStrings(Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { +static void FromLineStrings(const Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { auto line_iter = source_vec.Values>(); auto result_data = FlatVector::Writer(target_vec, row_count, result_offset); @@ -1390,7 +1391,7 @@ static void FromLineStrings(Vector &source_vec, Vector &target_vec, idx_t row_co } template -static void ToPolygons(Vector &source_vec, Vector &target_vec, idx_t row_count) { +static void ToPolygons(const Vector &source_vec, Vector &target_vec, idx_t row_count) { auto geom_data = source_vec.Values(); auto poly_writer = FlatVector::Writer>>(target_vec, row_count); @@ -1414,7 +1415,7 @@ static void ToPolygons(Vector &source_vec, Vector &target_vec, idx_t row_count) } template -static void FromPolygons(Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { +static void FromPolygons(const Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { auto poly_iter = source_vec.Values>>(); auto result_data = FlatVector::Writer(target_vec, row_count, result_offset); @@ -1453,7 +1454,7 @@ static void FromPolygons(Vector &source_vec, Vector &target_vec, idx_t row_count } template -static void ToMultiPoints(Vector &source_vec, Vector &target_vec, idx_t row_count) { +static void ToMultiPoints(const Vector &source_vec, Vector &target_vec, idx_t row_count) { auto geom_data = source_vec.Values(); auto mult_writer = FlatVector::Writer>(target_vec, row_count); for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { @@ -1474,7 +1475,7 @@ static void ToMultiPoints(Vector &source_vec, Vector &target_vec, idx_t row_coun } template -static void FromMultiPoints(Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { +static void FromMultiPoints(const Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { auto mult_iter = source_vec.Values>(); auto result_data = FlatVector::Writer(target_vec, row_count, result_offset); @@ -1514,7 +1515,7 @@ static void FromMultiPoints(Vector &source_vec, Vector &target_vec, idx_t row_co } template -static void ToMultiLineStrings(Vector &source_vec, Vector &target_vec, idx_t row_count) { +static void ToMultiLineStrings(const Vector &source_vec, Vector &target_vec, idx_t row_count) { source_vec.Flatten(); const auto geom_data = FlatVector::GetData(source_vec); auto mult_writer = @@ -1539,7 +1540,7 @@ static void ToMultiLineStrings(Vector &source_vec, Vector &target_vec, idx_t row } template -static void FromMultiLineStrings(Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { +static void FromMultiLineStrings(const Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { auto mult_iter = source_vec.Values>>(); auto result_data = FlatVector::Writer(target_vec, row_count, result_offset); @@ -1585,7 +1586,7 @@ static void FromMultiLineStrings(Vector &source_vec, Vector &target_vec, idx_t r } template -static void ToMultiPolygons(Vector &source_vec, Vector &target_vec, idx_t row_count) { +static void ToMultiPolygons(const Vector &source_vec, Vector &target_vec, idx_t row_count) { source_vec.Flatten(); const auto geom_data = FlatVector::GetData(source_vec); auto mult_writer = FlatVector::Writer>>>( @@ -1613,7 +1614,7 @@ static void ToMultiPolygons(Vector &source_vec, Vector &target_vec, idx_t row_co } template -static void FromMultiPolygons(Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { +static void FromMultiPolygons(const Vector &source_vec, Vector &target_vec, idx_t row_count, idx_t result_offset) { auto mult_iter = source_vec.Values>>>(); auto result_data = FlatVector::Writer(target_vec, row_count, result_offset); @@ -1665,7 +1666,7 @@ static void FromMultiPolygons(Vector &source_vec, Vector &target_vec, idx_t row_ } template -static void ToVectorizedFormatInternal(Vector &source, Vector &target, idx_t count, GeometryType geom_type) { +static void ToVectorizedFormatInternal(const Vector &source, Vector &target, idx_t count, GeometryType geom_type) { switch (geom_type) { case GeometryType::POINT: ToPoints(source, target, count); @@ -1690,7 +1691,7 @@ static void ToVectorizedFormatInternal(Vector &source, Vector &target, idx_t cou } } -void Geometry::ToVectorizedFormat(Vector &source, Vector &target, idx_t count, GeometryType geom_type, +void Geometry::ToVectorizedFormat(const Vector &source, Vector &target, idx_t count, GeometryType geom_type, VertexType vert_type) { switch (vert_type) { case VertexType::XY: @@ -1711,7 +1712,7 @@ void Geometry::ToVectorizedFormat(Vector &source, Vector &target, idx_t count, G } template -static void FromVectorizedFormatInternal(Vector &source, Vector &target, idx_t count, GeometryType geom_type, +static void FromVectorizedFormatInternal(const Vector &source, Vector &target, idx_t count, GeometryType geom_type, idx_t result_offset) { switch (geom_type) { case GeometryType::POINT: @@ -1737,7 +1738,7 @@ static void FromVectorizedFormatInternal(Vector &source, Vector &target, idx_t c } } -void Geometry::FromVectorizedFormat(Vector &source, Vector &target, idx_t count, GeometryType geom_type, +void Geometry::FromVectorizedFormat(const Vector &source, Vector &target, idx_t count, GeometryType geom_type, VertexType vert_type, idx_t result_offset) { switch (vert_type) { case VertexType::XY: @@ -2044,7 +2045,7 @@ void Geometry::FromSpatialGeometry(const string_t &source, string_t &target, Vec target = blob; } -void Geometry::FromSpatialGeometry(Vector &source_vec, Vector &target_vec, idx_t count, idx_t result_offset) { +void Geometry::FromSpatialGeometry(const Vector &source_vec, Vector &target_vec, idx_t count, idx_t result_offset) { auto entries = source_vec.Values(); auto target_data = FlatVector::GetDataMutable(target_vec); @@ -2370,7 +2371,7 @@ void Geometry::ToSpatialGeometry(const string_t &source, string_t &target, Vecto target = blob; } -void Geometry::ToSpatialGeometry(Vector &source, Vector &target, idx_t count) { +void Geometry::ToSpatialGeometry(const Vector &source, Vector &target, idx_t count) { UnaryExecutor::Execute(source, target, count, [&](const string_t &source) { string_t result; ToSpatialGeometry(source, result, target); @@ -2392,7 +2393,7 @@ void Geometry::ToSpatialGeometry(const string_t &source, string &target) { ToSpatialGeometryConvert(state, writer); } -void Geometry::ToVectorizedFormat(Vector &source, Vector &target, idx_t count, GeometryStorageType type) { +void Geometry::ToVectorizedFormat(const Vector &source, Vector &target, idx_t count, GeometryStorageType type) { if (type == GeometryStorageType::SPATIAL) { ToSpatialGeometry(source, target, count); return; @@ -2402,7 +2403,7 @@ void Geometry::ToVectorizedFormat(Vector &source, Vector &target, idx_t count, G ToVectorizedFormat(source, target, count, types.first, types.second); } -void Geometry::FromVectorizedFormat(Vector &source, Vector &target, idx_t count, GeometryStorageType type, +void Geometry::FromVectorizedFormat(const Vector &source, Vector &target, idx_t count, GeometryStorageType type, idx_t result_offset) { if (type == GeometryStorageType::SPATIAL) { FromSpatialGeometry(source, target, count, result_offset); diff --git a/src/duckdb/src/common/types/geometry_crs.cpp b/src/duckdb/src/common/types/geometry_crs.cpp index dd22b5deb..8916561ae 100644 --- a/src/duckdb/src/common/types/geometry_crs.cpp +++ b/src/duckdb/src/common/types/geometry_crs.cpp @@ -156,10 +156,10 @@ class WKTParser { bool TryMatchText(string &result) { const auto start = pos; // First character must be alphabetic or underscore - if (pos < end && (isalpha(*pos) || *pos == '_')) { + if (pos < end && (StringUtil::CharacterIsAlpha(*pos) || *pos == '_')) { pos++; // Subsequent characters can also include digits - while (pos < end && (isalnum(*pos) || *pos == '_')) { + while (pos < end && (StringUtil::CharacterIsAlphaNumeric(*pos) || *pos == '_')) { pos++; } } @@ -286,7 +286,7 @@ class WKTParser { } void SkipWhitespace() { - while (pos < end && isspace(*pos)) { + while (pos < end && StringUtil::CharacterIsSpace(*pos)) { pos++; } } @@ -686,8 +686,8 @@ unique_ptr CoordinateReferenceSystem::TryConvert(Clie } auto &catalog = Catalog::GetSystemCatalog(context); - auto entry = catalog.GetEntry(context, CatalogType::COORDINATE_SYSTEM_ENTRY, DEFAULT_SCHEMA, - source_crs.GetIdentifier(), OnEntryNotFound::RETURN_NULL); + auto entry = catalog.GetEntry(context, CatalogType::COORDINATE_SYSTEM_ENTRY, Identifier::DefaultSchema(), + Identifier(source_crs.GetIdentifier()), OnEntryNotFound::RETURN_NULL); if (!entry) { return nullptr; } @@ -702,7 +702,7 @@ unique_ptr CoordinateReferenceSystem::TryConvert(Clie return make_uniq(crs_entry.authority + ":" + crs_entry.code); } case CoordinateReferenceSystemType::SRID: { - return make_uniq(crs_entry.name); + return make_uniq(crs_entry.name.GetIdentifierName()); } case CoordinateReferenceSystemType::PROJJSON: { if (crs_entry.projjson_definition.empty()) { diff --git a/src/duckdb/src/common/types/hyperloglog.cpp b/src/duckdb/src/common/types/hyperloglog.cpp index 071a470a2..5978cccf7 100644 --- a/src/duckdb/src/common/types/hyperloglog.cpp +++ b/src/duckdb/src/common/types/hyperloglog.cpp @@ -171,7 +171,7 @@ class HLLV1 { }; void HyperLogLog::Serialize(Serializer &serializer) const { - if (serializer.ShouldSerialize(3)) { + if (serializer.ShouldSerialize(StorageVersion::V1_2_0)) { serializer.WriteProperty(100, "type", HLLStorageType::HLL_V2); serializer.WriteProperty(101, "data", k, sizeof(k)); } else { diff --git a/src/duckdb/src/common/types/interval.cpp b/src/duckdb/src/common/types/interval.cpp index 13a1a068d..68f7373b4 100644 --- a/src/duckdb/src/common/types/interval.cpp +++ b/src/duckdb/src/common/types/interval.cpp @@ -102,34 +102,34 @@ bool Interval::FromCString(const char *str, idx_t len, interval_t &result, strin // colon: we are parsing a time goto interval_parse_time; } else { - if (pos == start_pos) { - return false; - } - // finished the number, parse it from the string - string_t nr_string(str + start_pos, UnsafeNumericCast(pos - start_pos)); - number = Cast::Operation(nr_string); - fraction = 0; - if (c == '.') { - idx_t frac_start = 0; - for (++pos; pos < len && StringUtil::CharacterIsDigit(str[pos]); ++pos) { - if (frac_start == 0) { - frac_start = pos; - } - } - - if (frac_start != 0) { - string_t frac_string(str + frac_start - 1, UnsafeNumericCast(pos - frac_start + 1)); - fraction = Cast::Operation(frac_string); + break; + } + } + { + if (pos == start_pos) { + return false; + } + string_t nr_string(str + start_pos, UnsafeNumericCast(pos - start_pos)); + number = Cast::Operation(nr_string); + fraction = 0; + if (pos < len && str[pos] == '.') { + idx_t frac_start = 0; + for (++pos; pos < len && StringUtil::CharacterIsDigit(str[pos]); ++pos) { + if (frac_start == 0) { + frac_start = pos; } } - if (negative) { - number = -number; - fraction = -fraction; + if (frac_start != 0) { + string_t frac_string(str + frac_start - 1, UnsafeNumericCast(pos - frac_start + 1)); + fraction = Cast::Operation(frac_string); } - goto interval_parse_identifier; } + if (negative) { + number = -number; + fraction = -fraction; + } + goto interval_parse_identifier; } - goto end_of_string; interval_parse_time : { // parse the remainder of the time as a Time type dtime_t time; diff --git a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp index a6d832dce..1f7ad5fff 100644 --- a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp +++ b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp @@ -41,6 +41,12 @@ void PartitionedTupleData::InitializeAppendState(PartitionedTupleDataAppendState InitializeAppendStateInternal(state, properties); } +void PartitionedTupleData::ResetAppendState(PartitionedTupleDataAppendState &state, + TupleDataPinProperties properties) const { + // Default: fall back to full re-initialization (subclasses can override with a faster path) + InitializeAppendState(state, properties); +} + void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, DataChunk &input, const SelectionVector &append_sel, const idx_t append_count) { TupleDataCollection::ToUnifiedFormat(state.chunk_state, input); diff --git a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp index ebed39d09..d2bd09f93 100644 --- a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp @@ -52,6 +52,14 @@ void TupleDataAllocator::SetDestroyBufferUponUnpin() { } } +void TupleDataAllocator::Reset() { + // Mark existing blocks for cleanup when they are unpinned, then drop our references. + // Avoids copy-constructing a fresh TupleDataAllocator object just to clear the block lists. + SetDestroyBufferUponUnpin(); + row_blocks.clear(); + heap_blocks.clear(); +} + void TupleDataAllocator::DestroyRowBlocks(const idx_t row_block_begin, const idx_t row_block_end) { if (row_block_begin == row_block_end) { return; @@ -327,6 +335,7 @@ void TupleDataAllocator::InitializeChunkState(TupleDataSegment &segment, TupleDa InitializeChunkStateInternal(pin_state, chunk_state, 0, true, init_heap, init_heap, chunk_state.chunk_parts, sort_key_payload_state); + FlatVector::SetSize(chunk_state.row_locations, chunk.count); chunk_state.chunk_lock = &chunk.lock.get(); } @@ -414,7 +423,8 @@ void TupleDataAllocator::InitializeChunkStateInternal(TupleDataPinState &pin_sta if (sort_key_payload_state) { D_ASSERT(!layout.IsSortKeyLayout()); // This must be the payload collection - lock_guard guard(part.lock); + // SortKeySetPayload() guards sort-key payload mutations with the sort-key chunk lock. + // Avoid nesting part.lock with that lock, which can introduce lock-order inversions. SortKeySetPayload(row_locations, offset, next, *sort_key_payload_state); } diff --git a/src/duckdb/src/common/types/row/tuple_data_collection.cpp b/src/duckdb/src/common/types/row/tuple_data_collection.cpp index e51b5fa4f..ccdde70db 100644 --- a/src/duckdb/src/common/types/row/tuple_data_collection.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_collection.cpp @@ -321,10 +321,11 @@ void TupleDataCollection::AppendUnified(TupleDataPinState &pin_state, TupleDataC } Build(pin_state, chunk_state, 0, actual_append_count); + FlatVector::SetSize(chunk_state.row_locations, actual_append_count); Scatter(chunk_state, new_chunk, append_sel, actual_append_count); } -static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector &vector, const idx_t count) { +static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, const Vector &vector) { vector.ToUnifiedFormat(format.unified); format.original_sel = format.unified.sel; format.original_owned_sel.Initialize(format.unified.owned_sel); @@ -333,14 +334,13 @@ static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector auto &entries = StructVector::GetEntries(vector); D_ASSERT(format.children.size() == entries.size()); for (idx_t struct_col_idx = 0; struct_col_idx < entries.size(); struct_col_idx++) { - ToUnifiedFormatInternal(format.children[struct_col_idx], entries[struct_col_idx], count); + ToUnifiedFormatInternal(format.children[struct_col_idx], entries[struct_col_idx]); } break; } case PhysicalType::LIST: D_ASSERT(format.children.size() == 1); - ToUnifiedFormatInternal(format.children[0], ListVector::GetChildMutable(vector), - ListVector::GetListSize(vector)); + ToUnifiedFormatInternal(format.children[0], ListVector::GetChild(vector)); break; case PhysicalType::ARRAY: { D_ASSERT(format.children.size() == 1); @@ -363,7 +363,7 @@ static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector } format.unified.data = reinterpret_cast(format.array_list_entries.get()); - ToUnifiedFormatInternal(format.children[0], ArrayVector::GetChildMutable(vector), child_array_total_size); + ToUnifiedFormatInternal(format.children[0], ArrayVector::GetChild(vector)); break; } default: @@ -374,7 +374,7 @@ static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector void TupleDataCollection::ToUnifiedFormat(TupleDataChunkState &chunk_state, DataChunk &new_chunk) { D_ASSERT(chunk_state.vector_data.size() >= chunk_state.column_ids.size()); // Needs InitializeAppend for (const auto &col_idx : chunk_state.column_ids) { - ToUnifiedFormatInternal(chunk_state.vector_data[col_idx], new_chunk.data[col_idx], new_chunk.size()); + ToUnifiedFormatInternal(chunk_state.vector_data[col_idx], new_chunk.data[col_idx]); } } @@ -523,10 +523,38 @@ void TupleDataCollection::Combine(unique_ptr other) { void TupleDataCollection::Reset() { count = 0; data_size = 0; - segments.clear(); + // Find the first non-null segment. + // Segments may be null if they were moved out by a prior Combine() call. + idx_t live_idx = segments.size(); // sentinel: no live segment + for (idx_t i = 0; i < segments.size(); i++) { + if (segments[i]) { + live_idx = i; + break; + } + } - // Refreshes the TupleDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_shared_ptr(*allocator); + if (live_idx < segments.size()) { + // At least one live segment found. Keep exactly one live segment and make the collection-level allocator + // match that segment allocator before resetting to preserve segment/allocator consistency. + if (live_idx > 0) { + segments[0] = std::move(segments[live_idx]); + } + segments.resize(1); + allocator = segments[0]->allocator; + segments[0]->Reset(); + allocator->Reset(); + D_ASSERT(segments[0]->allocator.get() == allocator.get()); + D_ASSERT(segments[0]->count == 0); + D_ASSERT(segments[0]->chunks.empty()); + D_ASSERT(segments[0]->chunk_parts.empty()); + } else { + // All segments were null (moved out by Combine). The old allocator is still shared by + // those moved-out segments and must keep its row_blocks intact. Create a fresh allocator. + segments.clear(); + + // Refreshes the TupleDataAllocator to prevent holding on to allocated data unnecessarily + allocator = make_shared_ptr(*allocator); + } } void TupleDataCollection::InitializeChunk(DataChunk &chunk) const { @@ -568,8 +596,7 @@ void TupleDataCollection::InitializeScan(TupleDataScanState &state, TupleDataPin void TupleDataCollection::InitializeScan(TupleDataScanState &state, vector column_ids, TupleDataPinProperties properties) const { - state.pin_state.row_handles.clear(); - state.pin_state.heap_handles.clear(); + state.Reset(); state.pin_state.properties = properties; state.segment_index = 0; state.chunk_index = 0; @@ -624,7 +651,7 @@ bool TupleDataCollection::Scan(TupleDataScanState &state, DataChunk &result) { if (!segments.empty()) { FinalizePinState(state.pin_state, *segments[segment_index_before]); } - result.SetCardinality(0); + result.SetChildCardinality(0); return false; } if (segment_index_before != DConstants::INVALID_INDEX && segment_index != segment_index_before) { @@ -644,7 +671,7 @@ bool TupleDataCollection::Scan(TupleDataParallelScanState &gstate, TupleDataLoca if (!segments.empty()) { FinalizePinState(lstate.pin_state, *segments[segment_index_before]); } - result.SetCardinality(0); + result.SetChildCardinality(0); return false; } } diff --git a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp index 6ded0ace0..a4bbc9c5f 100644 --- a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp @@ -114,7 +114,7 @@ void TupleDataCollection::ComputeHeapSizes(TupleDataChunkState &chunk_state, con std::fill_n(heap_sizes, append_count, 0); for (idx_t col_idx = 0; col_idx < new_chunk.ColumnCount(); col_idx++) { - auto &source_v = new_chunk.data[col_idx]; + const auto &source_v = new_chunk.data[col_idx]; auto &source_format = chunk_state.vector_data[col_idx]; ComputeHeapSizes(chunk_state.heap_sizes, source_v, source_format, append_sel, append_count); } diff --git a/src/duckdb/src/common/types/row/tuple_data_segment.cpp b/src/duckdb/src/common/types/row/tuple_data_segment.cpp index be6901670..f1d77a6b1 100644 --- a/src/duckdb/src/common/types/row/tuple_data_segment.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_segment.cpp @@ -104,12 +104,28 @@ TupleDataSegment::TupleDataSegment(shared_ptr allocator_p) } TupleDataSegment::~TupleDataSegment() { - lock_guard guard(pinned_handles_lock); if (allocator) { allocator->SetDestroyBufferUponUnpin(); // Prevent blocks from being added to eviction queue } - pinned_row_handles.clear(); - pinned_heap_handles.clear(); + // Avoid triggering the mutex init slow-path (on macOS, first lock on a fresh mutex is expensive). + // In the destructor we are the sole owner, so checking without the lock is safe. + if (!pinned_row_handles.empty() || !pinned_heap_handles.empty()) { + lock_guard guard(pinned_handles_lock); + pinned_row_handles.clear(); + pinned_heap_handles.clear(); + } +} + +void TupleDataSegment::Reset() { + count = 0; + data_size = 0; + // Release any remaining pin handles before clearing. For UNPIN_AFTER_DONE, pinned_row_handles / + // pinned_heap_handles may be non-empty due to resize-only calls in CreateRowBlock / CreateHeapBlock + // (actual BufferHandles are in the TupleDataPinState, not here). Unpin() safely clears them. + Unpin(); + // chunk/chunk_part destructors are trivial (only POD fields), so clear is cheap. + chunks.clear(); + chunk_parts.clear(); } idx_t TupleDataSegment::ChunkCount() const { diff --git a/src/duckdb/src/common/types/time.cpp b/src/duckdb/src/common/types/time.cpp index 9da75007d..971996da6 100644 --- a/src/duckdb/src/common/types/time.cpp +++ b/src/duckdb/src/common/types/time.cpp @@ -175,6 +175,11 @@ bool Time::TryConvertTimeTZ(const char *buf, idx_t len, idx_t &pos, dtime_tz_t & return false; } + // Interval parsing accepts larger hour counts, but we must limit ourselves to <= 24:00:00 + if (time_part.micros > Interval::MICROS_PER_DAY) { + return false; + } + // skip optional whitespace before offset while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { pos++; diff --git a/src/duckdb/src/common/types/timestamp.cpp b/src/duckdb/src/common/types/timestamp.cpp index 08101d8c8..46e463bdc 100644 --- a/src/duckdb/src/common/types/timestamp.cpp +++ b/src/duckdb/src/common/types/timestamp.cpp @@ -26,6 +26,21 @@ static inline T TemporalRound(T value, T scale) { return UnsafeNumericCast((value + negative) / scale - negative); } +static bool CanStartTimestampSuffix(char c) { + return c == 'Z' || c == '+' || c == '-' || StringUtil::CharacterIsSpace(c); +} + +static idx_t TimestampTimeLength(const char *str, idx_t len) { + idx_t suffix_pos = 0; + while (suffix_pos < len && StringUtil::CharacterIsSpace(str[suffix_pos])) { + suffix_pos++; + } + while (suffix_pos < len && !CanStartTimestampSuffix(str[suffix_pos])) { + suffix_pos++; + } + return suffix_pos; +} + // timestamp/datetime uses 64 bits, high 32 bits for date and low 32 bits for time // string format is YYYY-MM-DDThh:mm:ssZ // T may be a space @@ -63,10 +78,10 @@ TimestampCastResult Timestamp::TryConvertTimestampTZ(const char *str, idx_t len, pos++; } idx_t time_pos = 0; - // TryConvertTime may recursively call us, so we opt for a stricter - // operation. Note that we can't pass strict== true here because we - // want to process any suffix. - if (!Time::TryConvertInterval(str + pos, len - pos, time_pos, time, false, nanos)) { + // TryConvertTime may recursively call us, so we use TryConvertInterval. + // Note that we can't pass strict== true here because we want to process any suffix. + auto time_len = TimestampTimeLength(str + pos, len - pos); + if (!Time::TryConvertInterval(str + pos, time_len, time_pos, time, false, nanos) || time_pos != time_len) { return TimestampCastResult::ERROR_INCORRECT_FORMAT; } // We parsed an interval, so make sure it is in range. @@ -121,7 +136,6 @@ TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, t optional_ptr nanos, bool strict) { string_t tz(nullptr, 0); bool has_offset = false; - // We don't understand TZ without an extension, so fail if one was provided. auto success = TryConvertTimestampTZ(str, len, result, use_offset, has_offset, tz, nanos); if (success != TimestampCastResult::SUCCESS) { return success; @@ -144,7 +158,14 @@ TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, t return TimestampCastResult::SUCCESS; } } - return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; + // We don't understand TZ without an extension, so fail if one was provided AND we were supposed to use it. + if (use_offset && !tz.Empty()) { + return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; + } else if (strict && !tz.Empty()) { + return TimestampCastResult::STRICT_UTC; + } else { + return TimestampCastResult::SUCCESS; + } } bool Timestamp::TryFromTimestampNanos(timestamp_t input, int32_t nanos, timestamp_ns_t &result) { @@ -180,7 +201,7 @@ TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, t string Timestamp::FormatError(const string &str) { return StringUtil::Format("invalid timestamp field format: \"%s\", " - "expected format is (YYYY-MM-DD HH:MM:SS[.US][±HH[:MM[:SS]]| ZONE])", + "expected format is (YYYY-MM-DD HH:MM[:SS[.US]][±HH[:MM[:SS]]| ZONE])", str); } @@ -315,11 +336,8 @@ timestamp_t Timestamp::FromString(const string &str, bool use_offset) { } string Timestamp::ToString(timestamp_t timestamp) { - if (timestamp == timestamp_t::infinity()) { - return Date::PINF.str; - } - if (timestamp == timestamp_t::ninfinity()) { - return Date::NINF.str; + if (!timestamp.IsFinite()) { + return Date::ToInfinity(timestamp); } date_t date; diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp index 819d9b1bf..c9d83abb6 100644 --- a/src/duckdb/src/common/types/value.cpp +++ b/src/duckdb/src/common/types/value.cpp @@ -152,6 +152,9 @@ Value::Value(double val) : type_(LogicalType::DOUBLE), is_null(false) { value_.double_ = val; } +Value::Value(const Identifier &val) : Value(val.GetIdentifierName()) { +} + Value::Value(const char *val) : Value(val ? string(val) : string()) { } @@ -769,12 +772,6 @@ Value Value::TIMESTAMP(int32_t year, int32_t month, int32_t day, int32_t hour, i return val; } -Value Value::AGGREGATE_STATE(const LogicalType &type, vector underlying_struct_values) { - // We just wrap the STRUCT value as the Vector's values are the same underneath, and also the type is being injected - // We do it for consistency where all LogicalType has its Value constructor defined - return STRUCT(type, std::move(underlying_struct_values)); -} - Value Value::STRUCT(const LogicalType &type, vector struct_values) { Value result; auto child_types = StructType::GetChildTypes(type); @@ -855,8 +852,8 @@ Value Value::MAP(const LogicalType &key_type, const LogicalType &value_type, vec struct_types.reserve(2); new_children.reserve(2); - struct_types.push_back(make_pair("key", key_type)); - struct_types.push_back(make_pair("value", value_type)); + struct_types.emplace_back(make_pair("key", key_type)); + struct_types.emplace_back(make_pair("value", value_type)); auto key = keys[i].DefaultCastAs(key_type); MapKeyCheck(unique_keys, key); @@ -976,7 +973,7 @@ Value Value::GEOMETRY(const_data_ptr_t data, idx_t len) { Value Value::TYPE(const LogicalType &type) { MemoryStream stream; SerializationOptions options; - options.serialization_compatibility = SerializationCompatibility::Latest(); + options.storage_compatibility = StorageCompatibility::Latest(); BinarySerializer::Serialize(type, stream, options); auto data_ptr = const_char_ptr_cast(stream.GetData()); auto data_len = stream.GetPosition(); @@ -1687,7 +1684,7 @@ hash_t Value::Hash() const { } Vector input(*this, count_t(1)); Vector result(LogicalType::HASH, 1); - VectorOperations::Hash(input, result, 1); + VectorOperations::Hash(input, result); if (result.GetVectorType() == VectorType::CONSTANT_VECTOR) { return *ConstantVector::GetData(result); @@ -1802,6 +1799,24 @@ string Value::ToSQLString() const { ret += "]"; return ret; } + case LogicalTypeId::MAP: { + // A bare `MAP {...}` literal infers its element types from the entries + // (and `MAP {}` infers MAP(INTEGER, INTEGER)), so it does not faithfully + // round-trip on its own. Append an explicit cast to the real type + auto &entries = MapValue::GetChildren(*this); + string ret = "MAP {"; + for (idx_t i = 0; i < entries.size(); i++) { + auto &kv = StructValue::GetChildren(entries[i]); + if (i > 0) { + ret += ", "; + } + ret += kv[0].ToSQLString(); + ret += ": "; + ret += kv[1].ToSQLString(); + } + ret += "}::" + type_.ToString(); + return ret; + } case LogicalTypeId::UNION: { string ret = "union_value("; auto union_tag = UnionValue::GetTag(*this); @@ -2189,7 +2204,7 @@ void Value::SerializeChildren(Serializer &serializer, const vector &child } void Value::SerializeInternal(Serializer &serializer, bool serialize_type) const { - if (serialize_type || !serializer.ShouldSerialize(4)) { + if (serialize_type || !serializer.ShouldSerialize(StorageVersion::V1_2_0)) { // only the root value needs to serialize its type // for forwards compatibility reasons, we also serialize the type always when targeting versions < v1.2.0 serializer.WriteProperty(100, "type", type_); @@ -2256,7 +2271,7 @@ void Value::SerializeInternal(Serializer &serializer, bool serialize_type) const auto blob_str = Blob::ToString(StringValue::Get(*this)); serializer.WriteProperty(102, "value", blob_str); } else if (type_.id() == LogicalTypeId::GEOMETRY) { - if (!serializer.ShouldSerialize(7)) { + if (!serializer.ShouldSerialize(StorageVersion::V1_5_0)) { // Write as old-style SPATIAL format string blob; Geometry::ToSpatialGeometry(StringValue::Get(*this), blob); diff --git a/src/duckdb/src/common/types/variant/variant_value.cpp b/src/duckdb/src/common/types/variant/variant_value.cpp index 489236922..b2766a4bf 100644 --- a/src/duckdb/src/common/types/variant/variant_value.cpp +++ b/src/duckdb/src/common/types/variant/variant_value.cpp @@ -690,7 +690,7 @@ void VariantValue::ToVARIANT(vector &input, Vector &result) { analyze_offsets.Initialize( Allocator::DefaultAllocator(), {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); - analyze_offsets.SetCardinality(count); + analyze_offsets.SetChildCardinality(count); variant::InitializeOffsets(analyze_offsets, count); for (idx_t i = 0; i < count; i++) { @@ -713,7 +713,7 @@ void VariantValue::ToVARIANT(vector &input, Vector &result) { conversion_offsets.Initialize( Allocator::DefaultAllocator(), {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); - conversion_offsets.SetCardinality(count); + conversion_offsets.SetChildCardinality(count); variant::InitializeOffsets(conversion_offsets, count); VariantVectorData variant_data(result); diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp index 591b0823c..8f8c7ec78 100644 --- a/src/duckdb/src/common/types/vector.cpp +++ b/src/duckdb/src/common/types/vector.cpp @@ -2,36 +2,23 @@ #include "duckdb/common/vector/constant_vector.hpp" #include "duckdb/common/vector/dictionary_vector.hpp" #include "duckdb/common/vector/flat_vector.hpp" -#include "duckdb/common/vector/fsst_vector.hpp" #include "duckdb/common/vector/list_vector.hpp" -#include "duckdb/common/vector/map_vector.hpp" #include "duckdb/common/vector/sequence_vector.hpp" #include "duckdb/common/vector/shredded_vector.hpp" #include "duckdb/common/vector/string_vector.hpp" -#include "duckdb/common/vector/union_vector.hpp" -#include "duckdb/common/vector/variant_vector.hpp" #include "duckdb/common/vector/struct_vector.hpp" #include "duckdb/common/types/vector.hpp" - #include "duckdb/common/assert.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/fsst.hpp" #include "duckdb/common/printer.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/type_visitor.hpp" -#include "duckdb/common/types/bit.hpp" #include "duckdb/common/types/null_value.hpp" #include "duckdb/common/types/sel_cache.hpp" #include "duckdb/common/types/value.hpp" -#include "duckdb/common/types/value_map.hpp" -#include "duckdb/common/types/bignum.hpp" -#include "duckdb/function/scalar/variant_utils.hpp" #include "duckdb/common/types/vector_cache.hpp" -#include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/storage/buffer/buffer_handle.hpp" -#include "duckdb/common/types/uuid.hpp" namespace duckdb { @@ -131,21 +118,70 @@ void Vector::ReferenceAndSetType(const Vector &other) { Reference(other); } +void CheckTypeIsReinterpretable(const LogicalType &a, const LogicalType &b) { + if (DBConfigOptions::global_verification_mode != DebugVerificationMode::VERIFY_VECTORS) { + return; + } + bool left_is_nested = a.IsNested(); + bool right_is_nested = b.IsNested(); + if (left_is_nested != right_is_nested) { + throw InternalException("Vector::Reinterpret (%s -> %s) - nested mismatch in reinterpret - either both need to " + "be nested or neither should be nested", + a, b); + } + auto left_internal = a.InternalType(); + auto right_internal = b.InternalType(); + if (!left_is_nested) { + // non-nested types - type size should be identical + if (GetTypeIdSize(left_internal) != GetTypeIdSize(right_internal)) { + throw InternalException( + "Vector::Reinterpret (%s -> %s) - attempting to reinterpret between types with different type sizes", a, + b); + } + return; + } + if (left_internal != right_internal) { + throw InternalException( + "Vector::Reinterpret (%s -> %s) - attempting to reinterpret between different nested types", a, b); + } + // recurse into children + switch (left_internal) { + case PhysicalType::STRUCT: { + auto &left_child_types = StructType::GetChildTypes(a); + auto &right_child_types = StructType::GetChildTypes(b); + if (left_child_types.size() != right_child_types.size()) { + throw InternalException( + "Vector::Reinterpret (%s -> %s) - attempting to reinterpret between struct types of different sizes", a, + b); + } + for (idx_t child_idx = 0; child_idx < left_child_types.size(); ++child_idx) { + CheckTypeIsReinterpretable(left_child_types[child_idx].second, right_child_types[child_idx].second); + } + break; + } + case PhysicalType::LIST: { + auto &left_child_type = ListType::GetChildType(a); + auto &right_child_type = ListType::GetChildType(b); + CheckTypeIsReinterpretable(left_child_type, right_child_type); + break; + } + case PhysicalType::ARRAY: { + auto &left_child_type = ArrayType::GetChildType(a); + auto &right_child_type = ArrayType::GetChildType(b); + CheckTypeIsReinterpretable(left_child_type, right_child_type); + break; + } + default: + throw InternalException("Unsupported nested type in CheckTypeIsReinterpretable"); + } +} + void Vector::Reinterpret(const Vector &other) { auto &this_type = GetType(); auto &other_type = other.GetType(); -#ifdef DEBUG - auto type_is_same = other_type == this_type; - bool this_is_nested = this_type.IsNested(); - bool other_is_nested = other_type.IsNested(); - - bool not_nested = this_is_nested == false && other_is_nested == false; - bool type_size_equal = GetTypeIdSize(this_type.InternalType()) == GetTypeIdSize(other_type.InternalType()); - //! Either the types are completely identical, or they are not nested and their physical type size is the same - //! The reason nested types are not allowed is because copying the auxiliary buffer does not happen recursively - //! e.g DOUBLE[] to BIGINT[], the type of the LIST would say BIGINT but the child Vector says DOUBLE - D_ASSERT((not_nested && type_size_equal) || type_is_same); -#endif + if (DBConfigOptions::global_verification_mode == DebugVerificationMode::VERIFY_VECTORS) { + CheckTypeIsReinterpretable(this_type, other_type); + } ConstReference(other); if (GetVectorType() == VectorType::DICTIONARY_VECTOR && other_type != this_type) { Vector new_vector(this_type, nullptr); @@ -469,13 +505,14 @@ void Vector::Shred(Vector &shredded_data, idx_t capacity) { } // FIXME: This should ideally be const -void Vector::Serialize(Serializer &serializer, idx_t count, bool compressed_serialization) { +void Vector::Serialize(Serializer &serializer, bool compressed_serialization) { auto &logical_type = GetType(); + auto count = size(); UnifiedVectorFormat vdata; // serialize compressed vectors to save space, but skip this if serializing into older versions - if (!serializer.ShouldSerialize(5)) { + if (!serializer.ShouldSerialize(StorageVersion::V1_3_0)) { compressed_serialization = false; } if (compressed_serialization) { @@ -508,14 +545,14 @@ void Vector::Serialize(Serializer &serializer, idx_t count, bool compressed_seri serializer.WriteProperty(90, "vector_type", VectorType::DICTIONARY_VECTOR); serializer.WriteProperty(91, "sel_vector", sel_data, sizeof(sel_t) * count); serializer.WriteProperty(92, "dict_count", used_count); - return dict.Serialize(serializer, used_count, false); + return dict.Serialize(serializer, false); } } } else if (vtype == VectorType::CONSTANT_VECTOR && count >= 1) { serializer.WriteProperty(90, "vector_type", VectorType::CONSTANT_VECTOR); // Resize to 1 so that size() == count == 1 during the recursive call, then restore FlatVector::SetSize(*this, 1); - Vector::Serialize(serializer, 1, false); // just serialize one value + Vector::Serialize(serializer, false); // just serialize one value FlatVector::SetSize(*this, count); return; } else if (vtype == VectorType::SEQUENCE_VECTOR) { @@ -550,13 +587,13 @@ void Vector::Serialize(Serializer &serializer, idx_t count, bool compressed_seri // constant size type: simple copy idx_t write_size = GetTypeIdSize(logical_type.InternalType()) * count; auto ptr = make_unsafe_uniq_array_uninitialized(write_size); - VectorOperations::WriteToStorage(*this, count, ptr.get()); + VectorOperations::WriteToStorage(*this, ptr.get()); serializer.WriteProperty(102, "data", ptr.get(), write_size); } else if (logical_type.id() == LogicalTypeId::GEOMETRY) { auto geoms = UnifiedVectorFormat::GetData(vdata); // Are we targeting an older serialization version? - if (!serializer.ShouldSerialize(7)) { + if (!serializer.ShouldSerialize(StorageVersion::V1_5_0)) { // Serialize data as old-style SPATIAL format string blob; serializer.WriteList(102, "data", count, [&](Serializer::List &list, idx_t i) { @@ -581,7 +618,7 @@ void Vector::Serialize(Serializer &serializer, idx_t count, bool compressed_seri case PhysicalType::VARCHAR: { auto strings = UnifiedVectorFormat::GetData(vdata); // new way to serialize strings, two blobs, first lengths, then string bytes - if (serializer.ShouldSerialize(8)) { + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { // we write all the lengths whether the string is null or not. lets not pull a parquet. auto length_data_length = sizeof(uint32_t) * count; auto length_data = make_unsafe_uniq_array_uninitialized(length_data_length); @@ -628,8 +665,7 @@ void Vector::Serialize(Serializer &serializer, idx_t count, bool compressed_seri // Serialize entries as a list serializer.WriteList(103, "children", entries.size(), [&](Serializer::List &list, idx_t i) { - list.WriteObject( - [&](Serializer &object) { entries[i].Serialize(object, count, compressed_serialization); }); + list.WriteObject([&](Serializer &object) { entries[i].Serialize(object, compressed_serialization); }); }); break; } @@ -658,9 +694,8 @@ void Vector::Serialize(Serializer &serializer, idx_t count, bool compressed_seri object.WriteProperty(101, "length", entries[i].length); }); }); - serializer.WriteObject(106, "child", [&](Serializer &object) { - child.Serialize(object, list_size, compressed_serialization); - }); + serializer.WriteObject(106, "child", + [&](Serializer &object) { child.Serialize(object, compressed_serialization); }); break; } case PhysicalType::ARRAY: { @@ -669,11 +704,9 @@ void Vector::Serialize(Serializer &serializer, idx_t count, bool compressed_seri auto &child = ArrayVector::GetChildMutable(serialized_vector); auto array_size = ArrayType::GetSize(serialized_vector.GetType()); - auto child_size = array_size * count; serializer.WriteProperty(103, "array_size", array_size); - serializer.WriteObject(104, "child", [&](Serializer &object) { - child.Serialize(object, child_size, compressed_serialization); - }); + serializer.WriteObject(104, "child", + [&](Serializer &object) { child.Serialize(object, compressed_serialization); }); break; } default: @@ -850,6 +883,18 @@ void Vector::SetVectorType(VectorType new_vector_type) { } } +void Vector::FlattenAndSetConstant() { + if (GetVectorType() != VectorType::FLAT_VECTOR && GetVectorType() != VectorType::CONSTANT_VECTOR) { + Flatten(); + } + // Struct buffers propagate vector type changes to their children. A flat or constant struct vector can still + // contain non-flat descendants, e.g. after slicing a list of structs, so normalize only those descendants first. + if (GetType().InternalType() == PhysicalType::STRUCT) { + BufferMutable().Cast().PrepareChildrenForSetConstant(); + } + SetVectorType(VectorType::CONSTANT_VECTOR); +} + void Vector::Verify(idx_t) const { Verify(); } @@ -874,7 +919,11 @@ void Vector::Verify(const SelectionVector &sel, idx_t count) const { buffer->Verify(GetType(), sel, count); } -void Vector::DebugTransformToDictionary(Vector &vector, idx_t count) { +void Vector::DebugTransformToDictionary(Vector &vector) { + const auto count = vector.size(); + if (count == 0) { + return; + } if (vector.GetVectorType() != VectorType::FLAT_VECTOR) { // only supported for flat vectors currently return; @@ -915,13 +964,14 @@ void Vector::DebugTransformToDictionary(Vector &vector, idx_t count) { vector.Verify(); } -void Vector::DebugShuffleNestedVector(Vector &vector, idx_t count) { +void Vector::DebugShuffleNestedVector(Vector &vector) { + const auto count = vector.size(); switch (vector.GetType().id()) { case LogicalTypeId::STRUCT: { auto &entries = StructVector::GetEntries(vector); // recurse into child elements for (auto &entry : entries) { - Vector::DebugShuffleNestedVector(entry, count); + Vector::DebugShuffleNestedVector(entry); } break; } @@ -961,7 +1011,7 @@ void Vector::DebugShuffleNestedVector(Vector &vector, idx_t count) { ListVector::SetListSize(vector, child_count); // recurse into child elements - Vector::DebugShuffleNestedVector(child_vector, child_count); + Vector::DebugShuffleNestedVector(child_vector); break; } default: diff --git a/src/duckdb/src/common/types/vector_buffer.cpp b/src/duckdb/src/common/types/vector_buffer.cpp index dcd515289..b4a8ecaa0 100644 --- a/src/duckdb/src/common/types/vector_buffer.cpp +++ b/src/duckdb/src/common/types/vector_buffer.cpp @@ -90,6 +90,9 @@ void VectorBuffer::Verify(const LogicalType &type) const { } void VectorBuffer::Verify(const LogicalType &type, const SelectionVector &sel, idx_t count) const { + if (count == 0) { + return; + } if (vector_type == VectorType::CONSTANT_VECTOR) { SelectionVector owned_sel; VerifyInternal(type, *ConstantVector::ZeroSelectionVector(1ULL, owned_sel), 1ULL); diff --git a/src/duckdb/src/common/util/util_parsed_expression.cpp b/src/duckdb/src/common/util/util_parsed_expression.cpp new file mode 100644 index 000000000..724b834fb --- /dev/null +++ b/src/duckdb/src/common/util/util_parsed_expression.cpp @@ -0,0 +1,1006 @@ +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_util.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#include "duckdb/parser/expression/list.hpp" +#include "duckdb/parser/expression_util.hpp" +#include "duckdb/common/types/hash.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" + +namespace duckdb { + +bool ParsedExpression::Equals(const BaseExpression &other) const { + if (!BaseExpression::Equals(other)) { + return false; + } + return Equals(static_cast(other)); +} + +bool ParsedExpression::Equals(const ParsedExpression &other) const { + auto &other_p = other; + if (expression_class != other_p.expression_class) { + return false; + } + if (type != other_p.type) { + return false; + } + return true; +} + +hash_t ParsedExpression::Hash() const { + hash_t hash = duckdb::Hash(static_cast(expression_class)); + hash = CombineHash(hash, duckdb::Hash(static_cast(type))); + ParsedExpressionIterator::EnumerateChildren( + *this, [&](const ParsedExpression &child) { hash = CombineHash(hash, child.Hash()); }); + return hash; +} + +ConstChildrenView ParsedExpression::Children() const { + ConstChildrenView result; + switch (GetExpressionClass()) { + case ExpressionClass::BETWEEN: { + auto &cast_expr = Cast(); + result.Append(cast_expr.Input()); + result.Append(cast_expr.LowerBound()); + result.Append(cast_expr.UpperBound()); + break; + } + case ExpressionClass::CASE: { + auto &cast_expr = Cast(); + for (auto &check : cast_expr.CaseChecks()) { + result.Append(*check.when_expr); + result.Append(*check.then_expr); + } + result.Append(cast_expr.Else()); + break; + } + case ExpressionClass::CAST: { + auto &cast_expr = Cast(); + result.Append(cast_expr.Child()); + break; + } + case ExpressionClass::COLLATE: { + auto &cast_expr = Cast(); + result.Append(cast_expr.Child()); + break; + } + case ExpressionClass::COMPARISON: { + auto &cast_expr = Cast(); + result.Append(cast_expr.Left()); + result.Append(cast_expr.Right()); + break; + } + case ExpressionClass::CONJUNCTION: { + auto &cast_expr = Cast(); + for (auto &child : cast_expr.GetChildren()) { + result.Append(*child); + } + break; + } + case ExpressionClass::FUNCTION: { + auto &cast_expr = Cast(); + if (cast_expr.Filter()) { + result.Append(*cast_expr.Filter()); + } + if (cast_expr.OrderBy()) { + for (auto &order : cast_expr.OrderBy()->orders) { + result.Append(*order.expression); + } + } + for (auto &arg : cast_expr.GetArguments()) { + result.Append(arg.GetExpression()); + } + break; + } + case ExpressionClass::LAMBDA: { + auto &cast_expr = Cast(); + result.Append(cast_expr.Left()); + result.Append(cast_expr.Right()); + break; + } + case ExpressionClass::OPERATOR: { + auto &cast_expr = Cast(); + for (auto &child : cast_expr.GetChildren()) { + result.Append(*child); + } + break; + } + case ExpressionClass::STAR: { + auto &cast_expr = Cast(); + for (auto &item : cast_expr.ReplaceList()) { + result.Append(*item.second); + } + if (cast_expr.Expression()) { + result.Append(*cast_expr.Expression()); + } + break; + } + case ExpressionClass::SUBQUERY: { + auto &cast_expr = Cast(); + if (cast_expr.GetChild()) { + result.Append(*cast_expr.GetChild()); + } + break; + } + case ExpressionClass::WINDOW: { + auto &cast_expr = Cast(); + for (auto &child : cast_expr.Partitions()) { + result.Append(*child); + } + for (auto &order : cast_expr.OrderBy()) { + result.Append(*order.expression); + } + if (cast_expr.StartExpr()) { + result.Append(*cast_expr.StartExpr()); + } + if (cast_expr.EndExpr()) { + result.Append(*cast_expr.EndExpr()); + } + if (cast_expr.Filter()) { + result.Append(*cast_expr.Filter()); + } + for (auto &order : cast_expr.ArgOrders()) { + result.Append(*order.expression); + } + for (auto &arg : cast_expr.GetArguments()) { + result.Append(arg.GetExpression()); + } + break; + } + case ExpressionClass::TYPE: { + auto &cast_expr = Cast(); + for (auto &child : cast_expr.GetChildren()) { + result.Append(*child); + } + break; + } + case ExpressionClass::BOUND_EXPRESSION: + case ExpressionClass::COLUMN_REF: + case ExpressionClass::LAMBDA_REF: + case ExpressionClass::CONSTANT: + case ExpressionClass::DEFAULT: + case ExpressionClass::PARAMETER: + case ExpressionClass::POSITIONAL_REFERENCE: + // these node types have no children + break; + default: + throw NotImplementedException("Unimplemented expression class"); + } + return result; +} + +ChildrenView ParsedExpression::ChildrenMutable() { + ChildrenView result; + switch (GetExpressionClass()) { + case ExpressionClass::BETWEEN: { + auto &cast_expr = Cast(); + result.Append(cast_expr.InputMutable()); + result.Append(cast_expr.LowerBoundMutable()); + result.Append(cast_expr.UpperBoundMutable()); + break; + } + case ExpressionClass::CASE: { + auto &cast_expr = Cast(); + for (auto &check : cast_expr.CaseChecksMutable()) { + result.Append(check.when_expr); + result.Append(check.then_expr); + } + result.Append(cast_expr.ElseMutable()); + break; + } + case ExpressionClass::CAST: { + auto &cast_expr = Cast(); + result.Append(cast_expr.ChildMutable()); + break; + } + case ExpressionClass::COLLATE: { + auto &cast_expr = Cast(); + result.Append(cast_expr.ChildMutable()); + break; + } + case ExpressionClass::COMPARISON: { + auto &cast_expr = Cast(); + result.Append(cast_expr.LeftMutable()); + result.Append(cast_expr.RightMutable()); + break; + } + case ExpressionClass::CONJUNCTION: { + auto &cast_expr = Cast(); + for (auto &child : cast_expr.GetChildrenMutable()) { + result.Append(child); + } + break; + } + case ExpressionClass::FUNCTION: { + auto &cast_expr = Cast(); + if (cast_expr.FilterMutable()) { + result.Append(cast_expr.FilterMutable()); + } + if (cast_expr.OrderByMutable()) { + for (auto &order : cast_expr.OrderByMutable()->orders) { + result.Append(order.expression); + } + } + for (auto &arg : cast_expr.GetArgumentsMutable()) { + result.Append(arg.GetExpressionMutable()); + } + break; + } + case ExpressionClass::LAMBDA: { + auto &cast_expr = Cast(); + result.Append(cast_expr.LeftMutable()); + result.Append(cast_expr.RightMutable()); + break; + } + case ExpressionClass::OPERATOR: { + auto &cast_expr = Cast(); + for (auto &child : cast_expr.GetChildrenMutable()) { + result.Append(child); + } + break; + } + case ExpressionClass::STAR: { + auto &cast_expr = Cast(); + for (auto &item : cast_expr.ReplaceListMutable()) { + result.Append(item.second); + } + if (cast_expr.ExpressionMutable()) { + result.Append(cast_expr.ExpressionMutable()); + } + break; + } + case ExpressionClass::SUBQUERY: { + auto &cast_expr = Cast(); + if (cast_expr.GetChildMutable()) { + result.Append(cast_expr.GetChildMutable()); + } + break; + } + case ExpressionClass::WINDOW: { + auto &cast_expr = Cast(); + for (auto &child : cast_expr.PartitionsMutable()) { + result.Append(child); + } + for (auto &order : cast_expr.OrderByMutable()) { + result.Append(order.expression); + } + if (cast_expr.StartExprMutable()) { + result.Append(cast_expr.StartExprMutable()); + } + if (cast_expr.EndExprMutable()) { + result.Append(cast_expr.EndExprMutable()); + } + if (cast_expr.FilterMutable()) { + result.Append(cast_expr.FilterMutable()); + } + for (auto &order : cast_expr.ArgOrdersMutable()) { + result.Append(order.expression); + } + for (auto &arg : cast_expr.GetArgumentsMutable()) { + result.Append(arg.GetExpressionMutable()); + } + break; + } + case ExpressionClass::TYPE: { + auto &cast_expr = Cast(); + for (auto &child : cast_expr.GetChildren()) { + result.Append(child); + } + break; + } + case ExpressionClass::BOUND_EXPRESSION: + case ExpressionClass::COLUMN_REF: + case ExpressionClass::LAMBDA_REF: + case ExpressionClass::CONSTANT: + case ExpressionClass::DEFAULT: + case ExpressionClass::PARAMETER: + case ExpressionClass::POSITIONAL_REFERENCE: + // these node types have no children + break; + default: + throw NotImplementedException("Unimplemented expression class"); + } + return result; +} + +void ParsedExpression::CopyBase(const ParsedExpression &other) { + expression_class = other.expression_class; + type = other.type; + alias = other.alias; + query_location = other.query_location; +} + +bool BetweenExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (!ParsedExpression::Equals(input, other_p.input)) { + return false; + } + if (!ParsedExpression::Equals(lower, other_p.lower)) { + return false; + } + if (!ParsedExpression::Equals(upper, other_p.upper)) { + return false; + } + return true; +} + + +unique_ptr BetweenExpression::Copy() const { + auto copy = duckdb::unique_ptr(new BetweenExpression()); + copy->input = input ? input->Copy() : nullptr; + copy->lower = lower ? lower->Copy() : nullptr; + copy->upper = upper ? upper->Copy() : nullptr; + copy->CopyBase(*this); + return std::move(copy); +} + +bool CaseExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (case_checks.size() != other_p.case_checks.size()) { + return false; + } + for (idx_t i = 0; i < case_checks.size(); i++) { + if (!case_checks[i].when_expr->Equals(*other_p.case_checks[i].when_expr)) { + return false; + } + if (!case_checks[i].then_expr->Equals(*other_p.case_checks[i].then_expr)) { + return false; + } + } + if (!ParsedExpression::Equals(else_expr, other_p.else_expr)) { + return false; + } + return true; +} + + +unique_ptr CaseExpression::Copy() const { + auto copy = duckdb::unique_ptr(new CaseExpression()); + for (auto &check : case_checks) { + CaseCheck new_check; + new_check.when_expr = check.when_expr->Copy(); + new_check.then_expr = check.then_expr->Copy(); + copy->case_checks.push_back(std::move(new_check)); + } + copy->else_expr = else_expr ? else_expr->Copy() : nullptr; + copy->CopyBase(*this); + return std::move(copy); +} + +bool CastExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (!ParsedExpression::Equals(child, other_p.child)) { + return false; + } + if (cast_type != other_p.cast_type) { + return false; + } + if (try_cast != other_p.try_cast) { + return false; + } + return true; +} + +hash_t CastExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, cast_type.Hash()); + hash = CombineHash(hash, duckdb::Hash(try_cast)); + return hash; +} + +unique_ptr CastExpression::Copy() const { + auto copy = duckdb::unique_ptr(new CastExpression()); + copy->child = child ? child->Copy() : nullptr; + copy->cast_type = cast_type; + copy->try_cast = try_cast; + copy->CopyBase(*this); + return std::move(copy); +} + +bool CollateExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (!ParsedExpression::Equals(child, other_p.child)) { + return false; + } + if (collation != other_p.collation) { + return false; + } + return true; +} + +hash_t CollateExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, duckdb::Hash(collation.c_str())); + return hash; +} + +unique_ptr CollateExpression::Copy() const { + auto copy = duckdb::unique_ptr(new CollateExpression()); + copy->child = child ? child->Copy() : nullptr; + copy->collation = collation; + copy->CopyBase(*this); + return std::move(copy); +} + +bool ColumnRefExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (column_names.size() != other_p.column_names.size()) { + return false; + } + for (idx_t i = 0; i < column_names.size(); i++) { + if (column_names[i] != other_p.column_names[i]) { + return false; + } + } + return true; +} + +hash_t ColumnRefExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + for (auto &s : column_names) { + hash = CombineHash(hash, s.Hash()); + } + return hash; +} + +unique_ptr ColumnRefExpression::Copy() const { + auto copy = duckdb::unique_ptr(new ColumnRefExpression()); + copy->column_names = column_names; + copy->CopyBase(*this); + return std::move(copy); +} + +bool LambdaRefExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (lambda_idx != other_p.lambda_idx) { + return false; + } + if (column_name != other_p.column_name) { + return false; + } + return true; +} + +hash_t LambdaRefExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, duckdb::Hash(lambda_idx)); + hash = CombineHash(hash, column_name.Hash()); + return hash; +} + +unique_ptr LambdaRefExpression::Copy() const { + throw InternalException("lambda reference expressions are transient, Copy should never be called"); +} + +bool ComparisonExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (!ParsedExpression::Equals(left, other_p.left)) { + return false; + } + if (!ParsedExpression::Equals(right, other_p.right)) { + return false; + } + return true; +} + + +unique_ptr ComparisonExpression::Copy() const { + auto copy = duckdb::unique_ptr(new ComparisonExpression()); + copy->left = left ? left->Copy() : nullptr; + copy->right = right ? right->Copy() : nullptr; + copy->CopyBase(*this); + return std::move(copy); +} + +bool ConjunctionExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (!ParsedExpression::ListEquals(children, other_p.children)) { + return false; + } + return true; +} + + +unique_ptr ConjunctionExpression::Copy() const { + auto copy = duckdb::unique_ptr(new ConjunctionExpression()); + for (auto &child : children) { + copy->children.push_back(child->Copy()); + } + copy->CopyBase(*this); + return std::move(copy); +} + +bool ConstantExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (value.type() != other_p.value.type() || ValueOperations::DistinctFrom(value, other_p.value)) { + return false; + } + return true; +} + +hash_t ConstantExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, value.Hash()); + return hash; +} + +unique_ptr ConstantExpression::Copy() const { + auto copy = duckdb::unique_ptr(new ConstantExpression()); + copy->value = value; + copy->CopyBase(*this); + return std::move(copy); +} + +bool DefaultExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + return true; +} + + +unique_ptr DefaultExpression::Copy() const { + auto copy = duckdb::unique_ptr(new DefaultExpression()); + copy->CopyBase(*this); + return std::move(copy); +} + +bool FunctionExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (function_name != other_p.function_name) { + return false; + } + if (schema != other_p.schema) { + return false; + } + if (!ParsedExpression::Equals(filter, other_p.filter)) { + return false; + } + if (!OrderModifier::Equals(order_bys, other_p.order_bys)) { + return false; + } + if (distinct != other_p.distinct) { + return false; + } + if (export_state != other_p.export_state) { + return false; + } + if (catalog != other_p.catalog) { + return false; + } + if (arguments.size() != other_p.arguments.size()) { + return false; + } + for (idx_t i = 0; i < arguments.size(); i++) { + if (!arguments[i].Equals(other_p.arguments[i])) { + return false; + } + } + return true; +} + +hash_t FunctionExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, function_name.Hash()); + hash = CombineHash(hash, schema.Hash()); + hash = CombineHash(hash, duckdb::Hash(distinct)); + hash = CombineHash(hash, duckdb::Hash(export_state)); + hash = CombineHash(hash, catalog.Hash()); + return hash; +} + +unique_ptr FunctionExpression::Copy() const { + auto copy = duckdb::unique_ptr(new FunctionExpression()); + copy->is_legacy_function_call = is_legacy_function_call; + copy->function_name = function_name; + copy->schema = schema; + copy->filter = filter ? filter->Copy() : nullptr; + copy->order_bys = order_bys ? unique_ptr_cast(order_bys->Copy()) : nullptr; + copy->distinct = distinct; + copy->is_operator = is_operator; + copy->export_state = export_state; + copy->catalog = catalog; + for (auto &arg : arguments) { + copy->arguments.emplace_back(arg.Copy()); + } + copy->CopyBase(*this); + return std::move(copy); +} + +bool LambdaExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (!ParsedExpression::Equals(lhs, other_p.lhs)) { + return false; + } + if (!ParsedExpression::Equals(expr, other_p.expr)) { + return false; + } + return true; +} + + +unique_ptr LambdaExpression::Copy() const { + auto copy = duckdb::unique_ptr(new LambdaExpression()); + copy->lhs = lhs ? lhs->Copy() : nullptr; + copy->expr = expr ? expr->Copy() : nullptr; + copy->syntax_type = syntax_type; + copy->CopyBase(*this); + return std::move(copy); +} + +bool OperatorExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (!ParsedExpression::ListEquals(children, other_p.children)) { + return false; + } + return true; +} + + +unique_ptr OperatorExpression::Copy() const { + auto copy = duckdb::unique_ptr(new OperatorExpression()); + for (auto &child : children) { + copy->children.push_back(child->Copy()); + } + copy->CopyBase(*this); + return std::move(copy); +} + +bool ParameterExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (identifier != other_p.identifier) { + return false; + } + return true; +} + +hash_t ParameterExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, identifier.Hash()); + return hash; +} + +unique_ptr ParameterExpression::Copy() const { + auto copy = duckdb::unique_ptr(new ParameterExpression()); + copy->identifier = identifier; + copy->CopyBase(*this); + return std::move(copy); +} + +bool PositionalReferenceExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (index != other_p.index) { + return false; + } + return true; +} + +hash_t PositionalReferenceExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, duckdb::Hash(index)); + return hash; +} + +unique_ptr PositionalReferenceExpression::Copy() const { + auto copy = duckdb::unique_ptr(new PositionalReferenceExpression()); + copy->index = index; + copy->CopyBase(*this); + return std::move(copy); +} + +bool StarExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (relation_name != other_p.relation_name) { + return false; + } + if (exclude_list != other_p.exclude_list) { + return false; + } + if (replace_list.size() != other_p.replace_list.size()) { + return false; + } + for (auto &entry : replace_list) { + auto other_entry = other_p.replace_list.find(entry.first); + if (other_entry == other_p.replace_list.end()) { + return false; + } + if (!entry.second->Equals(*other_entry->second)) { + return false; + } + } + if (columns != other_p.columns) { + return false; + } + if (!ParsedExpression::Equals(expr, other_p.expr)) { + return false; + } + if (rename_list != other_p.rename_list) { + return false; + } + return true; +} + +hash_t StarExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, relation_name.Hash()); + hash = CombineHash(hash, duckdb::Hash(columns)); + return hash; +} + +unique_ptr StarExpression::Copy() const { + auto copy = duckdb::unique_ptr(new StarExpression()); + copy->relation_name = relation_name; + copy->exclude_list = exclude_list; + for (auto &entry : replace_list) { + copy->replace_list[entry.first] = entry.second->Copy(); + } + copy->columns = columns; + copy->expr = expr ? expr->Copy() : nullptr; + copy->rename_list = rename_list; + copy->CopyBase(*this); + return std::move(copy); +} + +bool SubqueryExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (subquery_type != other_p.subquery_type) { + return false; + } + if (!subquery || !other_p.subquery || !subquery->Equals(*other_p.subquery)) { + return false; + } + if (!ParsedExpression::Equals(child, other_p.child)) { + return false; + } + if (comparison_type != other_p.comparison_type) { + return false; + } + return true; +} + +hash_t SubqueryExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, duckdb::Hash(static_cast(subquery_type))); + hash = CombineHash(hash, duckdb::Hash(static_cast(comparison_type))); + return hash; +} + +unique_ptr SubqueryExpression::Copy() const { + auto copy = duckdb::unique_ptr(new SubqueryExpression()); + copy->subquery_type = subquery_type; + copy->subquery = subquery ? unique_ptr_cast(subquery->Copy()) : nullptr; + copy->child = child ? child->Copy() : nullptr; + copy->comparison_type = comparison_type; + copy->CopyBase(*this); + return std::move(copy); +} + +bool WindowExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (function_name != other_p.function_name) { + return false; + } + if (schema != other_p.schema) { + return false; + } + if (catalog != other_p.catalog) { + return false; + } + if (!ParsedExpression::ListEquals(partitions, other_p.partitions)) { + return false; + } + if (orders.size() != other_p.orders.size()) { + return false; + } + for (idx_t i = 0; i < orders.size(); i++) { + if (orders[i].type != other_p.orders[i].type) { + return false; + } + if (orders[i].null_order != other_p.orders[i].null_order) { + return false; + } + if (!orders[i].expression->Equals(*other_p.orders[i].expression)) { + return false; + } + } + if (start != other_p.start) { + return false; + } + if (end != other_p.end) { + return false; + } + if (!ParsedExpression::Equals(start_expr, other_p.start_expr)) { + return false; + } + if (!ParsedExpression::Equals(end_expr, other_p.end_expr)) { + return false; + } + if (ignore_nulls != other_p.ignore_nulls) { + return false; + } + if (!ParsedExpression::Equals(filter_expr, other_p.filter_expr)) { + return false; + } + if (exclude_clause != other_p.exclude_clause) { + return false; + } + if (distinct != other_p.distinct) { + return false; + } + if (arg_orders.size() != other_p.arg_orders.size()) { + return false; + } + for (idx_t i = 0; i < arg_orders.size(); i++) { + if (arg_orders[i].type != other_p.arg_orders[i].type) { + return false; + } + if (arg_orders[i].null_order != other_p.arg_orders[i].null_order) { + return false; + } + if (!arg_orders[i].expression->Equals(*other_p.arg_orders[i].expression)) { + return false; + } + } + if (has_ignore_nulls != other_p.has_ignore_nulls) { + return false; + } + if (arguments.size() != other_p.arguments.size()) { + return false; + } + for (idx_t i = 0; i < arguments.size(); i++) { + if (!arguments[i].Equals(other_p.arguments[i])) { + return false; + } + } + return true; +} + +hash_t WindowExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, function_name.Hash()); + hash = CombineHash(hash, schema.Hash()); + hash = CombineHash(hash, catalog.Hash()); + for (idx_t i = 0; i < orders.size(); i++) { + hash = CombineHash(hash, duckdb::Hash(static_cast(orders[i].type))); + hash = CombineHash(hash, duckdb::Hash(static_cast(orders[i].null_order))); + } + hash = CombineHash(hash, duckdb::Hash(static_cast(start))); + hash = CombineHash(hash, duckdb::Hash(static_cast(end))); + hash = CombineHash(hash, duckdb::Hash(ignore_nulls)); + hash = CombineHash(hash, duckdb::Hash(static_cast(exclude_clause))); + hash = CombineHash(hash, duckdb::Hash(distinct)); + for (idx_t i = 0; i < arg_orders.size(); i++) { + hash = CombineHash(hash, duckdb::Hash(static_cast(arg_orders[i].type))); + hash = CombineHash(hash, duckdb::Hash(static_cast(arg_orders[i].null_order))); + } + hash = CombineHash(hash, duckdb::Hash(has_ignore_nulls)); + return hash; +} + +unique_ptr WindowExpression::Copy() const { + auto copy = duckdb::unique_ptr(new WindowExpression()); + copy->is_legacy_function_call = is_legacy_function_call; + copy->function_name = function_name; + copy->schema = schema; + copy->catalog = catalog; + for (auto &child : partitions) { + copy->partitions.push_back(child->Copy()); + } + for (auto &order : orders) { + copy->orders.emplace_back(order.type, order.null_order, order.expression->Copy()); + } + copy->start = start; + copy->end = end; + copy->start_expr = start_expr ? start_expr->Copy() : nullptr; + copy->end_expr = end_expr ? end_expr->Copy() : nullptr; + copy->ignore_nulls = ignore_nulls; + copy->filter_expr = filter_expr ? filter_expr->Copy() : nullptr; + copy->exclude_clause = exclude_clause; + copy->distinct = distinct; + for (auto &order : arg_orders) { + copy->arg_orders.emplace_back(order.type, order.null_order, order.expression->Copy()); + } + copy->has_ignore_nulls = has_ignore_nulls; + for (auto &arg : arguments) { + copy->arguments.emplace_back(arg.Copy()); + } + copy->CopyBase(*this); + return std::move(copy); +} + +bool TypeExpression::Equals(const ParsedExpression &other) const { + if (!ParsedExpression::Equals(other)) { + return false; + } + auto &other_p = other.Cast(); + if (catalog != other_p.catalog) { + return false; + } + if (schema != other_p.schema) { + return false; + } + if (type_name != other_p.type_name) { + return false; + } + if (!ParsedExpression::ListEquals(children, other_p.children)) { + return false; + } + return true; +} + +hash_t TypeExpression::Hash() const { + hash_t hash = ParsedExpression::Hash(); + hash = CombineHash(hash, catalog.Hash()); + hash = CombineHash(hash, schema.Hash()); + hash = CombineHash(hash, type_name.Hash()); + return hash; +} + +unique_ptr TypeExpression::Copy() const { + auto copy = duckdb::unique_ptr(new TypeExpression()); + copy->catalog = catalog; + copy->schema = schema; + copy->type_name = type_name; + for (auto &child : children) { + copy->children.push_back(child->Copy()); + } + copy->CopyBase(*this); + return std::move(copy); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/value_operations/comparison_operations.cpp b/src/duckdb/src/common/value_operations/comparison_operations.cpp index 471a355d7..6dbaf741a 100644 --- a/src/duckdb/src/common/value_operations/comparison_operations.cpp +++ b/src/duckdb/src/common/value_operations/comparison_operations.cpp @@ -80,7 +80,7 @@ static bool TemplatedBooleanOperation(const Value &left, const Value &right) { Value left_copy = left; Value right_copy = right; - auto comparison_type = LogicalType::ForceMaxLogicalType(left_type, right_type); + auto comparison_type = LogicalType::DefaultForceMaxLogicalType(left_type, right_type); if (!left_copy.DefaultTryCastAs(comparison_type) || !right_copy.DefaultTryCastAs(comparison_type)) { return false; } diff --git a/src/duckdb/src/common/vector/dictionary_vector.cpp b/src/duckdb/src/common/vector/dictionary_vector.cpp index b32ba8dae..70d7b5740 100644 --- a/src/duckdb/src/common/vector/dictionary_vector.cpp +++ b/src/duckdb/src/common/vector/dictionary_vector.cpp @@ -167,17 +167,17 @@ buffer_ptr DictionaryVector::CreateReusableDictionary(const Log return entry; } -const Vector &DictionaryVector::GetCachedHashes(Vector &input) { +const Vector &DictionaryVector::GetCachedHashes(const Vector &input) { D_ASSERT(CanCacheHashes(input)); - auto &entry = input.BufferMutable().Cast().GetEntry(); + const auto &entry = input.Buffer().Cast().GetEntry(); lock_guard guard(entry.cached_hashes_lock); if (!entry.cached_hashes) { // Uninitialized: hash the dictionary const auto dictionary_size = DictionarySize(input).GetIndex(); entry.cached_hashes = make_uniq(LogicalType::HASH, dictionary_size); - VectorOperations::Hash(entry.data, *entry.cached_hashes, dictionary_size); + VectorOperations::Hash(entry.data, *entry.cached_hashes); } return *entry.cached_hashes; } diff --git a/src/duckdb/src/common/vector/list_vector.cpp b/src/duckdb/src/common/vector/list_vector.cpp index ab3d61e37..127125e54 100644 --- a/src/duckdb/src/common/vector/list_vector.cpp +++ b/src/duckdb/src/common/vector/list_vector.cpp @@ -359,7 +359,7 @@ void ListVector::PushBack(Vector &target, const Value &insert) { target_buffer.PushBack(insert); } -idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { +idx_t ListVector::GetConsecutiveChildList(const Vector &list, Vector &result, idx_t offset, idx_t count) { auto info = ListVector::GetConsecutiveChildListInfo(list, offset, count); if (info.needs_slicing) { SelectionVector sel(info.child_list_info.length); @@ -371,7 +371,7 @@ idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t of return info.child_list_info.length; } -idx_t ListVector::GetTotalEntryCount(Vector &list, idx_t count) { +idx_t ListVector::GetTotalEntryCount(const Vector &list) { idx_t total_count = 0; for (auto entry : list.ValidValues()) { total_count += entry.GetValue().length; @@ -379,7 +379,7 @@ idx_t ListVector::GetTotalEntryCount(Vector &list, idx_t count) { return total_count; } -ConsecutiveChildListInfo ListVector::GetConsecutiveChildListInfo(Vector &list, idx_t offset, idx_t count) { +ConsecutiveChildListInfo ListVector::GetConsecutiveChildListInfo(const Vector &list, idx_t offset, idx_t count) { ConsecutiveChildListInfo info; auto list_data = list.Values(); @@ -432,7 +432,7 @@ ConsecutiveChildListInfo ListVector::GetConsecutiveChildListInfo(Vector &list, i return info; } -void ListVector::GetConsecutiveChildSelVector(Vector &list, SelectionVector &sel, idx_t offset, idx_t count) { +void ListVector::GetConsecutiveChildSelVector(const Vector &list, SelectionVector &sel, idx_t offset, idx_t count) { auto list_data = list.Values(); // SelectionVector child_sel(info.second.length); diff --git a/src/duckdb/src/common/vector/map_vector.cpp b/src/duckdb/src/common/vector/map_vector.cpp index 30277334c..2774c78a3 100644 --- a/src/duckdb/src/common/vector/map_vector.cpp +++ b/src/duckdb/src/common/vector/map_vector.cpp @@ -23,7 +23,7 @@ const Vector &MapVector::GetValues(const Vector &vector) { return GetValues((Vector &)vector); } -MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const SelectionVector &sel) { +MapInvalidReason MapVector::CheckMapValidity(const Vector &map, idx_t count, const SelectionVector &sel) { D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); // unify the MAP vector, which is a physical LIST vector @@ -61,7 +61,7 @@ MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const Sel return MapInvalidReason::VALID; } -void MapVector::MapConversionVerify(Vector &vector, idx_t count) { +void MapVector::MapConversionVerify(const Vector &vector, idx_t count) { auto reason = MapVector::CheckMapValidity(vector, count); EvalMapInvalidReason(reason); } diff --git a/src/duckdb/src/common/vector/shredded_vector.cpp b/src/duckdb/src/common/vector/shredded_vector.cpp index a47f35281..a42fb0c13 100644 --- a/src/duckdb/src/common/vector/shredded_vector.cpp +++ b/src/duckdb/src/common/vector/shredded_vector.cpp @@ -44,8 +44,8 @@ Value ShreddedVectorBuffer::GetValue(const LogicalType &type, idx_t index) const auto unshredded_val = unshredded.GetValue(index); child_list_t shredded_subtypes; - shredded_subtypes.push_back(make_pair("unshredded", unshredded.GetType())); - shredded_subtypes.push_back(make_pair("shredded", shredded.GetType())); + shredded_subtypes.emplace_back(make_pair("unshredded", unshredded.GetType())); + shredded_subtypes.emplace_back(make_pair("shredded", shredded.GetType())); Vector new_shredded(LogicalType::STRUCT(std::move(shredded_subtypes))); StructVector::GetEntries(new_shredded)[0].Reference(unshredded_val, count_t(1)); StructVector::GetEntries(new_shredded)[1].Reference(shredded_val, count_t(1)); @@ -113,7 +113,7 @@ void ShreddedVector::Unshred(const Vector &vec, const SelectionVector &sel, idx_ vec.ConstReference(unshredded_vector); } -bool ShreddedVector::IsFullyShredded(Vector &vec) { +bool ShreddedVector::IsFullyShredded(const Vector &vec) { auto &unshredded_vector = GetUnshreddedVector(vec); if (unshredded_vector.GetVectorType() == VectorType::CONSTANT_VECTOR && ConstantVector::IsNull(unshredded_vector)) { return true; diff --git a/src/duckdb/src/common/vector/struct_vector.cpp b/src/duckdb/src/common/vector/struct_vector.cpp index 2c1848ba0..48df3c1c5 100644 --- a/src/duckdb/src/common/vector/struct_vector.cpp +++ b/src/duckdb/src/common/vector/struct_vector.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/vector/struct_vector.hpp" +#include "duckdb/common/enums/vector_type.hpp" #include "duckdb/common/vector/constant_vector.hpp" #include "duckdb/common/vector/dictionary_vector.hpp" #include "duckdb/common/vector/flat_vector.hpp" @@ -45,6 +46,17 @@ void VectorStructBuffer::SetVectorType(VectorType new_vector_type) { } } +void VectorStructBuffer::PrepareChildrenForSetConstant() { + for (auto &child : children) { + if (child.GetVectorType() != VectorType::FLAT_VECTOR && child.GetVectorType() != VectorType::CONSTANT_VECTOR) { + child.Flatten(); + } + if (child.GetType().InternalType() == PhysicalType::STRUCT) { + child.BufferMutable().Cast().PrepareChildrenForSetConstant(); + } + } +} + VectorStructBuffer::~VectorStructBuffer() { } @@ -223,9 +235,6 @@ Value VectorStructBuffer::GetValue(const LogicalType &type, idx_t index) const { for (idx_t i = 0; i < children.size(); i++) { child_values.push_back(children[i].GetValue(index)); } - if (type.id() == LogicalTypeId::AGGREGATE_STATE) { - return Value::AGGREGATE_STATE(type, std::move(child_values)); - } return Value::STRUCT(type, std::move(child_values)); } } @@ -279,8 +288,7 @@ buffer_ptr VectorStructBuffer::FlattenSliceInternal(const LogicalT vector &StructVector::GetEntries(Vector &vector) { D_ASSERT(vector.GetType().id() == LogicalTypeId::STRUCT || vector.GetType().id() == LogicalTypeId::UNION || - vector.GetType().id() == LogicalTypeId::VARIANT || - vector.GetType().id() == LogicalTypeId::AGGREGATE_STATE); + vector.GetType().id() == LogicalTypeId::VARIANT); if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { throw InternalException("Struct vectors cannot be dictionary vectors"); diff --git a/src/duckdb/src/common/vector_operations/boolean_operators.cpp b/src/duckdb/src/common/vector_operations/boolean_operators.cpp index ac45b1fa9..c2a80a23b 100644 --- a/src/duckdb/src/common/vector_operations/boolean_operators.cpp +++ b/src/duckdb/src/common/vector_operations/boolean_operators.cpp @@ -4,7 +4,6 @@ // operations AND OR ! //===--------------------------------------------------------------------===// -#include "duckdb/common/vector_operations/binary_executor.hpp" #include "duckdb/common/vector_operations/unary_executor.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" @@ -15,7 +14,8 @@ namespace { // AND/OR //===--------------------------------------------------------------------===// template -void TemplatedBooleanNullmask(Vector &left, Vector &right, Vector &result, idx_t count) { +void TemplatedBooleanNullmask(const Vector &left, const Vector &right, Vector &result) { + const auto count = (left.GetVectorType() == VectorType::CONSTANT_VECTOR) ? right.size() : left.size(); D_ASSERT(left.GetType().id() == LogicalTypeId::BOOLEAN && right.GetType().id() == LogicalTypeId::BOOLEAN && result.GetType().id() == LogicalTypeId::BOOLEAN); @@ -40,22 +40,23 @@ void TemplatedBooleanNullmask(Vector &left, Vector &right, Vector &result, idx_t auto right_data = right.Values(); result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::ScatterWriter(result); + auto result_data = FlatVector::Writer(result, count); if (left_data.CanHaveNull() || right_data.CanHaveNull()) { for (idx_t i = 0; i < count; i++) { auto left_entry = left_data[i]; auto right_entry = right_data[i]; + bool result_value = false; bool is_null = OP::Operation(left_entry.GetValueUnsafe() > 0, right_entry.GetValueUnsafe() > 0, - !left_entry.IsValid(), !right_entry.IsValid(), result_data[i]); + !left_entry.IsValid(), !right_entry.IsValid(), result_value); if (is_null) { - result_data.SetInvalid(i); + result_data.WriteNull(result_value); + } else { + result_data.WriteValue(result_value); } } } else { for (idx_t i = 0; i < count; i++) { - auto left_val = left_data.GetValueUnsafe(i); - auto right_val = right_data.GetValueUnsafe(i); - result_data[i] = OP::SimpleOperation(left_val, right_val); + result_data.WriteValue(OP::SimpleOperation(left_data.GetValueUnsafe(i), right_data.GetValueUnsafe(i))); } } } @@ -165,17 +166,29 @@ struct NotOperator { } // namespace -void VectorOperations::And(Vector &left, Vector &right, Vector &result, idx_t count) { - TemplatedBooleanNullmask(left, right, result, count); +void VectorOperations::And(const Vector &left, const Vector &right, Vector &result) { + const bool left_is_const = left.GetVectorType() == VectorType::CONSTANT_VECTOR; + const bool right_is_const = right.GetVectorType() == VectorType::CONSTANT_VECTOR; + if (!left_is_const && !right_is_const && left.size() != right.size()) { + throw InternalException("Mismatch in input vector sizes for And - left has %d rows but right has %d", + left.size(), right.size()); + } + TemplatedBooleanNullmask(left, right, result); } -void VectorOperations::Or(Vector &left, Vector &right, Vector &result, idx_t count) { - TemplatedBooleanNullmask(left, right, result, count); +void VectorOperations::Or(const Vector &left, const Vector &right, Vector &result) { + const bool left_is_const = left.GetVectorType() == VectorType::CONSTANT_VECTOR; + const bool right_is_const = right.GetVectorType() == VectorType::CONSTANT_VECTOR; + if (!left_is_const && !right_is_const && left.size() != right.size()) { + throw InternalException("Mismatch in input vector sizes for Or - left has %d rows but right has %d", + left.size(), right.size()); + } + TemplatedBooleanNullmask(left, right, result); } -void VectorOperations::Not(Vector &input, Vector &result, idx_t count) { +void VectorOperations::Not(const Vector &input, Vector &result) { D_ASSERT(input.GetType() == LogicalType::BOOLEAN && result.GetType() == LogicalType::BOOLEAN); - UnaryExecutor::Execute(input, result, count); + UnaryExecutor::Execute(input, result); } } // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/comparison_operators.cpp b/src/duckdb/src/common/vector_operations/comparison_operators.cpp index e4b4b0d9e..deae898ed 100644 --- a/src/duckdb/src/common/vector_operations/comparison_operators.cpp +++ b/src/duckdb/src/common/vector_operations/comparison_operators.cpp @@ -129,7 +129,7 @@ int8_t Comparator::Operation(const double &left, const double &right) { // Fast path: direct BinaryExecutor for primitive types (single pass, no intermediate vector) //===--------------------------------------------------------------------===// template -static bool TryPrimitiveComparisonExecute(Vector &left, Vector &right, Vector &result, idx_t count) { +static bool TryPrimitiveComparisonExecute(const Vector &left, const Vector &right, Vector &result) { #ifdef DUCKDB_SMALLER_BINARY return false; #else @@ -137,46 +137,46 @@ static bool TryPrimitiveComparisonExecute(Vector &left, Vector &right, Vector &r switch (left.GetType().InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::INT16: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::INT32: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::INT64: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::UINT8: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::UINT16: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::UINT32: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::UINT64: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::INT128: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::UINT128: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::FLOAT: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::DOUBLE: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::INTERVAL: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; case PhysicalType::VARCHAR: - BinaryExecutor::Execute(left, right, result, count); + BinaryExecutor::Execute(left, right, result); return true; default: return false; @@ -185,10 +185,11 @@ static bool TryPrimitiveComparisonExecute(Vector &left, Vector &right, Vector &r } template -static void ComparatorToBoolean(Vector &left, Vector &right, Vector &result, idx_t count, PREDICATE predicate) { +static void ComparatorToBoolean(const Vector &left, const Vector &right, Vector &result, PREDICATE predicate) { D_ASSERT(result.GetType() == LogicalType::BOOLEAN); - Vector comparator_result(LogicalType::TINYINT, count); - VectorOperations::Comparator(left, right, comparator_result, count); + Vector comparator_result(LogicalType::TINYINT); + VectorOperations::Comparator(left, right, comparator_result); + const auto count = comparator_result.size(); auto cmp_data = comparator_result.Values(); result.SetVectorType(VectorType::FLAT_VECTOR); auto result_data = FlatVector::Writer(result, count); @@ -202,60 +203,64 @@ static void ComparatorToBoolean(Vector &left, Vector &right, Vector &result, idx } } -void VectorOperations::Equals(Vector &left, Vector &right, Vector &result, idx_t count) { - if (TryPrimitiveComparisonExecute(left, right, result, count)) { - return; +static idx_t GetComparisonCount(const Vector &left, const Vector &right, const char *fname) { + const bool left_is_const = left.GetVectorType() == VectorType::CONSTANT_VECTOR; + const bool right_is_const = right.GetVectorType() == VectorType::CONSTANT_VECTOR; + if (!left_is_const && !right_is_const && left.size() != right.size()) { + throw InternalException("Mismatch in input vector sizes for %s - left has %d rows but right has %d", fname, + left.size(), right.size()); } - ComparatorToBoolean(left, right, result, count, [](int8_t v) { return v == 0; }); + return left_is_const ? right.size() : left.size(); } -void VectorOperations::NotEquals(Vector &left, Vector &right, Vector &result, idx_t count) { - if (TryPrimitiveComparisonExecute(left, right, result, count)) { - return; +void VectorOperations::Equals(const Vector &left, const Vector &right, Vector &result) { + if (!TryPrimitiveComparisonExecute(left, right, result)) { + ComparatorToBoolean(left, right, result, [](int8_t v) { return v == 0; }); } - ComparatorToBoolean(left, right, result, count, [](int8_t v) { return v != 0; }); } -void VectorOperations::GreaterThan(Vector &left, Vector &right, Vector &result, idx_t count) { - if (TryPrimitiveComparisonExecute(left, right, result, count)) { - return; +void VectorOperations::NotEquals(const Vector &left, const Vector &right, Vector &result) { + if (!TryPrimitiveComparisonExecute(left, right, result)) { + ComparatorToBoolean(left, right, result, [](int8_t v) { return v != 0; }); } - ComparatorToBoolean(left, right, result, count, [](int8_t v) { return v > 0; }); } -void VectorOperations::GreaterThanEquals(Vector &left, Vector &right, Vector &result, idx_t count) { - if (TryPrimitiveComparisonExecute(left, right, result, count)) { - return; +void VectorOperations::GreaterThan(const Vector &left, const Vector &right, Vector &result) { + if (!TryPrimitiveComparisonExecute(left, right, result)) { + ComparatorToBoolean(left, right, result, [](int8_t v) { return v > 0; }); } - ComparatorToBoolean(left, right, result, count, [](int8_t v) { return v >= 0; }); } -void VectorOperations::LessThan(Vector &left, Vector &right, Vector &result, idx_t count) { +void VectorOperations::GreaterThanEquals(const Vector &left, const Vector &right, Vector &result) { + if (!TryPrimitiveComparisonExecute(left, right, result)) { + ComparatorToBoolean(left, right, result, [](int8_t v) { return v >= 0; }); + } +} + +void VectorOperations::LessThan(const Vector &left, const Vector &right, Vector &result) { // NOLINTNEXTLINE: flip right / left (left < right is equal to right > left) - if (TryPrimitiveComparisonExecute(right, left, result, count)) { - return; + if (!TryPrimitiveComparisonExecute(right, left, result)) { + ComparatorToBoolean(left, right, result, [](int8_t v) { return v < 0; }); } - ComparatorToBoolean(left, right, result, count, [](int8_t v) { return v < 0; }); } -void VectorOperations::LessThanEquals(Vector &left, Vector &right, Vector &result, idx_t count) { +void VectorOperations::LessThanEquals(const Vector &left, const Vector &right, Vector &result) { // NOLINTNEXTLINE: flip right / left (left <= right is equal to right >= left) - if (TryPrimitiveComparisonExecute(right, left, result, count)) { - return; + if (!TryPrimitiveComparisonExecute(right, left, result)) { + ComparatorToBoolean(left, right, result, [](int8_t v) { return v <= 0; }); } - ComparatorToBoolean(left, right, result, count, [](int8_t v) { return v <= 0; }); } struct StandardComparatorExecute { template - static inline void Execute(Vector &left, Vector &right, Vector &result, idx_t count) { + static inline void Execute(const Vector &left, const Vector &right, Vector &result, idx_t count) { BinaryExecutor::Execute(left, right, result, count); } }; struct DistinctComparatorExecute { template - static void Execute(Vector &left, Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, + static void Execute(const Vector &left, const Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count) { UnifiedVectorFormat left_format, right_format; left.ToUnifiedFormat(left_format); @@ -273,7 +278,7 @@ struct DistinctComparatorExecute { }; // forward declaration - nested comparators call DistinctComparator recursively for children -static void DistinctComparatorTypeSwitch(Vector &left, Vector &right, int8_t *result_data, +static void DistinctComparatorTypeSwitch(const Vector &left, const Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count); @@ -288,8 +293,8 @@ static int8_t DistinctNullComparator(bool left_null, bool right_null) { return Comparator::RIGHT_IS_GREATER; } -static void StructComparator(Vector &left, Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, - const SelectionVector &rhs_sel, idx_t sel_count, +static void StructComparator(const Vector &left, const Vector &right, int8_t *result_data, + const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count, optional_ptr result_validity = nullptr) { if (sel_count == 0) { return; @@ -365,12 +370,11 @@ static void StructComparator(Vector &left, Vector &right, int8_t *result_data, c } struct ListEntryAccessor { - static Vector &GetChild(Vector &vector) { - return ListVector::GetChildMutable(vector); + static const Vector &GetChild(const Vector &vector) { + return ListVector::GetChild(vector); } - static void FlattenChild(Vector &vector) { - auto &child = ListVector::GetChildMutable(vector); - child.Flatten(); + static void FlattenChild(const Vector &) { + // no-op: ToUnifiedFormat handles all vector types without needing to flatten first } static idx_t GetOffset(UnifiedVectorFormat &format, idx_t sel_idx) { auto entries = UnifiedVectorFormat::GetData(format); @@ -387,12 +391,11 @@ struct ListEntryAccessor { struct ArrayEntryAccessor { explicit ArrayEntryAccessor(idx_t array_size) : array_size(array_size) { } - Vector &GetChild(Vector &vector) { - return ArrayVector::GetChildMutable(vector); + static const Vector &GetChild(const Vector &vector) { + return ArrayVector::GetChild(vector); } - void FlattenChild(Vector &vector) { - auto &child = ArrayVector::GetChildMutable(vector); - child.Flatten(); + static void FlattenChild(const Vector &) { + // no-op: ToUnifiedFormat handles all vector types without needing to flatten first } idx_t GetOffset(UnifiedVectorFormat &format, idx_t sel_idx) { return format.sel->get_index(sel_idx) * array_size; @@ -404,13 +407,13 @@ struct ArrayEntryAccessor { }; template -static void ListOrArrayComparator(Vector &left, Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, - const SelectionVector &rhs_sel, idx_t sel_count, ACCESSOR accessor, - optional_ptr result_validity = nullptr) { +static void ListOrArrayComparator(const Vector &left, const Vector &right, int8_t *result_data, + const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count, + ACCESSOR accessor, optional_ptr result_validity = nullptr) { if (sel_count == 0) { return; } - // recursively flatten child vectors so they can be indexed directly via selection vectors + // FlattenChild is a no-op; ToUnifiedFormat handles all vector types in the recursive comparators accessor.FlattenChild(left); accessor.FlattenChild(right); // step 1: handle top-level validity @@ -519,22 +522,22 @@ static void ListOrArrayComparator(Vector &left, Vector &right, int8_t *result_da } } -static void ListComparator(Vector &left, Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, +static void ListComparator(const Vector &left, const Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count, optional_ptr result_validity = nullptr) { ListEntryAccessor accessor; ListOrArrayComparator(left, right, result_data, lhs_sel, rhs_sel, sel_count, accessor, result_validity); } -static void ArrayComparator(Vector &left, Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, - const SelectionVector &rhs_sel, idx_t sel_count, +static void ArrayComparator(const Vector &left, const Vector &right, int8_t *result_data, + const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count, optional_ptr result_validity = nullptr) { ArrayEntryAccessor accessor(ArrayType::GetSize(left.GetType())); ListOrArrayComparator(left, right, result_data, lhs_sel, rhs_sel, sel_count, accessor, result_validity); } -static void VariantComparator(Vector &left, Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, - const SelectionVector &rhs_sel, idx_t sel_count, +static void VariantComparator(const Vector &left, const Vector &right, int8_t *result_data, + const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count, optional_ptr result_validity = nullptr) { RecursiveUnifiedVectorFormat left_recursive_data, right_recursive_data; Vector::RecursiveToUnifiedFormat(left, left_recursive_data); @@ -568,7 +571,7 @@ static void VariantComparator(Vector &left, Vector &right, int8_t *result_data, auto right_val = VariantUtils::ConvertVariantToValue(right_variant, rhs_sel.get_index(i), 0); LogicalType max_logical_type; - if (!LogicalType::TryGetMaxLogicalTypeUnchecked(left_val.type(), right_val.type(), max_logical_type)) { + if (!LogicalType::DefaultTryGetMaxLogicalTypeUnchecked(left_val.type(), right_val.type(), max_logical_type)) { throw InvalidInputException( "Can't compare values of type %s (%s) and type %s (%s) - an explicit cast is required", left_val.type().ToString(), left_val.ToString(), right_val.type().ToString(), right_val.ToString()); @@ -584,7 +587,7 @@ static void VariantComparator(Vector &left, Vector &right, int8_t *result_data, } } -static void DistinctComparatorTypeSwitchInternal(Vector &left, Vector &right, int8_t *result_data, +static void DistinctComparatorTypeSwitchInternal(const Vector &left, const Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count) { D_ASSERT(left.GetType().InternalType() == right.GetType().InternalType()); @@ -650,13 +653,13 @@ static void DistinctComparatorTypeSwitchInternal(Vector &left, Vector &right, in } } -static void DistinctComparatorTypeSwitch(Vector &left, Vector &right, int8_t *result_data, +static void DistinctComparatorTypeSwitch(const Vector &left, const Vector &right, int8_t *result_data, const SelectionVector &lhs_sel, const SelectionVector &rhs_sel, idx_t sel_count) { DistinctComparatorTypeSwitchInternal(left, right, result_data, lhs_sel, rhs_sel, sel_count); } -static void ComparatorTypeSwitch(Vector &left, Vector &right, Vector &result, idx_t count) { +static void ComparatorTypeSwitch(const Vector &left, const Vector &right, Vector &result, idx_t count) { D_ASSERT(left.GetType().InternalType() == right.GetType().InternalType() && result.GetType() == LogicalType::TINYINT); switch (left.GetType().InternalType()) { @@ -727,15 +730,21 @@ static void ComparatorTypeSwitch(Vector &left, Vector &right, Vector &result, id } } -void VectorOperations::Comparator(Vector &left, Vector &right, Vector &result, idx_t count) { +void VectorOperations::ComparatorFill(const Vector &left, const Vector &right, Vector &result, idx_t count) { ComparatorTypeSwitch(left, right, result, count); + FlatVector::SetSize(result, count); +} + +void VectorOperations::Comparator(const Vector &left, const Vector &right, Vector &result) { + const auto count = GetComparisonCount(left, right, "Comparator"); + ComparatorFill(left, right, result, count); } template static void DistinctExecuteGenericLoop(const T *__restrict ldata, const T *__restrict rdata, int8_t *__restrict result_data, const SelectionVector *__restrict lsel, - const SelectionVector *__restrict rsel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask) { + const SelectionVector *__restrict rsel, idx_t count, const ValidityMask &lmask, + const ValidityMask &rmask) { for (idx_t i = 0; i < count; i++) { auto lindex = lsel->get_index(i); auto rindex = rsel->get_index(i); @@ -745,7 +754,7 @@ static void DistinctExecuteGenericLoop(const T *__restrict ldata, const T *__res } template -static void DistinctExecuteConstant(Vector &left, Vector &right, Vector &result) { +static void DistinctExecuteConstant(const Vector &left, const Vector &right, Vector &result) { result.SetVectorType(VectorType::CONSTANT_VECTOR); auto ldata = ConstantVector::GetData(left); auto rdata = ConstantVector::GetData(right); @@ -755,7 +764,7 @@ static void DistinctExecuteConstant(Vector &left, Vector &right, Vector &result) } template -static void DistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { +static void DistinctExecute(const Vector &left, const Vector &right, Vector &result, idx_t count) { if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { DistinctExecuteConstant(left, right, result); } else { @@ -771,7 +780,8 @@ static void DistinctExecute(Vector &left, Vector &right, Vector &result, idx_t c } template -static bool TryPrimitiveDistinctComparatorExecute(Vector &left, Vector &right, Vector &result, idx_t count) { +static bool TryPrimitiveDistinctComparatorExecute(const Vector &left, const Vector &right, Vector &result, + idx_t count) { #ifdef DUCKDB_SMALLER_BINARY return false; #else @@ -826,25 +836,32 @@ static bool TryPrimitiveDistinctComparatorExecute(Vector &left, Vector &right, V #endif } -void VectorOperations::DistinctComparator(Vector &left, Vector &right, Vector &result, idx_t count) { +void VectorOperations::DistinctComparatorFill(const Vector &left, const Vector &right, Vector &result, idx_t count) { D_ASSERT(result.GetType() == LogicalType::TINYINT); - if (TryPrimitiveDistinctComparatorExecute(left, right, result, count)) { - return; + if (!TryPrimitiveDistinctComparatorExecute(left, right, result, count)) { + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetDataMutable(result); + auto &sel = *FlatVector::IncrementalSelectionVector(); + DistinctComparatorTypeSwitchInternal(left, right, result_data, sel, sel, count); } - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetDataMutable(result); - auto &sel = *FlatVector::IncrementalSelectionVector(); - DistinctComparatorTypeSwitchInternal(left, right, result_data, sel, sel, count); + FlatVector::SetSize(result, count); +} + +void VectorOperations::DistinctComparator(const Vector &left, const Vector &right, Vector &result) { + const auto count = GetComparisonCount(left, right, "DistinctComparator"); + DistinctComparatorFill(left, right, result, count); } -void VectorOperations::DistinctComparatorNullsFirst(Vector &left, Vector &right, Vector &result, idx_t count) { +void VectorOperations::DistinctComparatorNullsFirstFill(const Vector &left, const Vector &right, Vector &result, + idx_t count) { if (TryPrimitiveDistinctComparatorExecute(left, right, result, count)) { + FlatVector::SetSize(result, count); return; } // run the NULLS LAST comparator, then flip the sign for NULL-involving rows // note that even for NULLS FIRST, ONLY the top-level is NULLS FIRST, // i.e. within structs we still use NULLS LAST semantics - VectorOperations::DistinctComparator(left, right, result, count); + DistinctComparatorFill(left, right, result, count); result.Flatten(); auto result_data = FlatVector::GetDataMutable(result); auto left_validity = left.Validity(); @@ -861,4 +878,9 @@ void VectorOperations::DistinctComparatorNullsFirst(Vector &left, Vector &right, } } +void VectorOperations::DistinctComparatorNullsFirst(const Vector &left, const Vector &right, Vector &result) { + const auto count = GetComparisonCount(left, right, "DistinctComparatorNullsFirst"); + DistinctComparatorNullsFirstFill(left, right, result, count); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/generators.cpp b/src/duckdb/src/common/vector_operations/generators.cpp index e3c69b93f..d1c3d1837 100644 --- a/src/duckdb/src/common/vector_operations/generators.cpp +++ b/src/duckdb/src/common/vector_operations/generators.cpp @@ -7,7 +7,6 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/common/limits.hpp" -#include "duckdb/common/numeric_utils.hpp" namespace duckdb { diff --git a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp index 7f0407558..33e010f72 100644 --- a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp +++ b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp @@ -4,10 +4,11 @@ namespace duckdb { -void VectorOperations::DistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count) { +void VectorOperations::DistinctFrom(const Vector &left, const Vector &right, Vector &result) { D_ASSERT(result.GetType() == LogicalType::BOOLEAN); - Vector comparator_result(LogicalType::TINYINT, count); - VectorOperations::DistinctComparator(left, right, comparator_result, count); + Vector comparator_result(LogicalType::TINYINT); + VectorOperations::DistinctComparator(left, right, comparator_result); + const idx_t count = comparator_result.size(); auto cmp_data = comparator_result.Values(); result.SetVectorType(VectorType::FLAT_VECTOR); auto result_data = FlatVector::Writer(result, count); @@ -16,10 +17,11 @@ void VectorOperations::DistinctFrom(Vector &left, Vector &right, Vector &result, } } -void VectorOperations::NotDistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count) { +void VectorOperations::NotDistinctFrom(const Vector &left, const Vector &right, Vector &result) { D_ASSERT(result.GetType() == LogicalType::BOOLEAN); - Vector comparator_result(LogicalType::TINYINT, count); - VectorOperations::DistinctComparator(left, right, comparator_result, count); + Vector comparator_result(LogicalType::TINYINT); + VectorOperations::DistinctComparator(left, right, comparator_result); + const idx_t count = comparator_result.size(); auto cmp_data = comparator_result.Values(); result.SetVectorType(VectorType::FLAT_VECTOR); auto result_data = FlatVector::Writer(result, count); @@ -28,12 +30,12 @@ void VectorOperations::NotDistinctFrom(Vector &left, Vector &right, Vector &resu } } -template -static idx_t DistinctComparatorSelect(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - COMPARATOR_FN comparator_fn, PREDICATE predicate) { +template +static idx_t DistinctComparatorSelect(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, FILL_FN fill_fn, PREDICATE predicate) { Vector comparator_result(LogicalType::TINYINT, count); - comparator_fn(left, right, comparator_result, count); + fill_fn(left, right, comparator_result, count); auto cmp_data = comparator_result.Values(); idx_t true_count = 0; @@ -56,86 +58,94 @@ static idx_t DistinctComparatorSelect(Vector &left, Vector &right, optional_ptr< } // true := A != B with nulls being equal -idx_t VectorOperations::DistinctFrom(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel) { - return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, VectorOperations::DistinctComparator, - [](int8_t v) { return v != 0; }); +idx_t VectorOperations::DistinctFrom(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel) { + return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, + VectorOperations::DistinctComparatorFill, [](int8_t v) { return v != 0; }); } // true := A == B with nulls being equal -idx_t VectorOperations::NotDistinctFrom(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, +idx_t VectorOperations::NotDistinctFrom(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel) { - return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, VectorOperations::DistinctComparator, - [](int8_t v) { return v == 0; }); + return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, + VectorOperations::DistinctComparatorFill, [](int8_t v) { return v == 0; }); } // true := A > B with nulls being maximal -idx_t VectorOperations::DistinctGreaterThan(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, +idx_t VectorOperations::DistinctGreaterThan(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { - return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, VectorOperations::DistinctComparator, - [](int8_t v) { return v > 0; }); + return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, + VectorOperations::DistinctComparatorFill, [](int8_t v) { return v > 0; }); } // true := A > B with nulls being minimal -idx_t VectorOperations::DistinctGreaterThanNullsFirst(Vector &left, Vector &right, +idx_t VectorOperations::DistinctGreaterThanNullsFirst(const Vector &left, const Vector &right, optional_ptr sel, idx_t count, optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, - VectorOperations::DistinctComparatorNullsFirst, [](int8_t v) { return v > 0; }); + VectorOperations::DistinctComparatorNullsFirstFill, [](int8_t v) { return v > 0; }); } // true := A >= B with nulls being maximal -idx_t VectorOperations::DistinctGreaterThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, +idx_t VectorOperations::DistinctGreaterThanEquals(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { - return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, VectorOperations::DistinctComparator, - [](int8_t v) { return v >= 0; }); + return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, + VectorOperations::DistinctComparatorFill, [](int8_t v) { return v >= 0; }); } // true := A < B with nulls being maximal -idx_t VectorOperations::DistinctLessThan(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, +idx_t VectorOperations::DistinctLessThan(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { - return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, VectorOperations::DistinctComparator, - [](int8_t v) { return v < 0; }); + return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, + VectorOperations::DistinctComparatorFill, [](int8_t v) { return v < 0; }); } // true := A < B with nulls being minimal -idx_t VectorOperations::DistinctLessThanNullsFirst(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, +idx_t VectorOperations::DistinctLessThanNullsFirst(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, - VectorOperations::DistinctComparatorNullsFirst, [](int8_t v) { return v < 0; }); + VectorOperations::DistinctComparatorNullsFirstFill, [](int8_t v) { return v < 0; }); } // true := A <= B with nulls being maximal -idx_t VectorOperations::DistinctLessThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, +idx_t VectorOperations::DistinctLessThanEquals(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { - return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, VectorOperations::DistinctComparator, - [](int8_t v) { return v <= 0; }); + return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, + VectorOperations::DistinctComparatorFill, [](int8_t v) { return v <= 0; }); } // true := A != B with nulls being equal, inputs selected -idx_t VectorOperations::NestedNotEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, optional_ptr null_mask) { - return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, VectorOperations::DistinctComparator, - [](int8_t v) { return v != 0; }); +idx_t VectorOperations::NestedNotEquals(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, + optional_ptr null_mask) { + return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, + VectorOperations::DistinctComparatorFill, [](int8_t v) { return v != 0; }); } // true := A == B with nulls being equal, inputs selected -idx_t VectorOperations::NestedEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, VectorOperations::DistinctComparator, - [](int8_t v) { return v == 0; }); +idx_t VectorOperations::NestedEquals(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask) { + return DistinctComparatorSelect(left, right, sel, count, true_sel, false_sel, + VectorOperations::DistinctComparatorFill, [](int8_t v) { return v == 0; }); } } // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/null_operations.cpp b/src/duckdb/src/common/vector_operations/null_operations.cpp index 42475df9e..4b34493d4 100644 --- a/src/duckdb/src/common/vector_operations/null_operations.cpp +++ b/src/duckdb/src/common/vector_operations/null_operations.cpp @@ -6,15 +6,15 @@ #include "duckdb/common/vector/constant_vector.hpp" #include "duckdb/common/vector/flat_vector.hpp" -#include "duckdb/common/exception.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" namespace duckdb { template -static void IsNullLoop(Vector &input, Vector &result, idx_t count) { +static void IsNullLoop(const Vector &input, Vector &result) { D_ASSERT(result.GetType() == LogicalType::BOOLEAN); + auto count = input.size(); if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); auto result_data = ConstantVector::GetData(result); @@ -29,15 +29,16 @@ static void IsNullLoop(Vector &input, Vector &result, idx_t count) { } } -void VectorOperations::IsNotNull(Vector &input, Vector &result, idx_t count) { - IsNullLoop(input, result, count); +void VectorOperations::IsNotNull(const Vector &input, Vector &result) { + IsNullLoop(input, result); } -void VectorOperations::IsNull(Vector &input, Vector &result, idx_t count) { - IsNullLoop(input, result, count); +void VectorOperations::IsNull(const Vector &input, Vector &result) { + IsNullLoop(input, result); } -bool VectorOperations::HasNotNull(Vector &input, idx_t count) { +bool VectorOperations::HasNotNull(const Vector &input) { + auto count = input.size(); if (count == 0) { return false; } @@ -57,7 +58,8 @@ bool VectorOperations::HasNotNull(Vector &input, idx_t count) { } } -bool VectorOperations::HasNull(Vector &input, idx_t count) { +bool VectorOperations::HasNull(const Vector &input) { + auto count = input.size(); if (count == 0) { return false; } @@ -77,30 +79,31 @@ bool VectorOperations::HasNull(Vector &input, idx_t count) { } } -idx_t VectorOperations::CountNotNull(Vector &input, const idx_t count) { - idx_t valid = 0; +idx_t VectorOperations::CountNotNull(const Vector &input) { + auto count = input.size(); - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(vdata); - if (vdata.validity.CannotHaveNull()) { - return count; - } switch (input.GetVectorType()) { case VectorType::FLAT_VECTOR: - valid += vdata.validity.CountValid(count); - break; + return FlatVector::Validity(input).CountValid(count); case VectorType::CONSTANT_VECTOR: - valid += vdata.validity.CountValid(1) * count; - break; - default: + if (!ConstantVector::IsNull(input)) { + return count; + } + return 0; + default: { + auto validity = input.Validity(); + if (validity.CannotHaveNull()) { + return count; + } + idx_t valid = 0; for (idx_t i = 0; i < count; ++i) { - const auto row_idx = vdata.sel->get_index(i); - valid += idx_t(vdata.validity.RowIsValid(row_idx)); + if (validity.IsValid(i)) { + valid++; + } } - break; + return valid; + } } - - return valid; } } // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp b/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp index 4db3b6e22..deda85aca 100644 --- a/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp +++ b/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp @@ -16,11 +16,12 @@ namespace duckdb { // In-Place Addition //===--------------------------------------------------------------------===// -void VectorOperations::AddInPlace(Vector &input, int64_t right, idx_t count) { +void VectorOperations::AddInPlace(Vector &input, int64_t right) { D_ASSERT(input.GetType().id() == LogicalTypeId::POINTER); if (right == 0) { return; } + const idx_t count = input.size(); switch (input.GetVectorType()) { case VectorType::CONSTANT_VECTOR: { D_ASSERT(!ConstantVector::IsNull(input)); diff --git a/src/duckdb/src/common/vector_operations/vector_hash.cpp b/src/duckdb/src/common/vector_operations/vector_hash.cpp index 6257334a5..e7d6c77df 100644 --- a/src/duckdb/src/common/vector_operations/vector_hash.cpp +++ b/src/duckdb/src/common/vector_operations/vector_hash.cpp @@ -68,7 +68,7 @@ void TightLoopHash(const T *__restrict ldata, hash_t *__restrict result_data, co } template -void TemplatedLoopHash(Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { +void TemplatedLoopHash(const Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); FlatVector::SetSize(result, count_t(count)); @@ -97,8 +97,8 @@ void TemplatedLoopHash(Vector &input, Vector &result, const SelectionVector *rse } template -void StructLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { - auto &children = StructVector::GetEntries(input); +void StructLoopHash(const Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { + const auto &children = StructVector::GetEntries(input); D_ASSERT(!children.empty()); idx_t col_no = 0; @@ -124,7 +124,7 @@ void StructLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, } template -void ListLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { +void ListLoopHash(const Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { // FIXME: if we want to be more efficient we shouldn't flatten, but the logic here currently requires it hashes.Flatten(); auto hdata = FlatVector::GetDataMutable(hashes); @@ -134,7 +134,7 @@ void ListLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, id const auto ldata = UnifiedVectorFormat::GetData(idata); // Hash the children into a temporary - auto &child = ListVector::GetChildMutable(input); + const auto &child = ListVector::GetChild(input); const auto child_count = ListVector::GetListSize(input); Vector child_hashes(LogicalType::HASH, child_count); @@ -215,7 +215,7 @@ void ListLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, id } template -void ArrayLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { +void ArrayLoopHash(const Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { hashes.Flatten(); auto hdata = FlatVector::GetDataMutable(hashes); @@ -223,7 +223,7 @@ void ArrayLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, i input.ToUnifiedFormat(idata); // Hash the children into a temporary - auto &child = ArrayVector::GetChildMutable(input); + const auto &child = ArrayVector::GetChild(input); auto array_size = ArrayType::GetSize(input.GetType()); auto is_flat = input.GetVectorType() == VectorType::FLAT_VECTOR; @@ -288,7 +288,7 @@ void ArrayLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, i } template -void HashTypeSwitch(Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { +void HashTypeSwitch(const Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { D_ASSERT(result.GetType().id() == LogicalType::HASH); switch (input.GetType().InternalType()) { case PhysicalType::BOOL: @@ -351,7 +351,7 @@ void HashTypeSwitch(Vector &input, Vector &result, const SelectionVector *rsel, template void TightLoopCombineHashConstant(const T *__restrict ldata, hash_t constant_hash, hash_t *__restrict hash_data, const SelectionVector *rsel, idx_t count, - const SelectionVector *__restrict sel_vector, ValidityMask &mask) { + const SelectionVector *__restrict sel_vector, const ValidityMask &mask) { if (mask.CanHaveNull()) { for (idx_t i = 0; i < count; i++) { auto ridx = HAS_RSEL ? rsel->get_index(i) : i; @@ -393,7 +393,7 @@ static inline void TightLoopCombineHash(const T *__restrict ldata, hash_t *__res } template -void TemplatedLoopCombineHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { +void TemplatedLoopCombineHash(const Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { if (input.GetVectorType() == VectorType::CONSTANT_VECTOR && hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { auto ldata = ConstantVector::GetData(input); auto hash_data = ConstantVector::GetData(hashes); @@ -409,6 +409,7 @@ void TemplatedLoopCombineHash(Vector &input, Vector &hashes, const SelectionVect auto constant_hash = *ConstantVector::GetData(hashes); // now re-initialize the hashes vector to an empty flat vector hashes.SetVectorType(VectorType::FLAT_VECTOR); + FlatVector::SetSize(hashes, count_t(count)); TightLoopCombineHashConstant( UnifiedVectorFormat::GetData(idata), constant_hash, FlatVector::GetDataMutable(hashes), rsel, count, idata.sel, idata.validity); @@ -428,7 +429,7 @@ void TemplatedLoopCombineHash(Vector &input, Vector &hashes, const SelectionVect } template -void CombineHashTypeSwitch(Vector &hashes, Vector &input, const SelectionVector *rsel, idx_t count) { +void CombineHashTypeSwitch(Vector &hashes, const Vector &input, const SelectionVector *rsel, idx_t count) { D_ASSERT(hashes.GetType().id() == LogicalType::HASH); switch (input.GetType().InternalType()) { case PhysicalType::BOOL: @@ -490,7 +491,7 @@ void CombineHashTypeSwitch(Vector &hashes, Vector &input, const SelectionVector } // namespace -void VectorOperations::Hash(Vector &input, Vector &result, idx_t count) { +void VectorOperations::Hash(const Vector &input, Vector &result, idx_t count) { if (input.GetVectorType() == VectorType::DICTIONARY_VECTOR && DictionaryVector::CanCacheHashes(input)) { Vector input_hashes(DictionaryVector::GetCachedHashes(input), DictionaryVector::SelVector(input), count); TemplatedLoopHash(input_hashes, result, nullptr, count); @@ -499,7 +500,7 @@ void VectorOperations::Hash(Vector &input, Vector &result, idx_t count) { } } -void VectorOperations::Hash(Vector &input, Vector &result, const SelectionVector &sel, idx_t count) { +void VectorOperations::Hash(const Vector &input, Vector &result, const SelectionVector &sel, idx_t count) { if (input.GetVectorType() == VectorType::DICTIONARY_VECTOR && DictionaryVector::CanCacheHashes(input)) { Vector input_hashes(DictionaryVector::GetCachedHashes(input), DictionaryVector::SelVector(input), count); TemplatedLoopHash(input_hashes, result, &sel, count); @@ -508,7 +509,11 @@ void VectorOperations::Hash(Vector &input, Vector &result, const SelectionVector } } -void VectorOperations::CombineHash(Vector &hashes, Vector &input, idx_t count) { +void VectorOperations::Hash(const Vector &input, Vector &result) { + Hash(input, result, input.size()); +} + +void VectorOperations::CombineHash(Vector &hashes, const Vector &input, idx_t count) { if (input.GetVectorType() == VectorType::DICTIONARY_VECTOR && DictionaryVector::CanCacheHashes(input)) { Vector input_hashes(DictionaryVector::GetCachedHashes(input), DictionaryVector::SelVector(input), count); TemplatedLoopCombineHash(input_hashes, hashes, nullptr, count); @@ -517,7 +522,17 @@ void VectorOperations::CombineHash(Vector &hashes, Vector &input, idx_t count) { } } -void VectorOperations::CombineHash(Vector &hashes, Vector &input, const SelectionVector &rsel, idx_t count) { +void VectorOperations::CombineHash(Vector &hashes, const Vector &input) { + const bool hashes_const = hashes.GetVectorType() == VectorType::CONSTANT_VECTOR; + const bool input_const = input.GetVectorType() == VectorType::CONSTANT_VECTOR; + if (!hashes_const && !input_const && hashes.size() != input.size()) { + throw InternalException("Mismatch in input vector sizes for CombineHash - hashes has %d rows but input has %d", + hashes.size(), input.size()); + } + CombineHash(hashes, input, hashes_const ? input.size() : hashes.size()); +} + +void VectorOperations::CombineHash(Vector &hashes, const Vector &input, const SelectionVector &rsel, idx_t count) { if (input.GetVectorType() == VectorType::DICTIONARY_VECTOR && DictionaryVector::CanCacheHashes(input)) { Vector input_hashes(DictionaryVector::GetCachedHashes(input), DictionaryVector::SelVector(input), count); TemplatedLoopCombineHash(input_hashes, hashes, &rsel, count); diff --git a/src/duckdb/src/common/vector_operations/vector_storage.cpp b/src/duckdb/src/common/vector_operations/vector_storage.cpp index 9aaf570ce..a575fbe14 100644 --- a/src/duckdb/src/common/vector_operations/vector_storage.cpp +++ b/src/duckdb/src/common/vector_operations/vector_storage.cpp @@ -9,8 +9,7 @@ namespace duckdb { namespace { template -void CopyToStorageLoop(Vector &source, idx_t count, data_ptr_t target) { - D_ASSERT(source.size() == count); +void CopyToStorageLoop(const Vector &source, data_ptr_t target) { auto result_data = (T *)target; for (auto entry : source.Values()) { if (!entry.IsValid()) { @@ -32,51 +31,51 @@ void ReadFromStorageLoop(data_ptr_t source, idx_t count, Vector &result) { } // namespace -void VectorOperations::WriteToStorage(Vector &source, idx_t count, data_ptr_t target) { - if (count == 0) { +void VectorOperations::WriteToStorage(const Vector &source, data_ptr_t target) { + if (source.size() == 0) { return; } switch (source.GetType().InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::INT16: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::INT32: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::INT64: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::UINT8: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::UINT16: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::UINT32: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::UINT64: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::INT128: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::UINT128: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::FLOAT: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::DOUBLE: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; case PhysicalType::INTERVAL: - CopyToStorageLoop(source, count, target); + CopyToStorageLoop(source, target); break; default: throw NotImplementedException("Unimplemented type for WriteToStorage"); diff --git a/src/duckdb/src/common/virtual_file_system.cpp b/src/duckdb/src/common/virtual_file_system.cpp index 1cf436ff7..d5f522c19 100644 --- a/src/duckdb/src/common/virtual_file_system.cpp +++ b/src/duckdb/src/common/virtual_file_system.cpp @@ -1,6 +1,7 @@ #include "duckdb/common/virtual_file_system.hpp" #include "duckdb/common/file_opener.hpp" +#include "duckdb/common/memory_mapped_file.hpp" #include "duckdb/common/gzip_file_system.hpp" #include "duckdb/common/pipe_file_system.hpp" #include "duckdb/common/string_util.hpp" @@ -112,6 +113,22 @@ VirtualFileSystem::VirtualFileSystem(unique_ptr &&inner) VirtualFileSystem::~VirtualFileSystem() { } +unique_ptr VirtualFileSystem::MemoryMapFile(const OpenFileInfo &path, FileOpenFlags flags, + const MMapOptions &options, + optional_ptr opener) { + auto registry = file_system_registry.atomic_load(); + auto &internal_filesystem = FindFileSystem(registry, path.path, opener); + return internal_filesystem.MemoryMapFile(path, flags, options, opener); +} + +FileSystem &VirtualFileSystem::GetDefaultFileSystem() { + auto &fs = *file_system_registry->default_fs->file_system; + if (SubSystemIsDisabled(fs.GetName())) { + throw PermissionException("File system %s has been disabled by configuration", fs.GetName()); + } + return fs; +} + unique_ptr VirtualFileSystem::OpenFileExtended(const OpenFileInfo &file, FileOpenFlags flags, optional_ptr opener) { auto compression = flags.Compression(); diff --git a/src/duckdb/src/common/windows_util.cpp b/src/duckdb/src/common/windows_util.cpp index 8be7a3a04..9ddab8b6f 100644 --- a/src/duckdb/src/common/windows_util.cpp +++ b/src/duckdb/src/common/windows_util.cpp @@ -16,7 +16,11 @@ std::wstring WindowsUtil::UTF8ToUnicode(const char *input) { if (result_size == 0) { throw IOException("Failure in MultiByteToWideChar"); } - return std::wstring(buffer.get(), result_size); + // If cbMultiByte parameter is -1, the function processes the entire input string, + // including the terminating null character. Therefore, the resulting Unicode string + // has a terminating null character, and the length returned by the function + // includes this character. + return std::wstring(buffer.get(), result_size - 1); } static string WideCharToMultiByteWrapper(LPCWSTR input, uint32_t code_page) { diff --git a/src/duckdb/src/execution/adaptive_filter.cpp b/src/duckdb/src/execution/adaptive_filter.cpp index ee5827445..3383f4f7e 100644 --- a/src/duckdb/src/execution/adaptive_filter.cpp +++ b/src/duckdb/src/execution/adaptive_filter.cpp @@ -14,17 +14,17 @@ namespace duckdb { AdaptiveFilter::AdaptiveFilter(const Expression &expr) : observe_interval(10), execute_interval(20), warmup(true) { auto &conj_expr = expr.Cast(); - D_ASSERT(conj_expr.children.size() > 1); - for (idx_t idx = 0; idx < conj_expr.children.size(); idx++) { + D_ASSERT(conj_expr.GetChildren().size() > 1); + for (idx_t idx = 0; idx < conj_expr.GetChildren().size(); idx++) { permutation.push_back(idx); - if (conj_expr.children[idx]->CanThrow()) { + if (conj_expr.GetChildren()[idx]->CanThrow()) { disable_permutations = true; } - if (idx != conj_expr.children.size() - 1) { + if (idx != conj_expr.GetChildren().size() - 1) { swap_likeliness.push_back(100); } } - right_random_border = 100 * (conj_expr.children.size() - 1); + right_random_border = 100 * (conj_expr.GetChildren().size() - 1); } AdaptiveFilter::AdaptiveFilter(const TableFilterSet &table_filters, vector filter_global_pos_p) diff --git a/src/duckdb/src/execution/aggregate_hashtable.cpp b/src/duckdb/src/execution/aggregate_hashtable.cpp index f0c07fc93..8750fc6b8 100644 --- a/src/duckdb/src/execution/aggregate_hashtable.cpp +++ b/src/duckdb/src/execution/aggregate_hashtable.cpp @@ -11,7 +11,6 @@ #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/ht_entry.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" - namespace duckdb { using ValidityBytes = TupleDataLayout::ValidityBytes; @@ -48,6 +47,12 @@ GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context_p, A : BaseAggregateHashTable(context_p, allocator, aggregate_objects_p, std::move(payload_types_p)), context(context_p), radix_bits(radix_bits), count(0), capacity(0), sink_count(0), skip_lookups(false), enable_hll(false), aggregate_allocator(make_shared_ptr(allocator)), state(*aggregate_allocator) { + clustered_state.all_clustered = AllAggregatesClustered(aggregate_objects_p); + clustered_state.n_clustered = CountAggregatesClustered(aggregate_objects_p); + if (clustered_state.n_clustered > 1) { + clustered_state.Initialize(); + } + // Append hash column to the end and initialise the row layout group_types_p.emplace_back(LogicalType::HASH); @@ -187,7 +192,7 @@ void GroupedAggregateHashTable::DestroyAggregateData(PartitionedTupleData &data, TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); auto &row_locations = iterator.GetChunkState().row_locations; do { - RowOperations::DestroyStates(state.row_state, *layout_ptr, row_locations, iterator.GetCurrentChunkCount()); + RowOperations::DestroyStates(state.row_state, *layout_ptr, row_locations); } while (iterator.Next()); data_collection->Reset(); } @@ -291,7 +296,6 @@ idx_t GroupedAggregateHashTable::GetHLLUpperBound() const { } void GroupedAggregateHashTable::Resize(idx_t size) { - D_ASSERT(size >= STANDARD_VECTOR_SIZE); D_ASSERT(IsPowerOfTwo(size)); if (Count() != 0 && size < capacity) { throw InternalException("Cannot downsize a non-empty hash table!"); @@ -373,7 +377,7 @@ optional_idx GroupedAggregateHashTable::TryAddDictionaryGroups(DataChunk &groups static constexpr idx_t MAX_DICTIONARY_SIZE_THRESHOLD = 20000; static constexpr idx_t DICTIONARY_THRESHOLD = 2; // dictionary vector - check if this is a duplicate eliminated dictionary from the storage - auto &dict_col = groups.data[0]; + const auto &dict_col = groups.data[0]; auto opt_dict_size = DictionaryVector::DictionarySize(dict_col); if (!opt_dict_size.IsValid()) { // dict size not known - this is not a dictionary that comes from the storage @@ -396,8 +400,8 @@ optional_idx GroupedAggregateHashTable::TryAddDictionaryGroups(DataChunk &groups return optional_idx(); } } - auto &dictionary_vector = DictionaryVector::Child(dict_col); - auto &offsets = DictionaryVector::SelVector(dict_col); + const auto &dictionary_vector = DictionaryVector::Child(dict_col); + const auto &offsets = DictionaryVector::SelVector(dict_col); auto &dict_state = state.dict_state; if (dict_state.dictionary_id.empty() || dict_state.dictionary_id != dictionary_id) { // new dictionary - initialize the index state @@ -408,6 +412,10 @@ optional_idx GroupedAggregateHashTable::TryAddDictionaryGroups(DataChunk &groups } memset(dict_state.found_entry.get(), 0, dict_size * sizeof(bool)); dict_state.dictionary_id = dictionary_id; + dict_state.resolved_count = 0; + dict_state.address_high_bits = ~uint64_t(0); + dict_state.address_high_bits_uniform = + (sizeof(uintptr_t) == sizeof(uint64_t)); // only relevant on 64-bits systems } else if (dict_size > dict_state.capacity) { throw InternalException("AggregateHT - using cached dictionary data but dictionary has changed (dictionary id " "%s - dict size %d, current capacity %d)", @@ -419,11 +427,15 @@ optional_idx GroupedAggregateHashTable::TryAddDictionaryGroups(DataChunk &groups idx_t unique_count = 0; // for each of the dictionary entries - check if we have already done a look-up into the hash table // if we have, we can just use the cached group pointers - for (idx_t i = 0; i < groups.size(); i++) { - auto dict_idx = offsets.get_index(i); - unique_entries.set_index(unique_count, dict_idx); - unique_count += !found_entry[dict_idx]; - found_entry[dict_idx] = true; + // once every dict slot has been seen the walk would produce no new entries - skip it + if (dict_state.resolved_count < dict_size) { + for (idx_t i = 0; i < groups.size(); i++) { + auto dict_idx = offsets.get_index(i); + unique_entries.set_index(unique_count, dict_idx); + unique_count += !found_entry[dict_idx]; + found_entry[dict_idx] = true; + } + dict_state.resolved_count += unique_count; } auto &new_dictionary_pointers = dict_state.new_dictionary_pointers; idx_t new_group_count = 0; @@ -434,7 +446,7 @@ optional_idx GroupedAggregateHashTable::TryAddDictionaryGroups(DataChunk &groups } // slice the dictionary unique_values.data[0].Slice(dictionary_vector, unique_entries, unique_count); - unique_values.SetCardinality(unique_count); + unique_values.CheckCardinality(unique_count); // now we know which entries we are going to add - hash them auto &hashes = dict_state.hashes; unique_values.Hash(hashes); @@ -455,7 +467,17 @@ optional_idx GroupedAggregateHashTable::TryAddDictionaryGroups(DataChunk &groups auto dict_addresses = FlatVector::ScatterWriter(dictionary_addresses); for (idx_t i = 0; i < unique_count; i++) { auto dict_idx = unique_entries.get_index(i); - dict_addresses[dict_idx] = new_dict_addresses[i] + layout_ptr->GetAggrOffset(); + const uintptr_t addr = new_dict_addresses[i] + layout_ptr->GetAggrOffset(); + static constexpr uint64_t GID_HIGH_MASK = ~(ClusteredAggr::MAX_GID_COUNT - 1); + if (dict_state.address_high_bits_uniform) { // for clustered aggregation: check high bit uniformity + const uint64_t high_bits = static_cast(addr) & GID_HIGH_MASK; + if (dict_state.address_high_bits == ~uint64_t(0)) { // uninitialized + dict_state.address_high_bits = high_bits; + } else if (high_bits != dict_state.address_high_bits) { + dict_state.address_high_bits_uniform = false; + } + } + dict_addresses[dict_idx] = addr; } // now set up the addresses for the aggregates auto result_addresses = FlatVector::Writer(state.addresses, groups.size()); @@ -463,9 +485,10 @@ optional_idx GroupedAggregateHashTable::TryAddDictionaryGroups(DataChunk &groups auto dict_idx = offsets.get_index(i); result_addresses.WriteValue(dict_addresses[dict_idx]); } + FlatVector::SetSize(state.addresses, groups.size()); - // finally process the aggregates - UpdateAggregates(payload, filter); + // finally process the aggregates (ht_offsets are only valid for unique entries, not the full payload) + UpdateAggregates(payload, filter, groups.size(), false); return new_group_count; } @@ -478,6 +501,8 @@ optional_idx GroupedAggregateHashTable::TryAddConstantGroups(DataChunk &groups, return optional_idx(); } #endif + // Capture row_count before Reference+SetChildCardinality, which share the buffer and would corrupt groups.size() + const idx_t row_count = groups.size(); auto &dict_state = state.dict_state; auto &unique_values = dict_state.unique_values; if (unique_values.ColumnCount() == 0) { @@ -485,8 +510,11 @@ optional_idx GroupedAggregateHashTable::TryAddConstantGroups(DataChunk &groups, } // slice the dictionary unique_values.Reference(groups); - unique_values.SetCardinality(1); + unique_values.SetChildCardinality(1); unique_values.Flatten(); + // Restore the groups chunk's buffer v_size which was corrupted to 1 by SetChildCardinality(1) above. + // unique_values.Flatten() created new independent buffers, so restoring groups is safe. + groups.SetChildCardinality(row_count); auto &hashes = dict_state.hashes; unique_values.Hash(hashes); @@ -504,14 +532,18 @@ optional_idx GroupedAggregateHashTable::TryAddConstantGroups(DataChunk &groups, // process the aggregates // FIXME: This should just be a CONSTANT_VECTOR but subsequent operations assume FLAT_VECTOR auto new_dict_addresses = FlatVector::GetData(new_dictionary_pointers); - auto result_addresses = FlatVector::Writer(state.addresses, payload.size()); + auto result_addresses = FlatVector::Writer(state.addresses, row_count); uintptr_t aggregate_address = new_dict_addresses[0] + layout_ptr->GetAggrOffset(); - for (idx_t i = 0; i < payload.size(); i++) { + static constexpr uint64_t GID_HIGH_MASK = ~(ClusteredAggr::MAX_GID_COUNT - 1); + dict_state.address_high_bits_uniform = (sizeof(uintptr_t) == sizeof(uint64_t)); + dict_state.address_high_bits = static_cast(aggregate_address) & GID_HIGH_MASK; + for (idx_t i = 0; i < row_count; i++) { result_addresses.WriteValue(aggregate_address); } - FlatVector::SetSize(state.addresses, payload.size()); + FlatVector::SetSize(state.addresses, row_count); state.addresses.SetVectorType(VectorType::CONSTANT_VECTOR); - UpdateAggregates(payload, filter); + // ht_offsets are only valid for the single constant group, not the full payload + UpdateAggregates(payload, filter, row_count, false); state.addresses.SetVectorType(VectorType::FLAT_VECTOR); return new_group_count; @@ -544,8 +576,57 @@ idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload, return AddChunk(groups, state.hashes, payload, filter); } -void GroupedAggregateHashTable::UpdateAggregates(DataChunk &payload, const unsafe_vector &filter) { - // Now every cell has an entry, update the aggregates +bool GroupedAggregateHashTable::UpdateAggregatesClustered(DataChunk &payload, const unsafe_vector &filter, + idx_t count, bool ht_offsets_valid) { + if (skip_lookups) { + return false; + } + if (sizeof(uintptr_t) < sizeof(uint64_t)) { + return false; + } + ClusteredAggr clustered; + if (ht_offsets_valid) { + if (capacity >= ClusteredAggr::MAX_GID_COUNT) { + return false; + } + if (!clustered_state.TryBuild(clustered, FlatVector::GetData(state.ht_offsets), count)) { + return false; + } + const auto aggr_offset = layout_ptr->GetAggrOffset(); + clustered.InitializeStates([&](uint64_t gid) { + auto slot = static_cast(gid); + return entries[slot].GetPointer() + aggr_offset; + }); + } else { + // dictionary path: addresses are already resolved pointers — and gids are not even set then + // But, we can use the "addresses as gids". ClusteredAggr only triggers on low gid counts (see above) and + // for speed the TryClustered method exploits that by ignoring high gid bits. When using "addresses as gids" + // we thus need to know that all high bits are the same (this is very typically the case). + if (!state.dict_state.address_high_bits_uniform) { + return false; + } + auto addrs = FlatVector::GetData(state.addresses); + if (!clustered_state.TryBuild(clustered, addrs, count)) { + return false; + } + clustered.InitializeStates( + [&](uint64_t gid) { return reinterpret_cast(state.dict_state.address_high_bits | gid); }); + } + + const bool skip_addresses = clustered_state.all_clustered; + auto &aggregates = layout_ptr->GetAggregates(); + RowOperations::UpdateStatesClustered(state.row_state, aggregates, &filter_set, &filter, state.addresses, payload, + clustered, skip_addresses); + return true; +} + +void GroupedAggregateHashTable::UpdateAggregates(DataChunk &payload, const unsafe_vector &filter, idx_t count, + bool ht_offsets_valid) { + if (UpdateAggregatesClustered(payload, filter, count, ht_offsets_valid)) { + Verify(); + return; + } + auto &aggregates = layout_ptr->GetAggregates(); idx_t filter_idx = 0; idx_t payload_idx = 0; @@ -554,7 +635,7 @@ void GroupedAggregateHashTable::UpdateAggregates(DataChunk &payload, const unsaf if (filter_idx >= filter.size() || i < filter[filter_idx]) { // Skip all the aggregates that are not in the filter payload_idx += aggr.child_count; - VectorOperations::AddInPlace(state.addresses, NumericCast(aggr.payload_size), payload.size()); + VectorOperations::AddInPlace(state.addresses, NumericCast(aggr.payload_size)); continue; } D_ASSERT(i == filter[filter_idx]); @@ -563,12 +644,12 @@ void GroupedAggregateHashTable::UpdateAggregates(DataChunk &payload, const unsaf RowOperations::UpdateFilteredStates(state.row_state, filter_set.GetFilterData(i), aggr, state.addresses, payload, payload_idx); } else { - RowOperations::UpdateStates(state.row_state, aggr, state.addresses, payload, payload_idx, payload.size()); + RowOperations::UpdateStates(state.row_state, aggr, state.addresses, payload, payload_idx); } // Move to the next aggregate payload_idx += aggr.child_count; - VectorOperations::AddInPlace(state.addresses, NumericCast(aggr.payload_size), payload.size()); + VectorOperations::AddInPlace(state.addresses, NumericCast(aggr.payload_size)); filter_idx++; } @@ -589,9 +670,9 @@ idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashe #endif const auto new_group_count = FindOrCreateGroups(groups, group_hashes, state.addresses, state.new_groups); - VectorOperations::AddInPlace(state.addresses, NumericCast(layout_ptr->GetAggrOffset()), payload.size()); + VectorOperations::AddInPlace(state.addresses, NumericCast(layout_ptr->GetAggrOffset())); - UpdateAggregates(payload, filter); + UpdateAggregates(payload, filter, groups.size()); return new_group_count; } @@ -605,7 +686,7 @@ void GroupedAggregateHashTable::FetchAggregates(DataChunk &groups, DataChunk &re } #endif - result.SetCardinality(groups); + result.SetChildCardinality(groups.size()); if (groups.size() == 0) { return; } @@ -681,7 +762,6 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V state.group_chunk.data[grp_idx].Reference(groups.data[grp_idx]); } state.group_chunk.data[groups.ColumnCount()].Reference(group_hashes_v); - state.group_chunk.SetCardinality(groups); // convert all vectors to unified format TupleDataCollection::ToUnifiedFormat(state.partitioned_append_state.chunk_state, state.group_chunk); @@ -711,6 +791,7 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V addresses[i] = row_location; } count += chunk_size; + FlatVector::SetSize(addresses_v, chunk_size); return chunk_size; } @@ -822,6 +903,7 @@ idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, V throw InternalException("Maximum outer iteration count reached in GroupedAggregateHashTable"); } + FlatVector::SetSize(addresses_v, chunk_size); count += new_group_count; return new_group_count; } @@ -909,13 +991,11 @@ void GroupedAggregateHashTable::Combine(TupleDataCollection &other_data, optiona while (fm_state.Scan()) { // Check for interrupts with each chunk context.InterruptCheck(); - const auto input_chunk_size = fm_state.groups.size(); FindOrCreateGroups(fm_state.groups, fm_state.hashes, fm_state.group_addresses, fm_state.new_groups_sel); RowOperations::CombineStates(state.row_state, *layout_ptr, fm_state.scan_state.chunk_state.row_locations, - fm_state.group_addresses, input_chunk_size); + fm_state.group_addresses); if (layout_ptr->HasDestructor()) { - RowOperations::DestroyStates(state.row_state, *layout_ptr, fm_state.scan_state.chunk_state.row_locations, - input_chunk_size); + RowOperations::DestroyStates(state.row_state, *layout_ptr, fm_state.scan_state.chunk_state.row_locations); } if (progress) { @@ -960,4 +1040,67 @@ bool GroupedAggregateHashTable::Scan(AggregateHTScanState &scan_state, DataChunk } } } + +void GroupedAggregateHashTable::ResetForNewIteration(idx_t initial_capacity, idx_t radix_bits_p) { + // Save the previous iteration's group count before destroying aggregate states. + // This lets us size the pointer table based on actual prior data rather than the + // global sink capacity, which is typically much larger than recursive iteration sizes. + const auto prev_count = count; + Destroy(); + aggregate_allocator->Reset(); + stored_allocators.clear(); + + const auto reuse_partitioned_data = partitioned_data && RadixPartitioning::RadixBitsOfPowerOfTwo( + partitioned_data->PartitionCount()) == radix_bits_p; + const auto reuse_unpartitioned_data = + radix_bits_p >= UNPARTITIONED_RADIX_BITS_THRESHOLD && unpartitioned_data && + RadixPartitioning::RadixBitsOfPowerOfTwo(unpartitioned_data->PartitionCount()) == 0; + + radix_bits = radix_bits_p; + if (reuse_partitioned_data) { + if (partitioned_data->Count() != 0) { + partitioned_data->Reset(); + } + partitioned_data->ResetAppendState(state.partitioned_append_state, + TupleDataPinProperties::KEEP_EVERYTHING_PINNED); + } else { + InitializePartitionedData(); + } + if (radix_bits >= UNPARTITIONED_RADIX_BITS_THRESHOLD) { + if (reuse_unpartitioned_data) { + if (unpartitioned_data->Count() != 0) { + unpartitioned_data->Reset(); + } + unpartitioned_data->ResetAppendState(state.unpartitioned_append_state, + TupleDataPinProperties::KEEP_EVERYTHING_PINNED); + } else { + InitializeUnpartitionedData(); + } + } else { + unpartitioned_data.reset(); + } + + count = 0; + sink_count = 0; + skip_lookups = false; + enable_hll = false; + hll = HyperLogLog(); + state.dict_state.dictionary_id = string(); + + // Compute effective capacity based on the previous iteration's actual group count. + // This avoids clearing a large pointer table (O(max_capacity)) when prior iterations + // processed few rows. The HT will grow during the iteration if more groups are encountered. + const auto effective_capacity = GetCapacityForCount(prev_count); + if (!hash_map.IsSet() || capacity < effective_capacity) { + Resize(effective_capacity); + } else { + // Logically shrink to effective_capacity: only zero the portion we need. + // Entries beyond effective_capacity are unreachable (bitmask limits all accesses), + // so leaving them uncleared is safe. The physical buffer is reused in place. + capacity = effective_capacity; + entries = reinterpret_cast(hash_map.get()); + bitmask = capacity - 1; + ClearPointerTable(); + } +} } // namespace duckdb diff --git a/src/duckdb/src/execution/base_aggregate_hashtable.cpp b/src/duckdb/src/execution/base_aggregate_hashtable.cpp index b558a326c..79d3a7d5f 100644 --- a/src/duckdb/src/execution/base_aggregate_hashtable.cpp +++ b/src/duckdb/src/execution/base_aggregate_hashtable.cpp @@ -12,4 +12,26 @@ BaseAggregateHashTable::BaseAggregateHashTable(ClientContext &context, Allocator filter_set.Initialize(context, aggregates, payload_types); } +bool BaseAggregateHashTable::AllAggregatesClustered(const vector &aggregates) { + if (aggregates.empty()) { + return false; + } + for (auto &aggregate : aggregates) { + if (aggregate.filter || !aggregate.function.GetStateClusterUpdateCallback()) { + return false; + } + } + return true; +} + +idx_t BaseAggregateHashTable::CountAggregatesClustered(const vector &aggregates) { + idx_t count = 0; + for (auto &aggregate : aggregates) { + if (!aggregate.filter && aggregate.function.GetStateClusterUpdateCallback()) { + count++; + } + } + return count; +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/column_binding_resolver.cpp b/src/duckdb/src/execution/column_binding_resolver.cpp index 993e09c8a..da74e1cfe 100644 --- a/src/duckdb/src/execution/column_binding_resolver.cpp +++ b/src/duckdb/src/execution/column_binding_resolver.cpp @@ -189,21 +189,22 @@ void ColumnBindingResolver::VisitOperator(LogicalOperator &op) { unique_ptr ColumnBindingResolver::VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) { - D_ASSERT(expr.depth == 0); + D_ASSERT(expr.Depth() == 0); // check the current set of column bindings to see which index corresponds to the column reference for (idx_t i = 0; i < bindings.size(); i++) { - if (expr.binding == bindings[i]) { + if (expr.Binding() == bindings[i]) { if (!types.empty()) { if (bindings.size() != types.size()) { throw InternalException( "Failed to bind column reference \"%s\" [%d.%d]: inequal num bindings/types (%llu != %llu)", - expr.GetAlias(), expr.binding.table_index.index, expr.binding.column_index, bindings.size(), + expr.GetAlias(), expr.Binding().table_index.index, expr.Binding().column_index, bindings.size(), types.size()); } if (expr.GetReturnType() != types[i]) { throw InternalException("Failed to bind column reference \"%s\" [%d.%d]: inequal types (%s != %s)", - expr.GetAlias(), expr.binding.table_index.index, expr.binding.column_index, - expr.GetReturnType().ToString(), types[i].ToString()); + expr.GetAlias(), expr.Binding().table_index.index, + expr.Binding().column_index, expr.GetReturnType().ToString(), + types[i].ToString()); } } if (verify_only) { @@ -217,7 +218,7 @@ unique_ptr ColumnBindingResolver::VisitReplace(BoundColumnRefExpress // could not bind the column reference, this should never happen and indicates a bug in the code // generate an error message throw InternalException("Failed to bind column reference \"%s\" [%d.%d] (bindings: %s)", expr.GetAlias(), - expr.binding.table_index.index, expr.binding.column_index, + expr.Binding().table_index.index, expr.Binding().column_index, LogicalOperator::ColumnBindingsToString(bindings)); // LCOV_EXCL_STOP } diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp index cde2751d8..e8a49fade 100644 --- a/src/duckdb/src/execution/expression_executor.cpp +++ b/src/duckdb/src/execution/expression_executor.cpp @@ -89,7 +89,7 @@ void ExpressionExecutor::Execute(DataChunk *input, DataChunk &result) { for (idx_t i = 0; i < expressions.size(); i++) { ExecuteExpression(i, result.data[i]); } - result.SetCardinality(input ? input->size() : 1); + result.SetChildCardinality(input ? input->size() : 1); result.Verify(context); } @@ -162,13 +162,12 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co expr.GetVerificationStats()->Verify(vector, count); } if (debug_vector_verification == DebugVectorVerification::DICTIONARY_EXPRESSION) { - Vector::DebugTransformToDictionary(vector, count); + Vector::DebugTransformToDictionary(vector); } if (debug_vector_verification == DebugVectorVerification::VARIANT_VECTOR) { if (TypeVisitor::Contains(vector.GetType(), [](const LogicalType &type) { if (type.IsJSONType() || type.id() == LogicalTypeId::VARIANT || type.id() == LogicalTypeId::UNION || - type.id() == LogicalTypeId::ENUM || type.id() == LogicalTypeId::LEGACY_AGGREGATE_STATE || - type.id() == LogicalTypeId::AGGREGATE_STATE || type.id() == LogicalTypeId::TYPE) { + type.id() == LogicalTypeId::ENUM || type.id() == LogicalTypeId::TYPE) { return true; } if (type.id() == LogicalTypeId::STRUCT && StructType::IsUnnamed(type)) { @@ -286,7 +285,7 @@ void ExpressionExecutor::Execute(const Expression &expr, ExpressionState *state, default: throw InternalException("Attempting to execute expression of unknown type!"); } - if (expr.GetExpressionClass() != ExpressionClass::BOUND_REF) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_REF && result.size() != count_t(count)) { // BoundReferenceExpression shares buffer with its source - we cannot resize it // all other expressions produce a fresh result vector that we own FlatVector::SetSize(result, count_t(count)); diff --git a/src/duckdb/src/execution/expression_executor/execute_case.cpp b/src/duckdb/src/execution/expression_executor/execute_case.cpp index b266e0a17..05667af50 100644 --- a/src/duckdb/src/execution/expression_executor/execute_case.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_case.cpp @@ -21,11 +21,11 @@ struct CaseExpressionState : public ExpressionState { unique_ptr ExpressionExecutor::InitializeState(const BoundCaseExpression &expr, ExpressionExecutorState &root) { auto result = make_uniq(expr, root); - for (auto &case_check : expr.case_checks) { + for (auto &case_check : expr.CaseChecks()) { result->AddChild(*case_check.when_expr); result->AddChild(*case_check.then_expr); } - result->AddChild(*expr.else_expr); + result->AddChild(expr.Else()); result->Finalize(); return std::move(result); @@ -42,8 +42,8 @@ void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionStat auto current_false_sel = &state.false_sel; auto current_sel = sel; idx_t current_count = count; - for (idx_t i = 0; i < expr.case_checks.size(); i++) { - auto &case_check = expr.case_checks[i]; + for (idx_t i = 0; i < expr.CaseChecks().size(); i++) { + auto &case_check = expr.CaseChecks()[i]; auto &intermediate_result = state.intermediate_chunk.data[i * 2 + 1]; auto check_state = state.child_states[i * 2].get(); auto then_state = state.child_states[i * 2 + 1].get(); @@ -77,13 +77,13 @@ void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionStat auto else_state = state.child_states.back().get(); if (current_count == count) { // everything was false, we can just evaluate the else expression directly - Execute(*expr.else_expr, else_state, sel, count, result); + Execute(expr.Else(), else_state, sel, count, result); return; } else { - auto &intermediate_result = state.intermediate_chunk.data[expr.case_checks.size() * 2]; + auto &intermediate_result = state.intermediate_chunk.data[expr.CaseChecks().size() * 2]; D_ASSERT(current_sel); - Execute(*expr.else_expr, else_state, current_sel, current_count, intermediate_result); + Execute(expr.Else(), else_state, current_sel, current_count, intermediate_result); FillSwitch(intermediate_result, result, *current_sel, NumericCast(current_count)); } } @@ -93,7 +93,7 @@ void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionStat } template -void TemplatedFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { +void TemplatedFillLoop(const Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { result.SetVectorType(VectorType::FLAT_VECTOR); auto res = FlatVector::GetDataMutable(result); auto &result_mask = FlatVector::ValidityMutable(result); @@ -122,7 +122,7 @@ void TemplatedFillLoop(Vector &vector, Vector &result, const SelectionVector &se } } -void ValidityFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { +void ValidityFillLoop(const Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { result.SetVectorType(VectorType::FLAT_VECTOR); auto &result_mask = FlatVector::ValidityMutable(result); if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -144,7 +144,7 @@ void ValidityFillLoop(Vector &vector, Vector &result, const SelectionVector &sel } } -void ExpressionExecutor::FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { +void ExpressionExecutor::FillSwitch(const Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { switch (result.GetType().InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: diff --git a/src/duckdb/src/execution/expression_executor/execute_cast.cpp b/src/duckdb/src/execution/expression_executor/execute_cast.cpp index be3b9bb26..a75a6201f 100644 --- a/src/duckdb/src/execution/expression_executor/execute_cast.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_cast.cpp @@ -1,7 +1,5 @@ #include "duckdb/common/vector/flat_vector.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/scalar_function.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" namespace duckdb { @@ -9,13 +7,13 @@ namespace duckdb { unique_ptr ExpressionExecutor::InitializeState(const BoundCastExpression &expr, ExpressionExecutorState &root) { auto result = make_uniq(expr, root); - result->AddChild(*expr.child); + result->AddChild(expr.Child()); result->Finalize(); - if (expr.bound_cast.HasInitLocalState()) { + if (expr.GetBoundCast().HasInitLocalState()) { auto context_ptr = root.executor->HasContext() ? &root.executor->GetContext() : nullptr; - CastLocalStateParameters parameters(context_ptr, expr.bound_cast.GetCastData()); - result->local_state = expr.bound_cast.InitLocalState(parameters); + CastLocalStateParameters parameters(context_ptr, expr.GetBoundCast().GetCastData()); + result->local_state = expr.GetBoundCast().InitLocalState(parameters); } return std::move(result); } @@ -30,13 +28,13 @@ void ExpressionExecutor::Execute(const BoundCastExpression &expr, ExpressionStat auto &child = state->intermediate_chunk.data[0]; auto child_state = state->child_states[0].get(); - Execute(*expr.child, child_state, sel, count, child); + Execute(expr.Child(), child_state, sel, count, child); string error_message; - auto error_ref = expr.try_cast ? &error_message : nullptr; - CastParameters parameters(expr.bound_cast.GetCastData(), false, error_ref, lstate); + auto error_ref = expr.IsTryCast() ? &error_message : nullptr; + CastParameters parameters(expr.GetBoundCast().GetCastData(), false, error_ref, lstate); parameters.query_location = expr.GetQueryLocation(); - parameters.cast_source = expr.child.get(); + parameters.cast_source = &expr.Child(); parameters.cast_target = expr; idx_t cast_count = count; bool all_constant = child.GetVectorType() == VectorType::CONSTANT_VECTOR; @@ -52,7 +50,7 @@ void ExpressionExecutor::Execute(const BoundCastExpression &expr, ExpressionStat FlatVector::SetSize(child, 1ULL); cast_count = 1; } - expr.bound_cast.Cast(child, result, cast_count, parameters); + expr.GetBoundCast().Cast(child, result, cast_count, parameters); if (all_constant) { if (child.GetVectorType() != VectorType::CONSTANT_VECTOR) { child.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -60,11 +58,7 @@ void ExpressionExecutor::Execute(const BoundCastExpression &expr, ExpressionStat // restore the size of the input vector FlatVector::SetSize(child, count); // ensure the result type is constant - if (result.GetVectorType() != VectorType::FLAT_VECTOR && - result.GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.Flatten(); - } - result.SetVectorType(VectorType::CONSTANT_VECTOR); + result.FlattenAndSetConstant(); } FlatVector::SetSize(result, count_t(count)); } diff --git a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp index 3b5cefc27..30029c3e5 100644 --- a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp @@ -15,10 +15,10 @@ namespace duckdb { // Select comparisons //===--------------------------------------------------------------------===// template -static bool TryPrimitiveSelectOperation(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, optional_ptr null_mask, - idx_t &result) { +static bool TryPrimitiveSelectOperation(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, + optional_ptr null_mask, idx_t &result) { #ifdef DUCKDB_SMALLER_BINARY return false; #else @@ -101,12 +101,12 @@ static bool TryPrimitiveSelectOperation(Vector &left, Vector &right, optional_pt } template -static idx_t ComparatorSelectOperation(Vector &left, Vector &right, optional_ptr sel, +static idx_t ComparatorSelectOperation(const Vector &left, const Vector &right, optional_ptr sel, idx_t count, optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask, PREDICATE predicate) { Vector comparator_result(LogicalType::TINYINT, count); - VectorOperations::Comparator(left, right, comparator_result, count); + VectorOperations::ComparatorFill(left, right, comparator_result, count); auto cmp_data = comparator_result.Values(); if (!sel) { @@ -141,9 +141,9 @@ static idx_t ComparatorSelectOperation(Vector &left, Vector &right, optional_ptr return true_count; } -idx_t VectorOperations::Equals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { +idx_t VectorOperations::Equals(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask) { idx_t result; if (TryPrimitiveSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask, result)) { return result; @@ -152,9 +152,9 @@ idx_t VectorOperations::Equals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { +idx_t VectorOperations::NotEquals(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask) { idx_t result; if (TryPrimitiveSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask, result)) { @@ -164,9 +164,9 @@ idx_t VectorOperations::NotEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { +idx_t VectorOperations::GreaterThan(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask) { idx_t result; if (TryPrimitiveSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask, result)) { @@ -176,8 +176,9 @@ idx_t VectorOperations::GreaterThan(Vector &left, Vector &right, optional_ptr 0; }); } -idx_t VectorOperations::GreaterThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, +idx_t VectorOperations::GreaterThanEquals(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { idx_t result; @@ -189,9 +190,9 @@ idx_t VectorOperations::GreaterThanEquals(Vector &left, Vector &right, optional_ [](int8_t v) { return v >= 0; }); } -idx_t VectorOperations::LessThan(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { +idx_t VectorOperations::LessThan(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask) { idx_t result; if (TryPrimitiveSelectOperation(right, left, sel, count, true_sel, false_sel, null_mask, result)) { @@ -201,7 +202,7 @@ idx_t VectorOperations::LessThan(Vector &left, Vector &right, optional_ptr sel, +idx_t VectorOperations::LessThanEquals(const Vector &left, const Vector &right, optional_ptr sel, idx_t count, optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask) { idx_t result; diff --git a/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp b/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp index d45601e15..e3bb5d953 100644 --- a/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp @@ -21,7 +21,7 @@ struct ConjunctionState : public ExpressionState { unique_ptr ExpressionExecutor::InitializeState(const BoundConjunctionExpression &expr, ExpressionExecutorState &root) { auto result = make_uniq(expr, root); - for (auto &child : expr.children) { + for (auto &child : expr.GetChildren()) { result->AddChild(*child); } @@ -33,9 +33,9 @@ void ExpressionExecutor::Execute(const BoundConjunctionExpression &expr, Express const SelectionVector *sel, idx_t count, Vector &result) { // execute the children state->intermediate_chunk.Reset(); - for (idx_t i = 0; i < expr.children.size(); i++) { + for (idx_t i = 0; i < expr.GetChildren().size(); i++) { auto ¤t_result = state->intermediate_chunk.data[i]; - Execute(*expr.children[i], state->child_states[i].get(), sel, count, current_result); + Execute(*expr.GetChildren()[i], state->child_states[i].get(), sel, count, current_result); if (i == 0) { // move the result result.Reference(current_result); @@ -44,10 +44,10 @@ void ExpressionExecutor::Execute(const BoundConjunctionExpression &expr, Express // AND/OR together switch (expr.GetExpressionType()) { case ExpressionType::CONJUNCTION_AND: - VectorOperations::And(current_result, result, intermediate, count); + VectorOperations::And(current_result, result, intermediate); break; case ExpressionType::CONJUNCTION_OR: - VectorOperations::Or(current_result, result, intermediate, count); + VectorOperations::Or(current_result, result, intermediate); break; default: throw InternalException("Unknown conjunction type!"); @@ -78,9 +78,9 @@ idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, Express temp_true = make_uniq(STANDARD_VECTOR_SIZE); true_sel = temp_true.get(); } - for (idx_t i = 0; i < expr.children.size(); i++) { - idx_t tcount = Select(*expr.children[permutation[i]], state.child_states[permutation[i]].get(), current_sel, - current_count, true_sel, temp_false.get()); + for (idx_t i = 0; i < expr.GetChildren().size(); i++) { + idx_t tcount = Select(*expr.GetChildren()[permutation[i]], state.child_states[permutation[i]].get(), + current_sel, current_count, true_sel, temp_false.get()); idx_t fcount = current_count - tcount; if (fcount > 0 && false_sel) { // move failing tuples into the false_sel @@ -119,9 +119,9 @@ idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, Express temp_false = make_uniq(STANDARD_VECTOR_SIZE); false_sel = temp_false.get(); } - for (idx_t i = 0; i < expr.children.size(); i++) { - idx_t tcount = Select(*expr.children[permutation[i]], state.child_states[permutation[i]].get(), current_sel, - current_count, temp_true.get(), false_sel); + for (idx_t i = 0; i < expr.GetChildren().size(); i++) { + idx_t tcount = Select(*expr.GetChildren()[permutation[i]], state.child_states[permutation[i]].get(), + current_sel, current_count, temp_true.get(), false_sel); if (tcount > 0) { if (true_sel) { // tuples passed, move them into the actual result vector diff --git a/src/duckdb/src/execution/expression_executor/execute_constant.cpp b/src/duckdb/src/execution/expression_executor/execute_constant.cpp index 7f3e6b095..00cdbaa8b 100644 --- a/src/duckdb/src/execution/expression_executor/execute_constant.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_constant.cpp @@ -13,8 +13,8 @@ unique_ptr ExpressionExecutor::InitializeState(const BoundConst void ExpressionExecutor::Execute(const BoundConstantExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.value.type() == expr.GetReturnType()); - result.Reference(expr.value, count_t(count)); + D_ASSERT(expr.GetValue().type() == expr.GetReturnType()); + result.Reference(expr.GetValue(), count_t(count)); } } // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_function.cpp b/src/duckdb/src/execution/expression_executor/execute_function.cpp index 11979802d..992376729 100644 --- a/src/duckdb/src/execution/expression_executor/execute_function.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_function.cpp @@ -1,8 +1,8 @@ -#include "duckdb/common/type_visitor.hpp" +#include "duckdb/common/enums/debug_verification_mode.hpp" #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/config.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/types/uuid.hpp" namespace duckdb { @@ -21,7 +21,7 @@ ExecuteFunctionState::ExecuteFunctionState(const Expression &expr, ExpressionExe switch (expr.GetExpressionClass()) { case ExpressionClass::BOUND_FUNCTION: { auto &bound_function = expr.Cast(); - auto &children = bound_function.children; + auto &children = bound_function.GetChildren(); for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { auto &child = *children[child_idx]; if (child.IsFoldable()) { @@ -102,11 +102,11 @@ bool ExecuteFunctionState::TryExecuteDictionaryExpression(const BoundFunctionExp // Offset the input dictionary Vector offset_input(DictionaryVector::Child(unary_input), offset, offset + count); input_chunk.data[input_col_idx.GetIndex()].Reference(offset_input); - input_chunk.SetCardinality(count); + input_chunk.SetChildCardinality(count); // Execute, storing the result in an intermediate vector, and copying it to the output dictionary Vector output_intermediate(result.GetType()); - expr.function.GetFunctionCallback()(input_chunk, state, output_intermediate); + expr.Function().GetFunctionCallback()(input_chunk, state, output_intermediate); VectorOperations::Copy(output_intermediate, output_dictionary->data, count, 0, offset); } } @@ -130,27 +130,29 @@ void ExecuteFunctionState::ResetDictionaryStates() { unique_ptr ExpressionExecutor::InitializeState(const BoundFunctionExpression &expr, ExpressionExecutorState &root) { auto result = make_uniq(expr, root); - for (auto &child : expr.children) { + for (auto &child : expr.GetChildren()) { result->AddChild(*child); } result->Finalize(); - if (expr.function.HasInitStateCallback()) { - result->local_state = expr.function.GetInitStateCallback()(*result, expr, expr.bind_info.get()); + if (expr.Function().HasInitStateCallback()) { + result->local_state = expr.Function().GetInitStateCallback()(*result, expr, expr.BindInfo().get()); } return std::move(result); } static void VerifyNullHandling(const BoundFunctionExpression &expr, DataChunk &args, Vector &result) { -#ifdef DEBUG - if (args.data.empty() || expr.function.GetNullHandling() != FunctionNullHandling::DEFAULT_NULL_HANDLING) { + if (DBConfigOptions::global_verification_mode != DebugVerificationMode::VERIFY_FUNCTIONS) { + return; + } + if (args.data.empty() || expr.Function().GetNullHandling() != FunctionNullHandling::DEFAULT_NULL_HANDLING) { return; } // Combine all the argument validity masks into a flat validity mask idx_t count = args.size(); ValidityMask combined_mask(count); - for (auto &arg : args.data) { + for (const auto &arg : args.data) { auto entries = arg.Validity(); if (!entries.CanHaveNull()) { continue; @@ -165,53 +167,31 @@ static void VerifyNullHandling(const BoundFunctionExpression &expr, DataChunk &a // Default is that if any of the arguments are NULL, the result is also NULL auto result_validity = result.Validity(); for (idx_t i = 0; i < count; i++) { - if (!combined_mask.RowIsValid(i)) { - D_ASSERT(!result_validity.IsValid(i)); + if (!combined_mask.RowIsValid(i) && result_validity.IsValid(i)) { + throw InternalException( + "VerifyNullHandling failed for scalar function \"%s\": row %d has a NULL argument but the result is " + "not NULL - functions with default NULL handling should return NULL for any NULL input", + expr.Function().GetName(), i); } } -#endif } -static void ExecuteSelectFunction(const BoundFunctionExpression &expr, DataChunk &args, ExpressionState &state, - Vector &result) { - if (expr.GetReturnType() != LogicalType::BOOLEAN) { - throw InvalidInputException("Function %s only has a select callback but returns %s", expr.function.GetName(), - expr.GetReturnType().ToString()); - } - if (expr.function.GetNullHandling() == FunctionNullHandling::SPECIAL_HANDLING) { - throw InvalidInputException("Function %s only has a select callback with SPECIAL_HANDLING but projected " - "execution requires a scalar callback to produce NULL results", - expr.function.GetName()); - } +static idx_t SelectBooleanResult(Vector &result, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, + SelectionVector *false_sel) { + return UnaryExecutor::Select( + result, sel, count, [](bool value) { return value; }, true_sel, false_sel); +} - result.SetVectorType(VectorType::FLAT_VECTOR); - auto count = args.size(); - auto result_data = FlatVector::GetDataMutable(result); - for (idx_t i = 0; i < count; i++) { - result_data[i] = false; - } +static void ExecuteConstantSelectFunction(const BoundFunctionExpression &expr, DataChunk &args, ExpressionState &state, + Vector &result) { + D_ASSERT(args.size() == 1); + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::Validity(result).SetAllValid(1); - auto &result_validity = FlatVector::ValidityMutable(result); - result_validity.SetAllValid(count); - D_ASSERT(expr.function.GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING); - for (auto &arg : args.data) { - auto entries = arg.Validity(); - if (!entries.CanHaveNull()) { - continue; - } - for (idx_t i = 0; i < count; i++) { - if (!entries.IsValid(i)) { - result_validity.SetInvalid(i); - } - } - } - - SelectionVector true_sel(count); - auto true_count = - expr.function.GetSelectCallback()(args, state, FlatVector::IncrementalSelectionVector(), &true_sel, nullptr); - for (idx_t i = 0; i < true_count; i++) { - result_data[true_sel.get_index(i)] = true; - } + SelectionVector true_sel(1); + SelectionVector false_sel(1); + auto true_count = expr.Function().GetSelectCallback()(args, state, nullptr, &true_sel, &false_sel); + *ConstantVector::GetData(result) = true_count == 1; } void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, ExpressionState *state, @@ -220,15 +200,15 @@ void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, Expression auto &arguments = state->intermediate_chunk; // if the input is constant and there function is non-volatile we only need to run it on one value bool all_constant = true; - if (expr.function.GetStability() == FunctionStability::VOLATILE) { + if (expr.Function().GetStability() == FunctionStability::VOLATILE) { // we cannot optimize away constant vectors for volatile functions all_constant = false; } - auto default_null_handling = expr.function.GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING; + auto default_null_handling = expr.Function().GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING; if (!state->types.empty()) { - for (idx_t i = 0; i < expr.children.size(); i++) { - D_ASSERT(state->types[i] == expr.children[i]->GetReturnType()); - Execute(*expr.children[i], state->child_states[i].get(), sel, count, arguments.data[i]); + for (idx_t i = 0; i < expr.GetChildren().size(); i++) { + D_ASSERT(state->types[i] == expr.GetChildren()[i]->GetReturnType()); + Execute(*expr.GetChildren()[i], state->child_states[i].get(), sel, count, arguments.data[i]); if (arguments.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { all_constant = false; } else if (default_null_handling && ConstantVector::IsNull(arguments.data[i])) { @@ -242,21 +222,20 @@ void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, Expression // if all arguments are constant temporarily set the child cardinality to 1 arguments.SetChildCardinality(1ULL); } else { - arguments.SetCardinality(count); + arguments.SetChildCardinality(count); } arguments.Verify(context); auto &execute_function_state = state->Cast(); - auto dictionary_executed = expr.function.HasFunctionCallback() && !all_constant && + auto dictionary_executed = expr.Function().HasFunctionCallback() && !all_constant && execute_function_state.TryExecuteDictionaryExpression(expr, arguments, *state, result); if (!dictionary_executed) { - if (expr.function.HasFunctionCallback()) { - expr.function.GetFunctionCallback()(arguments, *state, result); - } else if (expr.function.HasSelectCallback()) { - ExecuteSelectFunction(expr, arguments, *state, result); + if (expr.Function().HasFunctionCallback()) { + expr.Function().GetFunctionCallback()(arguments, *state, result); + } else if (all_constant && expr.Function().HasSelectCallback()) { + ExecuteConstantSelectFunction(expr, arguments, *state, result); } else { - throw InternalException("Scalar function %s has neither an execution nor a select callback", - expr.function.GetName()); + throw InternalException("Scalar function %s has no execution callback", expr.Function().GetName()); } } if (all_constant) { @@ -266,11 +245,7 @@ void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, Expression } arguments.SetChildCardinality(count); // ensure the result type is constant - if (result.GetVectorType() != VectorType::FLAT_VECTOR && - result.GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.Flatten(); - } - result.SetVectorType(VectorType::CONSTANT_VECTOR); + result.FlattenAndSetConstant(); } FlatVector::SetSize(result, count_t(count)); @@ -281,19 +256,59 @@ void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, Expression idx_t ExpressionExecutor::Select(const BoundFunctionExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { - if (!expr.function.HasSelectCallback()) { + if (!expr.Function().HasSelectCallback()) { return DefaultSelect(expr, state, sel, count, true_sel, false_sel); } - // FIXME: push constant handling in here similar to Execute state->intermediate_chunk.Reset(); auto &arguments = state->intermediate_chunk; - for (idx_t i = 0; i < expr.children.size(); i++) { - D_ASSERT(state->types[i] == expr.children[i]->GetReturnType()); - Execute(*expr.children[i], state->child_states[i].get(), sel, count, arguments.data[i]); + bool all_constant = true; + if (expr.Function().GetStability() == FunctionStability::VOLATILE) { + all_constant = false; + } + auto default_null_handling = expr.Function().GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING; + for (idx_t i = 0; i < expr.GetChildren().size(); i++) { + D_ASSERT(state->types[i] == expr.GetChildren()[i]->GetReturnType()); + Execute(*expr.GetChildren()[i], state->child_states[i].get(), sel, count, arguments.data[i]); + if (arguments.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { + all_constant = false; + } else if (default_null_handling && ConstantVector::IsNull(arguments.data[i])) { + Vector result(LogicalType::BOOLEAN); + ConstantVector::SetNull(result, count_t(1)); + return SelectBooleanResult(result, sel, count, true_sel, false_sel); + } + } + if (all_constant) { + // if all arguments are constant we only need to run the function on one value + arguments.SetChildCardinality(1ULL); + } else { + arguments.SetChildCardinality(count); } - arguments.SetCardinality(count); arguments.Verify(context); - return expr.function.GetSelectCallback()(arguments, *state, sel, true_sel, false_sel); + if (all_constant) { + Vector result(LogicalType::BOOLEAN); + if (expr.Function().HasFunctionCallback()) { + expr.Function().GetFunctionCallback()(arguments, *state, result); + result.FlattenAndSetConstant(); + } else { + ExecuteConstantSelectFunction(expr, arguments, *state, result); + } + // restore the input cardinality + for (auto &arg : arguments.data) { + arg.SetVectorType(VectorType::CONSTANT_VECTOR); + } + arguments.SetChildCardinality(count); + return SelectBooleanResult(result, sel, count, true_sel, false_sel); + } + auto &execute_function_state = state->Cast(); + if (expr.Function().HasFunctionCallback()) { + Vector result(LogicalType::BOOLEAN); + auto dictionary_executed = + execute_function_state.TryExecuteDictionaryExpression(expr, arguments, *state, result); + if (dictionary_executed) { + return SelectBooleanResult(result, sel, count, true_sel, false_sel); + } + } + return expr.Function().GetSelectCallback()(arguments, *state, sel, true_sel, false_sel); } } // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_operator.cpp b/src/duckdb/src/execution/expression_executor/execute_operator.cpp index 65bb8f586..921e03f65 100644 --- a/src/duckdb/src/execution/expression_executor/execute_operator.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_operator.cpp @@ -7,7 +7,7 @@ namespace duckdb { unique_ptr ExpressionExecutor::InitializeState(const BoundOperatorExpression &expr, ExpressionExecutorState &root) { auto result = make_uniq(expr, root); - for (auto &child : expr.children) { + for (auto &child : expr.GetChildren()) { result->AddChild(*child); } @@ -21,13 +21,13 @@ void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, Expression // IN has n children auto expression_type = expr.GetExpressionType(); if (expression_type == ExpressionType::COMPARE_IN || expression_type == ExpressionType::COMPARE_NOT_IN) { - if (expr.children.size() < 2) { + if (expr.GetChildren().size() < 2) { throw InvalidInputException("IN needs at least two children"); } - Vector left(expr.children[0]->GetReturnType()); + Vector left(expr.GetChildren()[0]->GetReturnType()); // eval left side - Execute(*expr.children[0], state->child_states[0].get(), sel, count, left); + Execute(*expr.GetChildren()[0], state->child_states[0].get(), sel, count, left); // init result to false Vector intermediate(LogicalType::BOOLEAN); @@ -36,12 +36,12 @@ void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, Expression // in rhs is a list of constants // for every child, OR the result of the comparison with the left // to get the overall result. - for (idx_t child = 1; child < expr.children.size(); child++) { - Vector vector_to_check(expr.children[child]->GetReturnType()); + for (idx_t child = 1; child < expr.GetChildren().size(); child++) { + Vector vector_to_check(expr.GetChildren()[child]->GetReturnType()); Vector comp_res(LogicalType::BOOLEAN); - Execute(*expr.children[child], state->child_states[child].get(), sel, count, vector_to_check); - VectorOperations::Equals(left, vector_to_check, comp_res, count); + Execute(*expr.GetChildren()[child], state->child_states[child].get(), sel, count, vector_to_check); + VectorOperations::Equals(left, vector_to_check, comp_res); if (child == 1) { // first child: move to result @@ -49,13 +49,13 @@ void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, Expression } else { // otherwise OR together Vector new_result(LogicalType::BOOLEAN); - VectorOperations::Or(intermediate, comp_res, new_result, count); + VectorOperations::Or(intermediate, comp_res, new_result); intermediate.Reference(new_result); } } if (expression_type == ExpressionType::COMPARE_NOT_IN) { // NOT IN: invert result - VectorOperations::Not(intermediate, result, count); + VectorOperations::Not(intermediate, result); } else { // directly use the result result.Reference(intermediate); @@ -69,9 +69,9 @@ void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, Expression const SelectionVector *current_sel = sel; idx_t remaining_count = count; idx_t next_count; - for (idx_t child = 0; child < expr.children.size(); child++) { - Vector vector_to_check(expr.children[child]->GetReturnType()); - Execute(*expr.children[child], state->child_states[child].get(), current_sel, remaining_count, + for (idx_t child = 0; child < expr.GetChildren().size(); child++) { + Vector vector_to_check(expr.GetChildren()[child]->GetReturnType()); + Execute(*expr.GetChildren()[child], state->child_states[child].get(), current_sel, remaining_count, vector_to_check); auto entries = vector_to_check.Validity(); @@ -111,7 +111,7 @@ void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, Expression auto &child_state = *state->child_states[0]; Vector try_result(result.GetType()); try { - Execute(*expr.children[0], &child_state, sel, count, try_result); + Execute(*expr.GetChildren()[0], &child_state, sel, count, try_result); if (try_result.GetVectorType() == VectorType::CONSTANT_VECTOR) { result.Reference(try_result); return; @@ -132,7 +132,7 @@ void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, Expression intermediate.Initialize(GetAllocator(), {result.GetType()}, 1); for (idx_t i = 0; i < count; i++) { intermediate.Reset(); - intermediate.SetCardinality(1); + intermediate.SetChildCardinality(1); // Make sure to clear any dictionary states in the child expression, so that it actually // gets executed anew for every row @@ -141,7 +141,7 @@ void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, Expression selvec.set_index(0, sel ? sel->get_index(i) : i); Value val(result.GetType()); try { - Execute(*expr.children[0], &child_state, &selvec, 1, intermediate.data[0]); + Execute(*expr.GetChildren()[0], &child_state, &selvec, 1, intermediate.data[0]); val = intermediate.GetValue(0, 0); } catch (std::exception &ex) { ErrorData error(ex); @@ -155,22 +155,22 @@ void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, Expression if (count == 1) { result.SetVectorType(VectorType::CONSTANT_VECTOR); } - } else if (expr.children.size() == 1) { + } else if (expr.GetChildren().size() == 1) { state->intermediate_chunk.Reset(); auto &child = state->intermediate_chunk.data[0]; - Execute(*expr.children[0], state->child_states[0].get(), sel, count, child); + Execute(*expr.GetChildren()[0], state->child_states[0].get(), sel, count, child); switch (expr.GetExpressionType()) { case ExpressionType::OPERATOR_NOT: { - VectorOperations::Not(child, result, count); + VectorOperations::Not(child, result); break; } case ExpressionType::OPERATOR_IS_NULL: { - VectorOperations::IsNull(child, result, count); + VectorOperations::IsNull(child, result); break; } case ExpressionType::OPERATOR_IS_NOT_NULL: { - VectorOperations::IsNotNull(child, result, count); + VectorOperations::IsNotNull(child, result); break; } default: diff --git a/src/duckdb/src/execution/expression_executor/execute_parameter.cpp b/src/duckdb/src/execution/expression_executor/execute_parameter.cpp index 5cb4971c8..94f4d9cc2 100644 --- a/src/duckdb/src/execution/expression_executor/execute_parameter.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_parameter.cpp @@ -13,10 +13,10 @@ unique_ptr ExpressionExecutor::InitializeState(const BoundParam void ExpressionExecutor::Execute(const BoundParameterExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.parameter_data); - D_ASSERT(expr.parameter_data->return_type == expr.GetReturnType()); - D_ASSERT(expr.parameter_data->GetValue().type() == expr.GetReturnType()); - result.Reference(expr.parameter_data->GetValue(), count_t(count)); + D_ASSERT(expr.ParameterData()); + D_ASSERT(expr.ParameterData()->return_type == expr.GetReturnType()); + D_ASSERT(expr.ParameterData()->GetValue().type() == expr.GetReturnType()); + result.Reference(expr.ParameterData()->GetValue(), count_t(count)); } } // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_reference.cpp b/src/duckdb/src/execution/expression_executor/execute_reference.cpp index 88fdfa63d..ff8882c20 100644 --- a/src/duckdb/src/execution/expression_executor/execute_reference.cpp +++ b/src/duckdb/src/execution/expression_executor/execute_reference.cpp @@ -12,13 +12,13 @@ unique_ptr ExpressionExecutor::InitializeState(const BoundRefer void ExpressionExecutor::Execute(const BoundReferenceExpression &expr, ExpressionState *state, const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.index != DConstants::INVALID_INDEX); - D_ASSERT(expr.index < chunk->ColumnCount()); + D_ASSERT(expr.Index() != DConstants::INVALID_INDEX); + D_ASSERT(expr.Index() < chunk->ColumnCount()); if (sel) { - result.Slice(chunk->data[expr.index], *sel, count); + result.Slice(chunk->data[expr.Index()], *sel, count); } else { - result.Reference(chunk->data[expr.index]); + result.Reference(chunk->data[expr.Index()]); } } diff --git a/src/duckdb/src/execution/expression_executor_state.cpp b/src/duckdb/src/execution/expression_executor_state.cpp index 9f212fc4b..8e592ac58 100644 --- a/src/duckdb/src/execution/expression_executor_state.cpp +++ b/src/duckdb/src/execution/expression_executor_state.cpp @@ -35,7 +35,7 @@ bool ExpressionState::HasContext() { ClientContext &ExpressionState::GetContext() { if (!HasContext()) { throw BinderException("Cannot use %s in this context", - (expr.Cast()).function.GetName()); + (expr.Cast()).Function().GetName()); } return root.executor->GetContext(); } diff --git a/src/duckdb/src/execution/index/art/art.cpp b/src/duckdb/src/execution/index/art/art.cpp index df9e34d16..bde663f37 100644 --- a/src/duckdb/src/execution/index/art/art.cpp +++ b/src/duckdb/src/execution/index/art/art.cpp @@ -27,6 +27,7 @@ #include "duckdb/storage/table/append_state.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table_io_manager.hpp" +#include "duckdb/storage/storage_manager.hpp" namespace duckdb { @@ -45,7 +46,7 @@ struct ARTIndexScanState : public IndexScanState { // ART //===--------------------------------------------------------------------===// -ART::ART(const string &name, const IndexConstraintType index_constraint_type, const vector &column_ids, +ART::ART(const Identifier &name, const IndexConstraintType index_constraint_type, const vector &column_ids, TableIOManager &table_io_manager, const vector> &unbound_expressions, AttachedDatabase &db, const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr, @@ -117,12 +118,12 @@ ART::ART(const string &name, const IndexConstraintType index_constraint_type, co auto it = info.options.find("storage_version"); if (it != info.options.end()) { // If this is an existing index with a saved storage version, use it. - storage_version = it->second.GetValue(); + storage_version = static_cast(it->second.GetValue()); } else { // Otherwise, this must be an existing index without a saved storage version. // We started saving the storage version in v1.5.0, so if it is not present, // we can not make any general assumptions about the exact storage version. - storage_version = optional_idx::Invalid(); + storage_version = StorageVersion::INVALID; } } @@ -174,7 +175,7 @@ unique_ptr ART::TryInitializeScan(const Expression &expr, const // bindings[1] = the index expression // bindings[2] = the constant auto &comparison = bindings[0].get().Cast(); - auto constant_value = bindings[2].get().Cast().value; + auto constant_value = bindings[2].get().Cast().GetValue(); auto comparison_type = comparison.GetExpressionType(); auto &left = BoundComparisonExpression::Left(comparison); @@ -214,10 +215,10 @@ unique_ptr ART::TryInitializeScan(const Expression &expr, const auto lower_inclusive = BoundBetweenExpression::LowerInclusive(between); auto upper_inclusive = BoundBetweenExpression::UpperInclusive(between); - low_value = lower_bound.Cast().value; + low_value = lower_bound.Cast().GetValue(); low_comparison_type = lower_inclusive ? ExpressionType::COMPARE_GREATERTHANOREQUALTO : ExpressionType::COMPARE_GREATERTHAN; - high_value = (upper_bound.Cast()).value; + high_value = (upper_bound.Cast()).GetValue(); high_comparison_type = upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO : ExpressionType::COMPARE_LESSTHAN; } @@ -254,7 +255,8 @@ unique_ptr ART::InitializeFullScan() { //===--------------------------------------------------------------------===// template -static void TemplatedGenerateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, unsafe_vector &keys) { +static void TemplatedGenerateKeys(ArenaAllocator &allocator, const Vector &input, unsafe_vector &keys) { + const idx_t count = input.size(); D_ASSERT(keys.size() >= count); UnifiedVectorFormat data; @@ -274,7 +276,8 @@ static void TemplatedGenerateKeys(ArenaAllocator &allocator, Vector &input, idx_ } template -static void ConcatenateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, unsafe_vector &keys) { +static void ConcatenateKeys(ArenaAllocator &allocator, const Vector &input, unsafe_vector &keys) { + const idx_t count = input.size(); UnifiedVectorFormat data; input.ToUnifiedFormat(data); auto input_data = UnifiedVectorFormat::GetData(data); @@ -309,46 +312,46 @@ template void GenerateKeysInternal(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys) { switch (input.data[0].GetType().InternalType()) { case PhysicalType::BOOL: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::INT8: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::INT16: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::INT32: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::INT64: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::INT128: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::UINT8: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::UINT16: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::UINT32: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::UINT64: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::UINT128: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::FLOAT: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::DOUBLE: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; case PhysicalType::VARCHAR: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); + TemplatedGenerateKeys(allocator, input.data[0], keys); break; default: throw InternalException("Invalid type for index"); @@ -358,46 +361,46 @@ void GenerateKeysInternal(ArenaAllocator &allocator, DataChunk &input, unsafe_ve for (idx_t i = 1; i < input.ColumnCount(); i++) { switch (input.data[i].GetType().InternalType()) { case PhysicalType::BOOL: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::INT8: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::INT16: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::INT32: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::INT64: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::INT128: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::UINT8: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::UINT16: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::UINT32: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::UINT64: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::UINT128: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::FLOAT: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::DOUBLE: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; case PhysicalType::VARCHAR: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); + ConcatenateKeys(allocator, input.data[i], keys); break; default: throw InternalException("Invalid type for index"); @@ -415,10 +418,10 @@ void ART::GenerateKeys(ArenaAllocator &allocator, DataChunk &input, unsafe GenerateKeysInternal(allocator, input, keys); } -static bool KeyInputNeedConversion(const vector &types, optional_idx storage_version) { +static bool KeyInputNeedConversion(const vector &types, StorageVersion storage_version) { // We only started tracking the storage version of the index in v1.5.0. // Old GEOMETRY columns (pre v1.5.0) had a different internal representation. - if (!storage_version.IsValid() || (storage_version.GetIndex() < 7)) { + if (storage_version == StorageVersion::INVALID || (storage_version < StorageVersion::V1_5_0)) { for (auto &type : types) { // ART does not support nested types, so we only need to check the top-level type. if (type.id() == LogicalTypeId::GEOMETRY) { @@ -452,12 +455,10 @@ static void ConvertKeyInput(DataChunk &input, DataChunk &result) { result.data[i].Reference(input.data[i]); } } - - result.SetCardinality(input.size()); } -void ART::GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, Vector &row_ids, unsafe_vector &keys, - unsafe_vector &row_id_keys) { +void ART::GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, const Vector &row_ids, + unsafe_vector &keys, unsafe_vector &row_id_keys) { auto key_input = &input; DataChunk converted_chunk; @@ -473,7 +474,6 @@ void ART::GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, Vector row_id_chunk.Initialize(Allocator::DefaultAllocator(), vector {LogicalType::ROW_TYPE}, key_input->size()); row_id_chunk.data[0].Reference(row_ids); - row_id_chunk.SetCardinality(key_input->size()); GenerateKeys<>(allocator, row_id_chunk, row_id_keys); } @@ -1018,8 +1018,8 @@ IndexStorageInfo ART::PrepareSerialize(const case_insensitive_map_t &opti info.options = options; // It never hurts to serialize the storage version, even to older formats - if (storage_version.IsValid()) { - info.options["storage_version"] = Value::UBIGINT(storage_version.GetIndex()); + if (storage_version != StorageVersion::INVALID) { + info.options["storage_version"] = Value::UBIGINT(static_cast(storage_version)); } for (auto &allocator : *allocators) { diff --git a/src/duckdb/src/execution/index/art/art_key.cpp b/src/duckdb/src/execution/index/art/art_key.cpp index 091b6e039..c1ec907d1 100644 --- a/src/duckdb/src/execution/index/art/art_key.cpp +++ b/src/duckdb/src/execution/index/art/art_key.cpp @@ -63,11 +63,12 @@ void ARTKey::CreateARTKey(ArenaAllocator &allocator, ARTKey &key, const char *va ARTKey::CreateARTKey(allocator, key, string_t(value, UnsafeNumericCast(strlen(value)))); } -ARTKey ARTKey::CreateKey(ArenaAllocator &allocator, Value &value, optional_idx storage_version) { +ARTKey ARTKey::CreateKey(ArenaAllocator &allocator, Value &value, StorageVersion storage_version) { const auto &type = value.type(); D_ASSERT(type.InternalType() == value.type().InternalType()); - if (type.id() == LogicalTypeId::GEOMETRY && (!storage_version.IsValid() || (storage_version.GetIndex() < 7))) { + if (type.id() == LogicalTypeId::GEOMETRY && + (storage_version == StorageVersion::INVALID || (storage_version < StorageVersion::V1_5_0))) { // Convert to old-style geometry for older storage versions. string buffer; Geometry::ToSpatialGeometry(value.GetValueUnsafe(), buffer); diff --git a/src/duckdb/src/execution/index/art/base_leaf.cpp b/src/duckdb/src/execution/index/art/base_leaf.cpp index 4a9332fc9..fce383d01 100644 --- a/src/duckdb/src/execution/index/art/base_leaf.cpp +++ b/src/duckdb/src/execution/index/art/base_leaf.cpp @@ -86,13 +86,12 @@ void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byt // Get the remaining row ID. remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; remainder |= UnsafeNumericCast(n7.key[0]); - - // Free the prefix (nodes) and inline the remainder. - if (prefix.GetType() == NType::PREFIX) { - Node::FreeTree(art, prefix); - Leaf::New(prefix, UnsafeNumericCast(remainder)); - return; - } + } + // Free the prefix (nodes) and inline the remainder. + if (prefix.GetType() == NType::PREFIX) { + Node::FreeTree(art, prefix); + Leaf::New(prefix, UnsafeNumericCast(remainder)); + return; } // Free the Node7Leaf and inline the remainder. Node::FreeNode(art, node); diff --git a/src/duckdb/src/execution/index/bound_index.cpp b/src/duckdb/src/execution/index/bound_index.cpp index 11662fb7e..0adb8f58f 100644 --- a/src/duckdb/src/execution/index/bound_index.cpp +++ b/src/duckdb/src/execution/index/bound_index.cpp @@ -15,7 +15,7 @@ namespace duckdb { // Bound index //------------------------------------------------------------------------------- -BoundIndex::BoundIndex(const string &name, const string &index_type, IndexConstraintType index_constraint_type, +BoundIndex::BoundIndex(const Identifier &name, const string &index_type, IndexConstraintType index_constraint_type, const vector &column_ids, TableIOManager &table_io_manager, const vector> &unbound_expressions_p, AttachedDatabase &db) : Index(column_ids, table_io_manager, db), name(name), index_type(index_type), @@ -148,7 +148,7 @@ unique_ptr BoundIndex::BindExpression(unique_ptr root_ex ExpressionIterator::VisitExpressionMutable( root_expr, [&](BoundColumnRefExpression &bound_colref, unique_ptr &expr) { expr = make_uniq(expr->GetReturnType(), - column_ids[bound_colref.binding.column_index]); + column_ids[bound_colref.Binding().column_index]); }); return root_expr; } @@ -249,7 +249,6 @@ void BoundIndex::ApplyBufferedReplays(const vector &table_types, Bu table_chunk.data[col_id].Reference(state.current_chunk.data[col_idx]); table_chunk.data[col_id].Slice(sel, rows_to_process); } - table_chunk.SetCardinality(rows_to_process); Vector row_ids(state.current_chunk.data.back(), sel, rows_to_process); if (replay_range.type == BufferedIndexReplay::INSERT_ENTRY) { diff --git a/src/duckdb/src/execution/index/index_type_set.cpp b/src/duckdb/src/execution/index/index_type_set.cpp index 6e265a19e..8219093cf 100644 --- a/src/duckdb/src/execution/index/index_type_set.cpp +++ b/src/duckdb/src/execution/index/index_type_set.cpp @@ -11,7 +11,7 @@ IndexTypeSet::IndexTypeSet() { optional_ptr IndexTypeSet::FindByName(const string &name) { lock_guard g(lock); - auto entry = functions.find(name); + auto entry = functions.find(Identifier(name)); if (entry == functions.end()) { return nullptr; } @@ -20,10 +20,10 @@ optional_ptr IndexTypeSet::FindByName(const string &name) { void IndexTypeSet::RegisterIndexType(const IndexType &index_type) { lock_guard g(lock); - if (functions.find(index_type.name) != functions.end()) { + if (functions.find(Identifier(index_type.name)) != functions.end()) { throw CatalogException("Index type with name \"%s\" already exists!", index_type.name.c_str()); } - functions[index_type.name] = index_type; + functions[Identifier(index_type.name)] = index_type; } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/unbound_index.cpp b/src/duckdb/src/execution/index/unbound_index.cpp index 1ed582f6b..9eefdab13 100644 --- a/src/duckdb/src/execution/index/unbound_index.cpp +++ b/src/duckdb/src/execution/index/unbound_index.cpp @@ -56,7 +56,6 @@ void UnboundIndex::BufferChunk(DataChunk &index_column_chunk, Vector &row_ids, combined_chunk.data[i].Reference(index_column_chunk.data[i]); } combined_chunk.data.back().Reference(row_ids); - combined_chunk.SetCardinality(index_column_chunk.size()); auto &buffer = buffered_replays.GetBuffer(replay_type); if (buffer == nullptr) { diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp index 978f11a08..0ee6540f9 100644 --- a/src/duckdb/src/execution/join_hashtable.cpp +++ b/src/duckdb/src/execution/join_hashtable.cpp @@ -1,6 +1,7 @@ #include "duckdb/execution/join_hashtable.hpp" #include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/vector/dictionary_vector.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/radix_partitioning.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" @@ -10,6 +11,7 @@ #include "duckdb/execution/operator/join/physical_hash_join.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/storage/buffer_manager.hpp" namespace duckdb { @@ -28,6 +30,11 @@ JoinHashTable::ProbeState::ProbeState() non_empty_sel(STANDARD_VECTOR_SIZE) { } +JoinHashTable::ProbeDictionaryState::ProbeDictionaryState() + : hashes(LogicalType::HASH), new_dictionary_pointers(LogicalType::POINTER), unique_entries(STANDARD_VECTOR_SIZE), + match_sel(STANDARD_VECTOR_SIZE) { +} + JoinHashTable::InsertState::InsertState(const JoinHashTable &ht) : SharedState(), remaining_sel(STANDARD_VECTOR_SIZE), key_match_sel(STANDARD_VECTOR_SIZE), rhs_row_locations(LogicalType::POINTER) { @@ -149,10 +156,15 @@ void JoinHashTable::Merge(JoinHashTable &other) { } } + if (bloom_filter.IsInitialized() && other.bloom_filter.IsInitialized()) { + bloom_filter.Merge(other.bloom_filter); + } + sink_collection->Combine(*other.sink_collection); } -static void ApplyBitmaskAndGetSaltBuild(Vector &hashes_v, Vector &salt_v, const idx_t &count, const idx_t &bitmask) { +static void ApplyBitmaskAndGetSaltBuild(Vector &hashes_v, Vector &salt_v, const idx_t &bitmask) { + const idx_t count = hashes_v.size(); if (hashes_v.GetVectorType() == VectorType::CONSTANT_VECTOR) { auto &hash = *ConstantVector::GetData(hashes_v); salt_v.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -164,17 +176,17 @@ static void ApplyBitmaskAndGetSaltBuild(Vector &hashes_v, Vector &salt_v, const hashes_v.Flatten(); } else { hashes_v.Flatten(); - auto salts = FlatVector::GetDataMutable(salt_v); + auto salt_data = FlatVector::Writer(salt_v, count_t(count)); auto hashes = FlatVector::GetDataMutable(hashes_v); for (idx_t i = 0; i < count; i++) { - salts[i] = ht_entry_t::ExtractSalt(hashes[i]); + salt_data.WriteValue(ht_entry_t::ExtractSalt(hashes[i])); hashes[i] &= bitmask; } } } template -idx_t GetOptionalIndex(const SelectionVector *sel, const idx_t idx) { +idx_t GetOptionalIndex(optional_ptr sel, const idx_t idx) { return HAS_SEL ? sel->get_index(idx) : idx; } @@ -194,8 +206,9 @@ static void AddPointerToCompare(JoinHashTable::ProbeState &state, const ht_entry } template -static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHashTable &ht, ht_entry_t *entries, - Vector &pointers_result_v, const SelectionVector *row_sel, idx_t &count) { +static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHashTable &ht, + unsafe_optional_ptr entries, Vector &pointers_result_v, + optional_ptr row_sel, idx_t &count) { auto hashes_dense = FlatVector::GetDataMutable(state.hashes_dense_v); idx_t keys_to_compare_count = 0; @@ -207,7 +220,7 @@ static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHash if (USE_SALTS) { // increment the ht_offset of the entry as long as the next entry is occupied and salt does not match while (true) { - const ht_entry_t entry = entries[row_ht_offset]; + const ht_entry_t entry = entries.get()[row_ht_offset]; const bool occupied = entry.IsOccupied(); // the entry is empty -> no match possible @@ -229,7 +242,7 @@ static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHash IncrementAndWrap(row_ht_offset, ht.bitmask); } } else { - const ht_entry_t entry = entries[row_ht_offset]; + const ht_entry_t entry = entries.get()[row_ht_offset]; const bool occupied = entry.IsOccupied(); if (occupied) { // the entry is occupied -> compare the keys @@ -248,9 +261,9 @@ static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHash /// b) an entry is found where (and the salt matches if USE_SALTS is true) /// -> match, add to compare sel and increase found count template -static idx_t ProbeForPointers(JoinHashTable::ProbeState &state, JoinHashTable &ht, ht_entry_t *entries, - Vector &pointers_result_v, const SelectionVector *row_sel, idx_t count, - const bool has_row_sel) { +static idx_t ProbeForPointers(JoinHashTable::ProbeState &state, JoinHashTable &ht, + unsafe_optional_ptr entries, Vector &pointers_result_v, + optional_ptr row_sel, idx_t count, const bool has_row_sel) { if (has_row_sel) { return ProbeForPointersInternal(state, ht, entries, pointers_result_v, row_sel, count); } else { @@ -262,9 +275,9 @@ static idx_t ProbeForPointers(JoinHashTable::ProbeState &state, JoinHashTable &h //! vector and the count argument to the number and position of the matches template static void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_state, JoinHashTable::ProbeState &state, - Vector &hashes_v, const SelectionVector *row_sel, idx_t &count, JoinHashTable &ht, - ht_entry_t *entries, Vector &pointers_result_v, SelectionVector &match_sel, - bool has_row_sel) { + Vector &hashes_v, optional_ptr row_sel, idx_t &count, + JoinHashTable &ht, unsafe_optional_ptr entries, + Vector &pointers_result_v, SelectionVector &match_sel, bool has_row_sel) { // densify hashes: If there is no sel, flatten the hashes, else densify via UnifiedVectorFormat if (has_row_sel) { auto hashes_unified = hashes_v.Values(); @@ -320,7 +333,7 @@ static void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_sta } // in the next iteration, we have a selection vector with the keys that do not match - row_sel = &state.keys_no_match_sel; + row_sel = state.keys_no_match_sel; has_row_sel = true; elements_to_probe_count = keys_no_match_count; @@ -337,8 +350,9 @@ inline bool JoinHashTable::UseSalt() const { } void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, - const SelectionVector *sel, idx_t &count, Vector &pointers_result_v, + optional_ptr sel, idx_t &count, Vector &pointers_result_v, SelectionVector &match_sel, const bool has_sel) { + auto entries = GetEntries(); if (UseSalt()) { GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, *this, entries, pointers_result_v, match_sel, has_sel); @@ -351,9 +365,9 @@ void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_sta void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes) { if (count == keys.size()) { // no null values are filtered: use regular hash functions - VectorOperations::Hash(keys.data[0], hashes, keys.size()); + VectorOperations::Hash(keys.data[0], hashes, count); for (idx_t i = 1; i < equality_types.size(); i++) { - VectorOperations::CombineHash(hashes, keys.data[i], keys.size()); + VectorOperations::CombineHash(hashes, keys.data[i], count); } } else { // null values were filtered: use selection vector @@ -394,14 +408,14 @@ void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChu for (idx_t i = 0; i < info.correlated_types.size(); i++) { info.group_chunk.data[i].Reference(keys.data[i]); } - info.group_chunk.SetCardinality(keys); + info.group_chunk.SetChildCardinality(keys.size()); if (info.correlated_payload.data.empty()) { vector types; types.push_back(keys.data[info.correlated_types.size()].GetType()); info.correlated_payload.InitializeEmpty(types); } info.correlated_payload.data[0].Reference(keys.data[info.correlated_types.size()]); - info.correlated_payload.SetCardinality(keys); + info.correlated_payload.SetChildCardinality(keys.size()); info.correlated_counts->AddChunk(info.group_chunk, info.correlated_payload, AggregateType::NON_DISTINCT); } @@ -424,13 +438,13 @@ void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChu } Vector hash_values(LogicalType::HASH); source_chunk.data[col_offset].Reference(hash_values); - source_chunk.SetCardinality(keys); + source_chunk.SetChildCardinality(keys.size()); // ToUnifiedFormat the source chunk TupleDataCollection::ToUnifiedFormat(append_state.chunk_state, source_chunk); // prepare the keys for processing - const SelectionVector *current_sel; + optional_ptr current_sel; SelectionVector sel(STANDARD_VECTOR_SIZE); idx_t added_count = PrepareKeys(keys, append_state.chunk_state.vector_data, current_sel, sel, true); if (added_count < keys.size()) { @@ -443,6 +457,9 @@ void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChu // hash the keys and obtain an entry in the list // note that we only hash the keys used in the equality comparison Hash(keys, *current_sel, added_count, hash_values); + if (bloom_filter.IsInitialized()) { + bloom_filter.InsertHashes(hash_values); + } // Re-reference and ToUnifiedFormat the hash column after computing it source_chunk.data[col_offset].Reference(hash_values); @@ -453,7 +470,8 @@ void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChu } idx_t JoinHashTable::PrepareKeys(DataChunk &keys, vector &vector_data, - const SelectionVector *¤t_sel, SelectionVector &sel, bool build_side) { + optional_ptr ¤t_sel, SelectionVector &sel, + bool build_side) { // figure out which keys are NULL, and create a selection vector out of them current_sel = FlatVector::IncrementalSelectionVector(); idx_t added_count = keys.size(); @@ -476,7 +494,7 @@ idx_t JoinHashTable::PrepareKeys(DataChunk &keys, vector } added_count = FilterNullValues(col_key_data, *current_sel, added_count, sel); // null values are NOT equal for this column, filter them out - current_sel = &sel; + current_sel = sel; } return added_count; } @@ -533,7 +551,7 @@ static inline void PerformKeyComparison(JoinHashTable::InsertState &state, JoinH const idx_t count, idx_t &key_match_count, idx_t &key_no_match_count) { // Get the data for the rows that need to be compared state.lhs_data.Reset(); - state.lhs_data.SetCardinality(count); // the right size + state.lhs_data.SetChildCardinality(count); // the right size // The target selection vector says where to write the results into the lhs_data, we just want to write // sequentially as otherwise we trigger a bug in the Gather function @@ -556,11 +574,15 @@ static inline void PerformKeyComparison(JoinHashTable::InsertState &state, JoinH } template -static inline void InsertMatchesAndIncrementMisses(atomic entries[], JoinHashTable::InsertState &state, - JoinHashTable &ht, const data_ptr_t lhs_row_locations[], - idx_t ht_offsets[], const hash_t hash_salts[], - const idx_t capacity_mask, const idx_t key_match_count, - const idx_t key_no_match_count) { +static inline void InsertMatchesAndIncrementMisses(unsafe_optional_ptr> entries, + JoinHashTable::InsertState &state, JoinHashTable &ht, + const data_ptr_t lhs_row_locations[], idx_t ht_offsets[], + const hash_t hash_salts[], const idx_t capacity_mask, + const idx_t key_match_count, const idx_t key_no_match_count) { + if (key_match_count != 0) { + ht.chains_longer_than_one = true; + } + // Insert the rows that match if (ht.insert_duplicate_keys) { if (key_match_count != 0) { @@ -571,7 +593,7 @@ static inline void InsertMatchesAndIncrementMisses(atomic entries[], const auto entry_index = state.keys_to_compare_sel.get_index(need_compare_idx); const auto &ht_offset = ht_offsets[entry_index]; - auto &entry = entries[ht_offset]; + auto &entry = entries.get()[ht_offset]; const auto row_ptr_to_insert = lhs_row_locations[entry_index]; const auto salt = hash_salts[entry_index]; @@ -592,11 +614,12 @@ static inline void InsertMatchesAndIncrementMisses(atomic entries[], } template -static void InsertHashesLoop(atomic entries[], Vector &row_locations, Vector &hashes_v, const idx_t &count, +static void InsertHashesLoop(unsafe_optional_ptr> entries, Vector &row_locations, Vector &hashes_v, JoinHashTable::InsertState &state, const TupleDataCollection &data_collection, JoinHashTable &ht) { D_ASSERT(hashes_v.GetType().id() == LogicalType::HASH); - ApplyBitmaskAndGetSaltBuild(hashes_v, state.salt_v, count, ht.bitmask); + ApplyBitmaskAndGetSaltBuild(hashes_v, state.salt_v, ht.bitmask); + const idx_t count = hashes_v.size(); const auto &layout = data_collection.GetLayout(); @@ -610,7 +633,7 @@ static void InsertHashesLoop(atomic entries[], Vector &row_locations // we start off with the entire chunk idx_t remaining_count = count; - const auto *remaining_sel = FlatVector::IncrementalSelectionVector(); + optional_ptr remaining_sel = FlatVector::IncrementalSelectionVector(); const auto all_valid = layout.CannotHaveNull(); const auto column_count = layout.ColumnCount(); @@ -641,7 +664,7 @@ static void InsertHashesLoop(atomic entries[], Vector &row_locations } } remaining_count = new_remaining_count; - remaining_sel = &state.remaining_sel; + remaining_sel = state.remaining_sel; } } @@ -660,7 +683,7 @@ static void InsertHashesLoop(atomic entries[], Vector &row_locations ht_entry_t entry; bool occupied; while (true) { - atomic &atomic_entry = entries[ht_offset]; + atomic &atomic_entry = entries.get()[ht_offset]; entry = atomic_entry.load(std::memory_order_relaxed); occupied = entry.IsOccupied(); @@ -676,7 +699,7 @@ static void InsertHashesLoop(atomic entries[], Vector &row_locations } if (!occupied) { // insert into free - auto &atomic_entry = entries[ht_offset]; + auto &atomic_entry = entries.get()[ht_offset]; const auto row_ptr_to_insert = lhs_row_locations[row_index]; const auto potential_collided_ptr = InsertRowToEntry(atomic_entry, row_ptr_to_insert, salt, ht.pointer_offset); @@ -713,23 +736,23 @@ static void InsertHashesLoop(atomic entries[], Vector &row_locations // update the overall selection vector to only point the entries that still need to be inserted // as there was no match found for them yet - remaining_sel = &state.remaining_sel; + remaining_sel = state.remaining_sel; remaining_count = key_no_match_count; } } -void JoinHashTable::InsertHashes(Vector &hashes_v, const idx_t count, TupleDataChunkState &chunk_state, - InsertState &insert_state, bool parallel) { +void JoinHashTable::InsertHashes(Vector &hashes_v, TupleDataChunkState &chunk_state, InsertState &insert_state, + bool parallel) { // Insert Hashes into the BF if (bloom_filter.IsInitialized()) { - bloom_filter.InsertHashes(hashes_v, count); + bloom_filter.InsertHashes(hashes_v); } - auto atomic_entries = reinterpret_cast *>(this->entries); + auto atomic_entries = GetAtomicEntries(); auto &row_locations = chunk_state.row_locations; if (parallel) { - InsertHashesLoop(atomic_entries, row_locations, hashes_v, count, insert_state, *data_collection, *this); + InsertHashesLoop(atomic_entries, row_locations, hashes_v, insert_state, *data_collection, *this); } else { - InsertHashesLoop(atomic_entries, row_locations, hashes_v, count, insert_state, *data_collection, *this); + InsertHashesLoop(atomic_entries, row_locations, hashes_v, insert_state, *data_collection, *this); } } @@ -745,7 +768,7 @@ void JoinHashTable::InsertPrefixRangeChunk(TupleDataChunkState &chunk_state, idx auto &sel = *FlatVector::IncrementalSelectionVector(); data_collection->Gather(chunk_state.row_locations, sel, count, 0, build_keys, sel, nullptr); FlatVector::SetSize(build_keys, count_t(count)); - prefix_range_filter->InsertKeys(build_keys, count, state); + prefix_range_filter->InsertKeys(build_keys, state); } void JoinHashTable::MergePrefixRangeBuildState(PrefixRangeFilter::BuildState &state) { @@ -762,8 +785,9 @@ void JoinHashTable::AllocatePointerTable() { throw InternalException("Hashtable capacity exceeds 48-bit limit (2^48 - 1)"); } - if (should_build_bloom_filter) { - bloom_filter.Initialize(context, Count()); + if (should_build_bloom_filter && !bloom_filter.IsInitialized()) { + bloom_filter_init_count = MaxValue(Count(), 1); + bloom_filter.Initialize(context, bloom_filter_init_count); } if (hash_map.get()) { @@ -772,7 +796,6 @@ void JoinHashTable::AllocatePointerTable() { if (capacity > current_capacity) { // Need more space hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(ht_entry_t)); - entries = reinterpret_cast(hash_map.get()); } else { // Just use the current hash map capacity = current_capacity; @@ -780,7 +803,6 @@ void JoinHashTable::AllocatePointerTable() { } else { // Allocate a hash map hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(ht_entry_t)); - entries = reinterpret_cast(hash_map.get()); } D_ASSERT(hash_map.GetSize() == capacity * sizeof(ht_entry_t)); @@ -791,9 +813,40 @@ void JoinHashTable::AllocatePointerTable() { {"size", to_string(data_collection->SizeInBytes() + hash_map.GetSize())}}); } +void JoinHashTable::PrepareBuildBloomFilter(idx_t estimated_row_count) { + should_build_bloom_filter = true; + if (!bloom_filter.IsInitialized()) { + bloom_filter_init_count = MaxValue(estimated_row_count, idx_t(1)); + bloom_filter.Initialize(context, bloom_filter_init_count); + } +} + +void JoinHashTable::PrepareBloomFilterForFinalize() { + if (!should_build_bloom_filter) { + return; + } + + // Finalize scans every build tuple and inserts its hash into the bloom filter. + // Resize here if the planner estimate was too low, then let finalize populate it. + const auto build_count = Count(); + const auto actual_init_count = MaxValue(build_count, 1); + static constexpr double REBUILD_UNDERESTIMATE_THRESHOLD = 2.0; + const auto estimated_too_low = bloom_filter_init_count == 0 || + static_cast(build_count) > + static_cast(bloom_filter_init_count) * REBUILD_UNDERESTIMATE_THRESHOLD; + if (bloom_filter.IsInitialized() && !estimated_too_low) { + return; + } + + bloom_filter.Reset(); + bloom_filter_init_count = actual_init_count; + bloom_filter.Initialize(context, bloom_filter_init_count); +} + void JoinHashTable::InitializePointerTable(idx_t entry_idx_from, idx_t entry_idx_to) { // initialize HT with all-zero entries - std::fill_n(entries + entry_idx_from, entry_idx_to - entry_idx_from, ht_entry_t()); + auto entries = GetEntries(); + std::fill_n(entries.get() + entry_idx_from, entry_idx_to - entry_idx_from, ht_entry_t()); } void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool parallel, @@ -810,13 +863,13 @@ void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool para InsertState insert_state(*this); do { const auto count = iterator.GetCurrentChunkCount(); - auto hash_data = FlatVector::GetDataMutable(hashes); + auto hash_data = FlatVector::Writer(hashes, count_t(count)); for (idx_t i = 0; i < count; i++) { - hash_data[i] = Load(row_locations[i] + pointer_offset); + hash_data.WriteValue(Load(row_locations[i] + pointer_offset)); } TupleDataChunkState &chunk_state = iterator.GetChunkState(); - InsertHashes(hashes, count, chunk_state, insert_state, parallel); + InsertHashes(hashes, chunk_state, insert_state, parallel); if (prefix_range_state) { InsertPrefixRangeChunk(chunk_state, count, *prefix_range_state); } @@ -824,7 +877,8 @@ void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool para } void JoinHashTable::InitializeScanStructure(ScanStructure &scan_structure, DataChunk &keys, - TupleDataChunkState &key_state, const SelectionVector *¤t_sel) { + TupleDataChunkState &key_state, + optional_ptr ¤t_sel) { D_ASSERT(Count() > 0); // should be handled before D_ASSERT(finalized); @@ -846,9 +900,34 @@ void JoinHashTable::InitializeScanStructure(ScanStructure &scan_structure, DataC } } +//! Per-chunk fast-reject so non-dictionary chunks skip the TryProbeDictionary call entirely +static bool LHSChunkIsDictionaryEligible(const Vector &lhs_key) { + if (lhs_key.GetVectorType() != VectorType::DICTIONARY_VECTOR) { + return false; + } + return DictionaryVector::DictionarySize(lhs_key).IsValid(); +} + +//! Per-chunk fast-reject so non-constant chunks skip the TryProbeConstant call entirely +static bool LHSChunkIsConstant(const Vector &lhs_key) { + return lhs_key.GetVectorType() == VectorType::CONSTANT_VECTOR; +} + void JoinHashTable::Probe(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, ProbeState &probe_state, optional_ptr precomputed_hashes) { - const SelectionVector *current_sel; + // dispatch into the compressed-vector fast paths when the operator gated them on. precomputed_hashes + // is incompatible: the caller hashed the full N-row chunk; the fast paths re-hash a sliced D-row chunk. + if (probe_state.dict_state && !precomputed_hashes) { + auto &lhs_key = keys.data[0]; + if (LHSChunkIsConstant(lhs_key) && TryProbeConstant(scan_structure, keys, key_state, probe_state)) { + return; + } + if (LHSChunkIsDictionaryEligible(lhs_key) && TryProbeDictionary(scan_structure, keys, key_state, probe_state)) { + return; + } + } + + optional_ptr current_sel; InitializeScanStructure(scan_structure, keys, key_state, current_sel); if (scan_structure.count == 0) { return; @@ -867,6 +946,216 @@ void JoinHashTable::Probe(ScanStructure &scan_structure, DataChunk &keys, TupleD } } +bool JoinHashTable::TryProbeDictionary(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, + ProbeState &probe_state) { + static constexpr idx_t MAX_DICTIONARY_SIZE_THRESHOLD = 20000; + static constexpr idx_t DICTIONARY_THRESHOLD = 2; + + D_ASSERT(equality_types.size() == 1); + D_ASSERT(Count() > 0); + D_ASSERT(finalized); + + auto &dict_col = keys.data[0]; + if (dict_col.GetVectorType() != VectorType::DICTIONARY_VECTOR) { + return false; + } + const auto opt_dict_size = DictionaryVector::DictionarySize(dict_col); + if (!opt_dict_size.IsValid()) { + // dict size not known - this is not a dictionary that comes from the storage + return false; + } + const idx_t dict_size = opt_dict_size.GetIndex(); + const auto &dictionary_id = DictionaryVector::DictionaryId(dict_col); + if (dictionary_id.empty()) { + // no id - can't cache across vectors, only use dict path if it is much smaller than the chunk + if (dict_size * DICTIONARY_THRESHOLD >= keys.size()) { + return false; + } + } else if (dict_size >= MAX_DICTIONARY_SIZE_THRESHOLD) { + return false; + } + if (dict_size == 0) { + // the cache below is never allocated for an empty dictionary - avoid dereferencing a null Vector + return false; + } + + auto &dictionary_vector = DictionaryVector::Child(dict_col); + auto &offsets = DictionaryVector::SelVector(dict_col); + D_ASSERT(probe_state.dict_state); + auto &dict_state = *probe_state.dict_state; + + if (dict_state.dictionary_id.empty() || dict_state.dictionary_id != dictionary_id) { + // new dictionary - initialize the per-id cache + if (dict_size > dict_state.capacity) { + dict_state.dictionary_pointers = make_uniq(LogicalType::POINTER, dict_size); + dict_state.found_entry = make_unsafe_uniq_array(dict_size); + dict_state.capacity = dict_size; + } + memset(dict_state.found_entry.get(), 0, dict_size * sizeof(bool)); + // zero the cache - an unresolved slot reads back as nullptr, the miss sentinel + memset(FlatVector::GetDataMutable(*dict_state.dictionary_pointers), 0, dict_size * sizeof(data_ptr_t)); + dict_state.dictionary_id = dictionary_id; + dict_state.resolved_count = 0; + // storage dictionaries reserve a NULL sentinel slot that non-NULL probe rows never + // reference; pre-mark NULL child slots resolved so resolved_count can still reach dict_size. + // Equality joins only: PrepareKeys drops NULL probe rows, so the NULL slot is never probed. + // Under NOT DISTINCT FROM those rows are kept and the walk below must probe it itself. + if (!null_values_are_equal[0] && dictionary_vector.GetVectorType() == VectorType::FLAT_VECTOR) { + const auto &child_validity = FlatVector::Validity(dictionary_vector); + if (child_validity.CanHaveNull()) { + for (idx_t i = 0; i < dict_size; i++) { + if (!child_validity.RowIsValid(i)) { + dict_state.found_entry[i] = true; + dict_state.resolved_count++; + } + } + } + } + } else if (dict_size > dict_state.capacity) { + throw InternalException("JoinHashTable - using cached dictionary data but dictionary has changed (dictionary " + "id %s - dict size %d, current capacity %d)", + dict_state.dictionary_id, dict_size, dict_state.capacity); + } + + // collect each dict slot once - skip the walk entirely once every slot is resolved + auto &found_entry = dict_state.found_entry; + auto &unique_entries = dict_state.unique_entries; + idx_t unique_count = 0; + if (dict_state.resolved_count < dict_size) { + for (idx_t i = 0; i < keys.size(); i++) { + const auto dict_idx = offsets.get_index(i); + unique_entries.set_index(unique_count, dict_idx); + unique_count += !found_entry[dict_idx]; + found_entry[dict_idx] = true; + } + dict_state.resolved_count += unique_count; + } + + if (unique_count > 0) { + auto &unique_values = dict_state.unique_values; + if (unique_values.ColumnCount() == 0) { + unique_values.InitializeEmpty(vector {equality_types[0]}); + TupleDataCollection::InitializeChunkState(dict_state.unique_key_state, {equality_types[0]}); + } + unique_values.data[0].Slice(dictionary_vector, unique_entries, unique_count); + unique_values.CheckCardinality(unique_count); + + TupleDataCollection::ToUnifiedFormat(dict_state.unique_key_state, unique_values); + + auto &hashes = dict_state.hashes; + VectorOperations::Hash(unique_values.data[0], hashes, unique_count); + + idx_t count = unique_count; + GetRowPointers(unique_values, dict_state.unique_key_state, probe_state, hashes, nullptr, count, + dict_state.new_dictionary_pointers, dict_state.match_sel, false); + + // remap hits from position to dict slot; missed slots stay nullptr from the rebind memset + const auto new_dict_ptrs = FlatVector::GetData(dict_state.new_dictionary_pointers); + auto unique_dict_ptrs = FlatVector::GetDataMutable(*dict_state.dictionary_pointers); + for (idx_t i = 0; i < count; i++) { + const auto position = dict_state.match_sel.get_index(i); + const auto dict_idx = unique_entries.get_index(position); + unique_dict_ptrs[dict_idx] = new_dict_ptrs[position]; + } + } + + // fan out the cache into the scan_structure shape that Probe() produces + scan_structure.is_null = false; + scan_structure.finished = false; + if (join_type != JoinType::INNER) { + memset(scan_structure.found_match.get(), 0, sizeof(bool) * STANDARD_VECTOR_SIZE); + } + TupleDataCollection::ToUnifiedFormat(key_state, keys); + optional_ptr current_sel; + const idx_t prepared_count = + PrepareKeys(keys, key_state.vector_data, current_sel, scan_structure.sel_vector, false); + scan_structure.has_null_value_filter = prepared_count < keys.size(); + + auto scan_structure_ptrs = FlatVector::GetDataMutable(scan_structure.pointers); + const auto unique_dict_ptrs = FlatVector::GetData(*dict_state.dictionary_pointers); + idx_t kept = 0; + for (idx_t i = 0; i < prepared_count; i++) { + const auto row_index = current_sel->get_index(i); + const auto dict_idx = offsets.get_index(row_index); + const auto ptr = unique_dict_ptrs[dict_idx]; + scan_structure_ptrs[row_index] = ptr; + scan_structure.sel_vector.set_index(kept, row_index); + // miss rows leave a stale pointer, but Next() only reads pointers at positions in sel_vector[0..count) + kept += (ptr != nullptr); + } + scan_structure.count = kept; + return true; +} + +bool JoinHashTable::TryProbeConstant(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, + ProbeState &probe_state) { + D_ASSERT(equality_types.size() == 1); + D_ASSERT(Count() > 0); + D_ASSERT(finalized); + + auto &constant_col = keys.data[0]; + if (constant_col.GetVectorType() != VectorType::CONSTANT_VECTOR) { + return false; + } +#ifndef DEBUG + if (keys.size() <= 1) { + // resolving once only pays off when the result fans out to multiple probe rows + return false; + } +#endif + + // constant vectors carry no dictionary id, so there is no cross-chunk cache - resolve the + // single distinct key with one hash-table lookup per chunk + D_ASSERT(probe_state.dict_state); + auto &dict_state = *probe_state.dict_state; + auto &unique_values = dict_state.unique_values; + if (unique_values.ColumnCount() == 0) { + unique_values.InitializeEmpty(vector {equality_types[0]}); + TupleDataCollection::InitializeChunkState(dict_state.unique_key_state, {equality_types[0]}); + } + unique_values.data[0].Reference(constant_col); + unique_values.SetChildCardinality(1); + unique_values.Flatten(); + + TupleDataCollection::ToUnifiedFormat(dict_state.unique_key_state, unique_values); + + auto &hashes = dict_state.hashes; + VectorOperations::Hash(unique_values.data[0], hashes, 1); + + // resolved_ptr is a chain-head, not a confirmed match - a real key-miss is still rejected by + // the chain walk in Next(). A NULL constant key resolves correctly via RowMatcher null equality. + idx_t count = 1; + GetRowPointers(unique_values, dict_state.unique_key_state, probe_state, hashes, nullptr, count, + dict_state.new_dictionary_pointers, dict_state.match_sel, false); + const auto resolved_ptr = + count == 0 ? nullptr : FlatVector::GetData(dict_state.new_dictionary_pointers)[0]; + + // fan out the resolved pointer into the scan_structure shape that Probe() produces + scan_structure.is_null = false; + scan_structure.finished = false; + if (join_type != JoinType::INNER) { + memset(scan_structure.found_match.get(), 0, sizeof(bool) * STANDARD_VECTOR_SIZE); + } + TupleDataCollection::ToUnifiedFormat(key_state, keys); + optional_ptr current_sel; + const idx_t prepared_count = + PrepareKeys(keys, key_state.vector_data, current_sel, scan_structure.sel_vector, false); + scan_structure.has_null_value_filter = prepared_count < keys.size(); + + auto scan_structure_ptrs = FlatVector::GetDataMutable(scan_structure.pointers); + idx_t kept = 0; + if (resolved_ptr) { + // PrepareKeys dropped NULL probe rows - the rest all share the one resolved chain-head + for (idx_t i = 0; i < prepared_count; i++) { + const auto row_index = current_sel->get_index(i); + scan_structure_ptrs[row_index] = resolved_ptr; + scan_structure.sel_vector.set_index(kept++, row_index); + } + } + scan_structure.count = kept; + return true; +} + ScanStructure::ScanStructure(JoinHashTable &ht_p, TupleDataChunkState &key_state_p) : key_state(key_state_p), pointers(LogicalType::POINTER), count(0), sel_vector(STANDARD_VECTOR_SIZE), chain_match_sel_vector(STANDARD_VECTOR_SIZE), chain_no_match_sel_vector(STANDARD_VECTOR_SIZE), @@ -913,6 +1202,14 @@ ScanStructure::ScanStructure(JoinHashTable &ht_p, TupleDataChunkState &key_state } } +void ScanStructure::Reset() { + count = 0; + finished = false; + is_null = true; + has_null_value_filter = false; + last_match_count = 0; +} + void ScanStructure::Next(DataChunk &keys, DataChunk &probe_data, DataChunk &result) { D_ASSERT(keys.size() == probe_data.size()); @@ -964,7 +1261,7 @@ bool ScanStructure::PointersExhausted() const { } idx_t ScanStructure::ResolvePredicates(DataChunk &keys, DataChunk &probe_data, SelectionVector &match_sel, - SelectionVector *no_match_sel) { + optional_ptr no_match_sel) { // Initialize the found_match array to the current sel_vector for (idx_t i = 0; i < this->count; ++i) { match_sel.set_index(i, this->sel_vector.get_index(i)); @@ -979,8 +1276,8 @@ idx_t ScanStructure::ResolvePredicates(DataChunk &keys, DataChunk &probe_data, S // we need to only use the vectors with the indices of the columns that are used in the probe phase, namely // the non-equality columns - result_count = - matcher->Match(keys, key_state.vector_data, match_sel, this->count, pointers, no_match_sel, no_match_count); + result_count = matcher->Match(keys, key_state.vector_data, match_sel, this->count, pointers, no_match_sel.get(), + no_match_count); } else { // no match sel is the opposite of match sel result_count = this->count; @@ -999,13 +1296,13 @@ idx_t ScanStructure::ResolvePredicates(DataChunk &keys, DataChunk &probe_data, S } idx_t ScanStructure::ApplyResidualPredicate(DataChunk &probe_data, SelectionVector &match_sel, idx_t match_count, - SelectionVector *no_match_sel, idx_t no_match_offset) { + optional_ptr no_match_sel, idx_t no_match_offset) { D_ASSERT(residual_state); D_ASSERT(residual_executor); // reset chunks for reuse (no reallocation!) residual_state->eval_chunk.Reset(); - residual_state->eval_chunk.SetCardinality(match_count); + residual_state->eval_chunk.SetChildCardinality(match_count); // copy probe columns at their ORIGINAL positions for (const auto &entry : ht.residual_info->probe_input_to_probe_map) { @@ -1189,8 +1486,6 @@ void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &probe_data, DataCh idx_t probe_col_idx = ht.lhs_output_in_probe[i]; result.data[i].Slice(probe_data.data[probe_col_idx], chain_match_sel_vector, result_count); } - result.SetCardinality(result_count); - // on the RHS, we need to fetch the data from the hash table ht.GatherRHS(pointers, chain_match_sel_vector, result_count, result, ht.lhs_output_in_probe.size()); @@ -1212,8 +1507,6 @@ void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &probe_data, DataCh idx_t probe_col_idx = ht.lhs_output_in_probe[i]; result.data[i].Slice(probe_data.data[probe_col_idx], lhs_sel_vector, base_count); } - result.SetCardinality(base_count); - // 2) gather RHS vectors ht.GatherRHS(rhs_pointers, *FlatVector::IncrementalSelectionVector(), base_count, result, ht.lhs_output_in_probe.size()); @@ -1262,7 +1555,6 @@ void ScanStructure::NextSemiOrAntiJoin(DataChunk &keys, DataChunk &probe_data, D idx_t probe_col_idx = ht.lhs_output_in_probe[i]; result.data[i].Slice(probe_data.data[probe_col_idx], sel, result_count); } - result.SetCardinality(result_count); } else { D_ASSERT(result.size() == 0); } @@ -1346,7 +1638,6 @@ void ScanStructure::NextRightSemiOrAntiJoin(DataChunk &keys, DataChunk &probe_da void ScanStructure::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &probe_data, DataChunk &result) { // extract OUTPUT columns from probe_data - result.SetCardinality(probe_data.size()); for (idx_t i = 0; i < ht.lhs_output_in_probe.size(); i++) { idx_t probe_col_idx = ht.lhs_output_in_probe[i]; result.data[i].Reference(probe_data.data[probe_col_idx]); @@ -1356,6 +1647,10 @@ void ScanStructure::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &pro mark_vector.SetVectorType(VectorType::FLAT_VECTOR); FlatVector::SetSize(mark_vector, count_t(probe_data.size())); + // set the cardinality AFTER referencing the probe columns: a referenced probe column may be a constant vector + // (e.g. a constant join key) whose size must be aligned to the chunk cardinality + result.SetChildCardinality(probe_data.size()); + // first we set the NULL values from the join keys // if there is any NULL in the keys, the result is NULL auto bool_result = FlatVector::GetDataMutable(mark_vector); @@ -1412,11 +1707,11 @@ void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &probe_data, DataChu for (idx_t i = 0; i < info.group_chunk.ColumnCount(); i++) { info.group_chunk.data[i].Reference(keys.data[i]); } - info.group_chunk.SetCardinality(keys); + info.group_chunk.SetChildCardinality(keys.size()); info.correlated_counts->FetchAggregates(info.group_chunk, info.result_chunk); // extract OUTPUT columns from probe_data - result.SetCardinality(probe_data.size()); + result.SetChildCardinality(probe_data.size()); for (idx_t i = 0; i < ht.lhs_output_in_probe.size(); i++) { idx_t probe_col_idx = ht.lhs_output_in_probe[i]; result.data[i].Reference(probe_data.data[probe_col_idx]); @@ -1496,8 +1791,6 @@ void ScanStructure::NextLeftJoin(DataChunk &keys, DataChunk &probe_data, DataChu idx_t probe_col_idx = ht.lhs_output_in_probe[i]; result.data[i].Slice(probe_data.data[probe_col_idx], sel, remaining_count); } - result.SetCardinality(remaining_count); - // now set the right side to NULL for (idx_t i = ht.lhs_output_in_probe.size(); i < result.ColumnCount(); i++) { Vector &vec = result.data[i]; @@ -1549,7 +1842,7 @@ void ScanStructure::NextSingleJoin(DataChunk &keys, DataChunk &probe_data, DataC GatherResult(vector, result_sel, result_sel, result_count, output_col_idx); FlatVector::SetSize(vector, count_t(probe_data.size())); } - result.SetCardinality(probe_data.size()); + result.SetChildCardinality(probe_data.size()); // like the SEMI, ANTI and MARK join types, the SINGLE join only ever does one pass over the HT per input chunk finished = true; @@ -1619,7 +1912,7 @@ void ScanStructure::NextUniqueLeftJoin(DataChunk &keys, DataChunk &probe_data, D } // single pass - done - result.SetCardinality(probe_data.size()); + result.SetChildCardinality(probe_data.size()); finished = true; } @@ -1664,8 +1957,6 @@ void JoinHashTable::ScanFullOuter(JoinHTScanState &state, Vector &addresses, Dat if (found_entries == 0) { return; } - result.SetCardinality(found_entries); - idx_t left_column_count = result.ColumnCount() - output_columns.size(); if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) { left_column_count = 0; @@ -1744,13 +2035,13 @@ idx_t JoinHashTable::GetTotalSize(const vector &partition_sizes, const ve return total_size + PointerTableSize(total_count); } -idx_t JoinHashTable::GetTotalSize(const vector> &local_hts, idx_t &max_partition_size, +idx_t JoinHashTable::GetTotalSize(const vector> &local_hts, idx_t &max_partition_size, idx_t &max_partition_count) const { const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); vector partition_sizes(num_partitions, 0); vector partition_counts(num_partitions, 0); for (auto &ht : local_hts) { - ht->GetSinkCollection().GetSizesAndCounts(partition_sizes, partition_counts); + ht.get().GetSinkCollection().GetSizesAndCounts(partition_sizes, partition_counts); } return GetTotalSize(partition_sizes, partition_counts, max_partition_size, max_partition_count); @@ -1846,6 +2137,61 @@ void JoinHashTable::Reset() { finalized = false; } +static void ResetCorrelatedMarkJoinInfo(JoinHashTable &ht) { + auto &info = ht.correlated_mark_join_info; + if (info.correlated_types.empty()) { + return; + } + vector correlated_aggregates; + vector payload_types; + correlated_aggregates.reserve(info.correlated_aggregates.size()); + payload_types.reserve(info.correlated_aggregates.size()); + for (auto &expr : info.correlated_aggregates) { + auto &aggr = expr->Cast(); + correlated_aggregates.emplace_back(aggr); + payload_types.push_back(aggr.GetReturnType()); + } + auto &allocator = BufferAllocator::Get(ht.context); + info.correlated_counts = make_uniq(ht.context, allocator, info.correlated_types, + payload_types, std::move(correlated_aggregates)); + info.group_chunk.Reset(); + info.correlated_payload.Reset(); + info.result_chunk.Reset(); +} + +void JoinHashTable::ResetForNewIterationSinglePartition() { + data_collection->Reset(); + // Always use a single partition (radix_bits=0) to avoid per-iteration overhead of resetting + // and re-creating many radix partitions when only one thread builds the hash table. + if (radix_bits != 0) { + radix_bits = 0; + sink_collection = make_uniq(buffer_manager, layout_ptr, MemoryTag::HASH_TABLE, + idx_t(0), layout_ptr->ColumnCount() - 1); + } else { + sink_collection->Reset(); + } + InitializePartitionMasks(); + capacity = DConstants::INVALID_INDEX; + bitmask = DConstants::INVALID_INDEX; + finalized = false; + has_null = false; + chains_longer_than_one = false; + total_probe_matches = 0; + load_factor = DEFAULT_LOAD_FACTOR; + should_build_bloom_filter = false; + bloom_filter.Reset(); + bloom_filter_init_count = 0; + prefix_range_filter.reset(); + should_build_prefix_range_filter = false; + ResetCorrelatedMarkJoinInfo(*this); + // The next iteration may rebuild a different small build side, so this iteration's dictionary + // state is stale. Keep in lock-step with BuildDictionaryArrays, which sets these four fields. + dict_arrays.clear(); + aux_next_ptrs.Reset(); + aux_next_ptrs_data = nullptr; + use_dict_emission = false; +} + bool JoinHashTable::PrepareExternalFinalize(const idx_t max_ht_size) { if (finalized) { Reset(); @@ -1935,7 +2281,7 @@ void JoinHashTable::ProbeAndSpill(ScanStructure &scan_structure, DataChunk &prob probe_keys.Slice(true_sel, true_count); probe_chunk.Slice(true_sel, true_count); - const SelectionVector *current_sel; + optional_ptr current_sel; InitializeScanStructure(scan_structure, probe_keys, key_state, current_sel); if (scan_structure.count == 0) { return; diff --git a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp index c818fd0df..714e99547 100644 --- a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp +++ b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp @@ -6,8 +6,8 @@ namespace duckdb { struct InitialNestedLoopJoin { template - static idx_t Operation(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, idx_t &rpos, - SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { + static idx_t Operation(const Vector &left, const Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, + idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { using MATCH_OP = ComparisonOperationWrapper; // initialize phase of nested loop join @@ -42,8 +42,8 @@ struct InitialNestedLoopJoin { struct RefineNestedLoopJoin { template - static idx_t Operation(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, idx_t &rpos, - SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { + static idx_t Operation(const Vector &left, const Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, + idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { using MATCH_OP = ComparisonOperationWrapper; auto left_entries = left.Values(); @@ -73,8 +73,8 @@ struct RefineNestedLoopJoin { }; template -static idx_t NestedLoopJoinTypeSwitch(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, - idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, +static idx_t NestedLoopJoinTypeSwitch(const Vector &left, const Vector &right, idx_t left_size, idx_t right_size, + idx_t &lpos, idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { switch (left.GetType().InternalType()) { case PhysicalType::BOOL: @@ -126,8 +126,8 @@ static idx_t NestedLoopJoinTypeSwitch(Vector &left, Vector &right, idx_t left_si } template -idx_t NestedLoopJoinComparisonSwitch(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, - idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, +idx_t NestedLoopJoinComparisonSwitch(const Vector &left, const Vector &right, idx_t left_size, idx_t right_size, + idx_t &lpos, idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count, ExpressionType comparison_type) { D_ASSERT(left.GetType() == right.GetType()); switch (comparison_type) { diff --git a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp index 735274e66..c877201dd 100644 --- a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp +++ b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp @@ -6,7 +6,7 @@ namespace duckdb { template -static void TemplatedMarkJoin(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { +static void TemplatedMarkJoin(const Vector &left, const Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { using MATCH_OP = ComparisonOperationWrapper; auto left_entries = left.Values(); @@ -35,7 +35,7 @@ static void TemplatedMarkJoin(Vector &left, Vector &right, idx_t lcount, idx_t r } } -static void MarkJoinNested(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[], +static void MarkJoinNested(const Vector &left, const Vector &right, idx_t lcount, idx_t rcount, bool found_match[], ExpressionType comparison_type) { Vector left_reference(left.GetType()); for (idx_t i = 0; i < lcount; i++) { @@ -79,7 +79,7 @@ static void MarkJoinNested(Vector &left, Vector &right, idx_t lcount, idx_t rcou } template -static void MarkJoinSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { +static void MarkJoinSwitch(const Vector &left, const Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { switch (left.GetType().InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: @@ -113,8 +113,8 @@ static void MarkJoinSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcou } } -static void MarkJoinComparisonSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[], - ExpressionType comparison_type) { +static void MarkJoinComparisonSwitch(const Vector &left, const Vector &right, idx_t lcount, idx_t rcount, + bool found_match[], ExpressionType comparison_type) { switch (left.GetType().InternalType()) { case PhysicalType::STRUCT: case PhysicalType::LIST: diff --git a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp index 913b40c84..a928e0fbd 100644 --- a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp +++ b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp @@ -7,24 +7,27 @@ namespace duckdb { AggregateObject::AggregateObject(BoundAggregateFunction function, FunctionData *bind_data, idx_t child_count, idx_t payload_size, AggregateType aggr_type, PhysicalType return_type, - Expression *filter) + optional_ptr filter) : function(std::move(function)), bind_data_wrapper(bind_data ? make_shared_ptr(bind_data->Copy()) : nullptr), child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type), filter(filter) { } -AggregateObject::AggregateObject(BoundAggregateExpression *aggr) - : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(), - AlignValue(aggr->function.GetStateSizeCallback()(aggr->function)), aggr->aggr_type, - aggr->GetReturnType().InternalType(), aggr->filter.get()) { +AggregateObject::AggregateObject(BoundAggregateExpression &aggr) + : AggregateObject(aggr.Function(), aggr.BindInfoMutable().get(), aggr.GetChildren().size(), + AlignValue(aggr.Function().GetStateSizeCallback()(aggr.Function())), aggr.GetAggregateType(), + aggr.GetReturnType().InternalType(), aggr.GetFilter().get()) { +} + +AggregateObject::AggregateObject(BoundAggregateExpression *aggr) : AggregateObject(*aggr) { } AggregateObject::AggregateObject(const BoundWindowExpression &window) - : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(), - AlignValue(window.aggregate->GetStateSizeCallback()(*window.aggregate)), - window.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT, - window.GetReturnType().InternalType(), window.filter_expr.get()) { + : AggregateObject(*window.AggregateFunction(), window.BindInfo().get(), window.GetChildren().size(), + AlignValue(window.AggregateFunction()->GetStateSizeCallback()(*window.AggregateFunction())), + window.Distinct() ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT, + window.GetReturnType().InternalType(), window.Filter().get()) { } vector AggregateObject::CreateAggregateObjects(const vector &bindings) { @@ -36,7 +39,7 @@ vector AggregateObject::CreateAggregateObjects(const vector &payload_types) : filter_executor(context, &filter_expr), true_sel(STANDARD_VECTOR_SIZE) { if (payload_types.empty()) { diff --git a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp index a1242ae42..809f6a2ca 100644 --- a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp +++ b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp @@ -21,7 +21,7 @@ DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vectorCast(); // Initialize the child executor and get the payload types for every aggregate - for (auto &child : aggregate.children) { + for (auto &child : aggregate.GetChildren()) { child_executor.AddExpression(*child); } if (!aggregate.IsDistinct()) { @@ -97,7 +97,7 @@ DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionIn grouping_set.insert(group); } idx_t group_by_size = group_expressions ? group_expressions->size() : 0; - for (idx_t set_idx = 0; set_idx < aggregate.children.size(); set_idx++) { + for (idx_t set_idx = 0; set_idx < aggregate.GetChildren().size(); set_idx++) { grouping_set.insert(ProjectionIndex(set_idx + group_by_size)); } // Create the hashtable for the aggregate @@ -108,7 +108,7 @@ DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionIn // Fill the chunk_types (only contains the payload of the distinct aggregates) vector chunk_types; - for (auto &child_p : aggregate.children) { + for (auto &child_p : aggregate.GetChildren()) { chunk_types.push_back(child_p->GetReturnType()); } } @@ -122,16 +122,16 @@ struct FindMatchingAggregate { bool operator()(const aggr_ref_t other_r) { auto &other = other_r.get(); auto &aggr = aggr_r.get(); - if (other.children.size() != aggr.children.size()) { + if (other.GetChildren().size() != aggr.GetChildren().size()) { return false; } - if (!Expression::Equals(aggr.filter, other.filter)) { + if (!Expression::Equals(aggr.GetFilterMutable(), other.GetFilterMutable())) { return false; } - for (idx_t i = 0; i < aggr.children.size(); i++) { - auto &other_child = other.children[i]->Cast(); - auto &aggr_child = aggr.children[i]->Cast(); - if (other_child.index != aggr_child.index) { + for (idx_t i = 0; i < aggr.GetChildren().size(); i++) { + auto &other_child = other.GetChildren()[i]->Cast(); + auto &aggr_child = aggr.GetChildren()[i]->Cast(); + if (other_child.Index() != aggr_child.Index()) { return false; } } diff --git a/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp b/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp index 068a34b72..04e9350c8 100644 --- a/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp +++ b/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp @@ -26,15 +26,15 @@ void GroupedAggregateData::InitializeGroupby(vector> grou bindings.push_back(&aggr); aggregate_return_types.push_back(aggr.GetReturnType()); - for (auto &child : aggr.children) { + for (auto &child : aggr.GetChildren()) { payload_types.push_back(child->GetReturnType()); } - if (aggr.filter) { + if (aggr.GetFilter()) { filter_count++; - payload_types_filters.push_back(aggr.filter->GetReturnType()); + payload_types_filters.push_back(aggr.GetFilter()->GetReturnType()); } - if (!aggr.function.HasStateCombineCallback()) { - throw InternalException("Aggregate function %s is missing a combine method", aggr.function.GetName()); + if (!aggr.Function().HasStateCombineCallback()) { + throw InternalException("Aggregate function %s is missing a combine method", aggr.Function().GetName()); } aggregates.push_back(std::move(expr)); } @@ -54,17 +54,17 @@ void GroupedAggregateData::InitializeDistinct(const unique_ptr &aggr // bindings.push_back(&aggr); filter_count = 0; aggregate_return_types.push_back(aggr.GetReturnType()); - for (idx_t i = 0; i < aggr.children.size(); i++) { - auto &child = aggr.children[i]; + for (idx_t i = 0; i < aggr.GetChildren().size(); i++) { + auto &child = aggr.GetChildren()[i]; group_types.push_back(child->GetReturnType()); groups.push_back(child->Copy()); payload_types.push_back(child->GetReturnType()); - if (aggr.filter) { + if (aggr.GetFilter()) { filter_count++; } } - if (!aggr.function.HasStateCombineCallback()) { - throw InternalException("Aggregate function %s is missing a combine method", aggr.function.GetName()); + if (!aggr.Function().HasStateCombineCallback()) { + throw InternalException("Aggregate function %s is missing a combine method", aggr.Function().GetName()); } } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp index 4cd67b246..7afb9a1c6 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp @@ -84,13 +84,13 @@ static vector CreateGroupChunkTypes(vector> for (auto &group : groups) { D_ASSERT(group->GetExpressionType() == ExpressionType::BOUND_REF); auto &bound_ref = group->Cast(); - group_indices.insert(bound_ref.index); + group_indices.insert(bound_ref.Index()); } idx_t highest_index = *group_indices.rbegin(); vector types(highest_index + 1, LogicalType::SQLNULL); for (auto &group : groups) { auto &bound_ref = group->Cast(); - types[bound_ref.index] = bound_ref.GetReturnType(); + types[bound_ref.Index()] = bound_ref.GetReturnType(); } return types; } @@ -155,10 +155,10 @@ PhysicalHashAggregate::PhysicalHashAggregate(PhysicalPlan &physical_plan, Client for (idx_t i = 0; i < aggregates.size(); i++) { auto &aggregate = aggregates[i]; auto &aggr = aggregate->Cast(); - aggregate_input_idx += aggr.children.size(); - if (aggr.aggr_type == AggregateType::DISTINCT) { + aggregate_input_idx += aggr.GetChildren().size(); + if (aggr.GetAggregateType() == AggregateType::DISTINCT) { distinct_filter.push_back(i); - } else if (aggr.aggr_type == AggregateType::NON_DISTINCT) { + } else if (aggr.GetAggregateType() == AggregateType::NON_DISTINCT) { non_distinct_filter.push_back(i); } else { // LCOV_EXCL_START throw NotImplementedException("AggregateType not implemented in PhysicalHashAggregate"); @@ -168,12 +168,13 @@ PhysicalHashAggregate::PhysicalHashAggregate(PhysicalPlan &physical_plan, Client for (idx_t i = 0; i < aggregates.size(); i++) { auto &aggregate = aggregates[i]; auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto &bound_ref_expr = aggr.filter->Cast(); - if (!filter_indexes.count(aggr.filter.get())) { + if (aggr.GetFilter()) { + auto &filter_ref = *aggr.GetFilter(); + auto &bound_ref_expr = filter_ref.Cast(); + if (!filter_indexes.count(filter_ref)) { // Replace the bound reference expression's index with the corresponding index of the payload chunk - filter_indexes[aggr.filter.get()] = bound_ref_expr.index; - bound_ref_expr.index = aggregate_input_idx; + filter_indexes[filter_ref] = bound_ref_expr.Index(); + bound_ref_expr.IndexMutable() = aggregate_input_idx; } aggregate_input_idx++; } @@ -192,7 +193,7 @@ PhysicalHashAggregate::PhysicalHashAggregate(PhysicalPlan &physical_plan, Client //===--------------------------------------------------------------------===// class HashAggregateGlobalSinkState : public GlobalSinkState { public: - HashAggregateGlobalSinkState(const PhysicalHashAggregate &op, ClientContext &context) { + HashAggregateGlobalSinkState(const PhysicalHashAggregate &op, ClientContext &context) : op(op) { grouping_states.reserve(op.groupings.size()); for (idx_t i = 0; i < op.groupings.size(); i++) { auto &grouping = op.groupings[i]; @@ -201,26 +202,53 @@ class HashAggregateGlobalSinkState : public GlobalSinkState { vector filter_types; for (auto &aggr : op.grouped_aggregate_data.aggregates) { auto &aggregate = aggr->Cast(); - for (auto &child : aggregate.children) { + for (auto &child : aggregate.GetChildren()) { payload_types.push_back(child->GetReturnType()); } - if (aggregate.filter) { - filter_types.push_back(aggregate.filter->GetReturnType()); + if (aggregate.GetFilter()) { + filter_types.push_back(aggregate.GetFilter()->GetReturnType()); } } payload_types.reserve(payload_types.size() + filter_types.size()); payload_types.insert(payload_types.end(), filter_types.begin(), filter_types.end()); } + const PhysicalHashAggregate &op; vector grouping_states; vector payload_types; //! Whether or not the aggregate is finished bool finished = false; + + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { + auto &grouping = op.groupings[grouping_idx]; + auto &grouping_state = grouping_states[grouping_idx]; + grouping.table_data.ResetGlobalSinkState(context, *grouping_state.table_state); + if (!grouping.HasDistinct()) { + continue; + } + auto &distinct_data = *grouping.distinct_data; + auto &distinct_state = *grouping_state.distinct_state; + for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { + auto &radix_table = distinct_data.radix_tables[table_idx]; + if (!radix_table) { + continue; + } + radix_table->ResetGlobalSinkState(context, *distinct_state.radix_states[table_idx]); + } + } + finished = false; + GlobalSinkState::Reset(context); + } }; class HashAggregateLocalSinkState : public LocalSinkState { public: - HashAggregateLocalSinkState(const PhysicalHashAggregate &op, ExecutionContext &context) { + HashAggregateLocalSinkState(const PhysicalHashAggregate &op, ExecutionContext &context) : op(op) { auto &payload_types = op.grouped_aggregate_data.payload_types; if (!payload_types.empty()) { aggregate_input_chunk.InitializeEmpty(payload_types); @@ -241,9 +269,41 @@ class HashAggregateLocalSinkState : public LocalSinkState { filter_set.Initialize(context.client, aggregate_objects, payload_types); } + const PhysicalHashAggregate &op; DataChunk aggregate_input_chunk; vector grouping_states; AggregateFilterDataSet filter_set; + + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSinkState &gstate_p) override { + auto &gstate = gstate_p.Cast(); + // Sink repopulates every aggregate-input column by reference before use, so we just clear it here. + // Use Reset() rather than SetChildCardinality(0): the chunk may still hold non-flat (e.g. dictionary) + // references from a previous iteration that SetChildCardinality cannot resize. + aggregate_input_chunk.Reset(); + for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { + auto &grouping = op.groupings[grouping_idx]; + auto &grouping_gstate = gstate.grouping_states[grouping_idx]; + auto &grouping_state = grouping_states[grouping_idx]; + grouping.table_data.ResetLocalSinkState(context, *grouping_gstate.table_state, *grouping_state.table_state); + if (!grouping.HasDistinct()) { + continue; + } + auto &distinct_data = *grouping.distinct_data; + auto &distinct_gstate = *grouping_gstate.distinct_state; + for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { + auto &radix_table = distinct_data.radix_tables[table_idx]; + if (!radix_table) { + continue; + } + radix_table->ResetLocalSinkState(context, *distinct_gstate.radix_states[table_idx], + *grouping_state.distinct_states[table_idx]); + } + } + } }; void PhysicalHashAggregate::SetMultiScan(GlobalSinkState &state) { @@ -297,18 +357,18 @@ void PhysicalHashAggregate::SinkDistinctGrouping(ExecutionContext &context, Data InterruptState interrupt_state; OperatorSinkInput sink_input {radix_global_sink, radix_local_sink, interrupt_state}; - if (aggregate.filter) { + if (aggregate.GetFilter()) { DataChunk filter_chunk; auto &filtered_data = sink.filter_set.GetFilterData(idx); filter_chunk.InitializeEmpty(filtered_data.filtered_payload.GetTypes()); // Add the filter Vector (BOOL) - auto it = filter_indexes.find(aggregate.filter.get()); + auto &filter_ref = *aggregate.GetFilter(); + auto it = filter_indexes.find(filter_ref); D_ASSERT(it != filter_indexes.end()); D_ASSERT(it->second < chunk.data.size()); - auto &filter_bound_ref = aggregate.filter->Cast(); - filter_chunk.data[filter_bound_ref.index].Reference(chunk.data[it->second]); - filter_chunk.SetCardinality(chunk.size()); + auto &filter_bound_ref = filter_ref.Cast(); + filter_chunk.data[filter_bound_ref.Index()].Reference(chunk.data[it->second]); // We cant use the AggregateFilterData::ApplyFilter method, because the chunk we need to // apply the filter to also has the groups, and the filtered_data.filtered_payload does not have those. @@ -327,18 +387,17 @@ void PhysicalHashAggregate::SinkDistinctGrouping(ExecutionContext &context, Data for (idx_t group_idx = 0; group_idx < grouped_aggregate_data.groups.size(); group_idx++) { auto &group = grouped_aggregate_data.groups[group_idx]; auto &bound_ref = group->Cast(); - auto &col = filtered_input.data[bound_ref.index]; - col.Reference(chunk.data[bound_ref.index]); + auto &col = filtered_input.data[bound_ref.Index()]; + col.Reference(chunk.data[bound_ref.Index()]); col.Slice(sel_vec, count); } - for (idx_t child_idx = 0; child_idx < aggregate.children.size(); child_idx++) { - auto &child = aggregate.children[child_idx]; + for (idx_t child_idx = 0; child_idx < aggregate.GetChildren().size(); child_idx++) { + auto &child = aggregate.GetChildren()[child_idx]; auto &bound_ref = child->Cast(); - auto &col = filtered_input.data[bound_ref.index]; - col.Reference(chunk.data[bound_ref.index]); + auto &col = filtered_input.data[bound_ref.Index()]; + col.Reference(chunk.data[bound_ref.Index()]); col.Slice(sel_vec, count); } - filtered_input.SetCardinality(count); radix_table.Sink(context, filtered_input, sink_input, empty_chunk, empty_filter); } else { @@ -373,25 +432,25 @@ SinkResultType PhysicalHashAggregate::Sink(ExecutionContext &context, DataChunk // Populate the aggregate child vectors for (auto &aggregate : aggregates) { auto &aggr = aggregate->Cast(); - for (auto &child_expr : aggr.children) { + for (auto &child_expr : aggr.GetChildren()) { D_ASSERT(child_expr->GetExpressionType() == ExpressionType::BOUND_REF); auto &bound_ref_expr = child_expr->Cast(); - D_ASSERT(bound_ref_expr.index < chunk.data.size()); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.index]); + D_ASSERT(bound_ref_expr.Index() < chunk.data.size()); + aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.Index()]); } } // Populate the filter vectors for (auto &aggregate : aggregates) { auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto it = filter_indexes.find(aggr.filter.get()); + if (aggr.GetFilter()) { + auto it = filter_indexes.find(*aggr.GetFilter()); D_ASSERT(it != filter_indexes.end()); D_ASSERT(it->second < chunk.data.size()); aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[it->second]); } } - aggregate_input_chunk.SetCardinality(chunk.size()); + aggregate_input_chunk.SetChildCardinality(chunk.size()); aggregate_input_chunk.Verify(context.client.db); // For every grouping set there is one radix_table @@ -691,7 +750,7 @@ TaskExecutionResult HashAggregateDistinctFinalizeTask::AggregateDistinctGrouping if (!blocked) { // Forward the payload idx payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); + next_payload_idx = payload_idx + aggregate.GetChildren().size(); } // If aggregate is not distinct, skip it @@ -734,15 +793,14 @@ TaskExecutionResult HashAggregateDistinctFinalizeTask::AggregateDistinctGrouping for (idx_t group_idx = 0; group_idx < group_by_size; group_idx++) { auto &group = grouped_aggregate_data.groups[group_idx]; auto &bound_ref_expr = group->Cast(); - group_chunk.data[bound_ref_expr.index].Reference(output_chunk.data[group_idx]); + group_chunk.data[bound_ref_expr.Index()].Reference(output_chunk.data[group_idx]); } - group_chunk.SetCardinality(output_chunk); for (idx_t child_idx = 0; child_idx < grouped_aggregate_data.groups.size() - group_by_size; child_idx++) { aggregate_input_chunk.data[payload_idx + child_idx].Reference( output_chunk.data[group_by_size + child_idx]); } - aggregate_input_chunk.SetCardinality(output_chunk); + aggregate_input_chunk.SetChildCardinality(output_chunk.size()); // Sink it into the main ht grouping_data.table_data.Sink(execution_context, group_chunk, sink_input, aggregate_input_chunk, {agg_idx}); @@ -806,11 +864,12 @@ SinkFinalizeType PhysicalHashAggregate::Finalize(Pipeline &pipeline, Event &even //===--------------------------------------------------------------------===// class HashAggregateGlobalSourceState : public GlobalSourceState { public: - HashAggregateGlobalSourceState(ClientContext &context, const PhysicalHashAggregate &op) : op(op), state_index(0) { + HashAggregateGlobalSourceState(ClientContext &context, const PhysicalHashAggregate &op) : op(op) { for (auto &grouping : op.groupings) { auto &rt = grouping.table_data; radix_states.push_back(rt.GetGlobalSourceState(context)); } + ResetState(context); } const PhysicalHashAggregate &op; @@ -818,6 +877,15 @@ class HashAggregateGlobalSourceState : public GlobalSourceState { vector> radix_states; +private: + void ResetState(ClientContext &context) { + state_index = 0; + for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { + op.groupings[grouping_idx].table_data.ResetGlobalSourceState(context, *radix_states[grouping_idx]); + } + GlobalSourceState::Reset(context); + } + public: idx_t MaxThreads() override { // If there are no tables, we only need one thread. @@ -834,6 +902,14 @@ class HashAggregateGlobalSourceState : public GlobalSourceState { } return MaxValue(1, threads); } + + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + ResetState(context); + } }; unique_ptr PhysicalHashAggregate::GetGlobalSourceState(ClientContext &context) const { @@ -842,20 +918,41 @@ unique_ptr PhysicalHashAggregate::GetGlobalSourceState(Client class HashAggregateLocalSourceState : public LocalSourceState { public: - explicit HashAggregateLocalSourceState(ExecutionContext &context, const PhysicalHashAggregate &op) { + explicit HashAggregateLocalSourceState(ExecutionContext &context, const PhysicalHashAggregate &op, + GlobalSourceState &gstate) + : op(op) { for (auto &grouping : op.groupings) { auto &rt = grouping.table_data; radix_states.push_back(rt.GetLocalSourceState(context)); } + ResetState(); } + const PhysicalHashAggregate &op; optional_idx radix_idx; vector> radix_states; + +private: + void ResetState() { + radix_idx.SetInvalid(); + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSourceState &gstate) override { + ResetState(); + for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { + op.groupings[grouping_idx].table_data.ResetLocalSourceState(context, *radix_states[grouping_idx]); + } + } }; unique_ptr PhysicalHashAggregate::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { - return make_uniq(context, *this); + return make_uniq(context, *this, gstate); } SourceResultType PhysicalHashAggregate::GetDataInternal(ExecutionContext &context, DataChunk &chunk, @@ -931,8 +1028,8 @@ InsertionOrderPreservingMap PhysicalHashAggregate::ParamsToString() cons aggregate_info += "\n"; } aggregate_info += aggregates[i]->GetName(); - if (aggregate.filter) { - aggregate_info += " Filter: " + aggregate.filter->GetName(); + if (aggregate.GetFilter()) { + aggregate_info += " Filter: " + aggregate.GetFilter()->GetName(); } } result["Aggregates"] = aggregate_info; diff --git a/src/duckdb/src/execution/operator/aggregate/physical_limited_distinct.cpp b/src/duckdb/src/execution/operator/aggregate/physical_limited_distinct.cpp index d4d6dd5d5..b5726cfdf 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_limited_distinct.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_limited_distinct.cpp @@ -11,7 +11,7 @@ static vector GetPayloadTypes(const vector> vector payload_types; for (auto &aggr_expr : aggregates) { auto &aggr = aggr_expr->Cast(); - for (auto &child : aggr.children) { + for (auto &child : aggr.GetChildren()) { payload_types.push_back(child->GetReturnType()); } } @@ -69,7 +69,7 @@ struct LimitedDistinctLocalSinkState : public LocalSinkState { // Build payload executor for aggregate children for (auto &aggr_expr : op.aggregates) { auto &aggr = aggr_expr->Cast(); - for (auto &child : aggr.children) { + for (auto &child : aggr.GetChildren()) { payload_executor.AddExpression(*child); } } @@ -216,7 +216,6 @@ SourceResultType PhysicalLimitedDistinct::GetDataInternal(ExecutionContext &cont chunk.data[col].Reference(gstate_source.payload_chunk.data[i]); col++; } - chunk.SetCardinality(gstate_source.group_chunk.size()); return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp index 553e42409..33c9b3a62 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp @@ -217,8 +217,8 @@ InsertionOrderPreservingMap PhysicalPartitionedAggregate::ParamsToString aggregate_info += "\n"; } aggregate_info += aggregates[i]->GetName(); - if (aggregate.filter) { - aggregate_info += " Filter: " + aggregate.filter->GetName(); + if (aggregate.GetFilter()) { + aggregate_info += " Filter: " + aggregate.GetFilter()->GetName(); } } result["Aggregates"] = aggregate_info; diff --git a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp index f08f178a4..180e50129 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp @@ -37,12 +37,12 @@ PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(PhysicalPlan &physica bindings.push_back(&aggr); D_ASSERT(!aggr.IsDistinct()); - D_ASSERT(aggr.function.HasStateCombineCallback()); - for (auto &child : aggr.children) { + D_ASSERT(aggr.Function().HasStateCombineCallback()); + for (auto &child : aggr.GetChildren()) { payload_types.push_back(child->GetReturnType()); } - if (aggr.filter) { - payload_types_filters.push_back(aggr.filter->GetReturnType()); + if (aggr.GetFilter()) { + payload_types_filters.push_back(aggr.GetFilter()->GetReturnType()); } } for (const auto &pay_filters : payload_types_filters) { @@ -54,16 +54,17 @@ PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(PhysicalPlan &physica idx_t aggregate_input_idx = 0; for (auto &aggregate : aggregates) { auto &aggr = aggregate->Cast(); - aggregate_input_idx += aggr.children.size(); + aggregate_input_idx += aggr.GetChildren().size(); } for (auto &aggregate : aggregates) { auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto &bound_ref_expr = aggr.filter->Cast(); - auto it = filter_indexes.find(aggr.filter.get()); + if (aggr.GetFilter()) { + auto &filter_ref = *aggr.GetFilter(); + auto &bound_ref_expr = filter_ref.Cast(); + auto it = filter_indexes.find(filter_ref); if (it == filter_indexes.end()) { - filter_indexes[aggr.filter.get()] = bound_ref_expr.index; - bound_ref_expr.index = aggregate_input_idx++; + filter_indexes[filter_ref] = bound_ref_expr.Index(); + bound_ref_expr.IndexMutable() = aggregate_input_idx++; } else { ++aggregate_input_idx; } @@ -126,30 +127,26 @@ SinkResultType PhysicalPerfectHashAggregate::Sink(ExecutionContext &context, Dat auto &group = groups[group_idx]; D_ASSERT(group->GetExpressionType() == ExpressionType::BOUND_REF); auto &bound_ref_expr = group->Cast(); - group_chunk.data[group_idx].Reference(chunk.data[bound_ref_expr.index]); + group_chunk.data[group_idx].Reference(chunk.data[bound_ref_expr.Index()]); } idx_t aggregate_input_idx = 0; for (auto &aggregate : aggregates) { auto &aggr = aggregate->Cast(); - for (auto &child_expr : aggr.children) { + for (auto &child_expr : aggr.GetChildren()) { D_ASSERT(child_expr->GetExpressionType() == ExpressionType::BOUND_REF); auto &bound_ref_expr = child_expr->Cast(); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.index]); + aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.Index()]); } } for (auto &aggregate : aggregates) { auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto it = filter_indexes.find(aggr.filter.get()); + if (aggr.GetFilter()) { + auto it = filter_indexes.find(*aggr.GetFilter()); D_ASSERT(it != filter_indexes.end()); aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[it->second]); } } - group_chunk.SetCardinality(chunk.size()); - - aggregate_input_chunk.SetCardinality(chunk.size()); - group_chunk.Verify(context.client.db); aggregate_input_chunk.Verify(context.client.db); D_ASSERT(aggregate_input_chunk.ColumnCount() == 0 || group_chunk.size() == aggregate_input_chunk.size()); @@ -220,8 +217,8 @@ InsertionOrderPreservingMap PhysicalPerfectHashAggregate::ParamsToString } aggregate_info += aggregates[i]->GetName(); auto &aggregate = aggregates[i]->Cast(); - if (aggregate.filter) { - aggregate_info += " Filter: " + aggregate.filter->GetName(); + if (aggregate.GetFilter()) { + aggregate_info += " Filter: " + aggregate.GetFilter()->GetName(); } } result["Aggregates"] = aggregate_info; diff --git a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp index e3d8d00f1..b13e08b03 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp @@ -38,13 +38,13 @@ class StreamingWindowState : public OperatorState { statev(LogicalType::POINTER, data_ptr_cast(&state_ptr), 1ULL), hashes(LogicalType::HASH), addresses(LogicalType::POINTER) { D_ASSERT(wexpr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE); - auto &aggregate = *wexpr.aggregate; - bind_data = wexpr.bind_info.get(); + auto &aggregate = *wexpr.AggregateFunction(); + bind_data = wexpr.BindInfo().get(); dtor = aggregate.GetCallbacks().GetStateDestructorCallback(); state.resize(aggregate.GetCallbacks().GetStateSizeCallback()(aggregate)); state_ptr = state.data(); aggregate.GetCallbacks().GetStateInitCallback()(aggregate, state.data()); - for (auto &child : wexpr.children) { + for (auto &child : wexpr.GetChildren()) { arg_types.push_back(child->GetReturnType()); executor.AddExpression(*child); } @@ -52,11 +52,11 @@ class StreamingWindowState : public OperatorState { arg_chunk.Initialize(allocator, arg_types); arg_cursor.Initialize(allocator, arg_types); } - if (wexpr.filter_expr) { - filter_executor.AddExpression(*wexpr.filter_expr); + if (wexpr.Filter()) { + filter_executor.AddExpression(*wexpr.Filter()); filter_sel.Initialize(); } - if (wexpr.distinct) { + if (wexpr.Distinct()) { distinct = make_uniq(client, allocator, arg_types); distinct_args.Initialize(allocator, arg_types); distinct_sel.Initialize(); @@ -65,7 +65,7 @@ class StreamingWindowState : public OperatorState { ~AggregateState() override { if (dtor) { - AggregateInputData aggr_input_data(bind_data, arena_allocator); + AggregateInputData aggr_input_data(*wexpr.AggregateFunction(), bind_data, arena_allocator); state_ptr = state.data(); dtor(statev, aggr_input_data, 1); } @@ -129,8 +129,8 @@ class StreamingWindowState : public OperatorState { auto &fstate = states[expr_idx]; if (expr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE) { fstate = make_uniq(client, wexpr, allocator); - } else if (wexpr.window && wexpr.window->HasStreamingStateCallback()) { - fstate = wexpr.window->GetStreamingState(client, input, wexpr); + } else if (wexpr.WindowFunction() && wexpr.WindowFunction()->HasStreamingStateCallback()) { + fstate = wexpr.WindowFunction()->GetStreamingState(client, input, wexpr); } else { throw NotImplementedException("GetStreamingState for %s", ExpressionTypeToString(expr.GetExpressionType())); @@ -148,7 +148,7 @@ class StreamingWindowState : public OperatorState { initialized = true; } - static inline void Reset(DataChunk &chunk) { + static inline void ResetChunk(DataChunk &chunk) { chunk.Reset(); } @@ -166,27 +166,30 @@ class StreamingWindowState : public OperatorState { DataChunk delayed; //! A buffer for shifting delayed input DataChunk shifted; + //! Set when `delayed` has been flushed into the output by reference - we must defer resetting (resizing) it + //! until the next call, by which point the referencing output chunk has been consumed. + bool flushed_delayed = false; }; StreamingWindowGlobalState::StreamingWindowGlobalState(ClientContext &client) { local_state = make_uniq(client); } -bool PhysicalStreamingWindow::IsStreamingFunction(ClientContext &client, unique_ptr &expr) { - auto &wexpr = expr->Cast(); - if (!wexpr.partitions.empty() || !wexpr.orders.empty() || !wexpr.arg_orders.empty() || - wexpr.exclude_clause != WindowExcludeMode::NO_OTHER) { +bool PhysicalStreamingWindow::IsStreamingFunction(ClientContext &client, BoundWindowExpression &wexpr) { + if (!wexpr.Partitions().empty() || !wexpr.OrderBy().empty() || !wexpr.ArgOrders().empty() || + wexpr.WindowExclude() != WindowExcludeMode::NO_OTHER) { return false; } if (wexpr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE) { // Aggregates with destructors (e.g., quantile) are too slow to repeatedly update/finalize - if (wexpr.aggregate->HasStateDestructorCallback()) { + if (wexpr.AggregateFunction()->HasStateDestructorCallback()) { return false; } // We can stream aggregates if they are "running totals" - return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS; - } else if (wexpr.window && wexpr.window->HasCanStreamCallback()) { - return wexpr.window->CanStream(client, wexpr, StreamingWindowState::MAX_BUFFER); + return wexpr.WindowStart() == WindowBoundary::UNBOUNDED_PRECEDING && + wexpr.WindowEnd() == WindowBoundary::CURRENT_ROW_ROWS; + } else if (wexpr.WindowFunction() && wexpr.WindowFunction()->HasCanStreamCallback()) { + return wexpr.WindowFunction()->CanStream(client, wexpr, StreamingWindowState::MAX_BUFFER); } else { return false; } @@ -199,7 +202,7 @@ unique_ptr PhysicalStreamingWindow::GetGlobalOperatorState( void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, DataChunk &input, Vector &result) { // Establish the aggregation environment const idx_t count = input.size(); - auto &aggregate = *wexpr.aggregate; + auto &aggregate = *wexpr.AggregateFunction(); auto &aggr_state = *this; auto &statev = aggr_state.statev; @@ -207,7 +210,7 @@ void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, Da ValidityMask filter_mask; auto filtered = count; auto &filter_sel = aggr_state.filter_sel; - if (wexpr.filter_expr) { + if (wexpr.Filter()) { filtered = filter_executor.SelectExpression(input, filter_sel); if (filtered < count) { filter_mask.Initialize(count); @@ -219,7 +222,7 @@ void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, Da } // Check for COUNT(*) - if (wexpr.children.empty()) { + if (wexpr.GetChildren().empty()) { D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t)); auto data = FlatVector::Writer(result, count); auto &unfiltered = aggr_state.unfiltered; @@ -241,7 +244,7 @@ void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, Da if (aggr_state.distinct) { auto &distinct_args = aggr_state.distinct_args; distinct_args.Reference(arg_chunk); - if (wexpr.filter_expr) { + if (wexpr.Filter()) { distinct_args.Slice(filter_sel, filtered); } idx_t distinct = 0; @@ -277,7 +280,7 @@ void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, Da // is not copied to the children when you slice vector structs; for (column_t col_idx = 0; col_idx < arg_chunk.ColumnCount(); ++col_idx) { - auto &col_vec = arg_cursor.data[col_idx]; + const auto &col_vec = arg_cursor.data[col_idx]; if (col_vec.GetType().InternalType() == PhysicalType::STRUCT) { structs.emplace_back(col_idx); } else { @@ -286,7 +289,7 @@ void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, Da } // Update the state and finalize it one row at a time. - AggregateInputData aggr_input_data(wexpr.bind_info.get(), aggr_state.arena_allocator); + AggregateInputData aggr_input_data(*wexpr.AggregateFunction(), wexpr.BindInfo().get(), aggr_state.arena_allocator); for (idx_t i = 0; i < count; ++i) { sel.set_index(0, i); for (const auto struct_idx : structs) { @@ -315,8 +318,8 @@ void PhysicalStreamingWindow::ExecuteFunctions(ExecutionContext &context, DataCh auto &fstate = *state.states[expr_idx]; if (expr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE) { fstate.Cast().Execute(context, output, result); - } else if (wexpr.window && wexpr.window->HasStreamingDataCallback()) { - wexpr.window->GetStreamingData(context, output, delayed, state.delayed_capacity, result, fstate); + } else if (wexpr.WindowFunction() && wexpr.WindowFunction()->HasStreamingDataCallback()) { + wexpr.WindowFunction()->GetStreamingData(context, output, delayed, state.delayed_capacity, result, fstate); } else { throw NotImplementedException("GetStreamingData for %s", ExpressionTypeToString(expr.GetExpressionType())); } @@ -343,7 +346,6 @@ void PhysicalStreamingWindow::ExecuteInput(ExecutionContext &context, DataChunk output.data[col_idx].Reference(input.data[col_idx]); FlatVector::SetSize(output.data[col_idx], count_t(count)); } - output.SetCardinality(count); ExecuteFunctions(context, output, state.delayed, gstate_p); } @@ -359,10 +361,10 @@ void PhysicalStreamingWindow::ExecuteShifted(ExecutionContext &context, DataChun idx_t delay = delayed.size(); D_ASSERT(out <= delay); - state.Reset(shifted); + state.ResetChunk(shifted); // shifted = delayed delayed.Copy(shifted); - state.Reset(delayed); + state.ResetChunk(delayed); const idx_t new_delayed_count = delay - out + in; for (idx_t col_idx = 0; col_idx < delayed.data.size(); ++col_idx) { // output[0:out] = delayed[0:out] @@ -374,7 +376,6 @@ void PhysicalStreamingWindow::ExecuteShifted(ExecutionContext &context, DataChun VectorOperations::Copy(input.data[col_idx], delayed.data[col_idx], in, 0, delay - out); FlatVector::SetSize(delayed.data[col_idx], count_t(new_delayed_count)); } - delayed.SetCardinality(new_delayed_count); ExecuteFunctions(context, output, delayed, gstate_p); } @@ -387,7 +388,6 @@ void PhysicalStreamingWindow::ExecuteDelayed(ExecutionContext &context, DataChun output.data[col_idx].Reference(delayed.data[col_idx]); FlatVector::SetSize(output.data[col_idx], count_t(count)); } - output.SetCardinality(count); ExecuteFunctions(context, output, input, gstate_p); } @@ -401,29 +401,38 @@ OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, D } auto &delayed = state.delayed; - // We can Reset delayed now that no one can be referencing it. - if (!delayed.size()) { - state.Reset(delayed); + if (!state.lead_count) { + // Without LEAD nothing is ever delayed (the delayed buffer is not even initialized), so emit the input + // directly. + ExecuteInput(context, delayed, input, output, gstate_p); + return OperatorResultType::NEED_MORE_INPUT; + } + // We can Reset delayed now that no one can be referencing it (the previous output has been consumed). + if (state.flushed_delayed) { + state.ResetChunk(delayed); + state.flushed_delayed = false; + } else if (!delayed.size()) { + state.ResetChunk(delayed); } if (delayed.size() < state.lead_count) { // If we don't have enough to produce a single row, // then just delay more rows, return nothing // and ask for more data. delayed.Append(input); - output.SetCardinality(0); return OperatorResultType::NEED_MORE_INPUT; } else if (input.size() < delayed.size()) { // If we can't consume all of the delayed values, // we need to split them instead of referencing them all - output.SetCardinality(input.size()); + output.SetChildCardinality(input.size()); ExecuteShifted(context, delayed, input, output, gstate_p); // We delayed the unused input so ask for more return OperatorResultType::NEED_MORE_INPUT; } else if (delayed.size()) { // We have enough delayed rows so flush them ExecuteDelayed(context, delayed, input, output, gstate_p); - // Defer resetting delayed as it may be referenced. - delayed.SetCardinality(0); + // delayed has been flushed into the output by reference. Defer resetting it until the next call (when the + // output has been consumed), since resizing its buffers now would corrupt the referencing output chunk. + state.flushed_delayed = true; // Come back to process the input return OperatorResultType::HAVE_MORE_OUTPUT; } else { @@ -442,11 +451,11 @@ OperatorFinalizeResultType PhysicalStreamingWindow::FinalExecute(ExecutionContex auto &delayed = state.delayed; // There are no more input rows auto &input = state.shifted; - state.Reset(input); + state.ResetChunk(input); if (delayed.size() > STANDARD_VECTOR_SIZE) { // More than one output buffer was delayed, so shift in what we can - output.SetCardinality(STANDARD_VECTOR_SIZE); + output.SetChildCardinality(STANDARD_VECTOR_SIZE); ExecuteShifted(context, delayed, input, output, gstate_p); return OperatorFinalizeResultType::HAVE_MORE_OUTPUT; } diff --git a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index 664bb5da0..2a3098c1d 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -2,6 +2,8 @@ #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/common/algorithm.hpp" +#include "duckdb/common/clustered_aggregate.hpp" +#include "duckdb/common/enums/debug_verification_mode.hpp" #include "duckdb/common/unordered_set.hpp" #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" @@ -10,6 +12,7 @@ #include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" #include "duckdb/execution/radix_partitioned_hashtable.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/config.hpp" #include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/interrupt.hpp" #include "duckdb/parallel/thread_context.hpp" @@ -39,41 +42,62 @@ PhysicalUngroupedAggregate::PhysicalUngroupedAggregate(PhysicalPlan &physical_pl //===--------------------------------------------------------------------===// // Ungrouped Aggregate State //===--------------------------------------------------------------------===// -UngroupedAggregateState::UngroupedAggregateState(const vector> &aggregate_expressions) - : aggregate_expressions(aggregate_expressions) { - counts = make_uniq_array>(aggregate_expressions.size()); +UngroupedAggregateState::UngroupedAggregateState(const vector> &aggregate_expressions) { + counts = make_uniq_array(aggregate_expressions.size()); for (idx_t i = 0; i < aggregate_expressions.size(); i++) { auto &aggregate = aggregate_expressions[i]; D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); auto &aggr = aggregate->Cast(); - auto state = make_unsafe_uniq_array_uninitialized(aggr.function.GetStateSizeCallback()(aggr.function)); - aggr.function.GetStateInitCallback()(aggr.function, state.get()); + auto state = + make_unsafe_uniq_array_uninitialized(aggr.Function().GetStateSizeCallback()(aggr.Function())); + aggr.Function().GetStateInitCallback()(aggr.Function(), state.get()); aggregate_data.push_back(std::move(state)); - bind_data.push_back(aggr.bind_info.get()); - destructors.push_back(aggr.function.GetStateDestructorCallback()); + bind_data.push_back(aggr.BindInfo() ? aggr.BindInfo()->Copy() : nullptr); + functions.push_back(aggr.Function()); + aggregate_types.push_back(aggr.GetAggregateType()); + argument_counts.push_back(aggr.GetChildren().size()); #ifdef DEBUG counts[i] = 0; #endif } } + +UngroupedAggregateState::UngroupedAggregateState(const UngroupedAggregateState &global_state) + : functions(global_state.functions), aggregate_types(global_state.aggregate_types), + argument_counts(global_state.argument_counts) { + counts = make_uniq_array(functions.size()); + for (idx_t i = 0; i < functions.size(); i++) { + auto &func = functions[i]; + auto state = make_unsafe_uniq_array_uninitialized(func.GetStateSizeCallback()(func)); + func.GetStateInitCallback()(func, state.get()); + aggregate_data.push_back(std::move(state)); + bind_data.push_back(global_state.bind_data[i] ? global_state.bind_data[i]->Copy() : nullptr); + counts[i] = 0; + } +} + UngroupedAggregateState::~UngroupedAggregateState() { - D_ASSERT(destructors.size() == aggregate_data.size()); - for (idx_t i = 0; i < destructors.size(); i++) { - if (!destructors[i]) { + D_ASSERT(functions.size() == aggregate_data.size()); + for (idx_t i = 0; i < functions.size(); i++) { + auto destructor = functions[i].GetStateDestructorCallback(); + if (!destructor) { continue; } Vector state_vector(Value::POINTER(CastPointerToValue(aggregate_data[i].get())), count_t(1)); state_vector.SetVectorType(VectorType::FLAT_VECTOR); ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(bind_data[i], allocator); - destructors[i](state_vector, aggr_input_data, 1); + AggregateInputData aggr_input_data(functions[i], bind_data[i].get(), allocator); + destructor(state_vector, aggr_input_data, 1); } } void UngroupedAggregateState::Move(UngroupedAggregateState &other) { other.aggregate_data = std::move(aggregate_data); - other.destructors = std::move(destructors); + other.bind_data = std::move(bind_data); + other.functions = std::move(functions); + other.aggregate_types = std::move(aggregate_types); + other.argument_counts = std::move(argument_counts); } //===--------------------------------------------------------------------===// @@ -104,51 +128,44 @@ ArenaAllocator &GlobalUngroupedAggregateState::CreateAllocator() const { void GlobalUngroupedAggregateState::Combine(LocalUngroupedAggregateState &other) { lock_guard glock(lock); - for (idx_t aggr_idx = 0; aggr_idx < state.aggregate_expressions.size(); aggr_idx++) { - auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); - - if (aggregate.IsDistinct()) { + for (idx_t aggr_idx = 0; aggr_idx < state.functions.size(); aggr_idx++) { + if (state.aggregate_types[aggr_idx] == AggregateType::DISTINCT) { continue; } + auto &func = state.functions[aggr_idx]; Vector source_state(Value::POINTER(CastPointerToValue(other.state.aggregate_data[aggr_idx].get())), count_t(1)); Vector dest_state(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get())), count_t(1)); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, + AggregateInputData aggr_input_data(func, state.bind_data[aggr_idx].get(), allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); - if (!aggregate.function.HasStateCombineCallback()) { - throw InternalException("Aggregate function " + aggregate.function.GetName() + - " does not support combining of states"); + if (!func.HasStateCombineCallback()) { + throw InternalException("Aggregate function " + func.GetName() + " does not support combining of states"); } - aggregate.function.GetStateCombineCallback()(source_state, dest_state, aggr_input_data, 1); -#ifdef DEBUG + func.GetStateCombineCallback()(source_state, dest_state, aggr_input_data, 1); state.counts[aggr_idx] += other.state.counts[aggr_idx]; -#endif } } void GlobalUngroupedAggregateState::CombineDistinct(LocalUngroupedAggregateState &other, DistinctAggregateData &distinct_data) { lock_guard glock(lock); - for (idx_t aggr_idx = 0; aggr_idx < state.aggregate_expressions.size(); aggr_idx++) { + for (idx_t aggr_idx = 0; aggr_idx < state.functions.size(); aggr_idx++) { if (!distinct_data.IsDistinct(aggr_idx)) { continue; } - auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, + auto &func = state.functions[aggr_idx]; + AggregateInputData aggr_input_data(func, state.bind_data[aggr_idx].get(), allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); Vector state_vec(Value::POINTER(CastPointerToValue(other.state.aggregate_data[aggr_idx].get())), count_t(1)); Vector combined_vec(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get())), count_t(1)); - if (!aggregate.function.HasStateCombineCallback()) { - throw InternalException("Aggregate function " + aggregate.function.GetName() + - " does not support combining of states"); + if (!func.HasStateCombineCallback()) { + throw InternalException("Aggregate function " + func.GetName() + " does not support combining of states"); } - aggregate.function.GetStateCombineCallback()(state_vec, combined_vec, aggr_input_data, 1); -#ifdef DEBUG + func.GetStateCombineCallback()(state_vec, combined_vec, aggr_input_data, 1); state.counts[aggr_idx] += other.state.counts[aggr_idx]; -#endif } } @@ -166,7 +183,7 @@ UngroupedAggregateExecuteState::UngroupedAggregateExecuteState(ClientContext &co D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); auto &aggr = aggregate->Cast(); // initialize the payload chunk - for (auto &child : aggr.children) { + for (auto &child : aggr.GetChildren()) { payload_types.push_back(child->GetReturnType()); child_executor.AddExpression(*child); } @@ -192,32 +209,32 @@ void UngroupedAggregateExecuteState::Sink(LocalUngroupedAggregateState &state, D auto &aggregate = aggregates[aggr_idx]->Cast(); payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); + next_payload_idx = payload_idx + aggregate.GetChildren().size(); if (aggregate.IsDistinct()) { continue; } idx_t payload_cnt = 0; + idx_t chunk_count; // resolve the filter (if any) - if (aggregate.filter) { + if (aggregate.GetFilter()) { auto &filtered_data = filter_set.GetFilterData(aggr_idx); - auto count = filtered_data.ApplyFilter(input); + chunk_count = filtered_data.ApplyFilter(input); child_executor.SetChunk(filtered_data.filtered_payload); - payload_chunk.SetCardinality(count); } else { + chunk_count = input.size(); child_executor.SetChunk(input); - payload_chunk.SetCardinality(input); } // resolve the child expressions of the aggregate (if any) - for (idx_t i = 0; i < aggregate.children.size(); ++i) { + for (idx_t i = 0; i < aggregate.GetChildren().size(); ++i) { child_executor.ExecuteExpression(payload_idx + payload_cnt, payload_chunk.data[payload_idx + payload_cnt]); payload_cnt++; } - state.Sink(payload_chunk, payload_idx, aggr_idx); + state.Sink(payload_chunk, payload_idx, aggr_idx, chunk_count); } } @@ -225,7 +242,7 @@ void UngroupedAggregateExecuteState::Sink(LocalUngroupedAggregateState &state, D // Local State //===--------------------------------------------------------------------===// LocalUngroupedAggregateState::LocalUngroupedAggregateState(GlobalUngroupedAggregateState &gstate) - : allocator(gstate.CreateAllocator()), state(gstate.state.aggregate_expressions) { + : allocator(gstate.CreateAllocator()), state(gstate.state), repeated_state_vector(LogicalType::POINTER) { } class UngroupedAggregateLocalSinkState : public LocalSinkState { @@ -277,7 +294,7 @@ class UngroupedAggregateLocalSinkState : public LocalSinkState { bool PhysicalUngroupedAggregate::SinkOrderDependent() const { for (auto &expr : aggregates) { auto &aggr = expr->Cast(); - if (aggr.function.GetOrderDependent() == AggregateOrderDependent::ORDER_DEPENDENT) { + if (aggr.Function().GetOrderDependent() == AggregateOrderDependent::ORDER_DEPENDENT) { return true; } } @@ -321,14 +338,13 @@ void PhysicalUngroupedAggregate::SinkDistinct(ExecutionContext &context, DataChu auto &radix_local_sink = *sink.radix_states[table_idx]; OperatorSinkInput sink_input {radix_global_sink, radix_local_sink, input.interrupt_state}; - if (aggregate.filter) { + if (aggregate.GetFilter()) { // The hashtable can apply a filter, but only on the payload // And in our case, we need to filter the groups (the distinct aggr children) // Apply the filter before inserting into the hashtable auto &filtered_data = sink.execute_state.filter_set.GetFilterData(idx); - idx_t count = filtered_data.ApplyFilter(chunk); - filtered_data.filtered_payload.SetCardinality(count); + filtered_data.ApplyFilter(chunk); radix_table.Sink(context, filtered_data.filtered_payload, sink_input, empty_chunk, distinct_filter); } else { @@ -352,17 +368,27 @@ SinkResultType PhysicalUngroupedAggregate::Sink(ExecutionContext &context, DataC return SinkResultType::NEED_MORE_INPUT; } -void LocalUngroupedAggregateState::Sink(DataChunk &payload_chunk, idx_t payload_idx, idx_t aggr_idx) { -#ifdef DEBUG - state.counts[aggr_idx] += payload_chunk.size(); -#endif - auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); - idx_t payload_cnt = aggregate.children.size(); +void LocalUngroupedAggregateState::Sink(DataChunk &payload_chunk, idx_t payload_idx, idx_t aggr_idx, idx_t count) { + state.counts[aggr_idx] += count; + auto &func = state.functions[aggr_idx]; + idx_t payload_cnt = state.argument_counts[aggr_idx]; D_ASSERT(payload_idx + payload_cnt <= payload_chunk.data.size()); auto start_of_input = payload_cnt == 0 ? nullptr : &payload_chunk.data[payload_idx]; - AggregateInputData aggr_input_data(state.bind_data[aggr_idx], allocator); - aggregate.function.GetStateSimpleUpdateCallback()(start_of_input, aggr_input_data, payload_cnt, - state.aggregate_data[aggr_idx].get(), payload_chunk.size()); + AggregateInputData aggr_input_data(func, state.bind_data[aggr_idx].get(), allocator); + auto cluster_update = func.GetStateClusterUpdateCallback(); + if (cluster_update) { + ClusteredAggr clustered; + clustered.SetSingleRun(state.aggregate_data[aggr_idx].get(), count); + aggr_input_data.clustered = &clustered; + cluster_update(start_of_input, aggr_input_data, payload_cnt, clustered, count); + } else { + auto state_ptr = CastPointerToValue(state.aggregate_data[aggr_idx].get()); + auto state_data = FlatVector::Writer(repeated_state_vector, count); + for (idx_t i = 0; i < count; i++) { + state_data.WriteValue(state_ptr); + } + func.GetStateUpdateCallback()(start_of_input, aggr_input_data, payload_cnt, repeated_state_vector, count); + } } //===--------------------------------------------------------------------===// @@ -485,7 +511,7 @@ void UngroupedDistinctAggregateFinalizeEvent::Schedule() { // Forward the payload idx payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); + next_payload_idx = payload_idx + aggregate.GetChildren().size(); // If aggregate is not distinct, skip it if (!distinct_data.IsDistinct(agg_idx)) { @@ -565,7 +591,7 @@ TaskExecutionResult UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() DataChunk payload_chunk; payload_chunk.InitializeEmpty(distinct_data.grouped_aggregate_data[table_idx]->group_types); - payload_chunk.SetCardinality(0); + payload_chunk.SetChildCardinality(0); while (true) { output_chunk.Reset(); @@ -580,14 +606,13 @@ TaskExecutionResult UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() } // We dont need to resolve the filter, we already did this in Sink - idx_t payload_cnt = aggregate.children.size(); + idx_t payload_cnt = aggregate.GetChildren().size(); for (idx_t i = 0; i < payload_cnt; i++) { payload_chunk.data[i].Reference(output_chunk.data[i]); } - payload_chunk.SetCardinality(output_chunk); // Update the aggregate state - state.Sink(payload_chunk, 0, agg_idx); + state.Sink(payload_chunk, 0, agg_idx, output_chunk.size()); } blocked = false; } @@ -632,29 +657,34 @@ SinkFinalizeType PhysicalUngroupedAggregate::Finalize(Pipeline &pipeline, Event //===--------------------------------------------------------------------===// void VerifyNullHandling(DataChunk &chunk, UngroupedAggregateState &state, const vector> &aggregates) { -#ifdef DEBUG + if (DBConfigOptions::global_verification_mode != DebugVerificationMode::VERIFY_FUNCTIONS) { + return; + } for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { auto &aggr = aggregates[aggr_idx]->Cast(); if (state.counts[aggr_idx] == 0 && - aggr.function.GetProperties().GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING) { + aggr.Function().GetProperties().GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING) { // Default is when 0 values go in, NULL comes out UnifiedVectorFormat vdata; chunk.data[aggr_idx].ToUnifiedFormat(vdata); - D_ASSERT(!vdata.validity.RowIsValid(vdata.sel->get_index(0))); + if (vdata.validity.RowIsValid(vdata.sel->get_index(0))) { + throw InternalException( + "VerifyNullHandling failed for aggregate function \"%s\": no rows were aggregated but the result " + "is not NULL - aggregates with default NULL handling should return NULL when no rows are " + "aggregated", + aggr.Function().GetName()); + } } } -#endif } void GlobalUngroupedAggregateState::Finalize(DataChunk &result, idx_t column_offset) { - result.SetCardinality(1); - for (idx_t aggr_idx = 0; aggr_idx < state.aggregate_expressions.size(); aggr_idx++) { - auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); - + result.SetChildCardinality(1); + for (idx_t aggr_idx = 0; aggr_idx < state.functions.size(); aggr_idx++) { + auto &func = state.functions[aggr_idx]; Vector state_vector(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get())), count_t(1)); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); - aggregate.function.GetStateFinalizeCallback()(state_vector, aggr_input_data, - result.data[column_offset + aggr_idx], 1, 0); + AggregateInputData aggr_input_data(func, state.bind_data[aggr_idx].get(), allocator); + func.GetStateFinalizeCallback()(state_vector, aggr_input_data, result.data[column_offset + aggr_idx], 1, 0); FlatVector::SetSize(result.data[column_offset + aggr_idx], count_t(1)); } } @@ -680,8 +710,8 @@ InsertionOrderPreservingMap PhysicalUngroupedAggregate::ParamsToString() aggregate_info += "\n"; } aggregate_info += aggregates[i]->GetName(); - if (aggregate.filter) { - aggregate_info += " Filter: " + aggregate.filter->GetName(); + if (aggregate.GetFilter()) { + aggregate_info += " Filter: " + aggregate.GetFilter()->GetName(); } } result["Aggregates"] = aggregate_info; diff --git a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp index 16cb6823e..956a50d43 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/sorting/sort_strategy.hpp" #include "duckdb/common/types/row/tuple_data_collection.hpp" #include "duckdb/common/types/row/tuple_data_iterator.hpp" +#include "duckdb/common/types/value_map.hpp" #include "duckdb/function/window/window_aggregate_function.hpp" #include "duckdb/function/window/window_executor.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" @@ -49,7 +50,8 @@ class WindowHashGroup { return ((n + (val - 1)) / val); } - WindowHashGroup(WindowGlobalSinkState &gsink, const ChunkRow &chunk_row, const idx_t hash_bin_p); + WindowHashGroup(WindowGlobalSinkState &gsink, const ChunkRow &chunk_row, idx_t group_idx, const Value &sink_idx, + idx_t bin_idx); void AllocateMasks(); void ComputeMasks(const idx_t begin_idx, const idx_t end_idx); @@ -136,7 +138,7 @@ class WindowHashGroup { task.stage = WindowGroupStage(next_task / group_threads); if (task.stage == group_stage) { task.thread_idx = next_task % group_threads; - task.group_idx = hash_bin; + task.group_idx = group_idx; task.begin_idx = task.thread_idx * per_thread; task.end_idx = MinValue(task.begin_idx + per_thread, ChunkCount()); ++next_task; @@ -156,7 +158,7 @@ class WindowHashGroup { idx_t blocks = 0; //! The partition boundary mask ValidityMask partition_mask; - //! The order boundary mask + //! The order boundary mask (maps sort column END to mask) OrderMasks order_masks; //! The fully materialised data collection unique_ptr collection; @@ -167,8 +169,12 @@ class WindowHashGroup { //! Executor local states, one per thread ThreadLocalStates thread_states; - //! The bin number - idx_t hash_bin; + //! The global group number (so we can find the task + const idx_t group_idx; + //! The sink index (so we can find the sink state) + const Value sink_idx; + //! The sink bin (so we can find the bin within the sink state) + const idx_t bin_idx; //! Single threading lock mutex lock; //! The the number of blocks per thread. @@ -195,72 +201,145 @@ class WindowHashGroup { class WindowGlobalSinkState : public GlobalSinkState { public: + using GlobalStatePtr = unique_ptr; using ExecutorPtr = unique_ptr; using Executors = vector; WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context); - SinkFinalizeType Finalize(ClientContext &client, InterruptState &interrupt_state) { - OperatorSinkFinalizeInput finalize {*strategy_sink, interrupt_state}; - auto result = sort_strategy->Finalize(client, finalize); - - return result; - } - //! Parent operator const PhysicalWindow &op; //! Client context ClientContext &client; //! The partitioned sunk data unique_ptr sort_strategy; - //! The partitioned sunk data - unique_ptr strategy_sink; + //! The partitioned sunk data. With partitioning there may be more than one + mutex lock; + value_map_t strategy_sinks; //! The number of sunk rows (for progress) atomic count; //! The execution functions Executors executors; //! The shared expressions library WindowSharedExpressions shared; + + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + // The sort strategy, executors, and shared expression layout are iteration-invariant. Only the + // sort sink holds per-iteration materialized data, so replace that while preserving the setup. + lock_guard sinks_guard(lock); + strategy_sinks.clear(); + if (!op.partition_info.RequiresPartitionColumns()) { + strategy_sinks.insert(make_pair(Value(), sort_strategy->GetGlobalSinkState(context))); + } + count = 0; + GlobalSinkState::Reset(context); + } + + optional_ptr GetOrCreatePartition(ClientContext &client, const Value &partition) { + lock_guard l(lock); + // find the state that corresponds to this partition and combine + auto entry = strategy_sinks.find(partition); + if (entry != strategy_sinks.end()) { + return entry->second.get(); + } + // no state yet for this partition - allocate a new one + auto new_global_state = sort_strategy->GetGlobalSinkState(client); + auto result = new_global_state.get(); + strategy_sinks.insert(make_pair(partition, std::move(new_global_state))); + return result; + } }; // Per-thread sink state class WindowLocalSinkState : public LocalSinkState { public: + using GlobalStatePtr = optional_ptr; + using LocalStatePtr = unique_ptr; + WindowLocalSinkState(ExecutionContext &context, const WindowGlobalSinkState &gstate) : local_group(gstate.sort_strategy->GetLocalSinkState(context)) { } - unique_ptr local_group; + explicit WindowLocalSinkState(ExecutionContext &context) { + } + + LocalStatePtr local_group; + + // Partitioning state + Value current_partition; + GlobalStatePtr partition_group; + + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSinkState &gstate_p) override { + auto &gstate = gstate_p.Cast(); + partition_group = nullptr; + if (local_group) { + local_group = gstate.sort_strategy->GetLocalSinkState(context); + } else { + local_group.reset(); + } + } + + SinkCombineResultType Combine(ExecutionContext &context, const WindowGlobalSinkState &gstate, + InterruptState &interrupt_state) { + if (!partition_group) { + return SinkCombineResultType::FINISHED; + } + + // flush the local state + OperatorSinkCombineInput hcombine {*partition_group, *local_group, interrupt_state}; + auto result = gstate.sort_strategy->Combine(context, hcombine); + + // Start a new state pair + partition_group = nullptr; + local_group.reset(); + + return result; + } }; // this implements a sorted window functions variant PhysicalWindow::PhysicalWindow(PhysicalPlan &physical_plan, vector types, vector> select_list_p, idx_t estimated_cardinality, - PhysicalOperatorType type) - : PhysicalOperator(physical_plan, type, std::move(types), estimated_cardinality), - select_list(std::move(select_list_p)), order_idx(0), is_order_dependent(false) { + vector partitions) + : PhysicalOperator(physical_plan, TYPE, std::move(types), estimated_cardinality), + select_list(std::move(select_list_p)), order_idx(0), is_order_dependent(false), + partition_info(std::move(partitions)) { idx_t max_orders = 0; for (idx_t i = 0; i < select_list.size(); ++i) { auto &expr = select_list[i]; D_ASSERT(expr->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); auto &bound_window = expr->Cast(); - if (bound_window.partitions.empty() && bound_window.orders.empty()) { + if (bound_window.Partitions().empty() && bound_window.OrderBy().empty()) { is_order_dependent = true; } - if (bound_window.orders.size() > max_orders) { + if (bound_window.OrderBy().size() > max_orders) { order_idx = i; - max_orders = bound_window.orders.size(); + max_orders = bound_window.OrderBy().size(); } } } +PhysicalWindow::PhysicalWindow(PhysicalPlan &physical_plan, vector types, + vector> select_list, idx_t estimated_cardinality) + : PhysicalWindow(physical_plan, std::move(types), std::move(select_list), estimated_cardinality, + vector()) { +} + static unique_ptr WindowExecutorFactory(BoundWindowExpression &wexpr, ClientContext &client, WindowSharedExpressions &shared) { if (wexpr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE) { return make_uniq(wexpr, client, shared); } else { - if (!wexpr.window.get()) { + if (!wexpr.WindowFunction().get()) { throw InternalException("Window expression type %s", ExpressionTypeToString(wexpr.GetExpressionType())); } return make_uniq(wexpr, shared); @@ -279,9 +358,42 @@ WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientCon executors.emplace_back(std::move(wexec)); } - sort_strategy = SortStrategy::Factory(client, wexpr.partitions, wexpr.orders, op.children[0].get().GetTypes(), - wexpr.partitions_stats, op.estimated_cardinality); - strategy_sink = sort_strategy->GetGlobalSinkState(client); + if (!op.partition_info.RequiresPartitionColumns()) { + sort_strategy = + SortStrategy::Factory(client, wexpr.Partitions(), wexpr.OrderBy(), op.children[0].get().GetTypes(), + wexpr.PartitionsStats(), op.estimated_cardinality); + GetOrCreatePartition(client, Value()); + } else { + // Pipeline does the partitioning for us, so leave them out + vector> unpartitioned; + vector> partitions_stats; + sort_strategy = SortStrategy::Factory(client, unpartitioned, wexpr.OrderBy(), op.children[0].get().GetTypes(), + partitions_stats, op.estimated_cardinality); + } +} + +OperatorPartitionInfo PhysicalWindow::RequiredPartitionInfo() const { + if (!partition_info.RequiresPartitionColumns()) { + return PhysicalOperator::RequiredPartitionInfo(); + } + + return partition_info; +} + +//===--------------------------------------------------------------------===// +// NextBatch +//===--------------------------------------------------------------------===// +SinkNextBatchType PhysicalWindow::NextBatch(ExecutionContext &context, OperatorSinkNextBatchInput &batch) const { + if (!partition_info.RequiresPartitionColumns()) { + return PhysicalOperator::NextBatch(context, batch); + } + + auto &gstate = batch.global_state.Cast(); + auto &lstate = batch.local_state.Cast(); + + (void)lstate.Combine(context, gstate, batch.interrupt_state); + + return SinkNextBatchType::READY; } //===--------------------------------------------------------------------===// @@ -292,19 +404,54 @@ SinkResultType PhysicalWindow::Sink(ExecutionContext &context, DataChunk &chunk, auto &lstate = sink.local_state.Cast(); gstate.count += chunk.size(); - OperatorSinkInput hsink {*gstate.strategy_sink, *lstate.local_group, sink.interrupt_state}; + if (partition_info.RequiresPartitionColumns()) { + if (!lstate.partition_group) { + // the local state is not yet initialized for this partition + // initialize the partition + child_list_t partition_values; + const auto &partition_columns = partition_info.partition_columns; + for (idx_t partition_idx = 0; partition_idx < partition_columns.size(); partition_idx++) { + auto column_name = to_string(partition_idx); + auto &partition = lstate.partition_info.partition_data[partition_idx]; + D_ASSERT(Value::NotDistinctFrom(partition.min_val, partition.max_val)); + partition_values.emplace_back(make_pair(std::move(column_name), partition.min_val)); + } + lstate.current_partition = Value::STRUCT(std::move(partition_values)); + + // initialize the state + lstate.partition_group = gstate.GetOrCreatePartition(context.client, lstate.current_partition); + lstate.local_group = gstate.sort_strategy->GetLocalSinkState(context); + } + } else if (!lstate.partition_group) { + lstate.partition_group = gstate.GetOrCreatePartition(context.client, lstate.current_partition); + } + + OperatorSinkInput hsink {*lstate.partition_group, *lstate.local_group, sink.interrupt_state}; return gstate.sort_strategy->Sink(context, chunk, hsink); } +//===--------------------------------------------------------------------===// +// Combine +//===--------------------------------------------------------------------===// SinkCombineResultType PhysicalWindow::Combine(ExecutionContext &context, OperatorSinkCombineInput &combine) const { auto &gstate = combine.global_state.Cast(); auto &lstate = combine.local_state.Cast(); - OperatorSinkCombineInput hcombine {*gstate.strategy_sink, *lstate.local_group, combine.interrupt_state}; + if (partition_info.RequiresPartitionColumns()) { + return lstate.Combine(context, gstate, combine.interrupt_state); + } else if (!lstate.partition_group) { + lstate.partition_group = gstate.GetOrCreatePartition(context.client, lstate.current_partition); + } + + OperatorSinkCombineInput hcombine {*lstate.partition_group, *lstate.local_group, combine.interrupt_state}; return gstate.sort_strategy->Combine(context, hcombine); } unique_ptr PhysicalWindow::GetLocalSinkState(ExecutionContext &context) const { + if (partition_info.RequiresPartitionColumns()) { + return make_uniq(context); + } + auto &gstate = sink_state->Cast(); return make_uniq(context, gstate); } @@ -317,16 +464,24 @@ SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, Clie OperatorSinkFinalizeInput &input) const { auto &gsink = input.global_state.Cast(); auto &sort_strategy = *gsink.sort_strategy; - auto &strategy_sink = *gsink.strategy_sink; - - OperatorSinkFinalizeInput hfinalize {strategy_sink, input.interrupt_state}; - return sort_strategy.Finalize(client, hfinalize); + SinkFinalizeType result = SinkFinalizeType::READY; + lock_guard sinks_guard(gsink.lock); + for (auto &strategy_sink : gsink.strategy_sinks) { + OperatorSinkFinalizeInput hfinalize {*strategy_sink.second, input.interrupt_state}; + result = sort_strategy.Finalize(client, hfinalize); + } + return result; } ProgressData PhysicalWindow::GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, const ProgressData source_progress) const { auto &gsink = gstate.Cast(); - return gsink.sort_strategy->GetSinkProgress(context, *gsink.strategy_sink, source_progress); + auto progress = source_progress; + lock_guard sinks_guard(gsink.lock); + for (auto &strategy_sink : gsink.strategy_sinks) { + progress.Add(gsink.sort_strategy->GetSinkProgress(context, *strategy_sink.second, progress)); + } + return progress; } //===--------------------------------------------------------------------===// @@ -334,6 +489,7 @@ ProgressData PhysicalWindow::GetSinkProgress(ClientContext &context, GlobalSinkS //===--------------------------------------------------------------------===// class WindowGlobalSourceState : public GlobalSourceState { public: + using GlobalStatePtr = unique_ptr; using WindowHashGroupPtr = unique_ptr; using ScannerPtr = unique_ptr; using Task = WindowSourceTask; @@ -356,8 +512,8 @@ class WindowGlobalSourceState : public GlobalSourceState { ClientContext &client; //! All the sunk data WindowGlobalSinkState &gsink; - //! The hashed sort global source state for delayed sorting - unique_ptr hashed_source; + //! The hashed sort global source states for delayed sorting + value_map_t hashed_sources; //! The sorted hash groups vector window_hash_groups; //! The total number of blocks to process; @@ -396,21 +552,27 @@ class WindowGlobalSourceState : public GlobalSourceState { WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &client, WindowGlobalSinkState &gsink_p) : client(client), gsink(gsink_p), next_group(0), locals(0), started(0), finished(0), stopped(false), completed(0) { auto &sort_strategy = *gsink.sort_strategy; - hashed_source = sort_strategy.GetGlobalSourceState(client, *gsink.strategy_sink); - auto &hash_groups = sort_strategy.GetHashGroups(*hashed_source); - window_hash_groups.resize(hash_groups.size()); - for (idx_t group_idx = 0; group_idx < hash_groups.size(); ++group_idx) { - const auto block_count = hash_groups[group_idx].chunks; - if (!block_count) { - continue; - } + for (auto &strategy_sink : gsink.strategy_sinks) { + auto hashed_source = sort_strategy.GetGlobalSourceState(client, *strategy_sink.second); + auto &hash_groups = sort_strategy.GetHashGroups(*hashed_source); + + for (idx_t bin_idx = 0; bin_idx < hash_groups.size(); ++bin_idx) { + const auto block_count = hash_groups[bin_idx].chunks; + if (!block_count) { + continue; + } - auto window_hash_group = make_uniq(gsink, hash_groups[group_idx], group_idx); - window_hash_group->batch_base = total_blocks; - total_blocks += block_count; + const idx_t group_idx = window_hash_groups.size(); + auto window_hash_group = + make_uniq(gsink, hash_groups[bin_idx], group_idx, strategy_sink.first, bin_idx); + window_hash_group->batch_base = total_blocks; + total_blocks += block_count; + + window_hash_groups.emplace_back(std::move(window_hash_group)); + } - window_hash_groups[group_idx] = std::move(window_hash_group); + hashed_sources.insert(make_pair(strategy_sink.first, std::move(hashed_source))); } CreateTaskList(); @@ -454,9 +616,11 @@ void WindowGlobalSourceState::CreateTaskList() { } } -WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gsink, const ChunkRow &chunk_row, const idx_t hash_bin_p) +WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gsink, const ChunkRow &chunk_row, idx_t group_idx, + const Value &sink_idx, idx_t bin_idx) : gsink(gsink), count(chunk_row.count), blocks(chunk_row.chunks), stage(WindowGroupStage::SORT), - hash_bin(hash_bin_p), sorted(0), materialized(0), masked(0), sunk(0), finalized(0), completed(0), batch_base(0) { + group_idx(group_idx), sink_idx(sink_idx), bin_idx(bin_idx), sorted(0), materialized(0), masked(0), sunk(0), + finalized(0), completed(0), batch_base(0) { // There are three types of partitions: // 1. No partition (no sorting) // 2. One partition (sorting, but no hashing) @@ -503,7 +667,9 @@ void WindowHashGroup::AllocateMasks() { const auto &executors = gsink.executors; for (auto &wexec : executors) { auto &wexpr = wexec->wexpr; - auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; + + const auto order_begin = gsink.op.partition_info.RequiresPartitionColumns() ? 0 : wexpr.Partitions().size(); + auto &order_mask = order_masks[order_begin + wexpr.OrderBy().size()]; if (order_mask.IsMaskSet()) { continue; } @@ -538,13 +704,12 @@ void WindowHashGroup::ComputeMasks(const idx_t block_begin, const idx_t block_en // If we are not sorting, then only the partition boundaries are needed. const auto &wexpr = gsink.op.select_list[gsink.op.order_idx]->Cast(); - auto &partitions = wexpr.partitions; - if (partitions.empty() && wexpr.orders.empty()) { - return; - } // Set up the partition compare structs - const auto key_count = partitions.size(); + const auto order_begin = gsink.op.partition_info.RequiresPartitionColumns() ? 0 : wexpr.Partitions().size(); + if (!order_begin && wexpr.OrderBy().empty()) { + return; + } // Set up the order data structures auto &collection = *rows; @@ -552,15 +717,15 @@ void WindowHashGroup::ComputeMasks(const idx_t block_begin, const idx_t block_en WindowCollectionChunkScanner scanner(collection, scan_cols, block_begin); unordered_map prefixes; for (auto &order_mask : order_masks) { - D_ASSERT(order_mask.first >= key_count); - auto order_type = scanner.PrefixStructType(order_mask.first, key_count); + D_ASSERT(order_mask.first >= order_begin); + auto order_type = scanner.PrefixStructType(order_mask.first, order_begin); vector types(2, order_type); auto &keys = prefixes[order_mask.first]; // We can't use InitializeEmpty here because it doesn't set up all of the STRUCT internals... keys.Initialize(collection.GetAllocator(), types); } - WindowDeltaScanner(collection, block_begin, block_end, scan_cols, key_count, + WindowDeltaScanner(collection, block_begin, block_end, scan_cols, order_begin, [&](const idx_t row_idx, DataChunk &prev, DataChunk &curr, const idx_t ndistinct, SelectionVector &distinct, const SelectionVector &matching) { // Process the partition boundaries @@ -582,19 +747,19 @@ void WindowHashGroup::ComputeMasks(const idx_t block_begin, const idx_t block_en for (auto &order_mask : order_masks) { // If there are no order columns, then all the partition elements are peers and we are // done - if (partitions.size() == order_mask.first) { + if (order_begin == order_mask.first) { continue; } auto &prefix = prefixes[order_mask.first]; prefix.Reset(); auto &order_prev = prefix.data[0]; auto &order_curr = prefix.data[1]; - scanner.ReferenceStructColumns(prev, order_prev, order_mask.first, partitions.size()); - scanner.ReferenceStructColumns(curr, order_curr, order_mask.first, partitions.size()); + scanner.ReferenceStructColumns(prev, order_prev, order_mask.first, order_begin); + scanner.ReferenceStructColumns(curr, order_curr, order_mask.first, order_begin); if (ndistinct) { prefix.Slice(matching, nmatch); } else { - prefix.SetCardinality(nmatch); + prefix.SetChildCardinality(nmatch); } const auto m = VectorOperations::DistinctFrom(order_curr, order_prev, nullptr, nmatch, &distinct, nullptr); @@ -685,8 +850,8 @@ void WindowLocalSourceState::Sort(ExecutionContext &context, InterruptState &int auto &gsink = gsource.gsink; auto &sort_strategy = *gsink.sort_strategy; - OperatorSinkFinalizeInput finalize {*gsink.strategy_sink, interrupt}; - sort_strategy.SortColumnData(context, task_local.group_idx, finalize); + OperatorSinkFinalizeInput finalize {*gsink.strategy_sinks[window_hash_group->sink_idx], interrupt}; + sort_strategy.SortColumnData(context, window_hash_group->bin_idx, finalize); // Mark this range as done window_hash_group->sorted += (task->end_idx - task->begin_idx); @@ -698,10 +863,11 @@ void WindowLocalSourceState::Materialize(ExecutionContext &context, InterruptSta D_ASSERT(task->stage == WindowGroupStage::MATERIALIZE); auto unused = make_uniq(); - OperatorSourceInput source {*gsource.hashed_source, *unused, interrupt}; + auto &hashed_source = gsource.hashed_sources[window_hash_group->sink_idx]; + OperatorSourceInput source {*hashed_source, *unused, interrupt}; auto &gsink = gsource.gsink; auto &sort_strategy = *gsink.sort_strategy; - sort_strategy.MaterializeColumnData(context, task_local.group_idx, source); + sort_strategy.MaterializeColumnData(context, window_hash_group->bin_idx, source); // Mark this range as done window_hash_group->materialized += (task->end_idx - task->begin_idx); @@ -712,7 +878,7 @@ void WindowLocalSourceState::Materialize(ExecutionContext &context, InterruptSta if (window_hash_group->materialized >= window_hash_group->blocks) { lock_guard prepare_guard(window_hash_group->lock); if (!window_hash_group->rows) { - window_hash_group->rows = sort_strategy.GetColumnData(task_local.group_idx, source); + window_hash_group->rows = sort_strategy.GetColumnData(window_hash_group->bin_idx, source); } } } @@ -739,7 +905,8 @@ WindowHashGroup::ExecutorGlobalStates &WindowHashGroup::GetGlobalStates(ClientCo // These can be large so we defer building them until we are ready. for (auto &wexec : executors) { auto &wexpr = wexec->wexpr; - auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; + const auto order_begin = gsink.op.partition_info.RequiresPartitionColumns() ? 0 : wexpr.Partitions().size(); + auto &order_mask = order_masks[order_begin + wexpr.OrderBy().size()]; gestates.emplace_back(wexec->GetGlobalState(client, count, partition_mask, order_mask)); } @@ -784,7 +951,7 @@ void WindowLocalSourceState::Sink(ExecutionContext &context, InterruptState &int // Compute fully materialised expressions if (coll_chunk.data.empty()) { - coll_chunk.SetCardinality(input_chunk); + coll_chunk.SetChildCardinality(input_chunk.size()); } else { coll_chunk.Reset(); coll_exec.Execute(input_chunk, coll_chunk); @@ -798,7 +965,7 @@ void WindowLocalSourceState::Sink(ExecutionContext &context, InterruptState &int // Compute sink expressions if (sink_chunk.data.empty()) { - sink_chunk.SetCardinality(input_chunk); + sink_chunk.SetChildCardinality(input_chunk.size()); } else { sink_chunk.Reset(); sink_exec.Execute(input_chunk, sink_chunk); @@ -994,19 +1161,20 @@ void WindowLocalSourceState::GetData(ExecutionContext &context, DataChunk &resul auto &executor = *executors[expr_idx]; auto &result = output_chunk.data[expr_idx]; if (eval_chunk.data.empty()) { - eval_chunk.SetCardinality(input_chunk); + eval_chunk.SetChildCardinality(input_chunk.size()); } else { eval_chunk.Reset(); eval_exec.Execute(input_chunk, eval_chunk); } OperatorSinkInput sink {*gestates[expr_idx], *local_states[expr_idx], interrupt}; - executor.Evaluate(context, position, eval_chunk, result, sink); + executor.Evaluate(context, position, eval_chunk, result, sink, input_chunk.size()); } - output_chunk.SetCardinality(input_chunk); output_chunk.Verify(context.client.db); idx_t out_idx = 0; - result.SetCardinality(input_chunk); + // input_chunk's cardinality is authoritative (its column buffers may still be a larger, stale capacity), so set the + // result cardinality explicitly here - before the references below alias input_chunk's/output_chunk's buffers. + result.SetChildCardinality(input_chunk.size()); for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); col_idx++) { result.data[out_idx++].Reference(input_chunk.data[col_idx]); } @@ -1038,17 +1206,17 @@ bool PhysicalWindow::SupportsPartitioning(const OperatorPartitionInfo &partition // We can only preserve order for single partitioning // or work stealing causes out of order batch numbers auto &wexpr = select_list[order_idx]->Cast(); - return wexpr.partitions.empty(); // NOLINT + return wexpr.Partitions().empty(); // NOLINT } OrderPreservationType PhysicalWindow::SourceOrder() const { auto &wexpr = select_list[order_idx]->Cast(); - if (!wexpr.partitions.empty()) { + if (!wexpr.Partitions().empty()) { // if we have partitions the window order is not defined return OrderPreservationType::NO_ORDER; } // without partitions we can maintain order - if (wexpr.orders.empty()) { + if (wexpr.OrderBy().empty()) { // if we have no orders we maintain insertion order return OrderPreservationType::INSERTION_ORDER; } diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp index 145d5b554..676be10a0 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp @@ -1,5 +1,7 @@ #include "duckdb/execution/operator/csv_scanner/csv_schema.hpp" +#include "duckdb/common/array.hpp" + namespace duckdb { struct TypeIdxPair { @@ -25,17 +27,26 @@ bool CSVSchema::CanWeCastIt(LogicalTypeId source, LogicalTypeId destination) { case LogicalTypeId::TINYINT: return destination == LogicalTypeId::SMALLINT || destination == LogicalTypeId::INTEGER || destination == LogicalTypeId::BIGINT || destination == LogicalTypeId::DECIMAL || + destination == LogicalTypeId::HUGEINT || destination == LogicalTypeId::BIGNUM || destination == LogicalTypeId::FLOAT || destination == LogicalTypeId::DOUBLE; case LogicalTypeId::SMALLINT: return destination == LogicalTypeId::INTEGER || destination == LogicalTypeId::BIGINT || - destination == LogicalTypeId::DECIMAL || destination == LogicalTypeId::FLOAT || + destination == LogicalTypeId::DECIMAL || destination == LogicalTypeId::HUGEINT || + destination == LogicalTypeId::BIGNUM || destination == LogicalTypeId::FLOAT || destination == LogicalTypeId::DOUBLE; case LogicalTypeId::INTEGER: return destination == LogicalTypeId::BIGINT || destination == LogicalTypeId::DECIMAL || + destination == LogicalTypeId::HUGEINT || destination == LogicalTypeId::BIGNUM || destination == LogicalTypeId::FLOAT || destination == LogicalTypeId::DOUBLE; case LogicalTypeId::BIGINT: - return destination == LogicalTypeId::DECIMAL || destination == LogicalTypeId::FLOAT || + return destination == LogicalTypeId::DECIMAL || destination == LogicalTypeId::HUGEINT || + destination == LogicalTypeId::BIGNUM || destination == LogicalTypeId::FLOAT || destination == LogicalTypeId::DOUBLE; + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UHUGEINT: + return destination == LogicalTypeId::BIGNUM || destination == LogicalTypeId::DOUBLE; + case LogicalTypeId::BIGNUM: + return destination == LogicalTypeId::DOUBLE; case LogicalTypeId::FLOAT: return destination == LogicalTypeId::DOUBLE; default: @@ -45,8 +56,9 @@ bool CSVSchema::CanWeCastIt(LogicalTypeId source, LogicalTypeId destination) { void CSVSchema::MergeSchemas(CSVSchema &other, bool null_padding) { // TODO: We could also merge names, maybe by giving preference to non-generated names? - const vector candidates_by_specificity = {LogicalType::BOOLEAN, LogicalType::BIGINT, - LogicalType::DOUBLE, LogicalType::VARCHAR}; + const array candidates_by_specificity = {LogicalType::BOOLEAN, LogicalType::BIGINT, + LogicalType::HUGEINT, LogicalType::BIGNUM, + LogicalType::DOUBLE, LogicalType::VARCHAR}; for (idx_t i = 0; i < columns.size() && i < other.columns.size(); i++) { auto this_type = columns[i].type.id(); auto other_type = other.columns[i].type.id(); @@ -76,13 +88,14 @@ void CSVSchema::MergeSchemas(CSVSchema &other, bool null_padding) { } } -CSVSchema::CSVSchema(const vector &names, const vector &types, const string &file_path, +CSVSchema::CSVSchema(const vector &names, const vector &types, const string &file_path, idx_t rows_read_p, const bool empty_p) : rows_read(rows_read_p), empty(empty_p) { Initialize(names, types, file_path); } -void CSVSchema::Initialize(const vector &names, const vector &types, const string &file_path_p) { +void CSVSchema::Initialize(const vector &names, const vector &types, + const string &file_path_p) { if (!columns.empty()) { throw InternalException("CSV Schema is already populated, this should not happen."); } @@ -100,7 +113,7 @@ void CSVSchema::Initialize(const vector &names, const vector CSVSchema::GetNames() const { vector names; for (auto &column : columns) { - names.push_back(column.name); + names.push_back(column.name.GetIdentifierName()); } return names; } @@ -149,7 +162,7 @@ bool CSVSchema::SchemasMatch(string &error_message, SnifferResult &sniffer_resul for (idx_t i = 0; i < sniffer_result.names.size(); i++) { // Populate our little schema - current_schema[sniffer_result.names[i]] = {sniffer_result.return_types[i], i}; + current_schema[sniffer_result.names[i].GetIdentifierName()] = {sniffer_result.return_types[i], i}; } if (is_minimal_sniffer) { auto min_sniffer = static_cast(sniffer_result); @@ -157,7 +170,7 @@ bool CSVSchema::SchemasMatch(string &error_message, SnifferResult &sniffer_resul bool min_sniff_match = true; // If we don't have more than one row, either the names must match or the types must match. for (auto &column : columns) { - if (current_schema.find(column.name) == current_schema.end()) { + if (current_schema.find(column.name.GetIdentifierName()) == current_schema.end()) { min_sniff_match = false; break; } @@ -182,7 +195,7 @@ bool CSVSchema::SchemasMatch(string &error_message, SnifferResult &sniffer_resul // If we got here, we have the right types but the wrong names, lets fix the names idx_t sniff_name_idx = 0; for (auto &column : columns) { - sniffer_result.names[sniff_name_idx++] = column.name; + sniffer_result.names[sniff_name_idx++] = Identifier(column.name); } return true; } @@ -200,15 +213,15 @@ bool CSVSchema::SchemasMatch(string &error_message, SnifferResult &sniffer_resul error << "Current file: " << cur_file_path << "\n"; for (auto &column : columns) { - if (current_schema.find(column.name) == current_schema.end()) { - error << "Column with name: \"" << column.name << "\" is missing" + if (current_schema.find(column.name.GetIdentifierName()) == current_schema.end()) { + error << "Column with name: \"" << column.name.GetIdentifierName() << "\" is missing" << "\n"; match = false; } else { - if (!CanWeCastIt(current_schema[column.name].type.id(), column.type.id())) { - error << "Column with name: \"" << column.name + if (!CanWeCastIt(current_schema[column.name.GetIdentifierName()].type.id(), column.type.id())) { + error << "Column with name: \"" << column.name.GetIdentifierName() << "\" is expected to have type: " << column.type.ToString(); - error << " But has type: " << current_schema[column.name].type.ToString() << "\n"; + error << " But has type: " << current_schema[column.name.GetIdentifierName()].type.ToString() << "\n"; match = false; } } diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index fdb70cc72..903a23273 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -1,8 +1,11 @@ #include "duckdb/execution/operator/csv_scanner/string_value_scanner.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/operator/decimal_cast_operators.hpp" #include "duckdb/common/operator/double_cast_operator.hpp" #include "duckdb/common/operator/integer_cast_operator.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/bignum.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/execution/operator/csv_scanner/csv_casting.hpp" @@ -21,17 +24,17 @@ constexpr idx_t StringValueScanner::LINE_FINDER_ID; StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_machine, const shared_ptr &buffer_handle, Allocator &buffer_allocator, idx_t result_size_p, idx_t buffer_position, CSVErrorHandler &error_handler_p, - CSVIterator &iterator_p, bool store_line_size_p, - shared_ptr csv_file_scan_p, idx_t &lines_read_p, bool sniffing_p, - const string &path_p, idx_t scan_id, bool &used_unstrictness) + CSVIterator &iterator_p, shared_ptr csv_file_scan_p, + idx_t &lines_read_p, bool sniffing_p, const string &path_p, idx_t scan_id, + bool &used_unstrictness) : ScannerResult(states, state_machine, result_size_p), number_of_columns(NumericCast(state_machine.dialect_options.num_cols)), null_padding(state_machine.options.null_padding), ignore_errors(state_machine.options.ignore_errors.GetValue()), extra_delimiter_bytes(state_machine.dialect_options.state_machine_options.delimiter.GetValue().empty() ? 0 : state_machine.dialect_options.state_machine_options.delimiter.GetValue().size() - 1), - error_handler(error_handler_p), iterator(iterator_p), store_line_size(store_line_size_p), - csv_file_scan(std::move(csv_file_scan_p)), lines_read(lines_read_p), used_unstrictness(used_unstrictness), + error_handler(error_handler_p), iterator(iterator_p), csv_file_scan(std::move(csv_file_scan_p)), + lines_read(lines_read_p), used_unstrictness(used_unstrictness), current_errors(scan_id, state_machine.options.IgnoreErrors()), sniffing(sniffing_p), path(path_p) { // Vector information D_ASSERT(number_of_columns > 0); @@ -300,7 +303,9 @@ void StringValueResult::AddValueToVector(const char *value_ptr, idx_t size, bool } bool success = true; string strip_thousands; - if (LogicalType::IsNumeric(parse_types[chunk_col_id].type_id) && + // Whether current type is logically numeric (BIGNUM is physically stored as VARCHAR instead of numerical value). + if ((LogicalType::IsNumeric(parse_types[chunk_col_id].type_id) || + parse_types[chunk_col_id].type_id == LogicalTypeId::BIGNUM) && state_machine.options.thousands_separator != '\0') { // If we have a thousands separator we should try to use that strip_thousands = BaseScanner::RemoveSeparator(value_ptr, size, state_machine.options.thousands_separator); @@ -328,6 +333,12 @@ void StringValueResult::AddValueToVector(const char *value_ptr, idx_t size, bool success = TrySimpleIntegerCast(value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); break; + case LogicalTypeId::HUGEINT: { + auto &result_value = static_cast(vector_ptr[chunk_col_id])[number_of_rows]; + success = TryCast::Operation(string_t(value_ptr, NumericCast(size)), + result_value, false); + break; + } case LogicalTypeId::UTINYINT: success = TrySimpleIntegerCast( value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); @@ -344,6 +355,29 @@ void StringValueResult::AddValueToVector(const char *value_ptr, idx_t size, bool success = TrySimpleIntegerCast( value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); break; + case LogicalTypeId::UHUGEINT: { + auto &result_value = static_cast(vector_ptr[chunk_col_id])[number_of_rows]; + success = TryCast::Operation(string_t(value_ptr, NumericCast(size)), + result_value, false); + break; + } + case LogicalTypeId::BIGNUM: { + try { + while (size > 0 && StringUtil::CharacterIsSpace(*value_ptr)) { + value_ptr++; + size--; + } + while (size > 0 && StringUtil::CharacterIsSpace(value_ptr[size - 1])) { + size--; + } + auto bignum = Bignum::VarcharToBignum(string_t(value_ptr, NumericCast(size))); + static_cast(vector_ptr[chunk_col_id])[number_of_rows] = + StringVector::AddStringOrBlob(parse_chunk.data[chunk_col_id], bignum.data(), bignum.size()); + } catch (const ConversionException &) { + success = false; + } + break; + } case LogicalTypeId::DOUBLE: success = TryDoubleCast(value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], @@ -494,7 +528,10 @@ DataChunk &StringValueResult::ToChunk() { throw InternalException("CSVScanner: ToChunk() function. Has a negative number of rows, this indicates an " "issue with the error handler."); } - parse_chunk.SetCardinality(static_cast(number_of_rows)); + const idx_t n = static_cast(number_of_rows); + if (parse_chunk.size() != n) { + parse_chunk.SetChildCardinality(n); + } return parse_chunk; } @@ -821,9 +858,6 @@ void StringValueResult::NullPaddingQuotedNewlineCheck() const { bool StringValueResult::AddRowInternal() { LinePosition current_line_start = {iterator.pos.buffer_idx, iterator.pos.buffer_pos, buffer_size}; idx_t current_line_size = current_line_start - current_line_position.end; - if (store_line_size) { - error_handler.NewMaxLineSize(current_line_size); - } current_line_position.begin = current_line_position.end; current_line_position.end = current_line_start; if (current_line_size > state_machine.options.maximum_line_size.GetValue()) { @@ -1004,8 +1038,7 @@ StringValueScanner::StringValueScanner(idx_t scanner_idx_p, const shared_ptrcontext), result_size, - iterator.pos.buffer_pos, *error_handler, iterator, - buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing, + iterator.pos.buffer_pos, *error_handler, iterator, csv_file_scan, lines_read, sniffing, buffer_manager->GetFilePath(), scanner_idx_p, used_unstrictness), start_pos(0) { if (scanner_idx == 0 && csv_file_scan) { @@ -1021,8 +1054,7 @@ StringValueScanner::StringValueScanner(const shared_ptr &buffe const CSVIterator &boundary) : BaseScanner(buffer_manager, state_machine, error_handler, false, nullptr, boundary), scanner_idx(0), result(states, *state_machine, cur_buffer_handle, Allocator::DefaultAllocator(), result_size, - iterator.pos.buffer_pos, *error_handler, iterator, - buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing, + iterator.pos.buffer_pos, *error_handler, iterator, csv_file_scan, lines_read, sniffing, buffer_manager->GetFilePath(), 0, used_unstrictness), start_pos(0) { if (scanner_idx == 0 && csv_file_scan) { @@ -1074,7 +1106,7 @@ void StringValueScanner::Flush(DataChunk &insert_chunk) { return; } // convert the columns in the parsed chunk to the types of the table - insert_chunk.SetCardinality(parse_chunk); + insert_chunk.SetChildCardinality(parse_chunk.size()); // We keep track of the borked lines, in case we are ignoring errors D_ASSERT(csv_file_scan); @@ -1766,10 +1798,13 @@ bool StringValueScanner::CanDirectlyCast(const LogicalType &type, bool icu_loade case LogicalTypeId::SMALLINT: case LogicalTypeId::INTEGER: case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: case LogicalTypeId::UTINYINT: case LogicalTypeId::USMALLINT: case LogicalTypeId::UINTEGER: case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::BIGNUM: case LogicalTypeId::DOUBLE: case LogicalTypeId::FLOAT: case LogicalTypeId::DATE: @@ -1837,9 +1872,6 @@ ValidRowInfo StringValueScanner::TryRow(CSVState state, idx_t start_pos, idx_t e void StringValueScanner::SetStart() { start_pos = iterator.GetGlobalCurrentPos(); if (iterator.first_one) { - if (result.store_line_size) { - result.error_handler.NewMaxLineSize(iterator.pos.buffer_pos); - } return; } if (iterator.GetEndPos() > cur_buffer_handle->actual_size) { diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp index d523f2b53..a7b0dc082 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp @@ -80,7 +80,7 @@ void CSVSniffer::SetResultOptions() const { AdaptiveSnifferResult CSVSniffer::MinimalSniff() { if (set_columns.IsSet()) { // Nothing to see here - return AdaptiveSnifferResult(*set_columns.types, *set_columns.names, true); + return AdaptiveSnifferResult(*set_columns.types, StringsToIdentifiers(*set_columns.names), true); } // Return Types detected vector return_types; @@ -140,7 +140,7 @@ AdaptiveSnifferResult CSVSniffer::MinimalSniff() { } detected_types.push_back(d_type); } - return {detected_types, names, sniffed_column_counts.result_position > 1}; + return {detected_types, StringsToIdentifiers(names), sniffed_column_counts.result_position > 1}; } SnifferResult CSVSniffer::AdaptiveSniff(const CSVSchema &file_schema) { @@ -253,9 +253,9 @@ SnifferResult CSVSniffer::SniffCSV(const bool force_match) { } options.was_type_manually_set = manually_set; if (set_columns.IsSet()) { - return SnifferResult(*set_columns.types, *set_columns.names); + return SnifferResult(*set_columns.types, StringsToIdentifiers(*set_columns.names)); } - return SnifferResult(detected_types, names); + return SnifferResult(detected_types, StringsToIdentifiers(names)); } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp index e9f8c0ebd..5a3b53d7e 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp @@ -1,8 +1,10 @@ #include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/operator/decimal_cast_operators.hpp" #include "duckdb/common/operator/double_cast_operator.hpp" #include "duckdb/common/operator/integer_cast_operator.hpp" #include "duckdb/common/string.hpp" +#include "duckdb/common/types/bignum.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" @@ -103,6 +105,45 @@ bool CSVSniffer::EmptyOrOnlyHeader() const { return (single_row_file && best_candidate->state_machine->dialect_options.header.GetValue()) || lines_sniffed == 0; } +bool CSVSniffer::CanYouCastBignum(const char *value_ptr, idx_t value_size) { + while (value_size > 0 && StringUtil::CharacterIsSpace(*value_ptr)) { + value_ptr++; + value_size--; + } + if (value_size == 0 || *value_ptr == '+') { + return false; + } + auto trimmed_size = value_size; + while (trimmed_size > 0 && StringUtil::CharacterIsSpace(value_ptr[trimmed_size - 1])) { + trimmed_size--; + } + if (trimmed_size == 0) { + return false; + } + idx_t digit_pos = value_ptr[0] == '-' ? 1 : 0; + if (digit_pos == trimmed_size) { + return false; + } + if (!StringUtil::CharacterIsDigit(value_ptr[digit_pos])) { + return false; + } + if (digit_pos + 1 < trimmed_size && value_ptr[digit_pos] == '0' && + StringUtil::CharacterIsDigit(value_ptr[digit_pos + 1])) { + return false; + } + for (idx_t pos = digit_pos + 1; pos < trimmed_size; pos++) { + if (!StringUtil::CharacterIsDigit(value_ptr[pos])) { + return false; + } + } + try { + Bignum::VarcharToBignum(string_t(value_ptr, NumericCast(trimmed_size))); + return true; + } catch (const ConversionException &) { + return false; + } +} + bool CSVSniffer::CanYouCastIt(ClientContext &context, const string_t value, const LogicalType &type, const DialectOptions &dialect_options, const bool is_null, const char decimal_separator, const char thousands_separator) { @@ -112,7 +153,7 @@ bool CSVSniffer::CanYouCastIt(ClientContext &context, const string_t value, cons auto value_ptr = value.GetData(); auto value_size = value.GetSize(); string strip_thousands; - if (type.IsNumeric() && thousands_separator != '\0') { + if ((type.IsNumeric() || type.id() == LogicalTypeId::BIGNUM) && thousands_separator != '\0') { // If we have a thousands separator we should try to use that strip_thousands = BaseScanner::RemoveSeparator(value_ptr, value_size, thousands_separator); value_ptr = strip_thousands.c_str(); @@ -139,6 +180,11 @@ bool CSVSniffer::CanYouCastIt(ClientContext &context, const string_t value, cons int64_t dummy_value; return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); } + case LogicalTypeId::HUGEINT: { + hugeint_t dummy_value; + return TryCast::Operation(string_t(value_ptr, NumericCast(value_size)), + dummy_value, true); + } case LogicalTypeId::UTINYINT: { uint8_t dummy_value; return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); @@ -155,6 +201,13 @@ bool CSVSniffer::CanYouCastIt(ClientContext &context, const string_t value, cons uint64_t dummy_value; return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); } + case LogicalTypeId::UHUGEINT: { + uhugeint_t dummy_value; + return TryCast::Operation(string_t(value_ptr, NumericCast(value_size)), + dummy_value, true); + } + case LogicalTypeId::BIGNUM: + return CanYouCastBignum(value_ptr, value_size); case LogicalTypeId::DOUBLE: { double dummy_value; return TryDoubleCast(value_ptr, value_size, dummy_value, true, decimal_separator); diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp index 5ab92c32b..0d6d8282f 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp @@ -38,6 +38,27 @@ bool CSVSniffer::TryCastVector(Vector &parse_chunk_col, idx_t size, const Logica return CSVCast::TryCastDecimalVectorCommaSeparated(options, parse_chunk_col, dummy_result, size, parameters, sql_type, line_error); } + if (sql_type.id() == LogicalTypeId::BIGNUM) { + auto vector_data = FlatVector::GetData(parse_chunk_col); + auto &validity = FlatVector::ValidityMutable(parse_chunk_col); + string strip_thousands; + for (idx_t row_idx = 0; row_idx < size; row_idx++) { + if (validity.RowIsValid(row_idx)) { + auto value_ptr = vector_data[row_idx].GetData(); + auto value_size = vector_data[row_idx].GetSize(); + if (sniffing_state_machine.options.thousands_separator != '\0') { + strip_thousands = BaseScanner::RemoveSeparator(value_ptr, value_size, + sniffing_state_machine.options.thousands_separator); + value_ptr = strip_thousands.c_str(); + value_size = strip_thousands.size(); + } + if (!CSVSniffer::CanYouCastBignum(value_ptr, value_size)) { + return false; + } + } + } + return true; + } // target type is not varchar: perform a cast string error_message; return VectorOperations::DefaultTryCast(parse_chunk_col, dummy_result, size, &error_message, true); diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp index bf6f85888..6dfca9190 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp @@ -13,7 +13,7 @@ void CSVSniffer::ReplaceTypes() { // types supplied as name -> value map idx_t found = 0; for (idx_t i = 0; i < names.size(); i++) { - auto it = sniffing_state_machine.options.sql_types_per_column.find(names[i]); + auto it = sniffing_state_machine.options.sql_types_per_column.find(Identifier(names[i])); if (it != sniffing_state_machine.options.sql_types_per_column.end()) { best_sql_types_candidates_per_column_idx[i] = { sniffing_state_machine.options.sql_type_list[it->second]}; @@ -23,7 +23,7 @@ void CSVSniffer::ReplaceTypes() { } } if (!file_options.union_by_name && found < sniffing_state_machine.options.sql_types_per_column.size()) { - auto error_msg = CSVError::ColumnTypesError(options.sql_types_per_column, names); + auto error_msg = CSVError::ColumnTypesError(options.sql_types_per_column, StringsToIdentifiers(names)); error_handler->Error(error_msg); } return; diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp index e293c7337..bc4a873a0 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp @@ -7,7 +7,7 @@ namespace duckdb { CSVFileScan::CSVFileScan(ClientContext &context, const OpenFileInfo &file_p, CSVReaderOptions options_p, - const MultiFileOptions &file_options, const vector &names, + const MultiFileOptions &file_options, const vector &names, const vector &types, CSVSchema &file_schema, bool per_file_single_threaded, shared_ptr buffer_manager_p, bool fixed_schema) : BaseFileReader(file_p), buffer_manager(std::move(buffer_manager_p)), @@ -94,10 +94,10 @@ void CSVFileScan::SetStart() { start_iterator = skip_scanner.GetIterator(); } -void CSVFileScan::SetNamesAndTypes(const vector &names_p, const vector &types_p) { - names = names_p; +void CSVFileScan::SetNamesAndTypes(const vector &names_p, const vector &types_p) { + names = IdentifiersToStrings(names_p); types = types_p; - columns = MultiFileColumnDefinition::ColumnsFromNamesAndTypes(names, types); + columns = MultiFileColumnDefinition::ColumnsFromNamesAndTypes(names_p, types); } void CSVFileScan::InitializeFileNamesTypes() { diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp index ada129dcb..01a63e9a3 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_multi_file_info.cpp @@ -41,7 +41,7 @@ void CSVMultiFileInfo::FinalizeCopyBind(ClientContext &context, BaseFileReaderOp csv_options.sql_type_list = expected_types; csv_options.columns_set = true; for (idx_t i = 0; i < expected_types.size(); i++) { - csv_options.sql_types_per_column[expected_names[i]] = i; + csv_options.sql_types_per_column[Identifier(expected_names[i])] = i; } } @@ -60,7 +60,7 @@ unique_ptr CSVMultiFileInfo::InitializeBindData(MultiFileBind //! Function to do schema discovery over one CSV file or a list/glob of CSV files CSVSchema CSVSchemaDiscovery::SchemaDiscovery(ClientContext &context, shared_ptr &buffer_manager, CSVReaderOptions &options, const MultiFileOptions &file_options, - vector &return_types, vector &names, + vector &return_types, vector &names, MultiFileList &multi_file_list) { vector schemas; const auto option_og = options; @@ -140,7 +140,7 @@ CSVSchema CSVSchemaDiscovery::SchemaDiscovery(ClientContext &context, shared_ptr best_schema.ReplaceNullWithVarchar(); if (names.empty()) { - names = best_schema.GetNames(); + names = StringsToIdentifiers(best_schema.GetNames()); return_types = best_schema.GetTypes(); } if (only_header_or_empty_files == current_file && !options.columns_set) { @@ -159,7 +159,7 @@ CSVSchema CSVSchemaDiscovery::SchemaDiscovery(ClientContext &context, shared_ptr return best_schema; } -void CSVMultiFileInfo::BindReader(ClientContext &context, vector &return_types, vector &names, +void CSVMultiFileInfo::BindReader(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) { auto &csv_data = bind_data.bind_data->Cast(); auto &multi_file_list = *bind_data.file_list; @@ -176,7 +176,7 @@ void CSVMultiFileInfo::BindReader(ClientContext &context, vector &r "read_csv_auto or set read_csv(..., " "AUTO_DETECT=TRUE) to automatically guess columns."); } - names = options.name_list; + names = StringsToIdentifiers(options.name_list); return_types = options.sql_type_list; } if (return_types.size() != names.size()) { @@ -219,19 +219,19 @@ void CSVMultiFileInfo::FinalizeBindData(MultiFileBindData &multi_file_data) { auto &options = csv_data.options; if (!options.force_not_null_names.empty()) { // Let's first check all column names match - duckdb::unordered_set column_names; + identifier_set_t column_names; for (auto &name : names) { column_names.insert(name); } for (auto &force_name : options.force_not_null_names) { - if (column_names.find(force_name) == column_names.end()) { + if (column_names.find(Identifier(force_name)) == column_names.end()) { throw BinderException("\"force_not_null\" expected to find %s, but it was not found in the table", force_name); } } D_ASSERT(options.force_not_null.empty()); for (idx_t i = 0; i < names.size(); i++) { - if (options.force_not_null_names.find(names[i]) != options.force_not_null_names.end()) { + if (options.force_not_null_names.find(names[i].GetIdentifierName()) != options.force_not_null_names.end()) { options.force_not_null.push_back(true); } else { options.force_not_null.push_back(false); @@ -301,8 +301,8 @@ shared_ptr CSVMultiFileInfo::CreateReader(ClientContext &context options.auto_detect = false; D_ASSERT(csv_data.csv_schema.Empty()); return make_shared_ptr(context, union_data.GetFileName(), std::move(options), bind_data.file_options, - csv_names, csv_types, csv_data.csv_schema, gstate.SingleThreadedRead(), nullptr, - false); + StringsToIdentifiers(csv_names), csv_types, csv_data.csv_schema, + gstate.SingleThreadedRead(), nullptr, false); } shared_ptr CSVMultiFileInfo::CreateReader(ClientContext &context, GlobalTableFunctionState &gstate_p, diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp index 990fb8828..d4b40722d 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp @@ -48,10 +48,6 @@ unique_ptr CSVGlobalState::Next(shared_ptr &cur // initialize the boundary for this file current_boundary = current_file.start_iterator; current_boundary.SetCurrentBoundaryToPosition(single_threaded, current_file.options); - if (current_boundary.done && context.client_data->debug_set_max_line_length) { - context.client_data->debug_max_line_length = - MaxValue(context.client_data->debug_max_line_length, current_boundary.pos.buffer_pos); - } current_buffer_in_use = make_shared_ptr(*current_file.buffer_manager, current_boundary.GetBufferIdx()); initialized = true; @@ -100,10 +96,6 @@ void CSVGlobalState::FinishFile(CSVFileScan &scan) { } scan.error_handler->ErrorIfAny(); FillRejectsTable(scan); - if (context.client_data->debug_set_max_line_length) { - context.client_data->debug_max_line_length = - MaxValue(context.client_data->debug_max_line_length, scan.error_handler->GetMaxLineLength()); - } } void FillScanErrorTable(InternalAppender &scan_appender, idx_t scan_idx, idx_t file_idx, CSVFileScan &file) { diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp index 07558c6bc..429621a83 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp @@ -1,4 +1,6 @@ #include "duckdb/execution/operator/csv_scanner/csv_error.hpp" + +#include "utf8proc_wrapper.hpp" #include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/function/table/read_csv.hpp" @@ -231,14 +233,14 @@ void CSVErrorHandler::FillRejectsTable(InternalAppender &errors_appender, const break; case CSVErrorType::TOO_FEW_COLUMNS: if (col_idx + 1 < bind_data.names.size()) { - errors_appender.Append(string_t(bind_data.names[col_idx + 1])); + errors_appender.Append(string_t(bind_data.names[col_idx + 1].GetIdentifierName())); } else { errors_appender.Append(Value()); } break; default: if (col_idx < bind_data.names.size()) { - errors_appender.Append(string_t(bind_data.names[col_idx])); + errors_appender.Append(string_t(bind_data.names[col_idx].GetIdentifierName())); } else { errors_appender.Append(Value()); } @@ -283,9 +285,11 @@ CSVError::CSVError(string error_message_p, CSVErrorType type_p, idx_t column_idx if (reader_options.ignore_errors.GetValue()) { RemoveNewLine(error_message); } - // Let's cap the csv row to 10k bytes. For performance reasons. - if (csv_row.size() > 10000) { - csv_row.erase(csv_row.begin() + 10000, csv_row.end()); + // Cap the csv row for performance reasons. + if (reader_options.rejects_line_size_limit > 0 && csv_row.size() > reader_options.rejects_line_size_limit) { + csv_row.erase(csv_row.begin() + NumericCast(reader_options.rejects_line_size_limit), csv_row.end()); + // truncating might add invalid UTF8 at the tail end - make it valid again + Utf8Proc::MakeValid(csv_row.data(), csv_row.size(), '.'); } error << error_message << '\n'; error << fixes << '\n'; @@ -294,7 +298,7 @@ CSVError::CSVError(string error_message_p, CSVErrorType type_p, idx_t column_idx full_error_message = error.str(); } -CSVError CSVError::ColumnTypesError(case_insensitive_map_t sql_types_per_column, const vector &names) { +CSVError CSVError::ColumnTypesError(identifier_map_t sql_types_per_column, const vector &names) { for (idx_t i = 0; i < names.size(); i++) { auto it = sql_types_per_column.find(names[i]); if (it != sql_types_per_column.end()) { diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp index 171ddd7ef..80c813c93 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp @@ -357,6 +357,12 @@ void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, throw BinderException("Unsupported parameter for REJECTS_LIMIT: cannot be negative"); } rejects_limit = NumericCast(limit); + } else if (loption == "rejects_line_size_limit") { + auto limit = ParseInteger(value, loption); + if (limit < 0) { + throw BinderException("Unsupported parameter for REJECTS_LINE_SIZE_LIMIT: cannot be negative"); + } + rejects_line_size_limit = NumericCast(limit); } else if (loption == "encoding") { encoding = ParseString(value, loption); } else if (loption == "thousands") { @@ -549,13 +555,14 @@ static Value StringVectorToValue(const vector &vec) { static uint8_t GetCandidateSpecificity(const LogicalType &candidate_type) { //! Const ht with accepted auto_types and their weights in specificity const duckdb::unordered_map auto_type_candidates_specificity { - {static_cast(LogicalTypeId::VARCHAR), 0}, {static_cast(LogicalTypeId::DOUBLE), 1}, - {static_cast(LogicalTypeId::FLOAT), 2}, {static_cast(LogicalTypeId::DECIMAL), 3}, - {static_cast(LogicalTypeId::BIGINT), 4}, {static_cast(LogicalTypeId::INTEGER), 5}, - {static_cast(LogicalTypeId::SMALLINT), 6}, {static_cast(LogicalTypeId::TINYINT), 7}, - {static_cast(LogicalTypeId::TIMESTAMP_TZ), 8}, {static_cast(LogicalTypeId::TIMESTAMP), 9}, - {static_cast(LogicalTypeId::DATE), 10}, {static_cast(LogicalTypeId::TIME), 11}, - {static_cast(LogicalTypeId::BOOLEAN), 12}, {static_cast(LogicalTypeId::SQLNULL), 13}}; + {static_cast(LogicalTypeId::VARCHAR), 0}, {static_cast(LogicalTypeId::DOUBLE), 1}, + {static_cast(LogicalTypeId::FLOAT), 2}, {static_cast(LogicalTypeId::DECIMAL), 3}, + {static_cast(LogicalTypeId::BIGNUM), 4}, {static_cast(LogicalTypeId::HUGEINT), 5}, + {static_cast(LogicalTypeId::BIGINT), 6}, {static_cast(LogicalTypeId::INTEGER), 7}, + {static_cast(LogicalTypeId::SMALLINT), 8}, {static_cast(LogicalTypeId::TINYINT), 9}, + {static_cast(LogicalTypeId::TIMESTAMP_TZ), 10}, {static_cast(LogicalTypeId::TIMESTAMP), 11}, + {static_cast(LogicalTypeId::DATE), 12}, {static_cast(LogicalTypeId::TIME), 13}, + {static_cast(LogicalTypeId::BOOLEAN), 14}, {static_cast(LogicalTypeId::SQLNULL), 15}}; auto id = static_cast(candidate_type.id()); auto it = auto_type_candidates_specificity.find(id); @@ -636,11 +643,11 @@ string CSVReaderOptions::GetUserDefinedParameters() const { void CSVReaderOptions::FromNamedParameters(const named_parameter_map_t &in, ClientContext &context, MultiFileOptions &file_options) { for (auto &kv : in) { - auto loption = StringUtil::Lower(kv.first); + auto loption = StringUtil::Lower(kv.first.GetIdentifierName()); if (MultiFileReader().ParseOption(loption, kv.second, file_options, context)) { continue; } - ParseOption(context, kv.first, kv.second); + ParseOption(context, kv.first.GetIdentifierName(), kv.second); } } @@ -665,11 +672,11 @@ void CSVReaderOptions::ParseOption(ClientContext &context, const string &key, co // Parse into temporary lists first vector parsed_names; vector parsed_types; - case_insensitive_map_t parsed_types_per_column; + identifier_map_t parsed_types_per_column; for (idx_t i = 0; i < struct_children.size(); i++) { auto &name = StructType::GetChildName(child_type, i); auto &val = struct_children[i]; - parsed_names.push_back(name); + parsed_names.emplace_back(name); if (val.type().id() != LogicalTypeId::VARCHAR) { throw BinderException("read_csv requires a type specification as string"); } @@ -754,7 +761,7 @@ void CSVReaderOptions::ParseOption(ClientContext &context, const string &key, co // Parse into temporary lists first vector sql_type_names; - case_insensitive_map_t parsed_types_per_column; + identifier_map_t parsed_types_per_column; if (child_type.id() == LogicalTypeId::STRUCT) { auto &struct_children = StructValue::GetChildren(val); D_ASSERT(StructType::GetChildCount(child_type) == struct_children.size()); diff --git a/src/duckdb/src/execution/operator/filter/physical_filter.cpp b/src/duckdb/src/execution/operator/filter/physical_filter.cpp index 9c0f925d6..0a9ee17ac 100644 --- a/src/duckdb/src/execution/operator/filter/physical_filter.cpp +++ b/src/duckdb/src/execution/operator/filter/physical_filter.cpp @@ -16,7 +16,7 @@ PhysicalFilter::PhysicalFilter(PhysicalPlan &physical_plan, vector // Create a conjunction from the select list. auto conjunction = make_uniq(ExpressionType::CONJUNCTION_AND); for (auto &expr : select_list) { - conjunction->children.push_back(std::move(expr)); + conjunction->GetChildrenMutable().push_back(std::move(expr)); } expression = std::move(conjunction); } @@ -34,6 +34,14 @@ class FilterState : public CachingOperatorState { void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { context.thread.profiler.Flush(op); } + + bool SupportsReuse() const override { + return true; + } + + void Reset() override { + ResetCachingState(); + } }; unique_ptr PhysicalFilter::GetOperatorState(ExecutionContext &context) const { @@ -55,7 +63,7 @@ OperatorResultType PhysicalFilter::ExecuteInternal(ExecutionContext &context, Da InsertionOrderPreservingMap PhysicalFilter::ParamsToString() const { InsertionOrderPreservingMap result; - result["__expression__"] = expression->GetName(); + result["__expression__"] = expression->GetName().GetIdentifierName(); SetEstimatedCardinality(result, estimated_cardinality); return result; } diff --git a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp index e79a4d044..ef3e7122b 100644 --- a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp @@ -33,8 +33,8 @@ SinkFinalizeType PhysicalBatchCollector::Finalize(Pipeline &pipeline, Event &eve auto &gstate = input.global_state.Cast(); auto collection = gstate.data.FetchCollection(); D_ASSERT(collection); - auto result = make_uniq(statement_type, properties, names, std::move(collection), - context.GetClientProperties()); + auto result = make_uniq(statement_type, properties, IdentifiersToStrings(names), + std::move(collection), context.GetClientProperties()); gstate.result = std::move(result); return SinkFinalizeType::READY; } diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp index 9e3caab6c..a433a5049 100644 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp @@ -100,8 +100,8 @@ unique_ptr PhysicalBufferedBatchCollector::GetGlobalSinkState(C unique_ptr PhysicalBufferedBatchCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); auto cc = gstate.context.lock(); - auto result = make_uniq(statement_type, properties, types, names, cc->GetClientProperties(), - gstate.buffered_data); + auto result = make_uniq(statement_type, properties, types, IdentifiersToStrings(names), + cc->GetClientProperties(), gstate.buffered_data); return std::move(result); } diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp index 8ee1e1617..0804f9b92 100644 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp @@ -62,8 +62,8 @@ unique_ptr PhysicalBufferedCollector::GetResult(GlobalSinkState &st lock_guard l(gstate.glock); // FIXME: maybe we want to check if the execution was successful before creating the StreamQueryResult ? auto cc = gstate.context.lock(); - auto result = make_uniq(statement_type, properties, types, names, cc->GetClientProperties(), - gstate.buffered_data); + auto result = make_uniq(statement_type, properties, types, IdentifiersToStrings(names), + cc->GetClientProperties(), gstate.buffered_data); return std::move(result); } diff --git a/src/duckdb/src/execution/operator/helper/physical_connect.cpp b/src/duckdb/src/execution/operator/helper/physical_connect.cpp new file mode 100644 index 000000000..3bfdff497 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_connect.cpp @@ -0,0 +1,45 @@ +#include "duckdb/execution/operator/helper/physical_connect.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/attached_database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/storage/storage_extension.hpp" + +namespace duckdb { + +SourceResultType PhysicalConnect::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + // The 4 grammar forms are accepted at parse time. This commit implements only + // `CONNECT ` (attached-db identifier); the other forms stay as per-form + // NotImplementedException placeholders and get swapped in by follow-up PRs. + if (info->target_is_local) { + throw NotImplementedException("CONNECT LOCAL is not yet implemented"); + } + if (info->name.empty()) { + throw NotImplementedException("CONNECT with no target is not yet implemented"); + } + if (info->name_is_string_literal) { + throw NotImplementedException("CONNECT '' is not yet implemented"); + } + + auto &client = context.client; + + // At most one active connection; only DISCONNECT clears it (even if the target was detached + // elsewhere, so the user explicitly acknowledges the broken connection). + if (client.IsConnected()) { + auto current = client.TryGetConnectedCatalog(); + throw InvalidInputException("Already connected to \"%s\"; DISCONNECT first before issuing another CONNECT", + current ? current->GetName().GetIdentifierName() : ""); + } + + auto target = DatabaseManager::Get(client).GetDatabase(info->name); + if (!target) { + throw InvalidInputException("Database \"%s\" is not attached", info->name); + } + // ConnectToCatalog resolves and caches the dispatch function name; throws if the catalog returns + // empty from GetConnectFunctionName (i.e. CONNECT is not supported in this context). + client.ConnectToCatalog(target); + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp b/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp index 002fad108..4c83e27c2 100644 --- a/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp @@ -13,7 +13,6 @@ SourceResultType PhysicalCreateSecret::GetDataInternal(ExecutionContext &context secret_manager.CreateSecret(client, create_input); chunk.data[0].Append(Value::BOOLEAN(true)); - chunk.SetCardinality(1); return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/helper/physical_disconnect.cpp b/src/duckdb/src/execution/operator/helper/physical_disconnect.cpp new file mode 100644 index 000000000..b3b213125 --- /dev/null +++ b/src/duckdb/src/execution/operator/helper/physical_disconnect.cpp @@ -0,0 +1,18 @@ +#include "duckdb/execution/operator/helper/physical_disconnect.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" + +namespace duckdb { + +SourceResultType PhysicalDisconnect::GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const { + auto &client = context.client; + // IsConnected() so DISCONNECT still works when the target was detached elsewhere. + if (!client.IsConnected()) { + throw InvalidInputException("DISCONNECT: no active CONNECT (already on LOCAL)"); + } + client.DisconnectFromCatalog(); + return SourceResultType::FINISHED; +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp index 09a83deea..86f6b14fc 100644 --- a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp @@ -39,7 +39,6 @@ SourceResultType PhysicalExplainAnalyze::GetDataInternal(ExecutionContext &conte chunk.data[0].Append(Value("analyzed_plan")); chunk.data[1].Append(Value(gstate.analyzed_plan)); - chunk.SetCardinality(1); return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/helper/physical_limit.cpp b/src/duckdb/src/execution/operator/helper/physical_limit.cpp index 8fb52af84..b57b132aa 100644 --- a/src/duckdb/src/execution/operator/helper/physical_limit.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_limit.cpp @@ -123,7 +123,8 @@ SinkResultType PhysicalLimit::Sink(ExecutionContext &context, DataChunk &chunk, } auto max_cardinality = max_element - state.current_offset; if (max_cardinality < chunk.size()) { - chunk.SetCardinality(max_cardinality); + // truncate the chunk to the first max_cardinality rows + chunk.Slice(0, max_cardinality); } state.data.Append(chunk, state.partition_info.batch_index.GetIndex()); state.current_offset += chunk.size(); @@ -236,7 +237,6 @@ Value PhysicalLimit::GetDelimiter(ExecutionContext &context, DataChunk &input, c for (idx_t c = 0; c < input.ColumnCount(); c++) { ConstantVector::Reference(single_row_input.data[c], count_t(1), input.data[c], 0, input.size()); } - single_row_input.SetCardinality(1); limit_executor.Execute(single_row_input, limit_chunk); auto limit_value = limit_chunk.GetValue(0, 0); return limit_value; diff --git a/src/duckdb/src/execution/operator/helper/physical_load.cpp b/src/duckdb/src/execution/operator/helper/physical_load.cpp index d6bcda546..132f08853 100644 --- a/src/duckdb/src/execution/operator/helper/physical_load.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_load.cpp @@ -1,5 +1,8 @@ #include "duckdb/execution/operator/helper/physical_load.hpp" #include "duckdb/main/extension_helper.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -39,7 +42,12 @@ SourceResultType PhysicalLoad::GetDataInternal(ExecutionContext &context, DataCh } } else { - ExtensionHelper::LoadExternalExtension(context.client, info->filename); + ExtensionLoadOptions options; + options.extension_name = info->filename; + options.alias = info->alias; + ExtensionHelper::LoadExternalExtension(context.client, options); + // adds an explicitly set extension schema to the search path + ExtensionLoader::RefreshSearchPath(context.client); } return SourceResultType::FINISHED; diff --git a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp index a42e7a4a8..665f91679 100644 --- a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp @@ -67,8 +67,9 @@ unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState if (!gstate.collection) { gstate.collection = CreateCollection(*gstate.context); } - auto result = make_uniq(statement_type, properties, names, std::move(gstate.collection), - gstate.context->GetClientProperties()); + auto result = + make_uniq(statement_type, properties, IdentifiersToStrings(names), + std::move(gstate.collection), gstate.context->GetClientProperties()); return std::move(result); } diff --git a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp index 7603b6e87..11b914b52 100644 --- a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp @@ -8,7 +8,7 @@ SourceResultType PhysicalPrepare::GetDataInternal(ExecutionContext &context, Dat auto &client = context.client; // store the prepared statement in the context - ClientData::Get(client).prepared_statements[name.ToStdString()] = prepared; + ClientData::Get(client).prepared_statements[Identifier(name.ToStdString())] = prepared; return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/helper/physical_reset.cpp b/src/duckdb/src/execution/operator/helper/physical_reset.cpp index d549d633b..9eec62b40 100644 --- a/src/duckdb/src/execution/operator/helper/physical_reset.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_reset.cpp @@ -9,10 +9,11 @@ namespace duckdb { void PhysicalReset::ResetExtensionVariable(ExecutionContext &context, DBConfig &config, ExtensionOption &extension_option) const { + auto effective_scope = scope == SetScope::AUTOMATIC ? extension_option.default_scope : scope; if (extension_option.set_function) { - extension_option.set_function(context.client, scope, extension_option.default_value); + extension_option.set_function(context.client, effective_scope, extension_option.default_value); } - if (scope == SetScope::GLOBAL) { + if (effective_scope == SetScope::GLOBAL) { config.ResetOption(extension_option); } else { auto &client_config = ClientConfig::GetConfig(context.client); diff --git a/src/duckdb/src/execution/operator/helper/physical_set.cpp b/src/duckdb/src/execution/operator/helper/physical_set.cpp index c9d65340e..5415f670b 100644 --- a/src/duckdb/src/execution/operator/helper/physical_set.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_set.cpp @@ -3,17 +3,12 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/main/database.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { void PhysicalSet::SetGenericVariable(ClientContext &context, idx_t setting_index, SetScope scope, Value target_value) { - if (scope == SetScope::GLOBAL) { - auto &config = DBConfig::GetConfig(context); - config.SetOption(setting_index, std::move(target_value)); - } else { - auto &client_config = ClientConfig::GetConfig(context); - client_config.user_settings.SetUserSetting(setting_index, std::move(target_value)); - } + Settings::Set(context, setting_index, scope, std::move(target_value)); } void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const String &name, diff --git a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp index 7813ef0e2..a518846a5 100644 --- a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp @@ -4,10 +4,11 @@ namespace duckdb { -PhysicalSetVariable::PhysicalSetVariable(PhysicalPlan &physical_plan, const string &name_p, idx_t estimated_cardinality) +PhysicalSetVariable::PhysicalSetVariable(PhysicalPlan &physical_plan, const Identifier &name_p, + idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::SET_VARIABLE, {LogicalType::BOOLEAN}, estimated_cardinality), - name(physical_plan.ArenaRef().MakeString(name_p)) { + name(physical_plan.ArenaRef().MakeString(name_p.GetIdentifierName())) { } SourceResultType PhysicalSetVariable::GetDataInternal(ExecutionContext &context, DataChunk &chunk, diff --git a/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp b/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp index 456086500..d4232942a 100644 --- a/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp @@ -37,7 +37,6 @@ SourceResultType PhysicalUpdateExtensions::GetDataInternal(ExecutionContext &con data.offset++; count++; } - chunk.SetCardinality(count); return data.offset >= data.update_result_entries.size() ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; @@ -52,7 +51,7 @@ unique_ptr PhysicalUpdateExtensions::GetGlobalSourceState(Cli } else { // Update extensions in extensions_to_update for (const auto &ext : info->extensions_to_update) { - res->update_result_entries.emplace_back(ExtensionHelper::UpdateExtension(context, ext)); + res->update_result_entries.emplace_back(ExtensionHelper::UpdateExtension(context, ext.GetIdentifierName())); } } diff --git a/src/duckdb/src/execution/operator/helper/physical_verify_vector.cpp b/src/duckdb/src/execution/operator/helper/physical_verify_vector.cpp index 2b79482b5..7bf7bacb1 100644 --- a/src/duckdb/src/execution/operator/helper/physical_verify_vector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_verify_vector.cpp @@ -35,7 +35,6 @@ OperatorResultType VerifyEmitConstantVectors(const DataChunk &input, DataChunk & for (idx_t c = 0; c < chunk.ColumnCount(); c++) { ConstantVector::Reference(chunk.data[c], count_t(1), copied_input.data[c], state.const_idx, 1); } - chunk.SetCardinality(1); state.const_idx++; if (state.const_idx >= copied_input.size()) { state.const_idx = 0; @@ -47,7 +46,7 @@ OperatorResultType VerifyEmitConstantVectors(const DataChunk &input, DataChunk & OperatorResultType VerifyEmitDictionaryVectors(const DataChunk &input, DataChunk &chunk, OperatorState &state) { input.Copy(chunk); for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - Vector::DebugTransformToDictionary(chunk.data[c], chunk.size()); + Vector::DebugTransformToDictionary(chunk.data[c]); } return OperatorResultType::NEED_MORE_INPUT; } @@ -202,7 +201,6 @@ OperatorResultType VerifyEmitSequenceVector(const DataChunk &input_p, DataChunk chunk.data[c].Sequence(start, increment, max_length); } } - chunk.SetCardinality(max_length); state.const_idx += max_length; if (state.const_idx >= input.size()) { state.const_idx = 0; @@ -214,7 +212,7 @@ OperatorResultType VerifyEmitSequenceVector(const DataChunk &input_p, DataChunk OperatorResultType VerifyEmitNestedShuffleVector(const DataChunk &input, DataChunk &chunk, OperatorState &state) { input.Copy(chunk); for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - Vector::DebugShuffleNestedVector(chunk.data[c], chunk.size()); + Vector::DebugShuffleNestedVector(chunk.data[c]); } return OperatorResultType::NEED_MORE_INPUT; } diff --git a/src/duckdb/src/execution/operator/join/outer_join_marker.cpp b/src/duckdb/src/execution/operator/join/outer_join_marker.cpp index 66a1aa035..7fc3badab 100644 --- a/src/duckdb/src/execution/operator/join/outer_join_marker.cpp +++ b/src/duckdb/src/execution/operator/join/outer_join_marker.cpp @@ -18,6 +18,9 @@ void OuterJoinMarker::Reset() { if (!enabled) { return; } + if (count == 0) { + return; + } memset(found_match.get(), 0, sizeof(bool) * count); } @@ -97,7 +100,6 @@ void OuterJoinMarker::Scan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanS result.data[col_idx].Slice(lstate.scan_chunk.data[col_idx - left_column_count], lstate.match_sel, result_count); } - result.SetCardinality(result_count); return; } } diff --git a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp index 85f1d7cb1..68e95e145 100644 --- a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp +++ b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp @@ -189,7 +189,7 @@ bool PerfectHashJoinExecutor::FullScanHashTable() { return true; } -bool PerfectHashJoinExecutor::FillSelectionVectorSwitchBuild(Vector &source, SelectionVector &sel_vec, +bool PerfectHashJoinExecutor::FillSelectionVectorSwitchBuild(const Vector &source, SelectionVector &sel_vec, SelectionVector &seq_sel_vec, idx_t count) { switch (source.GetType().InternalType()) { case PhysicalType::INT8: @@ -218,7 +218,7 @@ bool PerfectHashJoinExecutor::FillSelectionVectorSwitchBuild(Vector &source, Sel } template -bool PerfectHashJoinExecutor::TemplatedFillSelectionVectorBuild(Vector &source, SelectionVector &sel_vec, +bool PerfectHashJoinExecutor::TemplatedFillSelectionVectorBuild(const Vector &source, SelectionVector &sel_vec, SelectionVector &seq_sel_vec, idx_t count) { if (perfect_join_statistics.build_min.IsNull() || perfect_join_statistics.build_max.IsNull()) { return false; @@ -283,7 +283,7 @@ OperatorResultType PerfectHashJoinExecutor::ProbePerfectHashTable(ExecutionConte state.join_keys.Reset(); state.probe_executor.Execute(input, state.join_keys); // select the keys that are in the min-max range - auto &keys_vec = state.join_keys.data[0]; + const auto &keys_vec = state.join_keys.data[0]; auto keys_count = state.join_keys.size(); // todo: add check for fast pass when probe is part of build domain FillSelectionVectorSwitchProbe(keys_vec, keys_count, state.probe_sel_vec, probe_sel_count, &state.build_sel_vec); @@ -304,7 +304,7 @@ OperatorResultType PerfectHashJoinExecutor::ProbePerfectHashTable(ExecutionConte return OperatorResultType::NEED_MORE_INPUT; } -void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(Vector &source, const idx_t &count, +void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(const Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, idx_t &probe_sel_count, optional_ptr build_sel_vec) const { if (build_sel_vec) { @@ -315,7 +315,7 @@ void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(Vector &source, con } template -void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(Vector &source, const idx_t &count, +void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(const Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, idx_t &probe_sel_count, SelectionVector *build_sel_vec) const { D_ASSERT(BUILD_SEL_VEC == static_cast(build_sel_vec)); @@ -366,7 +366,7 @@ void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(Vector &source, con } template -void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(Vector &source, const idx_t &count, +void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(const Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, idx_t &probe_sel_count, SelectionVector *build_sel_vec) const { D_ASSERT(probe_sel_count == 0); diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp index 9611a7c3a..1ae267e1d 100644 --- a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp @@ -239,16 +239,18 @@ class AsOfPayloadScanner { using BLOCK_ITERATOR = block_iterator_t; BLOCK_ITERATOR itr(block_state, chunk_idx, 0); - const auto sort_keys = FlatVector::GetDataMutable(sort_key_pointers); const auto result_count = NextSize(); - for (idx_t i = 0; i < result_count; ++i) { - const auto idx = block_state.GetIndex(chunk_idx, i); - sort_keys[i] = &itr[idx]; + { + auto writer = FlatVector::Writer(sort_key_pointers, result_count); + for (idx_t i = 0; i < result_count; ++i) { + const auto idx = block_state.GetIndex(chunk_idx, i); + writer.WriteValue(&itr[idx]); + } } // Scan scan_chunk.Reset(); - scan_state.Scan(sorted_run, sort_key_pointers, result_count, scan_chunk); + scan_state.Scan(sorted_run, sort_key_pointers, scan_chunk); return scan_chunk.size() > 0; } @@ -1025,7 +1027,7 @@ void AsOfProbeBuffer::ScanLeft() { const auto count = lhs_payload.size(); lhs_valid_mask.Reset(); for (auto col_idx : op.null_sensitive) { - auto &col = lhs_keys.data[col_idx]; + const auto &col = lhs_keys.data[col_idx]; lhs_valid_mask.Combine(col, count); } @@ -1196,11 +1198,10 @@ void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &c // Reference the projected right payload into the result for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { const auto rhs_idx = op.right_projection_map[col_idx]; - auto &source = rhs_input.data[rhs_idx]; + const auto &source = rhs_input.data[rhs_idx]; auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; target.Reference(source); } - chunk.SetChildCardinality(lhs_match_count); // Update the match masks for the rows we ended up with left_outer.Reset(); @@ -1627,7 +1628,6 @@ void AsOfLocalSourceState::ExecuteRightTask(ExecutionContext &context, DataChunk const auto rhs_idx = op.right_projection_map[col_idx]; chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); } - chunk.SetCardinality(result_count); return; } diff --git a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp index b056dd09d..9c4993c74 100644 --- a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp @@ -216,7 +216,7 @@ OperatorResultType PhysicalBlockwiseNLJoin::ExecuteInternal(ExecutionContext &co InsertionOrderPreservingMap PhysicalBlockwiseNLJoin::ParamsToString() const { InsertionOrderPreservingMap result; result["Join Type"] = EnumUtil::ToString(join_type); - result["Condition"] = condition->GetName(); + result["Condition"] = condition->GetName().GetIdentifierName(); SetEstimatedCardinality(result, estimated_cardinality); return result; } diff --git a/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp b/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp index bd5884fd6..56c32317a 100644 --- a/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp @@ -36,9 +36,9 @@ InsertionOrderPreservingMap PhysicalComparisonJoin::ParamsToString() con condition_info += "\n"; } D_ASSERT(join_condition.IsComparison()); - condition_info += StringUtil::Format("%s %s %s", join_condition.GetLHS().GetName(), + condition_info += StringUtil::Format("%s %s %s", join_condition.GetLHS().GetName().GetIdentifierName(), ExpressionTypeToOperator(join_condition.GetComparisonType()), - join_condition.GetRHS().GetName()); + join_condition.GetRHS().GetName().GetIdentifierName()); } if (predicate) { @@ -123,7 +123,7 @@ void PhysicalComparisonJoin::ConstructEmptyJoinResult(JoinType join_type, bool h // MARK join with empty hash table D_ASSERT(result.ColumnCount() == input.ColumnCount() + 1); // for every data vector, we just reference the child chunk - result.SetCardinality(input); + result.SetChildCardinality(input.size()); for (idx_t i = 0; i < input.ColumnCount(); i++) { result.data[i].Reference(input.data[i]); } @@ -146,7 +146,7 @@ void PhysicalComparisonJoin::ConstructEmptyJoinResult(JoinType join_type, bool h } else if (join_type == JoinType::LEFT || join_type == JoinType::OUTER || join_type == JoinType::SINGLE) { // LEFT/FULL OUTER/SINGLE join and build side is empty // for the LHS we reference the data - result.SetCardinality(input.size()); + result.SetChildCardinality(input.size()); for (idx_t i = 0; i < input.ColumnCount(); i++) { result.data[i].Reference(input.data[i]); } diff --git a/src/duckdb/src/execution/operator/join/physical_cross_product.cpp b/src/duckdb/src/execution/operator/join/physical_cross_product.cpp index dc3f704ef..822f6342e 100644 --- a/src/duckdb/src/execution/operator/join/physical_cross_product.cpp +++ b/src/duckdb/src/execution/operator/join/physical_cross_product.cpp @@ -21,7 +21,23 @@ class CrossProductGlobalState : public GlobalSinkState { public: explicit CrossProductGlobalState(ClientContext &context, const PhysicalCrossProduct &op) : rhs_materialized(context, op.children[1].get().GetTypes()) { + ResetState(context); + } + +private: + void ResetState(ClientContext &context) { + rhs_materialized.ResetForReuse(); rhs_materialized.InitializeAppend(append_state); + GlobalSinkState::Reset(context); + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + ResetState(context); } ColumnDataCollection rhs_materialized; @@ -29,10 +45,21 @@ class CrossProductGlobalState : public GlobalSinkState { mutex rhs_lock; }; +class CrossProductLocalSinkState : public LocalSinkState { +public: + bool SupportsReuse() const override { + return true; + } +}; + unique_ptr PhysicalCrossProduct::GetGlobalSinkState(ClientContext &context) const { return make_uniq(context, *this); } +unique_ptr PhysicalCrossProduct::GetLocalSinkState(ExecutionContext &context) const { + return make_uniq(); +} + SinkResultType PhysicalCrossProduct::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &sink = input.global_state.Cast(); lock_guard client_guard(sink.rhs_lock); @@ -43,9 +70,9 @@ SinkResultType PhysicalCrossProduct::Sink(ExecutionContext &context, DataChunk & //===--------------------------------------------------------------------===// // Operator //===--------------------------------------------------------------------===// -CrossProductExecutor::CrossProductExecutor(ColumnDataCollection &rhs) - : rhs(rhs), position_in_chunk(0), initialized(false), finished(false) { +CrossProductExecutor::CrossProductExecutor(ColumnDataCollection &rhs) : rhs(rhs) { rhs.InitializeScanChunk(scan_chunk); + Reset(); } void CrossProductExecutor::Reset(const DataChunk &input, DataChunk &output) { @@ -57,6 +84,15 @@ void CrossProductExecutor::Reset(const DataChunk &input, DataChunk &output) { scan_chunk.Reset(); } +void CrossProductExecutor::Reset() { + initialized = false; + finished = false; + scan_input_chunk = false; + position_in_chunk = 0; + scan_state = ColumnDataScanState(); + scan_chunk.Reset(); +} + bool CrossProductExecutor::NextValue(const DataChunk &input, DataChunk &output) { if (!initialized) { // not initialized yet: initialize the scan @@ -99,7 +135,7 @@ OperatorResultType CrossProductExecutor::Execute(const DataChunk &input, DataChu auto &constant_chunk = scan_input_chunk ? scan_chunk : input; auto col_count = constant_chunk.ColumnCount(); auto col_offset = scan_input_chunk ? input.ColumnCount() : 0; - output.SetCardinality(constant_chunk.size()); + output.SetChildCardinality(constant_chunk.size()); for (idx_t i = 0; i < col_count; i++) { output.data[col_offset + i].Reference(constant_chunk.data[i]); } @@ -121,6 +157,15 @@ class CrossProductOperatorState : public CachingOperatorState { } CrossProductExecutor executor; + + bool SupportsReuse() const override { + return true; + } + + void Reset() override { + ResetCachingState(); + executor.Reset(); + } }; unique_ptr PhysicalCrossProduct::GetOperatorState(ExecutionContext &context) const { diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp index aec5331b4..7b608a3fa 100644 --- a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp @@ -6,6 +6,7 @@ #include "duckdb/common/types.hpp" #include "duckdb/common/types/value_map.hpp" #include "duckdb/common/uhugeint.hpp" +#include "duckdb/common/types/uhugeint.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/join_hashtable.hpp" #include "duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp" @@ -23,14 +24,13 @@ #include "duckdb/parallel/pipeline.hpp" #include "duckdb/parallel/thread_context.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" -#include "duckdb/planner/filter/perfect_hash_join_filter.hpp" -#include "duckdb/planner/filter/prefix_range_filter.hpp" -#include "duckdb/planner/filter/selectivity_optional_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/planner/table_filter_set.hpp" #include "duckdb/storage/buffer_manager.hpp" @@ -111,7 +111,7 @@ void PhysicalHashJoin::ExtractResidualPredicateColumns(unique_ptr &p ExpressionIterator::EnumerateExpression(predicate, [&](unique_ptr &expr) { if (expr->GetExpressionClass() == ExpressionClass::BOUND_REF) { auto &ref = expr->Cast(); - idx_t col_idx = ref.index; + idx_t col_idx = ref.Index(); if (col_idx < probe_column_count) { // Probe (LHS) column @@ -179,7 +179,7 @@ void PhysicalHashJoin::InitializeBuildSide(const vector &lhs_input_ idx_t cond_idx = 0; for (auto &condition : conditions) { if (condition.GetRHS().GetExpressionClass() == ExpressionClass::BOUND_REF) { - auto build_input_idx = condition.GetRHS().Cast().index; + auto build_input_idx = condition.GetRHS().Cast().Index(); build_columns_in_conditions.emplace(build_input_idx, cond_idx); } cond_idx++; @@ -268,6 +268,18 @@ JoinFilterGlobalState::~JoinFilterGlobalState() { JoinFilterLocalState::~JoinFilterLocalState() { } +static bool CanUsePerfectHashJoin(const PhysicalHashJoin &op, PerfectHashJoinExecutor &perfect_join_executor) { + if (op.conditions.size() != 1 || !op.conditions[0].GetRightStats()) { + return false; + } + const auto &right_stats = *op.conditions[0].GetRightStats(); + if (!TypeIsIntegral(right_stats.GetType().InternalType()) || !NumericStats::HasMinMax(right_stats)) { + return false; + } + return perfect_join_executor.CanDoPerfectHashJoin(op, NumericStats::Min(right_stats), + NumericStats::Max(right_stats)); +} + unique_ptr JoinFilterPushdownInfo::GetGlobalState(ClientContext &context, const PhysicalOperator &op) const { // clear any previously set filters @@ -293,14 +305,7 @@ class HashJoinGlobalSinkState : public GlobalSinkState { // For perfect hash join perfect_join_executor = make_uniq(op, *hash_table); - bool use_perfect_hash = false; - if (op.conditions.size() == 1 && op.conditions[0].GetRightStats()) { - const auto &right_stats = *op.conditions[0].GetRightStats(); - if (TypeIsIntegral(right_stats.GetType().InternalType()) && NumericStats::HasMinMax(right_stats)) { - use_perfect_hash = perfect_join_executor->CanDoPerfectHashJoin(op, NumericStats::Min(right_stats), - NumericStats::Max(right_stats)); - } - } + auto use_perfect_hash = CanUsePerfectHashJoin(op, *perfect_join_executor); // For external hash join external = Settings::Get(context); // Set probe types @@ -324,6 +329,40 @@ class HashJoinGlobalSinkState : public GlobalSinkState { void ScheduleFinalize(Pipeline &pipeline, Event &event); void InitializeProbeSpill(); + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + // Use single-partition mode on subsequent iterations to avoid 16-partition overhead for small joins. + hash_table->ResetForNewIterationSinglePartition(); + perfect_join_executor = make_uniq(op, *hash_table); + auto use_perfect_hash = CanUsePerfectHashJoin(op, *perfect_join_executor); + finalized = false; + active_local_states = 0; + external = Settings::Get(context); + total_size = 0; + max_partition_size = 0; + max_partition_count = 0; + probe_side_requirement = 0; + local_hash_tables.clear(); + owned_local_hash_tables.clear(); + probe_spill.reset(); + scanned_data = false; + preserve_build_for_reuse = false; + skip_filter_pushdown = false; + global_filter_state.reset(); + temporary_memory_state->SetZero(); + keep_local_hash_tables = true; + if (op.filter_pushdown) { + if (op.filter_pushdown->probe_info.empty() && use_perfect_hash) { + skip_filter_pushdown = true; + } + global_filter_state = op.filter_pushdown->GetGlobalState(context, op); + } + GlobalSinkState::Reset(context); + } + public: ClientContext &context; const PhysicalHashJoin &op; @@ -350,8 +389,10 @@ class HashJoinGlobalSinkState : public GlobalSinkState { idx_t max_partition_count; idx_t probe_side_requirement; - //! Hash tables built by each thread - vector> local_hash_tables; + //! Hash tables built by each thread (owned by the corresponding local sink states) + vector> local_hash_tables; + //! Local hash tables whose ownership has been transferred during the normal one-shot finalize path + vector> owned_local_hash_tables; //! Excess probe data gathered during Sink vector probe_types; @@ -359,9 +400,12 @@ class HashJoinGlobalSinkState : public GlobalSinkState { //! Whether or not we have started scanning data using GetData atomic scanned_data; + //! Preserve the finalized build-side state across recursive CTE iterations + bool preserve_build_for_reuse = false; bool skip_filter_pushdown = false; unique_ptr global_filter_state; + bool keep_local_hash_tables = false; }; unique_ptr JoinFilterPushdownInfo::GetLocalState(JoinFilterGlobalState &gstate) const { @@ -373,7 +417,7 @@ unique_ptr JoinFilterPushdownInfo::GetLocalState(JoinFilte class HashJoinLocalSinkState : public LocalSinkState { public: HashJoinLocalSinkState(const PhysicalHashJoin &op, ClientContext &context, HashJoinGlobalSinkState &gstate) - : join_key_executor(context) { + : op(op), join_key_executor(context) { auto &allocator = BufferAllocator::Get(context); for (auto &cond : op.conditions) { @@ -385,8 +429,9 @@ class HashJoinLocalSinkState : public LocalSinkState { payload_chunk.Initialize(allocator, op.payload_columns.col_types); } - hash_table = op.InitializeHashTable(context, gstate.initial_radix_bits); + hash_table = op.InitializeHashTable(context, gstate.hash_table->GetRadixBits()); hash_table->GetSinkCollection().InitializeAppendState(append_state); + keep_hash_table = gstate.keep_local_hash_tables; gstate.active_local_states++; @@ -396,6 +441,7 @@ class HashJoinLocalSinkState : public LocalSinkState { } public: + const PhysicalHashJoin &op; PartitionedTupleDataAppendState append_state; ExpressionExecutor join_key_executor; @@ -405,16 +451,77 @@ class HashJoinLocalSinkState : public LocalSinkState { //! Thread-local HT unique_ptr hash_table; + bool keep_hash_table = false; unique_ptr local_filter_state; + + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSinkState &gstate_p) override { + auto &gstate = gstate_p.Cast(); + join_keys.Reset(); + payload_chunk.Reset(); + if (hash_table) { + hash_table->ResetForNewIterationSinglePartition(); + } else { + hash_table = op.InitializeHashTable(context.client, gstate.hash_table->GetRadixBits()); + } + hash_table->GetSinkCollection().ResetAppendState(append_state); + keep_hash_table = gstate.keep_local_hash_tables; + gstate.active_local_states++; + if (op.filter_pushdown) { + local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); + } else { + local_filter_state.reset(); + } + } }; +static bool ShouldPrepareBloomFilterBuild(const PhysicalHashJoin &op) { + if (!op.filter_pushdown || op.filter_pushdown->probe_info.empty()) { + return false; + } + idx_t equality_column_count = 0; + for (auto &cond : op.conditions) { + auto cmp = cond.GetComparisonType(); + if (cmp == ExpressionType::COMPARE_EQUAL || cmp == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + equality_column_count++; + } + } + if (equality_column_count != 1) { + return false; + } + auto probe_estimated_cardinality = op.children[0].get().estimated_cardinality; + auto build_estimated_cardinality = op.children[1].get().estimated_cardinality; + if (probe_estimated_cardinality == 0 || build_estimated_cardinality == 0) { + return false; + } + static constexpr double BUILD_TO_PROBE_RATIO_THRESHOLD = 1.0; + const double build_to_probe_ratio = + static_cast(build_estimated_cardinality) / static_cast(probe_estimated_cardinality); + if (build_to_probe_ratio > BUILD_TO_PROBE_RATIO_THRESHOLD) { + return false; + } + static constexpr double NON_FILTERING_RATIO_THRESHOLD = 0.1; + static constexpr idx_t NON_FILTERING_BUILD_SIDE_THRESHOLD = 4194304; + if (!op.filter_pushdown->build_side_has_filter && build_to_probe_ratio > NON_FILTERING_RATIO_THRESHOLD && + build_estimated_cardinality > NON_FILTERING_BUILD_SIDE_THRESHOLD) { + return false; + } + return true; +} + unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &context, const idx_t initial_radix_bits) const { auto result = make_uniq(context, *this, conditions, payload_columns.col_types, join_type, initial_radix_bits, rhs_output_columns.col_idxs, residual_info ? residual_info->Copy() : nullptr, predicate ? predicate.get() : nullptr, lhs_output_in_probe); + if (ShouldPrepareBloomFilterBuild(*this)) { + result->PrepareBuildBloomFilter(children[1].get().estimated_cardinality); + } if (!delim_types.empty() && join_type == JoinType::MARK) { // correlated MARK join @@ -430,7 +537,7 @@ unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &c auto &info = result->correlated_mark_join_info; vector delim_payload_types; - vector correlated_aggregates; + vector correlated_aggregates; unique_ptr aggr; // jury-rigging the GroupedAggregateHashTable @@ -439,7 +546,7 @@ unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &c FunctionBinder function_binder(context); aggr = function_binder.BindAggregateFunction(CountStarFun::GetFunction(), {}, nullptr, AggregateType::NON_DISTINCT); - correlated_aggregates.push_back(&*aggr); + correlated_aggregates.emplace_back(*aggr); delim_payload_types.push_back(aggr->GetReturnType()); info.correlated_aggregates.push_back(std::move(aggr)); @@ -449,13 +556,13 @@ unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &c children.push_back(make_uniq_base(count_fun.GetReturnType(), 0U)); aggr = function_binder.BindAggregateFunction(count_fun, std::move(children), nullptr, AggregateType::NON_DISTINCT); - correlated_aggregates.push_back(&*aggr); + correlated_aggregates.emplace_back(*aggr); delim_payload_types.push_back(aggr->GetReturnType()); info.correlated_aggregates.push_back(std::move(aggr)); auto &allocator = BufferAllocator::Get(context); - info.correlated_counts = make_uniq(context, allocator, delim_types, - delim_payload_types, correlated_aggregates); + info.correlated_counts = make_uniq( + context, allocator, delim_types, delim_payload_types, std::move(correlated_aggregates)); info.correlated_types = delim_types; info.group_chunk.Initialize(allocator, delim_types); info.result_chunk.Initialize(allocator, delim_payload_types); @@ -473,13 +580,32 @@ unique_ptr PhysicalHashJoin::GetLocalSinkState(ExecutionContext return make_uniq(*this, context.client, gstate); } +void PhysicalHashJoin::SetPreserveBuildForRecursiveReuse(bool preserve) const { + if (!sink_state) { + return; + } + auto &state = sink_state->Cast(); + state.preserve_build_for_reuse = preserve; + if (preserve) { + state.scanned_data = false; + } +} + +bool PhysicalHashJoin::CanPreserveBuildForRecursiveReuse() const { + if (!sink_state) { + return false; + } + auto &state = sink_state->Cast(); + return state.preserve_build_for_reuse && !state.external; +} + void JoinFilterPushdownInfo::Sink(DataChunk &chunk, JoinFilterLocalState &lstate) const { // if we are pushing any filters into a probe-side, compute the min/max over the columns that we are pushing for (idx_t pushdown_idx = 0; pushdown_idx < join_condition.size(); pushdown_idx++) { auto join_condition_idx = join_condition[pushdown_idx]; for (idx_t i = 0; i < 2; i++) { idx_t aggr_idx = pushdown_idx * 2 + i; - lstate.local_aggregate_state->Sink(chunk, join_condition_idx, aggr_idx); + lstate.local_aggregate_state->Sink(chunk, join_condition_idx, aggr_idx, chunk.size()); } } } @@ -497,7 +623,7 @@ SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chun } if (payload_columns.col_types.empty()) { // there are only keys: place an empty chunk in the payload - lstate.payload_chunk.SetCardinality(chunk.size()); + lstate.payload_chunk.SetChildCardinality(chunk.size()); } else { // there are payload columns lstate.payload_chunk.ReferenceColumns(chunk, payload_columns.col_idxs); } @@ -518,7 +644,12 @@ SinkCombineResultType PhysicalHashJoin::Combine(ExecutionContext &context, Opera lstate.hash_table->GetSinkCollection().FlushAppendState(lstate.append_state); annotated_lock_guard guard(gstate.lock); - gstate.local_hash_tables.push_back(std::move(lstate.hash_table)); + if (lstate.keep_hash_table) { + gstate.local_hash_tables.push_back(*lstate.hash_table); + } else { + gstate.owned_local_hash_tables.push_back(std::move(lstate.hash_table)); + gstate.local_hash_tables.push_back(*gstate.owned_local_hash_tables.back()); + } if (gstate.local_hash_tables.size() == gstate.active_local_states) { // Set to 0 until PrepareFinalize gstate.temporary_memory_state->SetZero(); @@ -620,6 +751,26 @@ void PhysicalHashJoin::PrepareFinalize(ClientContext &context, GlobalSinkState & gstate.temporary_memory_state->SetRemainingSize(gstate.total_size); } +static void ExecuteHashJoinTableInitTask(HashJoinGlobalSinkState &sink, idx_t entry_idx_from, idx_t entry_idx_to) { + sink.hash_table->InitializePointerTable(entry_idx_from, entry_idx_to); +} + +static void ExecuteHashJoinFinalizeTask(HashJoinGlobalSinkState &sink, optional_idx partition_idx, + optional_ptr prefix_range_state) { + const auto &data_collection = sink.hash_table->GetDataCollection(); + if (!partition_idx.IsValid() || sink.hash_table->GetRadixBits() == 0) { + // Unpartitioned builds still finalize over the full chunk range even if the scheduler created a + // single "partition 0" task, because tuple-data segments are not tagged with partition ids there. + sink.hash_table->Finalize(0U, data_collection.ChunkCount(), false, prefix_range_state); + } else { + // Parallel finalize - each thread processes one partition + const auto chunk_ranges = data_collection.GetChunkRangesForPartition(partition_idx.GetIndex()); + for (auto &chunk_range : chunk_ranges) { + sink.hash_table->Finalize(chunk_range.first, chunk_range.second, true, prefix_range_state); + } + } +} + class HashJoinTableInitTask : public ExecutorTask { public: HashJoinTableInitTask(shared_ptr event_p, ClientContext &context, HashJoinGlobalSinkState &sink_p, @@ -629,7 +780,7 @@ class HashJoinTableInitTask : public ExecutorTask { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - sink.hash_table->InitializePointerTable(entry_idx_from, entry_idx_to); + ExecuteHashJoinTableInitTask(sink, entry_idx_from, entry_idx_to); event->FinishTask(); return TaskExecutionResult::TASK_FINISHED; } @@ -653,29 +804,35 @@ class HashJoinTableInitEvent : public BasePipelineEvent { HashJoinGlobalSinkState &sink; public: - void Schedule() override { - auto &context = pipeline->GetClientContext(); - vector> finalize_tasks; + vector> GetTasks() { auto &ht = *sink.hash_table; const auto entry_count = ht.capacity; - auto num_threads = NumericCast(sink.num_threads); - - // we don't have to check whether it is too skewed here, as we only initialize the pointer table + auto &context = pipeline->GetClientContext(); + vector> finalize_tasks; if (FinalizeSingleThreaded(sink, false)) { - // Single-threaded memset finalize_tasks.push_back( make_uniq(shared_from_this(), context, sink, 0U, entry_count, sink.op)); - } else { - // have 4 times more tasks than threads, but bound the to a minimum - const idx_t entries_per_task = MaxValue(entry_count / num_threads / 4, MINIMUM_ENTRIES_PER_TASK); - // Parallel memset - for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx += entries_per_task) { - auto entry_idx_to = MinValue(entry_idx + entries_per_task, entry_count); - finalize_tasks.push_back(make_uniq(shared_from_this(), context, sink, entry_idx, - entry_idx_to, sink.op)); - } + return finalize_tasks; + } + + auto num_threads = NumericCast(sink.num_threads); + // have 4 times more tasks than threads, but bound the to a minimum + const idx_t entries_per_task = MaxValue(entry_count / num_threads / 4, MINIMUM_ENTRIES_PER_TASK); + // Parallel memset + for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx += entries_per_task) { + auto entry_idx_to = MinValue(entry_idx + entries_per_task, entry_count); + finalize_tasks.push_back( + make_uniq(shared_from_this(), context, sink, entry_idx, entry_idx_to, sink.op)); } - SetTasks(std::move(finalize_tasks)); + return finalize_tasks; + } + + void ExecuteDirectly() { + ExecuteHashJoinTableInitTask(sink, 0U, sink.hash_table->capacity); + } + + void Schedule() override { + SetTasks(GetTasks()); } static constexpr const idx_t MINIMUM_ENTRIES_PER_TASK = 131072; }; @@ -689,17 +846,7 @@ class HashJoinFinalizeTask : public ExecutorTask { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - const auto &data_collection = sink.hash_table->GetDataCollection(); - if (!partition_idx.IsValid()) { - // Single-threaded finalize - sink.hash_table->Finalize(0U, data_collection.ChunkCount(), false, prefix_range_state); - } else { - // Parallel finalize - each thread processes one partition - const auto chunk_ranges = data_collection.GetChunkRangesForPartition(partition_idx.GetIndex()); - for (auto &chunk_range : chunk_ranges) { - sink.hash_table->Finalize(chunk_range.first, chunk_range.second, true, prefix_range_state); - } - } + ExecuteHashJoinFinalizeTask(sink, partition_idx, prefix_range_state); event->FinishTask(); return TaskExecutionResult::TASK_FINISHED; } @@ -722,40 +869,57 @@ class HashJoinFinalizeEvent : public BasePipelineEvent { HashJoinGlobalSinkState &sink; public: - void Schedule() override { + vector> GetTasks() { auto &ht = *sink.hash_table; const auto build_prefix_range_filter = ht.ShouldBuildPrefixRangeFilter(); - const bool finalize_single_threaded = FinalizeSingleThreaded(sink, false); - vector> finalize_tasks; - if (finalize_single_threaded) { - // Single-threaded finalize + if (FinalizeSingleThreaded(sink, false)) { auto prefix_range_state = build_prefix_range_filter ? RegisterPrefixRangeState(ht) : nullptr; finalize_tasks.push_back( make_uniq(sink, shared_from_this(), optional_idx(), prefix_range_state)); - } else { - // Parallel finalize - const auto num_partitions = RadixPartitioning::NumberOfPartitions(ht.GetRadixBits()); - const auto ¤t_partitions = ht.GetCurrentPartitions(); - for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { - if (sink.external && !current_partitions.RowIsValidUnsafe(partition_idx)) { - continue; // Partition is not being built on - } - auto prefix_range_state = build_prefix_range_filter ? RegisterPrefixRangeState(ht) : nullptr; - finalize_tasks.push_back( - make_uniq(sink, shared_from_this(), partition_idx, prefix_range_state)); + return finalize_tasks; + } + + // Parallel finalize + const auto num_partitions = RadixPartitioning::NumberOfPartitions(ht.GetRadixBits()); + const auto ¤t_partitions = ht.GetCurrentPartitions(); + for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { + if (sink.external && !current_partitions.RowIsValidUnsafe(partition_idx)) { + continue; // Partition is not being built on } + auto prefix_range_state = build_prefix_range_filter ? RegisterPrefixRangeState(ht) : nullptr; + finalize_tasks.push_back( + make_uniq(sink, shared_from_this(), partition_idx, prefix_range_state)); } - SetTasks(std::move(finalize_tasks)); + return finalize_tasks; + } + + void ExecuteDirectly() { + auto &ht = *sink.hash_table; + auto prefix_range_state = ht.ShouldBuildPrefixRangeFilter() ? RegisterPrefixRangeState(ht) : nullptr; + ExecuteHashJoinFinalizeTask(sink, optional_idx(), prefix_range_state); + FinishTasks(); + } + + void Schedule() override { + SetTasks(GetTasks()); } void FinishEvent() override { + FinishTasks(); + } + + static constexpr idx_t CHUNKS_PER_TASK = 64; + +private: + void FinishTasks() { for (auto &prefix_range_state : prefix_range_states) { sink.hash_table->MergePrefixRangeBuildState(*prefix_range_state); } sink.hash_table->GetDataCollection().VerifyEverythingPinned(); - // chains are final; materialize dict_arrays and overwrite NEXT_PTR with the dict index + // Both finalize paths finish writing the chains before reaching here, + // so dictionary emission is safe on either path if (sink.hash_table->CanUseDictionaryEmission(sink.op, sink.external, sink.op.children[0].get().estimated_cardinality)) { sink.hash_table->BuildDictionaryArrays(sink.op); @@ -764,9 +928,6 @@ class HashJoinFinalizeEvent : public BasePipelineEvent { sink.hash_table->finalized = true; } - static constexpr idx_t CHUNKS_PER_TASK = 64; - -private: optional_ptr RegisterPrefixRangeState(JoinHashTable &ht) { prefix_range_states.push_back(ht.InitializePrefixRangeBuildState()); return *prefix_range_states.back(); @@ -783,6 +944,12 @@ void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) hash_table->AllocatePointerTable(); auto new_init_event = make_shared_ptr(pipeline, *this); + if (FinalizeSingleThreaded(*this, false)) { + new_init_event->ExecuteDirectly(); + auto new_finalize_event = make_shared_ptr(pipeline, *this); + new_finalize_event->ExecuteDirectly(); + return; + } event.InsertEvent(new_init_event); auto new_finalize_event = make_shared_ptr(pipeline, *this); @@ -821,13 +988,13 @@ class HashJoinRepartitionTask : public ExecutorTask { class HashJoinRepartitionEvent : public BasePipelineEvent { public: HashJoinRepartitionEvent(Pipeline &pipeline_p, const PhysicalHashJoin &op_p, HashJoinGlobalSinkState &sink, - vector> &local_hts) + vector> &local_hts) : BasePipelineEvent(pipeline_p), op(op_p), sink(sink), local_hts(local_hts) { } const PhysicalHashJoin &op; HashJoinGlobalSinkState &sink; - vector> &local_hts; + vector> &local_hts; public: void Schedule() override { @@ -837,7 +1004,7 @@ class HashJoinRepartitionEvent : public BasePipelineEvent { idx_t total_size = 0; idx_t total_count = 0; for (auto &local_ht : local_hts) { - auto &sink_collection = local_ht->GetSinkCollection(); + auto &sink_collection = local_ht.get().GetSinkCollection(); total_size += sink_collection.SizeInBytes(); total_count += sink_collection.Count(); } @@ -854,9 +1021,11 @@ class HashJoinRepartitionEvent : public BasePipelineEvent { if (repartition_threads < local_hts.size()) { // Limit the number of threads working on repartitioning based on our memory reservation for (idx_t thread_idx = repartition_threads; thread_idx < local_hts.size(); thread_idx++) { - local_hts[thread_idx % repartition_threads]->Merge(*local_hts[thread_idx]); + local_hts[thread_idx % repartition_threads].get().Merge(local_hts[thread_idx].get()); } - local_hts.resize(repartition_threads); + auto repartition_end = + local_hts.begin() + static_cast>::difference_type>(repartition_threads); + local_hts.erase(repartition_end, local_hts.end()); } auto &context = pipeline->GetClientContext(); @@ -865,13 +1034,14 @@ class HashJoinRepartitionEvent : public BasePipelineEvent { partition_tasks.reserve(local_hts.size()); for (auto &local_ht : local_hts) { partition_tasks.push_back( - make_uniq(shared_from_this(), context, *sink.hash_table, *local_ht, op)); + make_uniq(shared_from_this(), context, *sink.hash_table, local_ht.get(), op)); } SetTasks(std::move(partition_tasks)); } void FinishEvent() override { local_hts.clear(); + sink.owned_local_hash_tables.clear(); // Minimum reservation is now the new smallest partition size const auto num_partitions = RadixPartitioning::NumberOfPartitions(sink.hash_table->GetRadixBits()); @@ -937,7 +1107,8 @@ void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, // the IN-list is expensive to execute otherwise auto in_expr = ExpressionFilter::CreateInExpression( make_uniq(info.columns[filter_idx].storage_type, idx_t(0)), std::move(in_list)); - auto filter = make_uniq(make_uniq(std::move(in_expr))); + auto filter = make_uniq( + CreateOptionalFilterExpression(std::move(in_expr), info.columns[filter_idx].storage_type)); info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(filter)); } @@ -968,8 +1139,12 @@ bool JoinFilterPushdownInfo::CanUseBloomFilter(const ClientContext &context, con // building the bloom filter is costly on the build to make probing faster, // so only use it if there are less build tuples than probing tuples static constexpr double BUILD_TO_PROBE_RATIO_THRESHOLD = 1.0; + auto build_count = ht->Count(); + if (build_count == 0) { + build_count = ht->GetSinkCollection().Count(); + } const double build_to_probe_ratio = - static_cast(ht->Count()) / static_cast(op.children[0].get().estimated_cardinality); + static_cast(build_count) / static_cast(op.children[0].get().estimated_cardinality); if (build_to_probe_ratio > BUILD_TO_PROBE_RATIO_THRESHOLD) { return false; } @@ -978,7 +1153,7 @@ bool JoinFilterPushdownInfo::CanUseBloomFilter(const ClientContext &context, con static constexpr double NON_FILTERING_RATIO_THRESHOLD = 0.1; static constexpr idx_t NON_FILTERING_BUILD_SIDE_THRESHOLD = 4194304; if (!build_side_has_filter && build_to_probe_ratio > NON_FILTERING_RATIO_THRESHOLD && - ht->Count() > NON_FILTERING_BUILD_SIDE_THRESHOLD) { + build_count > NON_FILTERING_BUILD_SIDE_THRESHOLD) { return false; } @@ -991,6 +1166,9 @@ bool JoinFilterPushdownInfo::CanUsePrefixRangeFilter(ClientContext &context, opt if (!CanUseBloomFilter(context, op, cmp, ht)) { return false; } + if (ht->Count() == 0) { + return false; + } if (ht->NullValuesAreEqual(0)) { // TODO: Support "A is B" type joins @@ -1018,39 +1196,81 @@ bool JoinFilterPushdownInfo::CanUsePrefixRangeFilter(ClientContext &context, opt } const auto &key_type = ht->conditions[0].GetLHS().GetReturnType(); - return PrefixRangeTableFilter::SupportedType(key_type); + return PrefixRangeFilter::SupportedType(key_type); +} + +static unique_ptr CreateSelectivityOptionalExpressionFilter(unique_ptr child_expr, + const LogicalType &column_type, + SelectivityOptionalFilterType type); + +static LogicalType GetRuntimeFilterInputType(const JoinFilterPushdownColumn &column, const LogicalType &runtime_type) { + return column.runtime_filter_type.IsValid() ? column.runtime_filter_type : runtime_type; +} + +static unique_ptr CreateRuntimeFilterInputExpression(ClientContext &context, + const JoinFilterPushdownColumn &column, + const LogicalType &runtime_type) { + D_ASSERT(column.storage_type.IsValid()); + auto input_type = GetRuntimeFilterInputType(column, runtime_type); + unique_ptr input = make_uniq(column.storage_type, idx_t(0)); + if (column.storage_type != input_type) { + input = BoundCastExpression::AddCastToType(context, std::move(input), input_type); + } + return input; } -void JoinFilterPushdownInfo::PushBloomFilter(const PhysicalOperator &op, JoinHashTable &ht, - const JoinFilterPushdownFilter &info, +void JoinFilterPushdownInfo::PushBloomFilter(ClientContext &context, const PhysicalOperator &op, JoinHashTable &ht, + const JoinFilterPushdownFilter &info, idx_t filter_idx, ProjectionIndex filter_col_idx) const { // If the nulls are equal, we let nulls pass. If not, we filter them auto filters_null_values = !ht.NullValuesAreEqual(0); const auto key_name = ht.conditions[0].GetRHS().ToString(); const auto key_type = ht.conditions[0].GetLHS().GetReturnType(); + auto filter_input_type = GetRuntimeFilterInputType(info.columns[filter_idx], key_type); ht.SetBuildBloomFilter(true); - auto filter = - make_uniq_base(ht.GetBloomFilter(), filters_null_values, key_name, key_type); - filter = make_uniq(std::move(filter), SelectivityOptionalFilterType::BF); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(filter)); + float selectivity_threshold; + idx_t n_vectors_to_check; + GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType::BF, selectivity_threshold, n_vectors_to_check); + vector> children; + children.push_back(CreateRuntimeFilterInputExpression(context, info.columns[filter_idx], key_type)); + auto filter_expr = make_uniq( + BoundScalarFunction(BloomFilterScalarFun::GetFunction(filter_input_type)), std::move(children), + make_uniq(ht.GetBloomFilter(), filters_null_values, key_name, key_type, + selectivity_threshold, n_vectors_to_check)); + info.dynamic_filters->PushFilter(op, filter_col_idx, + CreateSelectivityOptionalExpressionFilter(std::move(filter_expr), + info.columns[filter_idx].storage_type, + SelectivityOptionalFilterType::BF)); } -void JoinFilterPushdownInfo::PushPerfectHashJoinFilter(const PhysicalOperator &op, +void JoinFilterPushdownInfo::PushPerfectHashJoinFilter(ClientContext &context, const PhysicalOperator &op, PerfectHashJoinExecutor &perfect_join_executor, - const JoinFilterPushdownFilter &info, + const JoinFilterPushdownFilter &info, idx_t filter_idx, ProjectionIndex filter_col_idx) const { const auto key_name = op.Cast().conditions[0].GetRHS().ToString(); - auto filter = make_uniq_base(perfect_join_executor, key_name, - perfect_join_executor.GetKeyType()); - filter = make_uniq(std::move(filter), SelectivityOptionalFilterType::PHJ); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(filter)); + const auto &key_type = perfect_join_executor.GetKeyType(); + auto filter_input_type = GetRuntimeFilterInputType(info.columns[filter_idx], key_type); + float selectivity_threshold; + idx_t n_vectors_to_check; + GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType::PHJ, selectivity_threshold, n_vectors_to_check); + vector> children; + children.push_back(CreateRuntimeFilterInputExpression(context, info.columns[filter_idx], key_type)); + auto filter_expr = make_uniq( + BoundScalarFunction(PerfectHashJoinScalarFun::GetFunction(filter_input_type)), std::move(children), + make_uniq(perfect_join_executor, key_name, selectivity_threshold, + n_vectors_to_check)); + info.dynamic_filters->PushFilter(op, filter_col_idx, + CreateSelectivityOptionalExpressionFilter(std::move(filter_expr), + info.columns[filter_idx].storage_type, + SelectivityOptionalFilterType::PHJ)); } void JoinFilterPushdownInfo::RegisterPrefixRangeFilter(const JoinFilterPushdownFilter &info, ClientContext &context, - JoinHashTable &ht, const PhysicalOperator &op, + JoinHashTable &ht, const PhysicalOperator &op, idx_t filter_idx, ProjectionIndex filter_col_idx, const Value &min_val, const Value &max_val) const { const auto key_type = ht.conditions[0].GetLHS().GetReturnType(); + auto filter_input_type = GetRuntimeFilterInputType(info.columns[filter_idx], key_type); if (!ht.GetPrefixRangeFilter()) { auto prefix_filter = PrefixRangeFilter::CreatePrefixRangeFilter(key_type); prefix_filter->Initialize(context, ht.Count(), min_val, max_val); @@ -1059,10 +1279,19 @@ void JoinFilterPushdownInfo::RegisterPrefixRangeFilter(const JoinFilterPushdownF } const auto key_name = ht.conditions[0].GetRHS().ToString(); - auto rf_filter = make_uniq(ht.GetPrefixRangeFilter(), key_name, key_type); - - auto opt_rf_filter = make_uniq(std::move(rf_filter), SelectivityOptionalFilterType::PRF); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(opt_rf_filter)); + float selectivity_threshold; + idx_t n_vectors_to_check; + GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType::PRF, selectivity_threshold, n_vectors_to_check); + vector> children; + children.push_back(CreateRuntimeFilterInputExpression(context, info.columns[filter_idx], key_type)); + auto filter_expr = make_uniq( + BoundScalarFunction(PrefixRangeScalarFun::GetFunction(filter_input_type)), std::move(children), + make_uniq(ht.GetPrefixRangeFilter(), key_name, key_type, selectivity_threshold, + n_vectors_to_check)); + info.dynamic_filters->PushFilter(op, filter_col_idx, + CreateSelectivityOptionalExpressionFilter(std::move(filter_expr), + info.columns[filter_idx].storage_type, + SelectivityOptionalFilterType::PRF)); } unique_ptr JoinFilterPushdownInfo::FinalizeMinMax(JoinFilterGlobalState &gstate) const { @@ -1078,23 +1307,33 @@ unique_ptr JoinFilterPushdownInfo::FinalizeMinMax(JoinFilterGlobalSta return final_min_max; } +static unique_ptr CreateSelectivityOptionalExpressionFilter(unique_ptr child_expr, + const LogicalType &column_type, + SelectivityOptionalFilterType type) { + float selectivity_threshold; + idx_t n_vectors_to_check; + GetThresholdAndVectorsToCheck(type, selectivity_threshold, n_vectors_to_check); + return make_uniq(CreateSelectivityOptionalFilterExpression( + std::move(child_expr), column_type, selectivity_threshold, n_vectors_to_check)); +} + static void CreateDynamicMinMaxFilter(const PhysicalComparisonJoin &op, const JoinFilterPushdownFilter &info, - const ProjectionIndex &filter_col_idx, unique_ptr filter) { - info.dynamic_filters->PushFilter( - op, filter_col_idx, - make_uniq(std::move(filter), SelectivityOptionalFilterType::MIN_MAX)); + const ProjectionIndex &filter_col_idx, unique_ptr filter_expr, + const LogicalType &column_type) { + info.dynamic_filters->PushFilter(op, filter_col_idx, + CreateSelectivityOptionalExpressionFilter(std::move(filter_expr), column_type, + SelectivityOptionalFilterType::MIN_MAX)); } -static unique_ptr CreateComparisonExpressionFilter(ExpressionType comparison_type, const Value &constant, - const LogicalType &column_type) { +static unique_ptr CreateComparisonExpressionFilter(ExpressionType comparison_type, const Value &constant, + const LogicalType &column_type) { auto constant_value = constant; if (!constant_value.IsNull()) { constant_value.DefaultTryCastAs(column_type); } auto column = make_uniq(column_type, 0ULL); - auto filter_expr = BoundComparisonExpression::Create(comparison_type, std::move(column), - make_uniq(std::move(constant_value))); - return make_uniq(std::move(filter_expr)); + return BoundComparisonExpression::Create(comparison_type, std::move(column), + make_uniq(std::move(constant_value))); } unique_ptr @@ -1137,7 +1376,14 @@ JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalCo continue; } - auto condition_type = pushdown_column.storage_type; + auto condition_type = min_val.type(); + auto runtime_filter_input_type = GetRuntimeFilterInputType(pushdown_column, condition_type); + bool can_emit_runtime_filters = pushdown_column.mode == JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION; + if (can_emit_runtime_filters && perfect_join_executor) { + can_emit_runtime_filters = runtime_filter_input_type == perfect_join_executor->GetKeyType(); + } else if (can_emit_runtime_filters && ht) { + can_emit_runtime_filters = runtime_filter_input_type == ht->conditions[0].GetLHS().GetReturnType(); + } // if the HT is small we can generate a complete "OR" filter // but only if the join condition is equality. @@ -1148,8 +1394,9 @@ JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalCo // min = max - single value // generate a "one-sided" comparison filter for the LHS // Note that this also works for equalities. - info.dynamic_filters->PushFilter(op, filter_col_idx, - CreateComparisonExpressionFilter(cmp, min_val, condition_type)); + info.dynamic_filters->PushFilter( + op, filter_col_idx, + make_uniq(CreateComparisonExpressionFilter(cmp, min_val, condition_type))); } else { // min != max - generate a range filter or bloom filter + optional range filter // for non-equalities, the range must be half-open @@ -1161,7 +1408,8 @@ JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalCo CreateDynamicMinMaxFilter( op, info, filter_col_idx, CreateComparisonExpressionFilter(ExpressionType::COMPARE_GREATERTHANOREQUALTO, min_val, - condition_type)); + condition_type), + condition_type); break; } default: @@ -1173,21 +1421,23 @@ JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalCo case ExpressionType::COMPARE_LESSTHANOREQUALTO: { CreateDynamicMinMaxFilter(op, info, filter_col_idx, CreateComparisonExpressionFilter( - ExpressionType::COMPARE_LESSTHANOREQUALTO, max_val, condition_type)); + ExpressionType::COMPARE_LESSTHANOREQUALTO, max_val, condition_type), + condition_type); break; } default: break; } - if (perfect_join_executor) { - PushPerfectHashJoinFilter(op, *perfect_join_executor, info, filter_col_idx); - } else if (CanUsePrefixRangeFilter(context, ht, op, cmp, min_val_before_cast, max_val_before_cast)) { + if (can_emit_runtime_filters && perfect_join_executor) { + PushPerfectHashJoinFilter(context, op, *perfect_join_executor, info, filter_idx, filter_col_idx); + } else if (can_emit_runtime_filters && + CanUsePrefixRangeFilter(context, ht, op, cmp, min_val_before_cast, max_val_before_cast)) { // It's important that these get the min/max val before casting - RegisterPrefixRangeFilter(info, context, *ht, op, filter_col_idx, min_val_before_cast, + RegisterPrefixRangeFilter(info, context, *ht, op, filter_idx, filter_col_idx, min_val_before_cast, max_val_before_cast); - } else if (ht && CanUseBloomFilter(context, op, cmp, ht)) { - PushBloomFilter(op, *ht, info, filter_col_idx); + } else if (can_emit_runtime_filters && ht && CanUseBloomFilter(context, op, cmp, ht)) { + PushBloomFilter(context, op, *ht, info, filter_idx, filter_col_idx); } } } @@ -1235,6 +1485,10 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl {{"external", to_string(sink.external)}}); if (sink.external) { // External Hash Join + // Recursive preserved-build reuse only applies to the in-memory finalized HT. External hash join + // runs through repeated build/probe rounds with partition/probe-spill state, so it cannot be + // treated as a stable materialized build that later recursive iterations may skip entirely. + sink.preserve_build_for_reuse = false; sink.perfect_join_executor.reset(); const auto max_partition_ht_size = sink.max_partition_size + ht.PointerTableSize(sink.max_partition_count); @@ -1254,9 +1508,14 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl } else { // No repartitioning! for (auto &local_ht : sink.local_hash_tables) { - ht.Merge(*local_ht); + ht.Merge(local_ht.get()); } sink.local_hash_tables.clear(); + sink.owned_local_hash_tables.clear(); + if (filter_pushdown && !sink.skip_filter_pushdown && ht.GetSinkCollection().Count() > 0) { + auto filter_min_max = filter_pushdown->FinalizeMinMax(*sink.global_filter_state); + filter_pushdown->FinalizeFilters(context, *this, std::move(filter_min_max), &ht, nullptr); + } D_ASSERT(sink.temporary_memory_state->GetReservation() >= sink.probe_side_requirement); sink.hash_table->PrepareExternalFinalize(sink.temporary_memory_state->GetReservation() - sink.probe_side_requirement); @@ -1268,9 +1527,10 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl // In-memory Hash Join for (auto &local_ht : sink.local_hash_tables) { - ht.Merge(*local_ht); + ht.Merge(local_ht.get()); } sink.local_hash_tables.clear(); + sink.owned_local_hash_tables.clear(); ht.Unpartition(); Value min; @@ -1300,6 +1560,9 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl if (filter_min_max) { filter_pushdown->FinalizeFilters(context, *this, std::move(filter_min_max), &ht, sink.perfect_join_executor); + if (!use_perfect_hash) { + ht.PrepareBloomFilterForFinalize(); + } } // In case of a large build side or duplicates, use regular hash join @@ -1316,12 +1579,40 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl //===--------------------------------------------------------------------===// // Operator //===--------------------------------------------------------------------===// +//! Structural gate for the compressed-vector probe paths; evaluated once per operator state. +//! Decides whether to allocate ProbeState::dict_state; Probe() dispatches into the fast paths +//! whenever that allocation is present. +static bool CanUseCompressedProbe(const HashJoinGlobalSinkState &sink, const vector &conditions) { + if (sink.external) { + // external joins re-finalize the HT mid-probe, invalidating the per-id pointer cache + return false; + } + if (sink.perfect_join_executor) { + return false; + } + if (conditions.size() != 1) { + return false; + } + const auto cmp = conditions[0].GetComparisonType(); + if (cmp != ExpressionType::COMPARE_EQUAL && cmp != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + return false; + } + if (conditions[0].GetLHS().GetExpressionClass() != ExpressionClass::BOUND_REF) { + return false; + } + if (conditions[0].GetLHS().GetReturnType().InternalType() == PhysicalType::STRUCT) { + return false; + } + return true; +} + class HashJoinOperatorState : public CachingOperatorState { public: - explicit HashJoinOperatorState(ClientContext &context, HashJoinGlobalSinkState &sink) - : probe_executor(context), scan_structure(*sink.hash_table, join_key_state) { + HashJoinOperatorState(ClientContext &context, const PhysicalHashJoin &op_p, HashJoinGlobalSinkState &sink) + : op(op_p), probe_executor(context), scan_structure(*sink.hash_table, join_key_state) { } + const PhysicalHashJoin &op; DataChunk lhs_join_keys; TupleDataChunkState join_key_state; DataChunk lhs_probe_data; @@ -1339,12 +1630,37 @@ class HashJoinOperatorState : public CachingOperatorState { void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { context.thread.profiler.Flush(op); } + + bool SupportsReuse() const override { + return true; + } + + void Reset() override { + auto &sink = op.sink_state->Cast(); + ResetCachingState(); + lhs_join_keys.Reset(); + lhs_probe_data.Reset(); + scan_structure.Reset(); + perfect_hash_join_state.reset(); + spill_state = JoinHashTable::ProbeSpillLocalAppendState(); + TupleDataCollection::InitializeChunkState(join_key_state, op.condition_types); + if (spill_chunk.ColumnCount() == 0) { + spill_chunk.Initialize(BufferAllocator::Get(sink.context), sink.probe_types); + } else { + spill_chunk.Reset(); + } + // discard cached build-side pointers; the HT may have been reset between iterations (e.g. recursive CTE) + if (probe_state.dict_state) { + probe_state.dict_state = make_uniq(); + } + // perfect_hash_join_state will be lazily initialized on first Execute when we have a real ExecutionContext + } }; unique_ptr PhysicalHashJoin::GetOperatorState(ExecutionContext &context) const { auto &allocator = BufferAllocator::Get(context.client); auto &sink = sink_state->Cast(); - auto state = make_uniq(context.client, sink); + auto state = make_uniq(context.client, *this, sink); state->lhs_join_keys.Initialize(allocator, condition_types); // initialize probe data with ALL probe columns (output + predicate) @@ -1352,13 +1668,13 @@ unique_ptr PhysicalHashJoin::GetOperatorState(ExecutionContext &c state->lhs_probe_data.Initialize(allocator, lhs_probe_columns.col_types); } + for (auto &cond : conditions) { + state->probe_executor.AddExpression(cond.GetLHS()); + } + TupleDataCollection::InitializeChunkState(state->join_key_state, condition_types); + if (sink.perfect_join_executor) { state->perfect_hash_join_state = sink.perfect_join_executor->GetOperatorState(context); - } else { - for (auto &cond : conditions) { - state->probe_executor.AddExpression(cond.GetLHS()); - } - TupleDataCollection::InitializeChunkState(state->join_key_state, condition_types); } if (sink.external) { @@ -1366,6 +1682,11 @@ unique_ptr PhysicalHashJoin::GetOperatorState(ExecutionContext &c sink.InitializeProbeSpill(); } + if (CanUseCompressedProbe(sink, conditions)) { + // non-null dict_state is the enabling signal for the compressed-vector probe paths in Probe() + state->probe_state.dict_state = make_uniq(); + } + return std::move(state); } @@ -1388,6 +1709,10 @@ OperatorResultType PhysicalHashJoin::ExecuteInternal(ExecutionContext &context, if (sink.perfect_join_executor) { D_ASSERT(!sink.external); + // Lazily initialize perfect_hash_join_state on first use if needed + if (!state.perfect_hash_join_state) { + state.perfect_hash_join_state = sink.perfect_join_executor->GetOperatorState(context); + } // for perfect hash join, when predicate is NULL, only output columns are needed state.lhs_probe_data.ReferenceColumns(input, lhs_output_columns.col_idxs); return sink.perfect_join_executor->ProbePerfectHashTable(context, input, state.lhs_probe_data, chunk, @@ -1438,7 +1763,7 @@ class HashJoinLocalSourceState; class HashJoinGlobalSourceState : public GlobalSourceState { public: - HashJoinGlobalSourceState(const PhysicalHashJoin &op, const ClientContext &context); + HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context); //! Initialize this source state using the info in the sink void Initialize(HashJoinGlobalSinkState &sink); @@ -1466,6 +1791,10 @@ class HashJoinGlobalSourceState : public GlobalSourceState { return count / ((idx_t)STANDARD_VECTOR_SIZE * parallel_scan_chunk_count); } + bool SupportsReuse() const override { + return true; + } + public: const PhysicalHashJoin &op; @@ -1493,11 +1822,36 @@ class HashJoinGlobalSourceState : public GlobalSourceState { idx_t full_outer_chunks_per_thread = DConstants::INVALID_INDEX; vector blocked_tasks; + +private: + void ResetState(ClientContext &context) { + global_stage = HashJoinSourceStage::INIT; + build_chunk_idx = DConstants::INVALID_INDEX; + build_chunk_count = 0; + build_chunk_done = 0; + build_chunks_per_thread = DConstants::INVALID_INDEX; + probe_chunk_count = 0; + probe_chunk_done = 0; + probe_count = op.children[0].get().estimated_cardinality; + parallel_scan_chunk_count = context.config.verify_parallelism ? 1 : 120; + full_outer_chunk_idx = DConstants::INVALID_INDEX; + full_outer_chunk_count = 0; + full_outer_chunk_done = 0; + full_outer_chunks_per_thread = DConstants::INVALID_INDEX; + blocked_tasks.clear(); + GlobalSourceState::Reset(context); + } + +public: + void Reset(ClientContext &context) override { + ResetState(context); + } }; class HashJoinLocalSourceState : public LocalSourceState { public: - HashJoinLocalSourceState(const PhysicalHashJoin &op, const HashJoinGlobalSinkState &sink, Allocator &allocator); + HashJoinLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate, const PhysicalHashJoin &op, + const HashJoinGlobalSinkState &sink, Allocator &allocator); //! Do the work this thread has been assigned void ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); @@ -1509,6 +1863,7 @@ class HashJoinLocalSourceState : public LocalSourceState { void ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); public: + const PhysicalHashJoin &op; //! The stage that this thread was assigned work for HashJoinSourceStage local_stage; //! Vector with pointers here so we don't have to re-initialize @@ -1536,6 +1891,35 @@ class HashJoinLocalSourceState : public LocalSourceState { idx_t full_outer_chunk_idx_from = DConstants::INVALID_INDEX; idx_t full_outer_chunk_idx_to = DConstants::INVALID_INDEX; unique_ptr full_outer_scan_state; + +private: + void ResetState() { + local_stage = HashJoinSourceStage::INIT; + build_chunk_idx_from = DConstants::INVALID_INDEX; + build_chunk_idx_to = DConstants::INVALID_INDEX; + probe_local_scan.allocator = nullptr; + probe_local_scan.current_chunk_state.handles.clear(); + probe_local_scan.current_chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; + probe_local_scan.chunk_index = DConstants::INVALID_INDEX; + lhs_probe_chunk.Reset(); + lhs_join_keys.Reset(); + lhs_probe_data.Reset(); + TupleDataCollection::InitializeChunkState(join_key_state, op.condition_types); + scan_structure.Reset(); + empty_ht_probe_in_progress = false; + full_outer_chunk_idx_from = DConstants::INVALID_INDEX; + full_outer_chunk_idx_to = DConstants::INVALID_INDEX; + full_outer_scan_state.reset(); + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSourceState &gstate_p) override { + ResetState(); + } }; unique_ptr PhysicalHashJoin::GetGlobalSourceState(ClientContext &context) const { @@ -1544,15 +1928,12 @@ unique_ptr PhysicalHashJoin::GetGlobalSourceState(ClientConte unique_ptr PhysicalHashJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { - return make_uniq(*this, sink_state->Cast(), + return make_uniq(context, gstate, *this, sink_state->Cast(), BufferAllocator::Get(context.client)); } -HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, const ClientContext &context) - : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0), - probe_chunk_done(0), probe_count(op.children[0].get().estimated_cardinality), - parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120), full_outer_chunk_count(0), - full_outer_chunk_done(0) { +HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context) : op(op) { + ResetState(context); } void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) { @@ -1715,13 +2096,11 @@ bool HashJoinGlobalSourceState::AssignTask(HashJoinGlobalSinkState &sink, HashJo return false; } -HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, const HashJoinGlobalSinkState &sink, +HashJoinLocalSourceState::HashJoinLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate, + const PhysicalHashJoin &op, const HashJoinGlobalSinkState &sink, Allocator &allocator) - : local_stage(HashJoinSourceStage::INIT), addresses(LogicalType::POINTER), lhs_join_key_executor(sink.context), + : op(op), addresses(LogicalType::POINTER), lhs_join_key_executor(sink.context), scan_structure(*sink.hash_table, join_key_state) { - auto &chunk_state = probe_local_scan.current_chunk_state; - chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; - lhs_probe_chunk.Initialize(allocator, sink.probe_types); lhs_join_keys.Initialize(allocator, op.condition_types); @@ -1733,6 +2112,7 @@ HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, c for (auto &cond : op.conditions) { lhs_join_key_executor.AddExpression(cond.GetLHS()); } + ResetState(); } void HashJoinLocalSourceState::ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, @@ -1850,8 +2230,12 @@ SourceResultType PhysicalHashJoin::GetDataInternal(ExecutionContext &context, Da annotated_lock_guard guard(gstate.lock); if (gstate.global_stage != HashJoinSourceStage::DONE) { gstate.global_stage = HashJoinSourceStage::DONE; - sink.hash_table->Reset(); - sink.temporary_memory_state->SetZero(); + if (sink.preserve_build_for_reuse) { + sink.scanned_data = false; + } else { + sink.hash_table->Reset(); + sink.temporary_memory_state->SetZero(); + } } return SourceResultType::FINISHED; } @@ -1925,9 +2309,9 @@ InsertionOrderPreservingMap PhysicalHashJoin::ParamsToString() const { if (i > 0) { condition_info += "\n"; } - condition_info += StringUtil::Format("%s %s %s", join_condition.GetLHS().GetName(), + condition_info += StringUtil::Format("%s %s %s", join_condition.GetLHS().GetName().GetIdentifierName(), ExpressionTypeToOperator(join_condition.GetComparisonType()), - join_condition.GetRHS().GetName()); + join_condition.GetRHS().GetName().GetIdentifierName()); } if (predicate) { diff --git a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp index 875d625cf..347138baf 100644 --- a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp +++ b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp @@ -374,7 +374,7 @@ class IEJoinCursor { //! Read a typed cell const T &operator[](idx_t row_idx) { auto index = Seek(row_idx); - auto &source = chunk.data[0]; + const auto &source = chunk.data[0]; const auto data_ptr = reinterpret_cast(FlatVector::GetData(source)); return data_ptr[index]; } @@ -614,7 +614,6 @@ idx_t IEJoinUnion::AppendKey(ExecutionContext &context, InterruptState &interrup // Mark the rid column payload.data[0].Sequence(rid, increment, scan_count); - payload.SetCardinality(scan_count); keys.Fuse(payload); rid += increment * UnsafeNumericCast(scan_count); @@ -1053,20 +1052,28 @@ class IEJoinLocalSourceState : public LocalSourceState { void SplitPayloads(DataChunk &chunk); // Apply any arbitrary predicate and return the new SV const SelectionVector *ApplyArbitraryPredicate(DataChunk &chunk); + // resolve joins that can potentially output N elements (SEMI, ANTI, MARK) + void ResolveSimpleJoin(ExecutionContext &context); // Use the given SV to strip out LHS duplicates. Return the number of unique rows idx_t FilterSemiJoin(const SelectionVector *sel); + // Build a set of values and matches fror ANTI and MARK joins + idx_t FindSimpleMatches(ExecutionContext &context, bool *found_match = nullptr); // Resolve SEMI joins void ResolveSemiJoin(ExecutionContext &context, DataChunk &result); // Resolve ANTI joins void ResolveAntiJoin(ExecutionContext &context, DataChunk &result); + // Resolve MARK joins + void ResolveMarkJoin(ExecutionContext &context, DataChunk &result); // resolve joins that can potentially output N*M elements (INNER, LEFT, RIGHT, FULL) void ResolveComplexJoin(ExecutionContext &context, DataChunk &result); // Resolve left join results void ExecuteLeftTask(ExecutionContext &context, DataChunk &result); // Resolve right join results void ExecuteRightTask(ExecutionContext &context, DataChunk &result); - // Resolve anti join results from NULL columns. + // Resolve ANTI join results from NULL columns. void ExecuteAntiTask(ExecutionContext &context, DataChunk &result); + // Resolve MARK join results from NULL columns. + void ExecuteMarkTask(ExecutionContext &context, DataChunk &result); // Execute the current task void ExecuteTask(ExecutionContext &context, DataChunk &result, InterruptState &interrupt); @@ -1314,7 +1321,11 @@ void IEJoinLocalSourceState::ExecuteTask(ExecutionContext &context, DataChunk &r } break; case IEJoinSourceStage::ANTI: - ExecuteAntiTask(context, result); + if (gsource.op.join_type == JoinType::ANTI) { + ExecuteAntiTask(context, result); + } else { + ExecuteMarkTask(context, result); + } break; } } @@ -1328,6 +1339,9 @@ void IEJoinLocalSourceState::ResolveInnerJoin(ExecutionContext &context, DataChu case JoinType::ANTI: ResolveAntiJoin(context, result); break; + case JoinType::MARK: + ResolveMarkJoin(context, result); + break; case JoinType::LEFT: case JoinType::INNER: case JoinType::RIGHT: @@ -1399,7 +1413,7 @@ void IEJoinLocalSourceState::MergePayloads(DataChunk &chunk) { chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); } } - chunk.SetCardinality(lpayload.size()); + chunk.SetChildCardinality(lpayload.size()); } void IEJoinLocalSourceState::SplitPayloads(DataChunk &chunk) { @@ -1415,8 +1429,6 @@ void IEJoinLocalSourceState::SplitPayloads(DataChunk &chunk) { rpayload.data[col_idx - left_cols].Reference(chunk.data[col_idx - left_cols]); } } - lpayload.SetCardinality(chunk.size()); - rpayload.SetCardinality(chunk.size()); } const SelectionVector *IEJoinLocalSourceState::ApplyArbitraryPredicate(DataChunk &chunk) { @@ -1459,7 +1471,7 @@ idx_t IEJoinLocalSourceState::FilterSemiJoin(const SelectionVector *sel) { return unique_count; } -void IEJoinLocalSourceState::ResolveSemiJoin(ExecutionContext &context, DataChunk &result) { +void IEJoinLocalSourceState::ResolveSimpleJoin(ExecutionContext &context) { auto &op = gsource.op; const auto &conditions = op.conditions; D_ASSERT(conditions.size() >= 2); @@ -1467,6 +1479,8 @@ void IEJoinLocalSourceState::ResolveSemiJoin(ExecutionContext &context, DataChun auto &ie_sink = op.sink_state->Cast(); auto &left_table = *ie_sink.tables[0]; + lpayload.Reset(); + do { auto result_count = joiner->JoinBlocks(lsel, rsel); if (result_count == 0) { @@ -1533,7 +1547,6 @@ void IEJoinLocalSourceState::ResolveSemiJoin(ExecutionContext &context, DataChun left_table.Repin(*left_iterator); op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, lsel, *left_scan_state); - lpayload.SetCardinality(result_count); } // Handle chunk boundaries: If we found a match for the last value @@ -1542,11 +1555,63 @@ void IEJoinLocalSourceState::ResolveSemiJoin(ExecutionContext &context, DataChun if (joiner->lrid > 0 && lsel[result_count - 1] == UnsafeNumericCast(joiner->lrid - 1)) { joiner->FinishRow(); } + } while (lpayload.size() == 0); +} - // SEMI JOINs return all the columns from the LHS - result.Reference(lpayload); - result.Verify(context.client.db); - } while (result.size() == 0); +void IEJoinLocalSourceState::ResolveSemiJoin(ExecutionContext &context, DataChunk &result) { + // SEMI JOINs return all the columns from the LHS + ResolveSimpleJoin(context); + result.Reference(lpayload); + result.Verify(context.client.db); +} + +idx_t IEJoinLocalSourceState::FindSimpleMatches(ExecutionContext &context, bool *found_match) { + // Merge lsel and unmatched LHS rids into (the unused) rsel, tracking the matches + auto &msel = rsel; + idx_t result_count = 0; + msel.resize(STANDARD_VECTOR_SIZE); + + // Scan through Li, merging anything in lsel. + // They are both ordered in the same way, so we can ratchet + const auto i = MinValue(joiner->i, joiner->n); + auto &anti_i = joiner->anti_i; + auto &p = joiner->p; + auto &li = joiner->li; + + // Scan through the SEMI JOIN rids + // Note that if there are no more matches, then lsel will be empty() + // and i will be at the end of the range + while (anti_i < i) { + // Get the next lrid + auto pos = p[anti_i++]; + auto rid = li[pos]; + if (rid <= 0) { + continue; + } + const auto lrid = UnsafeNumericCast(rid - 1); + + // If the lrid is in the SEMI JOIN, then merge it and mark the row as matched. + if (anti_lsel < lsel.size() && lsel[anti_lsel] == lrid) { + ++anti_lsel; + // If we are not tracking matches, then only keep the unmatched rows + if (!found_match) { + continue; + } + found_match[result_count] = true; + } else if (found_match) { + // Found an unmatched rid, so mark it as unmatched + found_match[result_count] = false; + } + msel[result_count++] = lrid; + + // If we have a full vector, stop + if (result_count >= STANDARD_VECTOR_SIZE) { + break; + } + } + msel.resize(result_count); + + return result_count; } void IEJoinLocalSourceState::ResolveAntiJoin(ExecutionContext &context, DataChunk &result) { @@ -1560,49 +1625,12 @@ void IEJoinLocalSourceState::ResolveAntiJoin(ExecutionContext &context, DataChun // when we run out of anti-matches do { if (anti_lsel >= lsel.size()) { - ResolveSemiJoin(context, result); - result.Reset(); + ResolveSimpleJoin(context); anti_lsel = 0; } - // Scan through Li, skipping anything in lsel. - // They are both ordered in the same way, so we can ratchet - const auto i = MinValue(joiner->i, joiner->n); - auto &anti_i = joiner->anti_i; - auto &p = joiner->p; - auto &li = joiner->li; - - // Put the results in rsel (as that is no longer needed) - idx_t result_count = 0; - rsel.resize(STANDARD_VECTOR_SIZE); - - // Scan through the SEMI JOIN rids - // Note that if there are no more matches, then lsel will be empty() - // and i will be at the end of the range - while (anti_i < i) { - // Get the next lrid - auto pos = p[anti_i++]; - auto rid = li[pos]; - if (rid <= 0) { - continue; - } - const auto lrid = UnsafeNumericCast(rid - 1); - - // If the lrid is in the SEMI JOIN, then skip it - if (anti_lsel < lsel.size() && lsel[anti_lsel] == lrid) { - ++anti_lsel; - continue; - } - - // Found an unmatched rid, so remember it - rsel[result_count++] = lrid; - - // If we have a full vector, stop - if (result_count >= STANDARD_VECTOR_SIZE) { - break; - } - } - rsel.resize(result_count); + // Put all the non-matches into rsel + const idx_t result_count = FindSimpleMatches(context); if (result_count == 0) { // If there were no SEMI rows and we didn't find any past the last one, @@ -1617,13 +1645,47 @@ void IEJoinLocalSourceState::ResolveAntiJoin(ExecutionContext &context, DataChun left_table.Repin(*left_iterator); op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, rsel, *left_scan_state); - lpayload.SetCardinality(result_count); result.Reference(lpayload); result.Verify(context.client.db); } while (result.size() == 0); } +void IEJoinLocalSourceState::ResolveMarkJoin(ExecutionContext &context, DataChunk &result) { + // MARK JOINs return all the columns from the LHS plus the mark column + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); + auto &left_table = *ie_sink.tables[0]; + auto &right_table = *ie_sink.tables[1]; + + // We need to have _all_ the rows, so we perform the SEMI JOIN, + // then go back and merge in the ANTI JOIN rows. + // This means we might not consume all the SEMI rows, so only get more when we have run out. + if (anti_lsel >= lsel.size()) { + ResolveSimpleJoin(context); + anti_lsel = 0; + } + + // Merge lsel and unmatched LHS rids into (the unused) rsel, tracking the matches + bool found_match[STANDARD_VECTOR_SIZE]; + FindSimpleMatches(context, found_match); + + // Read the lhs rows + left_table.Repin(*left_iterator); + op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, rsel, + *left_scan_state); + + // Compute the residual keys + // We know here that the two sort columns are not NULL, so the tail columns are all we need + left_executor.Execute(&lpayload, left_keys); + + // Now hand it off to code that knows the rules... + // Note that it can handle left_keys.ColumnCount().empty() + PhysicalJoin::ConstructMarkJoinResult(left_keys, lpayload, result, found_match, right_table.has_null); + + result.Verify(context.client.db); +} + void IEJoinLocalSourceState::ResolveComplexJoin(ExecutionContext &context, DataChunk &result) { auto &op = gsource.op; auto &ie_sink = op.sink_state->Cast(); @@ -1733,7 +1795,7 @@ void IEJoinGlobalSourceState::Initialize() { stage_tasks.emplace_back(left_outers + right_outers); // ANTI - if (op.join_type == JoinType::ANTI) { + if (op.join_type == JoinType::ANTI || op.join_type == JoinType::MARK) { auto &left_table = *gsink.tables[0]; const auto null_block = (left_table.count - left_table.has_null) / STANDARD_VECTOR_SIZE; stage_tasks.emplace_back(left_blocks - null_block); @@ -1972,7 +2034,6 @@ void IEJoinLocalSourceState::ExecuteLeftTask(ExecutionContext &context, DataChun } op.ProjectResult(chunk, result); - result.SetCardinality(count); result.Verify(context.client.db); } @@ -2005,7 +2066,6 @@ void IEJoinLocalSourceState::ExecuteRightTask(ExecutionContext &context, DataChu } op.ProjectResult(chunk, result); - result.SetCardinality(count); result.Verify(context.client.db); } @@ -2034,6 +2094,43 @@ void IEJoinLocalSourceState::ExecuteAntiTask(ExecutionContext &context, DataChun result.Verify(context.client.db); } +void IEJoinLocalSourceState::ExecuteMarkTask(ExecutionContext &context, DataChunk &result) { + auto &op = gsource.op; + auto &ie_sink = op.sink_state->Cast(); + + outer_sel.reserve(STANDARD_VECTOR_SIZE); + outer_sel.resize(0); + while (outer_idx < outer_count) { + outer_sel.emplace_back(outer_idx++); + if (outer_sel.size() >= STANDARD_VECTOR_SIZE) { + break; + } + } + const idx_t count = outer_sel.size(); + if (!count) { + return; + } + + auto &left_table = *ie_sink.tables[0]; + left_table.Repin(*left_iterator); + op.SliceSortedPayload(lpayload, left_table, *left_iterator, left_chunk_state, left_block_index, outer_sel, + *left_scan_state); + + // for the initial set of columns we just reference the left side + result.SetChildCardinality(lpayload.size()); + for (idx_t i = 0; i < lpayload.ColumnCount(); i++) { + result.data[i].Reference(lpayload.data[i]); + } + + // One of the two main keys is NULL, so the result is NULL + auto &mark_vector = result.data.back(); + mark_vector.SetVectorType(VectorType::FLAT_VECTOR); + auto &mask = FlatVector::ValidityMutable(mark_vector); + mask.SetAllInvalid(count); + + result.Verify(context.client.db); +} + //===--------------------------------------------------------------------===// // Pipeline Construction //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp index 55a085dfb..6bc30050d 100644 --- a/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp @@ -9,6 +9,29 @@ namespace duckdb { +static void LeftDelimBindScanCollection(const PhysicalLeftDelimJoin &delim_join, ColumnDataCollection &collection) { + auto &cast_cached_scan = delim_join.join.children[0].get().Cast(); + cast_cached_scan.collection = &collection; +} + +static void LeftDelimResetSinkState(const PhysicalOperator &op, ClientContext &context, + unique_ptr &state) { + if (state && state->SupportsReuse()) { + state->Reset(context); + return; + } + state = op.GetGlobalSinkState(context); +} + +static void LeftDelimResetLocalSinkState(const PhysicalOperator &op, ExecutionContext &context, GlobalSinkState &gstate, + unique_ptr &state) { + if (state && state->SupportsReuse()) { + state->Reset(context, gstate); + return; + } + state = op.GetLocalSinkState(context); +} + PhysicalLeftDelimJoin::PhysicalLeftDelimJoin(PhysicalPlan &physical_plan, PhysicalPlanGenerator &planner, vector types, PhysicalOperator &original_join, PhysicalOperator &distinct, @@ -37,19 +60,40 @@ PhysicalLeftDelimJoin::PhysicalLeftDelimJoin(PhysicalPlan &physical_plan, Physic //===--------------------------------------------------------------------===// class LeftDelimJoinGlobalState : public GlobalSinkState { public: - explicit LeftDelimJoinGlobalState(ClientContext &context, const PhysicalLeftDelimJoin &delim_join) - : lhs_data(context, delim_join.children[0].get().GetTypes()) { - D_ASSERT(!delim_join.delim_scans.empty()); - // set up the delim join chunk to scan in the original join - auto &cast_cached_scan = delim_join.join.children[0].get().Cast(); - cast_cached_scan.collection = &lhs_data; + explicit LeftDelimJoinGlobalState(ClientContext &context, const PhysicalLeftDelimJoin &delim_join_p) + : delim_join(delim_join_p), lhs_data(context, delim_join_p.children[0].get().GetTypes()) { + D_ASSERT(!delim_join_p.delim_scans.empty()); + LeftDelimBindScanCollection(delim_join_p, lhs_data); } + const PhysicalLeftDelimJoin &delim_join; ColumnDataCollection lhs_data; mutex lhs_lock; + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + lhs_data.ResetForReuse(); + LeftDelimBindScanCollection(delim_join, lhs_data); + LeftDelimResetSinkState(delim_join.distinct, context, delim_join.distinct.sink_state); + if (delim_join.delim_scans.size() > 1) { + PhysicalHashAggregate::SetMultiScan(*delim_join.distinct.sink_state); + } + GlobalSinkState::Reset(context); + } + void Merge(ColumnDataCollection &input) { + if (input.Count() == 0) { + return; + } lock_guard guard(lhs_lock); + if (lhs_data.Count() == 0) { + lhs_data.Swap(input); + LeftDelimBindScanCollection(delim_join, lhs_data); + return; + } lhs_data.Combine(input); } }; @@ -57,14 +101,26 @@ class LeftDelimJoinGlobalState : public GlobalSinkState { class LeftDelimJoinLocalState : public LocalSinkState { public: explicit LeftDelimJoinLocalState(ClientContext &context, const PhysicalLeftDelimJoin &delim_join) - : lhs_data(context, delim_join.children[0].get().GetTypes()) { + : delim_join(delim_join), lhs_data(context, delim_join.children[0].get().GetTypes()) { lhs_data.InitializeAppend(append_state); } + const PhysicalLeftDelimJoin &delim_join; + unique_ptr distinct_state; ColumnDataCollection lhs_data; ColumnDataAppendState append_state; + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSinkState &gstate) override { + lhs_data.ResetForReuse(); + lhs_data.InitializeAppend(append_state); + LeftDelimResetLocalSinkState(delim_join.distinct, context, *delim_join.distinct.sink_state, distinct_state); + } + void Append(DataChunk &input) { lhs_data.Append(input); } diff --git a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp index 8f1cf3841..60ca5845f 100644 --- a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp @@ -60,7 +60,6 @@ static void ConstructSemiOrAntiJoinResult(DataChunk &left, DataChunk &result, bo // reference the columns of the left side from the result result.Slice(left, sel, result_count); } else { - result.SetCardinality(0); } } @@ -75,7 +74,7 @@ void PhysicalJoin::ConstructAntiJoinResult(DataChunk &left, DataChunk &result, b void PhysicalJoin::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &left, DataChunk &result, bool found_match[], bool has_null) { // for the initial set of columns we just reference the left side - result.SetCardinality(left); + result.SetChildCardinality(left.size()); for (idx_t i = 0; i < left.ColumnCount(); i++) { result.data[i].Reference(left.data[i]); } @@ -149,15 +148,13 @@ bool PhysicalNestedLoopJoin::IsSupported(const vector &conditions class NestedLoopJoinGlobalState : public GlobalSinkState { public: explicit NestedLoopJoinGlobalState(ClientContext &context, const PhysicalNestedLoopJoin &op) - : right_payload_data(context, op.children[1].get().GetTypes()), + : op(op), right_payload_data(context, op.children[1].get().GetTypes()), right_condition_data(context, op.GetJoinTypes()), has_null(false), right_outer(PropagatesBuildSide(op.join_type)) { - if (op.filter_pushdown) { - skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); - global_filter_state = op.filter_pushdown->GetGlobalState(context, op); - } + ResetState(context); } + const PhysicalNestedLoopJoin &op; mutex nj_lock; //! Materialized data of the RHS ColumnDataCollection right_payload_data; @@ -171,31 +168,74 @@ class NestedLoopJoinGlobalState : public GlobalSinkState { bool skip_filter_pushdown = false; //! The global filter states to push down (if any) unique_ptr global_filter_state; + +private: + void ResetState(ClientContext &context) { + right_payload_data.Reset(); + right_condition_data.Reset(); + has_null = false; + right_outer.Reset(); + if (op.filter_pushdown) { + skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); + global_filter_state = op.filter_pushdown->GetGlobalState(context, op); + } else { + skip_filter_pushdown = false; + global_filter_state.reset(); + } + GlobalSinkState::Reset(context); + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + ResetState(context); + } }; class NestedLoopJoinLocalState : public LocalSinkState { public: - explicit NestedLoopJoinLocalState(ClientContext &context, const PhysicalNestedLoopJoin &op, + explicit NestedLoopJoinLocalState(ExecutionContext &context, const PhysicalNestedLoopJoin &op, NestedLoopJoinGlobalState &gstate) - : rhs_executor(context) { + : op(op), rhs_executor(context.client) { vector condition_types; for (auto &cond : op.conditions) { rhs_executor.AddExpression(cond.GetRHS()); condition_types.push_back(cond.GetRHS().GetReturnType()); } - right_condition.Initialize(Allocator::Get(context), condition_types); - - if (op.filter_pushdown) { - local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); - } + right_condition.Initialize(Allocator::Get(context.client), condition_types); + ResetState(gstate); } + const PhysicalNestedLoopJoin &op; //! The chunk holding the right condition DataChunk right_condition; //! The executor of the RHS condition ExpressionExecutor rhs_executor; //! Local state for accumulating filter statistics unique_ptr local_filter_state; + +private: + void ResetState(NestedLoopJoinGlobalState &gstate) { + right_condition.Reset(); + if (op.filter_pushdown) { + local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); + } else { + local_filter_state.reset(); + } + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSinkState &gstate_p) override { + auto &gstate = gstate_p.Cast(); + ResetState(gstate); + } }; vector PhysicalNestedLoopJoin::GetJoinTypes() const { @@ -268,7 +308,7 @@ unique_ptr PhysicalNestedLoopJoin::GetGlobalSinkState(ClientCon unique_ptr PhysicalNestedLoopJoin::GetLocalSinkState(ExecutionContext &context) const { auto &gstate = sink_state->Cast(); - return make_uniq(context.client, *this, gstate); + return make_uniq(context, *this, gstate); } //===--------------------------------------------------------------------===// @@ -278,8 +318,7 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { public: PhysicalNestedLoopJoinState(ClientContext &context, const PhysicalNestedLoopJoin &op, const vector &conditions) - : fetch_next_left(true), fetch_next_right(false), lhs_executor(context), left_tuple(0), right_tuple(0), - left_outer(IsLeftOuterJoin(op.join_type)), pred_executor(context) { + : lhs_executor(context), left_outer(IsLeftOuterJoin(op.join_type)), pred_executor(context) { vector condition_types; for (auto &cond : conditions) { lhs_executor.AddExpression(cond.GetLHS()); @@ -295,6 +334,7 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { pred_executor.AddExpression(*op.predicate); pred_matches.Initialize(); } + ResetState(); } bool fetch_next_left; @@ -317,10 +357,33 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { ExpressionExecutor pred_executor; SelectionVector pred_matches; +private: + void ResetState() { + ResetCachingState(); + fetch_next_left = true; + fetch_next_right = false; + left_condition.Reset(); + condition_scan_state = ColumnDataScanState(); + payload_scan_state = ColumnDataScanState(); + right_condition.Reset(); + right_payload.Reset(); + left_tuple = 0; + right_tuple = 0; + left_outer.Reset(); + } + public: void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { context.thread.profiler.Flush(op); } + + bool SupportsReuse() const override { + return true; + } + + void Reset() override { + ResetState(); + } }; unique_ptr PhysicalNestedLoopJoin::GetOperatorState(ExecutionContext &context) const { @@ -487,40 +550,72 @@ OperatorResultType PhysicalNestedLoopJoin::ResolveComplexJoin(ExecutionContext & //===--------------------------------------------------------------------===// class NestedLoopJoinGlobalScanState : public GlobalSourceState { public: - explicit NestedLoopJoinGlobalScanState(const PhysicalNestedLoopJoin &op) : op(op) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(sink.right_payload_data, scan_state); + explicit NestedLoopJoinGlobalScanState(const PhysicalNestedLoopJoin &op, ClientContext &context) : op(op) { + ResetState(context); } const PhysicalNestedLoopJoin &op; OuterJoinGlobalScanState scan_state; +private: + void ResetState(ClientContext &context) { + auto &sink = op.sink_state->Cast(); + sink.right_outer.InitializeScan(sink.right_payload_data, scan_state); + GlobalSourceState::Reset(context); + } + public: idx_t MaxThreads() override { auto &sink = op.sink_state->Cast(); return sink.right_outer.MaxThreads(); } + + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + ResetState(context); + } }; class NestedLoopJoinLocalScanState : public LocalSourceState { public: - explicit NestedLoopJoinLocalScanState(const PhysicalNestedLoopJoin &op, NestedLoopJoinGlobalScanState &gstate) { - D_ASSERT(op.sink_state); + explicit NestedLoopJoinLocalScanState(ExecutionContext &context, const PhysicalNestedLoopJoin &op, + NestedLoopJoinGlobalScanState &gstate) + : op(op) { + ResetState(gstate); + } + + const PhysicalNestedLoopJoin &op; + OuterJoinLocalScanState scan_state; + +private: + void ResetState(NestedLoopJoinGlobalScanState &gstate) { auto &sink = op.sink_state->Cast(); sink.right_outer.InitializeScan(gstate.scan_state, scan_state); } - OuterJoinLocalScanState scan_state; +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSourceState &gstate_p) override { + auto &gstate = gstate_p.Cast(); + // The source only scans unmatched RHS rows for RIGHT/FULL OUTER joins. The materialized RHS payload and + // match bitmap live in the sink state, so a reused local source state must rebind its scan to that data. + ResetState(gstate); + } }; unique_ptr PhysicalNestedLoopJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); + return make_uniq(*this, context); } unique_ptr PhysicalNestedLoopJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { - return make_uniq(*this, gstate.Cast()); + return make_uniq(context, *this, gstate.Cast()); } SourceResultType PhysicalNestedLoopJoin::GetDataInternal(ExecutionContext &context, DataChunk &chunk, diff --git a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp index 5f5b970a6..4a9fdb22c 100644 --- a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp @@ -63,15 +63,12 @@ class MergeJoinGlobalState : public GlobalSinkState { using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - MergeJoinGlobalState(ClientContext &client, const PhysicalPiecewiseMergeJoin &op) { + MergeJoinGlobalState(ClientContext &client, const PhysicalPiecewiseMergeJoin &op) : op(op) { const auto &rhs_types = op.children[1].get().GetTypes(); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); table = make_uniq(client, rhs_order, rhs_types, op); - if (op.filter_pushdown) { - skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); - global_filter_state = op.filter_pushdown->GetGlobalState(client, op); - } + ResetState(client); } inline idx_t Count() const { @@ -80,12 +77,35 @@ class MergeJoinGlobalState : public GlobalSinkState { void Sink(ExecutionContext &context, DataChunk &input, MergeJoinLocalState &lstate); + const PhysicalPiecewiseMergeJoin &op; //! The sorted table unique_ptr table; //! Should we not bother pushing down filters? bool skip_filter_pushdown = false; //! The global filter states to push down (if any) unique_ptr global_filter_state; + +private: + void ResetState(ClientContext &client) { + table->ResetForReuse(client); + if (op.filter_pushdown) { + skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); + global_filter_state = op.filter_pushdown->GetGlobalState(client, op); + } else { + skip_filter_pushdown = false; + global_filter_state.reset(); + } + GlobalSinkState::Reset(client); + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &client) override { + ResetState(client); + } }; class MergeJoinLocalState : public LocalSinkState { @@ -95,16 +115,34 @@ class MergeJoinLocalState : public LocalSinkState { MergeJoinLocalState(ExecutionContext &context, MergeJoinGlobalState &gstate, const idx_t child) : table(context, *gstate.table, child) { - auto &op = gstate.table->op; - if (op.filter_pushdown) { - local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); - } + ResetState(gstate); } //! The local sort state LocalSortedTable table; //! Local state for accumulating filter statistics unique_ptr local_filter_state; + +private: + void ResetState(MergeJoinGlobalState &gstate) { + auto &op = gstate.table->op; + if (op.filter_pushdown) { + local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); + } else { + local_filter_state.reset(); + } + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSinkState &gstate_p) override { + auto &gstate = gstate_p.Cast(); + table.ResetForReuse(context); + ResetState(gstate); + } }; unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSinkState(ClientContext &context) const { @@ -198,6 +236,7 @@ class PiecewiseMergeJoinState : public CachingOperatorState { // Sort on the first column lhs_order.emplace_back(op.lhs_orders[0].Copy()); + lhs_global_table = make_uniq(client, lhs_order, op.children[0].get().GetTypes(), op); // Set up shared data for multiple predicates sel.Initialize(STANDARD_VECTOR_SIZE); @@ -210,13 +249,7 @@ class PiecewiseMergeJoinState : public CachingOperatorState { rhs_input.Initialize(client, op.children[1].get().GetTypes()); auto &gsink = op.sink_state->Cast(); - auto &rhs_table = *gsink.table; - rhs_iterator = rhs_table.CreateIteratorState(); - rhs_table.InitializePayloadState(rhs_chunk_state); - rhs_scan_state = rhs_table.CreateScanState(client); - - // Since we have now materialized the payload, the keys will not have payloads? - sort_key_type = rhs_table.GetSortKeyType(); + InitializeRightScanState(gsink); if (op.predicate) { pred_executor.AddExpression(*op.predicate); @@ -263,11 +296,50 @@ class PiecewiseMergeJoinState : public CachingOperatorState { SelectionVector pred_matches; public: + bool SupportsReuse() const override { + return true; + } + + void Reset() override { + ResetCachingState(); + left_outer.Reset(); + lhs_payload.Reset(); + lhs_scan.Reset(); + left_position = 0; + first_fetch = true; + finished = true; + lhs_iterator.reset(); + rhs_iterator.reset(); + right_position = 0; + right_chunk_index = 0; + right_base = 0; + prev_left_index = 0; + rhs_chunk_state.ResetForScan(); + rhs_keys.Reset(); + rhs_input.Reset(); + lhs_global_table->ResetForReuse(client); + } + + void InitializeRightScanState(MergeJoinGlobalState &gsink) { + auto &rhs_table = *gsink.table; + rhs_iterator = rhs_table.CreateIteratorState(); + rhs_chunk_state.ResetForScan(); + rhs_table.InitializePayloadState(rhs_chunk_state); + if (!rhs_scan_state) { + rhs_scan_state = rhs_table.CreateScanState(client); + } + // Since we have now materialized the payload, the keys will not have payloads? + sort_key_type = rhs_table.GetSortKeyType(); + } + void ResolveJoinKeys(ExecutionContext &context, DataChunk &input) { // sort by join key - const auto &lhs_types = lhs_payload.GetTypes(); - lhs_global_table = make_uniq(context.client, lhs_order, lhs_types, op); - lhs_local_table = make_uniq(context, *lhs_global_table, 0U); + lhs_global_table->ResetForReuse(context.client); + if (!lhs_local_table) { + lhs_local_table = make_uniq(context, *lhs_global_table, 0U); + } else { + lhs_local_table->ResetForReuse(context); + } lhs_local_table->Sink(context, input); lhs_global_table->Combine(context, *lhs_local_table); @@ -575,6 +647,7 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte do { if (state.first_fetch) { state.ResolveJoinKeys(context, input); + state.InitializeRightScanState(gstate); state.lhs_payload.Verify(context.client.db); state.right_chunk_index = 0; @@ -635,7 +708,6 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte chunk.data[col_idx].Reference(state.rhs_input.data[col_idx - left_cols]); } } - chunk.SetCardinality(result_count); auto sel = FlatVector::IncrementalSelectionVector(); if (tail_cols) { @@ -672,7 +744,6 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte } } } - chunk.SetCardinality(result_count); // Apply any arbitrary predicate if (predicate) { @@ -831,7 +902,6 @@ SourceResultType PhysicalPiecewiseMergeJoin::GetDataInternal(ExecutionContext &c for (idx_t col_idx = 0; col_idx < right_column_count; ++col_idx) { result.data[left_column_count + col_idx].Slice(rhs_chunk.data[col_idx], rsel, result_count); } - result.SetCardinality(result_count); break; } } diff --git a/src/duckdb/src/execution/operator/join/physical_positional_join.cpp b/src/duckdb/src/execution/operator/join/physical_positional_join.cpp index de4178ae0..035858e18 100644 --- a/src/duckdb/src/execution/operator/join/physical_positional_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_positional_join.cpp @@ -20,7 +20,8 @@ PhysicalPositionalJoin::PhysicalPositionalJoin(PhysicalPlan &physical_plan, vect class PositionalJoinGlobalState : public GlobalSinkState { public: explicit PositionalJoinGlobalState(ClientContext &context, const PhysicalPositionalJoin &op) - : rhs(context, op.children[1].get().GetTypes()), initialized(false), source_offset(0), exhausted(false) { + : rhs(context, op.children[1].get().GetTypes()), initialized(false), source_count(0), source_offset(0), + exhausted(false) { rhs.InitializeAppend(append_state); } @@ -31,6 +32,7 @@ class PositionalJoinGlobalState : public GlobalSinkState { bool initialized; ColumnDataScanState scan_state; DataChunk source; + idx_t source_count; idx_t source_offset; bool exhausted; @@ -66,18 +68,21 @@ void PositionalJoinGlobalState::InitializeScan() { } idx_t PositionalJoinGlobalState::Refill() { - if (source_offset >= source.size()) { + // Use source_count (not source.size()) to avoid corruption from shared buffers after Reference+SetChildCardinality + if (source_offset >= source_count) { if (!exhausted) { source.Reset(); rhs.Scan(scan_state, source); + source_count = source.size(); } source_offset = 0; } - const auto available = source.size() - source_offset; + const auto available = source_count - source_offset; if (!available) { if (!exhausted) { source.Reset(); + source_count = 0; for (idx_t i = 0; i < source.ColumnCount(); ++i) { auto &vec = source.data[i]; ConstantVector::SetNull(vec, count_t(STANDARD_VECTOR_SIZE)); @@ -90,7 +95,7 @@ idx_t PositionalJoinGlobalState::Refill() { } idx_t PositionalJoinGlobalState::CopyData(DataChunk &output, const idx_t count, const idx_t col_offset) { - if (!source_offset && (source.size() >= count || exhausted)) { + if (!source_offset && (source_count >= count || exhausted)) { // Fast track: aligned and has enough data for (idx_t i = 0; i < source.ColumnCount(); ++i) { output.data[col_offset + i].Reference(source.data[i]); @@ -100,15 +105,15 @@ idx_t PositionalJoinGlobalState::CopyData(DataChunk &output, const idx_t count, // Copy data for (idx_t target_offset = 0; target_offset < count;) { const auto needed = count - target_offset; - const auto available = exhausted ? needed : (source.size() - source_offset); - const auto copy_size = MinValue(needed, available); - const auto source_count = source_offset + copy_size; + const auto available = exhausted ? needed : (source_count - source_offset); + const auto copy_count = MinValue(needed, available); + const auto source_end = source_offset + copy_count; for (idx_t i = 0; i < source.ColumnCount(); ++i) { - VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_count, source_offset, + VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_end, source_offset, target_offset); } - target_offset += copy_size; - source_offset += copy_size; + target_offset += copy_count; + source_offset += copy_count; Refill(); } } @@ -153,13 +158,12 @@ void PositionalJoinGlobalState::GetData(DataChunk &output) { // LHS exhausted if (exhausted) { // RHS exhausted too, so we are done - output.SetCardinality(0); return; } // LHS is all NULL const auto col_offset = output.ColumnCount() - source.ColumnCount(); - const auto count = MinValue(STANDARD_VECTOR_SIZE, source.size() - source_offset); + const auto count = MinValue(STANDARD_VECTOR_SIZE, source_count - source_offset); for (idx_t i = 0; i < col_offset; ++i) { auto &vec = output.data[i]; ConstantVector::SetNull(vec, count_t(count)); diff --git a/src/duckdb/src/execution/operator/join/physical_range_join.cpp b/src/duckdb/src/execution/operator/join/physical_range_join.cpp index 70dba3524..63073c204 100644 --- a/src/duckdb/src/execution/operator/join/physical_range_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_range_join.cpp @@ -39,6 +39,14 @@ PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ExecutionContext &context, sort_chunk.InitializeEmpty(types); } +void PhysicalRangeJoin::LocalSortedTable::ResetForReuse(ExecutionContext &context) { + local_sink = global_table.sort->GetLocalSinkState(context); + has_null = 0; + count = 0; + keys.Reset(); + sort_chunk.Reset(); +} + void PhysicalRangeJoin::LocalSortedTable::Sink(ExecutionContext &context, DataChunk &input) { // Obtain sorting columns keys.Reset(); @@ -62,8 +70,7 @@ void PhysicalRangeJoin::LocalSortedTable::Sink(ExecutionContext &context, DataCh for (column_t col_idx = 0; col_idx < input.ColumnCount(); ++col_idx) { sort_chunk.data[col_idx + 1].Reference(input.data[col_idx]); } - sort_chunk.SetCardinality(input); - + sort_chunk.SetChildCardinality(input.size()); // Sink the data into the local sort state InterruptState interrupt; OperatorSinkInput sink {*global_table.global_sink, *local_sink, interrupt}; @@ -105,6 +112,16 @@ void PhysicalRangeJoin::GlobalSortedTable::Combine(ExecutionContext &context, Lo count += ltable.count; } +void PhysicalRangeJoin::GlobalSortedTable::ResetForReuse(ClientContext &client) { + global_sink = sort->GetGlobalSinkState(client); + has_null = 0; + count = 0; + tasks_completed = 0; + global_source.reset(); + sorted.reset(); + found_match.reset(); +} + void PhysicalRangeJoin::GlobalSortedTable::Finalize(ClientContext &client, InterruptState &interrupt) { OperatorSinkFinalizeInput finalize {*global_sink, interrupt}; sort->Finalize(client, finalize); @@ -373,7 +390,7 @@ idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(Vector &primary, const vec const auto count = keys.size(); size_t all_constant = 0; - for (auto &v : keys.data) { + for (const auto &v : keys.data) { if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { ++all_constant; } @@ -385,7 +402,7 @@ idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(Vector &primary, const vec if (conditions[c].GetComparisonType() == ExpressionType::COMPARE_DISTINCT_FROM) { continue; } - auto &v = keys.data[c]; + const auto &v = keys.data[c]; if (ConstantVector::IsNull(v)) { ConstantVector::SetNull(primary, count_t(count)); return count; @@ -404,7 +421,7 @@ idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(Vector &primary, const vec continue; } // ToUnifiedFormat the rest, as the sort code will do this anyway. - auto &v = keys.data[c]; + const auto &v = keys.data[c]; UnifiedVectorFormat vdata; v.ToUnifiedFormat(vdata); auto &vvalidity = vdata.validity; @@ -442,7 +459,7 @@ idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(Vector &primary, const vec } return count - pvalidity.CountValid(count); } else { - return count - VectorOperations::CountNotNull(primary, count); + return count - VectorOperations::CountNotNull(primary); } } @@ -455,7 +472,6 @@ void PhysicalRangeJoin::ProjectResult(DataChunk &chunk, DataChunk &result) const for (idx_t i = 0; i < right_projection_map.size(); ++i) { result.data[left_projected + i].Reference(chunk.data[left_width + right_projection_map[i]]); } - result.SetCardinality(chunk); } template @@ -467,17 +483,18 @@ static void TemplatedSliceSortedPayload(DataChunk &chunk, const SortedRun &sorte using BLOCK_ITERATOR = block_iterator_t; BLOCK_ITERATOR itr(state, chunk_idx, 0); - const auto sort_keys = FlatVector::GetDataMutable(sort_key_pointers); const auto result_size = NumericCast(result.size()); - - for (idx_t i = 0; i < result_size; ++i) { - const auto idx = state.GetIndex(chunk_idx, result[i]); - sort_keys[i] = &itr[idx]; + { + auto writer = FlatVector::Writer(sort_key_pointers, result_size); + for (idx_t i = 0; i < result_size; ++i) { + const auto idx = state.GetIndex(chunk_idx, result[i]); + writer.WriteValue(&itr[idx]); + } } // Scan chunk.Reset(); - scan_state.Scan(sorted_run, sort_key_pointers, result_size, chunk); + scan_state.Scan(sorted_run, sort_key_pointers, chunk); } void PhysicalRangeJoin::SliceSortedPayload(DataChunk &chunk, GlobalSortedTable &table, @@ -530,7 +547,7 @@ void PhysicalRangeJoin::SliceSortedPayload(DataChunk &chunk, GlobalSortedTable & } } -idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, +idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, const Vector &left, const Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel) { switch (condition) { case ExpressionType::COMPARE_NOTEQUAL: diff --git a/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp index 8103902b2..68fdf3baf 100644 --- a/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp @@ -10,6 +10,24 @@ namespace duckdb { +static void RightDelimResetSinkState(const PhysicalOperator &op, ClientContext &context, + unique_ptr &state) { + if (state && state->SupportsReuse()) { + state->Reset(context); + return; + } + state = op.GetGlobalSinkState(context); +} + +static void RightDelimResetLocalSinkState(const PhysicalOperator &op, ExecutionContext &context, + GlobalSinkState &gstate, unique_ptr &state) { + if (state && state->SupportsReuse()) { + state->Reset(context, gstate); + return; + } + state = op.GetLocalSinkState(context); +} + PhysicalRightDelimJoin::PhysicalRightDelimJoin(PhysicalPlan &physical_plan, PhysicalPlanGenerator &planner, vector types, PhysicalOperator &original_join, PhysicalOperator &distinct, @@ -29,16 +47,48 @@ PhysicalRightDelimJoin::PhysicalRightDelimJoin(PhysicalPlan &physical_plan, Phys //===--------------------------------------------------------------------===// // Sink //===--------------------------------------------------------------------===// -class RightDelimJoinGlobalState : public GlobalSinkState {}; +class RightDelimJoinGlobalState : public GlobalSinkState { +public: + explicit RightDelimJoinGlobalState(const PhysicalRightDelimJoin &op) : op(op) { + } + + const PhysicalRightDelimJoin &op; + + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + RightDelimResetSinkState(op.join, context, op.join.sink_state); + RightDelimResetSinkState(op.distinct, context, op.distinct.sink_state); + if (op.delim_scans.size() > 1) { + PhysicalHashAggregate::SetMultiScan(*op.distinct.sink_state); + } + GlobalSinkState::Reset(context); + } +}; class RightDelimJoinLocalState : public LocalSinkState { public: + explicit RightDelimJoinLocalState(const PhysicalRightDelimJoin &op) : op(op) { + } + + const PhysicalRightDelimJoin &op; unique_ptr join_state; unique_ptr distinct_state; + + bool SupportsReuse() const override { + return true; + } + + void Reset(ExecutionContext &context, GlobalSinkState &gstate) override { + RightDelimResetLocalSinkState(op.join, context, *op.join.sink_state, join_state); + RightDelimResetLocalSinkState(op.distinct, context, *op.distinct.sink_state, distinct_state); + } }; unique_ptr PhysicalRightDelimJoin::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(); + auto state = make_uniq(*this); join.sink_state = join.GetGlobalSinkState(context); distinct.sink_state = distinct.GetGlobalSinkState(context); if (delim_scans.size() > 1) { @@ -48,7 +98,7 @@ unique_ptr PhysicalRightDelimJoin::GetGlobalSinkState(ClientCon } unique_ptr PhysicalRightDelimJoin::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(); + auto state = make_uniq(*this); state->join_state = join.GetLocalSinkState(context); state->distinct_state = distinct.GetLocalSinkState(context); return std::move(state); diff --git a/src/duckdb/src/execution/operator/order/physical_top_n.cpp b/src/duckdb/src/execution/operator/order/physical_top_n.cpp index ef2611c7e..912474904 100644 --- a/src/duckdb/src/execution/operator/order/physical_top_n.cpp +++ b/src/duckdb/src/execution/operator/order/physical_top_n.cpp @@ -6,7 +6,7 @@ #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/create_sort_key.hpp" #include "duckdb/storage/data_table.hpp" -#include "duckdb/planner/filter/dynamic_filter.hpp" +#include "duckdb/planner/filter/table_filter_function_helpers.hpp" namespace duckdb { @@ -124,8 +124,8 @@ class TopNHeap { void Scan(TopNScanState &state, DataChunk &chunk, idx_t &pos); bool CheckBoundaryValues(DataChunk &sort_chunk, DataChunk &payload, TopNBoundaryValue &boundary_val); - void AddSmallHeap(DataChunk &input, Vector &sort_keys_vec); - void AddLargeHeap(DataChunk &input, Vector &sort_keys_vec); + void AddSmallHeap(DataChunk &input, const Vector &sort_keys_vec); + void AddLargeHeap(DataChunk &input, const Vector &sort_keys_vec); public: idx_t ReduceThreshold() const { @@ -198,7 +198,7 @@ TopNHeap::TopNHeap(ExecutionContext &context, const vector &payload : TopNHeap(context.client, BufferAllocator::Get(context.client), payload_types, orders, limit, offset) { } -void TopNHeap::AddSmallHeap(DataChunk &input, Vector &sort_keys_vec) { +void TopNHeap::AddSmallHeap(DataChunk &input, const Vector &sort_keys_vec) { // insert the sort keys into the priority queue constexpr idx_t BASE_INDEX = NumericLimits::Maximum(); @@ -240,7 +240,7 @@ void TopNHeap::AddSmallHeap(DataChunk &input, Vector &sort_keys_vec) { heap_data.Append(input, matching_sel, match_count, VectorAppendMode::ALLOW_RESIZE); } -void TopNHeap::AddLargeHeap(DataChunk &input, Vector &sort_keys_vec) { +void TopNHeap::AddLargeHeap(DataChunk &input, const Vector &sort_keys_vec) { auto sort_key_values = FlatVector::GetData(sort_keys_vec); idx_t base_index = heap_data.size(); idx_t match_count = 0; @@ -281,7 +281,7 @@ bool TopNHeap::CheckBoundaryValues(DataChunk &sort_chunk, DataChunk &payload, To col.SetVectorType(VectorType::CONSTANT_VECTOR); } } - boundary_values.SetCardinality(sort_chunk.size()); + boundary_values.SetChildCardinality(sort_chunk.size()); // we have boundary values // from these boundary values, determine which values we should insert (if any) diff --git a/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp b/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp index 33435736b..80c5cc687 100644 --- a/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp +++ b/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp @@ -8,14 +8,16 @@ namespace duckdb { TableCatalogEntry &CSVRejectsTable::GetErrorsTable(ClientContext &context) { - auto &temp_catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - auto &table_entry = temp_catalog.GetEntry(context, TEMP_CATALOG, DEFAULT_SCHEMA, errors_table); + auto &temp_catalog = Catalog::GetCatalog(context, Identifier::TempCatalog()); + auto &table_entry = temp_catalog.GetEntry(context, Identifier::TempCatalog(), + Identifier::DefaultSchema(), Identifier(errors_table)); return table_entry; } TableCatalogEntry &CSVRejectsTable::GetScansTable(ClientContext &context) { - auto &temp_catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - auto &table_entry = temp_catalog.GetEntry(context, TEMP_CATALOG, DEFAULT_SCHEMA, scan_table); + auto &temp_catalog = Catalog::GetCatalog(context, Identifier::TempCatalog()); + auto &table_entry = temp_catalog.GetEntry(context, Identifier::TempCatalog(), + Identifier::DefaultSchema(), Identifier(scan_table)); return table_entry; } @@ -34,14 +36,16 @@ shared_ptr CSVRejectsTable::GetOrCreate(ClientContext &context, throw BinderException("The names of the rejects scan and rejects error tables can't be the same. Use different " "names for these tables."); } - auto key = - "CSV_REJECTS_TABLE_CACHE_ENTRY_" + StringUtil::Upper(rejects_scan) + "_" + StringUtil::Upper(rejects_error); + auto key = StringUtil::Format("CSV_REJECTS_TABLE_CACHE_ENTRY_%s_%s", StringUtil::Upper(rejects_scan), + StringUtil::Upper(rejects_error)); auto &cache = ObjectCache::GetObjectCache(context); - auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - auto rejects_scan_exist = catalog.GetEntry(context, DEFAULT_SCHEMA, rejects_scan, - OnEntryNotFound::RETURN_NULL) != nullptr; - auto rejects_error_exist = catalog.GetEntry(context, DEFAULT_SCHEMA, rejects_error, - OnEntryNotFound::RETURN_NULL) != nullptr; + auto &catalog = Catalog::GetCatalog(context, Identifier::TempCatalog()); + auto rejects_scan_exist = + catalog.GetEntry(context, Identifier::DefaultSchema(), Identifier(rejects_scan), + OnEntryNotFound::RETURN_NULL) != nullptr; + auto rejects_error_exist = + catalog.GetEntry(context, Identifier::DefaultSchema(), Identifier(rejects_error), + OnEntryNotFound::RETURN_NULL) != nullptr; if ((rejects_scan_exist || rejects_error_exist) && !cache.Get(key)) { std::ostringstream error; if (rejects_scan_exist) { @@ -59,7 +63,7 @@ shared_ptr CSVRejectsTable::GetOrCreate(ClientContext &context, void CSVRejectsTable::InitializeTable(ClientContext &context, const ReadCSVData &data) { // (Re)Create the temporary rejects table - auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG); + auto &catalog = Catalog::GetCatalog(context, Identifier::TempCatalog()); // Create CSV_ERROR_TYPE ENUM string enum_name = "CSV_ERROR_TYPE"; @@ -81,7 +85,8 @@ void CSVRejectsTable::InitializeTable(ClientContext &context, const ReadCSVData // Create Rejects Scans Table { - auto info = make_uniq(TEMP_CATALOG, DEFAULT_SCHEMA, scan_table); + auto info = + make_uniq(Identifier::TempCatalog(), Identifier::DefaultSchema(), Identifier(scan_table)); info->temporary = true; info->on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; // 0. Scan ID @@ -114,7 +119,8 @@ void CSVRejectsTable::InitializeTable(ClientContext &context, const ReadCSVData } { // Create Rejects Error Table - auto info = make_uniq(TEMP_CATALOG, DEFAULT_SCHEMA, errors_table); + auto info = make_uniq(Identifier::TempCatalog(), Identifier::DefaultSchema(), + Identifier(errors_table)); info->temporary = true; info->on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; // 0. Scan ID diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp index 4e9114ace..542471a4b 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp @@ -43,7 +43,7 @@ PhysicalBatchCopyToFile::PhysicalBatchCopyToFile(PhysicalPlan &physical_plan, ve InsertionOrderPreservingMap PhysicalBatchCopyToFile::ParamsToString() const { InsertionOrderPreservingMap result; - result["FORMAT"] = StringUtil::Upper(function.name); + result["FORMAT"] = StringUtil::Upper(function.name.GetIdentifierName()); return result; } @@ -692,7 +692,6 @@ SourceResultType PhysicalBatchCopyToFile::GetDataInternal(ExecutionContext &cont switch (return_type) { case CopyFunctionReturnType::CHANGED_ROWS: chunk.data[0].Append(Value::BIGINT(NumericCast(g.rows_copied.load()))); - chunk.SetCardinality(1); break; case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: { vector file_list; @@ -701,14 +700,12 @@ SourceResultType PhysicalBatchCopyToFile::GetDataInternal(ExecutionContext &cont } chunk.data[0].Append(Value::BIGINT(NumericCast(g.rows_copied.load()))); chunk.data[1].Append(Value::LIST(LogicalType::VARCHAR, std::move(file_list))); - chunk.SetCardinality(1); break; } case CopyFunctionReturnType::WRITTEN_FILE_STATISTICS: { if (g.written_file_info) { g.written_file_info->file_path = std::move(fp); PhysicalCopyToFile::ReturnStatistics(chunk, *g.written_file_info); - chunk.SetCardinality(1); } break; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp index 898ce81ef..1c3b31357 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp @@ -97,9 +97,9 @@ class CollectionMerger { if (scan_chunk.size() == 0) { break; } - auto new_row_group = result_collection.Append(scan_chunk, append_state); - if (new_row_group) { - writer.WriteNewRowGroup(optimistic_collection); + auto flushed_row_group_idx = result_collection.Append(scan_chunk, append_state); + if (flushed_row_group_idx.IsValid()) { + writer.WriteNewRowGroup(optimistic_collection, flushed_row_group_idx.GetIndex()); } } data_table.ResetOptimisticCollection(context, collection_indexes[i]); @@ -546,10 +546,10 @@ SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &i auto &optimistic_collection = table.GetStorage().GetOptimisticCollection(context.client, lstate.collection_index); auto &collection = *optimistic_collection.collection; - auto new_row_group = collection.Append(insert_chunk, lstate.current_append_state); - if (new_row_group) { + auto flushed_row_group_idx = collection.Append(insert_chunk, lstate.current_append_state); + if (flushed_row_group_idx.IsValid()) { // we have already written to disk - flush the next row group as well - lstate.optimistic_writer->WriteNewRowGroup(optimistic_collection); + lstate.optimistic_writer->WriteNewRowGroup(optimistic_collection, flushed_row_group_idx.GetIndex()); } return SinkResultType::NEED_MORE_INPUT; } @@ -691,7 +691,6 @@ SourceResultType PhysicalBatchInsert::GetDataInternal(ExecutionContext &context, OperatorSourceInput &input) const { auto &insert_gstate = sink_state->Cast(); - chunk.SetCardinality(1); chunk.data[0].Append(Value::BIGINT(NumericCast(insert_gstate.insert_count))); return SourceResultType::FINISHED; diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp index bf1b950e2..7b91065fc 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -49,6 +49,21 @@ struct VectorOfValuesEquality { } }; +struct VectorOfValuesLess { + bool operator()(const vector &a, const vector &b) const { + const auto count = MinValue(a.size(), b.size()); + for (idx_t i = 0; i < count; i++) { + if (ValueOperations::DistinctLessThan(a[i], b[i])) { + return true; + } + if (ValueOperations::DistinctLessThan(b[i], a[i])) { + return false; + } + } + return a.size() < b.size(); + } +}; + template using vector_of_value_map_t = unordered_map, T, VectorOfValuesHashFunction, VectorOfValuesEquality>; @@ -125,8 +140,11 @@ struct GlobalFileState { idx_t num_batches DUCKDB_GUARDED_BY(lock); }; +static bool PhysicalCopyRotateNow(const PhysicalCopyToFile &op, GlobalFileState &global_state) + DUCKDB_REQUIRES(global_state.lock); + //===--------------------------------------------------------------------===// -// CopyToFileGlobalState hpp +// Copy State Declarations //===--------------------------------------------------------------------===// class PartitionedCopy; @@ -139,9 +157,12 @@ class CopyToFileGlobalState : public GlobalSinkState { void Initialize(); void CreateDir(const string &dir_path) DUCKDB_REQUIRES(lock); + unique_ptr CreateFileStateLocked(string output_path = string(), + optional_ptr> partition_values = nullptr) + DUCKDB_REQUIRES(lock); unique_ptr CreateFileState(string output_path = string(), optional_ptr> partition_values = nullptr) - DUCKDB_REQUIRES(lock); + DUCKDB_EXCLUDES(lock); unique_ptr FinalizeFileStateLocked(unique_ptr file_state) DUCKDB_REQUIRES(lock); void FinalizeFileState(unique_ptr file_state) DUCKDB_EXCLUDES(lock); @@ -173,6 +194,7 @@ class CopyToFileGlobalState : public GlobalSinkState { unique_ptr global_state; //! Lambda to create a new global file state const std::function()> create_file_state_fun; + unordered_set *> creating_file_states DUCKDB_GUARDED_BY(lock); //! The final batch mutable annotated_mutex last_batch_lock; @@ -194,7 +216,7 @@ class CopyToFileGlobalState : public GlobalSinkState { }; //===--------------------------------------------------------------------===// -// CopyToFileLocalState hpp +// Copy Local State Declaration //===--------------------------------------------------------------------===// class PartitionedCopyLocalState; @@ -223,9 +245,10 @@ class CopyToFileLocalState : public LocalSinkState { }; //===--------------------------------------------------------------------===// -// PartitionedCopy hpp +// Partitioned Copy Declarations //===--------------------------------------------------------------------===// -enum class PartitionedCopyStage : uint8_t { SORT, MATERIALIZE, MASK, BATCH, FLUSH, DONE }; +enum class PartitionedCopyStage : uint8_t { SORT, MATERIALIZE, MASK, BATCH, PREPARE, FLUSH, DONE }; +enum class FileCreationReason : uint8_t { NORMAL, SORTED_RUN_BOUNDARY, ROTATION }; struct PartitionedCopyTask { PartitionedCopyStage stage = PartitionedCopyStage::DONE; @@ -248,19 +271,299 @@ struct PartitionedCopyBatch { }; struct PartitionWriteInfo { + //! Serializes operations that need a complete partition writer run boundary. + annotated_mutex lock; unique_ptr file_state; idx_t active_writes = 0; }; -struct CreatePartitionFileStateResult { - unique_ptr file_state; +struct PartitionFileStateReservation { vector> files_to_finalize; + idx_t offset = 0; +}; + +enum class PartitionedCopyBatchMode : uint8_t { BUFFERING, PREPARING, DELAYED, PREPARED }; + +enum class PartitionedCopyCollectionSchema : uint8_t { RAW_SCHEMA, WRITE_SCHEMA }; + +struct PartitionedCopyCollection { + PartitionedCopyCollection() = default; + explicit PartitionedCopyCollection(PartitionedCopyCollectionSchema schema_p) : schema(schema_p) { + } + PartitionedCopyCollection(PartitionedCopyCollectionSchema schema_p, unique_ptr collection_p) + : schema(schema_p), collection(std::move(collection_p)) { + } + + idx_t Count() const { + return collection ? collection->Count() : 0; + } + + PartitionedCopyCollectionSchema schema = PartitionedCopyCollectionSchema::WRITE_SCHEMA; + unique_ptr collection; }; struct PartitionedCopyBatchState { + void SetValues(vector values_p) { + D_ASSERT(values.empty() || values == values_p); + if (values.empty()) { + values = std::move(values_p); + } + } + + idx_t AddCollectionSlot(PartitionedCopyCollectionSchema schema, idx_t row_count) { + D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING || mode == PartitionedCopyBatchMode::PREPARING); + collections.emplace_back(schema); + if (mode == PartitionedCopyBatchMode::PREPARING && batches.size() < collections.size()) { + batches.emplace_back(); + } + count += row_count; + return collections.size() - 1; + } + + void StoreCollection(idx_t batch_idx, unique_ptr collection) { + D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING || mode == PartitionedCopyBatchMode::PREPARING); + D_ASSERT(batch_idx < collections.size()); + collections[batch_idx].collection = std::move(collection); + } + + bool CanStartPreparing(idx_t flush_threshold, bool has_delayed_partition) const { + return mode == PartitionedCopyBatchMode::BUFFERING && count >= flush_threshold && !has_delayed_partition; + } + + void StartPreparing() { + D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING); + mode = PartitionedCopyBatchMode::PREPARING; + batches.resize(collections.size()); + } + + bool TryReserveWriteInfo() { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); + if (write_info || write_info_requested) { + return false; + } + write_info_requested = true; + return true; + } + + void SetWriteInfo(PartitionWriteInfo &write_info_p) { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); + D_ASSERT(!write_info || write_info == optional_ptr(write_info_p)); + write_info = write_info_p; + write_info_requested = false; + batches.resize(collections.size()); + } + + void EnsurePreparingBatchSlots() { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); + D_ASSERT(write_info); + batches.resize(collections.size()); + } + + void MarkDelayed() { + D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING); + mode = PartitionedCopyBatchMode::DELAYED; + } + + void MarkPrepared() { + if (mode == PartitionedCopyBatchMode::PREPARING) { + mode = PartitionedCopyBatchMode::PREPARED; + } + } + + bool SkipsPrepare() const { + return mode == PartitionedCopyBatchMode::DELAYED || mode == PartitionedCopyBatchMode::PREPARED; + } + + bool NeedsWriteInfo() const { + return mode == PartitionedCopyBatchMode::PREPARING && !write_info; + } + + bool IsWriteInfoReserved() const { + return write_info_requested; + } + + bool NeedsPrepare(idx_t batch_idx) const { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); + D_ASSERT(write_info); + D_ASSERT(batch_idx < collections.size()); + return collections[batch_idx].collection && (batch_idx >= batches.size() || !batches[batch_idx]); + } + + PartitionedCopyCollection TakeCollection(idx_t batch_idx) { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); + D_ASSERT(batch_idx < collections.size()); + return std::move(collections[batch_idx]); + } + + void StorePreparedBatch(idx_t batch_idx, unique_ptr batch) { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING || mode == PartitionedCopyBatchMode::PREPARED); + D_ASSERT(batch_idx < batches.size()); + batches[batch_idx] = std::move(batch); + } + + vector TakeDelayedCollections() { + D_ASSERT(mode == PartitionedCopyBatchMode::DELAYED); + return std::move(collections); + } + + vector> TakePreparedBatches() { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARED); + return std::move(batches); + } + vector values; optional_ptr write_info; + vector collections; vector> batches; + idx_t count = 0; + bool write_info_requested = false; + PartitionedCopyBatchMode mode = PartitionedCopyBatchMode::BUFFERING; +}; + +struct DelayedPartitionFlush { + vector values; + PartitionedCopyCollection data; +}; + +struct DelayedPartitionState { + bool HasBufferedData() const { + return buffer.Count() > 0; + } + + bool InFlight() const { + return in_flight > 0; + } + + void Buffer(PartitionedCopyCollection data) { + if (!HasBufferedData()) { + buffer = std::move(data); + return; + } + D_ASSERT(buffer.schema == data.schema); + buffer.collection->Combine(*data.collection); + } + + bool Ready(idx_t flush_threshold, bool force) const { + return HasBufferedData() && !InFlight() && (force || buffer.Count() >= flush_threshold); + } + + DelayedPartitionFlush Take(const vector &values) { + D_ASSERT(HasBufferedData()); + D_ASSERT(!InFlight()); + DelayedPartitionFlush result; + result.values = values; + result.data = std::move(buffer); + in_flight++; + return result; + } + + void Complete() { + D_ASSERT(in_flight > 0); + in_flight--; + } + + bool Active() const { + return HasBufferedData() || InFlight(); + } + + PartitionedCopyCollection buffer; + idx_t in_flight = 0; +}; + +class DelayedPartitionBuffers { +public: + bool Has(const vector &values) DUCKDB_EXCLUDES(lock) { + annotated_lock_guard guard(lock); + return partitions.find(values) != partitions.end(); + } + + bool Empty() const DUCKDB_EXCLUDES(lock) { + annotated_lock_guard guard(lock); + return partitions.empty(); + } + + optional BufferOrTakeReady(const vector &values, PartitionedCopyCollection data, + const idx_t flush_threshold, const bool force = false) + DUCKDB_EXCLUDES(lock) { + if (!data.collection || data.collection->Count() == 0) { + return nullopt; + } + + DelayedPartitionFlush result; + bool has_result = false; + { + annotated_lock_guard guard(lock); + auto &state = partitions[values]; + state.Buffer(std::move(data)); + if (state.Ready(flush_threshold, force)) { + result = state.Take(values); + has_result = true; + } + EraseIfInactive(values, state); + } + + if (!has_result) { + return nullopt; + } + return std::move(result); + } + + optional TakeNext() DUCKDB_EXCLUDES(lock) { + DelayedPartitionFlush result; + { + annotated_lock_guard guard(lock); + auto next = partitions.end(); + for (auto entry = partitions.begin(); entry != partitions.end(); entry++) { + auto &state = entry->second; + if (!state.HasBufferedData() || state.InFlight()) { + continue; + } + if (next == partitions.end() || VectorOfValuesLess()(entry->first, next->first)) { + next = entry; + } + } + if (next != partitions.end()) { + result = next->second.Take(next->first); + return std::move(result); + } + } + return nullopt; + } + + optional Complete(const vector &values, const idx_t flush_threshold, + const bool take_ready) DUCKDB_EXCLUDES(lock) { + DelayedPartitionFlush result; + bool has_result = false; + annotated_lock_guard guard(lock); + auto entry = partitions.find(values); + D_ASSERT(entry != partitions.end()); + if (entry == partitions.end()) { + return nullopt; + } + entry->second.Complete(); + if (take_ready && entry->second.Ready(flush_threshold, false)) { + // Data for the same partition reached the threshold while an earlier flush was in-flight. + // Hand it back to the caller immediately so the partition does not wait for a global drain. + result = entry->second.Take(values); + has_result = true; + } else if (!entry->second.Active()) { + partitions.erase(entry); + } + if (!has_result) { + return nullopt; + } + return std::move(result); + } + +private: + void EraseIfInactive(const vector &values, const DelayedPartitionState &state) DUCKDB_REQUIRES(lock) { + if (!state.Active()) { + partitions.erase(values); + } + } + + mutable annotated_mutex lock; + vector_of_value_map_t partitions DUCKDB_GUARDED_BY(lock); }; //! Manages a single partitioned COPY hash bin @@ -287,6 +590,7 @@ class PartitionedCopyHashGroup { bool TryPrepareNextStage() DUCKDB_REQUIRES(lock); optional TryNextTask() DUCKDB_REQUIRES(lock); optional TryNextBatchTask() DUCKDB_REQUIRES(lock); + optional TryNextPrepareTask() DUCKDB_REQUIRES(lock); optional TryNextFlushTask() DUCKDB_REQUIRES(lock); void Sort(ExecutionContext &context, GlobalSinkState &sink, InterruptState &interrupt, @@ -295,7 +599,10 @@ class PartitionedCopyHashGroup { const PartitionedCopyTask &task); void Mask(const PartitionedCopyTask &task); void Batch(const PartitionedCopyTask &task); - void Flush(const PartitionedCopyTask &task); + void Prepare(ExecutionContext &execution_context, InterruptState &interrupt_state, const PartitionedCopyTask &task); + void Flush(ExecutionContext &execution_context, InterruptState &interrupt_state, const PartitionedCopyTask &task); + void PrepareBatchStates() DUCKDB_REQUIRES(lock); + void CompletePreparedBatchStates() DUCKDB_REQUIRES(lock); public: //! The PartitionedCopy that this hash group belongs to @@ -342,6 +649,12 @@ class PartitionedCopyHashGroup { vector> batch_states DUCKDB_GUARDED_BY(lock); //! Count of batched rows atomic batched; + //! Count of rows that either have a prepared batch or do not need one + atomic prepared; + + //! The current partition/batch index for preparing + idx_t prepare_partition_idx DUCKDB_GUARDED_BY(lock) = 0; + idx_t prepare_batch_idx DUCKDB_GUARDED_BY(lock) = 0; //! The current partition index for flushing idx_t flush_partition_idx DUCKDB_GUARDED_BY(lock) = 0; @@ -359,6 +672,7 @@ class PartitionedCopyState { public: bool ShouldInitiateFlush(const idx_t &local_append_count); + bool IsCombineComplete() const DUCKDB_REQUIRES(lock); void CreateTaskList() DUCKDB_REQUIRES(lock); bool HasCompleted() const; optional TryAssignTask(); @@ -417,10 +731,38 @@ class PartitionedCopy { void FinalizeState(PartitionedCopyState &state, InterruptState &interrupt_state) DUCKDB_REQUIRES(state.lock); PartitionWriteInfo &GetPartitionWriteInfo(const vector &values) DUCKDB_EXCLUDES(active_writes_lock); - unique_ptr CreatePartitionFileState(const vector &values) + void ReleasePartitionWriteInfo(PartitionWriteInfo &write_info) DUCKDB_EXCLUDES(active_writes_lock); + bool HasDelayedPartition(const vector &values); + bool HasDelayedPartitions() const; + optional BufferOrTakeReadyPartition(const vector &values, + PartitionedCopyCollection data, bool force = false); + optional TakeNextDelayedPartition(); + optional CompleteDelayedPartition(const vector &values, bool take_ready); + bool DrainDelayedPartitions(ExecutionContext &execution_context, InterruptState &interrupt_state); + unique_ptr SortPartitionCollection(ExecutionContext &execution_context, + InterruptState &interrupt_state, + unique_ptr collection); + PartitionedCopyCollectionSchema GetPartitionCollectionSchema() const; + const vector &GetPartitionCollectionTypes(PartitionedCopyCollectionSchema schema) const; + unique_ptr PrepareCollectionForWrite(PartitionedCopyCollection data); + unique_ptr ProjectToWriteColumns(unique_ptr collection); + unique_ptr PreparePartitionBatch(const vector &values, PartitionWriteInfo &write_info, + PartitionedCopyCollection data); + void FlushPreparedPartitionRun(const vector &values, PartitionWriteInfo &write_info, + vector> batches); + void FlushPreparedPartitionBatch(const vector &values, PartitionWriteInfo &write_info, + unique_ptr batch); + void FlushPartitionCollection(ExecutionContext &execution_context, InterruptState &interrupt_state, + DelayedPartitionFlush flush); + unique_ptr CreatePartitionFileState(const vector &values, + FileCreationReason reason = FileCreationReason::NORMAL) DUCKDB_EXCLUDES(active_writes_lock); - CreatePartitionFileStateResult CreatePartitionFileStateLocked(const vector &values) + PartitionFileStateReservation + ReservePartitionFileStateLocked(const vector &values, FileCreationReason reason = FileCreationReason::NORMAL) DUCKDB_REQUIRES(active_writes_lock); + unique_ptr CreatePartitionFileStateFromReservation(const vector &values, idx_t offset) + DUCKDB_EXCLUDES(active_writes_lock); + void FinalizeActiveWrites() DUCKDB_EXCLUDES(active_writes_lock); void FinalizeFileStates(vector> files_to_finalize) DUCKDB_EXCLUDES(active_writes_lock); string GetOrCreateDirectory(string path, const vector &values) DUCKDB_REQUIRES(copy_gstate.lock); @@ -428,6 +770,25 @@ class PartitionedCopy { unique_ptr ConstructSortStrategy() const; void CreateNextState(); bool ShouldStopFlushing() const; + bool RequiresSerializedPartitionWrites() const; + void EnsureFreshPartitionFileForSortedRun(PartitionWriteInfo &write_info, const vector &values) + DUCKDB_EXCLUDES(active_writes_lock); + void EnsureFreshPartitionFileForRotation(PartitionWriteInfo &write_info, const vector &values) + DUCKDB_EXCLUDES(active_writes_lock); + //! Swaps write_info.file_state after temporarily dropping copy_gstate.lock to initialize the replacement file. + //! Callers that can reach the swap path must serialize the full partition writer run for this write_info. + void EnsureFreshPartitionFile(PartitionWriteInfo &write_info, const vector &values, + FileCreationReason reason) DUCKDB_EXCLUDES(active_writes_lock); + template + void WithSerializedPartitionWriteRun(PartitionWriteInfo &write_info, FUNC &&func) { + annotated_unique_lock run_guard(write_info.lock, std::defer_lock); + if (RequiresSerializedPartitionWrites()) { + run_guard.lock(); + } + func(); + } + void FlushDelayedPartitionRun(const vector &values, PartitionWriteInfo &write_info, + ColumnDataCollection &collection); public: const PhysicalCopyToFile &op; @@ -437,6 +798,7 @@ class PartitionedCopy { //! Which columns/types to write to the file vector write_columns; vector write_types; + vector raw_columns; //! Partition/sort strategy with PhysicalOperator-like interface const unique_ptr sort_strategy; @@ -462,12 +824,69 @@ class PartitionedCopy { vector_of_value_map_t> active_writes DUCKDB_GUARDED_BY(active_writes_lock); vector_of_value_map_t previous_partitions DUCKDB_GUARDED_BY(active_writes_lock); idx_t global_offset DUCKDB_GUARDED_BY(active_writes_lock) = 0; + + //! Delayed below-threshold partitions + DelayedPartitionBuffers delayed_partition_buffers; }; +//===--------------------------------------------------------------------===// +// Partitioned Copy Scoped Guards +//===--------------------------------------------------------------------===// +class PartitionWriteInfoGuard { +public: + PartitionWriteInfoGuard(PartitionedCopy &partitioned_copy_p, PartitionWriteInfo &write_info_p) + : partitioned_copy(partitioned_copy_p), write_info(write_info_p) { + } + + ~PartitionWriteInfoGuard(); + +private: + PartitionedCopy &partitioned_copy; + optional_ptr write_info; +}; + +class DelayedPartitionFlushGuard { +public: + DelayedPartitionFlushGuard(PartitionedCopy &partitioned_copy_p, const vector &values_p) + : partitioned_copy(partitioned_copy_p), values(values_p) { + } + + ~DelayedPartitionFlushGuard(); + optional Complete(); + +private: + PartitionedCopy &partitioned_copy; + const vector &values; + bool active = true; +}; + +PartitionWriteInfoGuard::~PartitionWriteInfoGuard() { + if (write_info) { + partitioned_copy.ReleasePartitionWriteInfo(*write_info); + } +} + +DelayedPartitionFlushGuard::~DelayedPartitionFlushGuard() { + if (active) { + partitioned_copy.CompleteDelayedPartition(values, false); + } +} + +optional DelayedPartitionFlushGuard::Complete() { + if (!active) { + return nullopt; + } + active = false; + return partitioned_copy.CompleteDelayedPartition(values, true); +} + +//===--------------------------------------------------------------------===// +// Partitioned Hash Group Implementation +//===--------------------------------------------------------------------===// PartitionedCopyHashGroup::PartitionedCopyHashGroup(PartitionedCopy &partitioned_copy, const ChunkRow &chunk_row, idx_t group_idx_p) : partitioned_copy(partitioned_copy), count(chunk_row.count), blocks(chunk_row.chunks), group_idx(group_idx_p), - stage(PartitionedCopyStage::SORT), sorted(0), materialized(0), masked(0), batched(0), flushed(0) { + stage(PartitionedCopyStage::SORT), sorted(0), materialized(0), masked(0), batched(0), prepared(0), flushed(0) { } PartitionedCopyStage PartitionedCopyHashGroup::GetStage() const { @@ -507,6 +926,14 @@ bool PartitionedCopyHashGroup::TryPrepareNextStage() { return false; case PartitionedCopyStage::BATCH: if (batched == count) { + PrepareBatchStates(); + stage = PartitionedCopyStage::PREPARE; + return true; + } + return false; + case PartitionedCopyStage::PREPARE: + if (prepared == count) { + CompletePreparedBatchStates(); stage = PartitionedCopyStage::FLUSH; return true; } @@ -530,6 +957,8 @@ optional PartitionedCopyHashGroup::TryNextTask() { switch (group_stage) { case PartitionedCopyStage::BATCH: return TryNextBatchTask(); + case PartitionedCopyStage::PREPARE: + return TryNextPrepareTask(); case PartitionedCopyStage::FLUSH: return TryNextFlushTask(); default: @@ -564,7 +993,7 @@ optional PartitionedCopyHashGroup::TryNextBatchTask() { task.stage = PartitionedCopyStage::BATCH; task.group_idx = group_idx; task.thread_idx = batch_states.size() - 1; - task.batch_idx = batch_state.batches.size(); + task.batch_idx = batch_state.collections.size(); task.begin_idx = batch_row_idx; // Find the end_idx @@ -630,11 +1059,58 @@ optional PartitionedCopyHashGroup::TryNextBatchTask() { task.end_idx = batch_row_idx; // Update partition/batch counters - batch_state.batches.emplace_back(); + const auto batch_idx = + batch_state.AddCollectionSlot(partitioned_copy.GetPartitionCollectionSchema(), task.end_idx - task.begin_idx); + D_ASSERT(batch_idx == task.batch_idx); return task; } +optional PartitionedCopyHashGroup::TryNextPrepareTask() { + while (prepare_partition_idx < batch_states.size()) { + auto &batch_state = *batch_states[prepare_partition_idx]; + if (batch_state.SkipsPrepare()) { + ++prepare_partition_idx; + prepare_batch_idx = 0; + continue; + } + D_ASSERT(batch_state.mode == PartitionedCopyBatchMode::PREPARING); + if (batch_state.NeedsWriteInfo()) { + if (!batch_state.IsWriteInfoReserved() && batch_state.TryReserveWriteInfo()) { + PartitionedCopyTask task; + task.stage = PartitionedCopyStage::PREPARE; + task.group_idx = group_idx; + task.thread_idx = prepare_partition_idx; + task.batch_idx = DConstants::INVALID_INDEX; + return task; + } + return nullopt; + } + while (prepare_batch_idx < batch_state.collections.size()) { + if (!batch_state.NeedsPrepare(prepare_batch_idx)) { + ++prepare_batch_idx; + continue; + } + break; + } + if (prepare_batch_idx == batch_state.collections.size()) { + ++prepare_partition_idx; + prepare_batch_idx = 0; + continue; + } + + PartitionedCopyTask task; + task.stage = PartitionedCopyStage::PREPARE; + task.group_idx = group_idx; + task.thread_idx = prepare_partition_idx; + task.batch_idx = prepare_batch_idx++; + + return task; + } + + return nullopt; +} + optional PartitionedCopyHashGroup::TryNextFlushTask() { if (flush_partition_idx == batch_states.size()) { return nullopt; @@ -662,9 +1138,12 @@ void PartitionedCopyHashGroup::Materialize(ExecutionContext &execution_context, auto unused = make_uniq(); OperatorSourceInput source_input {source, *unused, interrupt}; partitioned_copy.sort_strategy->MaterializeColumnData(execution_context, group_idx, source_input); - materialized += (task.end_idx - task.begin_idx); + // Read `blocks` before incrementing `materialized`: once materialized == blocks another thread can advance + // the stage all the way to FLUSH completion and delete this hash group, making `blocks` a dangling read. + const idx_t local_blocks = blocks; + const auto new_materialized = (materialized += (task.end_idx - task.begin_idx)); - if (materialized >= blocks) { + if (new_materialized >= local_blocks) { annotated_lock_guard guard(lock); if (!collection) { collection = partitioned_copy.sort_strategy->GetColumnData(group_idx, source_input); @@ -714,7 +1193,11 @@ void PartitionedCopyHashGroup::Mask(const PartitionedCopyTask &task) { void PartitionedCopyHashGroup::Batch(const PartitionedCopyTask &task) { D_ASSERT(task.stage == PartitionedCopyStage::BATCH); - const auto &op = partitioned_copy.op; + const auto collection_schema = partitioned_copy.GetPartitionCollectionSchema(); + const auto &batch_types = partitioned_copy.GetPartitionCollectionTypes(collection_schema); + const auto &batch_columns = collection_schema == PartitionedCopyCollectionSchema::RAW_SCHEMA + ? partitioned_copy.raw_columns + : partitioned_copy.write_columns; // Initialize the scan ColumnDataScanState scan_state; @@ -724,12 +1207,12 @@ void PartitionedCopyHashGroup::Batch(const PartitionedCopyTask &task) { collection->Seek(task.begin_idx, scan_state, scan_chunk); D_ASSERT(task.begin_idx >= scan_state.current_row_index); - // Initialize the write chunk - DataChunk write_chunk; - write_chunk.Initialize(partitioned_copy.context, partitioned_copy.write_types); + // Initialize the batch chunk + DataChunk batch_chunk; + batch_chunk.Initialize(partitioned_copy.context, batch_types); // Initialize the append - auto batch = make_uniq(partitioned_copy.context, partitioned_copy.write_types); + auto batch = make_uniq(partitioned_copy.context, batch_types); ColumnDataAppendState append_state; batch->InitializeAppend(append_state); @@ -746,62 +1229,206 @@ void PartitionedCopyHashGroup::Batch(const PartitionedCopyTask &task) { } } - write_chunk.ReferenceColumns(scan_chunk, partitioned_copy.write_columns); - batch->Append(append_state, write_chunk); + batch_chunk.ReferenceColumns(scan_chunk, batch_columns); + batch->Append(append_state, batch_chunk); collection->Scan(scan_state, scan_chunk); } - // Get pointer to batch state and initialize file if needed (under lock) + // Get pointer to batch state (under lock) optional_ptr batch_state; + optional_ptr write_info; + vector partition_values; + bool acquire_write_info = false; + bool prepare_batch = false; { annotated_lock_guard guard(lock); batch_state = batch_states[task.thread_idx]; - D_ASSERT(batch_state->values.empty() || batch_state->values == values); - if (batch_state->values.empty()) { - batch_state->values = std::move(values); - batch_state->write_info = partitioned_copy.GetPartitionWriteInfo(batch_state->values); + batch_state->SetValues(std::move(values)); + + if (batch_state->mode == PartitionedCopyBatchMode::BUFFERING) { + const auto flush_threshold = Settings::Get(partitioned_copy.context); + const auto has_delayed_partition = partitioned_copy.HasDelayedPartition(batch_state->values); + if (batch_state->CanStartPreparing(flush_threshold, has_delayed_partition)) { + batch_state->StartPreparing(); + acquire_write_info = batch_state->TryReserveWriteInfo(); + } + } + + if (batch_state->mode == PartitionedCopyBatchMode::PREPARING && batch_state->write_info) { + write_info = batch_state->write_info; + partition_values = batch_state->values; + prepare_batch = true; + } else if (acquire_write_info) { + partition_values = batch_state->values; + prepare_batch = true; } } - // Prepare the batch (lock-free) - const auto create_file_state_fun = [&]() { - return partitioned_copy.CreatePartitionFileState(batch_state->values); - }; - auto [batch_analyzer, prepared_batch] = - op.PrepareBatch(partitioned_copy.context, partitioned_copy.copy_gstate, batch_state->write_info->file_state, - create_file_state_fun, std::move(batch)); + const auto row_count = task.end_idx - task.begin_idx; + if (!prepare_batch) { + { + annotated_lock_guard guard(lock); + auto ¤t_batch_state = *batch_states[task.thread_idx]; + current_batch_state.StoreCollection(task.batch_idx, std::move(batch)); + } + batched += row_count; + return; + } - annotated_lock_guard guard(lock); - batch_state->batches[task.batch_idx] = make_uniq(batch_analyzer, std::move(prepared_batch)); - // TODO: if this batch is an entire partition, we could immediately write it without moving it + if (acquire_write_info) { + auto &partition_write_info = partitioned_copy.GetPartitionWriteInfo(partition_values); + { + annotated_lock_guard guard(lock); + auto ¤t_batch_state = *batch_states[task.thread_idx]; + current_batch_state.SetWriteInfo(partition_write_info); + write_info = current_batch_state.write_info; + } + } + D_ASSERT(write_info); + auto prepared_batch = partitioned_copy.PreparePartitionBatch( + partition_values, *write_info, PartitionedCopyCollection(collection_schema, std::move(batch))); + + { + annotated_lock_guard guard(lock); + auto ¤t_batch_state = *batch_states[task.thread_idx]; + current_batch_state.StorePreparedBatch(task.batch_idx, std::move(prepared_batch)); + } - batched += (task.end_idx - task.begin_idx); + prepared += row_count; + batched += row_count; } -void PartitionedCopyHashGroup::Flush(const PartitionedCopyTask &task) { - D_ASSERT(task.stage == PartitionedCopyStage::FLUSH); +void PartitionedCopyHashGroup::PrepareBatchStates() { + const auto flush_threshold = Settings::Get(partitioned_copy.context); + for (auto &batch_state : batch_states) { + D_ASSERT(batch_state); + D_ASSERT(!batch_state->values.empty()); + D_ASSERT(batch_state->count > 0); - optional_ptr batch_state; + if (batch_state->mode == PartitionedCopyBatchMode::PREPARING) { + D_ASSERT(batch_state->write_info); + batch_state->EnsurePreparingBatchSlots(); + continue; + } + D_ASSERT(batch_state->mode == PartitionedCopyBatchMode::BUFFERING); + + if (batch_state->count < flush_threshold || partitioned_copy.HasDelayedPartition(batch_state->values)) { + batch_state->MarkDelayed(); + prepared += batch_state->count; + continue; + } + + batch_state->StartPreparing(); + } +} + +void PartitionedCopyHashGroup::CompletePreparedBatchStates() { + for (auto &batch_state : batch_states) { + D_ASSERT(batch_state); + batch_state->MarkPrepared(); + } +} + +void PartitionedCopyHashGroup::Prepare(ExecutionContext &execution_context, InterruptState &interrupt_state, + const PartitionedCopyTask &task) { + D_ASSERT(task.stage == PartitionedCopyStage::PREPARE); + + vector values; + optional_ptr write_info; + PartitionedCopyCollection data; + bool acquire_write_info = false; { annotated_lock_guard guard(lock); - batch_state = batch_states[task.thread_idx]; + auto &batch_state = *batch_states[task.thread_idx]; + D_ASSERT(batch_state.mode == PartitionedCopyBatchMode::PREPARING); + values = batch_state.values; + if (task.batch_idx == DConstants::INVALID_INDEX) { + D_ASSERT(batch_state.NeedsWriteInfo()); + D_ASSERT(batch_state.IsWriteInfoReserved()); + acquire_write_info = true; + } else { + D_ASSERT(batch_state.write_info); + D_ASSERT(task.batch_idx < batch_state.collections.size()); + write_info = batch_state.write_info; + data = batch_state.TakeCollection(task.batch_idx); + } } - auto &op = partitioned_copy.op; - const auto create_file_state_fun = [&]() { - return partitioned_copy.CreatePartitionFileState(batch_state->values); - }; - for (auto &batch : batch_state->batches) { - op.FlushBatch(partitioned_copy.context, partitioned_copy.copy_gstate, batch_state->write_info->file_state, - create_file_state_fun, batch->batch_analyzer, std::move(batch->prepared_batch)); + if (acquire_write_info) { + auto &partition_write_info = partitioned_copy.GetPartitionWriteInfo(values); + annotated_lock_guard guard(lock); + auto &batch_state = *batch_states[task.thread_idx]; + batch_state.SetWriteInfo(partition_write_info); + return; } + D_ASSERT(data.collection); + const auto row_count = data.Count(); + + auto prepared_batch = partitioned_copy.PreparePartitionBatch(values, *write_info, std::move(data)); { - annotated_lock_guard guard(partitioned_copy.active_writes_lock); - batch_state->write_info->active_writes--; + annotated_lock_guard guard(lock); + auto &batch_state = *batch_states[task.thread_idx]; + batch_state.StorePreparedBatch(task.batch_idx, std::move(prepared_batch)); + } + + prepared += row_count; +} + +void PartitionedCopyHashGroup::Flush(ExecutionContext &execution_context, InterruptState &interrupt_state, + const PartitionedCopyTask &task) { + D_ASSERT(task.stage == PartitionedCopyStage::FLUSH); + + vector values; + vector collections; + vector> batches; + optional_ptr write_info; + PartitionedCopyBatchMode mode; + { + annotated_lock_guard guard(lock); + auto &batch_state = *batch_states[task.thread_idx]; + values = batch_state.values; + mode = batch_state.mode; + if (mode == PartitionedCopyBatchMode::DELAYED) { + collections = batch_state.TakeDelayedCollections(); + } else { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARED); + write_info = batch_state.write_info; + batches = batch_state.TakePreparedBatches(); + } } + D_ASSERT(!values.empty()); + + if (mode == PartitionedCopyBatchMode::DELAYED) { + const auto collection_schema = partitioned_copy.GetPartitionCollectionSchema(); + const auto &collection_types = partitioned_copy.GetPartitionCollectionTypes(collection_schema); + auto collection = make_uniq(partitioned_copy.context, collection_types); + for (auto &batch : collections) { + D_ASSERT(batch.schema == collection_schema); + if (batch.collection) { + collection->Combine(*batch.collection); + } + } + if (collection->Count() == 0) { + return; + } + + auto data = PartitionedCopyCollection(collection_schema, std::move(collection)); + auto ready_partition = partitioned_copy.BufferOrTakeReadyPartition(values, std::move(data), false); + if (ready_partition) { + partitioned_copy.FlushPartitionCollection(execution_context, interrupt_state, std::move(*ready_partition)); + } + return; + } + + D_ASSERT(write_info); + PartitionWriteInfoGuard write_guard(partitioned_copy, *write_info); + partitioned_copy.FlushPreparedPartitionRun(values, *write_info, std::move(batches)); } +//===--------------------------------------------------------------------===// +// Partitioned Copy State Implementation +//===--------------------------------------------------------------------===// PartitionedCopyState::PartitionedCopyState(PartitionedCopy &partitioned_copy_p, unique_ptr global_sink_state_p) : partitioned_copy(partitioned_copy_p), @@ -816,7 +1443,8 @@ bool PartitionedCopyState::ShouldInitiateFlush(const idx_t &local_append_count) } // CAS so only one thread increments "next_flush_check" - const auto desired = expected + Settings::Get(partitioned_copy.context); + const auto flush_threshold = Settings::Get(partitioned_copy.context); + const auto desired = expected + flush_threshold; const auto exchanged = next_flush_check.compare_exchange_strong(expected, desired, std::memory_order_relaxed, std::memory_order_relaxed); if (!exchanged) { @@ -829,7 +1457,11 @@ bool PartitionedCopyState::ShouldInitiateFlush(const idx_t &local_append_count) unique_count = MaxValue(unique_count, 1); // If we have enough rows per partition on average, we can start flushing - return total_count / unique_count >= Settings::Get(partitioned_copy.context); + return total_count / unique_count >= 2 * flush_threshold; +} + +bool PartitionedCopyState::IsCombineComplete() const { + return combined == locals; } void PartitionedCopyState::CreateTaskList() { @@ -902,7 +1534,6 @@ optional PartitionedCopyState::TryAssignTask() { while (next_group < partition_blocks.size()) { const auto group_idx = partition_blocks[next_group++].second; active_groups.emplace_back(group_idx); - auto &hash_group = hash_groups[group_idx]; annotated_lock_guard guard(hash_group->lock); hash_group->TryPrepareNextStage(); @@ -933,8 +1564,11 @@ void PartitionedCopyState::ExecuteTask(ExecutionContext &execution_context, cons case PartitionedCopyStage::BATCH: hash_group.Batch(task); break; + case PartitionedCopyStage::PREPARE: + hash_group.Prepare(execution_context, interrupt, task); + break; case PartitionedCopyStage::FLUSH: - hash_group.Flush(task); + hash_group.Flush(execution_context, interrupt, task); break; default: throw InternalException("Invalid PartitionedCopyStage in PartitionedCopyState::ExecuteTask"); @@ -967,12 +1601,16 @@ class PartitionedCopyLocalState : public LocalSinkState { idx_t append_count = 0; }; +//===--------------------------------------------------------------------===// +// Partitioned Copy Lifecycle +//===--------------------------------------------------------------------===// PartitionedCopy::PartitionedCopy(const PhysicalCopyToFile &op_p, ClientContext &context_p, CopyToFileGlobalState ©_gstate_p) : op(op_p), context(context_p), copy_gstate(copy_gstate_p), sort_strategy(ConstructSortStrategy()), flushing(false), locals(0), combined(0), finalized(false) { unordered_set part_col_set(op.partition_columns.begin(), op.partition_columns.end()); for (idx_t col_idx = 0; col_idx < op.expected_types.size(); col_idx++) { + raw_columns.push_back(col_idx); if (op.write_partition_columns || part_col_set.find(col_idx) == part_col_set.end()) { write_columns.push_back(col_idx); write_types.push_back(op.expected_types[col_idx]); @@ -985,11 +1623,10 @@ unique_ptr PartitionedCopy::ConstructSortStrategy() const { for (auto &col : op.partition_columns) { partition_bys.push_back(make_uniq(op.expected_types[col], col)); } - vector order_bys; vector> partition_stats; - return SortStrategy::Factory(context, partition_bys, order_bys, op.children[0].get().GetTypes(), partition_stats, - op.children[0].get().estimated_cardinality); + return SortStrategy::Factory(context, partition_bys, op.order_columns, op.expected_types, partition_stats, + op.children.empty() ? 0 : op.children[0].get().estimated_cardinality); } void PartitionedCopy::CreateNextState() { @@ -1004,6 +1641,12 @@ bool PartitionedCopy::ShouldStopFlushing() const { locals.load(std::memory_order_relaxed) == combined.load(std::memory_order_relaxed); } +bool PartitionedCopy::RequiresSerializedPartitionWrites() const { + // A full partition writer run must remain serialized when the run boundary has file-state semantics: + // ORDER BY starts a fresh file per sorted run, and rotation can start a fresh file between batches. + return !op.order_columns.empty() || op.Rotate(); +} + void PartitionedCopy::InitializeFlush() { if (flushing) { return; @@ -1120,9 +1763,15 @@ class PartitionedCopyFinalizeTask : public ExecutorTask { ExecutionContext execution_context(event->GetClientContext(), *thread_context, event->Cast().pipeline); InterruptState interrupt_state(shared_from_this()); - while (partitioned_copy.flushing.load(std::memory_order_relaxed)) { + while (partitioned_copy.flushing.load(std::memory_order_relaxed) || partitioned_copy.HasDelayedPartitions()) { event->GetClientContext().InterruptCheck(); - partitioned_copy.Flush(execution_context, interrupt_state); + if (partitioned_copy.flushing.load(std::memory_order_relaxed)) { + partitioned_copy.Flush(execution_context, interrupt_state); + } else { + if (!partitioned_copy.DrainDelayedPartitions(execution_context, interrupt_state)) { + TaskScheduler::YieldThread(); + } + } } event->FinishTask(); return TaskExecutionResult::TASK_FINISHED; @@ -1148,13 +1797,18 @@ class PartitionedCopyFinalizeEvent : public BasePipelineEvent { if (!partitioned_copy.flushing_state) { partitioned_copy.InitializeFlush(); } - auto &flushing_state = *partitioned_copy.flushing_state; - - annotated_lock_guard state_guard(flushing_state.lock); - D_ASSERT(flushing_state.combined == flushing_state.locals); + idx_t task_count; + if (partitioned_copy.flushing_state) { + auto &flushing_state = *partitioned_copy.flushing_state; + annotated_lock_guard state_guard(flushing_state.lock); + D_ASSERT(flushing_state.combined == flushing_state.locals); + task_count = flushing_state.locals; + } else { + task_count = MaxValue(partitioned_copy.locals.load(std::memory_order_relaxed), 1); + } vector> tasks; - for (idx_t i = 0; i < flushing_state.locals; i++) { + for (idx_t i = 0; i < task_count; i++) { tasks.push_back( make_shared_ptr(GetClientContext(), shared_from_this(), partitioned_copy)); } @@ -1162,15 +1816,21 @@ class PartitionedCopyFinalizeEvent : public BasePipelineEvent { } void FinishEvent() override { - annotated_lock_guard global_guard(partitioned_copy.lock); - if (!partitioned_copy.sinking_state) { - partitioned_copy.copy_gstate.TryFinalizeOwnedFileState(); - return; // Done + bool done; + { + annotated_lock_guard global_guard(partitioned_copy.lock); + done = !partitioned_copy.sinking_state; + if (!done) { + auto partitioned_copy_finalize_event = + make_shared_ptr(*pipeline, partitioned_copy); + InsertEvent(std::move(partitioned_copy_finalize_event)); + } } - auto partitioned_copy_finalize_event = - make_shared_ptr(*pipeline, partitioned_copy); - InsertEvent(std::move(partitioned_copy_finalize_event)); + if (done) { + partitioned_copy.FinalizeActiveWrites(); + partitioned_copy.copy_gstate.TryFinalizeOwnedFileState(); + } } private: @@ -1178,19 +1838,28 @@ class PartitionedCopyFinalizeEvent : public BasePipelineEvent { }; void PartitionedCopy::Finalize(Pipeline &pipeline, Event &event, InterruptState &interrupt_state) { - annotated_lock_guard global_guard(lock); - finalized = true; - if (!sinking_state && !flushing_state) { - return; - } + bool should_finalize_writes; + { + annotated_lock_guard global_guard(lock); + finalized = true; + if (!sinking_state && !flushing_state && !HasDelayedPartitions()) { + should_finalize_writes = true; + } else { + should_finalize_writes = false; + if (sinking_state) { + annotated_lock_guard guard(sinking_state->lock); + FinalizeState(*sinking_state, interrupt_state); + } - if (sinking_state) { - annotated_lock_guard guard(sinking_state->lock); - FinalizeState(*sinking_state, interrupt_state); + auto partitioned_copy_finalize_event = make_shared_ptr(pipeline, *this); + event.InsertEvent(std::move(partitioned_copy_finalize_event)); + } } - auto partitioned_copy_finalize_event = make_shared_ptr(pipeline, *this); - event.InsertEvent(std::move(partitioned_copy_finalize_event)); + if (should_finalize_writes) { + FinalizeActiveWrites(); + copy_gstate.TryFinalizeOwnedFileState(); + } } void PartitionedCopy::Flush(ExecutionContext &execution_context, InterruptState &interrupt_state) { @@ -1206,9 +1875,18 @@ void PartitionedCopy::Flush(ExecutionContext &execution_context, InterruptState { annotated_lock_guard guard(flushing_state_copy->lock); + D_ASSERT(flushing_state_copy->combined <= flushing_state_copy->locals); if (!flushing_state_copy->global_source_state) { - return; // Finalization not yet complete, nothing to do + if (!flushing_state_copy->IsCombineComplete()) { + return; // Combine not complete yet + } + D_ASSERT(flushing_state_copy->IsCombineComplete()); + FinalizeState(*flushing_state_copy, interrupt_state); + if (!flushing_state_copy->global_source_state) { + return; // Finalization not yet complete, nothing to do + } } + D_ASSERT(flushing_state_copy->global_source_state); } if (ShouldStopFlushing()) { @@ -1237,21 +1915,243 @@ void PartitionedCopy::Flush(ExecutionContext &execution_context, InterruptState should_finalize_writes = finalized && !sinking_state; } if (should_finalize_writes) { - vector> files_to_finalize; - { - annotated_lock_guard aw_guard(active_writes_lock); - for (auto &entry : active_writes) { - files_to_finalize.push_back(std::move(entry.second->file_state)); + DrainDelayedPartitions(execution_context, interrupt_state); + } +} + +//===--------------------------------------------------------------------===// +// Delayed Partition Buffering +//===--------------------------------------------------------------------===// +bool PartitionedCopy::HasDelayedPartition(const vector &values) { + return delayed_partition_buffers.Has(values); +} + +bool PartitionedCopy::HasDelayedPartitions() const { + return !delayed_partition_buffers.Empty(); +} + +optional +PartitionedCopy::BufferOrTakeReadyPartition(const vector &values, PartitionedCopyCollection data, bool force) { + const auto flush_threshold = Settings::Get(context); + return delayed_partition_buffers.BufferOrTakeReady(values, std::move(data), flush_threshold, force); +} + +optional PartitionedCopy::TakeNextDelayedPartition() { + return delayed_partition_buffers.TakeNext(); +} + +optional PartitionedCopy::CompleteDelayedPartition(const vector &values, + const bool take_ready) { + const auto flush_threshold = Settings::Get(context); + return delayed_partition_buffers.Complete(values, flush_threshold, take_ready); +} + +bool PartitionedCopy::DrainDelayedPartitions(ExecutionContext &execution_context, InterruptState &interrupt_state) { + bool drained = false; + while (auto partition = TakeNextDelayedPartition()) { + drained = true; + FlushPartitionCollection(execution_context, interrupt_state, std::move(*partition)); + } + return drained; +} + +//===--------------------------------------------------------------------===// +// Partitioned Collection Helpers +//===--------------------------------------------------------------------===// +unique_ptr PartitionedCopy::SortPartitionCollection(ExecutionContext &execution_context, + InterruptState &interrupt_state, + unique_ptr collection) { + D_ASSERT(collection); + D_ASSERT(collection->Count() > 0); + + auto global_sink_state = sort_strategy->GetGlobalSinkState(context); + auto local_sink_state = sort_strategy->GetLocalSinkState(execution_context); + + ColumnDataScanState scan_state; + collection->InitializeScan(scan_state); + DataChunk scan_chunk; + collection->InitializeScanChunk(scan_state, scan_chunk); + + while (collection->Scan(scan_state, scan_chunk)) { + OperatorSinkInput sink_input {*global_sink_state, *local_sink_state, interrupt_state}; + sort_strategy->Sink(execution_context, scan_chunk, sink_input); + } + + OperatorSinkCombineInput combine_input {*global_sink_state, *local_sink_state, interrupt_state}; + sort_strategy->Combine(execution_context, combine_input); + + OperatorSinkFinalizeInput finalize_input {*global_sink_state, interrupt_state}; + sort_strategy->Finalize(context, finalize_input); + + auto global_source_state = sort_strategy->GetGlobalSourceState(context, *global_sink_state); + const auto &hash_groups = sort_strategy->GetHashGroups(*global_source_state); + + auto local_source_state = sort_strategy->GetLocalSourceState(execution_context, *global_source_state); + OperatorSourceInput source_input {*global_source_state, *local_source_state, interrupt_state}; + for (idx_t group_idx = 0; group_idx < hash_groups.size(); group_idx++) { + if (!hash_groups[group_idx].count) { + continue; + } + + sort_strategy->SortColumnData(execution_context, group_idx, finalize_input); + while (sort_strategy->MaterializeColumnData(execution_context, group_idx, source_input) == + SourceResultType::HAVE_MORE_OUTPUT) { + } + + auto result = sort_strategy->GetColumnData(group_idx, source_input); + if (!result || result->Count() != collection->Count()) { + throw InternalException("Failed to materialize delayed partition data"); + } + return result; + } + + throw InternalException("Failed to find delayed partition data after sorting"); +} + +PartitionedCopyCollectionSchema PartitionedCopy::GetPartitionCollectionSchema() const { + return op.order_columns.empty() ? PartitionedCopyCollectionSchema::WRITE_SCHEMA + : PartitionedCopyCollectionSchema::RAW_SCHEMA; +} + +const vector &PartitionedCopy::GetPartitionCollectionTypes(PartitionedCopyCollectionSchema schema) const { + return schema == PartitionedCopyCollectionSchema::RAW_SCHEMA ? op.expected_types : write_types; +} + +unique_ptr PartitionedCopy::PrepareCollectionForWrite(PartitionedCopyCollection data) { + D_ASSERT(data.collection); + if (data.schema == PartitionedCopyCollectionSchema::WRITE_SCHEMA) { + return std::move(data.collection); + } + return ProjectToWriteColumns(std::move(data.collection)); +} + +unique_ptr PartitionedCopy::ProjectToWriteColumns(unique_ptr collection) { + D_ASSERT(collection); + + auto result = make_uniq(context, write_types); + ColumnDataAppendState append_state; + result->InitializeAppend(append_state); + + ColumnDataScanState scan_state; + collection->InitializeScan(scan_state); + DataChunk scan_chunk; + collection->InitializeScanChunk(scan_state, scan_chunk); + + DataChunk write_chunk; + write_chunk.Initialize(context, write_types); + + while (collection->Scan(scan_state, scan_chunk)) { + write_chunk.ReferenceColumns(scan_chunk, write_columns); + result->Append(append_state, write_chunk); + } + + return result; +} + +//===--------------------------------------------------------------------===// +// Partitioned Write Helpers +//===--------------------------------------------------------------------===// +unique_ptr PartitionedCopy::PreparePartitionBatch(const vector &values, + PartitionWriteInfo &write_info, + PartitionedCopyCollection data) { + auto collection = PrepareCollectionForWrite(std::move(data)); + const auto create_file_state_fun = [&]() { + return CreatePartitionFileState(values); + }; + auto [batch_analyzer, prepared_batch] = + op.PrepareBatch(context, copy_gstate, write_info.file_state, create_file_state_fun, std::move(collection)); + return make_uniq(batch_analyzer, std::move(prepared_batch)); +} + +void PartitionedCopy::FlushPreparedPartitionRun(const vector &values, PartitionWriteInfo &write_info, + vector> batches) { + WithSerializedPartitionWriteRun(write_info, [&]() { + EnsureFreshPartitionFileForSortedRun(write_info, values); + for (auto &batch : batches) { + D_ASSERT(batch); + EnsureFreshPartitionFileForRotation(write_info, values); + FlushPreparedPartitionBatch(values, write_info, std::move(batch)); + } + }); +} + +void PartitionedCopy::FlushPreparedPartitionBatch(const vector &values, PartitionWriteInfo &write_info, + unique_ptr batch) { + D_ASSERT(batch); + const auto create_file_state_fun = [&]() { + return CreatePartitionFileState(values); + }; + op.FlushBatch(context, copy_gstate, write_info.file_state, create_file_state_fun, batch->batch_analyzer, + std::move(batch->prepared_batch)); +} + +void PartitionedCopy::FlushDelayedPartitionRun(const vector &values, PartitionWriteInfo &write_info, + ColumnDataCollection &collection) { + WithSerializedPartitionWriteRun(write_info, [&]() { + EnsureFreshPartitionFileForSortedRun(write_info, values); + + DataChunk scan_chunk; + ColumnDataScanState scan_state; + collection.InitializeScan(scan_state); + collection.InitializeScanChunk(scan_state, scan_chunk); + + auto batch = make_uniq(context, write_types); + ColumnDataAppendState append_state; + batch->InitializeAppend(append_state); + + const auto flush_batch = [&]() { + if (batch->Count() == 0) { + return; + } + EnsureFreshPartitionFileForRotation(write_info, values); + auto prepared_batch = PreparePartitionBatch( + values, write_info, + PartitionedCopyCollection(PartitionedCopyCollectionSchema::WRITE_SCHEMA, std::move(batch))); + FlushPreparedPartitionBatch(values, write_info, std::move(prepared_batch)); + + batch = make_uniq(context, write_types); + batch->InitializeAppend(append_state); + }; + + while (collection.Scan(scan_state, scan_chunk)) { + batch->Append(append_state, scan_chunk); + const CopyFunctionBatchAnalyzer batch_analyzer(*batch, op.batch_size, op.batch_size_bytes); + if (batch_analyzer.MeetsFlushCriteria()) { + flush_batch(); } - active_writes.clear(); } - FinalizeFileStates(std::move(files_to_finalize)); + flush_batch(); + }); +} + +void PartitionedCopy::FlushPartitionCollection(ExecutionContext &execution_context, InterruptState &interrupt_state, + DelayedPartitionFlush flush) { + while (true) { + DelayedPartitionFlushGuard delayed_guard(*this, flush.values); + if (flush.data.collection && flush.data.collection->Count() > 0) { + if (flush.data.schema == PartitionedCopyCollectionSchema::RAW_SCHEMA) { + flush.data.collection = + SortPartitionCollection(execution_context, interrupt_state, std::move(flush.data.collection)); + flush.data.collection = ProjectToWriteColumns(std::move(flush.data.collection)); + flush.data.schema = PartitionedCopyCollectionSchema::WRITE_SCHEMA; + } + D_ASSERT(flush.data.schema == PartitionedCopyCollectionSchema::WRITE_SCHEMA); + + auto &write_info = GetPartitionWriteInfo(flush.values); + PartitionWriteInfoGuard write_guard(*this, write_info); + FlushDelayedPartitionRun(flush.values, write_info, *flush.data.collection); + } + + auto next = delayed_guard.Complete(); + if (!next) { + return; + } + flush = std::move(*next); } } PartitionWriteInfo &PartitionedCopy::GetPartitionWriteInfo(const vector &values) { PartitionWriteInfo *result; - vector> files_to_finalize; { annotated_lock_guard guard(active_writes_lock); // check if we have already started writing this partition @@ -1262,9 +2162,6 @@ PartitionWriteInfo &PartitionedCopy::GetPartitionWriteInfo(const vector & result = active_write_entry->second.get(); } else { auto info = make_uniq(); - auto create_result = CreatePartitionFileStateLocked(values); - info->file_state = std::move(create_result.file_state); - files_to_finalize = std::move(create_result.files_to_finalize); result = info.get(); info->active_writes = 1; @@ -1273,30 +2170,36 @@ PartitionWriteInfo &PartitionedCopy::GetPartitionWriteInfo(const vector & } } - FinalizeFileStates(std::move(files_to_finalize)); return *result; } -unique_ptr PartitionedCopy::CreatePartitionFileState(const vector &values) { - CreatePartitionFileStateResult result; +void PartitionedCopy::ReleasePartitionWriteInfo(PartitionWriteInfo &write_info) { + annotated_lock_guard guard(active_writes_lock); + D_ASSERT(write_info.active_writes > 0); + write_info.active_writes--; +} + +unique_ptr PartitionedCopy::CreatePartitionFileState(const vector &values, + FileCreationReason reason) { + PartitionFileStateReservation reservation; { annotated_lock_guard guard(active_writes_lock); - result = CreatePartitionFileStateLocked(values); + reservation = ReservePartitionFileStateLocked(values, reason); } - FinalizeFileStates(std::move(result.files_to_finalize)); - return std::move(result.file_state); + FinalizeFileStates(std::move(reservation.files_to_finalize)); + return CreatePartitionFileStateFromReservation(values, reservation.offset); } -CreatePartitionFileStateResult PartitionedCopy::CreatePartitionFileStateLocked(const vector &values) { - CreatePartitionFileStateResult result; - idx_t offset = 0; +PartitionFileStateReservation PartitionedCopy::ReservePartitionFileStateLocked(const vector &values, + FileCreationReason reason) { + PartitionFileStateReservation reservation; // check if we need to close any writers before we can continue if (active_writes.size() >= Settings::Get(context)) { // we need to! try to close writers for (auto it = active_writes.begin(); it != active_writes.end(); ++it) { if (it->second->active_writes == 0) { // we can evict this entry - evict the partition - result.files_to_finalize.push_back(std::move(it->second->file_state)); + reservation.files_to_finalize.push_back(std::move(it->second->file_state)); ++previous_partitions[it->first]; active_writes.erase(it); break; @@ -1305,15 +2208,24 @@ CreatePartitionFileStateResult PartitionedCopy::CreatePartitionFileStateLocked(c } if (op.hive_file_pattern) { + if (reason == FileCreationReason::SORTED_RUN_BOUNDARY || reason == FileCreationReason::ROTATION) { + ++previous_partitions[values]; + } auto prev_offset = previous_partitions.find(values); if (prev_offset != previous_partitions.end()) { - offset = prev_offset->second; + reservation.offset = prev_offset->second; } } else { - offset = global_offset++; + reservation.offset = global_offset++; } - // Access global state under lock + return reservation; +} + +unique_ptr PartitionedCopy::CreatePartitionFileStateFromReservation(const vector &values, + idx_t offset) { + // The reservation/eviction decision has already been made under active_writes_lock. This section only + // serializes global file bookkeeping: directory tracking, filename registration, and writer initialization. annotated_lock_guard guard(copy_gstate.lock); // Create a writer for the current file @@ -1331,8 +2243,76 @@ CreatePartitionFileStateResult PartitionedCopy::CreatePartitionFileStateLocked(c } } // Initialize write - result.file_state = copy_gstate.CreateFileState(full_path, values); - return result; + return copy_gstate.CreateFileStateLocked(full_path, values); +} + +void PartitionedCopy::EnsureFreshPartitionFileForSortedRun(PartitionWriteInfo &write_info, + const vector &values) { + if (op.order_columns.empty()) { + return; + } + EnsureFreshPartitionFile(write_info, values, FileCreationReason::SORTED_RUN_BOUNDARY); +} + +void PartitionedCopy::EnsureFreshPartitionFileForRotation(PartitionWriteInfo &write_info, const vector &values) { + if (!op.Rotate()) { + return; + } + EnsureFreshPartitionFile(write_info, values, FileCreationReason::ROTATION); +} + +void PartitionedCopy::EnsureFreshPartitionFile(PartitionWriteInfo &write_info, const vector &values, + FileCreationReason reason) { + D_ASSERT(reason == FileCreationReason::SORTED_RUN_BOUNDARY || reason == FileCreationReason::ROTATION); + D_ASSERT(RequiresSerializedPartitionWrites()); + + optional_ptr old_file_state_ptr; + { + annotated_lock_guard global_guard(copy_gstate.lock); + if (!write_info.file_state) { + return; + } + old_file_state_ptr = write_info.file_state.get(); + annotated_lock_guard file_guard(old_file_state_ptr->lock); + if (reason == FileCreationReason::SORTED_RUN_BOUNDARY && old_file_state_ptr->num_batches == 0) { + return; + } + if (reason == FileCreationReason::ROTATION && !PhysicalCopyRotateNow(op, *old_file_state_ptr)) { + return; + } + } + + auto new_file_state = CreatePartitionFileState(values, reason); + + unique_ptr old_file_state; + { + annotated_lock_guard global_guard(copy_gstate.lock); + D_ASSERT(write_info.file_state); + D_ASSERT(RefersToSameObject(*old_file_state_ptr.get(), *write_info.file_state)); + annotated_lock_guard file_guard(write_info.file_state->lock); + if (reason == FileCreationReason::SORTED_RUN_BOUNDARY) { + D_ASSERT(write_info.file_state->num_batches > 0); + } else { + D_ASSERT(PhysicalCopyRotateNow(op, *write_info.file_state)); + } + + old_file_state = std::move(write_info.file_state); + write_info.file_state = std::move(new_file_state); + } + + copy_gstate.FinalizeFileState(std::move(old_file_state)); +} + +void PartitionedCopy::FinalizeActiveWrites() { + vector> files_to_finalize; + { + annotated_lock_guard aw_guard(active_writes_lock); + for (auto &entry : active_writes) { + files_to_finalize.push_back(std::move(entry.second->file_state)); + } + active_writes.clear(); + } + FinalizeFileStates(std::move(files_to_finalize)); } void PartitionedCopy::FinalizeFileStates(vector> files_to_finalize) { @@ -1351,7 +2331,7 @@ string PartitionedCopy::GetOrCreateDirectory(string path, const vector &v const auto &partition_col_name = op.names[op.partition_columns[i]]; const auto &partition_value = values[i]; string p_dir; - p_dir += HivePartitioning::Escape(partition_col_name); + p_dir += HivePartitioning::Escape(partition_col_name.GetIdentifierName()); p_dir += "="; if (partition_value.IsNull()) { p_dir += "__HIVE_DEFAULT_PARTITION__"; @@ -1366,11 +2346,11 @@ string PartitionedCopy::GetOrCreateDirectory(string path, const vector &v } //===--------------------------------------------------------------------===// -// CopyToFileGlobalState cpp +// Copy Global State Implementation //===--------------------------------------------------------------------===// CopyToFileGlobalState::CopyToFileGlobalState(const PhysicalCopyToFile &op_p, ClientContext &context_p) : op(op_p), context(context_p), initialized(false), finalized(false), prepare_global_state(nullptr), - create_file_state_fun([&]() DUCKDB_REQUIRES(lock) { return CreateFileState(); }), rows_copied(0), + create_file_state_fun([&]() DUCKDB_EXCLUDES(lock) { return CreateFileState(); }), rows_copied(0), last_file_offset(0) { } @@ -1399,7 +2379,7 @@ void CopyToFileGlobalState::Initialize() { return; } // initialize writing to the file - global_state = CreateFileState(op.file_path); + global_state = CreateFileStateLocked(op.file_path); initialized = true; } @@ -1415,8 +2395,8 @@ void CopyToFileGlobalState::CreateDir(const string &dir_path) { created_directories.insert(dir_path); } -unique_ptr CopyToFileGlobalState::CreateFileState(string output_path, - optional_ptr> partition_values) { +unique_ptr +CopyToFileGlobalState::CreateFileStateLocked(string output_path, optional_ptr> partition_values) { auto &fs = FileSystem::GetFileSystem(context); if (output_path.empty()) { output_path = op.filename_pattern.CreateFilename(fs, op.file_path, op.file_extension, last_file_offset++); @@ -1459,6 +2439,12 @@ unique_ptr CopyToFileGlobalState::CreateFileState(string output return res; } +unique_ptr CopyToFileGlobalState::CreateFileState(string output_path, + optional_ptr> partition_values) { + annotated_lock_guard guard(lock); + return CreateFileStateLocked(std::move(output_path), partition_values); +} + optional_ptr CopyToFileGlobalState::AddFile(const string &file_name) { auto file_info = make_uniq(file_name); optional_ptr result; @@ -1508,7 +2494,7 @@ void CopyToFileGlobalState::TryFinalizeOwnedFileState() { } //===--------------------------------------------------------------------===// -// CopyToFileLocalState hpp +// Copy Local State Implementation //===--------------------------------------------------------------------===// CopyToFileLocalState::CopyToFileLocalState(const PhysicalCopyToFile &op_p, ExecutionContext &context_p, CopyToFileGlobalState &gstate_p) @@ -1530,7 +2516,7 @@ PhysicalCopyToFile::PhysicalCopyToFile(PhysicalPlan &physical_plan, vector PhysicalCopyToFile::ParamsToString() const { InsertionOrderPreservingMap result; - result["FORMAT"] = StringUtil::Upper(function.name); + result["FORMAT"] = StringUtil::Upper(function.name.GetIdentifierName()); return result; } @@ -1625,9 +2611,9 @@ unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext } auto state = make_uniq(*this, context); - if (!per_thread_output && Rotate() && write_empty_file) { + if (!partition_output && !per_thread_output && Rotate() && write_empty_file) { annotated_lock_guard guard(state->lock); - state->global_state = state->CreateFileState(); + state->global_state = state->CreateFileStateLocked(); } if (partition_output) { @@ -1767,7 +2753,7 @@ SinkFinalizeType PhysicalCopyToFile::Finalize(Pipeline &pipeline, Event &event, if (NumericCast(gstate.rows_copied.load()) == 0 && sink_state != nullptr) { // no rows from source, write schema to file annotated_lock_guard guard(gstate.lock); - gstate.global_state = gstate.CreateFileState(); + gstate.global_state = gstate.CreateFileStateLocked(); } } @@ -1841,6 +2827,52 @@ static void FlushLegacyCopyBatch(ClientContext &context, const CopyFunction &fun //===--------------------------------------------------------------------===// // Prepare/Flush Batch //===--------------------------------------------------------------------===// +static void EnsureFileState(CopyToFileGlobalState &gstate, unique_ptr &file_state_ptr, + const std::function()> &create_file_state_fun) { + auto file_state_key = &file_state_ptr; + while (true) { + bool create_file_state = false; + { + annotated_lock_guard guard(gstate.lock); + if (file_state_ptr) { + return; + } + if (gstate.creating_file_states.insert(file_state_key).second) { + create_file_state = true; + } + } + + if (!create_file_state) { + TaskScheduler::YieldThread(); + continue; + } + + unique_ptr new_file_state; + try { + new_file_state = create_file_state_fun(); + } catch (...) { + annotated_lock_guard guard(gstate.lock); + gstate.creating_file_states.erase(file_state_key); + throw; + } + + unique_ptr unused_file_state; + { + annotated_lock_guard guard(gstate.lock); + if (!file_state_ptr) { + file_state_ptr = std::move(new_file_state); + } else { + unused_file_state = std::move(new_file_state); + } + gstate.creating_file_states.erase(file_state_key); + } + if (unused_file_state) { + gstate.FinalizeFileState(std::move(unused_file_state)); + } + return; + } +} + pair> PhysicalCopyToFile::PrepareBatch(ClientContext &context, GlobalSinkState &gstate_p, unique_ptr &file_state_ptr, @@ -1857,13 +2889,9 @@ PhysicalCopyToFile::PrepareBatch(ClientContext &context, GlobalSinkState &gstate auto prepare_global_state = gstate.prepare_global_state.load(std::memory_order_acquire); if (!prepare_global_state) { D_ASSERT(!file_state_ptr); - annotated_unique_lock global_guard(gstate.lock); + EnsureFileState(gstate, file_state_ptr, create_file_state_fun); prepare_global_state = gstate.prepare_global_state.load(std::memory_order_acquire); - if (!prepare_global_state) { - file_state_ptr = create_file_state_fun(); - prepare_global_state = gstate.prepare_global_state.load(std::memory_order_acquire); - D_ASSERT(prepare_global_state); - } + D_ASSERT(prepare_global_state); } // Prepare the batch @@ -1878,20 +2906,25 @@ void PhysicalCopyToFile::FlushBatch(ClientContext &context, GlobalSinkState &gst auto &gstate = gstate_p.Cast(); while (true) { + EnsureFileState(gstate, file_state_ptr, create_file_state_fun); + // Decide which file to flush to annotated_unique_lock global_guard(gstate.lock); if (!file_state_ptr) { - file_state_ptr = create_file_state_fun(); + global_guard.unlock(); + TaskScheduler::YieldThread(); + continue; } annotated_unique_lock file_guard(file_state_ptr->lock); if (PhysicalCopyRotateNow(*this, *file_state_ptr)) { // Global state must be rotated. Move to local scope, create an new one, and immediately release global lock auto owned_file_state = std::move(file_state_ptr); - file_state_ptr = create_file_state_fun(); file_guard.unlock(); global_guard.unlock(); + EnsureFileState(gstate, file_state_ptr, create_file_state_fun); + // Finalize this file! gstate.FinalizeFileState(std::move(owned_file_state)); } else { @@ -1947,13 +2980,11 @@ SourceResultType PhysicalCopyToFile::GetDataInternal(ExecutionContext &context, } ReturnStatistics(chunk, file_entry); } - chunk.SetCardinality(count); source_state.offset += count; return source_state.offset < gstate.written_files.size() ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; } - chunk.SetCardinality(1); switch (return_type) { case CopyFunctionReturnType::CHANGED_ROWS: chunk.data[0].Append(Value::BIGINT(NumericCast(gstate.rows_copied.load()))); diff --git a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp index 50f10c1f0..893ad1026 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp @@ -169,7 +169,6 @@ SinkResultType PhysicalDelete::Sink(ExecutionContext &context, DataChunk &chunk, for (idx_t i = table.ColumnCount(); i < l_state.delete_chunk.ColumnCount(); i++) { l_state.delete_chunk.data[i].Reference(row_ids); } - l_state.delete_chunk.SetCardinality(chunk.size()); // Slice down to the claimed rows (RETURNING only) if (return_chunk && delete_count != chunk.size()) { @@ -262,7 +261,6 @@ SourceResultType PhysicalDelete::GetDataInternal(ExecutionContext &context, Data auto &state = input.global_state.Cast(); auto &g = sink_state->Cast(); if (!return_chunk) { - chunk.SetCardinality(1); chunk.data[0].Append(Value::BIGINT(NumericCast(g.deleted_count.load()))); return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp index 1ad0b843c..9f0ab53e1 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp @@ -140,7 +140,6 @@ void CheckOnConflictCondition(ExecutionContext &context, DataChunk &conflicts, c ExpressionExecutor executor(context.client, *condition); result.Initialize(context.client, {LogicalType::BOOLEAN}); executor.Execute(conflicts, result); - result.SetCardinality(conflicts.size()); } static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_chunk, DataChunk &input_chunk, @@ -152,7 +151,6 @@ static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_ch // We have not scanned the initial table, so we can just duplicate the initial chunk result.Initialize(client, input_chunk.GetTypes()); result.Reference(input_chunk); - result.SetCardinality(input_chunk); return; } vector combined_types; @@ -165,7 +163,7 @@ static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_ch // Add the VALUES list for (idx_t i = 0; i < insert_types.size(); i++) { idx_t col_idx = i; - auto &other_col = input_chunk.data[i]; + const auto &other_col = input_chunk.data[i]; auto &this_col = result.data[col_idx]; D_ASSERT(other_col.GetType() == this_col.GetType()); this_col.Reference(other_col); @@ -173,7 +171,7 @@ static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_ch // Add the columns from the original conflicting tuples for (idx_t i = 0; i < types_to_fetch.size(); i++) { idx_t col_idx = i + insert_types.size(); - auto &other_col = scan_chunk.data[i]; + const auto &other_col = scan_chunk.data[i]; auto &this_col = result.data[col_idx]; D_ASSERT(other_col.GetType() == this_col.GetType()); this_col.Reference(other_col); @@ -184,7 +182,6 @@ static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_ch // We can have a SET expression without a conflict target ONLY if there is only 1 Index on the table // In which case this also can't cause a discrepancy between existing tuple count and insert tuple count D_ASSERT(input_chunk.size() == scan_chunk.size()); - result.SetCardinality(input_chunk.size()); } static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, DuckTableEntry &table, Vector &row_ids, @@ -199,7 +196,6 @@ static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, DuckT do_update_filter_result.Initialize(context.client, {LogicalType::BOOLEAN}); ExpressionExecutor where_executor(context.client, *do_update_condition); where_executor.Execute(chunk, do_update_filter_result); - do_update_filter_result.SetCardinality(chunk.size()); do_update_filter_result.Flatten(); SelectionVector sel(chunk.size()); @@ -215,7 +211,6 @@ static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, DuckT if (count != chunk.size()) { // Filter any conflicts not meeting the condition. chunk.Slice(sel, count); - chunk.SetCardinality(count); row_ids.Slice(sel, count); row_ids.Flatten(); } @@ -224,7 +219,7 @@ static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, DuckT if (chunk.size() == 0) { auto initialize = vector(set_types.size(), false); update_chunk.Initialize(context.client, set_types, initialize, chunk.size()); - update_chunk.SetCardinality(chunk); + update_chunk.SetChildCardinality(chunk.size()); return; } @@ -232,7 +227,6 @@ static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, DuckT update_chunk.Initialize(context.client, set_types, chunk.size()); ExpressionExecutor executor(context.client, set_expressions); executor.Execute(chunk, update_chunk); - update_chunk.SetCardinality(chunk); } template @@ -256,7 +250,6 @@ static idx_t PerformOnConflictAction(InsertLocalState &lstate, InsertGlobalState // Arrange the columns in the standard table order. DataChunk &append_chunk = lstate.append_chunk; - append_chunk.SetCardinality(update_chunk); for (idx_t i = 0; i < append_chunk.ColumnCount(); i++) { append_chunk.data[i].Reference(chunk.data[i]); } @@ -351,9 +344,9 @@ static void PrepareSortKeys(DataChunk &input, unordered_map(LogicalType::BLOB); - CreateSortKeyHelpers::CreateSortKey(column, input.size(), order_modifiers, *sort_key); + CreateSortKeyHelpers::CreateSortKey(column, order_modifiers, *sort_key); } } @@ -487,7 +480,6 @@ static idx_t HandleInsertConflicts(DuckTableEntry &table, ExecutionContext &cont conflict_chunk.InitializeEmpty(tuples.GetTypes()); conflict_chunk.Reference(tuples); conflict_chunk.Slice(conflict_manager.GetInvertedSel(), conflict_count); - conflict_chunk.SetCardinality(conflict_count); // Contains the conflict chunk and the scanned chunk (wide). DataChunk combined_chunk; @@ -511,7 +503,6 @@ static idx_t HandleInsertConflicts(DuckTableEntry &table, ExecutionContext &cont auto &inverted_sel = conflict_manager.GetInvertedSel(); auto new_size = SelectionVector::Inverted(inverted_sel, sel_vec, conflict_count, tuples.size()); tuples.Slice(sel_vec, new_size); - tuples.SetCardinality(new_size); return affected_tuples; } @@ -600,11 +591,9 @@ idx_t PhysicalInsert::OnConflictHandling(DuckTableEntry &table, ExecutionContext lstate.update_chunk.Reference(insert_chunk); lstate.update_chunk.Slice(last_occurrences, last_occurrences_count); - lstate.update_chunk.SetCardinality(last_occurrences_count); } insert_chunk.Slice(sel, sel_count); - insert_chunk.SetCardinality(sel_count); } // Check whether any conflicts arise, and if they all meet the conflict_target + condition @@ -667,9 +656,9 @@ SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &insert auto &optimistic_collection = data_table.GetOptimisticCollection(context.client, lstate.collection_index); auto &collection = *optimistic_collection.collection; - auto new_row_group = collection.Append(insert_chunk, lstate.local_append_state); - if (new_row_group) { - lstate.optimistic_writer->WriteNewRowGroup(optimistic_collection); + auto flushed_row_group_idx = collection.Append(insert_chunk, lstate.local_append_state); + if (flushed_row_group_idx.IsValid()) { + lstate.optimistic_writer->WriteNewRowGroup(optimistic_collection, flushed_row_group_idx.GetIndex()); } return SinkResultType::NEED_MORE_INPUT; } @@ -711,8 +700,6 @@ SinkCombineResultType PhysicalInsert::Combine(ExecutionContext &context, Operato storage.FinalizeLocalAppend(append_state); } else { // we have written rows to disk optimistically - merge directly into the transaction-local storage - lstate.optimistic_writer->WriteUnflushedRowGroups(optimistic_collection); - lstate.optimistic_writer->FinalFlush(); gstate.table.GetStorage().LocalMerge(context.client, gstate.table, optimistic_collection); auto &optimistic_writer = gstate.table.GetStorage().GetOptimisticWriter(context.client); optimistic_writer.Merge(*lstate.optimistic_writer); @@ -751,7 +738,6 @@ SourceResultType PhysicalInsert::GetDataInternal(ExecutionContext &context, Data auto &state = input.global_state.Cast(); auto &insert_gstate = sink_state->Cast(); if (!return_chunk) { - chunk.SetCardinality(1); chunk.data[0].Append(Value::BIGINT(NumericCast(insert_gstate.insert_count))); return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp b/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp index 030479df9..094224692 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp @@ -1,6 +1,7 @@ #include "duckdb/execution/operator/persistent/physical_merge_into.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" namespace duckdb { @@ -193,7 +194,7 @@ class MergeIntoGlobalState : public GlobalSinkState { if (action->action_type == MergeActionType::MERGE_ERROR) { // abort - generate an error message string merge_condition; - merge_condition += MergeIntoStatement::ActionConditionToString(range.condition); + merge_condition += MergeQueryNode::ActionConditionToString(range.condition); if (action->condition) { merge_condition += " AND " + action->condition->ToString(); } @@ -456,7 +457,6 @@ SourceResultType PhysicalMergeInto::GetDataInternal(ExecutionContext &context, D OperatorSourceInput &input) const { auto &g = sink_state->Cast(); if (!return_chunk) { - chunk.SetCardinality(1); chunk.data[0].Append(Value::BIGINT(NumericCast(g.merged_count.load()))); return SourceResultType::FINISHED; } @@ -503,7 +503,6 @@ SourceResultType PhysicalMergeInto::GetDataInternal(ExecutionContext &context, D } Value merge_action(merge_action_name); chunk.data.back().Reference(merge_action, count_t(lstate.scan_chunk.size())); - chunk.SetCardinality(lstate.scan_chunk.size()); } if (result != SourceResultType::FINISHED) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_update.cpp b/src/duckdb/src/execution/operator/persistent/physical_update.cpp index 27c8dedc5..36fcfc822 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_update.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_update.cpp @@ -114,7 +114,6 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, DataChunk &update_chunk = l_state.update_chunk; update_chunk.Reset(); - update_chunk.SetCardinality(chunk); for (idx_t i = 0; i < expressions.size(); i++) { // Default expression, set to the default value of the column. @@ -125,7 +124,7 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, D_ASSERT(expressions[i]->GetExpressionType() == ExpressionType::BOUND_REF); auto &binding = expressions[i]->Cast(); - update_chunk.data[i].Reference(chunk.data[binding.index]); + update_chunk.data[i].Reference(chunk.data[binding.Index()]); } auto &row_ids = chunk.data[chunk.ColumnCount() - 1]; @@ -134,10 +133,12 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, // Regular in-place update. if (!update_is_del_and_insert) { if (return_chunk) { - mock_chunk.SetCardinality(update_chunk); + // (re)reference all output columns first, then validate + set the cardinality. mock_chunk is not reset + // here, but with return_chunk the update projects every table column, so all columns are referenced. for (idx_t i = 0; i < columns.size(); i++) { mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); } + mock_chunk.CheckCardinality(update_chunk.size()); } auto &update_state = l_state.GetUpdateState(table, tableref, context.client); table.Update(update_state, context.client, tableref, row_ids, columns, update_chunk); @@ -177,7 +178,6 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, auto &delete_chunk = index_update ? l_state.delete_chunk : l_state.mock_chunk; delete_chunk.Reset(); - delete_chunk.SetCardinality(update_count); if (index_update) { auto &transaction = DuckTransaction::Get(context.client, table.db); @@ -193,11 +193,12 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, auto &delete_state = l_state.GetDeleteState(table, tableref, context.client); table.Delete(delete_state, context.client, tableref, del_row_ids, update_count); - // Arrange the columns in the standard table order. - mock_chunk.SetCardinality(update_count); + // Arrange the columns in the standard table order, then validate + set the cardinality from the referenced + // columns. The del+insert path projects every table column (it re-inserts the full row), so all are referenced. for (idx_t i = 0; i < columns.size(); i++) { mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); } + mock_chunk.CheckCardinality(update_count); table.LocalAppend(tableref, context.client, mock_chunk, bound_constraints, del_row_ids, delete_chunk); if (return_chunk) { @@ -249,7 +250,6 @@ SourceResultType PhysicalUpdate::GetDataInternal(ExecutionContext &context, Data auto &state = input.global_state.Cast(); auto &g = sink_state->Cast(); if (!return_chunk) { - chunk.SetCardinality(1); chunk.data[0].Append(Value::BIGINT(NumericCast(g.updated_count.load()))); return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp index e4cdcaf6f..52188c8a8 100644 --- a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp @@ -23,12 +23,12 @@ PhysicalPivot::PhysicalPivot(PhysicalPlan &physical_plan, vector ty for (auto &aggr_expr : bound_pivot.aggregates) { auto &aggr = aggr_expr->Cast(); // for each aggregate, initialize an empty aggregate state and finalize it immediately - auto state = make_unsafe_uniq_array(aggr.function.GetStateSizeCallback()(aggr.function)); - aggr.function.GetStateInitCallback()(aggr.function, state.get()); + auto state = make_unsafe_uniq_array(aggr.Function().GetStateSizeCallback()(aggr.Function())); + aggr.Function().GetStateInitCallback()(aggr.Function(), state.get()); Vector state_vector(Value::POINTER(CastPointerToValue(state.get())), count_t(1)); Vector result_vector(aggr_expr->GetReturnType()); - AggregateInputData aggr_input_data(aggr.bind_info.get(), physical_plan.ArenaRef()); - aggr.function.GetStateFinalizeCallback()(state_vector, aggr_input_data, result_vector, 1, 0); + AggregateInputData aggr_input_data(aggr, physical_plan.ArenaRef()); + aggr.Function().GetStateFinalizeCallback()(state_vector, aggr_input_data, result_vector, 1, 0); empty_aggregates.push_back(result_vector.GetValue(0)); } } @@ -79,7 +79,6 @@ OperatorResultType PhysicalPivot::Execute(ExecutionContext &context, DataChunk & } } } - chunk.SetCardinality(input.size()); return OperatorResultType::NEED_MORE_INPUT; } diff --git a/src/duckdb/src/execution/operator/projection/physical_projection.cpp b/src/duckdb/src/execution/operator/projection/physical_projection.cpp index b87f8cfad..72bd21aed 100644 --- a/src/duckdb/src/execution/operator/projection/physical_projection.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_projection.cpp @@ -17,6 +17,10 @@ class ProjectionState : public OperatorState { void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { context.thread.profiler.Flush(op); } + + bool SupportsReuse() const override { + return true; + } }; PhysicalProjection::PhysicalProjection(PhysicalPlan &physical_plan, vector types, diff --git a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp index 32dc4818e..1e2564de5 100644 --- a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp @@ -112,7 +112,6 @@ OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context ConstantVector::Reference(state.input_chunk.data[col_idx], count_t(1), input.data[col_idx], state.row_index, input.size()); } - state.input_chunk.SetCardinality(1); state.row_index++; state.new_row = false; state.current_ordinality_idx = 1; @@ -153,7 +152,7 @@ InsertionOrderPreservingMap PhysicalTableInOutFunction::ParamsToString() result[it.first] = it.second; } } else { - result["Name"] = function.name; + result["Name"] = function.name.GetIdentifierName(); } SetEstimatedCardinality(result, estimated_cardinality); return result; diff --git a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp index 0ba97b0cb..9b9731fee 100644 --- a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp @@ -22,8 +22,8 @@ class UnnestOperatorState : public OperatorState { for (auto &exp : select_list) { D_ASSERT(exp->GetExpressionType() == ExpressionType::BOUND_UNNEST); auto &bue = exp->Cast(); - list_data_types.push_back(bue.child->GetReturnType()); - executor.AddExpression(*bue.child.get()); + list_data_types.push_back(bue.Child()->GetReturnType()); + executor.AddExpression(*bue.Child().get()); unnest_sels.emplace_back(STANDARD_VECTOR_SIZE); null_sels.emplace_back(STANDARD_VECTOR_SIZE); @@ -53,12 +53,12 @@ class UnnestOperatorState : public OperatorState { public: //! Reset the fields of the unnest operator state - void Reset(); + void ResetState(); //! Prepare the input for the next unnest void PrepareInput(DataChunk &input, const vector> &select_list); }; -void UnnestOperatorState::Reset() { +void UnnestOperatorState::ResetState() { current_row = 0; list_position = 0; first_fetch = true; @@ -88,16 +88,16 @@ void UnnestOperatorState::PrepareInput(DataChunk &input, const vector= input.size()) { - state.Reset(); + state.ResetState(); return OperatorResultType::NEED_MORE_INPUT; } @@ -217,7 +217,7 @@ OperatorResultType PhysicalUnnest::ExecuteInternal(ExecutionContext &context, Da } } idx_t col_offset = 0; - chunk.SetCardinality(result_length); + chunk.SetChildCardinality(result_length); if (result_length == 0) { // nothing to unnest - skip column processing entirely continue; @@ -236,7 +236,7 @@ OperatorResultType PhysicalUnnest::ExecuteInternal(ExecutionContext &context, Da col_offset = input.ColumnCount(); } for (idx_t col_idx = 0; col_idx < state.list_data.ColumnCount(); col_idx++) { - auto &list_vector = state.list_data.data[col_idx]; + const auto &list_vector = state.list_data.data[col_idx]; auto &result_vector = chunk.data[col_offset + col_idx]; if (state.list_data.data[col_idx].GetType() == LogicalType::SQLNULL || ListType::GetChildType(state.list_data.data[col_idx].GetType()) == LogicalType::SQLNULL || @@ -246,7 +246,7 @@ OperatorResultType PhysicalUnnest::ExecuteInternal(ExecutionContext &context, Da ConstantVector::SetNull(result_vector, count_t(result_length)); continue; } - auto &child_vector = ListVector::GetChild(list_vector); + const auto &child_vector = ListVector::GetChild(list_vector); result_vector.Slice(child_vector, state.unnest_sels[col_idx], result_length); if (state.null_counts[col_idx] > 0) { // we have NULL values that we need to set - flatten diff --git a/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp index 05598dc0c..411a733d8 100644 --- a/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp @@ -7,7 +7,6 @@ SourceResultType PhysicalDummyScan::GetDataInternal(ExecutionContext &context, D OperatorSourceInput &input) const { // return a single row on the first call to the dummy scan FlatVector::SetSize(chunk.data[0], 1); - chunk.SetCardinality(1); return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp index 676886709..199f9c1f7 100644 --- a/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp @@ -7,14 +7,30 @@ namespace duckdb { class ExpressionScanState : public OperatorState { public: - explicit ExpressionScanState(Allocator &allocator, const PhysicalExpressionScan &op) : expression_index(0) { + explicit ExpressionScanState(Allocator &allocator, const PhysicalExpressionScan &op) { temp_chunk.Initialize(allocator, op.GetTypes()); + ResetState(); } //! The current position in the scan idx_t expression_index; //! Temporary chunk for evaluating expressions DataChunk temp_chunk; + +private: + void ResetState() { + expression_index = 0; + temp_chunk.Reset(); + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset() override { + ResetState(); + } }; unique_ptr PhysicalExpressionScan::GetOperatorState(ExecutionContext &context) const { diff --git a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp index 205b0c552..b152d097a 100644 --- a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp @@ -55,22 +55,24 @@ class PositionalScanGlobalSourceState : public GlobalSourceState { class PositionalTableScanner { public: PositionalTableScanner(ExecutionContext &context, PhysicalOperator &table_p, GlobalSourceState &gstate_p) - : table(table_p), global_state(gstate_p), source_offset(0), exhausted(false) { + : table(table_p), global_state(gstate_p), source_count(0), source_offset(0), exhausted(false) { local_state = table.GetLocalSourceState(context, gstate_p); source.Initialize(Allocator::Get(context.client), table.types); } idx_t Refill(ExecutionContext &context) { - if (source_offset >= source.size()) { + if (source_offset >= source_count) { if (!exhausted) { source.Reset(); + source_count = 0; InterruptState interrupt_state; OperatorSourceInput source_input {global_state, *local_state, interrupt_state}; auto source_result = SourceResultType::HAVE_MORE_OUTPUT; - while (source_result == SourceResultType::HAVE_MORE_OUTPUT && source.size() == 0) { + while (source_result == SourceResultType::HAVE_MORE_OUTPUT && source_count == 0) { // TODO: this could as well just be propagated further, but for now iterating it is source_result = table.GetData(context, source, source_input); + source_count = source.size(); if (source_result == SourceResultType::BLOCKED) { throw NotImplementedException( "Unexpected interrupt from table Source in PositionalTableScanner refill"); @@ -80,10 +82,11 @@ class PositionalTableScanner { source_offset = 0; } - const auto available = source.size() - source_offset; + const auto available = source_count - source_offset; if (!available) { if (!exhausted) { source.Reset(); + source_count = 0; for (idx_t i = 0; i < source.ColumnCount(); ++i) { auto &vec = source.data[i]; ConstantVector::SetNull(vec, count_t(STANDARD_VECTOR_SIZE)); @@ -96,7 +99,7 @@ class PositionalTableScanner { } idx_t CopyData(ExecutionContext &context, DataChunk &output, const idx_t count, const idx_t col_offset) { - if (!source_offset && (source.size() >= count || exhausted)) { + if (!source_offset && (source_count >= count || exhausted)) { // Fast track: aligned and has enough data for (idx_t i = 0; i < source.ColumnCount(); ++i) { output.data[col_offset + i].Reference(source.data[i]); @@ -106,15 +109,15 @@ class PositionalTableScanner { // Copy data for (idx_t target_offset = 0; target_offset < count;) { const auto needed = count - target_offset; - const auto available = exhausted ? needed : (source.size() - source_offset); - const auto copy_size = MinValue(needed, available); - const auto source_count = source_offset + copy_size; + const auto available = exhausted ? needed : (source_count - source_offset); + const auto copy_count = MinValue(needed, available); + const auto source_end = source_offset + copy_count; for (idx_t i = 0; i < source.ColumnCount(); ++i) { - VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_count, source_offset, + VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_end, source_offset, target_offset); } - target_offset += copy_size; - source_offset += copy_size; + target_offset += copy_count; + source_offset += copy_count; Refill(context); } } @@ -130,6 +133,7 @@ class PositionalTableScanner { GlobalSourceState &global_state; unique_ptr local_state; DataChunk source; + idx_t source_count; idx_t source_offset; bool exhausted; }; diff --git a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp index c986196f0..cd1524cf3 100644 --- a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp @@ -59,7 +59,6 @@ class TableScanGlobalSourceState : public GlobalSourceState { for (idx_t c = 0; c < op.parameters.size(); c++) { input_chunk.data[c].Reference(op.parameters[c], count_t(1)); } - input_chunk.SetCardinality(1); } } @@ -187,12 +186,15 @@ SourceResultType PhysicalTableScan::GetDataInternal(ExecutionContext &context, D switch (output_async_result) { case AsyncResultType::BLOCKED: { D_ASSERT(data.async_result.HasTasks()); - annotated_lock_guard guard(g_state.lock); - if (g_state.CanBlock()) { - data.async_result.ScheduleTasks(input.interrupt_state, context.pipeline->executor); - return SourceResultType::BLOCKED; + { + annotated_lock_guard guard(g_state.lock); + if (g_state.CanBlock()) { + data.async_result.ScheduleTasks(input.interrupt_state, context.pipeline->executor); + return SourceResultType::BLOCKED; + } } - return SourceResultType::FINISHED; + data.async_result.ExecuteTasksSynchronously(); + return SourceResultType::HAVE_MORE_OUTPUT; } case AsyncResultType::IMPLICIT: if (chunk.size() > 0) { @@ -313,7 +315,7 @@ string PhysicalTableScan::GetFilterInfo(const TableFilterSet &filter_set) const bool first_item = true; for (auto &f : filter_set) { auto filter_idx = f.GetIndex(); - auto &filter = f.Filter(); + auto &filter = f.Filter().Cast(); if (filter_idx < names.size()) { if (!first_item) { filters_info += "\n"; @@ -327,7 +329,7 @@ string PhysicalTableScan::GetFilterInfo(const TableFilterSet &filter_set) const if (entry == virtual_columns.end()) { throw InternalException("Virtual column not found"); } - filters_info += filter.ToString(entry->second.name); + filters_info += filter.ToString(entry->second.name.GetIdentifierName()); } else { auto column_name = column_id.GetName(names[col_id]); filters_info += filter.ToString(column_name); @@ -346,7 +348,7 @@ InsertionOrderPreservingMap PhysicalTableScan::ParamsToString() const { result[it.first] = it.second; } } else { - result["Function"] = StringUtil::Upper(function.name); + result["Function"] = StringUtil::Upper(function.name.GetIdentifierName()); } if (function.projection_pushdown) { string projections; @@ -415,25 +417,16 @@ TableFunctionParallelism PhysicalTableScan::SourceParallelism() const { return function.parallelism; } -InsertionOrderPreservingMap PhysicalTableScan::ExtraSourceParams(GlobalSourceState &gstate_p, - LocalSourceState &lstate) const { - if (!function.dynamic_to_string) { - return InsertionOrderPreservingMap(); +void PhysicalTableScan::GetMetrics(ClientContext &context, GlobalSourceState &gstate_p, LocalSourceState &lstate, + OperatorMetrics &operator_metrics) const { + if (!function.get_metrics) { + return; } auto &gstate = gstate_p.Cast(); auto &state = lstate.Cast(); - TableFunctionDynamicToStringInput input(function, bind_data.get(), state.local_state.get(), - gstate.global_state.get()); - return function.dynamic_to_string(input); -} - -optional_idx PhysicalTableScan::GetRowsScanned(GlobalSourceState &gstate_p, LocalSourceState &lstate) const { - if (function.rows_scanned) { - auto &gstate = gstate_p.Cast(); - auto &state = lstate.Cast(); - return function.rows_scanned(*gstate.global_state, *state.local_state); - } - return optional_idx(); + TableFunctionGetMetricsInput input(context, bind_data.get(), state.local_state.get(), gstate.global_state.get(), + operator_metrics); + function.get_metrics(input); } } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_attach.cpp b/src/duckdb/src/execution/operator/schema/physical_attach.cpp index 638591741..c78e94091 100644 --- a/src/duckdb/src/execution/operator/schema/physical_attach.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_attach.cpp @@ -28,7 +28,7 @@ SourceResultType PhysicalAttach::GetDataInternal(ExecutionContext &context, Data } if (name.empty()) { auto &fs = FileSystem::GetFileSystem(context.client); - name = AttachedDatabase::ExtractDatabaseName(path, fs); + name = Identifier(AttachedDatabase::ExtractDatabaseName(path, fs)); } // check ATTACH IF NOT EXISTS diff --git a/src/duckdb/src/execution/operator/schema/physical_create_index.cpp b/src/duckdb/src/execution/operator/schema/physical_create_index.cpp index 9d318f903..89929be4b 100644 --- a/src/duckdb/src/execution/operator/schema/physical_create_index.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_create_index.cpp @@ -91,9 +91,8 @@ SinkResultType PhysicalCreateIndex::Sink(ExecutionContext &context, DataChunk &c // Check for NULLs, if we are creating a PRIMARY KEY. // FIXME: Later, we want to ensure that we skip the NULL check for any non-PK alter. if (alter_table_info) { - auto row_count = lstate.key_chunk.size(); for (idx_t i = 0; i < lstate.key_chunk.ColumnCount(); i++) { - if (VectorOperations::HasNull(lstate.key_chunk.data[i], row_count)) { + if (VectorOperations::HasNull(lstate.key_chunk.data[i])) { throw ConstraintException("NOT NULL constraint failed: %s", info->index_name); } } diff --git a/src/duckdb/src/execution/operator/schema/physical_detach.cpp b/src/duckdb/src/execution/operator/schema/physical_detach.cpp index 1a2ff0700..576f170e0 100644 --- a/src/duckdb/src/execution/operator/schema/physical_detach.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_detach.cpp @@ -1,10 +1,12 @@ #include "duckdb/execution/operator/schema/physical_detach.hpp" #include "duckdb/parser/parsed_data/detach_info.hpp" #include "duckdb/catalog/catalog.hpp" +#include "duckdb/common/exception/transaction_exception.hpp" #include "duckdb/main/database_manager.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/database.hpp" #include "duckdb/storage/storage_extension.hpp" +#include "duckdb/transaction/meta_transaction.hpp" namespace duckdb { @@ -14,6 +16,21 @@ namespace duckdb { SourceResultType PhysicalDetach::GetDataInternal(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { auto &db_manager = DatabaseManager::Get(context.client); + + // Reject detach if the current transaction already has outstanding work on the database. + { + auto attached_db = db_manager.GetDatabase(info->name); + if (attached_db) { + auto &meta_transaction = MetaTransaction::Get(context.client); + if (meta_transaction.TryGetTransaction(*attached_db)) { + throw TransactionException( + "Cannot detach database \"%s\" because the current transaction has outstanding " + "work on it - commit or rollback first", + info->name); + } + } + } + db_manager.DetachDatabase(context.client, info->name, info->if_not_found); return SourceResultType::FINISHED; diff --git a/src/duckdb/src/execution/operator/schema/physical_drop.cpp b/src/duckdb/src/execution/operator/schema/physical_drop.cpp index e6a8cbe99..dbed2efa7 100644 --- a/src/duckdb/src/execution/operator/schema/physical_drop.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_drop.cpp @@ -50,7 +50,7 @@ SourceResultType PhysicalDrop::GetDataInternal(ExecutionContext &context, DataCh auto &extra_info = info->extra_drop_info->Cast(); SecretManager::Get(context.client) .DropSecretByName(context.client, info->name, info->if_not_found, extra_info.persist_mode, - extra_info.secret_storage); + Identifier(extra_info.secret_storage)); break; } case CatalogType::TRIGGER_ENTRY: { diff --git a/src/duckdb/src/execution/operator/set/physical_cte.cpp b/src/duckdb/src/execution/operator/set/physical_cte.cpp index 7434b3141..6c4f7c8ed 100644 --- a/src/duckdb/src/execution/operator/set/physical_cte.cpp +++ b/src/duckdb/src/execution/operator/set/physical_cte.cpp @@ -6,8 +6,9 @@ namespace duckdb { -PhysicalCTE::PhysicalCTE(PhysicalPlan &physical_plan, string ctename, TableIndex table_index, vector types, - PhysicalOperator &top, PhysicalOperator &bottom, idx_t estimated_cardinality) +PhysicalCTE::PhysicalCTE(PhysicalPlan &physical_plan, Identifier ctename, TableIndex table_index, + vector types, PhysicalOperator &top, PhysicalOperator &bottom, + idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::CTE, std::move(types), estimated_cardinality), table_index(table_index), ctename(std::move(ctename)) { children.push_back(top); @@ -22,12 +23,31 @@ PhysicalCTE::~PhysicalCTE() { //===--------------------------------------------------------------------===// class CTEGlobalState : public GlobalSinkState { public: - explicit CTEGlobalState(ClientContext &context, const PhysicalCTE &op) : working_table_ref(op.working_table.get()) { + explicit CTEGlobalState(ClientContext &context, const PhysicalCTE &op) + : op(op), working_table_ref(op.working_table.get()) { + ResetState(context); } + const PhysicalCTE &op; optional_ptr working_table_ref; mutex lhs_lock; +private: + void ResetState(ClientContext &context) { + op.working_table->Reset(); + working_table_ref = op.working_table.get(); + GlobalSinkState::Reset(context); + } + +public: + bool SupportsReuse() const override { + return true; + } + + void Reset(ClientContext &context) override { + ResetState(context); + } + void MergeIT(ColumnDataCollection &input) { lock_guard guard(lhs_lock); working_table_ref->Combine(input); @@ -46,12 +66,11 @@ class CTELocalState : public LocalSinkState { ColumnDataAppendState append_state; void Append(DataChunk &input) { - lhs_data.Append(input); + lhs_data.Append(append_state, input); } }; unique_ptr PhysicalCTE::GetGlobalSinkState(ClientContext &context) const { - working_table->Reset(); return make_uniq(context, *this); } @@ -118,10 +137,25 @@ vector> PhysicalCTE::GetSources() const { InsertionOrderPreservingMap PhysicalCTE::ParamsToString() const { InsertionOrderPreservingMap result; - result["CTE Name"] = ctename; + result["CTE Name"] = ctename.GetIdentifierName(); result["Table Index"] = StringUtil::Format("%llu", table_index.index); SetEstimatedCardinality(result, estimated_cardinality); return result; } +ProgressData PhysicalCTE::GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, + const ProgressData source_progress) const { + auto &state = gstate.Cast(); + lock_guard guard(state.lhs_lock); + if (!state.working_table_ref) { + return ProgressData {0, 1, true}; + } + auto &working_table = *state.working_table_ref; + auto count = double(working_table.Count()); + ProgressData progress; + progress.done = count; + progress.total = count + source_progress.total; + return progress; +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp index 1260393a1..1537721f6 100644 --- a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp +++ b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp @@ -1,25 +1,19 @@ -#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte_state.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/aggregate_hashtable.hpp" -#include "duckdb/execution/executor.hpp" #include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/parallel/event.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/parallel/pipeline_executor.hpp" #include "duckdb/storage/buffer_manager.hpp" -#include +#include "duckdb/main/settings.hpp" namespace duckdb { -PhysicalRecursiveCTE::PhysicalRecursiveCTE(PhysicalPlan &physical_plan, string ctename, TableIndex table_index, +PhysicalRecursiveCTE::PhysicalRecursiveCTE(PhysicalPlan &physical_plan, Identifier ctename, TableIndex table_index, vector types, bool union_all, PhysicalOperator &top, PhysicalOperator &bottom, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::RECURSIVE_CTE, std::move(types), estimated_cardinality), - ctename(std::move(ctename)), table_index(table_index), union_all(union_all) { + ctename(std::move(ctename)), table_index(table_index), union_all(union_all), + shared_executor_pool(make_shared_ptr()) { children.push_back(top); children.push_back(bottom); } @@ -28,53 +22,194 @@ PhysicalRecursiveCTE::~PhysicalRecursiveCTE() { } //===--------------------------------------------------------------------===// -// Sink +// Sink State //===--------------------------------------------------------------------===// -class RecursiveCTEState : public GlobalSinkState { -public: - explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op) - : executor(context), intermediate_table(context, op.using_key ? op.internal_types : op.GetTypes()), - new_groups(STANDARD_VECTOR_SIZE) { - vector aggr_input_types; - vector payload_aggregates_ptr; - for (idx_t i = 0; i < op.payload_aggregates.size(); i++) { - D_ASSERT(op.payload_aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &bound_aggr_expr = op.payload_aggregates[i]->Cast(); - for (auto &child_expr : bound_aggr_expr.children) { - executor.AddExpression(*child_expr); - aggr_input_types.push_back(child_expr->GetReturnType()); - } - payload_aggregates_ptr.push_back(&bound_aggr_expr); +RecursiveCTEState::RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op) + : op(op), executor(context), allow_executor_reuse(Settings::Get(context)), + executor_pool(op.shared_executor_pool), + intermediate_table(context, op.using_key ? op.internal_types : op.GetTypes()), new_groups(STANDARD_VECTOR_SIZE), + dummy_addresses(LogicalType::POINTER) { + vector aggr_input_types; + vector payload_aggregates; + for (idx_t i = 0; i < op.payload_aggregates.size(); i++) { + D_ASSERT(op.payload_aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); + auto &bound_aggr_expr = op.payload_aggregates[i]->Cast(); + for (auto &child_expr : bound_aggr_expr.GetChildren()) { + executor.AddExpression(*child_expr); + aggr_input_types.push_back(child_expr->GetReturnType()); } + payload_aggregates.emplace_back(bound_aggr_expr); + } + + payload_rows.Initialize(Allocator::Get(context), aggr_input_types); + + ht = make_uniq(context, BufferAllocator::Get(context), op.distinct_types, + op.payload_types, std::move(payload_aggregates)); + if (op.using_key) { + distinct_rows.Initialize(Allocator::DefaultAllocator(), op.distinct_types); + source_distinct_rows.Initialize(Allocator::DefaultAllocator(), op.distinct_types); + source_payload_rows.Initialize(Allocator::DefaultAllocator(), op.payload_types); + } + source_result.Initialize(Allocator::DefaultAllocator(), op.GetTypes()); + InitializeIntermediateAppend(); + op.working_table->InitializeAppend(working_append_state); + if (op.recurring_table) { + op.recurring_table->InitializeAppend(recurring_append_state); + } +} + +RecursiveCTEState::~RecursiveCTEState() { + ClearCachedExecutors(); +} + +void RecursiveCTEState::InitializeIntermediateAppend() { + intermediate_table.InitializeAppend(intermediate_append_state); +} + +void RecursiveCTEState::ResetRecurringTable() { + D_ASSERT(op.recurring_table); + op.recurring_table->Reset(); + op.recurring_table->InitializeAppend(recurring_append_state); +} + +ColumnDataCollection &RecursiveCTEState::CurrentOutputTable() { + if (op.using_key || !output_is_working) { + return intermediate_table; + } + D_ASSERT(op.working_table); + return *op.working_table; +} + +ColumnDataCollection &RecursiveCTEState::CurrentInputTable() { + if (op.using_key) { + D_ASSERT(op.working_table); + return *op.working_table; + } + if (output_is_working) { + return intermediate_table; + } + D_ASSERT(op.working_table); + return *op.working_table; +} + +const ColumnDataCollection &RecursiveCTEState::CurrentInputTable() const { + if (op.using_key) { + D_ASSERT(op.working_table); + return *op.working_table; + } + if (output_is_working) { + return intermediate_table; + } + D_ASSERT(op.working_table); + return *op.working_table; +} + +ColumnDataAppendState &RecursiveCTEState::CurrentOutputAppendState() { + if (op.using_key || !output_is_working) { + return intermediate_append_state; + } + return working_append_state; +} - payload_rows.Initialize(Allocator::Get(context), aggr_input_types); +void RecursiveCTEState::AdvanceIterationBuffers() { + if (!op.using_key) { + output_is_working = !output_is_working; + } +} - ht = make_uniq(context, BufferAllocator::Get(context), op.distinct_types, - op.payload_types, payload_aggregates_ptr); +void RecursiveCTEState::ResetCurrentOutputTableForReuse() { + auto &output = CurrentOutputTable(); + output.ResetForReuse(); + if (op.using_key || !output_is_working) { + InitializeIntermediateAppend(); + } else { + D_ASSERT(op.working_table); + op.working_table->InitializeAppend(working_append_state); } +} - unique_ptr ht; - ExpressionExecutor executor; - DataChunk payload_rows; +void RecursiveCTEState::RebindRecursiveScans() { + if (op.using_key) { + return; + } + auto &input_table = CurrentInputTable(); + for (auto &scan_ref : op.recursive_scans) { + auto &scan = scan_ref.get(); + scan.collection = input_table; + } +} - mutex intermediate_table_lock; - ColumnDataCollection intermediate_table; - ColumnDataScanState scan_state; - bool initialized = false; - bool finished_scan = false; - SelectionVector new_groups; - AggregateHTScanState ht_scan_state; -}; +vector> &RecursiveCTEState::GetCachedExecutors(Pipeline &pipeline, idx_t max_threads) { + // Ordinary pipelines build PipelineExecutors once and discard them after the query finishes. + // Recursive CTEs re-enter the same pipelines many times, and correlated recursive invocations can + // construct several RecursiveCTEState objects for the same physical plan. Keep a state-local cache + // for cheap per-iteration reuse, and spill back into a shared pool so later states can recycle the + // already-initialized executors instead of reconstructing them from scratch. + lock_guard guard(cached_executor_lock); + auto entry = cached_executors.find(pipeline); + if (entry == cached_executors.end()) { + entry = cached_executors.emplace(reference(pipeline), vector>()).first; + } + auto &executors = entry->second; + if (!allow_executor_reuse) { + while (executors.size() < max_threads) { + executors.push_back(make_uniq(pipeline.GetClientContext(), pipeline)); + } + return executors; + } + D_ASSERT(executor_pool); + // Lock order: cached_executor_lock -> executor_pool->lock. + lock_guard pool_guard(executor_pool->lock); + auto pool_entry = executor_pool->executors.find(pipeline); + if (pool_entry == executor_pool->executors.end()) { + pool_entry = + executor_pool->executors.emplace(reference(pipeline), vector>()) + .first; + } + auto &shared_executors = pool_entry->second; + while (executors.size() < max_threads) { + if (!shared_executors.empty()) { + executors.push_back(std::move(shared_executors.back())); + shared_executors.pop_back(); + } else { + executors.push_back(make_uniq(pipeline.GetClientContext(), pipeline)); + } + } + return executors; +} + +void RecursiveCTEState::ClearCachedExecutors() { + lock_guard guard(cached_executor_lock); + if (cached_executors.empty()) { + return; + } + if (!allow_executor_reuse) { + cached_executors.clear(); + return; + } + D_ASSERT(executor_pool); + // Lock order: cached_executor_lock -> executor_pool->lock. + lock_guard pool_guard(executor_pool->lock); + for (auto &entry : cached_executors) { + auto pool_entry = executor_pool->executors.find(entry.first.get()); + if (pool_entry == executor_pool->executors.end()) { + pool_entry = executor_pool->executors.emplace(entry.first, vector>()).first; + } + auto &shared_executors = pool_entry->second; + for (auto &executor : entry.second) { + shared_executors.push_back(std::move(executor)); + } + } + cached_executors.clear(); +} unique_ptr PhysicalRecursiveCTE::GetGlobalSinkState(ClientContext &context) const { return make_uniq(context, *this); } idx_t PhysicalRecursiveCTE::ProbeHT(DataChunk &chunk, RecursiveCTEState &state) const { - Vector dummy_addresses(LogicalType::POINTER); - // Use the HT to eliminate duplicate rows - idx_t new_group_count = state.ht->FindOrCreateGroups(chunk, dummy_addresses, state.new_groups); + idx_t new_group_count = state.ht->FindOrCreateGroups(chunk, state.dummy_addresses, state.new_groups); // we only return entries we have not seen before (i.e. new groups) chunk.Slice(state.new_groups, new_group_count); @@ -82,20 +217,18 @@ idx_t PhysicalRecursiveCTE::ProbeHT(DataChunk &chunk, RecursiveCTEState &state) return new_group_count; } -static void PopulateChunk(DataChunk &group_chunk, DataChunk &input_chunk, const vector &idx_set, - bool reference) { +static void GatherChunk(DataChunk &output_chunk, DataChunk &input_chunk, const vector &idx_set) { idx_t chunk_index = 0; - // Populate the group_chunk for (auto &group_idx : idx_set) { - if (reference) { - // Reference from input_chunk[chunk_index] -> group_chunk[group_idx] - group_chunk.data[chunk_index++].Reference(input_chunk.data[group_idx]); - } else { - // Reference from input_chunk[group.index] -> group_chunk[chunk_index] - group_chunk.data[group_idx].Reference(input_chunk.data[chunk_index++]); - } + output_chunk.data[chunk_index++].Reference(input_chunk.data[group_idx]); + } +} + +static void ScatterChunk(DataChunk &output_chunk, DataChunk &input_chunk, const vector &idx_set) { + idx_t chunk_index = 0; + for (auto &group_idx : idx_set) { + output_chunk.data[group_idx].Reference(input_chunk.data[chunk_index++]); } - group_chunk.SetCardinality(input_chunk.size()); } SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { @@ -103,22 +236,23 @@ SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk & lock_guard guard(gstate.intermediate_table_lock); if (!using_key) { + auto &output = gstate.CurrentOutputTable(); + auto &append_state = gstate.CurrentOutputAppendState(); if (!union_all) { idx_t match_count = ProbeHT(chunk, gstate); if (match_count > 0) { - gstate.intermediate_table.Append(chunk); + output.Append(append_state, chunk); } } else { - gstate.intermediate_table.Append(chunk); + output.Append(append_state, chunk); } } else { - // Split incoming DataChunk into payload and keys - DataChunk distinct_rows; - distinct_rows.Initialize(Allocator::DefaultAllocator(), distinct_types); - PopulateChunk(distinct_rows, chunk, distinct_idx, true); + // Split incoming DataChunk into payload and keys using the cached distinct_rows chunk + gstate.distinct_rows.Reset(); + GatherChunk(gstate.distinct_rows, chunk, distinct_idx); // Add result of recursive anchor to intermediate table - gstate.intermediate_table.Append(chunk); + gstate.intermediate_table.Append(gstate.intermediate_append_state, chunk); // Execute aggregate expressions on chunk if any if (!gstate.executor.expressions.empty()) { @@ -127,7 +261,7 @@ SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk & } // Add the result of the executed expressions to the hash table - gstate.ht->AddChunk(distinct_rows, gstate.payload_rows, AggregateType::NON_DISTINCT); + gstate.ht->AddChunk(gstate.distinct_rows, gstate.payload_rows, AggregateType::NON_DISTINCT); } return SinkResultType::NEED_MORE_INPUT; @@ -141,7 +275,7 @@ SourceResultType PhysicalRecursiveCTE::GetDataInternal(ExecutionContext &context auto &gstate = sink_state->Cast(); if (!gstate.initialized) { if (!using_key) { - gstate.intermediate_table.InitializeScan(gstate.scan_state); + gstate.CurrentOutputTable().InitializeScan(gstate.scan_state); } else { gstate.ht->InitializeScan(gstate.ht_scan_state); recurring_table->InitializeScan(gstate.scan_state); @@ -153,7 +287,7 @@ SourceResultType PhysicalRecursiveCTE::GetDataInternal(ExecutionContext &context if (!gstate.finished_scan) { if (!using_key) { // scan any chunks we have collected so far - gstate.intermediate_table.Scan(gstate.scan_state, chunk); + gstate.CurrentOutputTable().Scan(gstate.scan_state, chunk); } if (chunk.size() == 0) { gstate.finished_scan = true; @@ -164,80 +298,85 @@ SourceResultType PhysicalRecursiveCTE::GetDataInternal(ExecutionContext &context // we have run out of chunks // now we need to recurse // we set up the working table as the data we gathered in this iteration of the recursion + auto ¤t_output = gstate.CurrentOutputTable(); // After an iteration, we reset the recurring table // and fill it up with the new hash table rows for the next iteration. - if (using_key && ref_recurring && gstate.intermediate_table.Count() != 0) { - recurring_table->Reset(); + if (using_key && ref_recurring && current_output.Count() != 0) { + gstate.ResetRecurringTable(); AggregateHTScanState scan_state; gstate.ht->InitializeScan(scan_state); - - // Initialise the DataChunks to read the resulting rows. - // One DataChunk for the payload, one for the keys. - // Create a new DataChunk to store the result. - DataChunk result; - DataChunk payload_rows; - DataChunk distinct_rows; - distinct_rows.Initialize(Allocator::DefaultAllocator(), distinct_types); - if (!payload_types.empty()) { - payload_rows.Initialize(Allocator::DefaultAllocator(), payload_types); - } - result.Initialize(Allocator::DefaultAllocator(), chunk.GetTypes()); + auto &result = gstate.source_result; + auto &payload_rows = gstate.source_payload_rows; + auto &distinct_rows = gstate.source_distinct_rows; while (gstate.ht->Scan(scan_state, distinct_rows, payload_rows)) { + result.Reset(); // Populate the result DataChunk with the keys and the payload. - PopulateChunk(result, distinct_rows, distinct_idx, false); - PopulateChunk(result, payload_rows, payload_idx, false); + ScatterChunk(result, distinct_rows, distinct_idx); + ScatterChunk(result, payload_rows, payload_idx); // Append the result to the recurring table. - recurring_table->Append(result); + recurring_table->Append(gstate.recurring_append_state, result); } - } else if (ref_recurring && gstate.intermediate_table.Count() != 0) { + } else if (ref_recurring && current_output.Count() != 0) { // we need to populate the recurring table from the intermediate table // careful: we can not just use Combine here, because this destroys the intermediate table // instead we need to scan and append to create a copy // Note: as we are in the "normal" recursion case here, not the USING KEY case, // we can just scan the intermediate table directly, instead of going through the HT ColumnDataScanState scan_state; - gstate.intermediate_table.InitializeScan(scan_state); - DataChunk result; - result.Initialize(Allocator::DefaultAllocator(), chunk.GetTypes()); - - while (gstate.intermediate_table.Scan(scan_state, result)) { - recurring_table->Append(result); + current_output.InitializeScan(scan_state); + while (current_output.Scan(scan_state, gstate.source_result)) { + recurring_table->Append(gstate.recurring_append_state, gstate.source_result); } } - working_table->Reset(); - working_table->Combine(gstate.intermediate_table); - // and we clear the intermediate table gstate.finished_scan = false; - gstate.intermediate_table.Reset(); + if (!using_key) { + gstate.AdvanceIterationBuffers(); + gstate.ResetCurrentOutputTableForReuse(); + gstate.RebindRecursiveScans(); + } else { + working_table->Reset(); + working_table->Combine(gstate.intermediate_table); + gstate.InitializeIntermediateAppend(); + } + + // Pre-grow the dedup HT to avoid costly Resize + ReinsertTuples during the next Sink phase. + // current_output.Count() is the count of rows output in the previous iteration — an upper bound + // on the number of new unique rows the next iteration can add (since the recursion is converging). + if (!union_all) { + const idx_t expected_new = current_output.Count(); + if (expected_new > 0) { + const idx_t desired_capacity = + GroupedAggregateHashTable::GetCapacityForCount(gstate.ht->Count() + expected_new); + if (desired_capacity > gstate.ht->Capacity()) { + gstate.ht->Resize(desired_capacity); + } + } + } + // now we need to re-execute all of the pipelines that depend on the recursion ExecuteRecursivePipelines(context); // check if we obtained any results // if not, we are done - if (gstate.intermediate_table.Count() == 0) { + if (gstate.CurrentOutputTable().Count() == 0) { gstate.finished_scan = true; if (using_key) { - // Initialise the DataChunks to read the ht. - // One DataChunk for payload, one for keys. - DataChunk payload_rows; - DataChunk distinct_rows; - distinct_rows.Initialize(Allocator::DefaultAllocator(), distinct_types); - if (!payload_types.empty()) { - payload_rows.Initialize(Allocator::DefaultAllocator(), payload_types); - } - + auto &payload_rows = gstate.source_payload_rows; + auto &distinct_rows = gstate.source_distinct_rows; + distinct_rows.Reset(); + payload_rows.Reset(); gstate.ht->Scan(gstate.ht_scan_state, distinct_rows, payload_rows); - PopulateChunk(chunk, distinct_rows, distinct_idx, false); - PopulateChunk(chunk, payload_rows, payload_idx, false); + ScatterChunk(chunk, distinct_rows, distinct_idx); + ScatterChunk(chunk, payload_rows, payload_idx); } break; } if (!using_key) { // set up the scan again - gstate.intermediate_table.InitializeScan(gstate.scan_state); + gstate.CurrentOutputTable().InitializeScan(gstate.scan_state); } } } @@ -245,108 +384,13 @@ SourceResultType PhysicalRecursiveCTE::GetDataInternal(ExecutionContext &context return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; } -void PhysicalRecursiveCTE::ExecuteRecursivePipelines(ExecutionContext &context) const { - if (!recursive_meta_pipeline) { - throw InternalException("Missing meta pipeline for recursive CTE"); - } - D_ASSERT(recursive_meta_pipeline->HasRecursiveCTE()); - - // get and reset pipelines - vector> pipelines; - recursive_meta_pipeline->GetPipelines(pipelines, true); - for (auto &pipeline : pipelines) { - auto sink = pipeline->GetSink(); - if (sink.get() != this) { - sink->sink_state.reset(); - } - for (auto &op_ref : pipeline->GetOperators()) { - auto &op = op_ref.get(); - op.op_state.reset(); - } - pipeline->ClearSource(); - } - - // get the MetaPipelines in the recursive_meta_pipeline and reschedule them - vector> meta_pipelines; - recursive_meta_pipeline->GetMetaPipelines(meta_pipelines, true, false); - auto &executor = recursive_meta_pipeline->GetExecutor(); - vector> events; - executor.ReschedulePipelines(meta_pipelines, events); - - while (true) { - executor.WorkOnTasks(); - if (executor.HasError()) { - executor.ThrowException(); - } - bool finished = true; - for (auto &event : events) { - if (!event->IsFinished()) { - finished = false; - break; - } - } - if (finished) { - // all pipelines finished: done! - break; - } - } -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// - -static void GatherColumnDataScans(const PhysicalOperator &op, vector> &delim_scans) { - if (op.type == PhysicalOperatorType::DELIM_SCAN || op.type == PhysicalOperatorType::CTE_SCAN) { - delim_scans.push_back(op); - } - for (auto &child : op.children) { - GatherColumnDataScans(child, delim_scans); - } -} - -void PhysicalRecursiveCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - sink_state.reset(); - recursive_meta_pipeline.reset(); - - auto &state = meta_pipeline.GetState(); - state.SetPipelineSource(current, *this); - - auto &executor = meta_pipeline.GetExecutor(); - executor.AddRecursiveCTE(*this); - - // the LHS of the recursive CTE is our initial state - auto &initial_state_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - initial_state_pipeline.Build(children[0]); - - // the RHS is the recursive pipeline - recursive_meta_pipeline = make_shared_ptr(executor, state, this); - recursive_meta_pipeline->SetRecursiveCTE(); - recursive_meta_pipeline->Build(children[1]); - - vector> ops; - GatherColumnDataScans(children[1], ops); - - for (auto op : ops) { - auto entry = state.cte_dependencies.find(op); - if (entry == state.cte_dependencies.end()) { - continue; - } - // this chunk scan introduces a dependency to the current pipeline - // namely a dependency on the CTE pipeline to finish - auto cte_dependency = entry->second.get().shared_from_this(); - current.AddDependency(cte_dependency); - } -} - vector> PhysicalRecursiveCTE::GetSources() const { return {*this}; } InsertionOrderPreservingMap PhysicalRecursiveCTE::ParamsToString() const { InsertionOrderPreservingMap result; - result["CTE Name"] = ctename; + result["CTE Name"] = ctename.GetIdentifierName(); result["Table Index"] = StringUtil::Format("%llu", table_index.index); SetEstimatedCardinality(result, estimated_cardinality); return result; diff --git a/src/duckdb/src/execution/operator/set/physical_recursive_cte_runtime.cpp b/src/duckdb/src/execution/operator/set/physical_recursive_cte_runtime.cpp new file mode 100644 index 000000000..d2e299b4b --- /dev/null +++ b/src/duckdb/src/execution/operator/set/physical_recursive_cte_runtime.cpp @@ -0,0 +1,1144 @@ +#include "duckdb/execution/operator/set/physical_recursive_cte_state.hpp" + +#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" +#include "duckdb/execution/operator/join/physical_blockwise_nl_join.hpp" +#include "duckdb/execution/operator/join/physical_delim_join.hpp" +#include "duckdb/execution/operator/join/physical_hash_join.hpp" +#include "duckdb/execution/operator/join/physical_nested_loop_join.hpp" +#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" +#include "duckdb/execution/operator/scan/physical_table_scan.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/parallel/event.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" +#include "duckdb/parallel/pipeline.hpp" +#include "duckdb/parallel/pipeline_complete_event.hpp" +#include "duckdb/parallel/pipeline_executor.hpp" +#include "duckdb/parallel/pipeline_finish_event.hpp" +#include "duckdb/parallel/pipeline_prepare_finish_event.hpp" +#include "duckdb/parallel/task_scheduler.hpp" + +#include "duckdb/main/settings.hpp" + +namespace duckdb { + +class RecursiveCTEPipelineEvent; + +struct RecursiveCTEEventStack { + RecursiveCTEEventStack(shared_ptr pipeline_event_p, + shared_ptr pipeline_prepare_finish_event_p, shared_ptr pipeline_finish_event_p, + shared_ptr pipeline_complete_event_p) + : pipeline_event(std::move(pipeline_event_p)), + pipeline_prepare_finish_event(std::move(pipeline_prepare_finish_event_p)), + pipeline_finish_event(std::move(pipeline_finish_event_p)), + pipeline_complete_event(std::move(pipeline_complete_event_p)) { + } + + shared_ptr pipeline_event; + shared_ptr pipeline_prepare_finish_event; + shared_ptr pipeline_finish_event; + shared_ptr pipeline_complete_event; +}; + +using recursive_cte_event_map_t = reference_map_t; + +struct RecursiveCTEInlineStageStack { + RecursiveCTEInlineStageStack(idx_t execute_stage_p, idx_t prepare_finish_stage_p, idx_t finish_stage_p) + : execute_stage(execute_stage_p), prepare_finish_stage(prepare_finish_stage_p), finish_stage(finish_stage_p) { + } + + idx_t execute_stage; + idx_t prepare_finish_stage; + idx_t finish_stage; +}; + +using recursive_cte_inline_stage_map_t = reference_map_t; + +enum class RecursiveCTEMetaPipelineEntryType : uint8_t { + BASE, + SHARED_FINISH_GROUP, + HAS_FINISH_EVENT, + SHARED_BASE_FINISH +}; + +struct RecursiveCTEMetaPipelinePlanEntry { + RecursiveCTEMetaPipelinePlanEntry(Pipeline &pipeline_p, RecursiveCTEMetaPipelineEntryType type_p, + optional_ptr finish_group_p = nullptr) + : pipeline(pipeline_p), type(type_p), finish_group(finish_group_p) { + } + + reference pipeline; + RecursiveCTEMetaPipelineEntryType type; + optional_ptr finish_group; +}; + +struct RecursiveCTEMetaPipelinePlan { + vector entries; + vector> initialize_on_schedule_pipelines; +}; + +//===--------------------------------------------------------------------===// +// Recursive CTE Task and Event for optimized execution +//===--------------------------------------------------------------------===// +// +// The normal pipeline scheduler is optimized for one-shot query execution: build the query-wide +// event graph, allocate PipelineExecutors/tasks, run the pipelines once, then tear that runtime +// state down. Recursive CTEs repeatedly re-enter the recursive member, often with tiny frontiers, +// so paying that setup/teardown cost every iteration quickly dominates the actual data processing. +// +// The helpers below keep the same pipeline/finalize semantics, but split out a recursive-specific +// runtime that can reset and reuse executors, selected operator state, and the recursive dependency +// topology across iterations. + +// Recursive iterations often produce small batches. Aim for about half a vector of input per worker +// before increasing parallelism so scheduler overhead stays low on narrow recursive workloads. +static constexpr const idx_t RECURSIVE_ROWS_PER_THREAD = STANDARD_VECTOR_SIZE / 2; + +static idx_t GetRecursiveThreadLimit(const RecursiveCTEState &state) { + idx_t recursive_rows = 0; + if (state.op.working_table && state.op.recursive_reference_count > 0) { + recursive_rows += state.CurrentInputTable().Count() * state.op.recursive_reference_count; + } + if (state.op.recurring_table && state.op.recurring_reference_count > 0) { + recursive_rows += state.op.recurring_table->Count() * state.op.recurring_reference_count; + } + return MaxValue((recursive_rows + RECURSIVE_ROWS_PER_THREAD - 1) / RECURSIVE_ROWS_PER_THREAD, 1); +} + +static idx_t GetRecursivePipelineMaxThreads(RecursiveCTEState &state, Pipeline &pipeline) { + auto max_threads = pipeline.GetMaxThreads(); + if (max_threads < 1) { + max_threads = 1; + } + return MinValue(max_threads, GetRecursiveThreadLimit(state)); +} + +static void ExecuteRecursivePipelineInline(PipelineExecutor &pipeline_executor) { + auto signal_state = make_shared_ptr(); + pipeline_executor.SetInterruptState(InterruptState(weak_ptr(signal_state))); + while (true) { + auto result = pipeline_executor.Execute(); + switch (result) { + case PipelineExecuteResult::FINISHED: + return; + case PipelineExecuteResult::NOT_FINISHED: + throw InternalException("Execute without limit should not return NOT_FINISHED"); + case PipelineExecuteResult::INTERRUPTED: + signal_state->Await(); + break; + } + } +} + +//! A task that executes a cached PipelineExecutor +class RecursiveCTETask : public ExecutorTask { +public: + RecursiveCTETask(Pipeline &pipeline_p, shared_ptr event_p, PipelineExecutor &executor_p) + : ExecutorTask(pipeline_p.executor, std::move(event_p)), pipeline(pipeline_p), pipeline_executor(executor_p) { + } + + Pipeline &pipeline; + PipelineExecutor &pipeline_executor; + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { + pipeline_executor.SetTaskForInterrupts(shared_from_this()); + + if (mode == TaskExecutionMode::PROCESS_PARTIAL) { + auto res = pipeline_executor.Execute(PARTIAL_CHUNK_COUNT); + switch (res) { + case PipelineExecuteResult::NOT_FINISHED: + return TaskExecutionResult::TASK_NOT_FINISHED; + case PipelineExecuteResult::INTERRUPTED: + return TaskExecutionResult::TASK_BLOCKED; + case PipelineExecuteResult::FINISHED: + break; + } + } else { + auto res = pipeline_executor.Execute(); + switch (res) { + case PipelineExecuteResult::NOT_FINISHED: + throw InternalException("Execute without limit should not return NOT_FINISHED"); + case PipelineExecuteResult::INTERRUPTED: + return TaskExecutionResult::TASK_BLOCKED; + case PipelineExecuteResult::FINISHED: + break; + } + } + + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; + } + + bool TaskBlockedOnResult() const override { + return pipeline_executor.RemainingSinkChunk(); + } + +private: + // Keep partial execution reasonably coarse so blocked pipelines make progress without creating tiny tasks. + static constexpr const idx_t PARTIAL_CHUNK_COUNT = 50; +}; + +//! Recursive execute event. +//! We still use BasePipelineEvent for dependency tracking and task bookkeeping, but the stock +//! pipeline scheduling path is too expensive here: it assumes a one-shot execution, creates fresh +//! PipelineExecutors/tasks every time, and resets shared pipeline state immediately before launch. +//! Recursive CTEs need the same dependency semantics while reusing cached PipelineExecutors, and +//! root events must sometimes be reset up-front on the main thread to avoid reset-vs-execute races. +class RecursiveCTEPipelineEvent : public BasePipelineEvent { +public: + RecursiveCTEPipelineEvent(shared_ptr pipeline_p, RecursiveCTEState &state_p) + : BasePipelineEvent(std::move(pipeline_p)), state(state_p) { + } + + RecursiveCTEState &state; + bool prepared_for_schedule = false; + + void PrepareForSchedule() { + // Root recursive pipeline events can be scheduled back-to-back while sharing operator instances. + // Prepare their global pipeline state up-front on the main thread so later task execution does + // not race with another root event resetting the same operator state. + pipeline->ResetForReschedule(false); + prepared_for_schedule = true; + } + + void Schedule() override { + // Sink state is prepared up-front from the main thread. Reinitialize the remaining + // global state here, reusing existing state objects when operators expose reset hooks. + // Dependency-free pipeline events can be prepared up-front on the main thread to avoid + // racing with another root event that shares operator instances. + if (!prepared_for_schedule) { + pipeline->ResetForReschedule(false); + } + + auto max_threads = GetRecursivePipelineMaxThreads(state, *pipeline); + + // Get or create cached executors for this pipeline + auto &executors = state.GetCachedExecutors(*pipeline, max_threads); + + // Create tasks using cached executors + vector> tasks; + for (idx_t i = 0; i < max_threads; i++) { + executors[i]->PrepareForExecution(); + tasks.push_back(make_uniq(*pipeline, shared_from_this(), *executors[i])); + } + SetTasks(std::move(tasks)); + } +}; + +//! Inline finish event for the single-thread recursive fast path. +//! PipelineFinishEvent is correct for the general scheduler, but it is intentionally scheduler-centric: +//! it expects finish work to go back through the task/event machinery. In the single-thread recursive +//! path that would mostly bounce control back into the scheduler just to run the finish work on the +//! current thread. This event preserves the same finalize contract, including BLOCKED handling, while +//! keeping the finish step inline with the cached execute/prepare stages. +class RecursiveCTEFinishEvent : public BasePipelineEvent { +public: + explicit RecursiveCTEFinishEvent(shared_ptr pipeline_p) : BasePipelineEvent(std::move(pipeline_p)) { + } + + void Schedule() override { + // This event is only used by the single-thread inline recursive fast path. + // Blocking here is safe because the caller runs it synchronously and explicitly waits for completion. + D_ASSERT(total_tasks == 0); + total_tasks = 1; + + auto signal_state = make_shared_ptr(); + InterruptState interrupt_state {weak_ptr(signal_state)}; + auto sink = pipeline->GetSink(); + auto &operators = pipeline->GetIntermediateOperators(); + idx_t operator_idx = 0; + + while (true) { + bool blocked = false; + for (; operator_idx < operators.size(); operator_idx++) { + auto &op = operators[operator_idx].get(); + if (!op.RequiresOperatorFinalize()) { + continue; + } + OperatorFinalizeInput op_finalize_input {*op.op_state, interrupt_state}; + auto op_state = op.OperatorFinalize(*pipeline, *this, executor.context, op_finalize_input); + if (op_state == OperatorFinalResultType::BLOCKED) { + signal_state->Await(); + blocked = true; + break; + } + } + if (blocked) { + continue; + } + + OperatorSinkFinalizeInput finalize_input {*sink->sink_state, interrupt_state}; + auto sink_state = sink->Finalize(*pipeline, *this, executor.context, finalize_input); + if (sink_state == SinkFinalizeType::BLOCKED) { + signal_state->Await(); + continue; + } + + sink->sink_state->state = sink_state; + FinishTask(); + return; + } + } +}; + +static idx_t AddRecursiveInlineStage(RecursiveCTEInlinePlan &plan, RecursiveCTEInlineStageType type, + Pipeline &pipeline) { + plan.stages.emplace_back(type, pipeline); + return plan.stages.size() - 1; +} + +static void AddRecursiveInlineDependency(RecursiveCTEInlinePlan &plan, idx_t dependent_stage, idx_t dependency_stage) { + auto &dependent = plan.stages[dependent_stage]; + dependent.dependency_count++; + plan.stages[dependency_stage].dependents.push_back(dependent_stage); +} + +static RecursiveCTEInlineStageStack AddRecursiveInlineStageStack(RecursiveCTEInlinePlan &plan, Pipeline &pipeline) { + auto execute = AddRecursiveInlineStage(plan, RecursiveCTEInlineStageType::EXECUTE, pipeline); + auto prepare_finish = AddRecursiveInlineStage(plan, RecursiveCTEInlineStageType::PREPARE_FINISH, pipeline); + auto finish = AddRecursiveInlineStage(plan, RecursiveCTEInlineStageType::FINISH, pipeline); + AddRecursiveInlineDependency(plan, prepare_finish, execute); + AddRecursiveInlineDependency(plan, finish, prepare_finish); + return RecursiveCTEInlineStageStack(execute, prepare_finish, finish); +} + +static RecursiveCTEEventStack CreateRecursiveEventStack(const shared_ptr &pipeline, + RecursiveCTEState &state) { + auto execute = make_shared_ptr(pipeline, state); + auto prepare_finish = make_shared_ptr(pipeline); + auto finish = make_shared_ptr(pipeline); + prepare_finish->AddDependency(*execute); + finish->AddDependency(*prepare_finish); + return RecursiveCTEEventStack(execute, prepare_finish, finish, nullptr); +} + +static bool RequiresInitializeOnSchedule(Pipeline &pipeline) { + auto source = pipeline.GetSource(); + if (!source || source->type != PhysicalOperatorType::TABLE_SCAN) { + return false; + } + auto &table_function = source->Cast(); + return table_function.function.global_initialization == TableFunctionInitialization::INITIALIZE_ON_SCHEDULE; +} + +static RecursiveCTEMetaPipelinePlan BuildRecursiveMetaPipelinePlan(MetaPipeline &meta_pipeline) { + // MetaPipeline already knows how to classify pipelines for the normal scheduler, but recursive + // execution needs that classification in two different forms: + // 1. a cached scheduler-free stage plan for the single-thread inline fast path + // 2. a per-iteration Event graph for the multi-threaded path + // Materializing a tiny neutral plan keeps both paths in sync without forcing the inline path to + // instantiate Event objects it will never schedule. + RecursiveCTEMetaPipelinePlan result; + + vector> pipelines; + meta_pipeline.GetPipelines(pipelines, false); + result.entries.reserve(pipelines.size()); + result.entries.emplace_back(*meta_pipeline.GetBasePipeline(), RecursiveCTEMetaPipelineEntryType::BASE); + + for (idx_t i = 1; i < pipelines.size(); i++) { + auto &pipeline = pipelines[i]; + auto finish_group = meta_pipeline.GetFinishGroup(*pipeline); + if (finish_group) { + result.entries.emplace_back(*pipeline, RecursiveCTEMetaPipelineEntryType::SHARED_FINISH_GROUP, + finish_group.get()); + continue; + } + if (meta_pipeline.HasFinishEvent(*pipeline)) { + result.entries.emplace_back(*pipeline, RecursiveCTEMetaPipelineEntryType::HAS_FINISH_EVENT); + continue; + } + result.entries.emplace_back(*pipeline, RecursiveCTEMetaPipelineEntryType::SHARED_BASE_FINISH); + } + + for (auto &pipeline : pipelines) { + if (RequiresInitializeOnSchedule(*pipeline)) { + result.initialize_on_schedule_pipelines.push_back(*pipeline); + } + } + return result; +} + +static unique_ptr +BuildRecursiveInlinePlan(const vector> &meta_pipelines) { + // This is the scheduler-free equivalent of the recursive event graph. It is only used when the + // current iteration is capped to one thread; in that case, preserving the same execute -> + // prepare-finish -> finish ordering matters, but paying full scheduler/task overhead does not. + auto plan = make_uniq(); + recursive_cte_inline_stage_map_t stage_map; + for (auto &meta_pipeline : meta_pipelines) { + auto meta_pipeline_plan = BuildRecursiveMetaPipelinePlan(*meta_pipeline); + RecursiveCTEInlineStageStack base_stack(DConstants::INVALID_INDEX, DConstants::INVALID_INDEX, + DConstants::INVALID_INDEX); + for (auto &entry : meta_pipeline_plan.entries) { + auto &pipeline = entry.pipeline.get(); + if (entry.type != RecursiveCTEMetaPipelineEntryType::BASE) { + D_ASSERT(base_stack.execute_stage != DConstants::INVALID_INDEX); + } + switch (entry.type) { + case RecursiveCTEMetaPipelineEntryType::BASE: { + base_stack = AddRecursiveInlineStageStack(*plan, pipeline); + stage_map.emplace(reference(pipeline), base_stack); + break; + } + case RecursiveCTEMetaPipelineEntryType::SHARED_FINISH_GROUP: { + D_ASSERT(entry.finish_group); + auto group_entry = stage_map.find(*entry.finish_group); + D_ASSERT(group_entry != stage_map.end()); + auto execute = AddRecursiveInlineStage(*plan, RecursiveCTEInlineStageType::EXECUTE, pipeline); + AddRecursiveInlineDependency(*plan, execute, base_stack.finish_stage); + AddRecursiveInlineDependency(*plan, group_entry->second.prepare_finish_stage, execute); + stage_map.emplace(reference(pipeline), + RecursiveCTEInlineStageStack(execute, group_entry->second.prepare_finish_stage, + group_entry->second.finish_stage)); + break; + } + case RecursiveCTEMetaPipelineEntryType::HAS_FINISH_EVENT: { + auto pipeline_stack = AddRecursiveInlineStageStack(*plan, pipeline); + AddRecursiveInlineDependency(*plan, pipeline_stack.execute_stage, base_stack.finish_stage); + stage_map.emplace(reference(pipeline), pipeline_stack); + break; + } + case RecursiveCTEMetaPipelineEntryType::SHARED_BASE_FINISH: { + auto execute = AddRecursiveInlineStage(*plan, RecursiveCTEInlineStageType::EXECUTE, pipeline); + AddRecursiveInlineDependency(*plan, base_stack.prepare_finish_stage, execute); + stage_map.emplace( + reference(pipeline), + RecursiveCTEInlineStageStack(execute, base_stack.prepare_finish_stage, base_stack.finish_stage)); + break; + } + default: + throw InternalException("Unsupported recursive meta pipeline plan entry"); + } + } + for (auto &pipeline : meta_pipeline_plan.initialize_on_schedule_pipelines) { + plan->initialize_on_schedule_pipelines.push_back(pipeline); + } + } + + for (auto &entry : stage_map) { + auto &pipeline = entry.first.get(); + for (auto &dependency : pipeline.GetDependencies()) { + auto dep = dependency.lock(); + D_ASSERT(dep); + auto stage_entry = stage_map.find(*dep); + if (stage_entry == stage_map.end()) { + continue; + } + AddRecursiveInlineDependency(*plan, entry.second.execute_stage, stage_entry->second.finish_stage); + } + } + + for (auto &meta_pipeline : meta_pipelines) { + for (auto &entry : meta_pipeline->GetDependencies()) { + auto pipeline_entry = stage_map.find(entry.first.get()); + if (pipeline_entry == stage_map.end()) { + continue; + } + for (auto &dependency : entry.second) { + auto dependency_entry = stage_map.find(dependency.get()); + if (dependency_entry == stage_map.end()) { + continue; + } + AddRecursiveInlineDependency(*plan, pipeline_entry->second.execute_stage, + dependency_entry->second.execute_stage); + } + } + } + + for (auto &meta_pipeline : meta_pipelines) { + vector> children; + meta_pipeline->GetMetaPipelines(children, false, true); + for (auto &child1 : children) { + if (child1->Type() != MetaPipelineType::JOIN_BUILD) { + continue; + } + auto child1_entry = stage_map.find(*child1->GetBasePipeline()); + if (child1_entry == stage_map.end()) { + continue; + } + + for (auto &child2 : children) { + if (child2->Type() != MetaPipelineType::JOIN_BUILD || child1.get() == child2.get()) { + continue; + } + if (child1->GetParent().get() != child2->GetParent().get()) { + continue; + } + auto child2_entry = stage_map.find(*child2->GetBasePipeline()); + if (child2_entry == stage_map.end()) { + continue; + } + + AddRecursiveInlineDependency(*plan, child1_entry->second.prepare_finish_stage, + child2_entry->second.execute_stage); + AddRecursiveInlineDependency(*plan, child1_entry->second.finish_stage, + child2_entry->second.prepare_finish_stage); + } + } + } + return plan; +} + +static bool OperatorDirectlyDependsOnRecursiveInput(const PhysicalOperator &op, TableIndex cte_index) { + if (op.type == PhysicalOperatorType::DELIM_SCAN) { + return true; + } + if (op.type == PhysicalOperatorType::RECURSIVE_CTE_SCAN || + op.type == PhysicalOperatorType::RECURSIVE_RECURRING_CTE_SCAN) { + auto &scan = op.Cast(); + return scan.cte_index == cte_index; + } + if (op.type == PhysicalOperatorType::LEFT_DELIM_JOIN || op.type == PhysicalOperatorType::RIGHT_DELIM_JOIN) { + auto &delim_join = op.Cast(); + for (auto &scan : delim_join.delim_scans) { + if (OperatorDirectlyDependsOnRecursiveInput(scan.get(), cte_index)) { + return true; + } + } + } + return false; +} + +static bool PipelineDirectlyDependsOnRecursiveInput(Pipeline &pipeline, TableIndex cte_index) { + auto source = pipeline.GetSource(); + if (source && OperatorDirectlyDependsOnRecursiveInput(*source, cte_index)) { + return true; + } + for (auto &op : pipeline.GetIntermediateOperators()) { + if (OperatorDirectlyDependsOnRecursiveInput(op.get(), cte_index)) { + return true; + } + } + return false; +} + +static reference_set_t +FindInvariantRecursiveMetaPipelines(const vector> &meta_pipelines, TableIndex cte_index) { + // By default the recursive executor reruns every meta-pipeline in the recursive member each + // iteration because the generic scheduling infrastructure does not know which build-side subplans + // are independent of the current recursive frontier. For small recursive workloads that repeats + // expensive setup for static subplans. Classify only JOIN_BUILD meta-pipelines that are + // transitively independent from the recursive input and whose build side does not propagate rows + // back into the recursive result, so their materialized state can safely survive later iterations. + reference_map_t> pipeline_to_meta_pipeline; + reference_set_t variant_meta_pipelines; + + for (auto &meta_pipeline : meta_pipelines) { + vector> pipelines; + meta_pipeline->GetPipelines(pipelines, false); + for (auto &pipeline : pipelines) { + pipeline_to_meta_pipeline.emplace(*pipeline, *meta_pipeline); + if (PipelineDirectlyDependsOnRecursiveInput(*pipeline, cte_index)) { + variant_meta_pipelines.insert(*meta_pipeline); + } + } + } + + bool changed = true; + while (changed) { + changed = false; + for (auto &meta_pipeline : meta_pipelines) { + if (variant_meta_pipelines.find(*meta_pipeline) != variant_meta_pipelines.end()) { + continue; + } + + bool depends_on_variant = false; + vector> pipelines; + meta_pipeline->GetPipelines(pipelines, false); + for (auto &pipeline : pipelines) { + for (auto &dependency : pipeline->GetDependencies()) { + auto dep = dependency.lock(); + if (!dep) { + continue; + } + auto dep_entry = pipeline_to_meta_pipeline.find(*dep); + if (dep_entry == pipeline_to_meta_pipeline.end()) { + continue; + } + if (variant_meta_pipelines.find(dep_entry->second) != variant_meta_pipelines.end()) { + depends_on_variant = true; + break; + } + } + if (depends_on_variant) { + break; + } + } + if (!depends_on_variant) { + for (auto &entry : meta_pipeline->GetDependencies()) { + for (auto &dependency : entry.second) { + auto dep_entry = pipeline_to_meta_pipeline.find(dependency.get()); + if (dep_entry == pipeline_to_meta_pipeline.end()) { + continue; + } + if (variant_meta_pipelines.find(dep_entry->second) != variant_meta_pipelines.end()) { + depends_on_variant = true; + break; + } + } + if (depends_on_variant) { + break; + } + } + } + if (depends_on_variant) { + variant_meta_pipelines.insert(*meta_pipeline); + changed = true; + } + } + } + + reference_set_t result; + for (auto &meta_pipeline : meta_pipelines) { + if (meta_pipeline->Type() != MetaPipelineType::JOIN_BUILD) { + continue; + } + auto sink = meta_pipeline->GetSink(); + if (!sink) { + continue; + } + + bool can_cache_build = false; + switch (sink->type) { + case PhysicalOperatorType::HASH_JOIN: { + auto &hash_join = sink->Cast(); + can_cache_build = !PropagatesBuildSide(hash_join.join_type); + break; + } + case PhysicalOperatorType::NESTED_LOOP_JOIN: { + auto &nested_loop_join = sink->Cast(); + can_cache_build = !PropagatesBuildSide(nested_loop_join.join_type); + break; + } + case PhysicalOperatorType::BLOCKWISE_NL_JOIN: { + auto &blockwise_nl_join = sink->Cast(); + can_cache_build = !PropagatesBuildSide(blockwise_nl_join.join_type); + break; + } + case PhysicalOperatorType::CROSS_PRODUCT: + can_cache_build = true; + break; + default: + break; + } + if (!can_cache_build) { + continue; + } + if (variant_meta_pipelines.find(*meta_pipeline) == variant_meta_pipelines.end()) { + result.insert(*meta_pipeline); + } + } + return result; +} + +static vector> GetActiveRecursiveMetaPipelines(const PhysicalRecursiveCTE &op, + RecursiveCTEState &state) { + vector> meta_pipelines; + op.recursive_meta_pipeline->GetMetaPipelines(meta_pipelines, true, false); + if (!state.allow_executor_reuse || !state.invariant_meta_pipelines_materialized || + op.invariant_meta_pipelines.empty()) { + return meta_pipelines; + } + + vector> active_meta_pipelines; + active_meta_pipelines.reserve(meta_pipelines.size()); + for (auto &meta_pipeline : meta_pipelines) { + if (op.invariant_meta_pipelines.find(*meta_pipeline) == op.invariant_meta_pipelines.end()) { + active_meta_pipelines.push_back(meta_pipeline); + } + } + return active_meta_pipelines; +} + +static void ConfigureInvariantRecursiveBuildReuse(const PhysicalRecursiveCTE &op, bool preserve_build) { + for (auto &meta_pipeline_ref : op.invariant_meta_pipelines) { + auto sink = meta_pipeline_ref.get().GetSink(); + if (!sink || sink->type != PhysicalOperatorType::HASH_JOIN) { + continue; + } + sink->Cast().SetPreserveBuildForRecursiveReuse(preserve_build); + } +} + +static bool InvariantRecursiveBuildsRemainReusable(const PhysicalRecursiveCTE &op) { + // Recursive invariant caching may omit these meta-pipelines entirely after the first materialization. + // That is only safe if every invariant HASH_JOIN actually kept a reusable in-memory build side. + // External/spilled hash joins run through mutable partition/probe-spill rounds and must therefore + // stay in the active recursive schedule instead of being treated as "materialized once". + for (auto &meta_pipeline_ref : op.invariant_meta_pipelines) { + auto sink = meta_pipeline_ref.get().GetSink(); + if (!sink || sink->type != PhysicalOperatorType::HASH_JOIN) { + continue; + } + if (!sink->Cast().CanPreserveBuildForRecursiveReuse()) { + return false; + } + } + return true; +} + +static void ProcessRecursiveExecutorTasks(Executor &executor) { + // Only drain the recursive executor here. Re-entering the outer query executor while waiting for + // recursive work can run unrelated tasks and used to break recursive completion tracking. + if (!executor.WorkOnTasks()) { + executor.WaitForTask(); + } + if (executor.HasError()) { + executor.ThrowException(); + } +} + +static void WaitForRecursiveEvents(Executor &executor, vector> &events) { + while (true) { + ProcessRecursiveExecutorTasks(executor); + bool finished = true; + for (auto &event : events) { + if (!event->IsFinished()) { + finished = false; + break; + } + } + if (finished) { + break; + } + } +} + +static void WaitForRecursiveEvent(Executor &executor, Event &event) { + while (!event.IsFinished()) { + ProcessRecursiveExecutorTasks(executor); + } +} + +// Build the recursive Event chain for one MetaPipeline using the neutral classification from +// BuildRecursiveMetaPipelinePlan(). +// +// Why this exists instead of letting MetaPipeline schedule itself directly: +// - the normal scheduler builds a query-wide, one-shot Event graph; recursive execution only wants +// to rebuild the recursive subset for the current iteration +// - recursive execution may omit already-materialized invariant build pipelines, so the event graph +// is not just a replay of the original query schedule +// - the recursive path needs RecursiveCTEPipelineEvent at the execute edge so it can reuse cached +// PipelineExecutors instead of constructing fresh ones every iteration +// +// The BASE pipeline anchors the chain for the whole meta-pipeline. Other pipelines either share the +// BASE finish event, own their own finish chain, or share another pipeline's finish group, mirroring +// the same finish/finalize semantics as the normal scheduler with less per-iteration setup work. +static void ScheduleRecursiveMetaPipeline(const shared_ptr &meta_pipeline, RecursiveCTEState &state, + Executor &executor, recursive_cte_event_map_t &event_map, + vector> &events) { + auto meta_pipeline_plan = BuildRecursiveMetaPipelinePlan(*meta_pipeline); + RecursiveCTEEventStack base_stack(nullptr, nullptr, nullptr, nullptr); + for (auto &entry : meta_pipeline_plan.entries) { + auto &pipeline = entry.pipeline.get(); + if (entry.type != RecursiveCTEMetaPipelineEntryType::BASE) { + D_ASSERT(base_stack.pipeline_event); + D_ASSERT(base_stack.pipeline_prepare_finish_event); + D_ASSERT(base_stack.pipeline_finish_event); + D_ASSERT(base_stack.pipeline_complete_event); + } + switch (entry.type) { + case RecursiveCTEMetaPipelineEntryType::BASE: { + // The base pipeline owns the full execute -> prepare-finish -> finish -> complete chain. + base_stack = CreateRecursiveEventStack(pipeline.shared_from_this(), state); + auto completion = make_shared_ptr(executor, false); + completion->AddDependency(*base_stack.pipeline_finish_event); + base_stack.pipeline_complete_event = completion; + event_map.emplace(reference(pipeline), base_stack); + events.push_back(base_stack.pipeline_event); + events.push_back(base_stack.pipeline_prepare_finish_event); + events.push_back(base_stack.pipeline_finish_event); + events.push_back(completion); + break; + } + case RecursiveCTEMetaPipelineEntryType::SHARED_FINISH_GROUP: { + // This pipeline executes independently, but its finalize work is merged into an existing + // finish group selected by MetaPipeline. + D_ASSERT(entry.finish_group); + auto group_entry = event_map.find(*entry.finish_group); + D_ASSERT(group_entry != event_map.end()); + auto execute = make_shared_ptr(pipeline.shared_from_this(), state); + execute->AddDependency(*base_stack.pipeline_finish_event); + group_entry->second.pipeline_prepare_finish_event->AddDependency(*execute); + event_map.emplace(reference(pipeline), + RecursiveCTEEventStack(execute, group_entry->second.pipeline_prepare_finish_event, + group_entry->second.pipeline_finish_event, + base_stack.pipeline_complete_event)); + events.push_back(execute); + break; + } + case RecursiveCTEMetaPipelineEntryType::HAS_FINISH_EVENT: { + // This pipeline needs its own prepare/finish chain, but still joins the BASE completion edge + // so the whole meta-pipeline reports done only after every finish stage completed. + auto pipeline_stack = CreateRecursiveEventStack(pipeline.shared_from_this(), state); + pipeline_stack.pipeline_event->AddDependency(*base_stack.pipeline_finish_event); + base_stack.pipeline_complete_event->AddDependency(*pipeline_stack.pipeline_finish_event); + pipeline_stack.pipeline_complete_event = base_stack.pipeline_complete_event; + event_map.emplace(reference(pipeline), pipeline_stack); + events.push_back(pipeline_stack.pipeline_event); + events.push_back(pipeline_stack.pipeline_prepare_finish_event); + events.push_back(pipeline_stack.pipeline_finish_event); + break; + } + case RecursiveCTEMetaPipelineEntryType::SHARED_BASE_FINISH: { + // This is the cheapest case: only the execute stage is private, while prepare/finish reuse + // the BASE pipeline chain. + auto execute = make_shared_ptr(pipeline.shared_from_this(), state); + base_stack.pipeline_prepare_finish_event->AddDependency(*execute); + event_map.emplace(reference(pipeline), + RecursiveCTEEventStack(execute, base_stack.pipeline_prepare_finish_event, + base_stack.pipeline_finish_event, + base_stack.pipeline_complete_event)); + events.push_back(execute); + break; + } + default: + throw InternalException("Unsupported recursive meta pipeline plan entry"); + } + } + + for (auto &pipeline : meta_pipeline_plan.initialize_on_schedule_pipelines) { + pipeline.get().ResetSource(true); + } +} + +static void ScheduleRecursivePipelines(const vector> &meta_pipelines, RecursiveCTEState &state, + Executor &executor, vector> &events) { + // Rebuild only the recursive subset of the event graph for this iteration. Reusing the query's + // original scheduler wiring is not enough here: recursive iterations may omit already-materialized + // invariant build pipelines, and root execute events need a dedicated prepare step before any task + // starts so shared pipeline state is not reset concurrently with active execution. + recursive_cte_event_map_t event_map; + for (auto &meta_pipeline : meta_pipelines) { + ScheduleRecursiveMetaPipeline(meta_pipeline, state, executor, event_map, events); + } + + for (auto &entry : event_map) { + auto &pipeline = entry.first.get(); + for (auto &dependency : pipeline.GetDependencies()) { + auto dep = dependency.lock(); + D_ASSERT(dep); + auto event_entry = event_map.find(*dep); + if (event_entry == event_map.end()) { + continue; + } + entry.second.pipeline_event->AddDependency(*event_entry->second.pipeline_complete_event); + } + } + + for (auto &meta_pipeline : meta_pipelines) { + for (auto &entry : meta_pipeline->GetDependencies()) { + auto pipeline_entry = event_map.find(entry.first.get()); + if (pipeline_entry == event_map.end()) { + continue; + } + for (auto &dependency : entry.second) { + auto dependency_entry = event_map.find(dependency.get()); + if (dependency_entry == event_map.end()) { + continue; + } + pipeline_entry->second.pipeline_event->AddDependency(*dependency_entry->second.pipeline_event); + } + } + } + + for (auto &meta_pipeline : meta_pipelines) { + vector> children; + meta_pipeline->GetMetaPipelines(children, false, true); + for (auto &child1 : children) { + if (child1->Type() != MetaPipelineType::JOIN_BUILD) { + continue; + } + auto child1_entry = event_map.find(*child1->GetBasePipeline()); + if (child1_entry == event_map.end()) { + continue; + } + + for (auto &child2 : children) { + if (child2->Type() != MetaPipelineType::JOIN_BUILD || child1.get() == child2.get()) { + continue; + } + if (child1->GetParent().get() != child2->GetParent().get()) { + continue; + } + auto child2_entry = event_map.find(*child2->GetBasePipeline()); + if (child2_entry == event_map.end()) { + continue; + } + + child1_entry->second.pipeline_prepare_finish_event->AddDependency(*child2_entry->second.pipeline_event); + child1_entry->second.pipeline_finish_event->AddDependency( + *child2_entry->second.pipeline_prepare_finish_event); + } + } + } + + for (auto &entry : event_map) { + if (!entry.second.pipeline_event->HasDependencies()) { + entry.second.pipeline_event->PrepareForSchedule(); + } + } + + for (auto &event : events) { + if (!event->HasDependencies()) { + event->Schedule(); + } + } +} + +static void ExecuteRecursivePipelineFinishInline(Pipeline &pipeline, Executor &executor) { + auto finish_event = make_shared_ptr(pipeline.shared_from_this()); + auto complete_event = make_shared_ptr(executor, false); + complete_event->AddDependency(*finish_event); + finish_event->Schedule(); + if (!complete_event->IsFinished()) { + WaitForRecursiveEvent(executor, *complete_event); + } +} + +static void ExecuteRecursiveInlinePlan(RecursiveCTEState &state, Executor &executor, + const RecursiveCTEInlinePlan &plan) { + for (auto &pipeline : plan.initialize_on_schedule_pipelines) { + pipeline.get().ResetSource(true); + } + + vector remaining_dependencies; + remaining_dependencies.reserve(plan.stages.size()); + vector ready_stages; + ready_stages.reserve(plan.stages.size()); + for (idx_t stage_idx = 0; stage_idx < plan.stages.size(); stage_idx++) { + auto dependency_count = plan.stages[stage_idx].dependency_count; + remaining_dependencies.push_back(dependency_count); + if (dependency_count == 0) { + ready_stages.push_back(stage_idx); + } + } + + for (idx_t ready_idx = 0; ready_idx < ready_stages.size(); ready_idx++) { + auto stage_idx = ready_stages[ready_idx]; + auto &stage = plan.stages[stage_idx]; + auto &pipeline = stage.pipeline.get(); + switch (stage.type) { + case RecursiveCTEInlineStageType::EXECUTE: { + pipeline.ResetForReschedule(false); + auto max_threads = GetRecursivePipelineMaxThreads(state, pipeline); + D_ASSERT(max_threads == 1); + auto &executors = state.GetCachedExecutors(pipeline, max_threads); + executors[0]->PrepareForExecution(); + ExecuteRecursivePipelineInline(*executors[0]); + break; + } + case RecursiveCTEInlineStageType::PREPARE_FINISH: + pipeline.PrepareFinalize(); + break; + case RecursiveCTEInlineStageType::FINISH: + ExecuteRecursivePipelineFinishInline(pipeline, executor); + break; + default: + throw InternalException("Unsupported recursive inline stage"); + } + + for (auto dependent_stage : stage.dependents) { + auto &remaining = remaining_dependencies[dependent_stage]; + D_ASSERT(remaining > 0); + remaining--; + if (remaining == 0) { + ready_stages.push_back(dependent_stage); + } + } + } + + if (ready_stages.size() != plan.stages.size()) { + throw InternalException("Recursive inline plan did not schedule every stage"); + } +} + +void PhysicalRecursiveCTE::ExecuteRecursivePipelines(ExecutionContext &context) const { + if (!recursive_meta_pipeline) { + throw InternalException("Missing meta pipeline for recursive CTE"); + } + D_ASSERT(recursive_meta_pipeline->HasRecursiveCTE()); + + auto &gstate = sink_state->Cast(); + auto &executor = recursive_meta_pipeline->GetExecutor(); + auto allow_reuse = Settings::Get(context.client); + auto active_meta_pipelines = GetActiveRecursiveMetaPipelines(*this, gstate); + auto can_cache_invariant_meta_pipelines = allow_reuse && !invariant_meta_pipelines.empty(); + + // The generic executor path would rebuild the recursive event graph and allocate fresh + // PipelineExecutors/tasks every iteration. Recursive execution keeps the already-built recursive + // MetaPipeline, resets only the state that must change, optionally skips invariant build pipelines + // after they have been materialized once, and then picks between: + // - a cached inline dependency plan when the iteration is effectively single-threaded + // - a custom recursive Event graph when the iteration still benefits from parallel execution + + // Reset sink state from the main thread so recursive iterations can reuse or recreate + // pipeline-local global sinks without tearing down the rest of the runtime state graph. + for (auto &meta_pipeline : active_meta_pipelines) { + vector> pipelines; + meta_pipeline->GetPipelines(pipelines, false); + for (auto &pipeline : pipelines) { + auto sink = pipeline->GetSink(); + if (sink.get() != this) { + pipeline->ResetSinkForReschedule(); + } + } + } + + ConfigureInvariantRecursiveBuildReuse(*this, can_cache_invariant_meta_pipelines); + + if (!allow_reuse) { + gstate.ClearCachedExecutors(); + } + + auto inline_execution = allow_reuse && GetRecursiveThreadLimit(gstate) == 1; + + if (inline_execution) { + auto &inline_plan = + gstate.invariant_meta_pipelines_materialized ? gstate.invariant_inline_plan : gstate.inline_plan; + if (!inline_plan) { + inline_plan = BuildRecursiveInlinePlan(active_meta_pipelines); + } + ExecuteRecursiveInlinePlan(gstate, executor, *inline_plan); + if (can_cache_invariant_meta_pipelines && InvariantRecursiveBuildsRemainReusable(*this)) { + gstate.invariant_meta_pipelines_materialized = true; + } + return; + } + + vector> events; + ScheduleRecursivePipelines(active_meta_pipelines, gstate, executor, events); + WaitForRecursiveEvents(executor, events); + if (can_cache_invariant_meta_pipelines && InvariantRecursiveBuildsRemainReusable(*this)) { + gstate.invariant_meta_pipelines_materialized = true; + } +} + +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// + +static void GatherColumnDataScans(const PhysicalOperator &op, vector> &delim_scans) { + if (op.type == PhysicalOperatorType::DELIM_SCAN || op.type == PhysicalOperatorType::CTE_SCAN) { + delim_scans.push_back(op); + } + for (auto child : op.GetChildren()) { + GatherColumnDataScans(child.get(), delim_scans); + } +} + +static void GatherRecursiveScansInternal(PhysicalOperator &op, TableIndex cte_index, + vector> &recursive_scans, + reference_set_t &visited) { + if (!visited.insert(op).second) { + return; + } + if (op.type == PhysicalOperatorType::RECURSIVE_CTE_SCAN) { + auto &scan = op.Cast(); + if (scan.cte_index == cte_index) { + recursive_scans.push_back(scan); + } + } + for (auto &child : op.children) { + GatherRecursiveScansInternal(child.get(), cte_index, recursive_scans, visited); + } + if (op.type == PhysicalOperatorType::LEFT_DELIM_JOIN || op.type == PhysicalOperatorType::RIGHT_DELIM_JOIN) { + auto &delim_join = op.Cast(); + GatherRecursiveScansInternal(delim_join.join, cte_index, recursive_scans, visited); + GatherRecursiveScansInternal(delim_join.distinct, cte_index, recursive_scans, visited); + } +} + +static void GatherRecursiveScans(PhysicalOperator &op, TableIndex cte_index, + vector> &recursive_scans) { + reference_set_t visited; + GatherRecursiveScansInternal(op, cte_index, recursive_scans, visited); +} + +static void CountRecursiveReferencesInternal(const PhysicalOperator &op, TableIndex cte_index, + idx_t &recursive_reference_count, idx_t &recurring_reference_count, + reference_set_t &visited) { + if (!visited.insert(op).second) { + return; + } + if (op.type == PhysicalOperatorType::RECURSIVE_CTE_SCAN || + op.type == PhysicalOperatorType::RECURSIVE_RECURRING_CTE_SCAN) { + auto &scan = op.Cast(); + if (scan.cte_index == cte_index) { + if (op.type == PhysicalOperatorType::RECURSIVE_CTE_SCAN) { + recursive_reference_count++; + } else { + recurring_reference_count++; + } + } + } + for (auto child : op.GetChildren()) { + CountRecursiveReferencesInternal(child.get(), cte_index, recursive_reference_count, recurring_reference_count, + visited); + } +} + +static void CountRecursiveReferences(const PhysicalOperator &op, TableIndex cte_index, idx_t &recursive_reference_count, + idx_t &recurring_reference_count) { + reference_set_t visited; + CountRecursiveReferencesInternal(op, cte_index, recursive_reference_count, recurring_reference_count, visited); +} + +void PhysicalRecursiveCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + op_state.reset(); + sink_state.reset(); + recursive_meta_pipeline.reset(); + { + D_ASSERT(shared_executor_pool); + lock_guard guard(shared_executor_pool->lock); + shared_executor_pool->executors.clear(); + } + recursive_reference_count = 0; + recurring_reference_count = 0; + recursive_scans.clear(); + invariant_meta_pipelines.clear(); + + auto &state = meta_pipeline.GetState(); + state.SetPipelineSource(current, *this); + + auto &executor = meta_pipeline.GetExecutor(); + executor.AddRecursiveCTE(*this); + + // the LHS of the recursive CTE is our initial state + auto &initial_state_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + initial_state_pipeline.Build(children[0]); + + // the RHS is the recursive pipeline + recursive_meta_pipeline = make_shared_ptr(executor, state, this); + recursive_meta_pipeline->SetRecursiveCTE(); + recursive_meta_pipeline->Build(children[1]); + CountRecursiveReferences(children[1], table_index, recursive_reference_count, recurring_reference_count); + GatherRecursiveScans(children[1], table_index, recursive_scans); + + vector> ops; + GatherColumnDataScans(children[1], ops); + bool has_delim_scan = false; + for (auto op : ops) { + if (op.get().type == PhysicalOperatorType::DELIM_SCAN) { + has_delim_scan = true; + break; + } + } + if (!has_delim_scan) { + vector> recursive_meta_pipelines; + recursive_meta_pipeline->GetMetaPipelines(recursive_meta_pipelines, true, false); + invariant_meta_pipelines = FindInvariantRecursiveMetaPipelines(recursive_meta_pipelines, table_index); + } + + for (auto op : ops) { + auto entry = state.cte_dependencies.find(op); + if (entry == state.cte_dependencies.end()) { + continue; + } + // this chunk scan introduces a dependency to the current pipeline + // namely a dependency on the CTE pipeline to finish + auto cte_dependency = entry->second.get().shared_from_this(); + current.AddDependency(cte_dependency); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp index 492c5add9..7005c4e41 100644 --- a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp +++ b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp @@ -1,5 +1,6 @@ #include "duckdb/execution/perfect_aggregate_hashtable.hpp" +#include "duckdb/common/clustered_aggregate.hpp" #include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/row_operations/row_operations.hpp" #include "duckdb/common/vector/flat_vector.hpp" @@ -23,9 +24,15 @@ PerfectAggregateHashTable::PerfectAggregateHashTable(ClientContext &context, All total_groups = (uint64_t)1 << total_required_bits; // we don't need to store the groups in a perfect hash table, since the group keys can be deduced by their location grouping_columns = group_types_p.size(); + clustered_state.all_clustered = AllAggregatesClustered(aggregate_objects_p); + clustered_state.n_clustered = CountAggregatesClustered(aggregate_objects_p); layout_ptr->Initialize(std::move(aggregate_objects_p)); tuple_size = layout_ptr->GetRowWidth(); + if (clustered_state.n_clustered > 1) { + clustered_state.Initialize(); + } + // allocate and null initialize the data owned_data = make_unsafe_uniq_array_uninitialized(tuple_size * total_groups); data = owned_data.get(); @@ -81,7 +88,8 @@ static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value } } -static void ComputeGroupLocation(Vector &group, Value &min, uintptr_t *address_data, idx_t current_shift, idx_t count) { +static void ComputeGroupLocation(const Vector &group, Value &min, uintptr_t *address_data, idx_t current_shift) { + const idx_t count = group.size(); UnifiedVectorFormat vdata; group.ToUnifiedFormat(vdata); @@ -117,19 +125,24 @@ static void ComputeGroupLocation(Vector &group, Value &min, uintptr_t *address_d void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) { // first we need to find the location in the HT of each of the groups + FlatVector::SetSize(addresses, groups.size()); auto address_data = FlatVector::GetDataMutable(addresses); // zero-initialize the address data memset(address_data, 0, groups.size() * sizeof(uintptr_t)); D_ASSERT(groups.ColumnCount() == group_minima.size()); - // then compute the actual group location by iterating over each of the groups + // Compute raw group ids first; convert them to state pointers below if needed. idx_t current_shift = total_required_bits; for (idx_t i = 0; i < groups.ColumnCount(); i++) { current_shift -= required_bits[i]; - ComputeGroupLocation(groups.data[i], group_minima[i], address_data, current_shift, groups.size()); + ComputeGroupLocation(groups.data[i], group_minima[i], address_data, current_shift); + } + + if (AddChunkClustered(address_data, payload)) { + return; } - // now we have the HT entry number for every tuple - // compute the actual pointer to the data by adding it to the base HT pointer and multiplying by the tuple size + + // Non-clustered path: convert group ids to state pointers and update aggregates. for (idx_t i = 0; i < groups.size(); i++) { const auto group = address_data[i]; if (group >= total_groups) { @@ -142,25 +155,63 @@ void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) address_data[i] = uintptr_t(data) + group * tuple_size; } - // after finding the group location we update the aggregates idx_t payload_idx = 0; auto &aggregates = layout_ptr->GetAggregates(); RowOperationsState row_state(*aggregate_allocator); for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { auto &aggregate = aggregates[aggr_idx]; - auto input_count = (idx_t)aggregate.child_count; if (aggregate.filter) { RowOperations::UpdateFilteredStates(row_state, filter_set.GetFilterData(aggr_idx), aggregate, addresses, payload, payload_idx); } else { - RowOperations::UpdateStates(row_state, aggregate, addresses, payload, payload_idx, payload.size()); + RowOperations::UpdateStates(row_state, aggregate, addresses, payload, payload_idx); } - // move to the next aggregate - payload_idx += input_count; - VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggregate.payload_size), payload.size()); + payload_idx += aggregate.child_count; + VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggregate.payload_size)); } } +bool PerfectAggregateHashTable::AddChunkClustered(uintptr_t *address_data, DataChunk &payload) { + // Build the clustered permutation from raw group ids. + ClusteredAggr clustered; + uint64_t group_ids_buf[STANDARD_VECTOR_SIZE]; + const uint64_t *group_ids_ptr; + auto count = payload.size(); + if constexpr (sizeof(uintptr_t) == sizeof(uint64_t)) { + group_ids_ptr = reinterpret_cast(address_data); + } else { + for (idx_t i = 0; i < count; i++) { + group_ids_buf[i] = static_cast(address_data[i]); + } + group_ids_ptr = group_ids_buf; + } + if (total_groups >= ClusteredAggr::MAX_GID_COUNT) { + return false; + } + if (!clustered_state.TryBuild(clustered, group_ids_ptr, count)) { + return false; + } + clustered.InitializeStates([&](uint64_t gid) { + auto group_id = static_cast(gid); + group_is_set[group_id] = true; + return data + group_id * tuple_size; + }); + + // When all aggregates are clustered-aware, we can skip maintaining per-tuple addresses. + bool skip_addresses = clustered_state.all_clustered; + if (!skip_addresses) { + for (idx_t i = 0; i < count; i++) { + address_data[i] = uintptr_t(data) + address_data[i] * tuple_size; + } + } + + auto &aggregates = layout_ptr->GetAggregates(); + RowOperationsState row_state(*aggregate_allocator); + RowOperations::UpdateStatesClustered(row_state, aggregates, &filter_set, nullptr, addresses, payload, clustered, + skip_addresses); + return true; +} + void PerfectAggregateHashTable::Combine(PerfectAggregateHashTable &other) { D_ASSERT(total_groups == other.total_groups); D_ASSERT(tuple_size == other.tuple_size); @@ -184,14 +235,18 @@ void PerfectAggregateHashTable::Combine(PerfectAggregateHashTable &other) { target_addresses_ptr[combine_count] = target_ptr; combine_count++; if (combine_count == STANDARD_VECTOR_SIZE) { - RowOperations::CombineStates(row_state, *layout_ptr, source_addresses, target_addresses, combine_count); + FlatVector::SetSize(source_addresses, combine_count); + FlatVector::SetSize(target_addresses, combine_count); + RowOperations::CombineStates(row_state, *layout_ptr, source_addresses, target_addresses); combine_count = 0; } } source_ptr += tuple_size; target_ptr += tuple_size; } - RowOperations::CombineStates(row_state, *layout_ptr, source_addresses, target_addresses, combine_count); + FlatVector::SetSize(source_addresses, combine_count); + FlatVector::SetSize(target_addresses, combine_count); + RowOperations::CombineStates(row_state, *layout_ptr, source_addresses, target_addresses); // FIXME: after moving the arena allocator, we currently have to ensure that the pointer is not nullptr, because the // FIXME: Destroy()-function of the hash table expects an allocator in some cases (e.g., for sorted aggregates) @@ -250,7 +305,6 @@ static void ReconstructGroupVector(uint32_t group_values[], Value &min, idx_t re default: throw InternalException("Invalid type for perfect aggregate HT group"); } - FlatVector::SetSize(result, count_t(entry_count)); } void PerfectAggregateHashTable::Scan(idx_t &scan_position, DataChunk &result) { @@ -282,7 +336,6 @@ void PerfectAggregateHashTable::Scan(idx_t &scan_position, DataChunk &result) { ReconstructGroupVector(group_values, group_minima[i], required_bits[i], shift, entry_count, result.data[i]); } // then construct the payloads - result.SetCardinality(entry_count); RowOperationsState row_state(*aggregate_allocator); RowOperations::FinalizeStates(row_state, *layout_ptr, addresses, result, grouping_columns); } @@ -309,12 +362,14 @@ void PerfectAggregateHashTable::Destroy() { for (idx_t i = 0; i < total_groups; i++) { data_pointers[count++] = payload_ptr; if (count == STANDARD_VECTOR_SIZE) { - RowOperations::DestroyStates(row_state, *layout_ptr, addresses, count); + FlatVector::SetSize(addresses, count); + RowOperations::DestroyStates(row_state, *layout_ptr, addresses); count = 0; } payload_ptr += tuple_size; } - RowOperations::DestroyStates(row_state, *layout_ptr, addresses, count); + FlatVector::SetSize(addresses, count); + RowOperations::DestroyStates(row_state, *layout_ptr, addresses); } } // namespace duckdb diff --git a/src/duckdb/src/execution/physical_operator.cpp b/src/duckdb/src/execution/physical_operator.cpp index ce9f6ed62..b1416b2aa 100644 --- a/src/duckdb/src/execution/physical_operator.cpp +++ b/src/duckdb/src/execution/physical_operator.cpp @@ -96,6 +96,10 @@ unique_ptr PhysicalOperator::GetGlobalOperatorState(ClientC return make_uniq(); } +bool PhysicalOperator::ResetGlobalOperatorState(ClientContext &context, GlobalOperatorState &state) const { + return false; +} + OperatorResultType PhysicalOperator::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, GlobalOperatorState &gstate, OperatorState &state) const { throw InternalException("Calling Execute on a node that is not an operator!"); @@ -498,8 +502,6 @@ OperatorFinalizeResultType CachingPhysicalOperator::FinalExecute(ExecutionContex if (state.cached_chunk) { chunk.Move(*state.cached_chunk); state.cached_chunk.reset(); - } else { - chunk.SetCardinality(0); } return OperatorFinalizeResultType::FINISHED; } diff --git a/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp b/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp index 6e384d64c..edadffb3a 100644 --- a/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp @@ -1,7 +1,5 @@ #include "duckdb/main/settings.hpp" -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/common/operator/subtract.hpp" #include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" #include "duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp" #include "duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp" @@ -11,8 +9,6 @@ #include "duckdb/execution/physical_plan_generator.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/main/client_context.hpp" -#include "duckdb/parser/expression/comparison_expression.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/operator/logical_aggregate.hpp" @@ -44,15 +40,22 @@ static bool CanUsePartitionedAggregate(ClientContext &context, LogicalAggregate return false; } } + + return PhysicalPlanGenerator::HasSingleValuePartitions(context, op.groups, child, partition_columns); +} + +bool PhysicalPlanGenerator::HasSingleValuePartitions(ClientContext &context, + const vector> &partitions, + PhysicalOperator &child, vector &partition_columns) { // check if the source is partitioned by the aggregate columns // figure out the columns we are grouping by - for (auto &group_expr : op.groups) { + for (auto &group_expr : partitions) { // only support bound reference here if (group_expr->GetExpressionType() != ExpressionType::BOUND_REF) { return false; } auto &ref = group_expr->Cast(); - partition_columns.push_back(ref.index); + partition_columns.push_back(ref.Index()); } // traverse the children of the aggregate to find the source operator reference child_ref(child); @@ -70,7 +73,7 @@ static bool CanUsePartitionedAggregate(ClientContext &context, LogicalAggregate return false; } auto &ref = expr->Cast(); - new_columns.push_back(ref.index); + new_columns.push_back(ref.Index()); } // continue into child node with new columns partition_columns = std::move(new_columns); @@ -224,7 +227,7 @@ static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate } for (auto &expression : op.expressions) { auto &aggregate = expression->Cast(); - if (aggregate.IsDistinct() || !aggregate.function.HasStateCombineCallback()) { + if (aggregate.IsDistinct() || !aggregate.Function().HasStateCombineCallback()) { // distinct aggregates are not supported in perfect hash aggregates return false; } @@ -241,7 +244,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalAggregate &op) { bool can_use_simple_aggregation = true; for (auto &expression : op.expressions) { auto &aggregate = expression->Cast(); - if (!aggregate.function.HasStateSimpleUpdateCallback()) { + if (!aggregate.Function().GetStateClusterUpdateCallback()) { // unsupported aggregate for simple aggregation: use hash aggregation can_use_simple_aggregation = false; break; @@ -262,16 +265,9 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalAggregate &op) { } if (op.groups.empty() && op.grouping_sets.size() <= 1) { - // no groups, check if we can use a simple aggregation - // special case: aggregate entire columns together - if (can_use_simple_aggregation) { - auto &group_by = Make(op.types, std::move(op.expressions), - op.estimated_cardinality, op.distinct_validity); - group_by.children.push_back(plan); - return group_by; - } - auto &group_by = - Make(context, op.types, std::move(op.expressions), op.estimated_cardinality); + // no groups: use the dedicated ungrouped aggregate path + auto &group_by = Make(op.types, std::move(op.expressions), op.estimated_cardinality, + op.distinct_validity); group_by.children.push_back(plan); return group_by; } @@ -313,7 +309,7 @@ PhysicalOperator &PhysicalPlanGenerator::ExtractAggregateExpressions(PhysicalOpe // bind sorted aggregates for (auto &aggr : aggregates) { auto &bound_aggr = aggr->Cast(); - if (bound_aggr.order_bys) { + if (bound_aggr.GetOrderBys()) { // sorted aggregate! FunctionBinder::BindSortedAggregate(context, bound_aggr, groups, grouping_sets); } @@ -326,18 +322,18 @@ PhysicalOperator &PhysicalPlanGenerator::ExtractAggregateExpressions(PhysicalOpe } for (auto &aggr : aggregates) { auto &bound_aggr = aggr->Cast(); - for (auto &child_expr : bound_aggr.children) { + for (auto &child_expr : bound_aggr.GetChildrenMutable()) { auto ref = make_uniq(child_expr->GetReturnType(), expressions.size()); types.push_back(child_expr->GetReturnType()); expressions.push_back(std::move(child_expr)); child_expr = std::move(ref); } - if (bound_aggr.filter) { - auto &filter = bound_aggr.filter; + if (bound_aggr.GetFilter()) { + auto &filter = bound_aggr.GetFilterMutable(); auto ref = make_uniq(filter->GetReturnType(), expressions.size()); types.push_back(filter->GetReturnType()); expressions.push_back(std::move(filter)); - bound_aggr.filter = std::move(ref); + bound_aggr.GetFilterMutable() = std::move(ref); } } if (expressions.empty()) { diff --git a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp index d7fda1b06..9da0e69a2 100644 --- a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp @@ -93,10 +93,10 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera ExpressionIterator::EnumerateExpression(predicate, [&](Expression &child) { if (child.GetExpressionClass() == ExpressionClass::BOUND_REF) { auto &ref = child.Cast(); - if (ref.index < lhs_width) { - ref.index = ref.index + rhs_width; + if (ref.Index() < lhs_width) { + ref.IndexMutable() = ref.Index() + rhs_width; } else { - ref.index = ref.index - lhs_width; + ref.IndexMutable() = ref.Index() - lhs_width; } } }); @@ -164,9 +164,9 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera return nullptr; } - EntryLookupInfo function_lookup(CatalogType::SCALAR_FUNCTION_ENTRY, arg_min_max); - auto arg_min_max_func = - binder->GetCatalogEntry(SYSTEM_CATALOG, DEFAULT_SCHEMA, function_lookup, OnEntryNotFound::RETURN_NULL); + EntryLookupInfo function_lookup(CatalogType::SCALAR_FUNCTION_ENTRY, Identifier(arg_min_max)); + auto arg_min_max_func = binder->GetCatalogEntry(Identifier::SystemCatalog(), Identifier::DefaultSchema(), + function_lookup, OnEntryNotFound::RETURN_NULL); // Can't find the arg_min/max aggregate we need, so give up before we break anything. if (!arg_min_max_func || arg_min_max_func->type != CatalogType::AGGREGATE_FUNCTION_ENTRY) { return nullptr; @@ -242,8 +242,8 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera auto pk = RowNumberFun::GetFunction().Bind(context); D_ASSERT(pk->GetReturnType() == pk_type); - pk->start = WindowBoundary::UNBOUNDED_PRECEDING; - pk->end = WindowBoundary::CURRENT_ROW_ROWS; + pk->WindowStartMutable() = WindowBoundary::UNBOUNDED_PRECEDING; + pk->WindowEndMutable() = WindowBoundary::CURRENT_ROW_ROWS; pk->SetAlias("row_number"); window_select.emplace_back(std::move(pk)); @@ -395,10 +395,10 @@ PhysicalOperator &PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) D_ASSERT(asof_end->GetReturnType() == asof_type); - asof_end->partitions = std::move(partitions); - asof_end->orders = std::move(orders); - asof_end->start = WindowBoundary::UNBOUNDED_PRECEDING; - asof_end->end = WindowBoundary::CURRENT_ROW_ROWS; + asof_end->PartitionsMutable() = std::move(partitions); + asof_end->OrderByMutable() = std::move(orders); + asof_end->WindowStartMutable() = WindowBoundary::UNBOUNDED_PRECEDING; + asof_end->WindowEndMutable() = WindowBoundary::CURRENT_ROW_ROWS; vector> window_select; window_select.emplace_back(std::move(asof_end)); diff --git a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp index d18c7ee93..19a1dee10 100644 --- a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp @@ -15,7 +15,7 @@ namespace duckdb { static void RewriteJoinCondition(unique_ptr &root_expr, idx_t offset) { ExpressionIterator::VisitExpressionMutable( - root_expr, [&](BoundReferenceExpression &ref, unique_ptr &expr) { ref.index += offset; }); + root_expr, [&](BoundReferenceExpression &ref, unique_ptr &expr) { ref.IndexMutable() += offset; }); } PhysicalOperator &PhysicalPlanGenerator::PlanComparisonJoin(LogicalComparisonJoin &op) { @@ -40,11 +40,11 @@ PhysicalOperator &PhysicalPlanGenerator::PlanComparisonJoin(LogicalComparisonJoi switch (op.join_type) { case JoinType::SEMI: case JoinType::ANTI: + case JoinType::MARK: can_merge = can_merge && op.conditions.size() == 1; break; case JoinType::RIGHT_ANTI: case JoinType::RIGHT_SEMI: - case JoinType::MARK: can_merge = can_merge && op.conditions.size() == 1; can_iejoin = false; break; diff --git a/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp b/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp index ab45ac8ac..15201e54e 100644 --- a/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp @@ -25,7 +25,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalCopyToFile &op) { op.file_path = fs.JoinPath(path, "tmp_" + base); } if (op.per_thread_output || op.file_size_bytes.IsValid() || op.rotate || op.partition_output || - !op.partition_columns.empty()) { + !op.partition_columns.empty() || !op.order_columns.empty()) { if (op.preserve_order == PreserveOrderType::PRESERVE_ORDER) { throw InvalidInputException("PRESERVE_ORDER is not supported with these parameters"); } @@ -83,6 +83,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalCopyToFile &op) { cast_copy.parallel = mode == CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; cast_copy.write_empty_file = op.write_empty_file; cast_copy.hive_file_pattern = op.hive_file_pattern; + cast_copy.order_columns = std::move(op.order_columns); cast_copy.children.push_back(plan); return copy; diff --git a/src/duckdb/src/execution/physical_plan/plan_create_index.cpp b/src/duckdb/src/execution/physical_plan/plan_create_index.cpp index 040c7d90c..b71ad0593 100644 --- a/src/duckdb/src/execution/physical_plan/plan_create_index.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_create_index.cpp @@ -36,7 +36,7 @@ static PhysicalOperator &AddFilter(PhysicalPlanGenerator &plan, LogicalCreateInd auto is_not_null_expr = make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); auto bound_ref = make_uniq(prev.types[i], i); - is_not_null_expr->children.push_back(std::move(bound_ref)); + is_not_null_expr->GetChildrenMutable().push_back(std::move(bound_ref)); filter_exprs.push_back(std::move(is_not_null_expr)); } diff --git a/src/duckdb/src/execution/physical_plan/plan_delete.cpp b/src/duckdb/src/execution/physical_plan/plan_delete.cpp index d51d48ac3..95ba09d7c 100644 --- a/src/duckdb/src/execution/physical_plan/plan_delete.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_delete.cpp @@ -12,8 +12,8 @@ PhysicalOperator &DuckCatalog::PlanDelete(ClientContext &context, PhysicalPlanGe // Get the row_id column index. auto &bound_ref = op.expressions[0]->Cast(); auto &del = planner.Make(op.types, op.table.Cast(), op.table.GetStorage(), - std::move(op.bound_constraints), bound_ref.index, op.estimated_cardinality, - op.return_chunk, std::move(op.return_columns)); + std::move(op.bound_constraints), bound_ref.Index(), + op.estimated_cardinality, op.return_chunk, std::move(op.return_columns)); del.children.push_back(plan); return del; } diff --git a/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp b/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp index aa031c40b..865ea823c 100644 --- a/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp @@ -45,7 +45,7 @@ PhysicalOperator &PhysicalPlanGenerator::PlanDelimJoin(LogicalComparisonJoin &op D_ASSERT(delim_expr->GetExpressionType() == ExpressionType::BOUND_REF); auto &bound_ref = delim_expr->Cast(); delim_types.push_back(bound_ref.GetReturnType()); - distinct_groups.push_back(make_uniq(bound_ref.GetReturnType(), bound_ref.index)); + distinct_groups.push_back(make_uniq(bound_ref.GetReturnType(), bound_ref.Index())); } // we still have to create the DISTINCT clause that is used to generate the duplicate eliminated chunk diff --git a/src/duckdb/src/execution/physical_plan/plan_distinct.cpp b/src/duckdb/src/execution/physical_plan/plan_distinct.cpp index 76fb90c88..e0761133a 100644 --- a/src/duckdb/src/execution/physical_plan/plan_distinct.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_distinct.cpp @@ -7,6 +7,7 @@ #include "duckdb/planner/operator/logical_distinct.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -26,7 +27,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalDistinct &op) { auto &target = distinct_targets[i]; if (target->GetExpressionType() == ExpressionType::BOUND_REF) { auto &bound_ref = target->Cast(); - group_by_references[bound_ref.index] = i; + group_by_references[bound_ref.Index()] = i; } aggregate_types.push_back(target->GetReturnType()); groups.push_back(std::move(target)); @@ -61,9 +62,9 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalDistinct &op) { auto first_aggregate = function_binder.BindAggregateFunction(FirstFunctionGetter::GetFunction(logical_type), std::move(first_children), nullptr, AggregateType::NON_DISTINCT); - first_aggregate->order_bys = op.order_by ? op.order_by->Copy() : nullptr; + first_aggregate->GetOrderBysMutable() = op.order_by ? op.order_by->Copy() : nullptr; - if (ClientConfig::GetConfig(context).enable_optimizer) { + if (Settings::Get(context)) { bool changes_made = false; auto new_expr = OrderedAggregateOptimizer::Apply(context, *first_aggregate, groups, nullptr, changes_made); diff --git a/src/duckdb/src/execution/physical_plan/plan_explain.cpp b/src/duckdb/src/execution/physical_plan/plan_explain.cpp index c061c24c0..319ad69d8 100644 --- a/src/duckdb/src/execution/physical_plan/plan_explain.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_explain.cpp @@ -46,7 +46,6 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalExplain &op) { for (idx_t i = 0; i < keys.size(); i++) { chunk.data[0].Append(Value(keys[i])); chunk.data[1].Append(Value(values[i])); - chunk.SetCardinality(chunk.size() + 1); if (chunk.size() == STANDARD_VECTOR_SIZE) { collection->Append(chunk); chunk.Reset(); diff --git a/src/duckdb/src/execution/physical_plan/plan_get.cpp b/src/duckdb/src/execution/physical_plan/plan_get.cpp index 860349745..359765f6f 100644 --- a/src/duckdb/src/execution/physical_plan/plan_get.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_get.cpp @@ -149,8 +149,8 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { // function does not support projection pushdown auto &table_scan = Make( op.returned_types, op.function, std::move(op.bind_data), op.returned_types, column_ids, vector(), - op.names, std::move(table_filters), op.estimated_cardinality, std::move(op.extra_info), - std::move(op.parameters), std::move(op.virtual_columns)); + IdentifiersToStrings(op.names), std::move(table_filters), op.estimated_cardinality, + std::move(op.extra_info), std::move(op.parameters), std::move(op.virtual_columns)); // first check if an additional projection is necessary if (column_ids.size() == op.returned_types.size()) { bool projection_necessary = false; @@ -199,7 +199,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { auto &table_scan = Make( op.types, op.function, std::move(op.bind_data), op.returned_types, column_ids, std::move(projection_indices), - op.names, std::move(table_filters), op.estimated_cardinality, std::move(op.extra_info), + IdentifiersToStrings(op.names), std::move(table_filters), op.estimated_cardinality, std::move(op.extra_info), std::move(op.parameters), std::move(op.virtual_columns)); auto &cast_table_scan = table_scan.Cast(); cast_table_scan.dynamic_filters = op.dynamic_filters; diff --git a/src/duckdb/src/execution/physical_plan/plan_limit.cpp b/src/duckdb/src/execution/physical_plan/plan_limit.cpp index 5892e73b7..99f66fa07 100644 --- a/src/duckdb/src/execution/physical_plan/plan_limit.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_limit.cpp @@ -5,6 +5,7 @@ #include "duckdb/execution/operator/helper/physical_streaming_limit.hpp" #include "duckdb/execution/operator/helper/physical_limit_percent.hpp" #include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/execution/operator/scan/physical_table_scan.hpp" #include "duckdb/execution/physical_plan_generator.hpp" #include "duckdb/main/config.hpp" #include "duckdb/planner/operator/logical_limit.hpp" @@ -22,8 +23,15 @@ bool UseBatchLimit(PhysicalOperator &child_node, BoundLimitNode &limit_val, Boun while (!finished) { auto ¤t_op = current_ref.get(); switch (current_op.type) { - case PhysicalOperatorType::TABLE_SCAN: + case PhysicalOperatorType::TABLE_SCAN: { + auto &table_scan = current_op.Cast(); + if (table_scan.table_filters && table_scan.table_filters->HasFilters()) { + finished = true; + break; + } + // limit on a table scan without filters - never use batch limit return false; + } case PhysicalOperatorType::PROJECTION: current_ref = current_op.children[0]; break; diff --git a/src/duckdb/src/execution/physical_plan/plan_projection.cpp b/src/duckdb/src/execution/physical_plan/plan_projection.cpp index a0b65a2c4..72d81e869 100644 --- a/src/duckdb/src/execution/physical_plan/plan_projection.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_projection.cpp @@ -23,7 +23,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalProjection &op) { for (idx_t i = 0; i < op.types.size(); i++) { if (op.expressions[i]->GetExpressionType() == ExpressionType::BOUND_REF) { auto &bound_ref = op.expressions[i]->Cast(); - if (bound_ref.index == i) { + if (bound_ref.Index() == i) { continue; } } diff --git a/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp b/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp index d3d3a5576..71e6e06ed 100644 --- a/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp @@ -50,9 +50,9 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalRecursiveCTE &op) { auto &target = op.key_targets[i]; D_ASSERT(target->GetExpressionType() == ExpressionType::BOUND_REF); auto &bound_ref = target->Cast(); - distinct_idx.emplace_back(bound_ref.index); + distinct_idx.emplace_back(bound_ref.Index()); distinct_types.push_back(bound_ref.GetReturnType()); - group_by_references[bound_ref.index] = i; + group_by_references[bound_ref.Index()] = i; } /* diff --git a/src/duckdb/src/execution/physical_plan/plan_set.cpp b/src/duckdb/src/execution/physical_plan/plan_set.cpp index 25df13222..330093982 100644 --- a/src/duckdb/src/execution/physical_plan/plan_set.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_set.cpp @@ -13,7 +13,7 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalSet &op) { // Set a variable. auto &plan = CreatePlan(*op.children[0]); - auto &set_variable = Make(std::move(op.name), op.estimated_cardinality); + auto &set_variable = Make(op.name, op.estimated_cardinality); set_variable.children.push_back(plan); return set_variable; } diff --git a/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp b/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp index 354af63a6..ad30e6a4c 100644 --- a/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp @@ -18,11 +18,11 @@ static vector> CreatePartitionedRowNumExpression(ClientCo const vector &types) { vector> res; auto expr = RowNumberFun::GetFunction().Bind(client); - expr->start = WindowBoundary::UNBOUNDED_PRECEDING; - expr->end = WindowBoundary::UNBOUNDED_FOLLOWING; + expr->WindowStartMutable() = WindowBoundary::UNBOUNDED_PRECEDING; + expr->WindowEndMutable() = WindowBoundary::UNBOUNDED_FOLLOWING; for (idx_t i = 0; i < types.size(); i++) { - expr->partitions.push_back(make_uniq(types[i], i)); - ExpressionBinder::PushCollation(client, expr->partitions.back(), types[i]); + expr->PartitionsMutable().push_back(make_uniq(types[i], i)); + ExpressionBinder::PushCollation(client, expr->PartitionsMutable().back(), types[i]); } res.push_back(std::move(expr)); return res; diff --git a/src/duckdb/src/execution/physical_plan/plan_simple.cpp b/src/duckdb/src/execution/physical_plan/plan_simple.cpp index 608df673a..fe26224f8 100644 --- a/src/duckdb/src/execution/physical_plan/plan_simple.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_simple.cpp @@ -1,3 +1,5 @@ +#include "duckdb/execution/operator/helper/physical_connect.hpp" +#include "duckdb/execution/operator/helper/physical_disconnect.hpp" #include "duckdb/execution/operator/helper/physical_load.hpp" #include "duckdb/execution/operator/helper/physical_transaction.hpp" #include "duckdb/execution/operator/helper/physical_update_extensions.hpp" @@ -34,6 +36,12 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalSimple &op) { case LogicalOperatorType::LOGICAL_UPDATE_EXTENSIONS: return Make(unique_ptr_cast(std::move(op.info)), op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_CONNECT: + return Make(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); + case LogicalOperatorType::LOGICAL_DISCONNECT: + return Make(unique_ptr_cast(std::move(op.info)), + op.estimated_cardinality); default: throw NotImplementedException("Unimplemented type for logical simple operator"); } diff --git a/src/duckdb/src/execution/physical_plan/plan_window.cpp b/src/duckdb/src/execution/physical_plan/plan_window.cpp index ace9b5c1a..a85dbcba8 100644 --- a/src/duckdb/src/execution/physical_plan/plan_window.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_window.cpp @@ -6,6 +6,7 @@ #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/planner/operator/logical_window.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -27,27 +28,67 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { const auto input_width = types.size() - op.expressions.size(); types.resize(input_width); - // Identify streaming windows - const bool enable_optimizer = ClientConfig::GetConfig(context).enable_optimizer; + // Identify streaming windows and partitioned windows + using Columns = vector; + const bool enable_optimizer = Settings::Get(context); vector blocking_windows; vector streaming_windows; + vector partitioned_windows; + vector partitioned_columns; for (idx_t expr_idx = 0; expr_idx < op.expressions.size(); expr_idx++) { - if (enable_optimizer && PhysicalStreamingWindow::IsStreamingFunction(context, op.expressions[expr_idx])) { + auto &wexpr = op.expressions[expr_idx]->Cast(); + Columns partition_columns; + if (enable_optimizer && PhysicalStreamingWindow::IsStreamingFunction(context, wexpr)) { streaming_windows.push_back(expr_idx); + } else if (!wexpr.Partitions().empty() && + HasSingleValuePartitions(context, wexpr.Partitions(), plan, partition_columns)) { + partitioned_windows.push_back(expr_idx); } else { blocking_windows.push_back(expr_idx); } + // Index these by expr_idx so we don't have to move them... + partitioned_columns.emplace_back(std::move(partition_columns)); } + // Streaming takes priority over partitioning + const bool has_streaming = !streaming_windows.empty(); + if (has_streaming) { + for (auto &expr_idx : partitioned_windows) { + blocking_windows.emplace_back(expr_idx); + } + partitioned_windows.clear(); + } + + // Find the widest partitioning supported by the input + if (!partitioned_windows.empty()) { + vector remaining; + auto widest = *std::max_element(partitioned_windows.begin(), partitioned_windows.end(), + [&](const idx_t lhs, const idx_t rhs) { + return partitioned_columns[lhs].size() < partitioned_columns[rhs].size(); + }); + for (auto &expr_idx : partitioned_windows) { + if (partitioned_columns[expr_idx] == partitioned_columns[widest]) { + remaining.emplace_back(expr_idx); + } else { + blocking_windows.emplace_back(expr_idx); + } + } + remaining.swap(partitioned_windows); + } + + // Restore blocking function order + std::sort(blocking_windows.begin(), blocking_windows.end()); + // Process the window functions by sharing the partition/order definitions unordered_map projection_map; vector> window_expressions; - idx_t streaming_count = 0; + idx_t special_count = 0; auto output_pos = input_width; - while (!blocking_windows.empty() || !streaming_windows.empty()) { - const bool process_blocking = streaming_windows.empty(); - auto &remaining = process_blocking ? blocking_windows : streaming_windows; - streaming_count += process_blocking ? 0 : 1; + auto &special_windows = streaming_windows.empty() ? partitioned_windows : streaming_windows; + while (!blocking_windows.empty() || !special_windows.empty()) { + const bool process_blocking = special_windows.empty(); + auto &remaining = process_blocking ? blocking_windows : special_windows; + special_count += process_blocking ? 0 : 1; // Find all functions that share the partitioning of the first remaining expression auto over_idx = remaining[0]; @@ -88,14 +129,14 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { // Is there a common sort prefix? const auto prefix = over_expr.GetSharedOrders(wexpr); - if (prefix != MinValue(over_expr.orders.size(), wexpr.orders.size())) { + if (prefix != MinValue(over_expr.OrderBy().size(), wexpr.OrderBy().size())) { unprocessed.emplace_back(expr_idx); continue; } matching.emplace_back(expr_idx); // Switch to the longer prefix - if (prefix < wexpr.orders.size()) { + if (prefix < wexpr.OrderBy().size()) { over_idx = expr_idx; } } @@ -120,14 +161,21 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { } // Chain the new window operator on top of the plan - if (i >= streaming_count) { + if (i >= special_count) { auto &window = Make(types, std::move(select_list), op.estimated_cardinality); window.children.push_back(plan); plan = window; - } else { + } else if (has_streaming) { auto &window = Make(types, std::move(select_list), op.estimated_cardinality); window.children.push_back(plan); plan = window; + } else { + const auto expr_idx = matching[0]; + auto &partitions = partitioned_columns[expr_idx]; + auto &window = + Make(types, std::move(select_list), op.estimated_cardinality, std::move(partitions)); + window.children.push_back(plan); + plan = window; } } diff --git a/src/duckdb/src/execution/physical_plan_generator.cpp b/src/duckdb/src/execution/physical_plan_generator.cpp index 87c3feac1..95b0d84fb 100644 --- a/src/duckdb/src/execution/physical_plan_generator.cpp +++ b/src/duckdb/src/execution/physical_plan_generator.cpp @@ -30,20 +30,23 @@ PhysicalOperator &PhysicalPlanGenerator::ResolveAndPlan(unique_ptrResolveOperatorTypes(); - profiler.EndPhase(); + { + auto timer = profiler.StartTimer(); + op->ResolveOperatorTypes(); + } // Resolve the column references. - profiler.StartPhase(MetricType::PHYSICAL_PLANNER_COLUMN_BINDING); - ColumnBindingResolver resolver; - resolver.VisitOperator(*op); - profiler.EndPhase(); + { + auto timer = profiler.StartTimer(); + ColumnBindingResolver resolver; + resolver.VisitOperator(*op); + } // Create the main physical plan. - profiler.StartPhase(MetricType::PHYSICAL_PLANNER_CREATE_PLAN); - physical_plan = PlanInternal(*op); - profiler.EndPhase(); + { + auto timer = profiler.StartTimer(); + physical_plan = PlanInternal(*op); + } // Return a reference to the root of this plan. return physical_plan->Root(); @@ -154,6 +157,8 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalOperator &op) { case LogicalOperatorType::LOGICAL_LOAD: case LogicalOperatorType::LOGICAL_ATTACH: case LogicalOperatorType::LOGICAL_DETACH: + case LogicalOperatorType::LOGICAL_CONNECT: + case LogicalOperatorType::LOGICAL_DISCONNECT: return CreatePlan(op.Cast()); case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: return CreatePlan(op.Cast()); diff --git a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp index d52bacdd7..a3a559fd8 100644 --- a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp +++ b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp @@ -104,6 +104,7 @@ struct RadixHTConfig { public: explicit RadixHTConfig(RadixHTGlobalSinkState &sink); + void Reset(); void SetRadixBits(const idx_t &radix_bits_p); bool SetRadixBitsToExternal(); idx_t GetRadixBits() const; @@ -184,6 +185,8 @@ class RadixHTGlobalSinkState : public GlobalSinkState { const idx_t block_alloc_size; //! If any thread has called combine atomic any_combined; + //! If any thread has called ht.Abandon() during Sink (meaning uncombined_data may have duplicates) + atomic any_abandoned; //! The radix HT const RadixPartitionedHashTable &radix_ht; @@ -215,7 +218,7 @@ RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const R number_of_threads(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())), memory_limit(BufferManager::GetBufferManager(context).GetOperatorMemoryLimit()), block_alloc_size(BufferManager::GetBufferManager(context).GetBlockAllocSize()), any_combined(false), - radix_ht(radix_ht_p), config(*this), stored_allocators_size(0), finalize_done(0), + any_abandoned(false), radix_ht(radix_ht_p), config(*this), stored_allocators_size(0), finalize_done(0), scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE), count_before_combining(0), max_partition_size(0) { // Compute minimum reservation @@ -266,7 +269,7 @@ void RadixHTGlobalSinkState::Destroy() { TupleDataChunkIterator iterator(data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); auto &row_locations = iterator.GetChunkState().row_locations; do { - RowOperations::DestroyStates(row_state, layout, row_locations, iterator.GetCurrentChunkCount()); + RowOperations::DestroyStates(row_state, layout, row_locations); } while (iterator.Next()); data_collection.Reset(); } @@ -278,6 +281,10 @@ RadixHTConfig::RadixHTConfig(RadixHTGlobalSinkState &sink_p) sink_radix_bits(InitialSinkRadixBits()), maximum_sink_radix_bits(MaximumSinkRadixBits()) { } +void RadixHTConfig::Reset() { + sink_radix_bits = InitialSinkRadixBits(); +} + void RadixHTConfig::SetRadixBits(const idx_t &radix_bits_p) { const auto max_bits = MinValue(maximum_sink_radix_bits, ExternalRadixBits(true)); SetRadixBitsInternal(MinValue(radix_bits_p, max_bits), false); @@ -371,6 +378,7 @@ idx_t RadixHTConfig::SinkCapacity() const { class RadixHTLocalSinkState : public LocalSinkState { public: RadixHTLocalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); + void ResetForReuse(const RadixPartitionedHashTable &radix_ht, RadixHTGlobalSinkState &gstate); public: //! Thread-local HT that is re-used after abandoning @@ -383,6 +391,8 @@ class RadixHTLocalSinkState : public LocalSinkState { static constexpr idx_t ADAPTIVITY_THRESHOLD = 1048576; //! Whether we have decided to adapt our strategy bool adapted; + //! Whether this local state has already registered itself as active for the current iteration + bool registered; //! Sink capacity for this thread idx_t local_sink_capacity; @@ -391,7 +401,7 @@ class RadixHTLocalSinkState : public LocalSinkState { }; RadixHTLocalSinkState::RadixHTLocalSinkState(ClientContext &, const RadixPartitionedHashTable &radix_ht) - : adapted(false), local_sink_capacity(DConstants::INVALID_INDEX) { + : adapted(false), registered(false), local_sink_capacity(DConstants::INVALID_INDEX) { // If there are no groups we create a fake group so everything has the same group group_chunk.InitializeEmpty(radix_ht.group_types); if (radix_ht.grouping_set.empty()) { @@ -399,6 +409,29 @@ RadixHTLocalSinkState::RadixHTLocalSinkState(ClientContext &, const RadixPartiti } } +void RadixHTLocalSinkState::ResetForReuse(const RadixPartitionedHashTable &radix_ht, RadixHTGlobalSinkState &gstate) { + group_chunk.Reset(); + if (radix_ht.grouping_set.empty()) { + group_chunk.data[0].Reference(Value::TINYINT(42), count_t(STANDARD_VECTOR_SIZE)); + } + registered = false; + abandoned_data.reset(); + if (!ht) { + adapted = false; + local_sink_capacity = DConstants::INVALID_INDEX; + return; + } + + local_sink_capacity = MaxValue(gstate.config.sink_capacity, ht->Capacity()); + ht->ResetForNewIteration(local_sink_capacity, gstate.config.GetRadixBits()); + if (gstate.number_of_threads > RadixHTConfig::GROW_STRATEGY_THREAD_THRESHOLD) { + ht->EnableHLL(true); + adapted = false; + } else { + adapted = true; + } +} + unique_ptr RadixPartitionedHashTable::GetGlobalSinkState(ClientContext &context) const { return make_uniq(context, *this); } @@ -407,6 +440,34 @@ unique_ptr RadixPartitionedHashTable::GetLocalSinkState(Executio return make_uniq(context.client, *this); } +void RadixPartitionedHashTable::ResetGlobalSinkState(ClientContext &context, GlobalSinkState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + gstate.Destroy(); + gstate.temporary_memory_state->SetMinimumReservation(gstate.minimum_reservation); + gstate.temporary_memory_state->SetRemainingSizeAndUpdateReservation(context, gstate.minimum_reservation); + gstate.finalized = false; + gstate.external = false; + gstate.active_threads = 0; + gstate.any_combined = false; + gstate.any_abandoned = false; + gstate.config.Reset(); + gstate.uncombined_data.reset(); + gstate.stored_allocators.clear(); + gstate.stored_allocators_size = 0; + gstate.partitions.clear(); + gstate.finalize_done = 0; + gstate.scan_pin_properties = TupleDataPinProperties::DESTROY_AFTER_DONE; + gstate.count_before_combining = 0; + gstate.max_partition_size = 0; +} + +void RadixPartitionedHashTable::ResetLocalSinkState(ExecutionContext &context, GlobalSinkState &gstate_p, + LocalSinkState &lstate_p) const { + auto &gstate = gstate_p.Cast(); + auto &lstate = lstate_p.Cast(); + lstate.ResetForReuse(*this, gstate); +} + void RadixPartitionedHashTable::PopulateGroupChunk(DataChunk &group_chunk, DataChunk &input_chunk) const { idx_t chunk_index = 0; // Populate the group_chunk @@ -416,9 +477,9 @@ void RadixPartitionedHashTable::PopulateGroupChunk(DataChunk &group_chunk, DataC D_ASSERT(group->GetExpressionType() == ExpressionType::BOUND_REF); auto &bound_ref_expr = group->Cast(); // Reference from input_chunk[group.index] -> group_chunk[chunk_index] - group_chunk.data[chunk_index++].Reference(input_chunk.data[bound_ref_expr.index]); + group_chunk.data[chunk_index++].Reference(input_chunk.data[bound_ref_expr.Index()]); } - group_chunk.SetCardinality(input_chunk.size()); + group_chunk.SetChildCardinality(input_chunk.size()); // the fake group for empty grouping_set was created with v_size=STANDARD_VECTOR_SIZE - resize to match if (grouping_set.empty()) { FlatVector::SetSize(group_chunk.data[0], count_t(input_chunk.size())); @@ -546,7 +607,10 @@ void RadixPartitionedHashTable::Sink(ExecutionContext &context, DataChunk &chunk // Using grow strategy, so won't ever adapt lstate.adapted = true; } + } + if (!lstate.registered) { gstate.active_threads++; + lstate.registered = true; } auto &group_chunk = lstate.group_chunk; @@ -571,6 +635,7 @@ void RadixPartitionedHashTable::Sink(ExecutionContext &context, DataChunk &chunk // This only works because we never resize the HT // We don't do this when running with 1 or 2 threads, it only makes sense when there's many threads ht.Abandon(); + gstate.any_abandoned = true; } // Check if we need to repartition @@ -615,9 +680,6 @@ void RadixPartitionedHashTable::Combine(ExecutionContext &context, GlobalSinkSta auto aggregate_allocator = ht.GetAggregateAllocator(); - // Eagerly destroy the HT - lstate.ht.reset(); - const annotated_lock_guard guard {gstate.lock}; D_ASSERT(!gstate.finalized); if (gstate.uncombined_data) { @@ -638,8 +700,10 @@ void RadixPartitionedHashTable::Finalize(ClientContext &context, GlobalSinkState auto &uncombined_data = *gstate.uncombined_data; gstate.count_before_combining = uncombined_data.Count(); - // If true there is no need to combine, it was all done by a single thread in a single HT - const auto single_ht = !gstate.external && gstate.active_threads == 1 && gstate.number_of_threads == 1; + // If true there is no need to combine, it was all done by a single thread in a single HT. + // This is the case when only one thread contributed data and the HT never overflowed its + // capacity (which would have caused Abandon() to be called, creating duplicates). + const auto single_ht = !gstate.external && gstate.active_threads == 1 && !gstate.any_abandoned; auto &uncombined_partition_data = uncombined_data.GetPartitions(); const auto n_partitions = uncombined_partition_data.size(); @@ -730,6 +794,7 @@ enum class RadixHTScanStatus : uint8_t { INIT, IN_PROGRESS, DONE }; class RadixHTLocalSourceState : public LocalSourceState { public: explicit RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &radix_ht); + void ResetForReuse(); public: //! Do the work this thread has been assigned @@ -771,6 +836,13 @@ unique_ptr RadixPartitionedHashTable::GetLocalSourceState(Exec return make_uniq(context, *this); } +void RadixPartitionedHashTable::ResetGlobalSourceState(ClientContext &context, GlobalSourceState &gstate_p) const { + auto &gstate = gstate_p.Cast(); + gstate.finished = false; + gstate.task_idx = 0; + gstate.task_done = 0; +} + RadixHTGlobalSourceState::RadixHTGlobalSourceState(ClientContext &context_p, const RadixPartitionedHashTable &radix_ht) : context(context_p), finished(false), task_idx(0), task_done(0) { for (column_t column_id = 0; column_id < radix_ht.group_types.size(); column_id++) { @@ -809,8 +881,7 @@ SourceResultType RadixHTGlobalSourceState::AssignTask(RadixHTGlobalSinkState &si } RadixHTLocalSourceState::RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &radix_ht) - : task(RadixHTSourceTaskType::NO_TASK), task_idx(DConstants::INVALID_INDEX), scan_status(RadixHTScanStatus::DONE), - layout(radix_ht.GetLayout().Copy()), aggregate_allocator(BufferAllocator::Get(context.client)), + : layout(radix_ht.GetLayout().Copy()), aggregate_allocator(BufferAllocator::Get(context.client)), row_state(aggregate_allocator) { auto &allocator = BufferAllocator::Get(context.client); auto scan_chunk_types = radix_ht.group_types; @@ -818,6 +889,23 @@ RadixHTLocalSourceState::RadixHTLocalSourceState(ExecutionContext &context, cons scan_chunk_types.push_back(aggr_type); } scan_chunk.Initialize(allocator, scan_chunk_types); + ResetForReuse(); +} + +void RadixHTLocalSourceState::ResetForReuse() { + task = RadixHTSourceTaskType::NO_TASK; + task_idx = DConstants::INVALID_INDEX; + ht.reset(); + scan_status = RadixHTScanStatus::DONE; + aggregate_allocator.Reset(); + row_state.addresses.reset(); + scan_state.Reset(); + scan_chunk.Reset(); +} + +void RadixPartitionedHashTable::ResetLocalSourceState(ExecutionContext &context, LocalSourceState &lstate_p) const { + auto &lstate = lstate_p.Cast(); + lstate.ResetForReuse(); } void RadixHTLocalSourceState::ExecuteTask(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, @@ -921,7 +1009,7 @@ void RadixHTLocalSourceState::Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSo RowOperations::FinalizeStates(row_state, layout, scan_state.chunk_state.row_locations, scan_chunk, group_cols); if (sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE && layout.HasDestructor()) { - RowOperations::DestroyStates(row_state, layout, scan_state.chunk_state.row_locations, scan_chunk.size()); + RowOperations::DestroyStates(row_state, layout, scan_state.chunk_state.row_locations); } auto &radix_ht = sink.radix_ht; @@ -942,7 +1030,6 @@ void RadixHTLocalSourceState::Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSo chunk.data[radix_ht.op.GroupCount() + radix_ht.op.aggregates.size() + i].Reference(radix_ht.grouping_values[i], count_t(scan_chunk.size())); } - chunk.SetCardinality(scan_chunk); D_ASSERT(chunk.size() != 0); } @@ -977,7 +1064,7 @@ SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, D // Special case hack to sort out aggregating from empty intermediates for aggregations without groups D_ASSERT(chunk.ColumnCount() == null_groups.size() + op.aggregates.size() + op.grouping_functions.size()); // For each column in the aggregates, set to initial state - chunk.SetCardinality(1); + chunk.SetChildCardinality(1); for (auto null_group : null_groups) { ConstantVector::SetNull(chunk.data[null_group], count_t(1)); } @@ -985,17 +1072,17 @@ SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, D for (idx_t i = 0; i < op.aggregates.size(); i++) { D_ASSERT(op.aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); auto &aggr = op.aggregates[i]->Cast(); - auto aggr_state = - make_unsafe_uniq_array_uninitialized(aggr.function.GetStateSizeCallback()(aggr.function)); - aggr.function.GetStateInitCallback()(aggr.function, aggr_state.get()); + auto aggr_state = make_unsafe_uniq_array_uninitialized( + aggr.Function().GetStateSizeCallback()(aggr.Function())); + aggr.Function().GetStateInitCallback()(aggr.Function(), aggr_state.get()); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); Vector state_vector(Value::POINTER(CastPointerToValue(aggr_state.get())), count_t(1)); auto &agg_result = chunk.data[null_groups.size() + i]; - aggr.function.GetStateFinalizeCallback()(state_vector, aggr_input_data, agg_result, 1, 0); + aggr.Function().GetStateFinalizeCallback()(state_vector, aggr_input_data, agg_result, 1, 0); FlatVector::SetSize(agg_result, count_t(1)); - if (aggr.function.HasStateDestructorCallback()) { - aggr.function.GetStateDestructorCallback()(state_vector, aggr_input_data, 1); + if (aggr.Function().HasStateDestructorCallback()) { + aggr.Function().GetStateDestructorCallback()(state_vector, aggr_input_data, 1); } } // Place the grouping values (all the groups of the grouping_set condensed into a single value) diff --git a/src/duckdb/src/execution/sample/reservoir_sample.cpp b/src/duckdb/src/execution/sample/reservoir_sample.cpp index 2f5721aaa..a3fd1f013 100644 --- a/src/duckdb/src/execution/sample/reservoir_sample.cpp +++ b/src/duckdb/src/execution/sample/reservoir_sample.cpp @@ -148,7 +148,6 @@ unique_ptr ReservoirSample::GetChunk() { ret->Initialize(allocator, reservoir_types, STANDARD_VECTOR_SIZE); ret->Slice(Chunk(), ret_sel, return_chunk_size); - ret->SetCardinality(return_chunk_size); return ret; } @@ -488,7 +487,7 @@ void ReservoirSample::EvictOverBudgetSamples() { MinValue(FIXED_SAMPLE_SIZE, static_cast(SAVE_PERCENTAGE * static_cast(GetTuplesSeen()))); if (num_samples_to_keep <= 0) { - reservoir_chunk->chunk.SetCardinality(0); + reservoir_chunk->chunk.SetChildCardinality(0); return; } @@ -537,7 +536,6 @@ void ReservoirSample::EvictOverBudgetSamples() { UpdateSampleAppend(new_reservoir_chunk->chunk, reservoir_chunk->chunk, new_sel, num_samples_to_keep); // set the cardinality - new_reservoir_chunk->chunk.SetCardinality(num_samples_to_keep); reservoir_chunk = std::move(new_reservoir_chunk); sel_size = num_samples_to_keep; base_reservoir_sample->UpdateMinWeightThreshold(); @@ -553,7 +551,6 @@ void ReservoirSample::ExpandSerializedSample() { auto copy_count = reservoir_chunk->chunk.size(); SelectionVector tmp_sel = SelectionVector(static_cast(0), copy_count); UpdateSampleAppend(new_res_chunk->chunk, reservoir_chunk->chunk, tmp_sel, copy_count); - new_res_chunk->chunk.SetCardinality(copy_count); std::swap(reservoir_chunk, new_res_chunk); } @@ -710,7 +707,7 @@ void ReservoirSample::UpdateSampleAppend(DataChunk &this_, DataChunk &other, Sel VectorOperations::Copy(other.data[i], this_.data[i], other_sel, append_count, 0, this_.size()); } } - this_.SetCardinality(new_size); + this_.SetChildCardinality(new_size); } void ReservoirSample::AddToReservoir(DataChunk &chunk) { @@ -744,7 +741,6 @@ void ReservoirSample::AddToReservoir(DataChunk &chunk) { } slice->Initialize(Allocator::DefaultAllocator(), types, samples_remaining); slice->Slice(chunk, input_sel, samples_remaining); - slice->SetCardinality(samples_remaining); AddToReservoir(*slice); return; } @@ -857,7 +853,7 @@ void ReservoirSamplePercentage::AddToReservoir(DataChunk &input) { current_sample->AddToReservoir(new_chunk); } else { input.Flatten(); - input.SetCardinality(append_to_current_sample_count); + input.SetChildCardinality(append_to_current_sample_count); current_sample->AddToReservoir(input); } } diff --git a/src/duckdb/src/function/aggregate/distributive/count.cpp b/src/duckdb/src/function/aggregate/distributive/count.cpp index 3c46bd3a0..aae3baf67 100644 --- a/src/duckdb/src/function/aggregate/distributive/count.cpp +++ b/src/duckdb/src/function/aggregate/distributive/count.cpp @@ -1,3 +1,4 @@ +#include "duckdb/common/clustered_aggregate.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/function/aggregate/distributive_functions.hpp" #include "duckdb/function/aggregate/distributive_function_utils.hpp" @@ -7,11 +8,6 @@ namespace duckdb { namespace { struct BaseCountFunction { - template - static void Initialize(STATE &state) { - state = 0; - } - template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { target += source; @@ -136,8 +132,90 @@ struct CountFunction : public BaseCountFunction { } } + template + static void CountClusteredRuns(const ClusteredAggr &cs, const ValidityMask &validity, INDEXER indexer) { + const bool all_valid = !validity.CanHaveNull(); + idx_t pos = 0; + for (idx_t r = 0; r < cs.n_group_runs; r++) { + auto &state = *reinterpret_cast(cs.group_runs[r].state); + const auto *run_sel = cs.group_runs[r].sel; + const auto run_count = cs.group_runs[r].count; + if (all_valid) { + state += UnsafeNumericCast(run_count); + } else { + for (idx_t k = 0; k < run_count; k++) { + state += validity.RowIsValidUnsafe(indexer(run_sel, pos, k)); + } + } + pos += run_count; + } + } + + static void CountScatterClusteredGeneric(Vector &input, const ClusteredAggr &cs, idx_t count) { + if (input.GetVectorType() == VectorType::FLAT_VECTOR) { + auto &validity = FlatVector::Validity(input); + CountClusteredRuns(cs, validity, + [&](const sel_t *run_sel, idx_t, idx_t k) { return run_sel ? run_sel[k] : k; }); + return; + } + UnifiedVectorFormat idata; + input.ToUnifiedFormat(idata); + CountClusteredRuns(cs, idata.validity, [&](const sel_t *run_sel, idx_t, idx_t k) { + return idata.sel->get_index(run_sel ? run_sel[k] : k); + }); + } + + template + static void CountClusteredDict(Vector &input, const ClusteredAggr &clustered, idx_t count, + const sel_t *cluster_iter = nullptr) { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(idata); + if constexpr (SIMPLE_DICT) { + CountClusteredRuns(clustered, idata.validity, + [&](const sel_t *, idx_t pos, idx_t k) { return cluster_iter[pos + k]; }); + } else { + CountClusteredRuns(clustered, idata.validity, [&](const sel_t *run_sel, idx_t, idx_t k) { + return idata.sel->get_index(run_sel ? run_sel[k] : k); + }); + } + } + + static void CountClusterUpdate(Vector inputs[], AggregateInputData &, idx_t input_count, + const ClusteredAggr &clustered, idx_t count) { + D_ASSERT(input_count == 1); + if (inputs[0].GetVectorType() == VectorType::CONSTANT_VECTOR) { + if (ConstantVector::IsNull(inputs[0])) { + return; + } + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + state += UnsafeNumericCast(clustered.group_runs[r].count); + } + return; + } + auto *cluster_iter = clustered.ClusterIter(inputs[0], count); + if (cluster_iter) { + CountClusteredDict(inputs[0], clustered, count, cluster_iter); + return; + } + CountScatterClusteredGeneric(inputs[0], clustered, count); + } + static void CountScatter(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, idx_t count) { + // COUNT(col) clustered fast path: add run counts when the input is all-valid, + // otherwise count validity bits over each run. For simple DICT input, ClusterIter + // pre-composes the dict sel once for the whole chunk. + if (aggr_input_data.clustered) { + auto &cs = *aggr_input_data.clustered; + auto *cluster_iter = cs.ClusterIter(inputs[0], count); + if (cluster_iter) { + CountClusteredDict(inputs[0], cs, count, cluster_iter); + return; + } + CountScatterClusteredGeneric(inputs[0], cs, count); + return; + } auto &input = inputs[0]; if (input.GetVectorType() == VectorType::FLAT_VECTOR && states.GetVectorType() == VectorType::FLAT_VECTOR) { auto sdata = FlatVector::GetDataMutable(states); @@ -176,65 +254,19 @@ struct CountFunction : public BaseCountFunction { } } } - - static inline void CountUpdateLoop(STATE &result, const ValidityMask &mask, idx_t count, - const SelectionVector &sel_vector) { - if (mask.CannotHaveNull()) { - // no NULL values - result += UnsafeNumericCast(count); - return; - } - for (idx_t i = 0; i < count; i++) { - auto idx = sel_vector.get_index(i); - if (mask.RowIsValid(idx)) { - result++; - } - } - } - - static void CountUpdate(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state_p, idx_t count) { - auto &input = inputs[0]; - auto &result = *reinterpret_cast(state_p); - switch (input.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - if (!ConstantVector::IsNull(input)) { - // if the constant is not null increment the state - result += UnsafeNumericCast(count); - } - break; - } - case VectorType::FLAT_VECTOR: { - CountFlatUpdateLoop(result, FlatVector::ValidityMutable(input), count); - break; - } - case VectorType::SEQUENCE_VECTOR: { - // sequence vectors cannot have NULL values - result += UnsafeNumericCast(count); - break; - } - default: { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(idata); - CountUpdateLoop(result, idata.validity, count, *idata.sel); - break; - } - } - } }; -LogicalType GetCountStateType(const BoundAggregateFunction &function) { - child_list_t children; - children.emplace_back("count", LogicalType::BIGINT); - return LogicalType::STRUCT(std::move(children)); +AggregateStateLayout GetCountStateType(const BoundAggregateFunction &function) { + return AggregateStateLayout(LogicalType::BIGINT, AlignValue(function.GetStateSizeCallback()(function))); } unique_ptr CountPropagateStats(ClientContext &context, BoundAggregateExpression &expr, AggregateStatisticsInput &input) { if (!expr.IsDistinct() && !input.child_stats[0].CanHaveNull()) { // count on a column without null values: use count star - expr.function.ReplaceImplementation(CountStarFun::GetFunction()); - expr.function.SetName("count_star"); - expr.children.clear(); + expr.FunctionMutable().ReplaceImplementation(CountStarFun::GetFunction()); + expr.FunctionMutable().SetName("count_star"); + expr.GetChildrenMutable().clear(); } return nullptr; } @@ -246,8 +278,8 @@ AggregateFunction CountFunctionBase::GetFunction() { AggregateFunction::StateInitialize, CountFunction::CountScatter, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, - FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountUpdate); - fun.name = "count"; + FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountClusterUpdate); + fun.SetName("count"); fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); fun.SetStructStateExport(GetCountStateType); fun.SetStatisticsCallback(CountPropagateStats); @@ -256,11 +288,10 @@ AggregateFunction CountFunctionBase::GetFunction() { AggregateFunction CountStarFun::GetFunction() { auto fun = AggregateFunction::NullaryAggregate(LogicalType::BIGINT); - fun.name = "count_star"; + fun.SetName("count_star"); fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); fun.SetWindowBatchCallback(CountStarFunction::Window); - fun.SetStructStateExport(GetCountStateType); return fun; } diff --git a/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp b/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp index 22a935f9d..7d987aa60 100644 --- a/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp +++ b/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp @@ -1,3 +1,4 @@ +#include "duckdb/common/clustered_aggregate.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/function/aggregate/distributive_functions.hpp" #include "duckdb/function/aggregate/distributive_function_utils.hpp" @@ -10,23 +11,38 @@ namespace { template struct FirstState { + static constexpr const char *STATE_NAMES[] = {"value", "is_set", "is_null"}; + using STATE_TYPE = StructStateType; + + using VALUE_TYPE = T; T value; bool is_set; bool is_null; }; struct FirstFunctionBase { - template - static void Initialize(STATE &state) { - state.is_set = false; - state.is_null = false; - } - static bool IgnoreNull() { return false; } }; +template +static inline void ScanClusterRange(idx_t pos, idx_t end, FUNC &&func) { + if constexpr (LAST) { + for (idx_t k = end; k > pos; k--) { + if (func(k - 1)) { + return; + } + } + } else { + for (idx_t k = pos; k < end; k++) { + if (func(k)) { + return; + } + } + } +} + template struct FirstFunction : public FirstFunctionBase { template @@ -51,6 +67,41 @@ struct FirstFunction : public FirstFunctionBase { Operation(state, input, unary_input); } + // Clustered early-exit: first/any_value scan forward and stop at the first valid value; + // last scans backward and stops at the first valid value from the end. + template + static void ClusteredOpInternal(STATE_TYPE &state, const INPUT_TYPE *vals, const sel_t *sel, + const SelectionVector &isel, const ValidityMask &validity, idx_t pos, idx_t end) { + if (!LAST && state.is_set) { + return; + } + ScanClusterRange(pos, end, [&](idx_t k) { + auto idx = isel.get_index(sel ? sel[k] : k); + if (ALL_VALID || validity.RowIsValidUnsafe(idx)) { + state.is_set = true; + state.is_null = false; + state.value = vals[idx]; + return true; + } + if (!SKIP_NULLS) { + state.is_set = true; + state.is_null = true; + return true; + } + return false; + }); + } + + template + static void ClusteredOp(STATE_TYPE &state, const INPUT_TYPE *vals, AggregateUnaryInput &input, const sel_t *sel, + const SelectionVector &isel, const ValidityMask &validity, idx_t pos, idx_t end) { + if (validity.CanHaveNull()) { + ClusteredOpInternal(state, vals, sel, isel, validity, pos, end); + } else { + ClusteredOpInternal(state, vals, sel, isel, validity, pos, end); + } + } + template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { if (!target.is_set) { @@ -129,6 +180,32 @@ struct FirstFunctionString : FirstFunctionStringBase { Operation(state, input, unary_input); } + template + static void ClusteredOpInternal(STATE_TYPE &state, const INPUT_TYPE *vals, AggregateUnaryInput &input, + const sel_t *sel, const SelectionVector &isel, const ValidityMask &validity, + idx_t pos, idx_t end) { + if (!LAST && state.is_set) { + return; + } + ScanClusterRange(pos, end, [&](idx_t k) { + auto idx = isel.get_index(sel ? sel[k] : k); + bool is_null = ALL_VALID ? false : !validity.RowIsValidUnsafe(idx); + FirstFunctionStringBase::template SetValue(state, input.input, vals[idx], + is_null); + return state.is_set; + }); + } + + template + static void ClusteredOp(STATE_TYPE &state, const INPUT_TYPE *vals, AggregateUnaryInput &input, const sel_t *sel, + const SelectionVector &isel, const ValidityMask &validity, idx_t pos, idx_t end) { + if (validity.CanHaveNull()) { + ClusteredOpInternal(state, vals, input, sel, isel, validity, pos, end); + } else { + ClusteredOpInternal(state, vals, input, sel, isel, validity, pos, end); + } + } + template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { if (!state.is_set || state.is_null) { @@ -176,11 +253,11 @@ struct FirstVectorFunction : FirstFunctionStringBase { OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); // slice with a selection vector and generate sort keys if (assign_count == count) { - CreateSortKeyHelpers::CreateSortKey(input, count, modifiers, sort_key); + CreateSortKeyHelpers::CreateSortKey(input, modifiers, sort_key); } else { SelectionVector sel(assign_sel, STANDARD_VECTOR_SIZE); Vector sliced_input(input, sel, assign_count); - CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key); + CreateSortKeyHelpers::CreateSortKey(sliced_input, modifiers, sort_key); } auto sort_key_data = FlatVector::GetData(sort_key); @@ -219,54 +296,33 @@ struct FirstVectorFunction : FirstFunctionStringBase { } }; -LogicalType GetFirstStateType(const BoundAggregateFunction &function) { - child_list_t child_types; - LogicalType value_type = function.GetArguments()[0]; - child_types.emplace_back("value", value_type); - child_types.emplace_back("is_set", LogicalType::BOOLEAN); - child_types.emplace_back("is_null", LogicalType::BOOLEAN); - return LogicalType::STRUCT(std::move(child_types)); -} - template -void FirstFunctionSimpleUpdate(Vector inputs[], AggregateInputData &aggregate_input_data, idx_t input_count, - data_ptr_t state, idx_t count) { - auto agg_state = reinterpret_cast *>(state); - if (LAST) { - // For LAST, iterate backward within each batch to find the last value - // This saves iterating through all elements when we only need the last one - D_ASSERT(input_count == 1); - UnifiedVectorFormat idata; - inputs[0].ToUnifiedFormat(idata); - auto input_data = UnifiedVectorFormat::GetData(idata); - - for (idx_t i = count; i-- > 0;) { - const auto idx = idata.sel->get_index(i); - const auto row_valid = idata.validity.RowIsValid(idx); - if (SKIP_NULLS && !row_valid) { - continue; - } - // Found the last value in this batch - update state and exit - agg_state->is_set = true; - agg_state->is_null = !row_valid; - if (row_valid) { - agg_state->value = input_data[idx]; - } - break; - } - // If we get here with SKIP_NULLS, all values were NULL - keep previous state - } else if (!agg_state->is_set) { - // For FIRST, this skips looping over the input once the aggregate state has been set - AggregateFunction::UnaryUpdate, T, FirstFunction>(inputs, aggregate_input_data, - input_count, state, count); +void FirstFunctionClusterUpdate(Vector inputs[], AggregateInputData &aggregate_input_data, idx_t input_count, + const ClusteredAggr &clustered, idx_t count) { + D_ASSERT(input_count == 1); + UnifiedVectorFormat idata; + inputs[0].ToUnifiedFormat(idata); + auto input_data = UnifiedVectorFormat::GetData(idata); + AggregateUnaryInput unary_input(aggregate_input_data, idata.validity); + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast *>(clustered.group_runs[r].state); + const auto *run_sel = clustered.group_runs[r].sel; + const auto run_count = clustered.group_runs[r].count; + FirstFunction::template ClusteredOp, FirstFunction>( + state, input_data, unary_input, run_sel, *idata.sel, idata.validity, 0, run_count); } } template AggregateFunction GetFirstAggregateTemplated(const LogicalType &type) { - auto result = AggregateFunction::UnaryAggregate, T, T, FirstFunction>(type, type); - result.SetStateSimpleUpdateCallback(FirstFunctionSimpleUpdate); - result.SetStructStateExport(GetFirstStateType); + auto result = + AggregateFunction({type}, type, AggregateFunction::StateSize>, + AggregateFunction::StateInitialize, FirstFunction>, + AggregateFunction::UnaryScatterUpdate, T, FirstFunction>, + AggregateFunction::StateCombine, FirstFunction>, + AggregateFunction::StateFinalize, T, FirstFunction>, + FunctionNullHandling::DEFAULT_NULL_HANDLING, FirstFunctionClusterUpdate); + AggregateFunction::WireStructStateType>(result); return result; } @@ -326,15 +382,11 @@ AggregateFunction GetFirstFunction(const LogicalType &type) { return GetFirstAggregateTemplated(type); case PhysicalType::VARCHAR: if (LAST) { - auto fun = AggregateFunction::UnaryAggregateDestructor, string_t, string_t, - FirstFunctionString>(type, type); - fun.SetStructStateExport(GetFirstStateType); - return fun; + return AggregateFunction::UnaryAggregateDestructor, string_t, string_t, + FirstFunctionString>(type, type); } else { - auto fun = AggregateFunction::UnaryAggregate, string_t, string_t, - FirstFunctionString>(type, type); - fun.SetStructStateExport(GetFirstStateType); - return fun; + return AggregateFunction::UnaryAggregate, string_t, string_t, + FirstFunctionString>(type, type); } default: { using OP = FirstVectorFunction; @@ -342,8 +394,9 @@ AggregateFunction GetFirstFunction(const LogicalType &type) { auto fun = AggregateFunction( {type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, OP::Update, AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, - nullptr, OP::Bind, LAST ? AggregateFunction::StateDestroy : nullptr, nullptr, nullptr); - fun.SetStructStateExport(GetFirstStateType); + FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NoClusterUpdate(), OP::Bind, + LAST ? AggregateFunction::StateDestroy : nullptr, nullptr, nullptr); + AggregateFunction::WireStructStateType(fun); return fun; } } @@ -387,22 +440,24 @@ unique_ptr BindFirst(BindAggregateFunctionInput &input) { template void AddFirstOperator(AggregateFunctionSet &set) { set.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindDecimalFirst)); + nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + BindDecimalFirst)); set.AddFunction(AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, BindFirst)); + FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + BindFirst)); } } // namespace AggregateFunction FirstFunctionGetter::GetFunction(const LogicalType &type) { auto fun = GetFirstFunction(type); - fun.name = "first"; + fun.SetName("first"); return fun; } AggregateFunction LastFunctionGetter::GetFunction(const LogicalType &type) { auto fun = GetFirstFunction(type); - fun.name = "last"; + fun.SetName("last"); return fun; } diff --git a/src/duckdb/src/function/aggregate/distributive/minmax.cpp b/src/duckdb/src/function/aggregate/distributive/minmax.cpp index 1ae833225..3fa331fc0 100644 --- a/src/duckdb/src/function/aggregate/distributive/minmax.cpp +++ b/src/duckdb/src/function/aggregate/distributive/minmax.cpp @@ -1,5 +1,7 @@ #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/function/aggregate_state_layout.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/aggregate_operators.hpp" #include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/types/null_value.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" @@ -21,8 +23,10 @@ namespace { template struct MinMaxState { + using value_type = T; + using STATE_TYPE = OptionalStateType; T value; - bool isset; + bool is_set; }; template @@ -62,27 +66,17 @@ static AggregateFunction GetUnaryAggregate(const LogicalType &type) { } struct MinMaxBase { - template - static void Initialize(STATE &state) { - state.isset = false; - } - template static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, idx_t count) { - if (!state.isset) { - OP::template Assign(state, input, unary_input.input); - state.isset = true; - } else { - OP::template Execute(state, input, unary_input.input); - } + OP::template Operation(state, input, unary_input); } template static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (!state.isset) { + if (!state.is_set) { OP::template Assign(state, input, unary_input.input); - state.isset = true; + state.is_set = true; } else { OP::template Execute(state, input, unary_input.input); } @@ -93,71 +87,65 @@ struct MinMaxBase { } }; -struct NumericMinMaxBase : public MinMaxBase { +template +struct NumericMinMaxBase : public MinMaxBase, public ClusteredStateCopy { template static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &) { state.value = input; } - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); + template + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input) { + if (!local.is_set) { + local.value = input; + local.is_set = true; } else { - target = state.value; + local.value = REDUCE_OP::template Operation(local.value, input); } } -}; -struct MinOperation : public NumericMinMaxBase { template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &) { - if (LessThan::Operation(input, state.value)) { - state.value = input; + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input, idx_t count) { + if (count != 0) { + UpdateClusteredLocal(local, input); } } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.isset) { - // source is NULL, nothing to do - return; - } - if (!target.isset) { - // target is NULL, use source value directly - target = source; - } else if (GreaterThan::Operation(target.value, source.value)) { - target.value = source.value; - } - } -}; - -struct MaxOperation : public NumericMinMaxBase { template static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &) { - if (GreaterThan::Operation(input, state.value)) { - state.value = input; + state.value = REDUCE_OP::template Operation(state.value, input); + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set) { + finalize_data.ReturnNull(); + } else { + target = state.value; } } template static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.isset) { - // source is NULL, nothing to do + if (!source.is_set) { return; } - if (!target.isset) { - // target is NULL, use source value directly + if (!target.is_set) { target = source; - } else if (LessThan::Operation(target.value, source.value)) { - target.value = source.value; + } else { + using value_type = decltype(target.value); + target.value = REDUCE_OP::template Operation(target.value, source.value); } } }; -struct MinMaxStringState : MinMaxState { +using MinOperation = NumericMinMaxBase; +using MaxOperation = NumericMinMaxBase; + +struct MinMaxStringState { + string_t value; + bool is_set; void Destroy() { - if (isset && !value.IsInlined()) { + if (is_set && !value.IsInlined()) { delete[] value.GetData(); } } @@ -171,7 +159,7 @@ struct MinMaxStringState : MinMaxState { // non-inlined string, need to allocate space for it somehow auto len = input.GetSize(); char *ptr; - if (!isset || value.GetSize() < len) { + if (!is_set || value.GetSize() < len) { // we cannot fit this into the current slot - destroy it and re-allocate Destroy(); ptr = new char[len]; @@ -199,7 +187,7 @@ struct StringMinMaxBase : public MinMaxBase { template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { + if (!state.is_set) { finalize_data.ReturnNull(); } else { target = StringVector::AddStringOrBlob(finalize_data.result, state.value); @@ -208,37 +196,32 @@ struct StringMinMaxBase : public MinMaxBase { template static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (!source.isset) { + if (!source.is_set) { // source is NULL, nothing to do return; } - if (!target.isset) { + if (!target.is_set) { // target is NULL, use source value directly Assign(target, source.value, input_data); - target.isset = true; + target.is_set = true; } else { OP::template Execute(target, source.value, input_data); } } }; -struct MinOperationString : public StringMinMaxBase { +template +struct StringMinMaxOperation : public StringMinMaxBase { template static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - if (LessThan::Operation(input, state.value)) { + if (COMPARE::template Operation(input, state.value)) { Assign(state, input, input_data); } } }; -struct MaxOperationString : public StringMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - if (GreaterThan::Operation(input, state.value)) { - Assign(state, input, input_data); - } - } -}; +using MinOperationString = StringMinMaxOperation; +using MaxOperationString = StringMinMaxOperation; template struct VectorMinMaxBase { @@ -248,11 +231,6 @@ struct VectorMinMaxBase { return true; } - template - static void Initialize(STATE &state) { - state.isset = false; - } - template static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { state.Destroy(); @@ -265,9 +243,9 @@ struct VectorMinMaxBase { template static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - if (!state.isset) { + if (!state.is_set) { Assign(state, input, input_data); - state.isset = true; + state.is_set = true; return; } if (LessThan::Operation(input, state.value)) { @@ -277,7 +255,7 @@ struct VectorMinMaxBase { template static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (!source.isset) { + if (!source.is_set) { // source is NULL, nothing to do return; } @@ -286,7 +264,7 @@ struct VectorMinMaxBase { template static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.isset) { + if (!state.is_set) { finalize_data.ReturnNull(); } else { CreateSortKeyHelpers::DecodeSortKey(state.value, finalize_data.result, finalize_data.result_idx, @@ -312,7 +290,8 @@ static AggregateFunction GetMinMaxFunction(const LogicalType &type) { return AggregateFunction( {type}, LogicalType::BLOB, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateSortKeyHelpers::UnaryUpdate, - AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, + AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, + FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NoClusterUpdate(), OP::Bind, AggregateFunction::StateDestroy); } @@ -337,18 +316,22 @@ unique_ptr BindMinMax(BindAggregateFunctionInput &input) { auto &context = input.GetClientContext(); auto &function = input.GetBoundFunction(); auto &arguments = input.GetArguments(); + auto input_type = arguments[0]->GetReturnType(); - // We should also push collations for non-VARCHAR here, but we aren't ready for it yet (see internal #8704) - const auto collation = arguments[0]->GetReturnType().id() == LogicalTypeId::VARCHAR && - (!StringType::GetCollation(arguments[0]->GetReturnType()).empty() || - !Settings::Get(context).empty()); + // The generic non-VARCHAR collation path is not ready yet (see internal #8704). BIT uses an explicit + // binary-comparable key so min/max follows the same logical order as comparisons and ORDER BY. + const auto varchar_collation = + input_type.id() == LogicalTypeId::VARCHAR && + (!StringType::GetCollation(input_type).empty() || !Settings::Get(context).empty()); + const auto collation = input_type.id() == LogicalTypeId::BIT || varchar_collation; auto collated_arg = collation ? arguments[0]->Copy() : nullptr; if (collation && ExpressionBinder::PushCollation(context, collated_arg, collated_arg->GetReturnType())) { // If aggr function is min/max and uses collations, replace bound_function with arg_min/arg_max // to make sure the result's correctness. string function_name = function.GetName() == "min" ? "arg_min" : "arg_max"; QueryErrorContext error_context; - auto func = Catalog::GetEntry(context, "", "", function_name, + auto func = Catalog::GetEntry(context, Identifier(), Identifier(), + Identifier(function_name), OnEntryNotFound::RETURN_NULL, error_context); if (!func) { throw NotImplementedException( @@ -376,7 +359,6 @@ unique_ptr BindMinMax(BindAggregateFunctionInput &input) { return make_uniq(); } - auto input_type = arguments[0]->GetReturnType(); if (input_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); } @@ -385,22 +367,24 @@ unique_ptr BindMinMax(BindAggregateFunctionInput &input) { auto state_export_type = function.GetStateTypeCallback(); auto minmax_func = GetMinMaxOperator(input_type); - minmax_func.SetStructStateExport(state_export_type); + if (state_export_type) { + minmax_func.SetStructStateExport(state_export_type); + } minmax_func.SetName(std::move(name)); minmax_func.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); minmax_func.SetDistinctDependent(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT); auto expr = minmax_func.Bind(context, std::move(arguments)); - arguments = std::move(expr->children); + arguments = std::move(expr->GetChildrenMutable()); - function = std::move(expr->function); - return std::move(expr->bind_info); + function = std::move(expr->FunctionMutable()); + return std::move(expr->BindInfoMutable()); } template AggregateFunction GetMinMaxOperator(const string &name) { - return AggregateFunction(name, {LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, BindMinMax); + return AggregateFunction(Identifier(name), {LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, BindMinMax); } } // namespace @@ -447,9 +431,9 @@ void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_ UnifiedVectorFormat n_format; UnifiedVectorFormat state_format; - auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); + auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(); - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, true); + STATE::VAL_TYPE::PrepareData(val_vector, val_extra_state, val_format, true); n_vector.ToUnifiedFormat(n_format); state_vector.ToUnifiedFormat(state_format); @@ -548,14 +532,16 @@ unique_ptr MinMaxNBind(BindAggregateFunctionInput &input) { template AggregateFunction GetMinMaxNFunction() { return AggregateFunction({LogicalTypeId::ANY, LogicalType::BIGINT}, LogicalType::LIST(LogicalType::ANY), nullptr, - nullptr, nullptr, nullptr, nullptr, nullptr, MinMaxNBind, nullptr); + nullptr, nullptr, nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, + MinMaxNBind, nullptr); } -LogicalType GetExportStateType(const BoundAggregateFunction &function) { - auto struct_children_types = child_list_t {}; +AggregateStateLayout GetExportStateType(const BoundAggregateFunction &function) { + child_list_t struct_children_types; struct_children_types.emplace_back("value", function.GetReturnType()); - struct_children_types.emplace_back("isset", LogicalType::BOOLEAN); - return LogicalType::STRUCT(std::move(struct_children_types)); + struct_children_types.emplace_back("is_set", LogicalType::BOOLEAN); + return AggregateStateLayout(LogicalType::STRUCT(std::move(struct_children_types)), + AlignValue(function.GetStateSizeCallback()(function))); } } // namespace @@ -564,14 +550,14 @@ LogicalType GetExportStateType(const BoundAggregateFunction &function) { //---------------------------------------------------s AggregateFunctionSet MinFun::GetFunctions() { AggregateFunctionSet min("min"); - min.AddFunction(MinFunction::GetFunction().SetStructStateExport(GetExportStateType)); + min.AddFunction(MinFunction::GetFunction()); min.AddFunction(GetMinMaxNFunction().SetStructStateExport(GetExportStateType)); return min; } AggregateFunctionSet MaxFun::GetFunctions() { AggregateFunctionSet max("max"); - max.AddFunction(MaxFunction::GetFunction().SetStructStateExport(GetExportStateType)); + max.AddFunction(MaxFunction::GetFunction()); max.AddFunction(GetMinMaxNFunction().SetStructStateExport(GetExportStateType)); return max; } diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp index 62a6c6ceb..6f586bfd5 100644 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -1,3 +1,4 @@ +#include "duckdb/common/clustered_aggregate.hpp" #include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/sorting/sort.hpp" #include "duckdb/common/vector/flat_vector.hpp" @@ -80,11 +81,13 @@ struct SortedAggregateBindData : public FunctionData { } SortedAggregateBindData(ClientContext &context, BoundAggregateExpression &expr) - : SortedAggregateBindData(context, expr.children, expr.function, expr.bind_info, expr.order_bys->orders) { + : SortedAggregateBindData(context, expr.GetChildrenMutable(), expr.FunctionMutable(), expr.BindInfoMutable(), + expr.GetOrderBysMutable()->orders) { } SortedAggregateBindData(ClientContext &context, BoundWindowExpression &expr) - : SortedAggregateBindData(context, expr.children, *expr.aggregate, expr.bind_info, expr.arg_orders) { + : SortedAggregateBindData(context, expr.GetChildrenMutable(), *expr.AggregateFunction(), expr.BindInfoMutable(), + expr.ArgOrdersMutable()) { } SortedAggregateBindData(const SortedAggregateBindData &other) @@ -196,7 +199,6 @@ struct SortedAggregateState { idx_t total_count = 0; for (column_t i = 0; i < linked.size(); ++i) { funcs[i].BuildListVector(linked[i], chunk.data[i], total_count); - chunk.SetCardinality(linked[i].total_capacity); FlatVector::SetSize(chunk.data[i], count_t(linked[i].total_capacity)); } } @@ -373,7 +375,6 @@ struct SortedAggregateState { for (column_t col_idx = 0; col_idx < input_chunk->ColumnCount(); ++col_idx) { prefixed.data[col_idx + 1].Reference(input_chunk->data[col_idx]); } - prefixed.SetCardinality(*input_chunk); // data[0] was referenced as a constant with count=1 - resize to match FlatVector::SetSize(prefixed.data[0], count_t(input_chunk->size())); } @@ -443,17 +444,6 @@ struct SortedAggregateFunction { D_ASSERT(buffered_cols[b] < input_count); buffered.data[b].Reference(inputs[buffered_cols[b]]); } - buffered.SetCardinality(count); - } - - static void SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, - idx_t count) { - const auto order_bind = aggr_input_data.bind_data->Cast(); - DataChunk arg_input; - ProjectInputs(inputs, order_bind, input_count, count, arg_input); - - const auto order_state = reinterpret_cast(state); - order_state->Update(aggr_input_data, arg_input); } static void ScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, @@ -492,7 +482,7 @@ struct SortedAggregateFunction { order_state->sel.Initialize(sel_data.data() + order_state->offset, count - order_state->offset); start += order_state->nsel; } - sel_data[order_state->offset++] = UnsafeNumericCast(sidx); + sel_data[order_state->offset++] = UnsafeNumericCast(i); } // Append nonempty slices to the arguments @@ -546,12 +536,12 @@ struct SortedAggregateFunction { // State variables auto bind_info = order_bind.bind_info.get(); - AggregateInputData aggr_bind_info(bind_info, aggr_input_data.allocator); + AggregateInputData aggr_bind_info(aggr, bind_info, aggr_input_data.allocator); // Inner aggregate APIs auto initialize = aggr.GetCallbacks().GetStateInitCallback(); auto destructor = aggr.GetCallbacks().GetStateDestructorCallback(); - auto simple_update = aggr.GetCallbacks().GetStateSimpleUpdateCallback(); + auto cluster_update = aggr.GetCallbacks().GetStateClusterUpdateCallback(); auto update = aggr.GetCallbacks().GetStateUpdateCallback(); auto finalize = aggr.GetCallbacks().GetStateFinalizeCallback(); @@ -632,12 +622,14 @@ struct SortedAggregateFunction { for (column_t col_idx = 0; col_idx < scanned.ColumnCount(); ++col_idx) { sliced.data[col_idx].Slice(scanned.data[col_idx], consumed, consumed + input_count); } - sliced.SetCardinality(input_count); - // These are all simple updates, so use it if available - if (simple_update) { - simple_update(sliced.data.data(), aggr_bind_info, sliced.data.size(), agg_state.data(), - sliced.size()); + if (cluster_update) { + ClusteredAggr clustered; + clustered.SetSingleRun(agg_state.data(), sliced.size()); + aggr_bind_info.clustered = &clustered; + cluster_update(sliced.data.data(), aggr_bind_info, sliced.data.size(), clustered, + sliced.size()); + aggr_bind_info.clustered = nullptr; } else { // We are only updating a constant state agg_state_vec.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -689,20 +681,20 @@ struct SortedAggregateFunction { void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, const vector> &groups, optional_ptr> grouping_sets) { - if (!expr.order_bys || expr.order_bys->orders.empty() || expr.children.empty()) { + if (!expr.GetOrderBys() || expr.GetOrderBys()->orders.empty() || expr.GetChildren().empty()) { // not a sorted aggregate: return return; } // Remove unnecessary ORDER BY clauses and return if nothing remains - if (context.config.enable_optimizer) { - if (expr.order_bys->Simplify(groups, grouping_sets)) { - expr.order_bys.reset(); + if (Settings::Get(context)) { + if (expr.GetOrderBysMutable()->Simplify(groups, grouping_sets)) { + expr.GetOrderBysMutable().reset(); return; } } - auto &bound_function = expr.function; - auto &children = expr.children; - auto &order_bys = *expr.order_bys; + auto &bound_function = expr.Function(); + auto &children = expr.GetChildrenMutable(); + auto &order_bys = *expr.GetOrderBysMutable(); auto sorted_bind = make_uniq(context, expr); if (!sorted_bind->sorted_on_args) { @@ -726,41 +718,40 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE AggregateDestructorType::LEGACY>, SortedAggregateFunction::ScatterUpdate, AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, bound_function.GetProperties().GetNullHandling(), - SortedAggregateFunction::SimpleUpdate, nullptr, + SortedAggregateFunction::Finalize, bound_function.GetProperties().GetNullHandling(), nullptr, nullptr, AggregateFunction::StateDestroy, nullptr, SortedAggregateFunction::WindowBatch); - expr.function.ReplaceImplementation(ordered_aggregate); - expr.bind_info = std::move(sorted_bind); - expr.order_bys.reset(); + expr.FunctionMutable().ReplaceImplementation(ordered_aggregate); + expr.BindInfoMutable() = std::move(sorted_bind); + expr.GetOrderBysMutable().reset(); } void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpression &expr) { // Make implicit orderings explicit - auto &aggregate = *expr.aggregate; - if (aggregate.GetOrderDependent() == AggregateOrderDependent::ORDER_DEPENDENT && expr.arg_orders.empty()) { - for (auto &order : expr.orders) { + auto &aggregate = *expr.AggregateFunction(); + if (aggregate.GetOrderDependent() == AggregateOrderDependent::ORDER_DEPENDENT && expr.ArgOrders().empty()) { + for (auto &order : expr.OrderBy()) { const auto type = order.type; const auto null_order = order.null_order; auto expression = order.expression->Copy(); - expr.arg_orders.emplace_back(type, null_order, std::move(expression)); + expr.ArgOrdersMutable().emplace_back(type, null_order, std::move(expression)); } } - if (expr.arg_orders.empty() || expr.children.empty()) { + if (expr.ArgOrders().empty() || expr.GetChildren().empty()) { // not a sorted aggregate: return return; } // Remove unnecessary ORDER BY clauses and return if nothing remains - if (context.config.enable_optimizer) { - if (BoundOrderModifier::Simplify(expr.arg_orders, expr.partitions, nullptr)) { - expr.arg_orders.clear(); + if (Settings::Get(context)) { + if (BoundOrderModifier::Simplify(expr.ArgOrdersMutable(), expr.PartitionsMutable(), nullptr)) { + expr.ArgOrdersMutable().clear(); return; } } - auto &children = expr.children; - auto &arg_orders = expr.arg_orders; + auto &children = expr.GetChildrenMutable(); + auto &arg_orders = expr.ArgOrdersMutable(); auto sorted_bind = make_uniq(context, expr); if (!sorted_bind->sorted_on_args) { @@ -783,15 +774,14 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpr AggregateDestructorType::LEGACY>, SortedAggregateFunction::ScatterUpdate, AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, aggregate.GetProperties().GetNullHandling(), - SortedAggregateFunction::SimpleUpdate, nullptr, + SortedAggregateFunction::Finalize, aggregate.GetProperties().GetNullHandling(), nullptr, nullptr, AggregateFunction::StateDestroy, nullptr, SortedAggregateFunction::WindowBatch); ordered_aggregate.SetWindowCallback(SortedAggregateFunction::Window); aggregate.ReplaceImplementation(ordered_aggregate); - expr.bind_info = std::move(sorted_bind); - expr.arg_orders.clear(); + expr.BindInfoMutable() = std::move(sorted_bind); + expr.ArgOrdersMutable().clear(); } } // namespace duckdb diff --git a/src/duckdb/src/function/aggregate_function.cpp b/src/duckdb/src/function/aggregate_function.cpp index 025cd2825..0a86519a5 100644 --- a/src/duckdb/src/function/aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate_function.cpp @@ -1,9 +1,22 @@ #include "duckdb/function/aggregate_function.hpp" + +#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" namespace duckdb { +AggregateInputData::AggregateInputData(const BoundAggregateExpression &expr, ArenaAllocator &allocator_p, + AggregateCombineType combine_type_p) + : AggregateInputData(expr.Function(), expr.BindInfo().get(), allocator_p, combine_type_p) { +} + +AggregateInputData::AggregateInputData(const AggregateObject &aggr, ArenaAllocator &allocator_p, + AggregateCombineType combine_type_p) + : AggregateInputData(aggr.function, aggr.bind_data_wrapper ? aggr.bind_data_wrapper->function_data.get() : nullptr, + allocator_p, combine_type_p) { +} + bool AggregateFunctionProperties::operator==(const AggregateFunctionProperties &rhs) const { return FunctionProperties::operator==(rhs) && order_dependent == rhs.order_dependent && distinct_dependent == rhs.distinct_dependent; @@ -14,10 +27,10 @@ bool AggregateFunctionProperties::operator!=(const AggregateFunctionProperties & bool AggregateFunctionCallbacks::operator==(const AggregateFunctionCallbacks &rhs) const { return state_size == rhs.state_size && initialize == rhs.initialize && update == rhs.update && - combine == rhs.combine && finalize == rhs.finalize && simple_update == rhs.simple_update && + combine == rhs.combine && finalize == rhs.finalize && cluster_update == rhs.cluster_update && window == rhs.window && window_init == rhs.window_init && window_batch == rhs.window_batch && bind == rhs.bind && destructor == rhs.destructor && statistics == rhs.statistics && - serialize == rhs.serialize && deserialize == rhs.deserialize; + serialize == rhs.serialize && deserialize == rhs.deserialize && get_state_type == rhs.get_state_type; } bool AggregateFunctionCallbacks::operator!=(const AggregateFunctionCallbacks &rhs) const { diff --git a/src/duckdb/src/function/built_in_functions.cpp b/src/duckdb/src/function/built_in_functions.cpp index 58c9929c2..934e779ba 100644 --- a/src/duckdb/src/function/built_in_functions.cpp +++ b/src/duckdb/src/function/built_in_functions.cpp @@ -26,7 +26,7 @@ BuiltinFunctions::~BuiltinFunctions() { void BuiltinFunctions::AddCollation(string name, ScalarFunction function, bool combinable, bool not_required_for_equality) { - CreateCollationInfo info(std::move(name), std::move(function), combinable, not_required_for_equality); + CreateCollationInfo info(Identifier(std::move(name)), std::move(function), combinable, not_required_for_equality); info.internal = true; catalog.CreateCollation(transaction, info); } @@ -50,7 +50,7 @@ void BuiltinFunctions::AddFunction(PragmaFunction function) { } void BuiltinFunctions::AddFunction(const string &name, PragmaFunctionSet functions) { - CreatePragmaFunctionInfo info(name, std::move(functions)); + CreatePragmaFunctionInfo info(Identifier(name), std::move(functions)); info.internal = true; catalog.CreatePragmaFunction(transaction, info); } @@ -63,7 +63,7 @@ void BuiltinFunctions::AddFunction(ScalarFunction function) { void BuiltinFunctions::AddFunction(const vector &names, ScalarFunction function) { // NOLINT: false positive for (auto &name : names) { - function.name = name; + function.name = Identifier(name); AddFunction(function); } } @@ -123,7 +123,7 @@ static unique_ptr BindExtensionFunction(FunctionBindExpressionInput // now find the function in the catalog auto &catalog = Catalog::GetSystemCatalog(db); auto &function_entry = - catalog.GetEntry(context, DEFAULT_SCHEMA, bound_function.GetName()); + catalog.GetEntry(context, Identifier::DefaultSchema(), bound_function.GetName()); // override the function with the extension function const auto &func = function_entry.functions.GetFunctionByArguments(context, bound_function.GetArguments()); diff --git a/src/duckdb/src/function/cast/array_casts.cpp b/src/duckdb/src/function/cast/array_casts.cpp index d4cb51e2f..a115cc3e5 100644 --- a/src/duckdb/src/function/cast/array_casts.cpp +++ b/src/duckdb/src/function/cast/array_casts.cpp @@ -63,7 +63,7 @@ static bool ArrayToArrayCast(Vector &source, Vector &result, idx_t count, CastPa result.SetVectorType(VectorType::CONSTANT_VECTOR); if (ConstantVector::IsNull(source)) { - ConstantVector::SetNull(result, count_t(count)); + ConstantVector::SetNull(result, true); return true; } diff --git a/src/duckdb/src/function/cast/decimal_cast.cpp b/src/duckdb/src/function/cast/decimal_cast.cpp index 76355ce73..01c98db67 100644 --- a/src/duckdb/src/function/cast/decimal_cast.cpp +++ b/src/duckdb/src/function/cast/decimal_cast.cpp @@ -116,7 +116,7 @@ struct DecimalScaleDownOperator { // This function detects if we can scale a decimal down to another. template static bool CanScaleDownDecimal(INPUT_TYPE input, DecimalScaleInput &data) { - int64_t divisor = UnsafeNumericCast(NumericHelper::POWERS_OF_TEN[data.source_scale]); + int64_t divisor = UnsafeNumericCast(data.factor); auto value = input % divisor; auto rounded_input = input; if (rounded_input < 0) { @@ -131,7 +131,7 @@ static bool CanScaleDownDecimal(INPUT_TYPE input, DecimalScaleInput template <> bool CanScaleDownDecimal(hugeint_t input, DecimalScaleInput &data) { - auto divisor = UnsafeNumericCast(Hugeint::POWERS_OF_TEN[data.source_scale]); + auto divisor = UnsafeNumericCast(data.factor); hugeint_t value = input % divisor; hugeint_t rounded_input = input; if (rounded_input < 0) { diff --git a/src/duckdb/src/function/cast/default_casts.cpp b/src/duckdb/src/function/cast/default_casts.cpp index 21d96bb17..0ca0048b7 100644 --- a/src/duckdb/src/function/cast/default_casts.cpp +++ b/src/duckdb/src/function/cast/default_casts.cpp @@ -58,7 +58,7 @@ void HandleCastError::AssignError(const string &error_message, string *error_mes // NULL cast only works if all values in source are NULL, otherwise an unimplemented cast exception is thrown bool DefaultCasts::TryVectorNullCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { bool success = true; - if (VectorOperations::HasNotNull(source, count)) { + if (VectorOperations::HasNotNull(source)) { HandleCastError::AssignError(TryCast::UnimplementedCastMessage(source.GetType(), result.GetType()), parameters); success = false; } @@ -71,54 +71,10 @@ bool DefaultCasts::ReinterpretCast(Vector &source, Vector &result, idx_t count, return true; } -static bool AggregateStateToBlobCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - if (result.GetType().id() != LogicalTypeId::BLOB) { - throw TypeMismatchException(source.GetType(), result.GetType(), - "Cannot cast LEGACY_AGGREGATE_STATE to anything but BLOB"); - } - result.Reinterpret(source); - return true; -} - -static bool AggregateStateToStructReinterpret(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - // Both `AGGREGATE_STATE` and `STRUCT` are `PhysicalType::STRUCT` - // However, we can't use Reinterpret because it's blocked by D_ASSERT for nested types - // Instead, we manually reference the entries - auto &source_entries = StructVector::GetEntries(source); - auto &result_entries = StructVector::GetEntries(result); - - D_ASSERT(source_entries.size() == result_entries.size()); - - for (idx_t i = 0; i < source_entries.size(); i++) { - result_entries[i].Reference(source_entries[i]); - } - - source.Flatten(); - FlatVector::ValidityMutable(result) = FlatVector::Validity(source); - FlatVector::SetSize(result, count_t(count)); - result.Verify(); - return true; -} - bool BoundCastInfo::IsNopCast() const { return function == DefaultCasts::NopCast; } -static BoundCastInfo AggregateStateCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - D_ASSERT(source.IsAggregateStateStructType()); - - LogicalType dummy_struct = LogicalType::STRUCT(AggregateStateType::GetChildTypes(source)); - auto cast_info = input.GetCastFunction(dummy_struct, target); - if (cast_info.IsNopCast()) { - // 1. `NopCast` cannot be used since it expects the types to be the same. - // 2. `ReinterpretCast` cannot be used since it's blocked for nested types. - // 3. We don't want to use `StructToStructCast` in this case as it introduces more complexity while we know - // that we don't have casting to perform - cast_info.SetFunction(AggregateStateToStructReinterpret); - } - return cast_info; -} - static bool NullTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { // cast a NULL to another type, just copy the properties and change the type ConstantVector::SetNull(result, count_t(count)); @@ -211,10 +167,6 @@ BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const L return TypeCastSwitch(input, source, target); case LogicalTypeId::BIGNUM: return BignumCastSwitch(input, source, target); - case LogicalTypeId::LEGACY_AGGREGATE_STATE: - return AggregateStateToBlobCast; - case LogicalTypeId::AGGREGATE_STATE: - return AggregateStateCast(input, source, target); default: return nullptr; } diff --git a/src/duckdb/src/function/cast/string_cast.cpp b/src/duckdb/src/function/cast/string_cast.cpp index 65f4ae400..aee1f9ff6 100644 --- a/src/duckdb/src/function/cast/string_cast.cpp +++ b/src/duckdb/src/function/cast/string_cast.cpp @@ -197,7 +197,7 @@ bool VectorStringToStruct::StringToNestedTypeCastLoop(const string_t *source_dat vector> child_masks; for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { if (!is_unnamed) { - child_names.insert({StructType::GetChildName(result.GetType(), child_idx), child_idx}); + child_names.insert({StructType::GetChildName(result.GetType(), child_idx).GetIdentifierName(), child_idx}); } child_masks.emplace_back(FlatVector::ValidityMutable(child_vectors[child_idx])); child_masks[child_idx].get().SetAllInvalid(count); @@ -461,14 +461,11 @@ BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const Logical return BoundCastInfo( &VectorCastHelpers::TryCastErrorLoop); case LogicalTypeId::TIMESTAMP_NS: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); case LogicalTypeId::TIMESTAMP_SEC: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); case LogicalTypeId::TIMESTAMP_MS: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); case LogicalTypeId::BLOB: return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::BIT: diff --git a/src/duckdb/src/function/cast/struct_cast.cpp b/src/duckdb/src/function/cast/struct_cast.cpp index fe134f200..0e26f4be8 100644 --- a/src/duckdb/src/function/cast/struct_cast.cpp +++ b/src/duckdb/src/function/cast/struct_cast.cpp @@ -26,7 +26,7 @@ unique_ptr StructBoundCastData::BindStructToStructCast(BindCastIn throw TypeMismatchException(input.query_location, source, target, "Cannot cast STRUCTs of different size"); } - InsertionOrderPreservingMap target_children_map; + InsertionOrderPreservingMap> target_children_map; if (!is_unnamed) { for (idx_t i = 0; i < target_children.size(); i++) { auto &name = target_children[i].first; @@ -180,7 +180,8 @@ static bool StructToVarcharCast(Vector &source, Vector &result, idx_t count, Cas auto &child_data = child_iterators[c]; if (!is_unnamed) { auto &name = child_types[c].first; - string_length += VectorCastHelpers::CalculateEscapedStringLength(name, key_needs_quotes[c]); + string_length += VectorCastHelpers::CalculateEscapedStringLength(name.GetIdentifierName(), + key_needs_quotes[c]); string_length += NAME_SEP_LENGTH; // ": " } auto child_entry = child_data[r]; @@ -210,7 +211,8 @@ static bool StructToVarcharCast(Vector &source, Vector &result, idx_t count, Cas if (!is_unnamed) { auto &name = child_types[c].first; // "{: }" - offset += VectorCastHelpers::WriteEscapedString(dataptr + offset, name, key_needs_quotes[c]); + offset += VectorCastHelpers::WriteEscapedString(dataptr + offset, name.GetIdentifierName(), + key_needs_quotes[c]); dataptr[offset++] = ':'; dataptr[offset++] = ' '; } @@ -287,7 +289,7 @@ static bool StructToMapCast(Vector &source, Vector &result, idx_t count, CastPar auto key_data = FlatVector::Writer(varchar_keys, field_count); for (idx_t field_idx = 0; field_idx < field_count; field_idx++) { auto &field_name = field_types[field_idx].first; - key_data.WriteValue(field_name); + key_data.WriteValue(field_name.GetIdentifierName()); } // Cast keys to result type diff --git a/src/duckdb/src/function/cast/time_casts.cpp b/src/duckdb/src/function/cast/time_casts.cpp index c330ddafe..9cab6536f 100644 --- a/src/duckdb/src/function/cast/time_casts.cpp +++ b/src/duckdb/src/function/cast/time_casts.cpp @@ -15,11 +15,11 @@ BoundCastInfo DefaultCasts::DateCastSwitch(BindCastInput &input, const LogicalTy return BoundCastInfo(&VectorCastHelpers::TryCastLoop); case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_TZ_NS: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); case LogicalTypeId::TIMESTAMP_SEC: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); case LogicalTypeId::TIMESTAMP_MS: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); default: return TryVectorNullCast; } @@ -59,7 +59,7 @@ BoundCastInfo DefaultCasts::TimeTzCastSwitch(BindCastInput &input, const Logical switch (target.id()) { case LogicalTypeId::VARCHAR: // time with time zone to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::TIME: // time with time zone to time return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); @@ -90,16 +90,13 @@ BoundCastInfo DefaultCasts::TimestampCastSwitch(BindCastInput &input, const Logi case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_TZ_NS: // timestamp (us) to timestamp [with time zone] (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_MS: // timestamp (us) to timestamp (ms) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_SEC: // timestamp (us) to timestamp (s) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); default: return TryVectorNullCast; } @@ -111,25 +108,22 @@ BoundCastInfo DefaultCasts::TimestampTzCastSwitch(BindCastInput &input, const Lo switch (target.id()) { case LogicalTypeId::VARCHAR: // timestamp with time zone to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::TIME_TZ: // timestamp with time zone to time with time zone. - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP: // timestamp with time zone to timestamp (us) return ReinterpretCast; case LogicalTypeId::TIMESTAMP_NS: // timestamptz (us) to timestamp (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_MS: // timestamptz (us) to timestamp (ms) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_SEC: // timestamptz (us) to timestamp (s) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); default: return TryVectorNullCast; } @@ -141,17 +135,16 @@ BoundCastInfo DefaultCasts::TimestampTzNsCastSwitch(BindCastInput &input, const switch (target.id()) { case LogicalTypeId::VARCHAR: // timestamp(ns) with time zone to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::TIME_TZ: // timestamp with time zone to time with time zone. - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_NS: // timestamp with time zone to timestamp (us) return ReinterpretCast; case LogicalTypeId::TIMESTAMP_TZ: // timestamp with time zone (ns) to timestamp with time zone (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); default: return TryVectorNullCast; } @@ -166,21 +159,17 @@ BoundCastInfo DefaultCasts::TimestampNsCastSwitch(BindCastInput &input, const Lo return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::DATE: // timestamp (ns) to date - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIME: // timestamp (ns) to time - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIME_NS: // timestamp (ns) to time (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: // timestamp (ns) to timestamp [with time zone] (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_MS: return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); default: @@ -197,22 +186,18 @@ BoundCastInfo DefaultCasts::TimestampMsCastSwitch(BindCastInput &input, const Lo return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::DATE: // timestamp (ms) to date - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIME: // timestamp (ms) to time - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: // timestamp (ms) to timestamp [with time zone] (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_TZ_NS: // timestamp (ms) to timestamp [with time zone] (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_SEC: // timestamp (ms) to timestamp (s) return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); @@ -230,26 +215,21 @@ BoundCastInfo DefaultCasts::TimestampSecCastSwitch(BindCastInput &input, const L return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::DATE: // timestamp (s) to date - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIME: // timestamp (s) to time - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_MS: // timestamp (s) to timestamp (ms) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: // timestamp (s) to timestamp [with time zone] (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_TZ_NS: // timestamp (s) to timestamp [with time zone] (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); + return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); default: return TryVectorNullCast; } diff --git a/src/duckdb/src/function/cast/union/from_struct.cpp b/src/duckdb/src/function/cast/union/from_struct.cpp index b7c2dc322..ed8fbdf65 100644 --- a/src/duckdb/src/function/cast/union/from_struct.cpp +++ b/src/duckdb/src/function/cast/union/from_struct.cpp @@ -30,7 +30,7 @@ bool StructToUnionCast::AllowImplicitCastFromStruct(const LogicalType &source, c } continue; } - if (!StringUtil::CIEquals(target_field_name, field_name)) { + if (target_field_name != field_name) { return false; } if (target_field != field && field != LogicalType::VARCHAR) { diff --git a/src/duckdb/src/function/cast/union_casts.cpp b/src/duckdb/src/function/cast/union_casts.cpp index 64fea74ac..8b77e3d72 100644 --- a/src/duckdb/src/function/cast/union_casts.cpp +++ b/src/duckdb/src/function/cast/union_casts.cpp @@ -152,7 +152,7 @@ static unique_ptr BindUnionToUnionCast(BindCastInput &input, cons auto &target_member_name = UnionType::GetMemberName(target, target_idx); // found a matching member - if (StringUtil::CIEquals(source_member_name, target_member_name)) { + if (source_member_name == target_member_name) { auto &target_member_type = UnionType::GetMemberType(target, target_idx); tag_map[source_idx] = static_cast(target_idx); member_casts.push_back(input.GetCastFunction(source_member_type, target_member_type)); @@ -373,7 +373,7 @@ BoundCastInfo DefaultCasts::UnionCastSwitch(BindCastInput &input, const LogicalT // bind a cast in which we convert all members to VARCHAR first child_list_t varchar_members; for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(source); member_idx++) { - varchar_members.push_back(make_pair(UnionType::GetMemberName(source, member_idx), LogicalType::VARCHAR)); + varchar_members.emplace_back(make_pair(UnionType::GetMemberName(source, member_idx), LogicalType::VARCHAR)); } auto varchar_type = LogicalType::UNION(std::move(varchar_members)); return BoundCastInfo(UnionToVarcharCast, BindUnionToUnionCast(input, source, varchar_type), diff --git a/src/duckdb/src/function/cast/variant/from_variant.cpp b/src/duckdb/src/function/cast/variant/from_variant.cpp index 44113bfc5..852708a96 100644 --- a/src/duckdb/src/function/cast/variant/from_variant.cpp +++ b/src/duckdb/src/function/cast/variant/from_variant.cpp @@ -2,10 +2,8 @@ #include "duckdb/common/vector/constant_vector.hpp" #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/common/vector/list_vector.hpp" -#include "duckdb/common/vector/map_vector.hpp" #include "duckdb/common/vector/shredded_vector.hpp" #include "duckdb/common/vector/string_vector.hpp" -#include "duckdb/common/vector/variant_vector.hpp" #include "duckdb/common/vector/struct_vector.hpp" #include "yyjson_utils.hpp" #include "duckdb/function/cast/default_casts.hpp" @@ -197,7 +195,7 @@ static bool CastVariantToPrimitive(FromVariantConversionData &conversion_data, V } static bool FindValues(UnifiedVariantVectorData &variant, idx_t row_index, SelectionVector &sel, - VariantNestedData &nested_data_entry) { + const VariantNestedData &nested_data_entry) { for (idx_t child_idx = 0; child_idx < nested_data_entry.child_count; child_idx++) { auto value_id = variant.GetValuesIndex(row_index, nested_data_entry.children_idx + child_idx); sel[child_idx] = value_id; @@ -210,14 +208,8 @@ static bool CastVariant(FromVariantConversionData &conversion_data, Vector &resu static bool ConvertVariantToList(FromVariantConversionData &conversion_data, Vector &result, const SelectionVector &sel, idx_t offset, idx_t count, optional_idx row) { - auto &allocator = Allocator::DefaultAllocator(); - - AllocatedData owned_child_data; - VariantNestedData *child_data = nullptr; - if (count) { - owned_child_data = allocator.Allocate(sizeof(VariantNestedData) * count); - child_data = reinterpret_cast(owned_child_data.get()); - } + const auto owned_child_data = make_unsafe_uniq_array_uninitialized(count); + const array_ptr child_data(owned_child_data.get(), count); //! Initialize the validity with that of the result (in case some rows are already set to invalid, we need to //! respect that) @@ -229,8 +221,8 @@ static bool ConvertVariantToList(FromVariantConversionData &conversion_data, Vec } } - auto collection_result = VariantUtils::CollectNestedData(conversion_data.variant, VariantLogicalType::ARRAY, sel, - count, row, offset, child_data, validity); + const auto collection_result = VariantUtils::CollectNestedData(conversion_data.variant, VariantLogicalType::ARRAY, + sel, count, row, offset, child_data, validity); if (!collection_result.success) { conversion_data.error = StringUtil::Format("Expected to find VARIANT(ARRAY), found VARIANT(%s) instead, can't convert", @@ -286,14 +278,8 @@ static bool ConvertVariantToList(FromVariantConversionData &conversion_data, Vec static bool ConvertVariantToArray(FromVariantConversionData &conversion_data, Vector &result, const SelectionVector &sel, idx_t offset, idx_t count, optional_idx row) { - auto &allocator = Allocator::DefaultAllocator(); - - AllocatedData owned_child_data; - VariantNestedData *child_data = nullptr; - if (count) { - owned_child_data = allocator.Allocate(sizeof(VariantNestedData) * count); - child_data = reinterpret_cast(owned_child_data.get()); - } + const auto owned_child_data = make_unsafe_uniq_array_uninitialized(count); + const array_ptr child_data(owned_child_data.get(), count); //! Initialize the validity with that of the result (in case some rows are already set to invalid, we need to //! respect that) @@ -353,14 +339,8 @@ static bool ConvertVariantToArray(FromVariantConversionData &conversion_data, Ve static bool ConvertVariantToStruct(FromVariantConversionData &conversion_data, Vector &result, const SelectionVector &sel, idx_t offset, idx_t count, optional_idx row) { auto &target_type = result.GetType(); - auto &allocator = Allocator::DefaultAllocator(); - - AllocatedData owned_child_data; - VariantNestedData *child_data = nullptr; - if (count) { - owned_child_data = allocator.Allocate(sizeof(VariantNestedData) * count); - child_data = reinterpret_cast(owned_child_data.get()); - } + const auto owned_child_data = make_unsafe_uniq_array_uninitialized(count); + array_ptr child_data(owned_child_data.get(), count); //! Initialize the validity with that of the result (in case some rows are already set to invalid, we need to //! respect that) @@ -407,7 +387,7 @@ static bool ConvertVariantToStruct(FromVariantConversionData &conversion_data, V //! Then find the relevant child of the OBJECTs we're converting //! FIXME: there is nothing preventing an OBJECT from containing the same key twice I believe ? VariantPathComponent component; - component.key = child_name; + component.key = child_name.GetIdentifierName(); component.lookup_mode = VariantChildLookupMode::BY_KEY; ValidityMask lookup_validity(count); VariantUtils::FindChildValues(conversion_data.variant, component, row_sel, child_values_sel, lookup_validity, diff --git a/src/duckdb/src/function/cast/variant/to_variant.cpp b/src/duckdb/src/function/cast/variant/to_variant.cpp index 617b6ae29..f0525ecfe 100644 --- a/src/duckdb/src/function/cast/variant/to_variant.cpp +++ b/src/duckdb/src/function/cast/variant/to_variant.cpp @@ -214,7 +214,7 @@ static bool CastToVARIANT(Vector &source, Vector &result, idx_t count, CastParam offsets.Initialize(Allocator::DefaultAllocator(), {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); - offsets.SetCardinality(count); + offsets.SetChildCardinality(count); auto &keys = VariantVector::GetKeys(result); auto &keys_entry = ListVector::GetChildMutable(keys); diff --git a/src/duckdb/src/function/cast_rules.cpp b/src/duckdb/src/function/cast_rules.cpp index 57abf9f1b..b386c9124 100644 --- a/src/duckdb/src/function/cast_rules.cpp +++ b/src/duckdb/src/function/cast_rules.cpp @@ -494,7 +494,7 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) for (idx_t to_member_idx = 0; to_member_idx < UnionType::GetMemberCount(to); to_member_idx++) { auto &to_member_name = UnionType::GetMemberName(to, to_member_idx); - if (StringUtil::CIEquals(from_member_name, to_member_name)) { + if (from_member_name == to_member_name) { auto &from_member_type = UnionType::GetMemberType(from, from_member_idx); auto &to_member_type = UnionType::GetMemberType(to, to_member_idx); @@ -513,8 +513,7 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) return cost; } } - if ((from.id() == LogicalTypeId::STRUCT || from.IsAggregateStateStructType()) && - (to.id() == LogicalTypeId::STRUCT || to.IsAggregateStateStructType())) { + if ((from.id() == LogicalTypeId::STRUCT) && (to.id() == LogicalTypeId::STRUCT)) { if (to.AuxInfo() == nullptr) { // If this struct is not fully resolved, we'll leave it to the actual cast logic to handle it. return 0; @@ -535,7 +534,7 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) int64_t cost = -1; if (named_struct_cast) { // Collect the target members in a map for easy lookup - case_insensitive_map_t target_members; + identifier_map_t target_members; for (idx_t target_idx = 0; target_idx < target_children.size(); target_idx++) { auto &target_name = target_children[target_idx].first; if (target_members.find(target_name) != target_members.end()) { diff --git a/src/duckdb/src/function/copy_blob.cpp b/src/duckdb/src/function/copy_blob.cpp index b4e1f857f..a3e5b8d56 100644 --- a/src/duckdb/src/function/copy_blob.cpp +++ b/src/duckdb/src/function/copy_blob.cpp @@ -34,7 +34,7 @@ string ParseStringOption(const Value &value, const string &loption) { } unique_ptr WriteBlobBind(ClientContext &context, CopyFunctionBindInput &input, - const vector &names, const vector &sql_types) { + const vector &names, const vector &sql_types) { if (sql_types.size() != 1 || sql_types.back().id() != LogicalTypeId::BLOB) { throw BinderException("\"COPY (FORMAT BLOB)\" only supports a single BLOB column"); } diff --git a/src/duckdb/src/function/copy_function.cpp b/src/duckdb/src/function/copy_function.cpp index f59e6910f..8f9b997e1 100644 --- a/src/duckdb/src/function/copy_function.cpp +++ b/src/duckdb/src/function/copy_function.cpp @@ -4,7 +4,7 @@ namespace duckdb { -CopyFunction::CopyFunction(const string &name) +CopyFunction::CopyFunction(const Identifier &name) : Function(name), plan(nullptr), copy_to_select(nullptr), copy_to_bind(nullptr), copy_options(nullptr), copy_to_initialize_local(nullptr), copy_to_initialize_global(nullptr), copy_to_get_written_statistics(nullptr), copy_to_sink(nullptr), copy_to_combine(nullptr), copy_to_finalize(nullptr), execution_mode(nullptr), @@ -18,7 +18,7 @@ CopyOption::CopyOption() : type(LogicalType::ANY), mode(CopyOptionMode::READ_WRI CopyOption::CopyOption(LogicalType type_p, CopyOptionMode mode_p) : type(std::move(type_p)), mode(mode_p) { } -vector GetCopyFunctionReturnNames(CopyFunctionReturnType return_type) { +vector GetCopyFunctionReturnNames(CopyFunctionReturnType return_type) { switch (return_type) { case CopyFunctionReturnType::CHANGED_ROWS: return {"Count"}; diff --git a/src/duckdb/src/function/function.cpp b/src/duckdb/src/function/function.cpp index 78a749ce6..208477bcb 100644 --- a/src/duckdb/src/function/function.cpp +++ b/src/duckdb/src/function/function.cpp @@ -10,7 +10,7 @@ namespace duckdb { bool FunctionProperties::operator==(const FunctionProperties &rhs) const { return stability == rhs.stability && null_handling == rhs.null_handling && errors == rhs.errors && - collation_handling == rhs.collation_handling; + collation_handling == rhs.collation_handling && capture_argument_aliases == rhs.capture_argument_aliases; } bool FunctionProperties::operator!=(const FunctionProperties &rhs) const { @@ -45,12 +45,16 @@ bool FunctionData::SupportStatementCache() const { return true; } -Function::Function(string name_p) : name(std::move(name_p)) { +Function::Function(Identifier name_p) : name(std::move(name_p)) { } Function::~Function() { } -SimpleFunction::SimpleFunction(string name_p, vector arguments_p, LogicalType return_type, +SimpleFunction::SimpleFunction(Identifier name_p, FunctionSignature signature_p) + : Function(std::move(name_p)), signature(std::move(signature_p)) { +} + +SimpleFunction::SimpleFunction(Identifier name_p, vector arguments_p, LogicalType return_type, LogicalType varargs_p) : Function(std::move(name_p)), signature(std::move(arguments_p), std::move(varargs_p), std::move(return_type)) { } @@ -58,12 +62,15 @@ SimpleFunction::SimpleFunction(string name_p, vector arguments_p, L SimpleFunction::~SimpleFunction() { } -static bool RequiresCatalogAndSchemaNamePrefix(const string &catalog_name, const string &schema_name) { - return !catalog_name.empty() && catalog_name != SYSTEM_CATALOG && !schema_name.empty() && - schema_name != DEFAULT_SCHEMA; +static bool RequiresCatalogAndSchemaNamePrefix(const Identifier &catalog_name, const Identifier &schema_name) { + return !catalog_name.empty() && catalog_name != Identifier::SystemCatalog() && !schema_name.empty() && + schema_name != Identifier::DefaultSchema(); } string FunctionParameter::ToString() const { + if (default_value) { + return StringUtil::Format("%s %s := %s", name, type.ToString(), default_value->ToString()); + } return StringUtil::Format("%s %s", name, type.ToString()); } @@ -90,7 +97,7 @@ string SimpleFunction::ToString() const { return name + signature.ToString(); } -SimpleNamedParameterFunction::SimpleNamedParameterFunction(string name_p, vector arguments_p, +SimpleNamedParameterFunction::SimpleNamedParameterFunction(Identifier name_p, vector arguments_p, LogicalType varargs_p) : Function(std::move(name_p)), arguments(std::move(arguments_p)), varargs(std::move(varargs_p)) { } @@ -138,8 +145,10 @@ hash_t SimpleFunction::Hash() const { return signature.Hash(); } -string Function::CallToString(const string &catalog_name, const string &schema_name, const string &name, - const vector &arguments, const LogicalType &varargs) { +string Function::CallToString(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &name, + const vector &arguments, + const vector> &named_arguments, + const LogicalType &varargs) { string result; if (RequiresCatalogAndSchemaNamePrefix(catalog_name, schema_name)) { result += catalog_name + "." + schema_name + "."; @@ -149,6 +158,11 @@ string Function::CallToString(const string &catalog_name, const string &schema_n for (auto &arg : arguments) { string_arguments.push_back(arg.ToString()); } + + for (const auto &[arg_name, arg_type] : named_arguments) { + string_arguments.push_back(StringUtil::Format("%s := %s", arg_name, arg_type.ToString())); + } + if (varargs.IsValid()) { string_arguments.push_back("[" + varargs.ToString() + "...]"); } @@ -156,15 +170,16 @@ string Function::CallToString(const string &catalog_name, const string &schema_n return result + ")"; } -string Function::CallToString(const string &catalog_name, const string &schema_name, const string &name, +string Function::CallToString(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &name, const vector &arguments, const LogicalType &varargs, const LogicalType &return_type) { - string result = CallToString(catalog_name, schema_name, name, arguments, varargs); + string result = + CallToString(catalog_name, schema_name, name, arguments, vector> {}, varargs); result += " -> " + return_type.ToString(); return result; } -string Function::CallToString(const string &catalog_name, const string &schema_name, const string &name, +string Function::CallToString(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &name, const vector &arguments, const named_parameter_type_map_t &named_parameters) { vector input_arguments; @@ -206,7 +221,7 @@ string BoundSimpleFunction::ToString() const { } bool FunctionParameter::operator==(const FunctionParameter &other) const { - return type == other.type && StringUtil::CIEquals(name, other.name); + return type == other.type && name == other.name; } bool FunctionParameter::operator!=(const FunctionParameter &other) const { diff --git a/src/duckdb/src/function/function_binder.cpp b/src/duckdb/src/function/function_binder.cpp index 81238c301..58fba7064 100644 --- a/src/duckdb/src/function/function_binder.cpp +++ b/src/duckdb/src/function/function_binder.cpp @@ -1,7 +1,9 @@ #include "duckdb/function/function_binder.hpp" #include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/window_function_catalog_entry.hpp" #include "duckdb/common/limits.hpp" #include "duckdb/common/type_visitor.hpp" #include "duckdb/execution/expression_executor.hpp" @@ -22,51 +24,85 @@ FunctionBinder::FunctionBinder(ClientContext &context_p) : binder(nullptr), cont FunctionBinder::FunctionBinder(Binder &binder_p) : binder(&binder_p), context(binder_p.context) { } -optional_idx FunctionBinder::BindVarArgsFunctionCost(const SimpleFunction &func, const vector &arguments) { - auto &sig = func.GetSignature(); - - if (arguments.size() < sig.GetParameterCount()) { - // not enough arguments to fulfill the non-vararg part of the function - return optional_idx(); +// Split the full (maybe-named) argument list into positional types + named (name, type) pairs. +// Returns false if a positional argument follows a named one, which is not allowed. +static bool TrySplitArgumentTypes(const vector>> &arguments, + vector &positional, vector> &named) { + for (auto &arg : arguments) { + auto type = ExpressionBinder::GetExpressionReturnType(*arg.second); + if (!arg.first.empty()) { + named.emplace_back(arg.first, std::move(type)); + continue; + } + if (named.empty()) { + positional.push_back(std::move(type)); + continue; + } + // a positional argument cannot follow a named one + return false; } - idx_t cost = 0; - for (idx_t i = 0; i < arguments.size(); i++) { - LogicalType arg_type = i < sig.GetParameterCount() ? sig.GetParameter(i).GetType() : sig.GetVarArgs(); - if (arguments[i] == arg_type) { - // arguments match: do nothing + return true; +} + +// Split the full (maybe-named) argument list into the positional + named (keyword) children. +static auto SplitArguments(vector>> arguments) + -> pair>, vector>>> { + vector> regular_args; + vector>> keyword_args; + + for (auto &arg : arguments) { + if (!arg.first.empty()) { + keyword_args.push_back(std::move(arg)); continue; } - int64_t cast_cost = CastFunctionSet::ImplicitCastCost(context, arguments[i], arg_type); - if (cast_cost >= 0) { - // we can implicitly cast, add the cost to the total cost - cost += idx_t(cast_cost); - } else { - // we can't implicitly cast: throw an error - return optional_idx(); + if (keyword_args.empty()) { + regular_args.push_back(std::move(arg.second)); + continue; } + // Defensive: a positional argument following a named one should already have been rejected during binding + // (see TrySplitArgumentTypes / BindFunctionWithImplicitNaming). + throw BinderException(arg.second->GetQueryLocation(), + "Positional argument '%s' cannot follow named arguments in a function call.", + arg.second->ToString()); } - return cost; + + return {std::move(regular_args), std::move(keyword_args)}; } -optional_idx FunctionBinder::BindFunctionCost(const SimpleFunction &func, const vector &arguments) { - auto &sig = func.GetSignature(); +optional_idx FunctionBinder::BindFunctionCost(const SimpleFunction &func, const vector &arguments, + const vector> &named_arguments) { + const auto &sig = func.GetSignature(); - if (sig.HasVarArgs()) { - // special case varargs function - return BindVarArgsFunctionCost(func, arguments); + // Compute total number of arguments passed + const auto received_arg_count = static_cast(arguments.size() + named_arguments.size()); + + // And the minimum and maximum number of arguments the function can accept + const auto minimum_arg_count = sig.GetRequiredParameterCount(); + + const auto maximum_arg_count = sig.HasVarArgs() ? NumericLimits::Maximum() : sig.GetParameterCount(); + + if (received_arg_count < minimum_arg_count) { + // We have fewer arguments than the function requires, so this function cannot be a match. + return optional_idx(); } - if (sig.GetParameterCount() != arguments.size()) { - // invalid argument count: check the next function + + if (received_arg_count > maximum_arg_count) { + // We have more arguments than the function can take, so this function cannot be a match. return optional_idx(); } + idx_t cost = 0; bool has_parameter = false; + for (idx_t i = 0; i < arguments.size(); i++) { if (arguments[i].id() == LogicalTypeId::UNKNOWN) { has_parameter = true; continue; } - int64_t cast_cost = CastFunctionSet::ImplicitCastCost(context, arguments[i], sig.GetParameter(i).GetType()); + + auto arg_type = i < sig.GetParameterCount() ? sig.GetParameter(i).GetType() : sig.GetVarArgs(); + + int64_t cast_cost = CastFunctionSet::ImplicitCastCost(context, arguments[i], arg_type); if (cast_cost >= 0) { // we can implicitly cast, add the cost to the total cost cost += idx_t(cast_cost); @@ -75,6 +111,51 @@ optional_idx FunctionBinder::BindFunctionCost(const SimpleFunction &func, const return optional_idx(); } } + + // Now check the named arguments + for (idx_t i = 0; i < named_arguments.size(); i++) { + auto &named_arg = named_arguments[i]; + auto opt_param_idx = sig.GetParameterIndexByName(named_arg.first); + + if (!opt_param_idx.IsValid()) { + if (!sig.HasVarArgs()) { + // no parameter with this name: continue + return optional_idx(); + } + + // This is a named vararg argument, we can skip the parameter index check as varargs are always at the end + // of the argument list + auto &vararg_type = sig.GetVarArgs(); + int64_t cast_cost = CastFunctionSet::ImplicitCastCost(context, named_arg.second, vararg_type); + if (cast_cost >= 0) { + // we can implicitly cast, add the cost to the total cost + cost += static_cast(cast_cost); + } else { + // we can't implicitly cast: throw an error + return optional_idx(); + } + } else { + // This parameter index might technically be invalid, in that it may point to a parameter that has already + // been filled by a positional argument. However, we will catch that later when we actually bind the + // argument expressions to construct the bound function expression. At that point we will also have more + // context so we can give a better error message. + const auto param_idx = opt_param_idx.GetIndex(); + + if (param_idx < arguments.size()) { + // If already covered by a positional argument, skip the cost check here. + continue; + } + + auto ¶m = sig.GetParameter(param_idx); + int64_t cast_cost = CastFunctionSet::ImplicitCastCost(context, named_arg.second, param.GetType()); + if (cast_cost >= 0) { + cost += idx_t(cast_cost); + } else { + return optional_idx(); + } + } + } + if (has_parameter) { // all arguments are implicitly castable and there is a parameter - return 0 as cost return 0; @@ -108,7 +189,8 @@ optional_idx FunctionBinder::BindVarArgsFunctionCost(const SimpleNamedParameterF } optional_idx FunctionBinder::BindFunctionCost(const SimpleNamedParameterFunction &func, - const vector &arguments) { + const vector &arguments, + const vector> &) { if (func.HasVarArgs()) { // special case varargs function return BindVarArgsFunctionCost(func, arguments); @@ -141,15 +223,17 @@ optional_idx FunctionBinder::BindFunctionCost(const SimpleNamedParameterFunction } template -vector FunctionBinder::BindFunctionsFromArguments(const string &name, FunctionSet &functions, - const vector &arguments, ErrorData &error) { +vector FunctionBinder::BindFunctionsFromArguments(const Identifier &name, const FunctionSet &functions, + const vector &arguments, + const vector> &named_arguments, + ErrorData &error) { optional_idx best_function; idx_t lowest_cost = NumericLimits::Maximum(); vector candidate_functions; for (idx_t f_idx = 0; f_idx < functions.functions.size(); f_idx++) { auto &func = functions.functions[f_idx]; // check the arguments of the function - auto bind_cost = BindFunctionCost(func, arguments); + auto bind_cost = BindFunctionCost(func, arguments, named_arguments); if (!bind_cost.IsValid()) { // auto casting was not possible continue; @@ -169,8 +253,8 @@ vector FunctionBinder::BindFunctionsFromArguments(const string &name, Fun if (!best_function.IsValid()) { // no matching function was found, throw an error vector candidates; - string catalog_name; - string schema_name; + Identifier catalog_name; + Identifier schema_name; for (auto &f : functions.functions) { if (catalog_name.empty() && !f.catalog_name.empty()) { catalog_name = f.catalog_name; @@ -180,7 +264,8 @@ vector FunctionBinder::BindFunctionsFromArguments(const string &name, Fun } candidates.push_back(f.ToString()); } - error = ErrorData(BinderException::NoMatchingFunction(catalog_name, schema_name, name, arguments, candidates)); + error = ErrorData(BinderException::NoMatchingFunction(catalog_name, schema_name, name, arguments, + named_arguments, candidates)); return candidate_functions; } candidate_functions.push_back(best_function.GetIndex()); @@ -188,14 +273,15 @@ vector FunctionBinder::BindFunctionsFromArguments(const string &name, Fun } template -optional_idx FunctionBinder::MultipleCandidateException(const string &catalog_name, const string &schema_name, - const string &name, FunctionSet &functions, - vector &candidate_functions, - const vector &arguments, ErrorData &error) { +static optional_idx +MultipleCandidateException(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &name, + const FunctionSet &functions, const vector &candidate_functions, + const vector &arguments, + const vector> &named_arguments, ErrorData &error) { D_ASSERT(functions.functions.size() > 1); // there are multiple possible function definitions // throw an exception explaining which overloads are there - string call_str = Function::CallToString(catalog_name, schema_name, name, arguments); + string call_str = Function::CallToString(catalog_name, schema_name, name, arguments, named_arguments); string candidate_str; for (auto &conf : candidate_functions) { const auto &f = functions.GetFunctionByOffset(conf); @@ -210,9 +296,11 @@ optional_idx FunctionBinder::MultipleCandidateException(const string &catalog_na } template -optional_idx FunctionBinder::BindFunctionFromArguments(const string &name, FunctionSet &functions, - const vector &arguments, ErrorData &error) { - auto candidate_functions = BindFunctionsFromArguments(name, functions, arguments, error); +optional_idx FunctionBinder::BindFunctionFromArguments(const Identifier &name, const FunctionSet &functions, + const vector &arguments, + const vector> &named_arguments, + ErrorData &error) { + auto candidate_functions = BindFunctionsFromArguments(name, functions, arguments, named_arguments, error); if (candidate_functions.empty()) { // No candidates, return an invalid index. return optional_idx(); @@ -225,41 +313,107 @@ optional_idx FunctionBinder::BindFunctionFromArguments(const string &name, Funct throw ParameterNotResolvedException(); } } - auto catalog_name = functions.functions.size() > 0 ? functions.functions[0].catalog_name : ""; - auto schema_name = functions.functions.size() > 0 ? functions.functions[0].schema_name : ""; + auto catalog_name = functions.functions.size() > 0 ? functions.functions[0].GetCatalogName() : Identifier(); + auto schema_name = functions.functions.size() > 0 ? functions.functions[0].GetSchemaName() : Identifier(); return MultipleCandidateException(catalog_name, schema_name, name, functions, candidate_functions, arguments, - error); + named_arguments, error); } return candidate_functions[0]; } -optional_idx FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, - const vector &arguments, ErrorData &error) { - return BindFunctionFromArguments(name, functions, arguments, error); +template +static bool AnyOverloadSupportsImplicitArgumentNames(const FunctionSet &functions) { + for (auto &func : functions.functions) { + if (func.GetProperties().GetCaptureArgumentAliases()) { + return true; + } + } + return false; +} + +static optional_idx PositionalAfterNamedArgumentError(const vector>> &arguments, + ErrorData &error) { + bool seen_named = false; + for (const auto &[name, expr] : arguments) { + if (!name.empty()) { + seen_named = true; + continue; + } + if (seen_named) { + error = ErrorData(BinderException( + expr->GetQueryLocation(), "Positional argument '%s' cannot follow named arguments in function call.", + expr->ToString())); + return optional_idx(); + } + } + throw InternalException("ThrowPositionalAfterNamed called without a positional-after-named argument"); +} + +template +optional_idx FunctionBinder::BindFunctionFromArguments(const Identifier &name, const FunctionSet &functions, + vector>> &arguments, + ErrorData &error) { + // First, attempt a regular bind, splitting the arguments into positional + named just once for all overloads. + vector positional; + vector> named; + + if (TrySplitArgumentTypes(arguments, positional, named)) { + return BindFunctionFromArguments(name, functions, positional, named, error); + } + + // The split failed because a positional argument follows a named one. + // Check if there is any overload that supports implicit argument names. + if (AnyOverloadSupportsImplicitArgumentNames(functions)) { + // If so, we can attempt to salvage the call by implicitly naming the positional arguments and retrying again + for (auto &[name, expr] : arguments) { + if (name.empty()) { + name = expr->GetAlias(); + } + } + + positional.clear(); + named.clear(); + + if (TrySplitArgumentTypes(arguments, positional, named)) { + return BindFunctionFromArguments(name, functions, positional, named, error); + } + } + + // No overload could rescue the positional-after-named call, give a clear error. + return PositionalAfterNamedArgumentError(arguments, error); +} + +optional_idx FunctionBinder::BindFunction(const Identifier &name, const ScalarFunctionSet &functions, + const vector ®ular_args, + const vector> &keyword_args, ErrorData &error) { + return BindFunctionFromArguments(name, functions, regular_args, keyword_args, error); } -optional_idx FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, - const vector &arguments, ErrorData &error) { - return BindFunctionFromArguments(name, functions, arguments, error); +optional_idx FunctionBinder::BindFunction(const Identifier &name, const AggregateFunctionSet &functions, + const vector ®ular_args, + const vector> &keyword_args, ErrorData &error) { + return BindFunctionFromArguments(name, functions, regular_args, keyword_args, error); } -optional_idx FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, - const vector &arguments, ErrorData &error) { - return BindFunctionFromArguments(name, functions, arguments, error); +optional_idx FunctionBinder::BindFunction(const Identifier &name, const WindowFunctionSet &functions, + const vector ®ular_args, + const vector> &keyword_args, ErrorData &error) { + return BindFunctionFromArguments(name, functions, regular_args, keyword_args, error); } -optional_idx FunctionBinder::BindFunction(const string &name, WindowFunctionSet &functions, - const vector &arguments, ErrorData &error) { - return BindFunctionFromArguments(name, functions, arguments, error); +optional_idx FunctionBinder::BindFunction(const Identifier &name, const TableFunctionSet &functions, + const vector ®ular_args, + const vector> &keyword_args, ErrorData &error) { + return BindFunctionFromArguments(name, functions, regular_args, keyword_args, error); } -optional_idx FunctionBinder::BindFunction(const string &name, PragmaFunctionSet &functions, vector ¶meters, - ErrorData &error) { +optional_idx FunctionBinder::BindFunction(const Identifier &name, const PragmaFunctionSet &functions, + vector ¶meters, ErrorData &error) { vector types; for (auto &value : parameters) { types.push_back(value.type()); } - auto entry = BindFunctionFromArguments(name, functions, types, error); + auto entry = BindFunctionFromArguments(name, functions, types, {}, error); if (!entry.IsValid()) { error.Throw(); } @@ -273,31 +427,43 @@ optional_idx FunctionBinder::BindFunction(const string &name, PragmaFunctionSet return entry; } -vector FunctionBinder::GetLogicalTypesFromExpressions(vector> &arguments) { - vector types; - types.reserve(arguments.size()); - for (auto &argument : arguments) { - types.push_back(ExpressionBinder::GetExpressionReturnType(*argument)); +pair, vector>> +FunctionBinder::GetArgumentsFromExpressions(const vector> ®ular_args, + const vector>> &keyword_args) { + vector regular_arg_types; + vector> keyword_arg_types; + + for (auto &arg : regular_args) { + regular_arg_types.push_back(ExpressionBinder::GetExpressionReturnType(*arg)); } - return types; + for (auto &kwarg : keyword_args) { + keyword_arg_types.emplace_back(kwarg.first, ExpressionBinder::GetExpressionReturnType(*kwarg.second)); + } + return {std::move(regular_arg_types), std::move(keyword_arg_types)}; } -optional_idx FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, - vector> &arguments, ErrorData &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); +optional_idx FunctionBinder::BindFunction(const Identifier &name, const ScalarFunctionSet &functions, + const vector> ®ular_args, + const vector>> &keyword_args, + ErrorData &error) { + auto [args, kwargs] = GetArgumentsFromExpressions(regular_args, keyword_args); + return BindFunctionFromArguments(name, functions, args, kwargs, error); } -optional_idx FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, - vector> &arguments, ErrorData &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); +optional_idx FunctionBinder::BindFunction(const Identifier &name, const AggregateFunctionSet &functions, + const vector> ®ular_args, + const vector>> &keyword_args, + ErrorData &error) { + auto [args, kwargs] = GetArgumentsFromExpressions(regular_args, keyword_args); + return BindFunctionFromArguments(name, functions, args, kwargs, error); } -optional_idx FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, - vector> &arguments, ErrorData &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); +optional_idx FunctionBinder::BindFunction(const Identifier &name, const TableFunctionSet &functions, + const vector> ®ular_args, + const vector>> &keyword_args, + ErrorData &error) { + auto [args, kwargs] = GetArgumentsFromExpressions(regular_args, keyword_args); + return BindFunctionFromArguments(name, functions, args, kwargs, error); } enum class LogicalTypeComparisonResult : uint8_t { IDENTICAL_TYPE, TARGET_IS_ANY, DIFFERENT_TYPES }; @@ -379,7 +545,7 @@ void FunctionBinder::CastToFunctionArguments(BoundSimpleFunction &function, vect } } -unique_ptr FunctionBinder::BindScalarFunction(const string &schema, const string &name, +unique_ptr FunctionBinder::BindScalarFunction(const Identifier &schema, const Identifier &name, vector> children, ErrorData &error, bool is_operator, optional_ptr binder) { // bind the function @@ -388,11 +554,24 @@ unique_ptr FunctionBinder::BindScalarFunction(const string &schema, return BindScalarFunction(function, std::move(children), error, is_operator, binder); } -unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogEntry &func, +unique_ptr FunctionBinder::BindScalarFunction(const ScalarFunctionCatalogEntry &func, vector> children, ErrorData &error, bool is_operator, optional_ptr binder) { - // bind the function - auto best_function = BindFunction(func.name, func.functions, children, error); + vector>> arguments; + arguments.reserve(children.size()); + for (auto &child : children) { + arguments.emplace_back(string(), std::move(child)); + } + return BindScalarFunction(func, std::move(arguments), error, is_operator, binder); +} + +unique_ptr FunctionBinder::BindScalarFunction(const ScalarFunctionCatalogEntry &func, + vector>> arguments, + ErrorData &error, bool is_operator, + optional_ptr binder) { + // select the best matching overload (this may name positional arguments by their alias for functions that opt + // into implicit argument naming, e.g. struct_pack/row) + auto best_function = BindFunctionFromArguments(func.name, func.functions, arguments, error); if (!best_function.IsValid()) { return nullptr; } @@ -400,6 +579,9 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE // found a matching function! const auto &bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); + // now that the overload is fixed, split the arguments into their final positional/named children + auto [regular_args, keyword_args] = SplitArguments(std::move(arguments)); + // If any of the parameters are NULL, the function will just be replaced with a NULL constant. // We try to give the NULL constant the correct type, but we have to do this without binding the function, // because functions with DEFAULT_NULL_HANDLING should not have to deal with NULL inputs in their bind code. @@ -408,7 +590,7 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE const auto return_type_if_null = bound_function.GetReturnType().IsComplete() ? bound_function.GetReturnType() : LogicalType::SQLNULL; if (bound_function.GetNullHandling() == FunctionNullHandling::DEFAULT_NULL_HANDLING) { - for (auto &child : children) { + for (auto &child : regular_args) { if (child->GetReturnType() == LogicalTypeId::SQLNULL) { return make_uniq(Value(return_type_if_null)); } @@ -423,8 +605,23 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE return make_uniq(Value(return_type_if_null)); } } + for (auto &child : keyword_args) { + if (child.second->GetReturnType() == LogicalTypeId::SQLNULL) { + return make_uniq(Value(return_type_if_null)); + } + if (!child.second->IsFoldable()) { + continue; + } + Value result; + if (!ExpressionExecutor::TryEvaluateScalar(context, *child.second, result)) { + continue; + } + if (result.IsNull()) { + return make_uniq(Value(return_type_if_null)); + } + } } - return BindScalarFunction(bound_function, std::move(children), is_operator, binder); + return BindScalarFunction(bound_function, std::move(regular_args), std::move(keyword_args), is_operator, binder); } static bool RequiresCollationPropagation(const LogicalType &type) { @@ -634,7 +831,7 @@ static void InferTemplateType(ClientContext &context, const LogicalType &source, } static void SubstituteTemplateType(LogicalType &type, case_insensitive_map_t> &bindings, - const string &function_name) { + const Identifier &function_name) { // Replace all template types in with their bound concrete types. type = TypeVisitor::VisitReplace(type, [&](const LogicalType &t) -> LogicalType { if (t.id() == LogicalTypeId::TEMPLATE) { @@ -683,7 +880,7 @@ void FunctionBinder::ResolveTemplateTypes(BoundSimpleFunction &bound_function, } } -static void VerifyTemplateType(const LogicalType &type, const string &function_name) { +static void VerifyTemplateType(const LogicalType &type, const Identifier &function_name) { TypeVisitor::Contains(type, [&](const LogicalType &type) { if (type.id() == LogicalTypeId::TEMPLATE) { const auto msg = @@ -702,28 +899,126 @@ void FunctionBinder::CheckTemplateTypesResolved(const BoundSimpleFunction &bound VerifyTemplateType(bound_function.GetReturnType(), bound_function.GetName()); } +// Drain all named argument and insert them in the correct position according to the function signature. +// Also insert default arguments where needed. +static void ResolveArguments(const SimpleFunction &function, vector> &arguments, + vector>> &named_arguments) { + const auto &sig = function.GetSignature(); + + const auto kwargs_offset = arguments.size(); + + // Reserve space for the named arguments + if (arguments.size() < sig.GetParameterCount()) { + arguments.resize(sig.GetParameterCount()); + } + + identifier_set_t seen_names; + + vector> trailing_kwargs; + + // We now need to reorder them to match the function signature, before appending them to the argument list. + for (idx_t kwarg_idx = 0; kwarg_idx < named_arguments.size(); kwarg_idx++) { + auto &[name, arg] = named_arguments[kwarg_idx]; + const auto location = arg->GetQueryLocation(); + + if (name.empty()) { + // Somehow this was not a named argument, throw an error. + throw BinderException(location, + "Positional arguments cannot follow named arguments in a function call to '%s'", + function.GetName()); + } + + if (seen_names.count(name)) { + // This should also not really happen when invoked through SQL + throw BinderException(location, "Duplicate named argument '%s' in function call to '%s'", + name.GetIdentifierName(), function.GetName()); + } + + seen_names.insert(name); + + const auto opt_param_idx = sig.GetParameterIndexByName(name); + if (!opt_param_idx.IsValid()) { + if (!sig.HasVarArgs()) { + throw BinderException(location, "Function '%s' does not have a parameter named '%s'", + function.GetName(), name); + } + + // This is a named vararg argument, come back for it later + trailing_kwargs.push_back(std::move(arg)); + continue; + } + + const auto param_idx = opt_param_idx.GetIndex(); + + if (param_idx < kwargs_offset) { + throw BinderException(location, + "Named argument '%s' cannot be used for parameter '%s' because it has already " + "been provided as a positional argument in function call to '%s'", + arg->ToString(), name, function.GetName()); + } + + // Move it into the correct reserved position + arguments[param_idx] = std::move(arg); + } + + // Fill out missing arguments with default values if they exist, otherwise throw an error. + for (idx_t i = 0; i < sig.GetParameterCount(); i++) { + if (arguments[i]) { + continue; + } + + const auto ¶m = sig.GetParameter(i); + + if (param.HasDefaultValue()) { + arguments[i] = make_uniq(*param.GetDefaultValue()); + arguments[i]->SetAlias(param.GetName()); + + } else { + throw BinderException("Missing value for parameter '%s' in function call to '%s'", param.GetName(), + function.GetName()); + } + } + + // Now spread out any trailing named vararg arguments into the remaining argument slots, wherever they may be + idx_t kwarg_idx = 0; + for (idx_t slot_idx = kwargs_offset; slot_idx < arguments.size(); slot_idx++) { + if (!arguments[slot_idx]) { + arguments[slot_idx] = std::move(trailing_kwargs[kwarg_idx++]); + } + } + + // And if there are still some left, just append them to the end + idx_t kwargs_remaining = trailing_kwargs.size() - kwarg_idx; + while (kwargs_remaining) { + arguments.push_back(std::move(trailing_kwargs[kwarg_idx++])); + kwargs_remaining--; + } +} + pair> -FunctionBinder::ResolveFunction(const ScalarFunction &function, vector> &children) { +FunctionBinder::ResolveFunction(const ScalarFunction &function, vector> &arguments, + vector>> &named_arguments) { + // Reorder named args + ResolveArguments(function, arguments, named_arguments); + // Make a BoundScalarFunction out of the ScalarFunction, so we can store bind info and other properties in it. BoundScalarFunction bound_function(function); // Expand varargs if necessary if (function.HasVarArgs()) { const auto &varargs_type = function.GetVarArgs(); - for (idx_t i = function.GetSignature().GetParameterCount(); i < children.size(); i++) { + for (idx_t i = function.GetSignature().GetParameterCount(); i < arguments.size(); i++) { bound_function.GetArguments().push_back(varargs_type); } } // Attempt to resolve template types, before we call the "Bind" callback. - ResolveTemplateTypes(bound_function, children); - - // ResolveTemplateTypes(bound_function.arguments, bound_function.return_type, children, bound_function); + ResolveTemplateTypes(bound_function, arguments); unique_ptr bind_info; if (bound_function.HasBindCallback()) { - BindScalarFunctionInput input(context, bound_function, children, binder); + BindScalarFunctionInput input(context, bound_function, arguments, binder); bind_info = bound_function.GetBindCallback()(input); } @@ -736,10 +1031,10 @@ FunctionBinder::ResolveFunction(const ScalarFunction &function, vector FunctionBinder::BindScalarFunction(const ScalarFunction &function, vector> children, bool is_operator, optional_ptr binder) { - auto [bound_function, bind_info] = ResolveFunction(function, children); + return BindScalarFunction(function, std::move(children), {}, is_operator, binder); +} + +unique_ptr FunctionBinder::BindScalarFunction(const ScalarFunction &function, + vector> children, + vector>> keyword_args, + bool is_operator, optional_ptr binder) { + auto [bound_function, bind_info] = ResolveFunction(function, children, keyword_args); unique_ptr result; auto result_func = make_uniq(std::move(bound_function), std::move(children), std::move(bind_info), is_operator); - if (result_func->function.HasBindExpressionCallback()) { + if (result_func->Function().HasBindExpressionCallback()) { // if a bind_expression callback is registered - call it and emit the resulting expression - FunctionBindExpressionInput input(context, result_func->function, result_func->bind_info.get(), - result_func->children); - result = result_func->function.GetBindExpressionCallback()(input); + FunctionBindExpressionInput input(context, result_func->FunctionMutable(), result_func->BindInfoMutable().get(), + result_func->GetChildrenMutable()); + result = result_func->Function().GetBindExpressionCallback()(input); } if (!result) { @@ -769,7 +1071,11 @@ unique_ptr FunctionBinder::BindScalarFunction(const ScalarFunction & } pair> -FunctionBinder::ResolveFunction(const AggregateFunction &function, vector> &children) { +FunctionBinder::ResolveFunction(const AggregateFunction &function, vector> &children, + vector>> &named_arguments) { + // Reorder named args + ResolveArguments(function, children, named_arguments); + // Make a BoundFunction out of the func BoundAggregateFunction bound_function(function); @@ -805,16 +1111,48 @@ unique_ptr FunctionBinder::BindAggregateFunction(const vector> children, unique_ptr filter, AggregateType aggr_type) { - auto [bound_function, bind_info] = ResolveFunction(function, children); + return BindAggregateFunction(function, std::move(children), {}, std::move(filter), aggr_type); +} + +unique_ptr +FunctionBinder::BindAggregateFunction(const AggregateFunction &function, vector> children, + vector>> keyword_args, + unique_ptr filter, AggregateType aggr_type) { + auto [bound_function, bind_info] = ResolveFunction(function, children, keyword_args); return make_uniq(std::move(bound_function), std::move(children), std::move(filter), std::move(bind_info), aggr_type); } +unique_ptr +FunctionBinder::BindAggregateFunction(const AggregateFunctionCatalogEntry &func, + vector>> arguments, ErrorData &error, + unique_ptr filter, AggregateType aggr_type) { + // select the best matching overload (this may name positional arguments by their alias for functions that opt + // into implicit argument naming, e.g. struct_pack/row) + auto best_function = BindFunctionFromArguments(func.name, func.functions, arguments, error); + if (!best_function.IsValid()) { + return nullptr; + } + + // found a matching function! + const auto &bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); + + // now that the overload is fixed, split the arguments into their final positional/named children + auto [regular_args, keyword_args] = SplitArguments(std::move(arguments)); + + return BindAggregateFunction(bound_function, std::move(regular_args), std::move(keyword_args), std::move(filter), + aggr_type); +} + pair> FunctionBinder::ResolveFunction(const WindowFunction &function, vector> &children, + vector>> &named_arguments, optional_ptr> orders, optional_ptr> arg_orders) { + // Reorder named args + ResolveArguments(function, children, named_arguments); + BoundWindowFunction bound_function(function); // Expand varargs if necessary @@ -844,18 +1182,45 @@ FunctionBinder::ResolveFunction(const WindowFunction &function, vector FunctionBinder::BindWindowFunction(const WindowFunction &function, - vector> children, - vector &orders, - vector &arg_orders) { - auto [bound_function, bind_info] = ResolveFunction(function, children, orders, arg_orders); +unique_ptr +FunctionBinder::BindWindowFunction(const WindowFunction &function, vector> children, + vector>> keyword_args, + vector &orders, vector &arg_orders) { + auto [bound_function, bind_info] = ResolveFunction(function, children, keyword_args, orders, arg_orders); auto return_type = bound_function.GetReturnType(); auto window = make_uniq(std::move(bound_function)); auto result = make_uniq(return_type, nullptr, std::move(window), std::move(bind_info)); - result->children = std::move(children); + result->GetChildrenMutable() = std::move(children); return result; } +unique_ptr FunctionBinder::BindWindowFunction(const WindowFunction &function, + vector> children, + vector &orders, + vector &arg_orders) { + vector>> empty_keyword_args; + return BindWindowFunction(function, std::move(children), std::move(empty_keyword_args), orders, arg_orders); +} + +unique_ptr +FunctionBinder::BindWindowFunction(const WindowFunctionCatalogEntry &func, + vector>> arguments, ErrorData &error, + vector &orders, vector &arg_orders) { + // select the best matching overload + auto best_function = BindFunctionFromArguments(func.name, func.functions, arguments, error); + if (!best_function.IsValid()) { + return nullptr; + } + + // found a matching function! + const auto &bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); + + // now that the overload is fixed, split the arguments into their final positional/named children + auto [regular_args, keyword_args] = SplitArguments(std::move(arguments)); + + return BindWindowFunction(bound_function, std::move(regular_args), std::move(keyword_args), orders, arg_orders); +} + } // namespace duckdb diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp index e3aa7527c..5dfa42e7d 100644 --- a/src/duckdb/src/function/function_list.cpp +++ b/src/duckdb/src/function/function_list.cpp @@ -13,6 +13,7 @@ #include "duckdb/function/scalar/string_functions.hpp" #include "duckdb/function/scalar/struct_functions.hpp" #include "duckdb/function/scalar/system_functions.hpp" +#include "duckdb/function/scalar/tablefilter_functions.hpp" #include "duckdb/function/window/ranking_functions.hpp" #include "duckdb/function/window/rows_functions.hpp" #include "duckdb/function/window/value_functions.hpp" @@ -20,45 +21,69 @@ #include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" - namespace duckdb { // Scalar Function -#define DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _NAME, _ALIAS_OF) \ - { _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, _PARAM::GetFunction, nullptr, nullptr, nullptr, nullptr, nullptr } -#define DUCKDB_SCALAR_FUNCTION(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _PARAM::Name, _PARAM::Name) -#define DUCKDB_SCALAR_FUNCTION_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) +#define DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _NAME, _ALIAS_OF) \ + { \ + _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, \ + _PARAM::GetFunction, nullptr, nullptr, nullptr, nullptr, nullptr \ + } +#define DUCKDB_SCALAR_FUNCTION(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _PARAM::Name, _PARAM::Name) +#define DUCKDB_SCALAR_FUNCTION_ALIAS(_PARAM) \ + DUCKDB_SCALAR_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) // Scalar Function Set -#define DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _NAME, _ALIAS_OF) \ - { _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, _PARAM::GetFunctions, nullptr, nullptr, nullptr, nullptr } -#define DUCKDB_SCALAR_FUNCTION_SET(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _PARAM::Name, _PARAM::Name) -#define DUCKDB_SCALAR_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) +#define DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _NAME, _ALIAS_OF) \ + { \ + _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, \ + _PARAM::GetFunctions, nullptr, nullptr, nullptr, nullptr \ + } +#define DUCKDB_SCALAR_FUNCTION_SET(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _PARAM::Name, _PARAM::Name) +#define DUCKDB_SCALAR_FUNCTION_SET_ALIAS(_PARAM) \ + DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) // Aggregate Function -#define DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _NAME, _ALIAS_OF) \ - { _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, nullptr, _PARAM::GetFunction, nullptr, nullptr, nullptr } -#define DUCKDB_AGGREGATE_FUNCTION(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _PARAM::Name, _PARAM::Name) -#define DUCKDB_AGGREGATE_FUNCTION_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) +#define DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _NAME, _ALIAS_OF) \ + { \ + _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, \ + nullptr, _PARAM::GetFunction, nullptr, nullptr, nullptr \ + } +#define DUCKDB_AGGREGATE_FUNCTION(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _PARAM::Name, _PARAM::Name) +#define DUCKDB_AGGREGATE_FUNCTION_ALIAS(_PARAM) \ + DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) // Aggregate Function Set -#define DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _NAME, _ALIAS_OF) \ - { _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, nullptr, nullptr, _PARAM::GetFunctions, nullptr, nullptr } -#define DUCKDB_AGGREGATE_FUNCTION_SET(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _PARAM::Name, _PARAM::Name) -#define DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) +#define DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _NAME, _ALIAS_OF) \ + { \ + _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, \ + nullptr, nullptr, _PARAM::GetFunctions, nullptr, nullptr \ + } +#define DUCKDB_AGGREGATE_FUNCTION_SET(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _PARAM::Name, _PARAM::Name) +#define DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(_PARAM) \ + DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) // Window Function -#define DUCKDB_WINDOW_FUNCTION_BASE(_PARAM, _NAME, _ALIAS_OF) \ - { _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, nullptr, nullptr, nullptr, _PARAM::GetFunction, nullptr } -#define DUCKDB_WINDOW_FUNCTION(_PARAM) DUCKDB_WINDOW_FUNCTION_BASE(_PARAM, _PARAM::Name, _PARAM::Name) -#define DUCKDB_WINDOW_FUNCTION_ALIAS(_PARAM) DUCKDB_WINDOW_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) +#define DUCKDB_WINDOW_FUNCTION_BASE(_PARAM, _NAME, _ALIAS_OF) \ + { \ + _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, \ + nullptr, nullptr, nullptr, _PARAM::GetFunction, nullptr \ + } +#define DUCKDB_WINDOW_FUNCTION(_PARAM) DUCKDB_WINDOW_FUNCTION_BASE(_PARAM, _PARAM::Name, _PARAM::Name) +#define DUCKDB_WINDOW_FUNCTION_ALIAS(_PARAM) \ + DUCKDB_WINDOW_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) // Window Function Set -#define DUCKDB_WINDOW_FUNCTION_SET_BASE(_PARAM, _NAME, _ALIAS_OF) \ - { _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, nullptr, nullptr, nullptr, nullptr, _PARAM::GetFunctions } -#define DUCKDB_WINDOW_FUNCTION_SET(_PARAM) DUCKDB_WINDOW_FUNCTION_SET_BASE(_PARAM, _PARAM::Name, _PARAM::Name) -#define DUCKDB_WINDOW_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_WINDOW_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) +#define DUCKDB_WINDOW_FUNCTION_SET_BASE(_PARAM, _NAME, _ALIAS_OF) \ + { \ + _NAME, _ALIAS_OF, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::Categories, nullptr, \ + nullptr, nullptr, nullptr, nullptr, _PARAM::GetFunctions \ + } +#define DUCKDB_WINDOW_FUNCTION_SET(_PARAM) DUCKDB_WINDOW_FUNCTION_SET_BASE(_PARAM, _PARAM::Name, _PARAM::Name) +#define DUCKDB_WINDOW_FUNCTION_SET_ALIAS(_PARAM) \ + DUCKDB_WINDOW_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name, _PARAM::ALIAS::Name) // Final Function #define FINAL_FUNCTION \ { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr } // this list is generated by scripts/generate_functions.py static const StaticFunctionDefinition function[] = { + DUCKDB_SCALAR_FUNCTION(OperatorNotEqualFun), DUCKDB_SCALAR_FUNCTION(NotLikeFun), DUCKDB_SCALAR_FUNCTION(NotILikeFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorModuloFun), @@ -68,8 +93,14 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(OperatorSubtractFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorFloatDivideFun), DUCKDB_SCALAR_FUNCTION_SET(OperatorIntegerDivideFun), + DUCKDB_SCALAR_FUNCTION(OperatorLessThanFun), + DUCKDB_SCALAR_FUNCTION(OperatorLessThanEqualsFun), + DUCKDB_SCALAR_FUNCTION(OperatorEqualFun), + DUCKDB_SCALAR_FUNCTION(OperatorGreaterThanFun), + DUCKDB_SCALAR_FUNCTION(OperatorGreaterThanEqualsFun), + DUCKDB_SCALAR_FUNCTION(IsDistinctFromFun), + DUCKDB_SCALAR_FUNCTION(IsNotDistinctFromFun), DUCKDB_SCALAR_FUNCTION(BetweenFun), - DUCKDB_SCALAR_FUNCTION(ComparisonFun), DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUbigintFun), DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUintegerFun), DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUsmallintFun), @@ -89,6 +120,12 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralUintegerFun), DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralUsmallintFun), DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressStringFun), + DUCKDB_SCALAR_FUNCTION(TableFilterBloomFilterFun), + DUCKDB_SCALAR_FUNCTION(TableFilterDynamicFun), + DUCKDB_SCALAR_FUNCTION(TableFilterOptionalFun), + DUCKDB_SCALAR_FUNCTION(TableFilterPerfectHashJoinFun), + DUCKDB_SCALAR_FUNCTION(TableFilterPrefixRangeFun), + DUCKDB_SCALAR_FUNCTION(TableFilterSelectivityOptionalFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AddFun), DUCKDB_AGGREGATE_FUNCTION_SET(AnyValueFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArbitraryFun), @@ -108,7 +145,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(BitLengthFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(CharLengthFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(CharacterLengthFun), - DUCKDB_SCALAR_FUNCTION_SET(CombineFun), + DUCKDB_SCALAR_FUNCTION(CombineFun), DUCKDB_AGGREGATE_FUNCTION(CombineAggrFun), DUCKDB_SCALAR_FUNCTION(ConcatFun), DUCKDB_SCALAR_FUNCTION(ConcatWsFun), @@ -122,21 +159,22 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION(CurrentQueryId), DUCKDB_SCALAR_FUNCTION(CurrentTransactionId), DUCKDB_SCALAR_FUNCTION(CurrvalFun), + DUCKDB_SCALAR_FUNCTION_SET(DecimalDivisionFun), DUCKDB_WINDOW_FUNCTION(DenseRankFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DivideFun), DUCKDB_SCALAR_FUNCTION_ALIAS(EndsWithFun), DUCKDB_SCALAR_FUNCTION(ErrorFun), DUCKDB_WINDOW_FUNCTION(FillFun), - DUCKDB_SCALAR_FUNCTION_SET(FinalizeFun), + DUCKDB_SCALAR_FUNCTION(FinalizeFun), DUCKDB_AGGREGATE_FUNCTION_SET(FirstFun), DUCKDB_WINDOW_FUNCTION(FirstValueFun), DUCKDB_SCALAR_FUNCTION(GetVariableFun), DUCKDB_SCALAR_FUNCTION(IlikeEscapeFun), - DUCKDB_WINDOW_FUNCTION_SET(LagFun), + DUCKDB_WINDOW_FUNCTION(LagFun), DUCKDB_AGGREGATE_FUNCTION_SET(LastFun), DUCKDB_WINDOW_FUNCTION(LastValueFun), DUCKDB_SCALAR_FUNCTION_ALIAS(LcaseFun), - DUCKDB_WINDOW_FUNCTION_SET(LeadFun), + DUCKDB_WINDOW_FUNCTION(LeadFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(LenFun), DUCKDB_SCALAR_FUNCTION_SET(LengthFun), DUCKDB_SCALAR_FUNCTION_SET(LengthGraphemeFun), @@ -169,6 +207,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_WINDOW_FUNCTION(NthValueFun), DUCKDB_WINDOW_FUNCTION(NtileFun), DUCKDB_SCALAR_FUNCTION_SET(OctetLengthFun), + DUCKDB_SCALAR_FUNCTION_SET(OverlayFun), DUCKDB_SCALAR_FUNCTION(ParseLogMessage), DUCKDB_SCALAR_FUNCTION(PathJoinFun), DUCKDB_WINDOW_FUNCTION(PercentRankFun), @@ -218,10 +257,13 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(SubstringGraphemeFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(SubtractFun), DUCKDB_SCALAR_FUNCTION(SuffixFun), + DUCKDB_SCALAR_FUNCTION(ToAggregateStateFun), DUCKDB_SCALAR_FUNCTION_SET(TryStrpTimeFun), DUCKDB_SCALAR_FUNCTION_ALIAS(UcaseFun), DUCKDB_SCALAR_FUNCTION(UpperFun), + DUCKDB_SCALAR_FUNCTION_SET(VariantExistsFun), DUCKDB_SCALAR_FUNCTION_SET(VariantExtractFun), + DUCKDB_SCALAR_FUNCTION_SET(VariantKeysFun), DUCKDB_SCALAR_FUNCTION(VariantNormalizeFun), DUCKDB_SCALAR_FUNCTION(VariantTypeofFun), DUCKDB_SCALAR_FUNCTION_SET(WriteLogFun), diff --git a/src/duckdb/src/function/function_set.cpp b/src/duckdb/src/function/function_set.cpp index 767ea4f72..14e3de92d 100644 --- a/src/duckdb/src/function/function_set.cpp +++ b/src/duckdb/src/function/function_set.cpp @@ -6,10 +6,10 @@ namespace duckdb { ScalarFunctionSet::ScalarFunctionSet() : FunctionSet("") { } -ScalarFunctionSet::ScalarFunctionSet(string name) : FunctionSet(std::move(name)) { +ScalarFunctionSet::ScalarFunctionSet(Identifier name) : FunctionSet(std::move(name)) { } -ScalarFunctionSet::ScalarFunctionSet(ScalarFunction fun) : FunctionSet(std::move(fun.name)) { +ScalarFunctionSet::ScalarFunctionSet(ScalarFunction fun) : FunctionSet(fun.name) { functions.push_back(std::move(fun)); } @@ -19,8 +19,8 @@ const ScalarFunction &ScalarFunctionSet::GetFunctionByArguments(ClientContext &c FunctionBinder binder(context); auto index = binder.BindFunction(name, *this, arguments, error); if (!index.IsValid()) { - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error.Message()); + throw BinderException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), + error.RawMessage()); } return GetFunctionByOffset(index.GetIndex()); } @@ -28,10 +28,10 @@ const ScalarFunction &ScalarFunctionSet::GetFunctionByArguments(ClientContext &c AggregateFunctionSet::AggregateFunctionSet() : FunctionSet("") { } -AggregateFunctionSet::AggregateFunctionSet(string name) : FunctionSet(std::move(name)) { +AggregateFunctionSet::AggregateFunctionSet(Identifier name) : FunctionSet(std::move(name)) { } -AggregateFunctionSet::AggregateFunctionSet(AggregateFunction fun) : FunctionSet(std::move(fun.name)) { +AggregateFunctionSet::AggregateFunctionSet(AggregateFunction fun) : FunctionSet(fun.name) { functions.push_back(std::move(fun)); } @@ -60,8 +60,8 @@ const AggregateFunction &AggregateFunctionSet::GetFunctionByArguments(ClientCont return func; } } - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error.Message()); + throw BinderException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), + error.RawMessage()); } return GetFunctionByOffset(index.GetIndex()); } @@ -69,10 +69,10 @@ const AggregateFunction &AggregateFunctionSet::GetFunctionByArguments(ClientCont WindowFunctionSet::WindowFunctionSet() : FunctionSet("") { } -WindowFunctionSet::WindowFunctionSet(string name) : FunctionSet(std::move(name)) { +WindowFunctionSet::WindowFunctionSet(Identifier name) : FunctionSet(std::move(name)) { } -WindowFunctionSet::WindowFunctionSet(WindowFunction fun) : FunctionSet(std::move(fun.name)) { +WindowFunctionSet::WindowFunctionSet(WindowFunction fun) : FunctionSet(fun.name) { functions.push_back(std::move(fun)); } @@ -82,16 +82,16 @@ const WindowFunction &WindowFunctionSet::GetFunctionByArguments(ClientContext &c FunctionBinder binder(context); auto index = binder.BindFunction(name, *this, arguments, error); if (!index.IsValid()) { - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error.Message()); + throw BinderException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), + error.RawMessage()); } return GetFunctionByOffset(index.GetIndex()); } -TableFunctionSet::TableFunctionSet(string name) : FunctionSet(std::move(name)) { +TableFunctionSet::TableFunctionSet(Identifier name) : FunctionSet(std::move(name)) { } -TableFunctionSet::TableFunctionSet(TableFunction fun) : FunctionSet(std::move(fun.name)) { +TableFunctionSet::TableFunctionSet(TableFunction fun) : FunctionSet(fun.name) { functions.push_back(std::move(fun)); } @@ -101,16 +101,16 @@ const TableFunction &TableFunctionSet::GetFunctionByArguments(ClientContext &con FunctionBinder binder(context); auto index = binder.BindFunction(name, *this, arguments, error); if (!index.IsValid()) { - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error.Message()); + throw BinderException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), + error.RawMessage()); } return GetFunctionByOffset(index.GetIndex()); } -PragmaFunctionSet::PragmaFunctionSet(string name) : FunctionSet(std::move(name)) { +PragmaFunctionSet::PragmaFunctionSet(Identifier name) : FunctionSet(std::move(name)) { } -PragmaFunctionSet::PragmaFunctionSet(PragmaFunction fun) : FunctionSet(std::move(fun.name)) { +PragmaFunctionSet::PragmaFunctionSet(PragmaFunction fun) : FunctionSet(fun.name) { functions.push_back(std::move(fun)); } diff --git a/src/duckdb/src/function/macro_function.cpp b/src/duckdb/src/function/macro_function.cpp index d41803adb..b885392b9 100644 --- a/src/duckdb/src/function/macro_function.cpp +++ b/src/duckdb/src/function/macro_function.cpp @@ -17,7 +17,7 @@ namespace duckdb { MacroFunction::MacroFunction(MacroType type) : type(type) { } -string FormatMacroFunction(const MacroFunction &function, const string &name) { +string FormatMacroFunction(const MacroFunction &function, const Identifier &name) { auto result = name + "("; string parameters; for (idx_t param_idx = 0; param_idx < function.parameters.size(); param_idx++) { @@ -42,9 +42,10 @@ string FormatMacroFunction(const MacroFunction &function, const string &name) { } MacroBindResult MacroFunction::BindMacroFunction( - Binder &binder, const vector> &functions, const string &name, + Binder &binder, const vector> &functions, const Identifier &name, FunctionExpression &function_expr, vector> &positional_arguments, - InsertionOrderPreservingMap> &named_arguments, idx_t depth) { + InsertionOrderPreservingMap, Identifier, identifier_map_t> &named_arguments, + idx_t depth) { ExpressionBinder expr_binder(binder, binder.context); expr_binder.lambda_bindings = binder.lambda_bindings; @@ -64,28 +65,28 @@ MacroBindResult MacroFunction::BindMacroFunction( // Find argument types and separate positional and default arguments vector positional_arg_types; - InsertionOrderPreservingMap named_arg_types; - for (auto &arg : function_expr.children) { - auto arg_copy = arg->Copy(); + InsertionOrderPreservingMap> named_arg_types; + for (auto &arg : function_expr.GetArgumentsMutable()) { + auto arg_copy = arg.GetExpression().Copy(); LogicalType arg_type = LogicalType::UNKNOWN; if (requires_bind) { const auto arg_bind_result = expr_binder.BindExpression(arg_copy, depth + 1); arg_type = arg_bind_result.HasError() ? LogicalType::UNKNOWN : arg_bind_result.expression->GetReturnType(); } - if (!arg->GetAlias().empty()) { + if (!arg.GetExpression().GetAlias().empty()) { // Default argument - if (named_arguments.find(arg->GetAlias()) != named_arguments.end()) { - return MacroBindResult( - StringUtil::Format("Macro %s() has named argument repeated '%s'", name, arg->GetAlias())); + if (named_arguments.find(arg.GetExpression().GetAlias()) != named_arguments.end()) { + return MacroBindResult(StringUtil::Format("Macro %s() has named argument repeated '%s'", name, + arg.GetExpression().GetAlias())); } - named_arg_types.insert(arg->GetAlias(), std::move(arg_type)); - named_arguments[arg->GetAlias()] = std::move(arg); + named_arg_types.insert(arg.GetExpression().GetAlias(), std::move(arg_type)); + named_arguments[arg.GetExpression().GetAlias()] = std::move(arg.GetExpressionMutable()); } else if (!named_arguments.empty()) { return MacroBindResult( StringUtil::Format("Macro %s() has positional argument following named argument", name)); } else { // Positional argument - positional_arguments.push_back(std::move(arg)); + positional_arguments.push_back(std::move(arg.GetExpressionMutable())); positional_arg_types.push_back(std::move(arg_type)); } } @@ -111,7 +112,7 @@ MacroBindResult MacroFunction::BindMacroFunction( for (auto &kv : named_arguments) { bool found = false; for (const auto ¶meter : function->parameters) { - if (StringUtil::CIEquals(kv.first, parameter->Cast().GetColumnName())) { + if (kv.first == parameter->Cast().GetColumnName()) { found = true; break; } @@ -254,23 +255,22 @@ MacroBindResult MacroFunction::BindMacroFunction( return MacroBindResult(macro_idx); } -unique_ptr -MacroFunction::CreateDummyBinding(const MacroFunction ¯o_def, const string &name, - vector> &positional_arguments, - InsertionOrderPreservingMap> &named_arguments) { +unique_ptr MacroFunction::CreateDummyBinding( + const MacroFunction ¯o_def, const Identifier &name, vector> &positional_arguments, + InsertionOrderPreservingMap, Identifier, identifier_map_t> &named_arguments) { // create a MacroBinding to bind this macro's parameters to its arguments vector types = macro_def.types; types.resize(macro_def.parameters.size(), LogicalType::UNKNOWN); - vector names; + vector names; for (idx_t i = 0; i < positional_arguments.size(); i++) { - names.push_back(macro_def.parameters[i]->Cast().GetColumnName()); + names.emplace_back(macro_def.parameters[i]->Cast().GetColumnName()); } for (auto &kv : named_arguments) { - names.push_back(kv.first); + names.emplace_back(kv.first); positional_arguments.push_back(std::move(kv.second)); // push defaults into positionals } - auto res = make_uniq(types, names, name); + auto res = make_uniq(types, names, name.GetIdentifierName()); res->arguments = &positional_arguments; return res; } @@ -289,7 +289,7 @@ void MacroFunction::CopyProperties(MacroFunction &other) const { vector> MacroFunction::GetPositionalParametersForSerialization(Serializer &serializer) const { vector> result; - if (serializer.ShouldSerialize(6)) { + if (serializer.ShouldSerialize(StorageVersion::V1_4_0)) { // We serialize all positional parameters as-is for (auto ¶m : parameters) { result.push_back(param->Copy()); @@ -299,7 +299,7 @@ MacroFunction::GetPositionalParametersForSerialization(Serializer &serializer) c // Serializing targeting an older version - delete all named parameters from the list of positional parameters for (auto ¶m : parameters) { auto &colref = param->Cast(); - if (default_parameters.find(colref.GetName()) != default_parameters.end()) { + if (default_parameters.find(Identifier(colref.GetName())) != default_parameters.end()) { // This is a default parameter - do not serialize continue; } @@ -313,7 +313,7 @@ void MacroFunction::FinalizeDeserialization() { for (auto &kv : default_parameters) { bool found = false; for (const auto ¶meter : parameters) { - if (StringUtil::CIEquals(kv.first, parameter->Cast().GetColumnName())) { + if (kv.first == parameter->Cast().GetColumnName()) { found = true; break; } @@ -332,7 +332,7 @@ string MacroFunction::ToSQL() const { vector param_strings; for (idx_t param_idx = 0; param_idx < parameters.size(); param_idx++) { const auto ¶m_name = parameters[param_idx]->Cast().GetColumnName(); - auto param_string = param_name; + auto param_string = param_name.GetIdentifierName(); if (types[param_idx] != LogicalType::UNKNOWN) { param_string += " " + types[param_idx].ToString(); } @@ -340,7 +340,7 @@ string MacroFunction::ToSQL() const { if (it != default_parameters.end()) { param_string += " := " + it->second->ToString(); } - param_strings.push_back(std::move(param_string)); + param_strings.emplace_back(std::move(param_string)); } return StringUtil::Format("(%s) AS ", StringUtil::Join(param_strings, ", ")); } diff --git a/src/duckdb/src/function/pragma/pragma_functions.cpp b/src/duckdb/src/function/pragma/pragma_functions.cpp index fb9e3c2b0..163e1a29d 100644 --- a/src/duckdb/src/function/pragma/pragma_functions.cpp +++ b/src/duckdb/src/function/pragma/pragma_functions.cpp @@ -11,6 +11,7 @@ #include "duckdb/planner/expression_binder.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/storage_manager.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/common/encryption_functions.hpp" #include "duckdb/logging/log_manager.hpp" @@ -26,7 +27,7 @@ static void PragmaEnableProfilingStatement(ClientContext &context, const Functio void RegisterEnableProfiling(BuiltinFunctions &set) { PragmaFunctionSet functions(""); - functions.AddFunction(PragmaFunction::PragmaStatement(string(), PragmaEnableProfilingStatement)); + functions.AddFunction(PragmaFunction::PragmaStatement(Identifier(), PragmaEnableProfilingStatement)); set.AddFunction("enable_profile", functions); set.AddFunction("enable_profiling", functions); @@ -87,11 +88,11 @@ static void PragmaDisableCheckpointOnShutdown(ClientContext &context, const Func } static void PragmaEnableOptimizer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_optimizer = true; + Settings::Set(context, SetScope::SESSION, Value::BOOLEAN(true)); } static void PragmaDisableOptimizer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_optimizer = false; + Settings::Set(context, SetScope::SESSION, Value::BOOLEAN(false)); } void PragmaFunctions::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/pragma_function.cpp b/src/duckdb/src/function/pragma_function.cpp index 531cdfb43..9208a906b 100644 --- a/src/duckdb/src/function/pragma_function.cpp +++ b/src/duckdb/src/function/pragma_function.cpp @@ -3,28 +3,28 @@ namespace duckdb { -PragmaFunction::PragmaFunction(string name, PragmaType pragma_type, pragma_query_t query, pragma_function_t function, - vector arguments, LogicalType varargs) +PragmaFunction::PragmaFunction(Identifier name, PragmaType pragma_type, pragma_query_t query, + pragma_function_t function, vector arguments, LogicalType varargs) : SimpleNamedParameterFunction(std::move(name), std::move(arguments), std::move(varargs)), type(pragma_type), query(query), function(function) { } -PragmaFunction PragmaFunction::PragmaCall(const string &name, pragma_query_t query, vector arguments, +PragmaFunction PragmaFunction::PragmaCall(const Identifier &name, pragma_query_t query, vector arguments, LogicalType varargs) { return PragmaFunction(name, PragmaType::PRAGMA_CALL, query, nullptr, std::move(arguments), std::move(varargs)); } -PragmaFunction PragmaFunction::PragmaCall(const string &name, pragma_function_t function, vector arguments, - LogicalType varargs) { +PragmaFunction PragmaFunction::PragmaCall(const Identifier &name, pragma_function_t function, + vector arguments, LogicalType varargs) { return PragmaFunction(name, PragmaType::PRAGMA_CALL, nullptr, function, std::move(arguments), std::move(varargs)); } -PragmaFunction PragmaFunction::PragmaStatement(const string &name, pragma_query_t query) { +PragmaFunction PragmaFunction::PragmaStatement(const Identifier &name, pragma_query_t query) { vector types; return PragmaFunction(name, PragmaType::PRAGMA_STATEMENT, query, nullptr, std::move(types), LogicalType::INVALID); } -PragmaFunction PragmaFunction::PragmaStatement(const string &name, pragma_function_t function) { +PragmaFunction PragmaFunction::PragmaStatement(const Identifier &name, pragma_function_t function) { vector types; return PragmaFunction(name, PragmaType::PRAGMA_STATEMENT, nullptr, function, std::move(types), LogicalType::INVALID); diff --git a/src/duckdb/src/function/register_function_list.cpp b/src/duckdb/src/function/register_function_list.cpp index 0add0b769..d183784c4 100644 --- a/src/duckdb/src/function/register_function_list.cpp +++ b/src/duckdb/src/function/register_function_list.cpp @@ -43,7 +43,7 @@ struct ExtensionRegister { template static void FillExtraInfo(const StaticFunctionDefinition &function, T &info) { info.internal = true; - info.alias_of = function.alias_of; + info.alias_of = Identifier(function.alias_of); FillFunctionDescriptions(function, info); OP::FillExtraInfo(info); } @@ -60,7 +60,7 @@ static void RegisterFunctionList(REGISTER_CONTEXT &context, const StaticFunction } else { result = function.get_function_set(); } - result.name = function.name; + result.name = Identifier(function.name); CreateScalarFunctionInfo info(result); FillExtraInfo(function, info); OP::RegisterFunction(context, info); @@ -72,7 +72,7 @@ static void RegisterFunctionList(REGISTER_CONTEXT &context, const StaticFunction } else { result = function.get_aggregate_function_set(); } - result.name = function.name; + result.name = Identifier(function.name); CreateAggregateFunctionInfo info(result); FillExtraInfo(function, info); OP::RegisterFunction(context, info); @@ -83,7 +83,7 @@ static void RegisterFunctionList(REGISTER_CONTEXT &context, const StaticFunction } else { result = function.get_window_function_set(); } - result.name = function.name; + result.name = Identifier(function.name); CreateWindowFunctionInfo info(result); FillExtraInfo(function, info); OP::RegisterFunction(context, info); diff --git a/src/duckdb/src/function/scalar/comparison/between.cpp b/src/duckdb/src/function/scalar/comparison/between.cpp index 56c665363..9add1da4b 100644 --- a/src/duckdb/src/function/scalar/comparison/between.cpp +++ b/src/duckdb/src/function/scalar/comparison/between.cpp @@ -36,25 +36,24 @@ void BetweenFunction(DataChunk &args, ExpressionState &state, Vector &result) { Vector intermediate1(LogicalType::BOOLEAN); Vector intermediate2(LogicalType::BOOLEAN); - auto &input = args.data[0]; - auto &lower = args.data[1]; - auto &upper = args.data[2]; - auto count = args.size(); + const auto &input = args.data[0]; + const auto &lower = args.data[1]; + const auto &upper = args.data[2]; if (upper_inclusive && lower_inclusive) { - VectorOperations::GreaterThanEquals(input, lower, intermediate1, count); - VectorOperations::LessThanEquals(input, upper, intermediate2, count); + VectorOperations::GreaterThanEquals(input, lower, intermediate1); + VectorOperations::LessThanEquals(input, upper, intermediate2); } else if (lower_inclusive) { - VectorOperations::GreaterThanEquals(input, lower, intermediate1, count); - VectorOperations::LessThan(input, upper, intermediate2, count); + VectorOperations::GreaterThanEquals(input, lower, intermediate1); + VectorOperations::LessThan(input, upper, intermediate2); } else if (upper_inclusive) { - VectorOperations::GreaterThan(input, lower, intermediate1, count); - VectorOperations::LessThanEquals(input, upper, intermediate2, count); + VectorOperations::GreaterThan(input, lower, intermediate1); + VectorOperations::LessThanEquals(input, upper, intermediate2); } else { - VectorOperations::GreaterThan(input, lower, intermediate1, count); - VectorOperations::LessThan(input, upper, intermediate2, count); + VectorOperations::GreaterThan(input, lower, intermediate1); + VectorOperations::LessThan(input, upper, intermediate2); } - VectorOperations::And(intermediate1, intermediate2, result, count); + VectorOperations::And(intermediate1, intermediate2, result); } #ifndef DUCKDB_SMALLER_BINARY @@ -87,9 +86,9 @@ struct ExclusiveBetweenOperator { }; template -static idx_t BetweenLoopTypeSwitch(Vector &input, Vector &lower, Vector &upper, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel) { +static idx_t BetweenLoopTypeSwitch(const Vector &input, const Vector &lower, const Vector &upper, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel) { switch (input.GetType().InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: @@ -144,9 +143,9 @@ idx_t BetweenSelect(DataChunk &args, ExpressionState &state, optional_ptrCast(); + auto &data = between_expr.BindInfo()->Cast(); return data.lower_inclusive; } bool BoundBetweenExpression::UpperInclusive(const BoundFunctionExpression &between_expr) { - auto &data = between_expr.bind_info->Cast(); + auto &data = between_expr.BindInfo()->Cast(); return data.upper_inclusive; } const Expression &BoundBetweenExpression::Input(const BoundFunctionExpression &between_expr) { - return *between_expr.children[0]; + return *between_expr.GetChildren()[0]; } const Expression &BoundBetweenExpression::LowerBound(const BoundFunctionExpression &between_expr) { - return *between_expr.children[1]; + return *between_expr.GetChildren()[1]; } const Expression &BoundBetweenExpression::UpperBound(const BoundFunctionExpression &between_expr) { - return *between_expr.children[2]; + return *between_expr.GetChildren()[2]; } unique_ptr &BoundBetweenExpression::InputMutable(BoundFunctionExpression &between_expr) { - return between_expr.children[0]; + return between_expr.GetChildrenMutable()[0]; } unique_ptr &BoundBetweenExpression::LowerBoundMutable(BoundFunctionExpression &between_expr) { - return between_expr.children[1]; + return between_expr.GetChildrenMutable()[1]; } unique_ptr &BoundBetweenExpression::UpperBoundMutable(BoundFunctionExpression &between_expr) { - return between_expr.children[2]; + return between_expr.GetChildrenMutable()[2]; } } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/comparison/comparison.cpp b/src/duckdb/src/function/scalar/comparison/comparison.cpp index 15cd950ce..90fd67384 100644 --- a/src/duckdb/src/function/scalar/comparison/comparison.cpp +++ b/src/duckdb/src/function/scalar/comparison/comparison.cpp @@ -1,142 +1,143 @@ #include "duckdb/function/scalar/comparison_functions.hpp" #include "duckdb/function/scalar_function.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/legacy_bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/parser/expression/comparison_expression.hpp" #include "duckdb/common/enums/expression_type.hpp" namespace duckdb { -struct ComparisonFunctionData : public FunctionData { - explicit ComparisonFunctionData(ExpressionType expression_type) : expression_type(expression_type) { - } - - ExpressionType expression_type; - -public: - unique_ptr Copy() const override { - return make_uniq(expression_type); - }; - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return expression_type == other.expression_type; - } -}; - +template void ComparisonFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto expr_type = state.expr.GetExpressionType(); - auto &left = args.data[0]; - auto &right = args.data[1]; - auto count = args.size(); - - switch (expr_type) { - case ExpressionType::COMPARE_EQUAL: - VectorOperations::Equals(left, right, result, count); - break; - case ExpressionType::COMPARE_NOTEQUAL: - VectorOperations::NotEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_LESSTHAN: - VectorOperations::LessThan(left, right, result, count); - break; - case ExpressionType::COMPARE_GREATERTHAN: - VectorOperations::GreaterThan(left, right, result, count); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - VectorOperations::LessThanEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - VectorOperations::GreaterThanEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - VectorOperations::DistinctFrom(left, right, result, count); - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - VectorOperations::NotDistinctFrom(left, right, result, count); - break; - default: + (void)state; + const auto &left = args.data[0]; + const auto &right = args.data[1]; + + if constexpr (TYPE == ExpressionType::COMPARE_EQUAL) { + VectorOperations::Equals(left, right, result); + } else if constexpr (TYPE == ExpressionType::COMPARE_NOTEQUAL) { + VectorOperations::NotEquals(left, right, result); + } else if constexpr (TYPE == ExpressionType::COMPARE_LESSTHAN) { + VectorOperations::LessThan(left, right, result); + } else if constexpr (TYPE == ExpressionType::COMPARE_GREATERTHAN) { + VectorOperations::GreaterThan(left, right, result); + } else if constexpr (TYPE == ExpressionType::COMPARE_LESSTHANOREQUALTO) { + VectorOperations::LessThanEquals(left, right, result); + } else if constexpr (TYPE == ExpressionType::COMPARE_GREATERTHANOREQUALTO) { + VectorOperations::GreaterThanEquals(left, right, result); + } else if constexpr (TYPE == ExpressionType::COMPARE_DISTINCT_FROM) { + VectorOperations::DistinctFrom(left, right, result); + } else if constexpr (TYPE == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + VectorOperations::NotDistinctFrom(left, right, result); + } else { throw InternalException("Unknown comparison type!"); } } +#ifndef DUCKDB_SMALLER_BINARY +template idx_t ComparisonSelect(DataChunk &args, ExpressionState &state, optional_ptr sel, optional_ptr true_sel, optional_ptr false_sel) { - auto expr_type = state.expr.GetExpressionType(); - auto &left = args.data[0]; - auto &right = args.data[1]; + (void)state; + const auto &left = args.data[0]; + const auto &right = args.data[1]; auto count = args.size(); - switch (expr_type) { - case ExpressionType::COMPARE_EQUAL: + if constexpr (TYPE == ExpressionType::COMPARE_EQUAL) { return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_NOTEQUAL: + } else if constexpr (TYPE == ExpressionType::COMPARE_NOTEQUAL) { return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_LESSTHAN: + } else if constexpr (TYPE == ExpressionType::COMPARE_LESSTHAN) { return VectorOperations::LessThan(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_GREATERTHAN: + } else if constexpr (TYPE == ExpressionType::COMPARE_GREATERTHAN) { return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: + } else if constexpr (TYPE == ExpressionType::COMPARE_LESSTHANOREQUALTO) { return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + } else if constexpr (TYPE == ExpressionType::COMPARE_GREATERTHANOREQUALTO) { return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_DISTINCT_FROM: + } else if constexpr (TYPE == ExpressionType::COMPARE_DISTINCT_FROM) { return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + } else if constexpr (TYPE == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, false_sel); + } else { + throw InternalException("Unknown comparison type!"); + } +} +#endif + +template +ExpressionType ComparisonGetExpressionType(FunctionToStringInput &input) { + (void)input; + return TYPE; +} + +template +static ScalarFunction GetComparisonFunctionInternal(const string &name) { + ScalarFunction comparison_fun(Identifier(name), {LogicalType::ANY, LogicalType::ANY}, LogicalType::BOOLEAN, + ComparisonFunction); + comparison_fun.SetGetExpressionTypeCallback(ComparisonGetExpressionType); + if constexpr (TYPE == ExpressionType::COMPARE_DISTINCT_FROM || TYPE == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + comparison_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + } +#ifndef DUCKDB_SMALLER_BINARY + comparison_fun.SetSelectCallback(ComparisonSelect); +#endif + return comparison_fun; +} + +static ScalarFunction GetComparisonFunction(ExpressionType type) { + switch (type) { + case ExpressionType::COMPARE_EQUAL: + return OperatorEqualFun::GetFunction(); + case ExpressionType::COMPARE_NOTEQUAL: + return OperatorNotEqualFun::GetFunction(); + case ExpressionType::COMPARE_LESSTHAN: + return OperatorLessThanFun::GetFunction(); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return OperatorLessThanEqualsFun::GetFunction(); + case ExpressionType::COMPARE_GREATERTHAN: + return OperatorGreaterThanFun::GetFunction(); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return OperatorGreaterThanEqualsFun::GetFunction(); + case ExpressionType::COMPARE_DISTINCT_FROM: + return IsDistinctFromFun::GetFunction(); + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return IsNotDistinctFromFun::GetFunction(); default: throw InternalException("Unknown comparison type!"); } } -unique_ptr BindComparisonFun(BindScalarFunctionInput &input) { - throw InvalidInputException("Comparison function cannot be called directly"); +ScalarFunction OperatorEqualFun::GetFunction() { + return GetComparisonFunctionInternal(OperatorEqualFun::Name); } -string ComparisonToString(FunctionToStringInput &input) { - auto &comparison_data = input.bind_data->Cast(); - return ComparisonExpression::ToString(comparison_data.expression_type, *input.children[0], *input.children[1]); +ScalarFunction OperatorNotEqualFun::GetFunction() { + return GetComparisonFunctionInternal(OperatorNotEqualFun::Name); } -ExpressionType ComparisonGetExpressionType(FunctionToStringInput &input) { - auto &comparison_data = input.bind_data->Cast(); - return comparison_data.expression_type; +ScalarFunction OperatorLessThanFun::GetFunction() { + return GetComparisonFunctionInternal(OperatorLessThanFun::Name); } -unique_ptr ComparisonLegacySerializeCallback(FunctionToStringInput &input) { - auto &comparison_data = input.bind_data->Cast(); - return make_uniq(comparison_data.expression_type, input.GetChild(0).Copy(), - input.GetChild(1).Copy()); +ScalarFunction OperatorLessThanEqualsFun::GetFunction() { + return GetComparisonFunctionInternal(OperatorLessThanEqualsFun::Name); } -void ComparisonFunctionSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const BoundScalarFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "comparison_type", bind_data.expression_type); +ScalarFunction OperatorGreaterThanFun::GetFunction() { + return GetComparisonFunctionInternal(OperatorGreaterThanFun::Name); } -unique_ptr ComparisonFunctionDeserialize(Deserializer &deserializer, BoundScalarFunction &function) { - auto expression_type = deserializer.ReadProperty(100, "comparison_type"); - return make_uniq(expression_type); +ScalarFunction OperatorGreaterThanEqualsFun::GetFunction() { + return GetComparisonFunctionInternal( + OperatorGreaterThanEqualsFun::Name); } -ScalarFunction ComparisonFun::GetFunction() { - ScalarFunction Comparison_fun("__comparison", {LogicalType::ANY, LogicalType::ANY}, LogicalType::BOOLEAN, - ComparisonFunction, BindComparisonFun); - Comparison_fun.SetToStringCallback(ComparisonToString); - Comparison_fun.SetGetExpressionTypeCallback(ComparisonGetExpressionType); - Comparison_fun.SetLegacySerializeCallback(ComparisonLegacySerializeCallback); - Comparison_fun.SetSerializeCallback(ComparisonFunctionSerialize); - Comparison_fun.SetDeserializeCallback(ComparisonFunctionDeserialize); - // DISTINCT FROM / NOT DISTINCT FROM handle NULLs specially - they must not be short-circuited on NULL input - Comparison_fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); -#ifndef DUCKDB_SMALLER_BINARY - Comparison_fun.SetSelectCallback(ComparisonSelect); -#endif - return Comparison_fun; +ScalarFunction IsDistinctFromFun::GetFunction() { + return GetComparisonFunctionInternal(IsDistinctFromFun::Name); +} + +ScalarFunction IsNotDistinctFromFun::GetFunction() { + return GetComparisonFunctionInternal(IsNotDistinctFromFun::Name); } //===--------------------------------------------------------------------===// @@ -148,10 +149,8 @@ unique_ptr BoundComparisonExpression::Create(ExpressionType type, un children.push_back(std::move(left)); children.push_back(std::move(right)); - auto function_data = make_uniq(type); - - auto result = make_uniq(BoundScalarFunction(ComparisonFun::GetFunction()), - std::move(children), std::move(function_data), false); + auto result = make_uniq(BoundScalarFunction(GetComparisonFunction(type)), + std::move(children), nullptr, true); return std::move(result); } @@ -179,24 +178,31 @@ bool BoundComparisonExpression::IsComparison(const Expression &expr) { } const Expression &BoundComparisonExpression::Left(const BoundFunctionExpression &comparison_expr) { - return *comparison_expr.children[0]; + return *comparison_expr.GetChildren()[0]; } const Expression &BoundComparisonExpression::Right(const BoundFunctionExpression &comparison_expr) { - return *comparison_expr.children[1]; + return *comparison_expr.GetChildren()[1]; } unique_ptr &BoundComparisonExpression::LeftMutable(BoundFunctionExpression &comparison_expr) { - return comparison_expr.children[0]; + return comparison_expr.GetChildrenMutable()[0]; } unique_ptr &BoundComparisonExpression::RightMutable(BoundFunctionExpression &comparison_expr) { - return comparison_expr.children[1]; + return comparison_expr.GetChildrenMutable()[1]; } void BoundComparisonExpression::SetType(BoundFunctionExpression &comparison_expr, ExpressionType new_type) { + auto arguments = comparison_expr.FunctionMutable().GetArguments(); + auto original_arguments = comparison_expr.FunctionMutable().GetOriginalArguments(); + comparison_expr.SetExpressionTypeUnsafe(new_type); - comparison_expr.bind_info->Cast().expression_type = new_type; + comparison_expr.FunctionMutable() = BoundScalarFunction(GetComparisonFunction(new_type)); + comparison_expr.FunctionMutable().GetArguments() = std::move(arguments); + comparison_expr.FunctionMutable().GetOriginalArguments() = std::move(original_arguments); + comparison_expr.BindInfoMutable().reset(); + comparison_expr.IsOperatorMutable() = true; } void BoundComparisonExpression::FlipType(BoundFunctionExpression &comparison_expr) { diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp index 27a9862d1..a8ad54965 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp @@ -44,7 +44,7 @@ void IntegralCompressFunction(DataChunk &args, ExpressionState &state, Vector &r D_ASSERT(args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR); const auto min_val = ConstantVector::GetData(args.data[1])[0]; UnaryExecutor::Execute( - args.data[0], result, args.size(), + args.data[0], result, [&](const INPUT_TYPE &input) { return TemplatedIntegralCompress::Operation(input, min_val); }, @@ -134,7 +134,7 @@ void IntegralDecompressFunction(DataChunk &args, ExpressionState &state, Vector D_ASSERT(args.data[1].GetType() == result.GetType()); const auto min_val = ConstantVector::GetData(args.data[1])[0]; UnaryExecutor::Execute( - args.data[0], result, args.size(), + args.data[0], result, [&](const INPUT_TYPE &input) { return TemplatedIntegralDecompress::Operation(input, min_val); }, @@ -202,7 +202,7 @@ unique_ptr CMIntegralDeserialize(Deserializer &deserializer, Bound } ScalarFunctionSet GetIntegralCompressFunctionSet(const LogicalType &result_type) { - ScalarFunctionSet set(IntegralCompressFunctionName(result_type)); + ScalarFunctionSet set {Identifier(IntegralCompressFunctionName(result_type))}; for (const auto &input_type : LogicalType::Integral()) { if (GetTypeIdSize(result_type.InternalType()) < GetTypeIdSize(input_type.InternalType())) { set.AddFunction(CMIntegralCompressFun::GetFunction(input_type, result_type)); @@ -212,7 +212,7 @@ ScalarFunctionSet GetIntegralCompressFunctionSet(const LogicalType &result_type) } ScalarFunctionSet GetIntegralDecompressFunctionSet(const LogicalType &result_type) { - ScalarFunctionSet set(IntegralDecompressFunctionName(result_type)); + ScalarFunctionSet set {Identifier(IntegralDecompressFunctionName(result_type))}; for (const auto &input_type : CMUtils::IntegralTypes()) { if (GetTypeIdSize(result_type.InternalType()) > GetTypeIdSize(input_type.InternalType())) { set.AddFunction(CMIntegralDecompressFun::GetFunction(input_type, result_type)); @@ -224,7 +224,7 @@ ScalarFunctionSet GetIntegralDecompressFunctionSet(const LogicalType &result_typ } // namespace ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { - ScalarFunction result(IntegralCompressFunctionName(result_type), {input_type, input_type}, result_type, + ScalarFunction result(Identifier(IntegralCompressFunctionName(result_type)), {input_type, input_type}, result_type, GetIntegralCompressFunctionInputSwitch(input_type, result_type), CMUtils::Bind); result.SetSerializeCallback(CMIntegralSerialize); result.SetDeserializeCallback(CMIntegralDeserialize); @@ -237,8 +237,9 @@ ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, } ScalarFunction CMIntegralDecompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { - ScalarFunction result(IntegralDecompressFunctionName(result_type), {input_type, result_type}, result_type, - GetIntegralDecompressFunctionInputSwitch(input_type, result_type), CMUtils::Bind); + ScalarFunction result(Identifier(IntegralDecompressFunctionName(result_type)), {input_type, result_type}, + result_type, GetIntegralDecompressFunctionInputSwitch(input_type, result_type), + CMUtils::Bind); result.SetSerializeCallback(CMIntegralSerialize); result.SetDeserializeCallback(CMIntegralDeserialize); return result; diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp index 11e988141..f5e2d90b9 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp @@ -75,7 +75,7 @@ inline uint8_t StringCompress(const string_t &input) { template void StringCompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { UnaryExecutor::Execute( - args.data[0], result, args.size(), StringCompress, + args.data[0], result, StringCompress, #if defined(D_ASSERT_IS_ENABLED) FunctionErrors::CAN_THROW_RUNTIME_ERROR); // Can only throw a runtime error when assertions are enabled #else @@ -179,8 +179,7 @@ void StringDecompressFunction(DataChunk &args, ExpressionState &state, Vector &r auto &allocator = ExecuteFunctionState::GetFunctionState(state)->Cast().allocator; allocator.Reset(); UnaryExecutor::Execute( - args.data[0], result, args.size(), - [&](const INPUT_TYPE &input) { return StringDecompress(input, allocator); }, + args.data[0], result, [&](const INPUT_TYPE &input) { return StringDecompress(input, allocator); }, FunctionErrors::CANNOT_ERROR); } @@ -234,7 +233,7 @@ unique_ptr CMStringDecompressDeserialize(Deserializer &deserialize } ScalarFunctionSet GetStringDecompressFunctionSet() { - ScalarFunctionSet set(StringDecompressFunctionName()); + ScalarFunctionSet set {Identifier(StringDecompressFunctionName())}; auto string_types = CMUtils::StringTypes(); // For backwards compatibility, see internal issue 5306 string_types.push_back(LogicalType::HUGEINT); @@ -247,7 +246,7 @@ ScalarFunctionSet GetStringDecompressFunctionSet() { } // namespace ScalarFunction CMStringCompressFun::GetFunction(const LogicalType &result_type) { - ScalarFunction result(StringCompressFunctionName(result_type), {LogicalType::VARCHAR}, result_type, + ScalarFunction result(Identifier(StringCompressFunctionName(result_type)), {LogicalType::VARCHAR}, result_type, GetStringCompressFunctionSwitch(result_type), CMUtils::Bind); result.SetSerializeCallback(CMStringCompressSerialize); result.SetDeserializeCallback(CMStringCompressDeserialize); @@ -260,7 +259,7 @@ ScalarFunction CMStringCompressFun::GetFunction(const LogicalType &result_type) } ScalarFunction CMStringDecompressFun::GetFunction(const LogicalType &input_type) { - ScalarFunction result(StringDecompressFunctionName(), {input_type}, LogicalType::VARCHAR, + ScalarFunction result(Identifier(StringDecompressFunctionName()), {input_type}, LogicalType::VARCHAR, GetStringDecompressFunctionSwitch(input_type), CMUtils::Bind, nullptr, StringDecompressLocalState::Init); result.SetSerializeCallback(CMStringDecompressSerialize); diff --git a/src/duckdb/src/function/scalar/create_sort_key.cpp b/src/duckdb/src/function/scalar/create_sort_key.cpp index 52b74b164..9bbfbbc48 100644 --- a/src/duckdb/src/function/scalar/create_sort_key.cpp +++ b/src/duckdb/src/function/scalar/create_sort_key.cpp @@ -87,7 +87,7 @@ struct SortKeyVectorData { static constexpr data_t LIST_DELIMITER = 0; static constexpr data_t BLOB_ESCAPE_CHARACTER = 1; - SortKeyVectorData(Vector &input, idx_t size, OrderModifiers modifiers) : vec(input) { + SortKeyVectorData(const Vector &input, idx_t size, OrderModifiers modifiers) : vec(input) { if (size != 0) { input.ToUnifiedFormat(format); } else { @@ -116,13 +116,13 @@ struct SortKeyVectorData { break; } case PhysicalType::ARRAY: { - auto &child_entry = ArrayVector::GetChildMutable(input); + const auto &child_entry = ArrayVector::GetChild(input); auto array_size = ArrayType::GetSize(input.GetType()); child_data.push_back(make_uniq(child_entry, size * array_size, child_modifiers)); break; } case PhysicalType::LIST: { - auto &child_entry = ListVector::GetChildMutable(input); + const auto &child_entry = ListVector::GetChild(input); auto child_size = size == 0 ? 0 : ListVector::GetListSize(input); child_data.push_back(make_uniq(child_entry, child_size, child_modifiers)); break; @@ -142,7 +142,7 @@ struct SortKeyVectorData { return vec.GetType().InternalType(); } - Vector &vec; + const Vector &vec; idx_t size; UnifiedVectorFormat format; vector> child_data; @@ -189,7 +189,8 @@ struct SortKeyConstantOperator { } template - static idx_t Decode(const_data_ptr_t input, Vector &result, TYPE &result_value) { + static idx_t Decode(const_data_ptr_t input, idx_t input_size, Vector &result, TYPE &result_value) { + D_ASSERT(input_size >= sizeof(T)); if (FLIP_BYTES) { // descending order - so flip bytes data_t flipped_bytes[sizeof(T)]; @@ -259,15 +260,16 @@ struct SortKeyVarcharOperator { } template - static idx_t Decode(const_data_ptr_t input, Vector &result, TYPE &result_value) { + static idx_t Decode(const_data_ptr_t input, idx_t input_size, Vector &result, TYPE &result_value) { // iterate until we encounter the string delimiter to figure out the string length data_t string_delimiter = SortKeyVectorData::STRING_DELIMITER; if (FLIP_BYTES) { string_delimiter = static_cast(~string_delimiter); } idx_t pos; - for (pos = 0; input[pos] != string_delimiter; pos++) { + for (pos = 0; pos < input_size && input[pos] != string_delimiter; pos++) { } + D_ASSERT(pos < input_size); idx_t str_len = pos; // now allocate the string data and fill it with the decoded data result_value = StringVector::EmptyString(result, str_len); @@ -370,7 +372,7 @@ struct SortKeyBlobOperator { } template - static idx_t Decode(const_data_ptr_t input, Vector &result, TYPE &result_value) { + static idx_t Decode(const_data_ptr_t input, idx_t input_size, Vector &result, TYPE &result_value) { // scan until we find the delimiter, keeping in mind escapes data_t string_delimiter = SortKeyVectorData::STRING_DELIMITER; data_t escape_character = SortKeyVectorData::BLOB_ESCAPE_CHARACTER; @@ -380,13 +382,15 @@ struct SortKeyBlobOperator { } idx_t blob_len = 0; idx_t pos; - for (pos = 0; input[pos] != string_delimiter; pos++) { + for (pos = 0; pos < input_size && input[pos] != string_delimiter; pos++) { blob_len++; if (input[pos] == escape_character) { // escape character - skip the next byte pos++; + D_ASSERT(pos < input_size); } } + D_ASSERT(pos < input_size); // now allocate the blob data and fill it with the decoded data result_value = StringVector::EmptyString(result, blob_len); auto str_data = data_ptr_cast(result_value.GetDataWriteable()); @@ -887,11 +891,74 @@ void CreateSortKeyInternal(vector> &sort_key_data, FinalizeSortData(result, row_count, key_lengths, offsets); } +#ifdef DEBUG +static void AssertSortKeyRoundTrip(vector> &sort_key_data, + const vector &modifiers, const Vector &result, idx_t row_count) { + D_ASSERT(sort_key_data.size() == modifiers.size()); + UnifiedVectorFormat result_format; + result.ToUnifiedFormat(result_format); + const auto result_is_blob = result.GetType() == LogicalType::BLOB; + const auto result_blob_data = result_is_blob ? UnifiedVectorFormat::GetData(result_format) : nullptr; + const auto result_int_data = result_is_blob ? nullptr : UnifiedVectorFormat::GetData(result_format); + idx_t constant_encoded_size = 0; + if (!result_is_blob) { + for (auto &column : sort_key_data) { + constant_encoded_size += 1 + GetTypeIdSize(column->vec.GetType().InternalType()); + } + D_ASSERT(constant_encoded_size <= sizeof(int64_t)); + } + + vector decoded_columns; + decoded_columns.reserve(sort_key_data.size()); + for (auto &column : sort_key_data) { + decoded_columns.emplace_back(column->vec.GetType()); + } + + for (idx_t r = 0; r < row_count; r++) { + auto key_idx = result_format.sel->get_index(r); + D_ASSERT(result_format.validity.RowIsValid(key_idx)); + + string_t full_key; + int64_t bswapped_key = 0; + if (result_is_blob) { + full_key = result_blob_data[key_idx]; + } else { + bswapped_key = BSwapIfLE(result_int_data[key_idx]); + full_key = string_t(const_char_ptr_cast(reinterpret_cast(&bswapped_key)), sizeof(int64_t)); + } + + const auto full_key_data = full_key.GetData(); + const auto full_key_size = full_key.GetSize(); + const auto expected_size = result_is_blob ? full_key_size : constant_encoded_size; + D_ASSERT(expected_size <= full_key_size); + idx_t offset = 0; + for (idx_t c = 0; c < sort_key_data.size(); c++) { + D_ASSERT(offset <= expected_size); + const auto sliced_data = full_key_data + offset; + const auto sliced_size = expected_size - offset; + auto sliced_key = string_t(sliced_data, UnsafeNumericCast(sliced_size)); + offset += CreateSortKeyHelpers::DecodeSortKey(sliced_key, decoded_columns[c], r, modifiers[c]); + } + D_ASSERT(offset <= expected_size); + + for (idx_t c = 0; c < sort_key_data.size(); c++) { + auto &source_column = sort_key_data[c]; + auto source_val = source_column->vec.GetValue(r); + auto decoded_val = decoded_columns[c].GetValue(r); + D_ASSERT(source_val.IsNull() == decoded_val.IsNull()); + if (!source_val.IsNull()) { + D_ASSERT(source_val == decoded_val); + } + } + } +} +#endif + } // namespace -void CreateSortKeyHelpers::CreateSortKey(Vector &input, idx_t input_count, OrderModifiers order_modifier, - Vector &result) { +void CreateSortKeyHelpers::CreateSortKey(const Vector &input, OrderModifiers order_modifier, Vector &result) { // prepare the sort key data + const auto input_count = input.size(); vector modifiers {order_modifier}; vector> sort_key_data; sort_key_data.push_back(make_uniq(input, input_count, order_modifier)); @@ -908,8 +975,32 @@ void CreateSortKeyHelpers::CreateSortKey(DataChunk &input, const vector modifiers {order_modifier}; + vector> sort_key_data; + sort_key_data.push_back(make_uniq(input, input_count, order_modifier)); + CreateSortKeyInternal(sort_key_data, modifiers, false, result, input_count); +} + +void CreateSortKeyHelpers::CreateSortKeyWithValidity(const Vector &input, Vector &result, + const OrderModifiers &modifiers) { + CreateSortKey(input, modifiers, result); + UnifiedVectorFormat format; + input.ToUnifiedFormat(format); + auto &validity = FlatVector::ValidityMutable(result); + + const auto count = input.size(); + for (idx_t i = 0; i < count; i++) { + auto idx = format.sel->get_index(i); + if (!format.validity.RowIsValid(idx)) { + validity.SetInvalid(i); + } + } +} + +void CreateSortKeyHelpers::CreateSortKeyWithValidity(const Vector &input, Vector &result, + const OrderModifiers &modifiers, idx_t count) { CreateSortKey(input, count, modifiers, result); UnifiedVectorFormat format; input.ToUnifiedFormat(format); @@ -924,7 +1015,7 @@ void CreateSortKeyHelpers::CreateSortKeyWithValidity(Vector &input, Vector &resu } static void CreateSortKeyFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &bind_data = state.expr.Cast().bind_info->Cast(); + auto &bind_data = state.expr.Cast().BindInfo()->Cast(); // prepare the sort key data vector> sort_key_data; @@ -932,6 +1023,9 @@ static void CreateSortKeyFunction(DataChunk &args, ExpressionState &state, Vecto sort_key_data.push_back(make_uniq(args.data[c], args.size(), bind_data.modifiers[c / 2])); } CreateSortKeyInternal(sort_key_data, bind_data.modifiers, bind_data.all_constant, result, args.size()); +#ifdef DEBUG + AssertSortKeyRoundTrip(sort_key_data, bind_data.modifiers, result, args.size()); +#endif } //===--------------------------------------------------------------------===// @@ -1074,6 +1168,20 @@ struct DecodeSortKeyData { const_data_ptr_t data; idx_t size; idx_t position; + + inline void RequireRemaining(idx_t required, const char *context) const { + (void)context; + D_ASSERT(position <= size); + D_ASSERT(required <= size - position); + } + + inline data_t ReadByte(const char *context) { + RequireRemaining(1, context); + if (position >= size) { + return 0; + } + return data[position++]; + } }; void DecodeSortKeyRecursive(DecodeSortKeyData decode_data[], DecodeSortKeyVectorData &vector_data, Vector &result, @@ -1090,16 +1198,17 @@ void TemplatedDecodeSortKeyInternal(DecodeSortKeyData decode_data_arr[], DecodeS for (idx_t i = 0; i < count; i++) { const auto result_idx = result_offset + i; auto &decode_data = decode_data_arr[i]; - auto validity_byte = decode_data.data[decode_data.position]; - decode_data.position++; + auto validity_byte = decode_data.ReadByte("reading validity byte"); if (validity_byte == null_byte) { // NULL value result_validity.SetInvalid(result_idx); continue; } - idx_t increment = - OP::template Decode(decode_data.data + decode_data.position, result, result_data[result_idx]); + auto remaining = decode_data.size - decode_data.position; + idx_t increment = OP::template Decode(decode_data.data + decode_data.position, remaining, result, + result_data[result_idx]); decode_data.position += increment; + D_ASSERT(decode_data.position <= decode_data.size); } } @@ -1121,8 +1230,7 @@ void DecodeSortKeyStruct(DecodeSortKeyData decode_data_arr[], DecodeSortKeyVecto const auto result_idx = result_offset + i; auto &decode_data = decode_data_arr[i]; // check if the top-level is valid or not - auto validity_byte = decode_data.data[decode_data.position]; - decode_data.position++; + auto validity_byte = decode_data.ReadByte("reading struct validity byte"); if (validity_byte == vector_data.null_byte) { // entire struct is NULL // note that we still deserialize the children @@ -1148,8 +1256,7 @@ void DecodeSortKeyList(DecodeSortKeyData decode_data_arr[], DecodeSortKeyVectorD const auto result_idx = result_offset + i; auto &decode_data = decode_data_arr[i]; // check if the top-level is valid or not - auto validity_byte = decode_data.data[decode_data.position]; - decode_data.position++; + auto validity_byte = decode_data.ReadByte("reading list validity byte"); if (validity_byte == vector_data.null_byte) { // entire list is NULL result_validity.SetInvalid(result_idx); @@ -1167,7 +1274,11 @@ void DecodeSortKeyList(DecodeSortKeyData decode_data_arr[], DecodeSortKeyVectorD auto start_list_size = ListVector::GetListSize(result); auto new_list_size = start_list_size; // loop until we find the list delimiter - while (decode_data.data[decode_data.position] != list_delimiter) { + while (true) { + decode_data.RequireRemaining(1, "scanning list delimiter"); + if (decode_data.data[decode_data.position] == list_delimiter) { + break; + } // found a valid entry here - decode it // first reserve space for it new_list_size++; @@ -1177,7 +1288,7 @@ void DecodeSortKeyList(DecodeSortKeyData decode_data_arr[], DecodeSortKeyVectorD DecodeSortKeyRecursive(&decode_data, vector_data.child_data[0], child_vector, new_list_size - 1, 1); } // skip the list delimiter - decode_data.position++; + decode_data.ReadByte("consuming list delimiter"); // set the list_entry_t information and update the list size list_data[result_idx].length = new_list_size - start_list_size; list_data[result_idx].offset = start_list_size; @@ -1193,8 +1304,7 @@ void DecodeSortKeyArray(DecodeSortKeyData decode_data_arr[], DecodeSortKeyVector const auto result_idx = result_offset + i; auto &decode_data = decode_data_arr[i]; // check if the top-level is valid or not - auto validity_byte = decode_data.data[decode_data.position]; - decode_data.position++; + auto validity_byte = decode_data.ReadByte("reading array validity byte"); if (validity_byte == vector_data.null_byte) { // entire array is NULL // note that we still read the child elements @@ -1214,7 +1324,11 @@ void DecodeSortKeyArray(DecodeSortKeyData decode_data_arr[], DecodeSortKeyVector idx_t found_elements = 0; auto child_start = array_size * result_idx; // loop until we find the list delimiter - while (decode_data.data[decode_data.position] != list_delimiter) { + while (true) { + decode_data.RequireRemaining(1, "scanning array delimiter"); + if (decode_data.data[decode_data.position] == list_delimiter) { + break; + } found_elements++; if (found_elements > array_size) { // error - found too many elements @@ -1225,7 +1339,7 @@ void DecodeSortKeyArray(DecodeSortKeyData decode_data_arr[], DecodeSortKeyVector child_start + found_elements - 1, 1); } // skip the list delimiter - decode_data.position++; + decode_data.ReadByte("consuming array delimiter"); if (found_elements != array_size) { throw InvalidInputException("Failed to decode array - found %d elements but expected %d", found_elements, array_size); @@ -1332,10 +1446,10 @@ void CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, DataChunk &result, i } static void DecodeSortKeyFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &bind_data = state.expr.Cast().bind_info->Cast(); + auto &bind_data = state.expr.Cast().BindInfo()->Cast(); const auto count = args.size(); - auto &sort_key_vec = args.data[0]; + const auto &sort_key_vec = args.data[0]; UnifiedVectorFormat sort_key_vec_format; sort_key_vec.ToUnifiedFormat(sort_key_vec_format); @@ -1343,8 +1457,8 @@ static void DecodeSortKeyFunction(DataChunk &args, ExpressionState &state, Vecto // However, all the actual values should be valid, so we assert that // Construct utility for all sort keys that we will decode - DecodeSortKeyData decode_data[STANDARD_VECTOR_SIZE]; - int64_t bswapped_ints[STANDARD_VECTOR_SIZE]; + vector decode_data(count); + vector bswapped_ints(count); if (sort_key_vec.GetType() == LogicalType::BLOB) { const auto sort_keys = UnifiedVectorFormat::GetData(sort_key_vec_format); if (sort_key_vec_format.sel->IsSet()) { @@ -1384,7 +1498,7 @@ static void DecodeSortKeyFunction(DataChunk &args, ExpressionState &state, Vecto for (idx_t c = 0; c < StructType::GetChildCount(result_type); c++) { auto &child_vector = child_vectors[c]; DecodeSortKeyVectorData sort_key_data(child_vector.GetType(), bind_data.modifiers[c]); - DecodeSortKeyRecursive(decode_data, sort_key_data, child_vector, 0, count); + DecodeSortKeyRecursive(decode_data.data(), sort_key_data, child_vector, 0, count); } } diff --git a/src/duckdb/src/function/scalar/date/strftime.cpp b/src/duckdb/src/function/scalar/date/strftime.cpp index a258d9010..8573b4f3d 100644 --- a/src/duckdb/src/function/scalar/date/strftime.cpp +++ b/src/duckdb/src/function/scalar/date/strftime.cpp @@ -70,37 +70,37 @@ static unique_ptr StrfTimeBindFunction(BindScalarFunctionInput &in template static void StrfTimeFunctionDate(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); if (info.is_null) { ConstantVector::SetNull(result, count_t(args.size())); return; } - info.format.ConvertDateVector(args.data[REVERSED ? 1 : 0], result, args.size()); + info.format.ConvertDateVector(args.data[REVERSED ? 1 : 0], result); } template static void StrfTimeFunctionTimestamp(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); if (info.is_null) { ConstantVector::SetNull(result, count_t(args.size())); return; } - info.format.ConvertTimestampVector(args.data[REVERSED ? 1 : 0], result, args.size()); + info.format.ConvertTimestampVector(args.data[REVERSED ? 1 : 0], result); } template static void StrfTimeFunctionTimestampNS(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); if (info.is_null) { ConstantVector::SetNull(result, count_t(args.size())); return; } - info.format.ConvertTimestampNSVector(args.data[REVERSED ? 1 : 0], result, args.size()); + info.format.ConvertTimestampNSVector(args.data[REVERSED ? 1 : 0], result); } struct StrpTimeBindData : public FunctionData { @@ -149,7 +149,7 @@ struct StrpTimeFunction { template static void Parse(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); // There is a bizarre situation where the format column is foldable but not constant // (i.e., the statistics tell us it has only one value) @@ -160,7 +160,7 @@ struct StrpTimeFunction { ConstantVector::SetNull(result, count_t(args.size())); return; } - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(args.data[0], result, [&](string_t input) { StrpTimeFormat::ParseResult result; for (auto &format : info.formats) { if (format.Parse(input, result)) { @@ -174,14 +174,14 @@ struct StrpTimeFunction { template static void TryParse(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); if (args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR && ConstantVector::IsNull(args.data[1])) { ConstantVector::SetNull(result, count_t(args.size())); return; } - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) -> optional { + UnaryExecutor::Execute(args.data[0], result, [&](string_t input) -> optional { T result; string error; for (auto &format : info.formats) { diff --git a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp index b9640e7c4..e71f148bc 100644 --- a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp +++ b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp @@ -27,7 +27,7 @@ struct ConstantOrNullBindData : public FunctionData { static void ConstantOrNullFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); result.Reference(info.value, count_t(args.size())); for (idx_t idx = 1; idx < args.ColumnCount(); idx++) { switch (args.data[idx].GetVectorType()) { @@ -94,11 +94,11 @@ unique_ptr ConstantOrNull::Bind(Value value) { } bool ConstantOrNull::IsConstantOrNull(BoundFunctionExpression &expr, const Value &val) { - if (expr.function.GetName() != "constant_or_null") { + if (expr.Function().GetName() != "constant_or_null") { return false; } - D_ASSERT(expr.bind_info); - auto &bind_data = expr.bind_info->Cast(); + D_ASSERT(expr.BindInfo()); + auto &bind_data = expr.BindInfo()->Cast(); D_ASSERT(bind_data.value.type() == val.type()); return bind_data.value == val; } diff --git a/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp index 2bb3a01b8..1a410c1e6 100644 --- a/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp +++ b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp @@ -17,7 +17,7 @@ ScalarFunction StGeomfromwkbFun::GetFunction() { } static void ToWKBFunction(DataChunk &input, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](const string_t &geom) { + UnaryExecutor::Execute(input.data[0], result, [&](const string_t &geom) { // TODO: convert to internal representation return geom; }); @@ -32,7 +32,7 @@ ScalarFunction StAswkbFun::GetFunction() { static void ToWKTFunction(DataChunk &input, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(input.data[0], result, input.size(), + UnaryExecutor::Execute(input.data[0], result, [&](const string_t &geom) { return Geometry::ToString(heap, geom); }); } @@ -43,7 +43,7 @@ ScalarFunction StAstextFun::GetFunction() { static void IntersectsExtentFunction(DataChunk &input, ExpressionState &state, Vector &result) { BinaryExecutor::Execute( - input.data[0], input.data[1], result, input.size(), [](const string_t &lhs_geom, const string_t &rhs_geom) { + input.data[0], input.data[1], result, [](const string_t &lhs_geom, const string_t &rhs_geom) { auto lhs_extent = GeometryExtent::Empty(); auto rhs_extent = GeometryExtent::Empty(); diff --git a/src/duckdb/src/function/scalar/list/contains_or_position.cpp b/src/duckdb/src/function/scalar/list/contains_or_position.cpp index 641c25954..510f1d426 100644 --- a/src/duckdb/src/function/scalar/list/contains_or_position.cpp +++ b/src/duckdb/src/function/scalar/list/contains_or_position.cpp @@ -14,9 +14,9 @@ static void ListSearchFunction(DataChunk &input, ExpressionState &state, Vector } auto target_count = input.size(); - auto &input_list = input.data[0]; - auto &list_child = ListVector::GetChildMutable(input_list); - auto &target = input.data[1]; + const auto &input_list = input.data[0]; + const auto &list_child = ListVector::GetChild(input_list); + const auto &target = input.data[1]; ListSearchOp(input_list, list_child, target, result, target_count); } diff --git a/src/duckdb/src/function/scalar/list/list_extract.cpp b/src/duckdb/src/function/scalar/list/list_extract.cpp index 8c1192f64..58e2b8c92 100644 --- a/src/duckdb/src/function/scalar/list/list_extract.cpp +++ b/src/duckdb/src/function/scalar/list/list_extract.cpp @@ -38,8 +38,9 @@ static optional_idx TryGetChildOffset(const list_entry_t &list_entry, const int6 return optional_idx(list_entry.offset + unsigned_offset); } -static void ExecuteListExtract(Vector &result, Vector &list, Vector &offsets, const idx_t count) { +static void ExecuteListExtract(Vector &result, const Vector &list, const Vector &offsets) { D_ASSERT(list.GetType().id() == LogicalTypeId::LIST); + const auto count = list.size(); auto list_entries = list.Values(); auto offsets_entries = offsets.Values(); @@ -93,25 +94,25 @@ static void ExecuteListExtract(Vector &result, Vector &list, Vector &offsets, co } } -static void ExecuteStringExtract(Vector &result, Vector &input_vector, Vector &subscript_vector, const idx_t count) { +static void ExecuteStringExtract(Vector &result, const Vector &input_vector, const Vector &subscript_vector) { BinaryExecutor::Execute( - input_vector, subscript_vector, result, count, + input_vector, subscript_vector, result, [&](string_t input_string, int64_t subscript) { return SubstringUnicode(result, input_string, subscript, 1); }); } static void ListExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 2); - auto count = args.size(); + const auto count = args.size(); - Vector &base = args.data[0]; - Vector &subscript = args.data[1]; + const Vector &base = args.data[0]; + const Vector &subscript = args.data[1]; switch (base.GetType().id()) { case LogicalTypeId::LIST: - ExecuteListExtract(result, base, subscript, count); + ExecuteListExtract(result, base, subscript); break; case LogicalTypeId::VARCHAR: - ExecuteStringExtract(result, base, subscript, count); + ExecuteStringExtract(result, base, subscript); break; case LogicalTypeId::SQLNULL: ConstantVector::SetNull(result, count_t(count)); @@ -146,6 +147,9 @@ ScalarFunctionSet ListExtractFun::GetFunctions() { ScalarFunction lfun({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::BIGINT}, LogicalType::TEMPLATE("T"), ListExtractFunction, ListExtractBind, ListExtractStats); + lfun.GetSignature().GetParameter(0).SetName("list"); + lfun.GetSignature().GetParameter(1).SetName("index"); + ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); lfun.SetFallible(); sfun.SetFallible(); @@ -161,6 +165,9 @@ ScalarFunctionSet ArrayExtractFun::GetFunctions() { ScalarFunction lfun({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::BIGINT}, LogicalType::TEMPLATE("T"), ListExtractFunction, ListExtractBind, ListExtractStats); + lfun.GetSignature().GetParameter(0).SetName("array"); + lfun.GetSignature().GetParameter(1).SetName("index"); + ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); array_extract_set.AddFunction(lfun); diff --git a/src/duckdb/src/function/scalar/list/list_intersect.cpp b/src/duckdb/src/function/scalar/list/list_intersect.cpp index 6555b135c..46f96c5c6 100644 --- a/src/duckdb/src/function/scalar/list/list_intersect.cpp +++ b/src/duckdb/src/function/scalar/list/list_intersect.cpp @@ -34,8 +34,8 @@ static void ListIntersectFunction(DataChunk &args, ExpressionState &state, Vecto const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); - CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(l_child, order_modifiers, l_sortkey_vec); + CreateSortKeyHelpers::CreateSortKey(r_child, order_modifiers, r_sortkey_vec); const auto l_sortkey_ptr = l_sortkey_vec.Values(); const auto r_sortkey_ptr = r_sortkey_vec.Values(); diff --git a/src/duckdb/src/function/scalar/list/list_resize.cpp b/src/duckdb/src/function/scalar/list/list_resize.cpp index 1f77f83ba..9530fab40 100644 --- a/src/duckdb/src/function/scalar/list/list_resize.cpp +++ b/src/duckdb/src/function/scalar/list/list_resize.cpp @@ -15,8 +15,8 @@ static void ListResizeFunction(DataChunk &args, ExpressionState &, Vector &resul return; } - auto &lists = args.data[0]; - auto &new_sizes = args.data[1]; + const auto &lists = args.data[0]; + const auto &new_sizes = args.data[1]; auto row_count = args.size(); D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); D_ASSERT(new_sizes.GetType().id() == LogicalTypeId::UBIGINT); @@ -73,7 +73,7 @@ static void ListResizeFunction(DataChunk &args, ExpressionState &, Vector &resul for (idx_t j = 0; j < remaining_count.GetValue(); j++) { sel.set_index(j, row_idx); } - auto &default_vector = args.data[2]; + const auto &default_vector = args.data[2]; list.Append(default_vector, sel, args.size(), 0, remaining_count.GetValue()); continue; } diff --git a/src/duckdb/src/function/scalar/list/list_select.cpp b/src/duckdb/src/function/scalar/list/list_select.cpp index e5fea0492..7ec6c0c33 100644 --- a/src/duckdb/src/function/scalar/list/list_select.cpp +++ b/src/duckdb/src/function/scalar/list/list_select.cpp @@ -89,8 +89,8 @@ struct SetSelectionVectorWhere { template void ListSelectFunction(const DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.data.size() == 2); - auto &list = args.data[0]; - auto &selection_list = args.data[1]; + const auto &list = args.data[0]; + const auto &selection_list = args.data[1]; idx_t count = args.size(); auto selection_list_data = selection_list.Values(); diff --git a/src/duckdb/src/function/scalar/list/list_zip.cpp b/src/duckdb/src/function/scalar/list/list_zip.cpp index ef1276d5c..16051c565 100644 --- a/src/duckdb/src/function/scalar/list/list_zip.cpp +++ b/src/duckdb/src/function/scalar/list/list_zip.cpp @@ -74,6 +74,13 @@ static void ListZipFunction(DataChunk &args, ExpressionState &state, Vector &res masks.push_back(ValidityMask(result_size)); } + // Ensure list children are flat so FlatVector::Validity can be used below + for (idx_t i = 0; i < args_size; i++) { + if (input_lists[i].has_value() && ListVector::GetListSize(args.data[i]) > 0) { + ListVector::GetChildMutable(args.data[i]).Flatten(); + } + } + idx_t offset = 0; auto result_data = FlatVector::Writer(result, count); for (idx_t row_idx = 0; row_idx < count; row_idx++) { @@ -113,15 +120,17 @@ static void ListZipFunction(DataChunk &args, ExpressionState &state, Vector &res result_data.WriteValue(entry); offset += len; } - if (result_size > 0) { - for (idx_t child_idx = 0; child_idx < args_size; child_idx++) { - if (args.data[child_idx].GetType() != LogicalType::SQLNULL) { - struct_entries[child_idx].Slice(ListVector::GetChild(args.data[child_idx]), selections[child_idx], - result_size); - } - struct_entries[child_idx].Flatten(); - FlatVector::SetValidity((struct_entries[child_idx]), masks[child_idx]); + for (idx_t child_idx = 0; child_idx < args_size; child_idx++) { + if (args.data[child_idx].GetType() == LogicalType::SQLNULL || + ListVector::GetListSize(args.data[child_idx]) == 0) { + struct_entries[child_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(struct_entries[child_idx], true); + } else { + struct_entries[child_idx].Slice(ListVector::GetChild(args.data[child_idx]), selections[child_idx], + result_size); } + struct_entries[child_idx].Flatten(); + FlatVector::SetValidity((struct_entries[child_idx]), masks[child_idx]); } } @@ -142,17 +151,17 @@ static unique_ptr ListZipBind(BindScalarFunctionInput &input) { } } - case_insensitive_set_t struct_names; + identifier_set_t struct_names; for (idx_t i = 0; i < size; i++) { auto &child = arguments[i]; switch (child->GetReturnType().id()) { case LogicalTypeId::LIST: case LogicalTypeId::ARRAY: child = BoundCastExpression::AddArrayCastToList(context, std::move(child)); - struct_children.push_back(make_pair(string(), ListType::GetChildType(child->GetReturnType()))); + struct_children.emplace_back(make_pair(string(), ListType::GetChildType(child->GetReturnType()))); break; case LogicalTypeId::SQLNULL: - struct_children.push_back(make_pair(string(), LogicalTypeId::SQLNULL)); + struct_children.emplace_back(make_pair(string(), LogicalTypeId::SQLNULL)); break; case LogicalTypeId::UNKNOWN: throw ParameterNotResolvedException(); diff --git a/src/duckdb/src/function/scalar/map/map_contains.cpp b/src/duckdb/src/function/scalar/map/map_contains.cpp index 5f83ef914..85e3f1f93 100644 --- a/src/duckdb/src/function/scalar/map/map_contains.cpp +++ b/src/duckdb/src/function/scalar/map/map_contains.cpp @@ -8,9 +8,9 @@ namespace duckdb { static void MapContainsFunction(DataChunk &input, ExpressionState &state, Vector &result) { const auto count = input.size(); - auto &map_vec = input.data[0]; - auto &key_vec = MapVector::GetKeys(map_vec); - auto &arg_vec = input.data[1]; + const auto &map_vec = input.data[0]; + const auto &key_vec = MapVector::GetKeys(map_vec); + const auto &arg_vec = input.data[1]; ListSearchOp(map_vec, key_vec, arg_vec, result, count); } diff --git a/src/duckdb/src/function/scalar/nested_functions.cpp b/src/duckdb/src/function/scalar/nested_functions.cpp index 69e306079..cb6d13ce2 100644 --- a/src/duckdb/src/function/scalar/nested_functions.cpp +++ b/src/duckdb/src/function/scalar/nested_functions.cpp @@ -5,11 +5,12 @@ namespace duckdb { -void MapUtil::ReinterpretMap(Vector &result, Vector &input, idx_t count) { +void MapUtil::ReinterpretMap(Vector &result, const Vector &input) { input.Flatten(); + const auto count = input.size(); - auto &input_keys = MapVector::GetKeys(input); - auto &input_values = MapVector::GetValues(input); + const auto &input_keys = MapVector::GetKeys(input); + const auto &input_values = MapVector::GetValues(input); // Copy the list offsets and top-level validity auto result_data = FlatVector::Writer(result, count); diff --git a/src/duckdb/src/function/scalar/operator/arithmetic.cpp b/src/duckdb/src/function/scalar/operator/arithmetic.cpp index dbcdc6d35..51e840dc5 100644 --- a/src/duckdb/src/function/scalar/operator/arithmetic.cpp +++ b/src/duckdb/src/function/scalar/operator/arithmetic.cpp @@ -1,3 +1,9 @@ +#include "duckdb/common/assert.hpp" +#include "duckdb/common/checked_integer.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/function/scalar_function.hpp" #include "duckdb/main/settings.hpp" #include "duckdb/common/enum_util.hpp" @@ -21,6 +27,10 @@ #include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/numeric_stats.hpp" +#include +#include namespace duckdb { @@ -81,6 +91,46 @@ static scalar_function_t GetScalarBinaryFunction(PhysicalType type) { return function; } +template +static Value NumericStatsValue(const LogicalType &type, T value) { + D_ASSERT(type.IsNumeric()); + switch (type.InternalType()) { + case PhysicalType::FLOAT: + return Value::FLOAT(value); + case PhysicalType::DOUBLE: + return Value::DOUBLE(value); + default: + return Value::Numeric(type, value); + } +} + +template +static bool WidenFloatingBounds(const LogicalType &type, Value &new_min, Value &new_max) { + auto min = new_min.GetValue(); + auto max = new_max.GetValue(); + if (!Value::IsFinite(min) || !Value::IsFinite(max)) { + return false; + } + min = std::nextafter(min, -std::numeric_limits::infinity()); + max = std::nextafter(max, std::numeric_limits::infinity()); + + new_min = NumericStatsValue(type, min); + new_max = NumericStatsValue(type, max); + return true; +} + +template +struct FloatingTryOperator { + template + static bool Operation(T left, T right, T &result) { + if (!Value::IsFinite(left) || !Value::IsFinite(right)) { + return false; + } + result = BASEOP::template Operation(left, right); + return Value::IsFinite(result); + } +}; + //===--------------------------------------------------------------------===// // + [add] //===--------------------------------------------------------------------===// @@ -99,8 +149,9 @@ struct AddPropagateStatistics { if (!OP::Operation(NumericStats::GetMax(lstats), NumericStats::GetMax(rstats), max)) { return true; } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); + + new_min = NumericStatsValue(type, min); + new_max = NumericStatsValue(type, max); return false; } }; @@ -116,8 +167,8 @@ struct SubtractPropagateStatistics { if (!OP::Operation(NumericStats::GetMax(lstats), NumericStats::GetMin(rstats), max)) { return true; } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); + new_min = NumericStatsValue(type, min); + new_max = NumericStatsValue(type, max); return false; } }; @@ -181,7 +232,8 @@ unique_ptr PropagateNumericStats(ClientContext &context, Functio auto &bind_data = input.bind_data->Cast(); bind_data.check_overflow = false; } - expr.function.SetFunctionCallback(GetScalarIntegerFunction(expr.GetReturnType().InternalType())); + expr.FunctionMutable().SetFunctionCallback( + GetScalarIntegerFunction(expr.GetReturnType().InternalType())); } auto result = NumericStats::CreateEmpty(expr.GetReturnType()); NumericStats::SetMin(result, new_min); @@ -190,6 +242,48 @@ unique_ptr PropagateNumericStats(ClientContext &context, Functio return result.ToUnique(); } +template +unique_ptr PropagateFloatingStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &child_stats = input.child_stats; + auto &expr = input.expr; + D_ASSERT(child_stats.size() == 2); + + auto &return_type = expr.GetReturnType(); + + auto &lstats = child_stats[0]; + auto &rstats = child_stats[1]; + if (!NumericStats::HasMinMax(lstats) || !NumericStats::HasMinMax(rstats)) { + return nullptr; + } + + Value new_min, new_max; + bool failed; + switch (return_type.InternalType()) { + case PhysicalType::FLOAT: + failed = PROPAGATE::template Operation>(return_type, lstats, rstats, new_min, + new_max); + if (failed || !WidenFloatingBounds(return_type, new_min, new_max)) { + return nullptr; + } + break; + case PhysicalType::DOUBLE: + failed = PROPAGATE::template Operation>(return_type, lstats, rstats, + new_min, new_max); + if (failed || !WidenFloatingBounds(return_type, new_min, new_max)) { + return nullptr; + } + break; + default: + return nullptr; + } + + auto result = NumericStats::CreateEmpty(return_type); + NumericStats::SetMin(result, new_min); + NumericStats::SetMax(result, new_max); + result.CombineValidity(lstats, rstats); + return result.ToUnique(); +} + template unique_ptr BindDecimalArithmetic(BindScalarFunctionInput &input) { auto &bound_function = input.GetBoundFunction(); @@ -323,7 +417,7 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &type) { static void BignumAdd(DataChunk &args, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); - BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), + BinaryExecutor::Execute(args.data[0], args.data[1], result, [&](bignum_t a, bignum_t b) { const BignumIntermediate lhs(a); const BignumIntermediate rhs(b); @@ -334,7 +428,7 @@ static void BignumAdd(DataChunk &args, ExpressionState &state, Vector &result) { static void BignumSubtract(DataChunk &args, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](bignum_t a, bignum_t b) { + args.data[0], args.data[1], result, [&](bignum_t a, bignum_t b) { const BignumIntermediate lhs(a); BignumIntermediate rhs(b); rhs.NegateInPlace(); @@ -346,7 +440,7 @@ static void BignumSubtract(DataChunk &args, ExpressionState &state, Vector &resu static void BignumNegate(DataChunk &args, ExpressionState &state, Vector &result) { auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](bignum_t a) { + UnaryExecutor::Execute(args.data[0], result, [&](bignum_t a) { const BignumIntermediate lhs(a); return lhs.Negate(heap); }); @@ -372,6 +466,13 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi function.SetFallible(); function.SetArgProperties({inc, inc}); return function; + } else if (left_type.IsFloating()) { + ScalarFunction function("+", {left_type, right_type}, left_type, + GetScalarBinaryFunction(left_type.InternalType())); + function.SetFallible(); + function.SetArgProperties({inc, inc}); + function.SetStatisticsCallback(PropagateFloatingStats); + return function; } else { ScalarFunction function("+", {left_type, right_type}, left_type, GetScalarBinaryFunction(left_type.InternalType())); @@ -617,12 +718,12 @@ static unique_ptr IntegerNegateBind(BindScalarFunctionInput &input return nullptr; } auto &const_expr = arguments[0]->Cast(); - if (const_expr.value.IsNull()) { + if (const_expr.GetValue().IsNull()) { return nullptr; } auto &type = bound_function.GetArguments()[0]; // only need to promote if the constant exactly equals the type's minimum value - if (const_expr.value != Value::MinimumValue(type)) { + if (const_expr.GetValue() != Value::MinimumValue(type)) { return nullptr; } LogicalType promoted_type; @@ -692,6 +793,13 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const function.SetArgProperties({ArgProperties().StrictlyIncreasing(), ArgProperties().StrictlyDecreasing()}); return function; + } else if (left_type.IsFloating()) { + ScalarFunction function("-", {left_type, right_type}, left_type, + GetScalarBinaryFunction(left_type.InternalType())); + function.SetFallible(); + function.SetArgProperties({ArgProperties().StrictlyIncreasing(), ArgProperties().StrictlyDecreasing()}); + function.SetStatisticsCallback(PropagateFloatingStats); + return function; } else { ScalarFunction function("-", {left_type, right_type}, left_type, GetScalarBinaryFunction(left_type.InternalType())); @@ -853,8 +961,8 @@ struct MultiplyPropagateStatistics { } } } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); + new_min = NumericStatsValue(type, min); + new_max = NumericStatsValue(type, max); return false; } }; @@ -947,6 +1055,10 @@ ScalarFunctionSet OperatorMultiplyFun::GetFunctions() { multiply.AddFunction(ScalarFunction( {type, type}, type, GetScalarIntegerFunction(type.InternalType()), nullptr, PropagateNumericStats)); + } else if (type.IsFloating()) { + multiply.AddFunction(ScalarFunction({type, type}, type, + GetScalarBinaryFunction(type.InternalType()), nullptr, + PropagateFloatingStats)); } else { multiply.AddFunction( ScalarFunction({type, type}, type, GetScalarBinaryFunction(type.InternalType()))); @@ -1060,7 +1172,7 @@ struct BinaryNumericDivideHugeintWrapper { template void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute(input.data[0], input.data[1], result, input.size()); + BinaryExecutor::Execute(input.data[0], input.data[1], result); } template diff --git a/src/duckdb/src/function/scalar/operator/decimal_division.cpp b/src/duckdb/src/function/scalar/operator/decimal_division.cpp new file mode 100644 index 000000000..bc706cd91 --- /dev/null +++ b/src/duckdb/src/function/scalar/operator/decimal_division.cpp @@ -0,0 +1,289 @@ +#include "duckdb/common/exception.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/types/cast_helpers.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/operator_functions.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Bind data +//===--------------------------------------------------------------------===// + +struct DecimalDivBindData : public FunctionData { + // scale_exp = s2 + result_scale - s1 + // >= 0: scale numerator up: result_int = (a_int * 10^scale_exp) / b_int + // < 0: scale divisor up: result_int = a_int / (b_int * 10^|scale_exp|) + int32_t scale_exp; + + explicit DecimalDivBindData(int32_t scale_exp_p) : scale_exp(scale_exp_p) { + } + + unique_ptr Copy() const override { + return make_uniq(scale_exp); + } + + bool Equals(const FunctionData &other) const override { + return scale_exp == other.Cast().scale_exp; + } +}; + +//===--------------------------------------------------------------------===// +// Execute functions +// +// COMPUTE_TYPE controls intermediate arithmetic: +// int64_t -- when input_max_width + |scale_exp| < CACHED_POWERS_OF_TEN (product fits in int64_t) +// hugeint_t -- otherwise +//===--------------------------------------------------------------------===// + +template +static void DecimalDivExecute(DataChunk &args, ExpressionState &state, Vector &result) { + auto &bind_data = state.expr.Cast().BindInfo()->Cast(); + + bool scale_divisor = bind_data.scale_exp < 0; + uint8_t abs_exp = UnsafeNumericCast(AbsValue(bind_data.scale_exp)); + + COMPUTE_TYPE scale_factor; + if constexpr (std::is_same_v) { + scale_factor = NumericHelper::POWERS_OF_TEN[abs_exp]; + } else { + scale_factor = Hugeint::POWERS_OF_TEN[abs_exp]; + } + + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](INPUT_TYPE a, INPUT_TYPE b) -> RESULT_TYPE { + if (b == INPUT_TYPE(0)) { + throw InvalidInputException("decimal_division: division by zero"); + } + + COMPUTE_TYPE numerator, divisor; + if (scale_divisor) { + numerator = COMPUTE_TYPE(a); + if constexpr (std::is_same_v) { + divisor = COMPUTE_TYPE(b) * scale_factor; + } else { + if (!Hugeint::TryMultiply(hugeint_t(b), scale_factor, divisor)) { + throw OutOfRangeException("decimal_division: scaled divisor overflowed DECIMAL range"); + } + } + } else { + numerator = COMPUTE_TYPE(a) * scale_factor; + divisor = COMPUTE_TYPE(b); + } + + bool negative; + COMPUTE_TYPE abs_num, abs_div; + if constexpr (std::is_same_v) { + negative = (numerator < 0) != (divisor < 0); + abs_num = numerator < 0 ? -numerator : numerator; + abs_div = divisor < 0 ? -divisor : divisor; + } else { + negative = (numerator.upper < 0) != (divisor.upper < 0); + abs_num = numerator.upper < 0 ? -numerator : numerator; + abs_div = divisor.upper < 0 ? -divisor : divisor; + } + + COMPUTE_TYPE q; + if constexpr (std::is_same_v) { + q = abs_num / abs_div; + } else { + if (abs_div.upper == 0) { + uint64_t r64; + q = Hugeint::DivModPositive(abs_num, abs_div.lower, r64); + } else { + hugeint_t r; + q = Hugeint::DivMod(abs_num, abs_div, r); + } + } + + COMPUTE_TYPE final_val = negative ? -q : q; + RESULT_TYPE out; + if constexpr (std::is_same_v) { + if (!TryCast::Operation(final_val, out)) { + throw OutOfRangeException("decimal_division: result out of range for result type"); + } + } else { + if (!Hugeint::TryCast(final_val, out)) { + throw OutOfRangeException("decimal_division: result out of range for result type"); + } + } + return out; + }); +} + +template +static scalar_function_t GetDecimalDivExecuteFunction(PhysicalType result_physical) { + switch (result_physical) { + case PhysicalType::INT16: + return DecimalDivExecute; + case PhysicalType::INT32: + return DecimalDivExecute; + case PhysicalType::INT64: + return DecimalDivExecute; + default: + return DecimalDivExecute; + } +} + +//===--------------------------------------------------------------------===// +// Bind function +//===--------------------------------------------------------------------===// + +// Result width and scale for e1 / e2 (SQL Server semantics): +// result_scale = max(6, s1 + p2 + 1) +// result_width = p1 - s1 + s2 + result_scale +// Both are capped at Decimal::MAX_WIDTH_DECIMAL (38). +static unique_ptr DecimalDivisionBind(BindScalarFunctionInput &input) { + auto &bound_function = input.GetBoundFunction(); + auto &arguments = input.GetArguments(); + + uint8_t p1, s1, p2, s2; + if (!arguments[0]->GetReturnType().GetDecimalProperties(p1, s1) || + !arguments[1]->GetReturnType().GetDecimalProperties(p2, s2)) { + throw InvalidInputException("decimal_division: both arguments must be DECIMAL"); + } + + uint8_t result_scale; + if (arguments.size() == 3) { + if (!arguments[2]->IsFoldable()) { + throw NotImplementedException("decimal_division: scale argument must be a constant integer"); + } + Value scale_val = ExpressionExecutor::EvaluateScalar(input.GetClientContext(), *arguments[2]); + if (scale_val.IsNull()) { + throw InvalidInputException("decimal_division: scale argument must not be NULL"); + } + int32_t scale_i = scale_val.GetValue(); + if (scale_i < 0 || scale_i > Decimal::MAX_WIDTH_DECIMAL) { + throw InvalidInputException("decimal_division: scale must be between 0 and %d, got %d", + Decimal::MAX_WIDTH_DECIMAL, scale_i); + } + result_scale = UnsafeNumericCast(scale_i); + } else { + result_scale = MaxValue(6, s1 + p2 + 1); + } + // result_width uses signed arithmetic to catch underflow before casting. + auto result_width_i = p1 - s1 + s2 + result_scale; + if (result_width_i < 1) { + throw InvalidInputException("decimal_division: scale %d produces a zero-width result for these input types", + result_scale); + } + uint8_t result_width = UnsafeNumericCast(result_width_i); + + if (result_scale > Decimal::MAX_WIDTH_DECIMAL) { + throw OutOfRangeException( + "Needed scale %d to accurately represent the division result, but this is out of range of the " + "DECIMAL type. Max scale is %d; could not perform an accurate division. Either add a cast to DOUBLE, " + "or add an explicit cast to a decimal with a lower scale.", + result_scale, Decimal::MAX_WIDTH_DECIMAL); + } + if (result_width > Decimal::MAX_WIDTH_DECIMAL) { + throw OutOfRangeException( + "Needed width %d to accurately represent the division result, but this is out of range of the " + "DECIMAL type. Max width is %d; could not perform an accurate division. Either add a cast to DOUBLE, " + "or add an explicit cast to a decimal with a lower scale.", + result_width, Decimal::MAX_WIDTH_DECIMAL); + } + + PhysicalType result_physical; + if (result_width <= Decimal::MAX_WIDTH_INT16) { + result_physical = PhysicalType::INT16; + } else if (result_width <= Decimal::MAX_WIDTH_INT32) { + result_physical = PhysicalType::INT32; + } else if (result_width <= Decimal::MAX_WIDTH_INT64) { + result_physical = PhysicalType::INT64; + } else { + result_physical = PhysicalType::INT128; + } + + bound_function.SetReturnType(LogicalType::DECIMAL(result_width, result_scale)); + + auto lhs_physical = arguments[0]->GetReturnType().InternalType(); + auto rhs_physical = arguments[1]->GetReturnType().InternalType(); + auto wider = MaxValue(lhs_physical, rhs_physical); + + int32_t scale_exp = s2 + result_scale - s1; + int32_t abs_scale_exp = scale_exp < 0 ? -scale_exp : scale_exp; + + uint8_t input_max_width; + switch (wider) { + case PhysicalType::INT16: + input_max_width = Decimal::MAX_WIDTH_INT16; + break; + case PhysicalType::INT32: + input_max_width = Decimal::MAX_WIDTH_INT32; + break; + case PhysicalType::INT64: + input_max_width = Decimal::MAX_WIDTH_INT64; + break; + case PhysicalType::INT128: + input_max_width = Decimal::MAX_WIDTH_DECIMAL; + break; + default: + throw InternalException("decimal_division: unexpected physical type"); + } + + // For same-tier inputs, preserve the original argument type so bind is idempotent + // during plan deserialization (bind is re-called with the stored argument types, and + // changing p changes the result-width formula). + // For cross-tier inputs, promote the narrower argument to the wider physical tier. + bound_function.GetArguments()[0] = + (lhs_physical != wider) ? LogicalType::DECIMAL(input_max_width, s1) : arguments[0]->GetReturnType(); + bound_function.GetArguments()[1] = + (rhs_physical != wider) ? LogicalType::DECIMAL(input_max_width, s2) : arguments[1]->GetReturnType(); + + switch (wider) { + case PhysicalType::INT16: + if (Decimal::MAX_WIDTH_INT16 + abs_scale_exp < NumericHelper::CACHED_POWERS_OF_TEN) { + bound_function.SetFunctionCallback(GetDecimalDivExecuteFunction(result_physical)); + } else { + bound_function.SetFunctionCallback(GetDecimalDivExecuteFunction(result_physical)); + } + break; + case PhysicalType::INT32: + if (Decimal::MAX_WIDTH_INT32 + abs_scale_exp < NumericHelper::CACHED_POWERS_OF_TEN) { + bound_function.SetFunctionCallback(GetDecimalDivExecuteFunction(result_physical)); + } else { + bound_function.SetFunctionCallback(GetDecimalDivExecuteFunction(result_physical)); + } + break; + case PhysicalType::INT64: + if (Decimal::MAX_WIDTH_INT64 + abs_scale_exp < NumericHelper::CACHED_POWERS_OF_TEN) { + bound_function.SetFunctionCallback(GetDecimalDivExecuteFunction(result_physical)); + } else { + bound_function.SetFunctionCallback(GetDecimalDivExecuteFunction(result_physical)); + } + break; + default: + bound_function.SetFunctionCallback(GetDecimalDivExecuteFunction(result_physical)); + break; + } + + return make_uniq(scale_exp); +} + +//===--------------------------------------------------------------------===// +// Registration +//===--------------------------------------------------------------------===// + +ScalarFunctionSet DecimalDivisionFun::GetFunctions() { + auto decimal_type = LogicalType(LogicalTypeId::DECIMAL); + ScalarFunctionSet set("decimal_division"); + + ScalarFunction two_arg({decimal_type, decimal_type}, decimal_type, + DecimalDivExecute, DecimalDivisionBind); + two_arg.SetFallible(); + set.AddFunction(two_arg); + + ScalarFunction three_arg({decimal_type, decimal_type, LogicalType::INTEGER}, decimal_type, + DecimalDivExecute, DecimalDivisionBind); + three_arg.SetFallible(); + set.AddFunction(three_arg); + + return set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/sequence/nextval.cpp b/src/duckdb/src/function/scalar/sequence/nextval.cpp index ded3f5af5..cfc02424e 100644 --- a/src/duckdb/src/function/scalar/sequence/nextval.cpp +++ b/src/duckdb/src/function/scalar/sequence/nextval.cpp @@ -31,21 +31,21 @@ struct NextSequenceValueOperator { } }; -SequenceCatalogEntry &BindSequence(Binder &binder, string &catalog, string &schema, const string &name) { +SequenceCatalogEntry &BindSequence(Binder &binder, Identifier &catalog, Identifier &schema, const Identifier &name) { // fetch the sequence from the catalog Binder::BindSchemaOrCatalog(binder.context, catalog, schema); EntryLookupInfo sequence_lookup(CatalogType::SEQUENCE_ENTRY, name); return binder.EntryRetriever().GetEntry(catalog, schema, sequence_lookup)->Cast(); } -SequenceCatalogEntry &BindSequenceFromContext(ClientContext &context, string &catalog, string &schema, - const string &name) { +SequenceCatalogEntry &BindSequenceFromContext(ClientContext &context, Identifier &catalog, Identifier &schema, + const Identifier &name) { Binder::BindSchemaOrCatalog(context, catalog, schema); return Catalog::GetEntry(context, catalog, schema, name); } -SequenceCatalogEntry &BindSequence(Binder &binder, const string &name) { - auto qname = QualifiedName::Parse(name); +SequenceCatalogEntry &BindSequence(Binder &binder, const Identifier &name) { + auto qname = QualifiedName::Parse(name.GetIdentifierName()); return BindSequence(binder, qname.catalog, qname.schema, qname.name); } @@ -73,7 +73,7 @@ unique_ptr NextValLocalFunction(ExpressionState &state, cons template void NextValFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - if (!func_expr.bind_info) { + if (!func_expr.BindInfo()) { // no bind info - return null ConstantVector::SetNull(result, count_t(args.size())); return; @@ -107,7 +107,7 @@ unique_ptr NextValBind(BindScalarFunctionInput &input) { if (seqname.IsNull()) { return nullptr; } - auto &seq = BindSequence(binder, seqname.ToString()); + auto &seq = BindSequence(binder, Identifier(seqname.ToString())); return make_uniq(seq); } diff --git a/src/duckdb/src/function/scalar/strftime_format.cpp b/src/duckdb/src/function/scalar/strftime_format.cpp index 11fe59393..51bd7822b 100644 --- a/src/duckdb/src/function/scalar/strftime_format.cpp +++ b/src/duckdb/src/function/scalar/strftime_format.cpp @@ -667,11 +667,11 @@ string StrTimeFormat::ParseFormatSpecifier(const string &format_string, StrTimeF return string(); } -void StrfTimeFormat::ConvertDateVector(Vector &input, Vector &result, idx_t count) { +void StrfTimeFormat::ConvertDateVector(const Vector &input, Vector &result) { D_ASSERT(input.GetType().id() == LogicalTypeId::DATE); D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(input, result, count, [&](date_t input) { + UnaryExecutor::Execute(input, result, [&](date_t input) { if (input.IsFinite()) { dtime_t time(0); idx_t len = GetLength(input, time, 0, nullptr); @@ -733,21 +733,21 @@ string_t StrfTimeFormat::ConvertTimestampValue(const timestamp_ns_t &input, Stri } } -void StrfTimeFormat::ConvertTimestampVector(Vector &input, Vector &result, idx_t count) { +void StrfTimeFormat::ConvertTimestampVector(const Vector &input, Vector &result) { D_ASSERT(input.GetType().id() == LogicalTypeId::TIMESTAMP || input.GetType().id() == LogicalTypeId::TIMESTAMP_TZ); D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); auto &heap = StringVector::GetStringHeap(result); - UnaryExecutor::Execute(input, result, count, + UnaryExecutor::Execute(input, result, [&](timestamp_t ts) { return ConvertTimestampValue(ts, heap); }); } -void StrfTimeFormat::ConvertTimestampNSVector(Vector &input, Vector &result, idx_t count) { +void StrfTimeFormat::ConvertTimestampNSVector(const Vector &input, Vector &result) { D_ASSERT(input.GetType().id() == LogicalTypeId::TIMESTAMP_NS || input.GetType().id() == LogicalTypeId::TIMESTAMP_TZ_NS); D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); auto &heap = StringVector::GetStringHeap(result); UnaryExecutor::Execute( - input, result, count, [&](timestamp_ns_t ts) { return ConvertTimestampValue(ts, heap); }); + input, result, [&](timestamp_ns_t ts) { return ConvertTimestampValue(ts, heap); }); } void StrpTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { @@ -810,7 +810,7 @@ int32_t StrpTimeFormat::TryParseCollection(const char *data, idx_t &pos, idx_t s // compare the characters idx_t i; for (i = 0; i < entry_size; i++) { - if (std::tolower(entry_data[i]) != std::tolower(data[pos + i])) { + if (StringUtil::CharacterToLower(entry_data[i]) != StringUtil::CharacterToLower(data[pos + i])) { break; } } @@ -1243,8 +1243,8 @@ bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result, b error_position = pos; return false; } - char pa_char = char(std::tolower(data[pos])); - char m_char = char(std::tolower(data[pos + 1])); + char pa_char = StringUtil::CharacterToLower(data[pos]); + char m_char = StringUtil::CharacterToLower(data[pos + 1]); if (m_char != 'm') { error_message = "Expected AM/PM"; error_position = pos; diff --git a/src/duckdb/src/function/scalar/string/caseconvert.cpp b/src/duckdb/src/function/scalar/string/caseconvert.cpp index b3e151b13..0cea5f4e4 100644 --- a/src/duckdb/src/function/scalar/string/caseconvert.cpp +++ b/src/duckdb/src/function/scalar/string/caseconvert.cpp @@ -107,7 +107,7 @@ struct CaseConvertOperator { template static void CaseConvertFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString>(args.data[0], result); } namespace { @@ -124,8 +124,7 @@ struct CaseConvertOperatorASCII { template static void CaseConvertFunctionASCII(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, - args.size()); + UnaryExecutor::ExecuteString>(args.data[0], result); } template @@ -135,7 +134,7 @@ static unique_ptr CaseConvertPropagateStats(ClientContext &conte D_ASSERT(child_stats.size() == 1); // can only propagate stats if the children have stats if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.SetFunctionCallback(CaseConvertFunctionASCII); + expr.FunctionMutable().SetFunctionCallback(CaseConvertFunctionASCII); } return nullptr; } diff --git a/src/duckdb/src/function/scalar/string/concat.cpp b/src/duckdb/src/function/scalar/string/concat.cpp index 998e5aa52..a0d7cb3e1 100644 --- a/src/duckdb/src/function/scalar/string/concat.cpp +++ b/src/duckdb/src/function/scalar/string/concat.cpp @@ -42,7 +42,7 @@ void StringConcatFunction(DataChunk &args, ExpressionState &state, Vector &resul idx_t constant_lengths = 0; vector result_lengths(args.size(), 0); for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - auto &input = args.data[col_idx]; + const auto &input = args.data[col_idx]; D_ASSERT(input.GetType().InternalType() == PhysicalType::VARCHAR); if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (ConstantVector::IsNull(input)) { @@ -74,7 +74,7 @@ void StringConcatFunction(DataChunk &args, ExpressionState &state, Vector &resul // now that the empty space for the strings has been allocated, perform the concatenation for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - auto &input = args.data[col_idx]; + const auto &input = args.data[col_idx]; // loop over the vector and concat to all results if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -112,7 +112,7 @@ void StringConcatFunction(DataChunk &args, ExpressionState &state, Vector &resul void ConcatOperator(DataChunk &args, ExpressionState &state, Vector &result) { BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t a, string_t b) { + args.data[0], args.data[1], result, [&](string_t a, string_t b) { auto a_data = a.GetData(); auto b_data = b.GetData(); auto a_length = a.GetSize(); @@ -130,7 +130,7 @@ void ConcatOperator(DataChunk &args, ExpressionState &state, Vector &result) { } struct ListConcatInputData { - ListConcatInputData(Vector &input, idx_t size) + ListConcatInputData(const Vector &input, idx_t size) : input(input), child_vec(ListVector::GetChild(input)), list_data(input.Values()) { } @@ -143,7 +143,7 @@ void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result, auto count = args.size(); vector input_data; - for (auto &input : args.data) { + for (const auto &input : args.data) { if (!is_operator && input.GetType().id() == LogicalTypeId::SQLNULL) { // LIST_CONCAT ignores NULL values continue; @@ -184,7 +184,7 @@ void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result, void ConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); if (info.return_type.id() == LogicalTypeId::SQLNULL) { result.SetVectorType(VectorType::CONSTANT_VECTOR); return; diff --git a/src/duckdb/src/function/scalar/string/length.cpp b/src/duckdb/src/function/scalar/string/length.cpp index 9dea0bf49..7555c5d57 100644 --- a/src/duckdb/src/function/scalar/string/length.cpp +++ b/src/duckdb/src/function/scalar/string/length.cpp @@ -63,7 +63,7 @@ unique_ptr LengthPropagateStats(ClientContext &context, Function D_ASSERT(child_stats.size() == 1); // can only propagate stats if the children have stats if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.SetFunctionCallback(ScalarFunction::UnaryFunction); + expr.FunctionMutable().SetFunctionCallback(ScalarFunction::UnaryFunction); } return nullptr; } @@ -72,14 +72,14 @@ unique_ptr LengthPropagateStats(ClientContext &context, Function // ARRAY / LIST LENGTH //------------------------------------------------------------------ void ListLengthFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; + const auto &input = args.data[0]; D_ASSERT(input.GetType().id() == LogicalTypeId::LIST); UnaryExecutor::Execute( - input, result, args.size(), [](list_entry_t input) { return UnsafeNumericCast(input.length); }); + input, result, [](list_entry_t input) { return UnsafeNumericCast(input.length); }); } void ArrayLengthFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; + const auto &input = args.data[0]; auto array_size = static_cast(ArrayType::GetSize(input.GetType())); auto validity_entries = input.Validity(); @@ -126,10 +126,10 @@ unique_ptr ArrayOrListLengthBind(BindScalarFunctionInput &input) { //------------------------------------------------------------------ void ListLengthBinaryFunction(DataChunk &args, ExpressionState &, Vector &result) { auto type = args.data[0].GetType(); - auto &input = args.data[0]; - auto &dimension = args.data[1]; + const auto &input = args.data[0]; + const auto &dimension = args.data[1]; BinaryExecutor::Execute( - input, dimension, result, args.size(), [](list_entry_t input, int64_t dimension) { + input, dimension, result, [](list_entry_t input, int64_t dimension) { if (dimension != 1) { throw NotImplementedException("array_length for lists with dimensions other than 1 not implemented"); } @@ -154,14 +154,14 @@ struct ArrayLengthBinaryFunctionData : public FunctionData { void ArrayLengthBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto type = args.data[0].GetType(); - auto &dimension = args.data[1]; + const auto &dimension = args.data[1]; auto &expr = state.expr.Cast(); - auto &data = expr.bind_info->Cast(); + auto &data = expr.BindInfo()->Cast(); auto &dimensions = data.dimensions; auto max_dimension = static_cast(dimensions.size()); - UnaryExecutor::Execute(dimension, result, args.size(), [&](int64_t dimension) { + UnaryExecutor::Execute(dimension, result, [&](int64_t dimension) { if (dimension < 1 || dimension > max_dimension) { throw OutOfRangeException(StringUtil::Format( "array_length dimension '%lld' out of range (min: '1', max: '%lld')", dimension, max_dimension)); diff --git a/src/duckdb/src/function/scalar/string/like.cpp b/src/duckdb/src/function/scalar/string/like.cpp index b6bdaed9a..c728dba20 100644 --- a/src/duckdb/src/function/scalar/string/like.cpp +++ b/src/duckdb/src/function/scalar/string/like.cpp @@ -496,12 +496,12 @@ struct GlobOperator { // This can be moved to the scalar_function class template void LikeEscapeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str = args.data[0]; - auto &pattern = args.data[1]; - auto &escape = args.data[2]; + const auto &str = args.data[0]; + const auto &pattern = args.data[1]; + const auto &escape = args.data[2]; TernaryExecutor::Execute( - str, pattern, escape, result, args.size(), FUNC::template Operation); + str, pattern, escape, result, FUNC::template Operation); } template @@ -511,7 +511,7 @@ unique_ptr ILikePropagateStats(ClientContext &context, FunctionS D_ASSERT(child_stats.size() >= 1); // can only propagate stats if the children have stats if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.SetFunctionCallback(ScalarFunction::BinaryFunction); + expr.FunctionMutable().SetFunctionCallback(ScalarFunction::BinaryFunction); } return nullptr; } @@ -519,16 +519,15 @@ unique_ptr ILikePropagateStats(ClientContext &context, FunctionS template void RegularLikeFunction(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - if (func_expr.bind_info) { - auto &matcher = func_expr.bind_info->Cast(); + if (func_expr.BindInfo()) { + auto &matcher = func_expr.BindInfo()->Cast(); // use fast like matcher - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](string_t input) { + UnaryExecutor::Execute(input.data[0], result, [&](string_t input) { return INVERT ? !matcher.Match(input) : matcher.Match(input); }); } else { // use generic like matcher - BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result, - input.size()); + BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result); } } diff --git a/src/duckdb/src/function/scalar/string/md5.cpp b/src/duckdb/src/function/scalar/string/md5.cpp index 93e95dfb2..af7db74c4 100644 --- a/src/duckdb/src/function/scalar/string/md5.cpp +++ b/src/duckdb/src/function/scalar/string/md5.cpp @@ -33,15 +33,15 @@ struct MD5Number128Operator { }; void MD5Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; + const auto &input = args.data[0]; - UnaryExecutor::ExecuteString(input, result, args.size()); + UnaryExecutor::ExecuteString(input, result); } void MD5NumberFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; + const auto &input = args.data[0]; - UnaryExecutor::Execute(input, result, args.size()); + UnaryExecutor::Execute(input, result); } } // namespace diff --git a/src/duckdb/src/function/scalar/string/nfc_normalize.cpp b/src/duckdb/src/function/scalar/string/nfc_normalize.cpp index 29d97606e..6f36105e2 100644 --- a/src/duckdb/src/function/scalar/string/nfc_normalize.cpp +++ b/src/duckdb/src/function/scalar/string/nfc_normalize.cpp @@ -25,7 +25,7 @@ struct NFCNormalizeOperator { void NFCNormalizeFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString(args.data[0], result); StringVector::AddHeapReference(result, args.data[0]); } diff --git a/src/duckdb/src/function/scalar/string/overlay.cpp b/src/duckdb/src/function/scalar/string/overlay.cpp new file mode 100644 index 000000000..fa85baa4d --- /dev/null +++ b/src/duckdb/src/function/scalar/string/overlay.cpp @@ -0,0 +1,107 @@ +#include "duckdb/function/scalar/string_common.hpp" +#include "duckdb/function/scalar/string_functions.hpp" + +namespace duckdb { + +namespace { + +// Returns the byte offset of the start of the `char_pos`th character (1-based). +// Returns input_size if char_pos is past the end of the string. +// Clamps to 0 if char_pos <= 1. +static idx_t CharPosToByte(const char *data, idx_t input_size, int64_t char_pos) { + if (char_pos <= 1) { + return 0; + } + int64_t current = 1; + for (idx_t i = 0; i < input_size; i++) { + if (IsCharacter(data[i])) { + if (current == char_pos) { + return i; + } + current++; + } + } + return input_size; +} + +static int64_t CharLength(const char *data, idx_t size) { + int64_t length = 0; + for (idx_t i = 0; i < size; i++) { + length += IsCharacter(data[i]); + } + return length; +} + +void OverlayFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto count = args.size(); + + auto input_values = args.data[0].Values(); + auto repl_values = args.data[1].Values(); + auto start_values = args.data[2].Values(); + optional> count_values; + if (args.ColumnCount() == 4) { + count_values = args.data[3].Values(); + } + + auto result_data = FlatVector::Writer(result, count); + for (idx_t i = 0; i < count; i++) { + auto input_entry = input_values[i]; + auto repl_entry = repl_values[i]; + auto start_entry = start_values[i]; + + if (!input_entry.IsValid() || !repl_entry.IsValid() || !start_entry.IsValid()) { + result_data.WriteNull(); + continue; + } + auto input = input_entry.GetValue(); + auto repl = repl_entry.GetValue(); + int64_t char_count = CharLength(repl.GetData(), repl.GetSize()); + if (count_values.has_value()) { + auto count_entry = count_values.value()[i]; + if (!count_entry.IsValid()) { + result_data.WriteNull(); + continue; + } + if (count_entry.GetValue() >= 0) { + char_count = count_entry.GetValue(); + } + } + + auto input_ptr = input.GetData(); + auto input_size = input.GetSize(); + auto repl_ptr = repl.GetData(); + auto repl_size = repl.GetSize(); + auto start = start_entry.GetValue(); + + idx_t byte_start = CharPosToByte(input_ptr, input_size, start); + int64_t end_pos = (char_count > 0 && start > NumericLimits::Maximum() - char_count) + ? NumericLimits::Maximum() + : start + char_count; + idx_t byte_end = CharPosToByte(input_ptr, input_size, end_pos); + idx_t after_size = input_size - byte_end; + idx_t blob_size = byte_start + repl_size + after_size; + + auto &blob = result_data.WriteEmptyString(blob_size); + auto blob_ptr = blob.GetDataWriteable(); + + memcpy(blob_ptr, input_ptr, byte_start); + memcpy(blob_ptr + byte_start, repl_ptr, repl_size); + memcpy(blob_ptr + byte_start + repl_size, input_ptr + byte_end, after_size); + + blob.Finalize(); + } +} + +} // namespace + +ScalarFunctionSet OverlayFun::GetFunctions() { + ScalarFunctionSet overlay_set("overlay"); + overlay_set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::BIGINT}, + LogicalType::VARCHAR, OverlayFunction)); + overlay_set.AddFunction( + ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::BIGINT, LogicalType::BIGINT}, + LogicalType::VARCHAR, OverlayFunction)); + return overlay_set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/regexp.cpp b/src/duckdb/src/function/scalar/string/regexp.cpp index ebc3ee87c..b21e37425 100644 --- a/src/duckdb/src/function/scalar/string/regexp.cpp +++ b/src/duckdb/src/function/scalar/string/regexp.cpp @@ -128,19 +128,19 @@ struct RegexFullMatch { template static void RegexpMatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &strings = args.data[0]; - auto &patterns = args.data[1]; + const auto &strings = args.data[0]; + const auto &patterns = args.data[1]; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); if (info.constant_pattern) { auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - UnaryExecutor::Execute(strings, result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(strings, result, [&](string_t input) { return OP::Operation(CreateStringPiece(input), lstate.constant_pattern); }); } else { - BinaryExecutor::Execute(strings, patterns, result, args.size(), + BinaryExecutor::Execute(strings, patterns, result, [&](string_t input, string_t pattern) { RE2 re(CreateStringPiece(pattern), info.options); if (!re.ok()) { @@ -187,17 +187,17 @@ static unique_ptr RegexReplaceBind(BindScalarFunctionInput &input) static void RegexReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - auto &replaces = args.data[2]; + const auto &strings = args.data[0]; + const auto &patterns = args.data[1]; + const auto &replaces = args.data[2]; auto &heap = StringVector::GetStringHeap(result); if (info.constant_pattern) { auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); BinaryExecutor::Execute( - strings, replaces, result, args.size(), [&](string_t input, string_t replace) { + strings, replaces, result, [&](string_t input, string_t replace) { std::string sstring = input.GetString(); if (info.global_replace) { RE2::GlobalReplace(&sstring, lstate.constant_pattern, CreateStringPiece(replace)); @@ -208,7 +208,7 @@ static void RegexReplaceFunction(DataChunk &args, ExpressionState &state, Vector }); } else { TernaryExecutor::Execute( - strings, patterns, replaces, result, args.size(), [&](string_t input, string_t pattern, string_t replace) { + strings, patterns, replaces, result, [&](string_t input, string_t pattern, string_t replace) { RE2 re(CreateStringPiece(pattern), info.options); if (!re.ok()) { throw InvalidInputException(re.error()); @@ -249,10 +249,10 @@ bool RegexpExtractBindData::Equals(const FunctionData &other_p) const { static void RegexExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); - auto &strings = args.data[0]; - auto &patterns = args.data[1]; + const auto &strings = args.data[0]; + const auto &patterns = args.data[1]; // Result strings are zero-copy slices of the input vector (or the input itself when // no_match_returns_input is set), so register the heap reference once per chunk regardless of // per-row outcomes. @@ -261,12 +261,12 @@ static void RegexExtractFunction(DataChunk &args, ExpressionState &state, Vector if (info.constant_pattern) { auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); const auto &re = lstate.constant_pattern; - UnaryExecutor::Execute(strings, result, args.size(), [&](string_t input) { + UnaryExecutor::Execute(strings, result, [&](string_t input) { return ExtractCaptureGroup(input, re, info.group_index, info.no_match_returns_input); }); } else { BinaryExecutor::Execute( - strings, patterns, result, args.size(), [&](string_t input, string_t pattern) { + strings, patterns, result, [&](string_t input, string_t pattern) { RE2 re(CreateStringPiece(pattern), info.options); return ExtractCaptureGroup(input, re, info.group_index, info.no_match_returns_input); }); @@ -284,8 +284,7 @@ static void RegexExtractStructFunction(DataChunk &args, ExpressionState &state, } auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - const auto count = args.size(); - auto &input = args.data[0]; + const auto &input = args.data[0]; auto &child_entries = StructVector::GetEntries(result); const auto groupSize = child_entries.size(); @@ -307,7 +306,7 @@ static void RegexExtractStructFunction(DataChunk &args, ExpressionState &state, result.SetVectorType(VectorType::CONSTANT_VECTOR); if (ConstantVector::IsNull(input)) { - ConstantVector::SetNull(result, count_t(count)); + ConstantVector::SetNull(result, count_t(args.size())); } else { ConstantVector::SetNull(result, false); auto idata = ConstantVector::GetData(input); @@ -386,8 +385,8 @@ static unique_ptr RegexExtractBind(BindScalarFunctionInput &input) } vector dummy_names; // not reused after bind child_list_t struct_children; - regexp_util::ParseGroupNameList(context, bound_function.GetName(), *arguments[2], constant_string, options, - constant_pattern, dummy_names, struct_children); + regexp_util::ParseGroupNameList(context, bound_function.GetName().GetIdentifierName(), *arguments[2], + constant_string, options, constant_pattern, dummy_names, struct_children); bound_function.SetReturnType(LogicalType::STRUCT(struct_children)); } else { int32_t group_idx = group.GetValue(); diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp index 470362496..4aa4dabe3 100644 --- a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp +++ b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp @@ -107,10 +107,10 @@ duckdb_re2::RE2 &GetPattern(const RegexpBaseBindData &info, ExpressionState &sta void RegexpExtractAll::Execute(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); - auto &strings = args.data[0]; - auto &patterns = args.data[1]; + const auto &strings = args.data[0]; + const auto &patterns = args.data[1]; D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); auto strings_entries = strings.Values(); @@ -242,13 +242,13 @@ static list_entry_t ExtractStructAllSingleTuple(const string_t &string_val, duck void RegexpExtractAllStruct::Execute(DataChunk &args, ExpressionState &state, Vector &result) { #ifdef D_ASSERT_IS_ENABLED auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); // Struct multi-match variant only supports constant pattern (enforced in Bind) D_ASSERT(info.constant_pattern); #endif // Expect arguments: string, pattern, list_of_group_names [, options] - auto &strings = args.data[0]; + const auto &strings = args.data[0]; D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); auto &struct_vector = ListVector::GetChildMutable(result); @@ -304,8 +304,8 @@ unique_ptr RegexpExtractAllStruct::Bind(BindScalarFunctionInput &i options.set_log_errors(false); vector group_names; child_list_t struct_children; - regexp_util::ParseGroupNameList(context, function.GetName(), *arguments[2], constant_string, options, true, - group_names, struct_children); + regexp_util::ParseGroupNameList(context, function.GetName().GetIdentifierName(), *arguments[2], constant_string, + options, true, group_names, struct_children); function.SetReturnType(LogicalType::LIST(LogicalType::STRUCT(struct_children))); return make_uniq(options, std::move(constant_string), constant_pattern, std::move(group_names)); diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp index e9f6dd4b9..84e955afb 100644 --- a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp +++ b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp @@ -107,16 +107,16 @@ void ParseGroupNameList(ClientContext &context, const string &function_name, Exp if (children.empty()) { throw BinderException("Group name list must be non-empty"); } - case_insensitive_set_t name_set; + identifier_set_t name_set; for (auto &child : children) { if (child.IsNull()) { throw BinderException("NULL group name in %s", function_name); } auto name = child.ToString(); - if (name_set.find(name) != name_set.end()) { + if (name_set.find(Identifier(name)) != name_set.end()) { throw BinderException("Duplicate group name '%s' in %s", name, function_name); } - name_set.insert(name); + name_set.insert(Identifier(name)); out_names.push_back(name); out_struct_children.emplace_back(make_pair(name, LogicalType::VARCHAR)); } diff --git a/src/duckdb/src/function/scalar/string/regexp_escape.cpp b/src/duckdb/src/function/scalar/string/regexp_escape.cpp index f998754d5..4aa85ba9a 100644 --- a/src/duckdb/src/function/scalar/string/regexp_escape.cpp +++ b/src/duckdb/src/function/scalar/string/regexp_escape.cpp @@ -14,7 +14,7 @@ struct EscapeOperator { }; void RegexpEscapeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString(args.data[0], result); } } // namespace diff --git a/src/duckdb/src/function/scalar/string/sha1.cpp b/src/duckdb/src/function/scalar/string/sha1.cpp index a6ff99735..0af67cbd1 100644 --- a/src/duckdb/src/function/scalar/string/sha1.cpp +++ b/src/duckdb/src/function/scalar/string/sha1.cpp @@ -22,9 +22,9 @@ struct SHA1Operator { }; void SHA1Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; + const auto &input = args.data[0]; - UnaryExecutor::ExecuteString(input, result, args.size()); + UnaryExecutor::ExecuteString(input, result); } } // namespace diff --git a/src/duckdb/src/function/scalar/string/sha256.cpp b/src/duckdb/src/function/scalar/string/sha256.cpp index 1220f5414..e3fabbdce 100644 --- a/src/duckdb/src/function/scalar/string/sha256.cpp +++ b/src/duckdb/src/function/scalar/string/sha256.cpp @@ -22,9 +22,9 @@ struct SHA256Operator { }; void SHA256Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; + const auto &input = args.data[0]; - UnaryExecutor::ExecuteString(input, result, args.size()); + UnaryExecutor::ExecuteString(input, result); } } // namespace diff --git a/src/duckdb/src/function/scalar/string/string_split.cpp b/src/duckdb/src/function/scalar/string/string_split.cpp index 0b1bac3b1..76cdd2187 100644 --- a/src/duckdb/src/function/scalar/string/string_split.cpp +++ b/src/duckdb/src/function/scalar/string/string_split.cpp @@ -124,7 +124,7 @@ void StringSplitFunction(DataChunk &args, ExpressionState &state, Vector &result void StringSplitRegexFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); if (info.constant_pattern) { // fast path: pre-compiled regex auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); diff --git a/src/duckdb/src/function/scalar/string/strip_accents.cpp b/src/duckdb/src/function/scalar/string/strip_accents.cpp index 1c15773ca..a0ff143f5 100644 --- a/src/duckdb/src/function/scalar/string/strip_accents.cpp +++ b/src/duckdb/src/function/scalar/string/strip_accents.cpp @@ -48,7 +48,7 @@ struct StripAccentsOperator { void StripAccentsFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() == 1); - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); + UnaryExecutor::ExecuteString(args.data[0], result); StringVector::AddHeapReference(result, args.data[0]); } diff --git a/src/duckdb/src/function/scalar/string/substring.cpp b/src/duckdb/src/function/scalar/string/substring.cpp index 0e4e6c9d9..110cd3fb9 100644 --- a/src/duckdb/src/function/scalar/string/substring.cpp +++ b/src/duckdb/src/function/scalar/string/substring.cpp @@ -265,38 +265,38 @@ struct SubstringGraphemeOp { template void SubstringFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input_vector = args.data[0]; - auto &offset_vector = args.data[1]; + const auto &input_vector = args.data[0]; + const auto &offset_vector = args.data[1]; if (args.ColumnCount() == 3) { - auto &length_vector = args.data[2]; + const auto &length_vector = args.data[2]; TernaryExecutor::Execute( - input_vector, offset_vector, length_vector, result, args.size(), + input_vector, offset_vector, length_vector, result, [&](string_t input_string, int64_t offset, int64_t length) { return OP::Substring(result, input_string, offset, length); }); } else { BinaryExecutor::Execute( - input_vector, offset_vector, result, args.size(), [&](string_t input_string, int64_t offset) { + input_vector, offset_vector, result, [&](string_t input_string, int64_t offset) { return OP::Substring(result, input_string, offset, NumericLimits::Maximum()); }); } } void SubstringFunctionASCII(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input_vector = args.data[0]; - auto &offset_vector = args.data[1]; + const auto &input_vector = args.data[0]; + const auto &offset_vector = args.data[1]; if (args.ColumnCount() == 3) { - auto &length_vector = args.data[2]; + const auto &length_vector = args.data[2]; TernaryExecutor::Execute( - input_vector, offset_vector, length_vector, result, args.size(), + input_vector, offset_vector, length_vector, result, [&](string_t input_string, int64_t offset, int64_t length) { return SubstringASCII(result, input_string, offset, length); }); } else { BinaryExecutor::Execute( - input_vector, offset_vector, result, args.size(), [&](string_t input_string, int64_t offset) { + input_vector, offset_vector, result, [&](string_t input_string, int64_t offset) { return SubstringASCII(result, input_string, offset, NumericLimits::Maximum()); }); } @@ -308,7 +308,7 @@ unique_ptr SubstringPropagateStats(ClientContext &context, Funct // can only propagate stats if the children have stats // we only care about the stats of the first child (i.e. the string) if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.SetFunctionCallback(SubstringFunctionASCII); + expr.FunctionMutable().SetFunctionCallback(SubstringFunctionASCII); } return nullptr; } diff --git a/src/duckdb/src/function/scalar/struct/remap_struct.cpp b/src/duckdb/src/function/scalar/struct/remap_struct.cpp index 250daf4dc..8b4083723 100644 --- a/src/duckdb/src/function/scalar/struct/remap_struct.cpp +++ b/src/duckdb/src/function/scalar/struct/remap_struct.cpp @@ -231,7 +231,7 @@ void RemapNested(Vector &input, Vector &default_vector, Vector &result, idx_t re void RemapStructFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); auto &input = args.data[0]; @@ -240,10 +240,10 @@ void RemapStructFunction(DataChunk &args, ExpressionState &state, Vector &result struct RemapIndex { idx_t index; LogicalType type; - unique_ptr> child_map; + unique_ptr> child_map; - static case_insensitive_map_t GetMap(const LogicalType &type) { - case_insensitive_map_t result; + static identifier_map_t GetMap(const LogicalType &type) { + identifier_map_t result; switch (type.id()) { case LogicalTypeId::STRUCT: { auto &children = StructType::GetChildTypes(type); @@ -276,7 +276,7 @@ struct RemapIndex { index.index = idx; index.type = type; if (IsRemappable(type)) { - index.child_map = make_uniq>(GetMap(type)); + index.child_map = make_uniq>(GetMap(type)); } return index; } @@ -286,12 +286,11 @@ struct RemapEntry { optional_idx index; optional_idx default_index; LogicalType target_type; - unique_ptr> child_remaps; + unique_ptr> child_remaps; - static void PerformRemap(const string &remap_target, const Value &remap_val, - case_insensitive_map_t &source_map, - case_insensitive_map_t &target_map, case_insensitive_map_t &result, - const LogicalType &parent_type) { + static void PerformRemap(const Identifier &remap_target, const Value &remap_val, + identifier_map_t &source_map, identifier_map_t &target_map, + identifier_map_t &result, const LogicalType &parent_type) { string remap_source; Value struct_val; if (remap_val.type().id() == LogicalTypeId::VARCHAR) { @@ -314,13 +313,13 @@ struct RemapEntry { } // find the source index - auto entry = source_map.find(remap_source); + auto entry = source_map.find(Identifier(remap_source)); if (entry == source_map.end()) { throw BinderException("Source value %s not found", remap_source); } auto target_entry = target_map.find(remap_target); if (target_entry == target_map.end()) { - throw BinderException("Target value %s not found", remap_target); + throw BinderException("Target value %s not found", remap_target.GetIdentifierName()); } auto &source_type = entry->second.type; @@ -344,7 +343,7 @@ struct RemapEntry { struct_val.ToString(), entry->second.type.ToString(), target_entry->second.type.ToString()); } - remap.child_remaps = make_uniq>(); + remap.child_remaps = make_uniq>(); auto &remap_types = StructType::GetChildTypes(struct_val.type()); auto &remap_values = StructValue::GetChildren(struct_val); for (idx_t child_idx = 0; child_idx < remap_types.size(); child_idx++) { @@ -357,9 +356,8 @@ struct RemapEntry { } static void HandleDefault(idx_t default_idx, const string &default_target, const LogicalType &default_type, - case_insensitive_map_t &target_map, - case_insensitive_map_t &result) { - auto entry = target_map.find(default_target); + identifier_map_t &target_map, identifier_map_t &result) { + auto entry = target_map.find(Identifier(default_target)); if (entry == target_map.end()) { throw BinderException("Default value %s not found for remap", default_target); } @@ -374,11 +372,11 @@ struct RemapEntry { target_type.ToString()); } // add to the map at this level only if it does not yet exist - auto result_entry = result.find(default_target); + auto result_entry = result.find(Identifier(default_target)); if (result_entry == result.end()) { result.emplace(default_target, std::move(remap)); - result_entry = result.find(default_target); - result_entry->second.child_remaps = make_uniq>(); + result_entry = result.find(Identifier(default_target)); + result_entry->second.child_remaps = make_uniq>(); } else { // the entry exists - add the default index result_entry->second.default_index = default_idx; @@ -389,8 +387,8 @@ struct RemapEntry { if (!result_entry->second.child_remaps || !entry->second.child_map) { throw BinderException("No child remaps found"); } - HandleDefault(child_idx, child_default.first, child_default.second, *entry->second.child_map, - *result_entry->second.child_remaps); + HandleDefault(child_idx, child_default.first.GetIdentifierName(), child_default.second, + *entry->second.child_map, *result_entry->second.child_remaps); } return; } @@ -406,7 +404,7 @@ struct RemapEntry { } static vector ConstructMapFromChildren(const child_list_t &target_children, - const case_insensitive_map_t &remap_map) { + const identifier_map_t &remap_map) { vector result; for (idx_t target_idx = 0; target_idx < target_children.size(); target_idx++) { auto &target_name = target_children[target_idx].first; @@ -428,7 +426,7 @@ struct RemapEntry { } static vector ConstructMap(const LogicalType &type, - const case_insensitive_map_t &remap_map) { + const identifier_map_t &remap_map) { D_ASSERT(IsRemappable(type)); switch (type.id()) { case LogicalTypeId::STRUCT: { @@ -455,7 +453,7 @@ struct RemapEntry { } static child_list_t RemapCastChildren(const child_list_t &source_children, - const case_insensitive_map_t &remap_map, + const identifier_map_t &remap_map, const unordered_map &source_name_map) { child_list_t new_source_children; for (idx_t source_idx = 0; source_idx < source_children.size(); source_idx++) { @@ -463,7 +461,7 @@ struct RemapEntry { auto &child_type = source_children[source_idx].second; auto entry = source_name_map.find(source_idx); if (entry != source_name_map.end()) { - auto remap_entry = remap_map.find(entry->second); + auto remap_entry = remap_map.find(Identifier(entry->second)); D_ASSERT(remap_entry != remap_map.end()); // this entry is remapped - fetch the target type if (IsRemappable(child_type) && remap_entry->second.child_remaps) { @@ -481,7 +479,7 @@ struct RemapEntry { return new_source_children; } - static LogicalType RemapCast(const LogicalType &type, const case_insensitive_map_t &remap_map) { + static LogicalType RemapCast(const LogicalType &type, const identifier_map_t &remap_map) { unordered_map source_name_map; for (auto &entry : remap_map) { if (entry.second.index.IsValid()) { @@ -569,7 +567,7 @@ unique_ptr RemapStructBind(BindScalarFunctionInput &input) { Value remap_val = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); // (recursively) generate the remap entries - case_insensitive_map_t remap_map; + identifier_map_t remap_map; if (!remap_val.IsNull()) { auto &remap_types = StructType::GetChildTypes(arguments[2]->GetReturnType()); auto &remap_values = StructValue::GetChildren(remap_val); @@ -589,7 +587,8 @@ unique_ptr RemapStructBind(BindScalarFunctionInput &input) { for (idx_t default_idx = 0; default_idx < default_types.size(); default_idx++) { auto &default_target = default_types[default_idx].first; auto &default_type = default_types[default_idx].second; - RemapEntry::HandleDefault(default_idx, default_target, default_type, target_map, remap_map); + RemapEntry::HandleDefault(default_idx, default_target.GetIdentifierName(), default_type, target_map, + remap_map); } } diff --git a/src/duckdb/src/function/scalar/struct/struct_concat.cpp b/src/duckdb/src/function/scalar/struct/struct_concat.cpp index 9c640f87d..ae81c6efd 100644 --- a/src/duckdb/src/function/scalar/struct/struct_concat.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_concat.cpp @@ -12,7 +12,7 @@ namespace duckdb { static void StructConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &result_cols = StructVector::GetEntries(result); idx_t offset = 0; - for (auto &arg : args.data) { + for (const auto &arg : args.data) { const auto &child_cols = StructVector::GetEntries(arg); for (auto &child_col : child_cols) { result_cols[offset++].Reference(child_col); @@ -30,7 +30,7 @@ static unique_ptr StructConcatBind(BindScalarFunctionInput &input) } child_list_t combined_children; - case_insensitive_set_t name_set; + identifier_set_t name_set; bool has_unnamed = false; @@ -50,13 +50,13 @@ static unique_ptr StructConcatBind(BindScalarFunctionInput &input) if (!child.first.empty()) { auto it = name_set.find(child.first); if (it != name_set.end()) { - if (*it == child.first) { + if (it->GetIdentifierName() == child.first.GetIdentifierName()) { throw InvalidInputException("struct_concat: Arguments contain duplicate STRUCT entry \"%s\"", - child.first); + child.first.GetIdentifierName()); } throw InvalidInputException( "struct_concat: Arguments contain case-insensitive duplicate STRUCT entry \"%s\" and \"%s\"", - child.first, *it); + child.first.GetIdentifierName(), it->GetIdentifierName()); } name_set.insert(child.first); } else { @@ -78,7 +78,7 @@ static unique_ptr StructConcatStats(ClientContext &context, Func const auto &expr = input.expr; auto &arg_stats = input.child_stats; - auto &arg_exprs = input.expr.children; + auto &arg_exprs = input.expr.GetChildren(); auto struct_stats = StructStats::CreateUnknown(expr.GetReturnType()); idx_t struct_index = 0; diff --git a/src/duckdb/src/function/scalar/struct/struct_contains.cpp b/src/duckdb/src/function/scalar/struct/struct_contains.cpp index 18ecf09a2..ad2271536 100644 --- a/src/duckdb/src/function/scalar/struct/struct_contains.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_contains.cpp @@ -10,8 +10,8 @@ namespace duckdb { template -static void TemplatedStructSearch(Vector &input_vector, vector &members, Vector &target, const idx_t count, - Vector &result) { +static void TemplatedStructSearch(const Vector &input_vector, const vector &members, const Vector &target, + const idx_t count, Vector &result) { // If the return type is not a bool, return the position const auto return_pos = std::is_same::value; @@ -27,7 +27,7 @@ static void TemplatedStructSearch(Vector &input_vector, vector &members, vector member_data_ptrs; vector member_vectors; idx_t total_matches = 0; - for (auto &member : members) { + for (const auto &member : members) { if (member.GetType().InternalType() == target_type.InternalType()) { UnifiedVectorFormat member_format; member.ToUnifiedFormat(member_format); @@ -111,8 +111,8 @@ static void TemplatedStructSearch(Vector &input_vector, vector &members, } template -static void StructNestedOp(Vector &input_vector, vector &members, Vector &target, const idx_t count, - Vector &result) { +static void StructNestedOp(const Vector &input_vector, const vector &members, const Vector &target, + const idx_t count, Vector &result) { const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); // Set up sort keys for nested types. @@ -120,21 +120,21 @@ static void StructNestedOp(Vector &input_vector, vector &members, Vector vector member_sort_key_vectors; for (idx_t i = 0; i < members_size; i++) { Vector member_sort_key_vec(LogicalType::BLOB, count); - CreateSortKeyHelpers::CreateSortKeyWithValidity(members[i], member_sort_key_vec, order_modifiers, count); + CreateSortKeyHelpers::CreateSortKeyWithValidity(members[i], member_sort_key_vec, order_modifiers); member_sort_key_vectors.push_back(std::move(member_sort_key_vec)); } Vector target_sort_key_vec(LogicalType::BLOB, count); - CreateSortKeyHelpers::CreateSortKeyWithValidity(target, target_sort_key_vec, order_modifiers, count); + CreateSortKeyHelpers::CreateSortKeyWithValidity(target, target_sort_key_vec, order_modifiers); TemplatedStructSearch(input_vector, member_sort_key_vectors, target_sort_key_vec, count, result); } template -static void StructSearchOp(Vector &input_vector, vector &members, Vector &target, const idx_t count, - Vector &result) { +static void StructSearchOp(const Vector &input_vector, const vector &members, const Vector &target, + const idx_t count, Vector &result) { const auto &target_type = target.GetType().InternalType(); switch (target_type) { case PhysicalType::BOOL: @@ -184,9 +184,9 @@ static void StructSearchFunction(DataChunk &args, ExpressionState &state, Vector } const auto count = args.size(); - auto &input_vector = args.data[0]; - auto &members = StructVector::GetEntries(input_vector); - auto &target = args.data[1]; + const auto &input_vector = args.data[0]; + const auto &members = StructVector::GetEntries(input_vector); + const auto &target = args.data[1]; StructSearchOp(input_vector, members, target, count, result); } diff --git a/src/duckdb/src/function/scalar/struct/struct_extract.cpp b/src/duckdb/src/function/scalar/struct/struct_extract.cpp index eb6731199..4c1d47beb 100644 --- a/src/duckdb/src/function/scalar/struct/struct_extract.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_extract.cpp @@ -13,13 +13,13 @@ namespace duckdb { static void StructExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); // this should be guaranteed by the binder - auto &vec = args.data[0]; + const auto &vec = args.data[0]; vec.Verify(); - auto &children = StructVector::GetEntries(vec); + const auto &children = StructVector::GetEntries(vec); D_ASSERT(info.index < children.size()); auto &struct_child = children[info.index]; result.Reference(struct_child); @@ -35,7 +35,7 @@ static unique_ptr StructExtractBind(BindScalarFunctionInput &input if (child_type.id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); } - D_ASSERT(LogicalTypeId::STRUCT == child_type.id() || child_type.IsAggregateStateStructType()); + D_ASSERT(LogicalTypeId::STRUCT == child_type.id()); auto &struct_children = StructType::GetChildTypes(child_type); if (struct_children.empty()) { throw InternalException("Can't extract something from an empty struct"); @@ -60,7 +60,7 @@ static unique_ptr StructExtractBind(BindScalarFunctionInput &input if (key_val.IsNull() || key_str.empty()) { throw BinderException("Key name for struct_extract needs to be neither NULL nor empty"); } - string key = StringUtil::Lower(key_str); + auto key = Identifier(key_str); LogicalType return_type; idx_t key_index = 0; @@ -68,7 +68,7 @@ static unique_ptr StructExtractBind(BindScalarFunctionInput &input for (size_t i = 0; i < struct_children.size(); i++) { auto &child = struct_children[i]; - if (StringUtil::Lower(child.first) == key) { + if (child.first == key) { found_key = true; key_index = i; return_type = child.second; @@ -80,11 +80,11 @@ static unique_ptr StructExtractBind(BindScalarFunctionInput &input vector candidates; candidates.reserve(struct_children.size()); for (auto &struct_child : struct_children) { - candidates.push_back(struct_child.first); + candidates.emplace_back(struct_child.first); } auto closest_settings = StringUtil::TopNJaroWinkler(candidates, key); auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); - throw BinderException("Could not find key \"%s\" in struct\n%s", key, message); + throw BinderException("Could not find key \"%s\" in struct\n%s", key.GetIdentifierName(), message); } bound_function.SetReturnType(std::move(return_type)); @@ -162,12 +162,12 @@ ScalarFunction GetKeyExtractFunction() { ScalarFunction GetIndexExtractFunction() { return ScalarFunction("struct_extract", {LogicalTypeId::STRUCT, LogicalType::BIGINT}, LogicalType::ANY, - StructExtractFunction, StructExtractBindIndex); + StructExtractFunction, StructExtractBindIndex, PropagateStructExtractStats); } ScalarFunction GetExtractAtFunction() { return ScalarFunction("struct_extract_at", {LogicalTypeId::STRUCT, LogicalType::BIGINT}, LogicalType::ANY, - StructExtractFunction, StructExtractAtBind); + StructExtractFunction, StructExtractAtBind, PropagateStructExtractStats); } ScalarFunctionSet StructExtractFun::GetFunctions() { diff --git a/src/duckdb/src/function/scalar/struct/struct_pack.cpp b/src/duckdb/src/function/scalar/struct/struct_pack.cpp index fa6cb03fd..01c3c61b3 100644 --- a/src/duckdb/src/function/scalar/struct/struct_pack.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_pack.cpp @@ -14,7 +14,7 @@ namespace duckdb { static void StructPackFunction(DataChunk &args, ExpressionState &state, Vector &result) { #ifdef DEBUG auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); // this should never happen if the binder below is sane D_ASSERT(args.ColumnCount() == StructType::GetChildTypes(info.stype).size()); #endif @@ -42,7 +42,7 @@ template static unique_ptr StructPackBind(BindScalarFunctionInput &input) { auto &bound_function = input.GetBoundFunction(); auto &arguments = input.GetArguments(); - case_insensitive_set_t name_collision_set; + identifier_set_t name_collision_set; // collect names and deconflict, construct return type if (arguments.empty()) { @@ -56,13 +56,13 @@ static unique_ptr StructPackBind(BindScalarFunctionInput &input) { if (child->GetAlias().empty()) { throw BinderException("Need named argument for struct pack, e.g. STRUCT_PACK(a := b)"); } - alias = child->GetAlias(); - if (name_collision_set.find(alias) != name_collision_set.end()) { + alias = child->GetAlias().GetIdentifierName(); + if (name_collision_set.find(Identifier(alias)) != name_collision_set.end()) { throw BinderException("Duplicate struct entry name \"%s\"", alias); } - name_collision_set.insert(alias); + name_collision_set.insert(Identifier(alias)); } - struct_children.push_back(make_pair(alias, arguments[i]->GetReturnType())); + struct_children.emplace_back(make_pair(alias, arguments[i]->GetReturnType())); } // this is more for completeness reasons @@ -86,6 +86,10 @@ static ScalarFunction GetStructPackFunction() { StructPackBind, StructPackStats); fun.SetVarArgs(LogicalType::ANY); fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + // struct_pack/row derive their (struct field) names from argument aliases, so the binder must capture argument + // expression aliases as named-argument names. This also preserves the legacy behavior of allowing positional + // arguments after named ones (the positional arguments simply take their expression's name as the field name). + fun.SetCaptureArgumentAliases(true); fun.SetSerializeCallback(VariableReturnBindData::Serialize); fun.SetDeserializeCallback(VariableReturnBindData::Deserialize); return fun; diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp index 44b97ae49..c83510905 100644 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -1,10 +1,10 @@ #include "duckdb/common/vector/flat_vector.hpp" -#include "duckdb/common/vector/map_vector.hpp" -#include "duckdb/common/vector/string_vector.hpp" #include "duckdb/common/vector/struct_vector.hpp" +#include "duckdb/function/aggregate_state_layout.hpp" #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/common/extension_type_info.hpp" #include "duckdb/common/types/value.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/function/scalar/generic_common.hpp" #include "duckdb/function/scalar/system_functions.hpp" @@ -22,24 +22,34 @@ namespace { // aggregate state export struct ExportAggregateBindData : public FunctionData { BoundAggregateFunction aggr; + unique_ptr bind_data; idx_t state_size; - explicit ExportAggregateBindData(BoundAggregateFunction aggr_p, idx_t state_size_p) - : aggr(std::move(aggr_p)), state_size(state_size_p) { + explicit ExportAggregateBindData(BoundAggregateFunction aggr_p, unique_ptr bind_data_p, + idx_t state_size_p) + : aggr(std::move(aggr_p)), bind_data(std::move(bind_data_p)), state_size(state_size_p) { } unique_ptr Copy() const override { - return make_uniq(aggr, state_size); + return make_uniq(aggr, bind_data ? bind_data->Copy() : nullptr, state_size); } bool Equals(const FunctionData &other_p) const override { auto &other = other_p.Cast(); + if (bind_data.get() != other.bind_data.get()) { + if (!bind_data || !other.bind_data) { + return false; + } + if (!bind_data->Equals(*other.bind_data)) { + return false; + } + } return aggr == other.aggr && state_size == other.state_size; } static ExportAggregateBindData &GetFrom(ExpressionState &state) { auto &func_expr = state.expr.Cast(); - return func_expr.bind_info->Cast(); + return func_expr.BindInfo()->Cast(); } }; @@ -92,262 +102,144 @@ void TemplateDispatch(PhysicalType type, ARGS &&... args) { OP::template Operation(std::forward(args)...); break; default: - throw InternalException("Unsupported physical type: %s for aggregate state", TypeIdToString(type)); + throw NotImplementedException("Unsupported physical type for default aggregate state export: %s", + TypeIdToString(type)); } } -struct AggregateStateLayout { - AggregateStateLayout(const LogicalType &type, idx_t state_size) - : state_size(state_size), aligned_state_size(AlignValue(state_size)) { - owned_type = type; - is_struct = type.InternalType() == PhysicalType::STRUCT; - if (is_struct) { - child_types = &StructType::GetChildTypes(owned_type); - } - } - - // Reconstruct a packed binary state from the vector representation - // Works only on a legacy state format where the entire state is stored as a blob - void Load(Vector &vec, const UnifiedVectorFormat &state_data, idx_t row, data_ptr_t dest) { - D_ASSERT(!is_struct); - auto idx = state_data.sel->get_index(row); - auto &blob = UnifiedVectorFormat::GetData(state_data)[idx]; - if (blob.GetSize() != state_size) { - throw IOException("Aggregate state size mismatch, expect %llu, got %llu", state_size, blob.GetSize()); - } - memcpy(dest, blob.GetData(), state_size); - } - - // Serializes a packed binary state back into a `Vector` format - // Works only on a legacy state format where the entire state is stored as a blob - void Store(Vector &result, idx_t row, data_ptr_t src) const { - D_ASSERT(!is_struct); - auto result_ptr = FlatVector::GetDataMutable(result); - result_ptr[row] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(src), state_size); - } - - bool is_struct; - idx_t state_size; - idx_t aligned_state_size; - LogicalType owned_type; - const child_list_t *child_types = nullptr; -}; +static AggregateStateLayout GetLayout(const BoundAggregateFunction &aggr) { + return aggr.GetStateTypeCallback()(aggr); +} -/* - * Load a specific field from the struct aggregate state into the packed binary representation - * By iterating over rows in the inner loop for a specific type T, the compiler is given a tight loop with a - * predictable memory access pattern. Since the field in a struct is a contiguous array of type `T`, the compiler - * can easily vectorize the read operations using SIMD instructions - */ -struct LoadFieldOp { +// Load rows from input_vec into the packed binary state buffer. Skips null rows. +struct LoadOp { template - static void Operation(idx_t root_stride, Vector &struct_vec, idx_t field_idx, const UnifiedVectorFormat &state_data, - idx_t count, data_ptr_t base_ptr, idx_t field_offset) { - auto &child = StructVector::GetEntries(struct_vec)[field_idx]; - auto child_data = FlatVector::GetData(child); - - for (idx_t row = 0; row < count; row++) { - auto row_idx = state_data.sel->get_index(row); - if (!state_data.validity.RowIsValid(row_idx)) { - continue; + static void Operation(idx_t root_stride, const Vector &input_vec, idx_t count, data_ptr_t base_ptr, + idx_t field_offset) { + auto values = input_vec.Values(); + for (idx_t i = 0; i < count; i++) { + const auto entry = values[i]; + if (entry.IsValid()) { + Store(entry.GetValue(), base_ptr + i * root_stride + field_offset); } - - auto dest = base_ptr + row * root_stride + field_offset; - *reinterpret_cast(dest) = child_data[row]; } } }; -struct StoreFieldOp { +// Store rows from the packed binary state buffer into a result vector. +struct StoreOp { template - static void Operation(Vector &struct_vec, idx_t field_idx, idx_t count, const data_ptr_t *sources, - idx_t field_offset) { - auto &child = StructVector::GetEntries(struct_vec)[field_idx]; - auto child_data = FlatVector::Writer(child, count); - - for (idx_t row = 0; row < count; row++) { - auto src = sources[row] + field_offset; - child_data.WriteValue(*reinterpret_cast(src)); + static void Operation(Vector &result, idx_t count, const data_ptr_t *sources, idx_t field_offset) { + auto dst = FlatVector::Writer(result, count); + for (idx_t i = 0; i < count; i++) { + dst.WriteValue(Load(sources[i] + field_offset)); } } }; -struct CopyFromInputFieldOp { +// Serialize an OptionalStateType state buffer (flat: T value at +0, bool is_set at +sizeof(T)) +// into a result vector, setting NULL when is_set=false. +struct StorePrimitiveOptionalOp { template - static void Operation(Vector &input_vec, Vector &result_vec, idx_t field_idx, const SelectionVector &sel, - idx_t count, const UnifiedVectorFormat &input_data) { - auto &input_child = StructVector::GetEntries(input_vec)[field_idx]; - auto input_child_data = FlatVector::GetData(input_child); - - auto &result_child = StructVector::GetEntries(result_vec)[field_idx]; - auto result_child_data = FlatVector::GetDataMutable(result_child); - + static void Operation(Vector &result, idx_t count, const data_ptr_t *addresses, idx_t base_offset) { + auto dst = FlatVector::Writer(result, count); for (idx_t i = 0; i < count; i++) { - idx_t row = sel.get_index(i); - auto src_idx = input_data.sel->get_index(row); - result_child_data[row] = input_child_data[src_idx]; + const auto value = Load(addresses[i] + base_offset); + const bool is_set = Load(addresses[i] + base_offset + sizeof(T)); + if (is_set) { + dst.WriteValue(value); + } else { + dst.WriteNull(); + } } } }; -struct LoadFieldForSelectedRowsOp { +// Deserialize a nullable input vector into OptionalStateType state buffer slots +// (flat: T value at +0, bool is_set at +sizeof(T)). +struct LoadPrimitiveOptionalOp { template - static void Operation(const AggregateStateLayout &layout, Vector &struct_vec, idx_t field_idx, - const SelectionVector &sel, idx_t count, const UnifiedVectorFormat &state_data, - data_ptr_t base_ptr, idx_t field_offset) { - auto &child = StructVector::GetEntries(struct_vec)[field_idx]; - auto child_data = FlatVector::GetData(child); - + static void Operation(const Vector &input, idx_t count, data_ptr_t base_ptr, idx_t aligned_state_size) { + auto values = input.Values(); for (idx_t i = 0; i < count; i++) { - idx_t row = sel.get_index(i); - auto src_idx = state_data.sel->get_index(row); - auto dest = base_ptr + i * layout.aligned_state_size + field_offset; - *reinterpret_cast(dest) = child_data[src_idx]; + const auto entry = values[i]; + auto *value_ptr = reinterpret_cast(base_ptr + i * aligned_state_size); + auto *is_set_ptr = reinterpret_cast(base_ptr + i * aligned_state_size + sizeof(T)); + if (entry.IsValid()) { + *value_ptr = entry.GetValue(); + *is_set_ptr = true; + } else { + *is_set_ptr = false; + } } } }; -struct StoreFieldForSelectedRowsOp { - template - static void Operation(const AggregateStateLayout &layout, Vector &result, idx_t field_idx, - const SelectionVector &sel, idx_t count, data_ptr_t base_ptr, idx_t field_offset) { - auto &child = StructVector::GetEntries(result)[field_idx]; - auto child_data = FlatVector::ScatterWriter(child); - +void DeserializeState(const AggregateStateLayout &layout, const Vector &input_vec, idx_t count, + data_ptr_t dest_buffer) { + if (layout.field.is_optional && layout.type.id() == LogicalTypeId::STRUCT) { + // Optional struct: deserialize children, then derive is_set from the struct's row validity. + const idx_t struct_size = AggregateStateField::GetPhysicalSize(layout.type); + const auto &child_types = StructType::GetChildTypes(layout.type); + const auto &struct_entries = StructVector::GetEntries(input_vec); + for (idx_t field_idx = 0; field_idx < layout.field.children.size(); field_idx++) { + const auto &child = layout.field.children[field_idx]; + AggregateStateLayout child_layout(child_types[field_idx].second, layout.total_state_size, + child.is_optional); + DeserializeState(child_layout, struct_entries[field_idx], count, dest_buffer + child.field_offset); + } + auto &validity = FlatVector::Validity(input_vec); for (idx_t i = 0; i < count; i++) { - idx_t row = sel.get_index(i); - auto src = base_ptr + i * layout.aligned_state_size + field_offset; - child_data[row] = *reinterpret_cast(src); + *reinterpret_cast(dest_buffer + i * layout.total_state_size + struct_size) = validity.RowIsValid(i); } - } -}; - -idx_t GetRecursivePhysicalSize(const LogicalType &type) { - if (type.id() != LogicalTypeId::STRUCT) { - return GetTypeIdSize(type.InternalType()); - } - idx_t size = 0; - for (const auto &child : StructType::GetChildTypes(type)) { - size += GetRecursivePhysicalSize(child.second); - } - return size; -} - -// Deserialize struct fields from a flattened struct vector into a state buffer. -// Uses LoadFieldOp for tight SIMD-friendly loops. -// root_stride is the aligned size of the top-level state, used to stride between rows in the buffer (root-state's -// size must be taken into account) -void DeserializeStructFields(const AggregateStateLayout &layout, idx_t root_stride, Vector &struct_vec, - const UnifiedVectorFormat &state_data, idx_t count, data_ptr_t dest_buffer) { - idx_t offset_in_state = 0; - for (idx_t field_idx = 0; field_idx < layout.child_types->size(); field_idx++) { - auto &field_type = layout.child_types->at(field_idx).second; - auto physical = field_type.InternalType(); - idx_t field_size = GetRecursivePhysicalSize(field_type); - idx_t alignment = MinValue(field_size, 8); - offset_in_state = AlignValue(offset_in_state, alignment); - - if (field_type.id() == LogicalTypeId::STRUCT) { - auto child_layout = AggregateStateLayout(field_type, field_size); - auto &struct_entries = StructVector::GetEntries(struct_vec); - DeserializeStructFields(child_layout, root_stride, struct_entries[field_idx], state_data, count, - dest_buffer + offset_in_state); - } else { - TemplateDispatch(physical, root_stride, struct_vec, field_idx, state_data, count, dest_buffer, - offset_in_state); + } else if (layout.type.id() == LogicalTypeId::STRUCT) { + const auto &child_types = StructType::GetChildTypes(layout.type); + const auto &struct_entries = StructVector::GetEntries(input_vec); + for (idx_t field_idx = 0; field_idx < layout.field.children.size(); field_idx++) { + const auto &child = layout.field.children[field_idx]; + AggregateStateLayout child_layout(child_types[field_idx].second, layout.total_state_size, + child.is_optional); + DeserializeState(child_layout, struct_entries[field_idx], count, dest_buffer + child.field_offset); } - - offset_in_state += field_size; + } else if (layout.field.is_optional) { + TemplateDispatch(layout.type.InternalType(), input_vec, count, dest_buffer, + layout.total_state_size); + } else { + TemplateDispatch(layout.type.InternalType(), layout.total_state_size, input_vec, count, dest_buffer, 0); } } -// Serialize packed binary states into a struct result vector. -// Uses StoreFieldOp for tight SIMD-friendly loops -void SerializeStructFields(const AggregateStateLayout &layout, Vector &result, idx_t count, - const data_ptr_t *addresses_ptrs) { - idx_t offset_in_state = 0; - for (idx_t field_idx = 0; field_idx < layout.child_types->size(); field_idx++) { - auto &field_type = layout.child_types->at(field_idx).second; - auto physical = field_type.InternalType(); - idx_t field_size = GetRecursivePhysicalSize(field_type); - idx_t alignment = MinValue(field_size, 8); - offset_in_state = AlignValue(offset_in_state, alignment); - - if (field_type.id() == LogicalTypeId::STRUCT) { - auto child_layout = AggregateStateLayout(field_type, field_size); - - // we need to write to the buffers with the current offset the child is pointing to in the state - Vector child_addresses(LogicalType::POINTER); - auto child_writer = FlatVector::Writer(child_addresses, count); - for (idx_t row = 0; row < count; row++) { - child_writer.WriteValue(addresses_ptrs[row] + offset_in_state); +void SerializeState(const AggregateStateLayout &layout, Vector &result, idx_t count, const data_ptr_t *addresses, + idx_t base_offset = 0) { + if (layout.field.is_optional && layout.type.id() == LogicalTypeId::STRUCT) { + // Optional struct: mark null rows in the struct's validity, then serialize children for all rows. + const idx_t struct_size = AggregateStateField::GetPhysicalSize(layout.type); + const auto &child_types = StructType::GetChildTypes(layout.type); + auto &struct_entries = StructVector::GetEntries(result); + for (idx_t i = 0; i < count; i++) { + if (!Load(addresses[i] + base_offset + struct_size)) { + FlatVector::SetNull(result, i, true); } - - auto child_ptrs = FlatVector::GetData(child_addresses); - auto &struct_entries = StructVector::GetEntries(result); - SerializeStructFields(child_layout, struct_entries[field_idx], count, child_ptrs); - } else { - TemplateDispatch(physical, result, field_idx, count, addresses_ptrs, offset_in_state); } - - offset_in_state += field_size; - } -} - -#ifdef DEBUG -// Internal verification: export all states to struct, import back, finalize, and compare to main path result -// using vector operations. Catches serialization/deserialization bugs for struct-based aggregate state. -void VerifyStructStateRoundtrip(const AggregateStateLayout &layout, const Vector &state_vec, idx_t count, - const UnifiedVectorFormat &state_data, const data_ptr_t *state_vec_ptr, - const Vector &result, const ExportAggregateBindData &bind_data, - AggregateInputData &aggr_input_data) { - // Build selection of valid state rows - SelectionVector valid_sel(count); - idx_t valid_count = 0; - for (idx_t i = 0; i < count; i++) { - if (state_data.validity.RowIsValid(state_data.sel->get_index(i))) { - valid_sel.set_index(valid_count++, i); + for (idx_t field_idx = 0; field_idx < layout.field.children.size(); field_idx++) { + const auto &child = layout.field.children[field_idx]; + AggregateStateLayout child_layout(child_types[field_idx].second, 0, child.is_optional); + SerializeState(child_layout, struct_entries[field_idx], count, addresses, base_offset + child.field_offset); + } + } else if (layout.type.id() == LogicalTypeId::STRUCT) { + const auto &child_types = StructType::GetChildTypes(layout.type); + auto &struct_entries = StructVector::GetEntries(result); + for (idx_t field_idx = 0; field_idx < layout.field.children.size(); field_idx++) { + const auto &child = layout.field.children[field_idx]; + AggregateStateLayout child_layout(child_types[field_idx].second, 0, child.is_optional); + SerializeState(child_layout, struct_entries[field_idx], count, addresses, base_offset + child.field_offset); } + } else if (layout.field.is_optional) { + TemplateDispatch(layout.type.InternalType(), result, count, addresses, base_offset); + } else { + TemplateDispatch(layout.type.InternalType(), result, count, addresses, base_offset); } - if (valid_count == 0) { - return; - } - - // SerializeStructFields: all valid states -> struct vector - Vector struct_vec(state_vec.GetType(), valid_count); - unsafe_unique_array addresses_ptrs(make_unsafe_uniq_array(valid_count)); - for (idx_t i = 0; i < valid_count; i++) { - addresses_ptrs[i] = state_vec_ptr[valid_sel.get_index(i)]; - } - SerializeStructFields(layout, struct_vec, valid_count, addresses_ptrs.get()); - - // DeserializeStructFields: struct vector -> packed buffer - unsafe_unique_array temp_state_buf(make_unsafe_uniq_array(valid_count * layout.aligned_state_size)); - UnifiedVectorFormat struct_format; - struct_vec.ToUnifiedFormat(struct_format); - DeserializeStructFields(layout, layout.aligned_state_size, struct_vec, struct_format, valid_count, - temp_state_buf.get()); - - // AggregateStateFinalize: packed buffer -> result values - Vector addresses_vec(LogicalType::POINTER); - auto addresses_finalize = FlatVector::Writer(addresses_vec, valid_count); - for (idx_t i = 0; i < valid_count; i++) { - addresses_finalize.WriteValue(temp_state_buf.get() + i * layout.aligned_state_size); - } - Vector result_roundtrip(result.GetType(), valid_count); - bind_data.aggr.GetStateFinalizeCallback()(addresses_vec, aggr_input_data, result_roundtrip, valid_count, 0); - - Vector result_slice(result.GetType(), valid_count); - result_slice.Slice(result, valid_sel, valid_count); - SelectionVector true_sel(valid_count); - SelectionVector false_sel(valid_count); - idx_t match_count = - VectorOperations::NotDistinctFrom(result_slice, result_roundtrip, nullptr, valid_count, &true_sel, &false_sel); - D_ASSERT(match_count == valid_count && - "Struct state roundtrip failed: finalize(serialize->deserialize(state)) != finalize(state). " - "Check SerializeStructFields/DeserializeStructFields for this aggregate."); } -#endif struct CombineState : public FunctionLocalState { idx_t state_size; @@ -398,58 +290,26 @@ void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector & D_ASSERT(bind_data.state_size == bind_data.aggr.GetStateSizeCallback()(bind_data.aggr)); D_ASSERT(input.data.size() == 1); - D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::LEGACY_AGGREGATE_STATE || - input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); - - AggregateStateLayout layout(input.data[0].GetType(), bind_data.state_size); - - auto state_vec_writer = FlatVector::Writer(local_state.addresses, input.size()); - - input.data[0].Flatten(); - - UnifiedVectorFormat state_data; - input.data[0].ToUnifiedFormat(state_data); - - if (layout.is_struct) { - for (idx_t i = 0; i < input.size(); i++) { - state_vec_writer.WriteValue(local_state.state_buffer.get() + i * layout.aligned_state_size); - } - DeserializeStructFields(layout, layout.aligned_state_size, input.data[0], state_data, input.size(), - local_state.state_buffer.get()); - } else { - for (idx_t i = 0; i < input.size(); i++) { - auto target_ptr = local_state.state_buffer.get() + layout.aligned_state_size * i; - auto state_idx = state_data.sel->get_index(i); + auto layout = GetLayout(bind_data.aggr); - if (state_data.validity.RowIsValid(state_idx)) { - layout.Load(input.data[0], state_data, i, target_ptr); - } else { - // create a dummy state because finalize does not understand NULLs in its input - // we put the NULL back in explicitly below - bind_data.aggr.GetStateInitCallback()(bind_data.aggr, data_ptr_cast(target_ptr)); - } - state_vec_writer.WriteValue(data_ptr_cast(target_ptr)); - } + auto count = input.size(); + auto state_vec_writer = FlatVector::Writer(local_state.addresses, count); + for (idx_t i = 0; i < count; i++) { + state_vec_writer.WriteValue(local_state.state_buffer.get() + i * layout.total_state_size); } - AggregateInputData aggr_input_data(nullptr, local_state.allocator); - bind_data.aggr.GetStateFinalizeCallback()(local_state.addresses, aggr_input_data, result, input.size(), 0); + DeserializeState(layout, input.data[0], count, local_state.state_buffer.get()); - for (idx_t i = 0; i < input.size(); i++) { - auto state_idx = state_data.sel->get_index(i); - if (!state_data.validity.RowIsValid(state_idx)) { + AggregateInputData aggr_input_data(bind_data.aggr, bind_data.bind_data.get(), local_state.allocator); + bind_data.aggr.GetStateFinalizeCallback()(local_state.addresses, aggr_input_data, result, count, 0); + + auto validity = input.data[0].Validity(); + for (idx_t i = 0; i < count; i++) { + if (!validity.IsValid(i)) { FlatVector::SetNull(result, i, true); } } - -#ifdef DEBUG - if (layout.is_struct) { - auto state_vec_ptr = FlatVector::GetData(local_state.addresses); - VerifyStructStateRoundtrip(layout, input.data[0], input.size(), state_data, state_vec_ptr, result, bind_data, - aggr_input_data); - } -#endif } void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &result) { @@ -458,248 +318,130 @@ void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &r local_state.allocator.Reset(); D_ASSERT(bind_data.state_size == bind_data.aggr.GetStateSizeCallback()(bind_data.aggr)); - D_ASSERT(input.data.size() == 2); - D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::LEGACY_AGGREGATE_STATE || - input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); D_ASSERT(input.data[0].GetType() == result.GetType()); - AggregateStateLayout layout(input.data[0].GetType(), bind_data.state_size); - if (input.data[0].GetType().InternalType() != input.data[1].GetType().InternalType()) { throw IOException("Aggregate state combine type mismatch, expect %s, got %s", input.data[0].GetType().ToString(), input.data[1].GetType().ToString()); } - if (layout.is_struct) { - input.data[0].Flatten(); - input.data[1].Flatten(); - result.Flatten(); - } - - UnifiedVectorFormat state0_data, state1_data; - input.data[0].ToUnifiedFormat(state0_data); - input.data[1].ToUnifiedFormat(state1_data); - - // Partition rows by NULL using SelectionVector - SelectionVector both_null_sel(STANDARD_VECTOR_SIZE); - // input1 is null - SelectionVector copy_from_0_sel(STANDARD_VECTOR_SIZE); - // input0 is null - SelectionVector copy_from_1_sel(STANDARD_VECTOR_SIZE); - SelectionVector both_valid_sel(STANDARD_VECTOR_SIZE); - - idx_t both_null_count = 0, copy_from_0_count = 0, copy_from_1_count = 0, both_valid_count = 0; - - for (idx_t i = 0; i < input.size(); i++) { - const bool is_null0 = !state0_data.validity.RowIsValid(state0_data.sel->get_index(i)); - const bool is_null1 = !state1_data.validity.RowIsValid(state1_data.sel->get_index(i)); + auto layout = GetLayout(bind_data.aggr); + auto count = input.size(); - if (is_null0 && is_null1) { - both_null_sel.set_index(both_null_count++, i); - } else if (is_null0) { - copy_from_1_sel.set_index(copy_from_1_count++, i); - } else if (is_null1) { - copy_from_0_sel.set_index(copy_from_0_count++, i); - } else { - both_valid_sel.set_index(both_valid_count++, i); - } - } + result.Flatten(); - D_ASSERT(both_null_count + copy_from_0_count + copy_from_1_count + both_valid_count == input.size()); + auto validity0 = input.data[0].Validity(); + auto validity1 = input.data[1].Validity(); - // Handle both-null rows - for (idx_t i = 0; i < both_null_count; i++) { - FlatVector::SetNull(result, both_null_sel.get_index(i), true); - } - - // Handle one-null rows - copy non-null input directly to result - // copy_from_0: input1 is null, copy input0 - if (copy_from_0_count > 0) { - if (layout.is_struct) { - idx_t offset_in_state = 0; - for (idx_t field_idx = 0; field_idx < layout.child_types->size(); field_idx++) { - auto &field_type = layout.child_types->at(field_idx).second; - auto physical = field_type.InternalType(); - idx_t field_size = GetTypeIdSize(physical); - idx_t alignment = MinValue(field_size, 8); - offset_in_state = AlignValue(offset_in_state, alignment); - TemplateDispatch(physical, input.data[0], result, field_idx, copy_from_0_sel, - copy_from_0_count, state0_data); - offset_in_state += field_size; - } - } else { - for (idx_t i = 0; i < copy_from_0_count; i++) { - idx_t row = copy_from_0_sel.get_index(i); - layout.Load(input.data[0], state0_data, row, local_state.state_buffer0.get()); - layout.Store(result, row, local_state.state_buffer0.get()); - } - } - } - // copy_from_1: input0 is null, copy input1 - if (copy_from_1_count > 0) { - if (layout.is_struct) { - idx_t offset_in_state = 0; - for (idx_t field_idx = 0; field_idx < layout.child_types->size(); field_idx++) { - auto &field_type = layout.child_types->at(field_idx).second; - auto physical = field_type.InternalType(); - idx_t field_size = GetTypeIdSize(physical); - idx_t alignment = MinValue(field_size, 8); - offset_in_state = AlignValue(offset_in_state, alignment); - TemplateDispatch(physical, input.data[1], result, field_idx, copy_from_1_sel, - copy_from_1_count, state1_data); - offset_in_state += field_size; - } - } else { - for (idx_t i = 0; i < copy_from_1_count; i++) { - idx_t row = copy_from_1_sel.get_index(i); - layout.Load(input.data[1], state1_data, row, local_state.state_buffer1.get()); - layout.Store(result, row, local_state.state_buffer1.get()); - } - } + // Initialize both state buffers and build address vectors for all rows + auto state0_writer = FlatVector::Writer(local_state.addresses0, count); + auto state1_writer = FlatVector::Writer(local_state.addresses1, count); + for (idx_t i = 0; i < count; i++) { + auto ptr0 = local_state.state_buffer0.get() + i * layout.total_state_size; + auto ptr1 = local_state.state_buffer1.get() + i * layout.total_state_size; + bind_data.aggr.GetStateInitCallback()(bind_data.aggr, ptr0); + bind_data.aggr.GetStateInitCallback()(bind_data.aggr, ptr1); + state0_writer.WriteValue(ptr0); + state1_writer.WriteValue(ptr1); } - // Handle both-valid rows - batched load, combine, store - if (both_valid_count > 0) { - auto state0_writer = FlatVector::Writer(local_state.addresses0, both_valid_count); - auto state1_writer = FlatVector::Writer(local_state.addresses1, both_valid_count); + // Deserialize both inputs — null rows are skipped by LoadOp, keeping the initialized empty state + DeserializeState(layout, input.data[0], count, local_state.state_buffer0.get()); + DeserializeState(layout, input.data[1], count, local_state.state_buffer1.get()); - // Pack state buffer pointers in selection order (not row order) - for (idx_t i = 0; i < both_valid_count; i++) { - state0_writer.WriteValue(local_state.state_buffer0.get() + i * layout.aligned_state_size); - state1_writer.WriteValue(local_state.state_buffer1.get() + i * layout.aligned_state_size); - } + AggregateInputData aggr_input_data(bind_data.aggr, bind_data.bind_data.get(), local_state.allocator, + AggregateCombineType::ALLOW_DESTRUCTIVE); + bind_data.aggr.GetStateCombineCallback()(local_state.addresses0, local_state.addresses1, aggr_input_data, count); - auto state0_ptrs = FlatVector::GetData(local_state.addresses0); - auto state1_ptrs = FlatVector::GetData(local_state.addresses1); - - if (layout.is_struct) { - // Use tight loops to load both inputs - idx_t offset_in_state = 0; - for (idx_t field_idx = 0; field_idx < layout.child_types->size(); field_idx++) { - auto &field_type = layout.child_types->at(field_idx).second; - auto physical = field_type.InternalType(); - idx_t field_size = GetTypeIdSize(physical); - idx_t alignment = MinValue(field_size, 8); - offset_in_state = AlignValue(offset_in_state, alignment); - - TemplateDispatch(physical, layout, input.data[0], field_idx, both_valid_sel, - both_valid_count, state0_data, - local_state.state_buffer0.get(), offset_in_state); - TemplateDispatch(physical, layout, input.data[1], field_idx, both_valid_sel, - both_valid_count, state1_data, - local_state.state_buffer1.get(), offset_in_state); - offset_in_state += field_size; - } - } else { - // Handle blob case - load both inputs for all both_valid rows - for (idx_t i = 0; i < both_valid_count; i++) { - idx_t row = both_valid_sel.get_index(i); - layout.Load(input.data[0], state0_data, row, state0_ptrs[i]); - layout.Load(input.data[1], state1_data, row, state1_ptrs[i]); - } - } + SerializeState(layout, result, count, FlatVector::GetData(local_state.addresses1)); - // Single batched combine call - AggregateInputData aggr_input_data(nullptr, local_state.allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); - bind_data.aggr.GetStateCombineCallback()(local_state.addresses0, local_state.addresses1, aggr_input_data, - both_valid_count); - - // Store results - if (layout.is_struct) { - idx_t offset_in_state = 0; - for (idx_t field_idx = 0; field_idx < layout.child_types->size(); field_idx++) { - auto &field_type = layout.child_types->at(field_idx).second; - auto physical = field_type.InternalType(); - idx_t field_size = GetTypeIdSize(physical); - idx_t alignment = MinValue(field_size, 8); - offset_in_state = AlignValue(offset_in_state, alignment); - - TemplateDispatch(physical, layout, result, field_idx, both_valid_sel, - both_valid_count, local_state.state_buffer1.get(), - offset_in_state); - offset_in_state += field_size; - } - } else { - // Handle blob case - store results using the legacy not tight-loop strategy - for (idx_t i = 0; i < both_valid_count; i++) { - idx_t row = both_valid_sel.get_index(i); - layout.Store(result, row, state1_ptrs[i]); - } + // Rows where both inputs were NULL produce no meaningful combined state — mark result as NULL + for (idx_t i = 0; i < count; i++) { + if (!validity0.IsValid(i) && !validity1.IsValid(i)) { + FlatVector::SetNull(result, i, true); } } } -unique_ptr BindAggregateStateInternal(ClientContext &context, BoundSimpleFunction &function, - vector> &arguments, - bool allow_legacy) { - auto &arg_return_type = arguments[0]->GetReturnType(); - for (auto &arg_type : function.GetArguments()) { - arg_type = arg_return_type; - } - - if (arg_return_type.id() != LogicalTypeId::AGGREGATE_STATE && - (!allow_legacy || arg_return_type.id() != LogicalTypeId::LEGACY_AGGREGATE_STATE)) { - string allowed = allow_legacy ? "AGGREGATE_STATE or LEGACY_AGGREGATE_STATE" : "AGGREGATE_STATE"; - throw BinderException("Can only %s %s, not %s", function.GetName(), allowed, arg_return_type.ToString()); - } - - // following error states are only reachable when someone messes up creating the state_type - // which is impossible from SQL - - auto state_type = AggregateStateType::GetStateType(arg_return_type); - - // now we can look up the function in the catalog again and bind it - auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, DEFAULT_SCHEMA, - state_type.function_name); +// looks up the aggregate function with the given name in the catalog and binds it with the given argument types +unique_ptr BindExportedAggregate(ClientContext &context, const string &function_name, + const vector &argument_types) { + auto &func = Catalog::GetSystemCatalog(context).GetEntry( + context, Identifier::DefaultSchema(), Identifier(function_name)); if (func.type != CatalogType::AGGREGATE_FUNCTION_ENTRY) { - throw InternalException("Could not find aggregate %s", state_type.function_name); + throw InternalException("Could not find aggregate %s", function_name); } auto &aggr_entry = func.Cast(); ErrorData error; - FunctionBinder function_binder(context); - auto best_function = - function_binder.BindFunction(aggr_entry.name, aggr_entry.functions, state_type.bound_argument_types, error); + auto best_function = function_binder.BindFunction(aggr_entry.name, aggr_entry.functions, argument_types, error); if (!best_function.IsValid()) { - throw InternalException("Could not re-bind exported aggregate %s: %s", state_type.function_name, - error.Message()); + throw InternalException("Could not re-bind exported aggregate %s: %s", function_name, error.Message()); } const auto &aggr = aggr_entry.functions.GetFunctionByOffset(best_function.GetIndex()); // FIXME: this is really hacky // but the aggregate state export needs a rework around how it handles more complex aggregates anyway vector> args; - args.reserve(state_type.bound_argument_types.size()); - for (auto &arg_type : state_type.bound_argument_types) { + args.reserve(argument_types.size()); + for (auto &arg_type : argument_types) { args.push_back(make_uniq(Value(arg_type))); } auto [bound_aggr, bind_info] = function_binder.ResolveFunction(aggr, args); - if (bind_info) { - throw BinderException("Aggregate function with bind info not supported yet in aggregate state export"); + if (bound_aggr.GetArguments() != argument_types) { + throw InternalException("Type mismatch for exported aggregate %s", function_name); } - if (bound_aggr.GetReturnType() != state_type.return_type || - bound_aggr.GetArguments() != state_type.bound_argument_types) { - throw InternalException("Type mismatch for exported aggregate %s", state_type.function_name); + return make_uniq(bound_aggr, std::move(bind_info), + bound_aggr.GetStateSizeCallback()(bound_aggr)); +} + +unique_ptr BindAggregateStateInternal(ClientContext &context, BoundSimpleFunction &function, + vector> &arguments) { + auto &arg_return_type = arguments[0]->GetReturnType(); + if (!arg_return_type.IsAggregateState()) { + throw BinderException("Can only %s %s, not AGGREGATE_STATE", function.GetName(), arg_return_type.ToString()); + } + // + // now we can look up the function in the catalog again and bind it + auto ext_info = arg_return_type.GetExtensionInfo(); + auto entry = ext_info->properties.find("function_name"); + if (entry == ext_info->properties.end() || entry->second.type().id() != LogicalTypeId::VARCHAR || + entry->second.IsNull()) { + throw InternalException("Aggregate state object should have a property called function_name that is a string"); + } + auto &function_name = StringValue::Get(entry->second); + + entry = ext_info->properties.find("parameters"); + if (entry == ext_info->properties.end() || entry->second.type().id() != LogicalTypeId::LIST || + entry->second.IsNull()) { + throw InternalException( + "Aggregate state object should have a property called parameters that is a list of types"); + } + vector argument_types; + for (auto &val : ListValue::GetChildren(entry->second)) { + if (val.IsNull() || val.type().id() != LogicalTypeId::TYPE) { + throw InternalException( + "Aggregate state object should have a property called parameters that is a list of types"); + } + argument_types.push_back(TypeValue::GetType(val)); } - return make_uniq(bound_aggr, bound_aggr.GetStateSizeCallback()(bound_aggr)); + return BindExportedAggregate(context, function_name, argument_types); } unique_ptr BindAggregateState(BindScalarFunctionInput &input) { auto &bound_function = input.GetBoundFunction(); auto &arguments = input.GetArguments(); auto bind_data = - BindAggregateStateInternal(input.GetClientContext(), input.GetBoundFunction(), input.GetArguments(), true); + BindAggregateStateInternal(input.GetClientContext(), input.GetBoundFunction(), input.GetArguments()); - // combine - if (arguments.size() == 2 && arguments[0]->GetReturnType() != arguments[1]->GetReturnType() && - arguments[1]->GetReturnType().id() != LogicalTypeId::BLOB && - arguments[1]->GetReturnType().id() != LogicalTypeId::STRUCT) { + // combine - both arguments must be aggregate states of the same function with the same signature + if (arguments.size() == 2 && arguments[0]->GetReturnType() != arguments[1]->GetReturnType()) { throw BinderException("Cannot COMBINE aggregate states from different functions, %s <> %s", arguments[0]->GetReturnType().ToString(), arguments[1]->GetReturnType().ToString()); } @@ -709,11 +451,6 @@ unique_ptr BindAggregateState(BindScalarFunctionInput &input) { } else { D_ASSERT(bound_function.GetName() == "combine"); bound_function.SetReturnType(arguments[0]->GetReturnType()); - // When the second argument is a STRUCT (e.g. from a parquet round-trip), prevent casting to AGGREGATE_STATE by - // matching the declared parameter type to the actual argument type. - if (arguments.size() == 2 && arguments[1]->GetReturnType().id() == LogicalTypeId::STRUCT) { - bound_function.GetArguments()[1] = arguments[1]->GetReturnType(); - } } return std::move(bind_data); @@ -722,7 +459,6 @@ unique_ptr BindAggregateState(BindScalarFunctionInput &input) { void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { D_ASSERT(offset == 0); - auto &bind_data = aggr_input_data.bind_data->Cast(); const data_ptr_t *addresses_ptrs; if (state.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (count != 1) { @@ -733,41 +469,10 @@ void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, addresses_ptrs = FlatVector::GetData(state); } - auto state_size = bind_data.aggregate->function.GetStateSizeCallback()(bind_data.aggregate->function); - - // Note: The underlying state type should always be a struct (we have a D_ASSERT for that in `GetStateType` - bool should_result_as_struct = bind_data.aggregate->function.HasGetStateTypeCallback(); - if (should_result_as_struct) { - AggregateStateLayout layout(bind_data.aggregate->function.GetStateType(), state_size); + auto layout = GetLayout(aggr_input_data.function); - result.Flatten(); - SerializeStructFields(layout, result, count, addresses_ptrs); - return; - } - - auto blob_ptr = FlatVector::Writer(result, count); - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto data_ptr = addresses_ptrs[row_idx]; - blob_ptr.WriteValue(string_t(const_char_ptr_cast(data_ptr), UnsafeNumericCast(state_size))); - } -} - -void ExportStateAggregateSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const BoundAggregateFunction &function) { - throw NotImplementedException("FIXME: export state serialize"); -} - -unique_ptr ExportStateAggregateDeserialize(Deserializer &deserializer, BoundAggregateFunction &function) { - throw NotImplementedException("FIXME: export state deserialize"); -} - -void ExportStateScalarSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const BoundScalarFunction &function) { - throw NotImplementedException("FIXME: export state serialize"); -} - -unique_ptr ExportStateScalarDeserialize(Deserializer &deserializer, BoundScalarFunction &function) { - throw NotImplementedException("FIXME: export state deserialize"); + result.Flatten(); + SerializeState(layout, result, count, addresses_ptrs); } unique_ptr CombineAggrBind(BindAggregateFunctionInput &input) { @@ -775,7 +480,7 @@ unique_ptr CombineAggrBind(BindAggregateFunctionInput &input) { auto &function = input.GetBoundFunction(); auto &arguments = input.GetArguments(); - auto bind_data = BindAggregateStateInternal(context, function, arguments, false); + auto bind_data = BindAggregateStateInternal(context, function, arguments); // Copy underlying aggregate's callbacks into this function (same pattern as `ExportAggregateFunction::Bind`) function.SetStateSizeCallback(bind_data->aggr.GetStateSizeCallback()); @@ -793,22 +498,14 @@ void CombineAggrUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx auto &bind_data = aggr_input_data.bind_data->Cast(); auto &underlying_aggr = bind_data.aggr; - auto state_size = bind_data.state_size; - - AggregateStateLayout layout(inputs[0].GetType(), state_size); - UnifiedVectorFormat sdata; - states.ToUnifiedFormat(sdata); - auto state_ptrs = UnifiedVectorFormat::GetData(sdata); + auto layout = GetLayout(underlying_aggr); - inputs[0].Flatten(); - - UnifiedVectorFormat input_data; - inputs[0].ToUnifiedFormat(input_data); - - auto aligned_size = layout.aligned_state_size; + auto aligned_size = layout.total_state_size; unsafe_unique_array temp_state_buf = make_unsafe_uniq_array(count * aligned_size); + auto state_values = states.Values(); + // source_vec holds pointers to the binary states buffer (temp_state_buf) deserialized from the input states Vector source_vec(LogicalType::POINTER); auto source_data = FlatVector::Writer(source_vec, count); @@ -822,13 +519,14 @@ void CombineAggrUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx auto temp_ptr = temp_state_buf.get() + i * aligned_size; underlying_aggr.GetStateInitCallback()(underlying_aggr, temp_ptr); source_data.WriteValue(temp_ptr); - target_data.WriteValue(state_ptrs[sdata.sel->get_index(i)]); + target_data.WriteValue(state_values[i].GetValue()); } - DeserializeStructFields(layout, layout.aligned_state_size, inputs[0], input_data, count, temp_state_buf.get()); + DeserializeState(layout, inputs[0], count, temp_state_buf.get()); ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData combine_input(nullptr, allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); + AggregateInputData combine_input(bind_data.aggr, bind_data.bind_data.get(), allocator, + AggregateCombineType::ALLOW_DESTRUCTIVE); underlying_aggr.GetStateCombineCallback()(source_vec, target_vec, combine_input, count); } @@ -837,7 +535,6 @@ void CombineAggrFinalize(Vector &state, AggregateInputData &aggr_input_data, Vec D_ASSERT(offset == 0); auto &bind_data = aggr_input_data.bind_data->Cast(); auto &underlying_aggr = bind_data.aggr; - auto state_size = bind_data.state_size; const data_ptr_t *addresses_ptrs; if (state.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (count != 1) { @@ -848,23 +545,104 @@ void CombineAggrFinalize(Vector &state, AggregateInputData &aggr_input_data, Vec addresses_ptrs = FlatVector::GetData(state); } - AggregateStateLayout layout(underlying_aggr.GetStateType(), state_size); + auto layout = GetLayout(underlying_aggr); result.Flatten(); - SerializeStructFields(layout, result, count, addresses_ptrs); + SerializeState(layout, result, count, addresses_ptrs); +} + +// constructs the AGGREGATE_STATE type for the given bound aggregate function +// the state layout (a struct) is aliased to AGGREGATE_STATE, with the function name and signature stored in the +// extension type info so that the aggregate can be re-bound later (e.g. by FINALIZE/COMBINE) +LogicalType CreateAggregateStateType(const BoundAggregateFunction &bound_function) { + LogicalType state_layout = bound_function.GetStateType().type; + state_layout.SetAlias("AGGREGATE_STATE"); + auto ext_info = make_uniq(); + ext_info->properties.emplace("function_name", bound_function.GetName()); + vector arguments; + for (auto &arg : bound_function.GetOriginalArguments().empty() ? bound_function.GetArguments() + : bound_function.GetOriginalArguments()) { + arguments.push_back(Value::TYPE(arg)); + } + ext_info->properties.emplace("parameters", Value::LIST(LogicalType::TYPE(), std::move(arguments))); + state_layout.SetExtensionInfo(std::move(ext_info)); + return state_layout; +} + +unique_ptr ToAggregateStateBind(BindScalarFunctionInput &input) { + auto &bound_function = input.GetBoundFunction(); + auto &arguments = input.GetArguments(); + auto &context = input.GetClientContext(); + for (idx_t i = 1; i < 3; i++) { + if (arguments[i]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[i]->IsFoldable()) { + throw BinderException("to_aggregate_state: the aggregate name and signature must be constant"); + } + } + auto function_name_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); + if (function_name_val.IsNull()) { + throw BinderException("to_aggregate_state: the aggregate name cannot be NULL"); + } + auto function_name = StringValue::Get(function_name_val); + + auto signature_val = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); + if (signature_val.IsNull()) { + throw BinderException("to_aggregate_state: the signature must be a list of types"); + } + vector argument_types; + for (auto &arg : ListValue::GetChildren(signature_val)) { + if (arg.IsNull()) { + throw BinderException("to_aggregate_state: the signature cannot contain NULL values"); + } + if (arg.type().id() == LogicalTypeId::TYPE) { + argument_types.push_back(TypeValue::GetType(arg)); + } else if (arg.type().id() == LogicalTypeId::VARCHAR) { + argument_types.push_back(TransformStringToLogicalType(StringValue::Get(arg), context)); + } else { + throw BinderException("to_aggregate_state: the signature must be a list of types"); + } + } + + auto bind_data = BindExportedAggregate(context, function_name, argument_types); + auto &aggr = bind_data->aggr; + if (!aggr.HasGetStateTypeCallback()) { + throw BinderException( + "Aggregate function \"%s\" does not have a state type callback defined - cannot convert to its state", + function_name); + } + auto state_layout = aggr.GetStateType().type; + bound_function.GetArguments()[0] = state_layout; + bound_function.SetReturnType(CreateAggregateStateType(aggr)); + return std::move(bind_data); +} + +void ToAggregateStateFunction(DataChunk &input, ExpressionState &state, Vector &result) { + // the input type is verified to match the state layout at bind time - we only need to reinterpret the vector + result.Reinterpret(input.data[0]); } } // namespace +void ExportAggregateFunction::SetStateExport(BoundAggregateExpression &aggregate, LogicalType state_layout) { + auto &bound_function = aggregate.FunctionMutable(); + bound_function.SetStateFinalizeCallback(ExportAggregateFinalize); + // statistics propagation is no longer correct post + bound_function.SetStatisticsCallback(nullptr); + bound_function.SetReturnType(state_layout); + // exported state always produces a valid (non-NULL) struct even for empty inputs + bound_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + aggregate.StateExportModeMutable() = AggregateStateExportMode::STATE_EXPORT; + aggregate.SetReturnType(std::move(state_layout)); +} + unique_ptr ExportAggregateFunction::Bind(unique_ptr child_aggregate) { - auto &bound_function = child_aggregate->function; + auto &bound_function = child_aggregate->Function(); if (!bound_function.HasStateCombineCallback()) { throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.GetName()); } - if (bound_function.HasBindCallback()) { - throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom binders"); - } if (bound_function.HasStateDestructorCallback()) { throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom destructors"); } @@ -872,41 +650,14 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega D_ASSERT(bound_function.HasStateSizeCallback()); D_ASSERT(bound_function.HasStateFinalizeCallback()); - D_ASSERT(child_aggregate->function.GetReturnType().id() != LogicalTypeId::INVALID); -#ifdef DEBUG - for (auto &arg_type : child_aggregate->function.GetArguments()) { - D_ASSERT(arg_type.id() != LogicalTypeId::INVALID); - } -#endif - auto export_bind_data = make_uniq(child_aggregate->Copy()); - aggregate_state_t state_type(child_aggregate->function.GetName(), child_aggregate->function.GetReturnType(), - child_aggregate->function.GetArguments()); - - LogicalType return_type; - if (bound_function.HasGetStateTypeCallback()) { - LogicalType state_layout = bound_function.GetStateType(); - auto struct_child_types = StructType::GetChildTypes(state_layout); - return_type = LogicalType::AGGREGATE_STATE(std::move(state_type), std::move(struct_child_types)); - } else { - return_type = LogicalType::LEGACY_AGGREGATE_STATE(std::move(state_type)); + D_ASSERT(child_aggregate->Function().GetReturnType().id() != LogicalTypeId::INVALID); + if (!bound_function.HasGetStateTypeCallback()) { + throw NotImplementedException( + "Aggregate function \"%s\" does not have a state type callback defined - cannot export state", + bound_function.GetName()); } - - auto export_function = - AggregateFunction("aggregate_state_export_" + bound_function.GetName(), bound_function.GetArguments(), - return_type, bound_function.GetStateSizeCallback(), bound_function.GetStateInitCallback(), - bound_function.GetStateUpdateCallback(), bound_function.GetStateCombineCallback(), - ExportAggregateFinalize, bound_function.GetStateSimpleUpdateCallback(), - /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, - /* can't propagate statistics */ nullptr, nullptr); - export_function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - export_function.SetSerializeCallback(ExportStateAggregateSerialize); - export_function.SetDeserializeCallback(ExportStateAggregateDeserialize); - - BoundAggregateFunction bound_func(export_function); - - return make_uniq(std::move(bound_func), std::move(child_aggregate->children), - std::move(child_aggregate->filter), std::move(export_bind_data), - child_aggregate->aggr_type); + SetStateExport(*child_aggregate, CreateAggregateStateType(bound_function)); + return child_aggregate; } ExportAggregateFunctionBindData::ExportAggregateFunctionBindData(unique_ptr aggregate_p) { @@ -923,55 +674,35 @@ bool ExportAggregateFunctionBindData::Equals(const FunctionData &other_p) const return aggregate->Equals(*other.aggregate); } -static ScalarFunction CreateFinalizeFun(LogicalTypeId aggregate_state_logical_type_id) { - auto function = ScalarFunction("finalize", {aggregate_state_logical_type_id}, LogicalTypeId::INVALID, - AggregateStateFinalize, BindAggregateState, nullptr, InitFinalizeState); +ScalarFunction FinalizeFun::GetFunction() { + auto function = ScalarFunction("finalize", {LogicalTypeId::ANY}, LogicalTypeId::INVALID, AggregateStateFinalize, + BindAggregateState, nullptr, InitFinalizeState); function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - function.SetSerializeCallback(ExportStateScalarSerialize); - function.SetDeserializeCallback(ExportStateScalarDeserialize); return function; } -static ScalarFunction CreateCombineFun(LogicalTypeId aggregate_state_logical_type_id) { - auto function = ScalarFunction("combine", {aggregate_state_logical_type_id, LogicalTypeId::ANY}, - aggregate_state_logical_type_id, AggregateStateCombine, BindAggregateState, nullptr, - InitCombineState); +ScalarFunction CombineFun::GetFunction() { + auto function = ScalarFunction("combine", {LogicalTypeId::ANY, LogicalTypeId::ANY}, LogicalTypeId::ANY, + AggregateStateCombine, BindAggregateState, nullptr, InitCombineState); function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - function.SetSerializeCallback(ExportStateScalarSerialize); - function.SetDeserializeCallback(ExportStateScalarDeserialize); return function; } -ScalarFunctionSet FinalizeFun::GetFunctions() { - ScalarFunctionSet finalize_set; - - auto blob_finalize = CreateFinalizeFun(LogicalTypeId::LEGACY_AGGREGATE_STATE); - finalize_set.AddFunction(blob_finalize); - - auto struct_based_finalize = CreateFinalizeFun(LogicalTypeId::AGGREGATE_STATE); - finalize_set.AddFunction(struct_based_finalize); - - return finalize_set; -} - -ScalarFunctionSet CombineFun::GetFunctions() { - ScalarFunctionSet combine_set; - - auto blob_combine = CreateCombineFun(LogicalTypeId::LEGACY_AGGREGATE_STATE); - combine_set.AddFunction(blob_combine); - - auto struct_based_combine = CreateCombineFun(LogicalTypeId::AGGREGATE_STATE); - combine_set.AddFunction(struct_based_combine); - - return combine_set; +ScalarFunction ToAggregateStateFun::GetFunction() { + auto function = ScalarFunction("to_aggregate_state", + {LogicalTypeId::ANY, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::ANY)}, + LogicalTypeId::ANY, ToAggregateStateFunction, ToAggregateStateBind); + function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + return function; } AggregateFunction CombineAggrFun::GetFunction() { - auto function = AggregateFunction("combine_aggr", {LogicalTypeId::AGGREGATE_STATE}, LogicalTypeId::AGGREGATE_STATE, - nullptr, nullptr, CombineAggrUpdate, nullptr, CombineAggrFinalize, nullptr, - CombineAggrBind, nullptr, nullptr, nullptr); - function.SetNullHandling(FunctionNullHandling::DEFAULT_NULL_HANDLING); + auto function = + AggregateFunction("combine_aggr", {LogicalTypeId::ANY}, LogicalTypeId::ANY, nullptr, nullptr, CombineAggrUpdate, + nullptr, CombineAggrFinalize, FunctionNullHandling::SPECIAL_HANDLING, nullptr, + CombineAggrBind, nullptr, nullptr, nullptr); + function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return function; } diff --git a/src/duckdb/src/function/scalar/system/current_connection_id.cpp b/src/duckdb/src/function/scalar/system/current_connection_id.cpp index 9050f61af..47236e3eb 100644 --- a/src/duckdb/src/function/scalar/system/current_connection_id.cpp +++ b/src/duckdb/src/function/scalar/system/current_connection_id.cpp @@ -29,7 +29,7 @@ unique_ptr CurrentConnectionIdBind(BindScalarFunctionInput &input) void CurrentConnectionIdFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); result.Reference(info.connection_id, count_t(args.size())); } diff --git a/src/duckdb/src/function/scalar/system/current_query_id.cpp b/src/duckdb/src/function/scalar/system/current_query_id.cpp index 65cf3d0f4..afcdb837e 100644 --- a/src/duckdb/src/function/scalar/system/current_query_id.cpp +++ b/src/duckdb/src/function/scalar/system/current_query_id.cpp @@ -35,7 +35,7 @@ unique_ptr CurrentQueryIdBind(BindScalarFunctionInput &input) { void CurrentQueryIdFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); result.Reference(info.query_id, count_t(args.size())); } diff --git a/src/duckdb/src/function/scalar/system/current_transaction_id.cpp b/src/duckdb/src/function/scalar/system/current_transaction_id.cpp index 6f84b00d6..6a1756df2 100644 --- a/src/duckdb/src/function/scalar/system/current_transaction_id.cpp +++ b/src/duckdb/src/function/scalar/system/current_transaction_id.cpp @@ -36,7 +36,7 @@ unique_ptr CurrentTransactionIdBind(BindScalarFunctionInput &input void CurrentTransactionIdFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); result.Reference(info.transaction_id, count_t(args.size())); } diff --git a/src/duckdb/src/function/scalar/system/parse_log_message.cpp b/src/duckdb/src/function/scalar/system/parse_log_message.cpp index 47fe60c37..9368d9372 100644 --- a/src/duckdb/src/function/scalar/system/parse_log_message.cpp +++ b/src/duckdb/src/function/scalar/system/parse_log_message.cpp @@ -68,7 +68,7 @@ unique_ptr ParseLogMessageBind(BindScalarFunctionInput &input) { void ParseLogMessageFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); if (info.log_type.is_structured) { // TODO: allow more complex parsing operations than DefaultCast diff --git a/src/duckdb/src/function/scalar/system/write_log.cpp b/src/duckdb/src/function/scalar/system/write_log.cpp index fde746769..c826a4373 100644 --- a/src/duckdb/src/function/scalar/system/write_log.cpp +++ b/src/duckdb/src/function/scalar/system/write_log.cpp @@ -126,7 +126,7 @@ void WriteLogFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.ColumnCount() >= 1); auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); + const auto &info = func_expr.BindInfo()->Cast(); UnifiedVectorFormat idata; args.data[0].ToUnifiedFormat(idata); diff --git a/src/duckdb/src/function/scalar/variant/variant_bind_utils.cpp b/src/duckdb/src/function/scalar/variant/variant_bind_utils.cpp new file mode 100644 index 000000000..7741e8c90 --- /dev/null +++ b/src/duckdb/src/function/scalar/variant/variant_bind_utils.cpp @@ -0,0 +1,118 @@ +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +VariantPathBindData::VariantPathBindData() : FunctionData() { +} +VariantPathBindData::VariantPathBindData(const string &input_path) : FunctionData() { + if (input_path.empty()) { + paths.emplace_back(); + } else { + paths.push_back({VariantPathComponent(input_path)}); + } +} +VariantPathBindData::VariantPathBindData(const vector &input_paths) : FunctionData() { + for (const auto &path : input_paths) { + if (path.empty()) { + paths.emplace_back(); + } else { + paths.push_back({VariantPathComponent(path)}); + } + } +} + +unique_ptr VariantPathBindData::Copy() const { + return make_uniq(*this); +} + +bool VariantPathBindData::Equals(const FunctionData &other) const { + auto &bind_data = other.Cast(); + if (paths.size() != bind_data.paths.size()) { + return false; + } + + for (idx_t i = 0; i < paths.size(); i++) { + if (paths[i].size() != bind_data.paths[i].size()) { + return false; + } + for (idx_t j = 0; j < paths[i].size(); j++) { + if (paths[i][j] != bind_data.paths[i][j]) { + return false; + } + } + } + return true; +} + +static vector CollectPaths(const Value &constant_arg, const string &function_name) { + vector paths; + const auto &children = ListValue::GetChildren(constant_arg); + for (const auto &child : children) { + if (child.IsNull()) { + throw BinderException("'%s' does not accept NULL paths", function_name); + } + paths.push_back(child.GetValue()); + } + + return paths; +} + +bool VariantBindUtils::GetConstantArgument(ClientContext &context, const Expression &expr, Value &constant_arg) { + if (!expr.IsFoldable()) { + return false; + } + constant_arg = ExpressionExecutor::EvaluateScalar(context, expr); + if (!constant_arg.IsNull()) { + return true; + } + return false; +} + +unique_ptr VariantBindUtils::VariantPathBind(BindScalarFunctionInput &input) { + auto &context = input.GetClientContext(); + auto &arguments = input.GetArguments(); + const auto &function_name = input.GetBoundFunction().GetName(); + + if (arguments.size() != 1 && arguments.size() != 2) { + throw BinderException("'%s' expects either one VARIANT column argument, or two VARIANT column and " + "VARCHAR path arguments", + function_name); + } + + if (arguments.size() == 1) { + // No path supplied, execute function over the root of the variant + return make_uniq(); + } + + const auto &path_expr = *arguments[1]; + const auto &return_type = path_expr.GetReturnType(); + if (return_type.id() != LogicalTypeId::VARCHAR && return_type.id() != LogicalTypeId::LIST) { + throw BinderException("'%s' expects the second argument to be of type VARCHAR or VARCHAR[], not %s", + function_name, return_type.ToString()); + } + if (return_type.id() == LogicalTypeId::LIST) { + const auto child_type_id = ListType::GetChildType(return_type).id(); + if (child_type_id != LogicalTypeId::VARCHAR && child_type_id != LogicalTypeId::SQLNULL) { + throw BinderException("'%s' expects the second argument to be of type VARCHAR or VARCHAR[], not %s", + function_name, return_type.ToString()); + } + } + + Value constant_arg; + if (!GetConstantArgument(context, path_expr, constant_arg)) { + throw BinderException("'%s' expects the second argument to be a constant expression", function_name); + } + + const auto path_type_id = constant_arg.type().id(); + if (path_type_id == LogicalTypeId::VARCHAR) { + return make_uniq(constant_arg.GetValue()); + } + if (path_type_id == LogicalTypeId::LIST) { + return make_uniq(CollectPaths(constant_arg, function_name.GetIdentifierName())); + } + + throw BinderException("'%s' received an unexpected type for the second argument", function_name); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/variant/variant_exists.cpp b/src/duckdb/src/function/scalar/variant/variant_exists.cpp new file mode 100644 index 000000000..f95983138 --- /dev/null +++ b/src/duckdb/src/function/scalar/variant/variant_exists.cpp @@ -0,0 +1,102 @@ +#include "duckdb/common/vector/flat_vector.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/scalar/variant_functions.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" + +namespace duckdb { + +// TODO: Currently collection will always happen on the unshredded variant, introduce a fast path for shredded variants. +static ValidityMask CollectVariantExistence(const UnifiedVariantVectorData &variant, + const vector &components, const idx_t count) { + ValidityMask path_validity(count); + VariantPathSelection path_selection(count); + + const auto owned_nested_data = make_unsafe_uniq_array_uninitialized(count); + const array_ptr nested_data(owned_nested_data.get(), count); + + VariantUtils::TraversePath(variant, components, count, nested_data, path_validity, path_selection); + + return path_validity; +} + +static void UnaryVariantExists(const Vector &variant_vec, const vector &components, + Vector &result, const idx_t count) { + RecursiveUnifiedVectorFormat source_format; + Vector::RecursiveToUnifiedFormat(variant_vec, source_format); + const UnifiedVariantVectorData variant(source_format); + + const auto &path_validity = CollectVariantExistence(variant, components, count); + + result.Initialize(VectorDataInitialization::UNINITIALIZED, count); + auto row_writer = FlatVector::Writer(result, count); + + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + if (!variant.RowIsValid(row_idx)) { + row_writer.WriteNull(); + continue; + } + + if (path_validity.RowIsValid(row_idx)) { + row_writer.WriteValue(true); + } else { + row_writer.WriteValue(false); + } + } +} + +static void ManyVariantExists(const Vector &variant_vec, const vector> &paths, + Vector &result, const idx_t count) { + vector existence_by_path; + existence_by_path.reserve(paths.size()); + + RecursiveUnifiedVectorFormat source_format; + Vector::RecursiveToUnifiedFormat(variant_vec, source_format); + const UnifiedVariantVectorData variant(source_format); + + for (const auto &path : paths) { + existence_by_path.push_back(CollectVariantExistence(variant, path, count)); + } + + result.Initialize(VectorDataInitialization::UNINITIALIZED, count); + auto result_writer = FlatVector::Writer>(result, count); + + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + if (!variant.RowIsValid(row_idx)) { + result_writer.WriteNull(); + continue; + } + + auto row_writer = result_writer.WriteList(paths.size()); + idx_t path_idx = 0; + for (auto &path_existence_writer : row_writer) { + if (existence_by_path[path_idx].RowIsValid(row_idx)) { + path_existence_writer.WriteValue(true); + } else { + path_existence_writer.WriteValue(false); + } + + path_idx++; + } + } +} + +static void VariantExistsFunction(DataChunk &input, ExpressionState &state, Vector &result) { + VariantUtils::ExecutePathFunction(input, state, result, UnaryVariantExists, ManyVariantExists); +} + +ScalarFunctionSet VariantExistsFun::GetFunctions() { + ScalarFunctionSet fun_set; + + ScalarFunction variant_exists("variant_exists", {LogicalType::VARIANT(), LogicalType::VARCHAR}, + LogicalType::BOOLEAN, VariantExistsFunction, VariantBindUtils::VariantPathBind, + nullptr); + fun_set.AddFunction(variant_exists); + + variant_exists.GetSignature().GetParameter(1).SetType(LogicalType::LIST(LogicalType::VARCHAR)); + variant_exists.SetReturnType(LogicalType::LIST(LogicalType::BOOLEAN)); + fun_set.AddFunction(variant_exists); + + return fun_set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/variant/variant_extract.cpp b/src/duckdb/src/function/scalar/variant/variant_extract.cpp index 3215df2bc..99f76497c 100644 --- a/src/duckdb/src/function/scalar/variant/variant_extract.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_extract.cpp @@ -9,9 +9,7 @@ #include "duckdb/function/scalar/variant_functions.hpp" #include "duckdb/function/scalar/regexp.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/execution/expression_executor.hpp" #include "duckdb/storage/statistics/struct_stats.hpp" -#include "duckdb/storage/statistics/list_stats.hpp" namespace duckdb { @@ -30,28 +28,7 @@ unique_ptr VariantExtractBindData::Copy() const { bool VariantExtractBindData::Equals(const FunctionData &other) const { auto &bind_data = other.Cast(); - if (bind_data.component.lookup_mode != component.lookup_mode) { - return false; - } - if (bind_data.component.lookup_mode == VariantChildLookupMode::BY_INDEX && - bind_data.component.index != component.index) { - return false; - } - if (bind_data.component.lookup_mode == VariantChildLookupMode::BY_KEY && bind_data.component.key != component.key) { - return false; - } - return true; -} - -static bool GetConstantArgument(ClientContext &context, Expression &expr, Value &constant_arg) { - if (!expr.IsFoldable()) { - return false; - } - constant_arg = ExpressionExecutor::EvaluateScalar(context, expr); - if (!constant_arg.IsNull()) { - return true; - } - return false; + return component == bind_data.component; } static unique_ptr VariantExtractPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -83,14 +60,14 @@ static unique_ptr VariantExtractBind(BindScalarFunctionInput &inpu if (arguments.size() != 2) { throw BinderException("'variant_extract' expects two arguments, VARIANT column and VARCHAR path"); } - auto &path = *arguments[1]; + const auto &path = *arguments[1]; if (path.GetReturnType().id() != LogicalTypeId::VARCHAR && path.GetReturnType().id() != LogicalTypeId::UINTEGER) { throw BinderException("'variant_extract' expects the second argument to be of type VARCHAR or UINTEGER, not %s", path.GetReturnType().ToString()); } Value constant_arg; - if (!GetConstantArgument(context, path, constant_arg)) { + if (!VariantBindUtils::GetConstantArgument(context, path, constant_arg)) { throw BinderException("'variant_extract' expects the second argument to be a constant expression"); } @@ -103,8 +80,8 @@ static unique_ptr VariantExtractBind(BindScalarFunctionInput &inpu } } -static bool TryShreddedExtractRecursive(Vector &input, const vector &components, Vector &result, - idx_t count, idx_t path_index = 0) { +static bool TryShreddedExtractRecursive(const Vector &input, const vector &components, + Vector &result, idx_t count, idx_t path_index = 0) { if (path_index >= components.size()) { // reached the end of the path - shred if (input.GetType().IsNested()) { @@ -142,7 +119,7 @@ static bool TryShreddedExtractRecursive(Vector &input, const vector &components, Vector &result, - idx_t count) { +static bool TryFromShreddedExtract(const Vector &variant_vec, const vector &components, + Vector &result, idx_t count) { if (variant_vec.GetVectorType() != VectorType::SHREDDED_VECTOR) { // input vector is not shredded return false; @@ -165,12 +142,11 @@ static bool TryFromShreddedExtract(Vector &variant_vec, const vector &components, Vector &result, - idx_t count) { +void VariantUtils::VariantExtract(const Vector &variant_vec, const vector &components, + Vector &result, idx_t count) { if (TryFromShreddedExtract(variant_vec, components, result, count)) { return; } - auto &allocator = Allocator::DefaultAllocator(); RecursiveUnifiedVectorFormat source_format; Vector::RecursiveToUnifiedFormat(variant_vec, source_format); @@ -188,8 +164,8 @@ void VariantUtils::VariantExtract(Vector &variant_vec, const vector(owned_nested_data.get()); + const auto owned_nested_data = make_unsafe_uniq_array_uninitialized(count); + array_ptr nested_data(owned_nested_data.get(), count); //! Perform the extract ValidityMask validity(count); @@ -294,15 +270,15 @@ static void VariantExtractFunction(DataChunk &input, ExpressionState &state, Vec auto count = input.size(); D_ASSERT(input.ColumnCount() == 2); - auto &variant_vec = input.data[0]; + const auto &variant_vec = input.data[0]; D_ASSERT(variant_vec.GetType() == LogicalType::VARIANT()); - auto &path = input.data[1]; + const auto &path = input.data[1]; D_ASSERT(path.GetVectorType() == VectorType::CONSTANT_VECTOR); (void)path; auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); + auto &info = func_expr.BindInfo()->Cast(); VariantUtils::VariantExtract(variant_vec, {info.component}, result, count); } diff --git a/src/duckdb/src/function/scalar/variant/variant_keys.cpp b/src/duckdb/src/function/scalar/variant/variant_keys.cpp new file mode 100644 index 000000000..81f44a728 --- /dev/null +++ b/src/duckdb/src/function/scalar/variant/variant_keys.cpp @@ -0,0 +1,173 @@ +#include "duckdb/common/vector/flat_vector.hpp" +#include "duckdb/common/vector/list_vector.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/function/scalar/variant_functions.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" + +namespace duckdb { + +static void WriteKeyIdList(const UnifiedVariantVectorData &variant, Vector &key_ids, + array_ptr nested_data, const ValidityMask &object_validity, + const idx_t count) { + auto writer = FlatVector::Writer>(key_ids, count); + const auto &list_validity = FlatVector::Validity(key_ids); + + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + if (!list_validity.RowIsValid(row_idx)) { + writer.WriteNull(); + continue; + } + if (!object_validity.RowIsValid(row_idx)) { + writer.WriteList(0); + continue; + } + + const auto &[child_count, children_idx] = nested_data[row_idx]; + auto row_writer = writer.WriteList(child_count); + + idx_t child_idx = 0; + for (auto &key_id_writer : row_writer) { + const auto key_id = variant.GetKeysIndex(row_idx, children_idx + child_idx); + key_id_writer.WriteValue(key_id); + child_idx++; + } + } +} + +// TODO: Currently collection will always happen on the unshredded variant, introduce a fast path for shredded variants. +static Vector CollectVariantKeys(const UnifiedVariantVectorData &variant, + const vector &components, const idx_t count) { + // By row found keys at the requested path in the VARIANT + Vector key_ids(LogicalType::LIST(LogicalType::UBIGINT), count); + VariantPathSelection path_selection(count); + + const auto owned_nested_data = make_unsafe_uniq_array_uninitialized(count); + const array_ptr nested_data(owned_nested_data.get(), count); + + auto &list_validity = FlatVector::ValidityMutable(key_ids); + VariantUtils::TraversePath(variant, components, count, nested_data, list_validity, path_selection); + + // For the final collection of nested_data we use an auxiliary validity vector so we can distinguish "path missing" + // from "path exists but not an object" (producing NULL and [] as output respectively). + ValidityMask object_validity(count); + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + if (!list_validity.RowIsValid(row_idx)) { + object_validity.SetInvalid(row_idx); + } + } + + const auto &final_indices = path_selection.Input(components.size()); + (void)VariantUtils::CollectNestedData(variant, VariantLogicalType::OBJECT, final_indices, count, optional_idx(), 0, + nested_data, object_validity); + + WriteKeyIdList(variant, key_ids, nested_data, object_validity, count); + + return key_ids; +} + +static void UnaryVariantKeys(const Vector &variant_vec, const vector &components, Vector &result, + const idx_t count) { + RecursiveUnifiedVectorFormat source_format; + Vector::RecursiveToUnifiedFormat(variant_vec, source_format); + const UnifiedVariantVectorData variant(source_format); + + auto key_ids = CollectVariantKeys(variant, components, count); + const auto &list_validity = FlatVector::Validity(key_ids); + const auto list_entries = FlatVector::GetData(key_ids); + const auto &child = ListVector::GetChild(key_ids); + const auto key_ids_data = FlatVector::GetData(child); + + result.Initialize(VectorDataInitialization::UNINITIALIZED, count); + auto result_writer = FlatVector::Writer>(result, count); + + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + if (!list_validity.RowIsValid(row_idx)) { + result_writer.WriteNull(); + continue; + } + + auto &entry = list_entries[row_idx]; + auto row_writer = result_writer.WriteList(entry.length); + + idx_t key_idx = 0; + for (auto &key_writer : row_writer) { + const auto key_id = key_ids_data[entry.offset + key_idx++]; + key_writer.WriteValue(variant.GetKey(row_idx, key_id)); + } + } +} + +static void ManyVariantKeys(const Vector &variant_vec, const vector> &paths, + Vector &result, const idx_t count) { + vector keys_by_path; + keys_by_path.reserve(paths.size()); + + RecursiveUnifiedVectorFormat source_format; + Vector::RecursiveToUnifiedFormat(variant_vec, source_format); + const UnifiedVariantVectorData variant(source_format); + + for (const auto &path : paths) { + keys_by_path.push_back(CollectVariantKeys(variant, path, count)); + } + + result.Initialize(VectorDataInitialization::UNINITIALIZED, count); + auto result_writer = FlatVector::Writer>>(result, count); + + for (idx_t row_idx = 0; row_idx < count; row_idx++) { + auto row_writer = result_writer.WriteList(paths.size()); + idx_t path_idx = 0; + for (auto &path_keys_writer : row_writer) { + const auto &key_ids = keys_by_path[path_idx]; + const auto &list_validity = FlatVector::Validity(key_ids); + const auto list_entries = FlatVector::GetData(key_ids); + const auto &child = ListVector::GetChild(key_ids); + const auto key_ids_data = FlatVector::GetData(child); + + if (!list_validity.RowIsValid(row_idx)) { + path_keys_writer.WriteNull(); + path_idx++; + continue; + } + + auto &entry = list_entries[row_idx]; + auto keys_writer = path_keys_writer.WriteList(entry.length); + + idx_t key_idx = 0; + for (auto &key_writer : keys_writer) { + const auto key_id = key_ids_data[entry.offset + key_idx++]; + key_writer.WriteValue(variant.GetKey(row_idx, key_id)); + } + + path_idx++; + } + } +} + +static void VariantKeysFunction(DataChunk &input, ExpressionState &state, Vector &result) { + VariantUtils::ExecutePathFunction(input, state, result, UnaryVariantKeys, ManyVariantKeys); +} + +static void AddFunctionsWithParameterType(ScalarFunctionSet &fun_set, const LogicalType &input_type) { + ScalarFunction variant_keys("variant_keys", {}, LogicalType::LIST(LogicalType::VARCHAR), VariantKeysFunction, + VariantBindUtils::VariantPathBind, nullptr); + + variant_keys.GetSignature().AddParameter(input_type); + fun_set.AddFunction(variant_keys); + + variant_keys.GetSignature().AddParameter(LogicalType::VARCHAR); + fun_set.AddFunction(variant_keys); + + variant_keys.GetSignature().GetParameter(1).SetType(LogicalType::LIST(LogicalType::VARCHAR)); + variant_keys.SetReturnType(LogicalType::LIST(LogicalType::LIST(LogicalType::VARCHAR))); + fun_set.AddFunction(variant_keys); +} + +ScalarFunctionSet VariantKeysFun::GetFunctions() { + ScalarFunctionSet fun_set; + + AddFunctionsWithParameterType(fun_set, LogicalType::VARIANT()); + + return fun_set; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/variant/variant_normalize.cpp b/src/duckdb/src/function/scalar/variant/variant_normalize.cpp index 4b63015f2..c44975bb0 100644 --- a/src/duckdb/src/function/scalar/variant/variant_normalize.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_normalize.cpp @@ -184,8 +184,9 @@ void VariantNormalizer::VisitDefault(VariantLogicalType type_id, const_data_ptr_ throw InternalException("VariantLogicalType(%s) not handled", EnumUtil::ToString(type_id)); } -void VariantNormalizer::Normalize(Vector &variant_vec, Vector &result, idx_t count) { +void VariantNormalizer::Normalize(const Vector &variant_vec, Vector &result) { D_ASSERT(variant_vec.GetType() == LogicalType::VARIANT()); + const idx_t count = variant_vec.size(); //! Set up the access helper for the source VARIANT RecursiveUnifiedVectorFormat source_format; @@ -193,9 +194,9 @@ void VariantNormalizer::Normalize(Vector &variant_vec, Vector &result, idx_t cou UnifiedVariantVectorData variant(source_format); //! Take the original sizes of the lists, the result will be similar size, never bigger - auto original_keys_size = ListVector::GetTotalEntryCount(VariantVector::GetKeys(variant_vec), count); - auto original_children_size = ListVector::GetTotalEntryCount(VariantVector::GetChildren(variant_vec), count); - auto original_values_size = ListVector::GetTotalEntryCount(VariantVector::GetValues(variant_vec), count); + auto original_keys_size = ListVector::GetTotalEntryCount(VariantVector::GetKeys(variant_vec)); + auto original_children_size = ListVector::GetTotalEntryCount(VariantVector::GetChildren(variant_vec)); + auto original_values_size = ListVector::GetTotalEntryCount(VariantVector::GetValues(variant_vec)); auto &keys = VariantVector::GetKeys(result); auto &children = VariantVector::GetChildren(result); @@ -254,11 +255,9 @@ void VariantNormalizer::Normalize(Vector &variant_vec, Vector &result, idx_t cou } static void VariantNormalizeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto count = input.size(); - D_ASSERT(input.ColumnCount() == 1); - auto &variant_vec = input.data[0]; - VariantNormalizer::Normalize(variant_vec, result, count); + const auto &variant_vec = input.data[0]; + VariantNormalizer::Normalize(variant_vec, result); } ScalarFunction VariantNormalizeFun::GetFunction() { diff --git a/src/duckdb/src/function/scalar/variant/variant_typeof.cpp b/src/duckdb/src/function/scalar/variant/variant_typeof.cpp index f8c8e5921..709dcafa9 100644 --- a/src/duckdb/src/function/scalar/variant/variant_typeof.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_typeof.cpp @@ -13,7 +13,7 @@ static void VariantTypeofFunction(DataChunk &input, ExpressionState &state, Vect auto count = input.size(); D_ASSERT(input.ColumnCount() == 1); - auto &variant_vec = input.data[0]; + const auto &variant_vec = input.data[0]; D_ASSERT(variant_vec.GetType() == LogicalType::VARIANT()); RecursiveUnifiedVectorFormat source_format; diff --git a/src/duckdb/src/function/scalar/variant/variant_utils.cpp b/src/duckdb/src/function/scalar/variant/variant_utils.cpp index dd64adff8..9d199c688 100644 --- a/src/duckdb/src/function/scalar/variant/variant_utils.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_utils.cpp @@ -26,6 +26,40 @@ PhysicalType VariantDecimalData::GetPhysicalType() const { } } +void VariantUtils::ExecutePathFunction(DataChunk &input, const ExpressionState &state, Vector &result, + const unary_path_function_t &unary_fn, const many_path_function_t &many_fn) { + D_ASSERT(input.ColumnCount() == 1 || input.ColumnCount() == 2); + const auto count = input.size(); + const auto &variant_vec = input.data[0]; + + if (input.ColumnCount() == 2) { + const auto &path = input.data[1]; + D_ASSERT(path.GetVectorType() == VectorType::CONSTANT_VECTOR); + (void)path; + } + + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.BindInfo()->Cast(); + auto n_columns = input.ColumnCount(); + + if (n_columns == 1) { + unary_fn(variant_vec, {}, result, count); + return; + } + + D_ASSERT(n_columns == 2); + const auto &path_type_id = input.data[1].GetType().id(); + + if (path_type_id == LogicalTypeId::VARCHAR) { + unary_fn(variant_vec, info.paths[0], result, count); + return; + } + if (path_type_id == LogicalTypeId::LIST) { + many_fn(variant_vec, info.paths, result, count); + return; + } +} + bool VariantUtils::IsNestedType(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { auto type_id = variant.GetTypeId(row, value_index); return type_id == VariantLogicalType::ARRAY || type_id == VariantLogicalType::OBJECT; @@ -82,7 +116,7 @@ vector VariantUtils::GetObjectKeys(const UnifiedVariantVectorData &varia void VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, const VariantPathComponent &component, optional_ptr sel, SelectionVector &res, - ValidityMask &res_validity, const VariantNestedData *nested_data, + ValidityMask &res_validity, const array_ptr &nested_data, const ValidityMask &validity, idx_t count) { for (idx_t i = 0; i < count; i++) { auto row_index = sel ? sel->get_index(i) : i; @@ -99,7 +133,7 @@ void VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, cons continue; } auto value_id = variant.GetValuesIndex(row_index, nested_data_entry.children_idx + child_idx); - res[i] = static_cast(value_id); + res[i] = value_id; continue; } bool found_child = false; @@ -145,7 +179,7 @@ vector VariantUtils::ValueIsNull(const UnifiedVariantVectorData &varia VariantNestedDataCollectionResult VariantUtils::CollectNestedData(const UnifiedVariantVectorData &variant, VariantLogicalType expected_type, const SelectionVector &value_index_sel, idx_t count, optional_idx row, idx_t offset, - VariantNestedData *child_data, ValidityMask &validity) { + array_ptr child_data, ValidityMask &validity) { VariantLogicalType wrong_type = VariantLogicalType::VARIANT_NULL; for (idx_t i = 0; i < count; i++) { auto row_index = row.IsValid() ? row.GetIndex() : i; @@ -183,6 +217,36 @@ VariantUtils::CollectNestedData(const UnifiedVariantVectorData &variant, Variant return VariantNestedDataCollectionResult(); } +void VariantUtils::TraversePath(const UnifiedVariantVectorData &variant, const vector &components, + const idx_t count, array_ptr nested_data, ValidityMask &validity, + VariantPathSelection &path_selection) { + for (idx_t i = 0; i < components.size(); i++) { + auto &component = components[i]; + auto &input_indices = path_selection.Input(i); + auto &output_indices = path_selection.Output(i); + + if (component.lookup_mode == VariantChildLookupMode::BY_INDEX) { + throw InternalException("Path indexes are not supported for this function"); + } + + (void)VariantUtils::CollectNestedData(variant, VariantLogicalType::OBJECT, input_indices, count, optional_idx(), + 0, nested_data, validity); + + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(variant, component, nullptr, output_indices, lookup_validity, nested_data, + validity, count); + + for (idx_t j = 0; j < count; j++) { + if (!validity.RowIsValid(j)) { + continue; + } + if (lookup_validity.CanHaveNull() && !lookup_validity.RowIsValid(j)) { + validity.SetInvalid(j); + } + } + } +} + Value VariantUtils::ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx) { return VariantVisitor::Visit(variant, row, values_idx); } @@ -220,7 +284,7 @@ void VariantUtils::FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap arguments, LogicalType return_type, +ScalarFunction::ScalarFunction(Identifier name, FunctionSignature sig, scalar_function_t function) + : SimpleFunction(std::move(name), std::move(sig)) { + callbacks.function = std::move(function); +} + +ScalarFunction::ScalarFunction(Identifier name, vector arguments, LogicalType return_type, scalar_function_t function, bind_scalar_function_t bind, function_statistics_t statistics, init_local_state_t init_local_state, LogicalType varargs, FunctionStability side_effects, FunctionNullHandling null_handling, @@ -41,7 +46,7 @@ ScalarFunction::ScalarFunction(vector arguments, LogicalType return bind_scalar_function_t bind, function_statistics_t statistics, init_local_state_t init_local_state, LogicalType varargs, FunctionStability side_effects, FunctionNullHandling null_handling, bind_lambda_function_t bind_lambda) - : ScalarFunction(string(), std::move(arguments), std::move(return_type), std::move(function), bind, statistics, + : ScalarFunction(Identifier(), std::move(arguments), std::move(return_type), std::move(function), bind, statistics, init_local_state, std::move(varargs), side_effects, null_handling, bind_lambda) { } diff --git a/src/duckdb/src/function/scalar_macro_function.cpp b/src/duckdb/src/function/scalar_macro_function.cpp index 3760c8c6f..561f2c8e8 100644 --- a/src/duckdb/src/function/scalar_macro_function.cpp +++ b/src/duckdb/src/function/scalar_macro_function.cpp @@ -32,8 +32,9 @@ unique_ptr ScalarMacroFunction::Copy() const { void RemoveQualificationRecursive(unique_ptr &root_expr) { ParsedExpressionIterator::VisitExpressionMutable( *root_expr, [&](ColumnRefExpression &col_ref) { - auto &col_names = col_ref.column_names; - if (col_names.size() == 2 && col_names[0].find(DummyBinding::DUMMY_NAME) != string::npos) { + auto &col_names = col_ref.ColumnNamesMutable(); + if (col_names.size() == 2 && + col_names[0].GetIdentifierName().find(DummyBinding::DUMMY_NAME) != string::npos) { col_names.erase(col_names.begin()); } }); diff --git a/src/duckdb/src/function/table/arrow.cpp b/src/duckdb/src/function/table/arrow.cpp index 5007a443e..8154f3e23 100644 --- a/src/duckdb/src/function/table/arrow.cpp +++ b/src/duckdb/src/function/table/arrow.cpp @@ -200,11 +200,11 @@ void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunction data.lines_read += output_size; if (global_state.CanRemoveFilterColumns()) { state.all_columns.Reset(); - state.all_columns.SetCardinality(output_size); + state.all_columns.SetChildCardinality(output_size); ArrowToDuckDB(state, data.arrow_table.GetColumns(), state.all_columns); output.ReferenceColumns(state.all_columns, global_state.projection_ids); } else { - output.SetCardinality(output_size); + output.SetChildCardinality(output_size); ArrowToDuckDB(state, data.arrow_table.GetColumns(), output); } diff --git a/src/duckdb/src/function/table/arrow_conversion.cpp b/src/duckdb/src/function/table/arrow_conversion.cpp index e644bc229..e93003233 100644 --- a/src/duckdb/src/function/table/arrow_conversion.cpp +++ b/src/duckdb/src/function/table/arrow_conversion.cpp @@ -318,7 +318,7 @@ static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, idx_t chunk_of } } -static void ArrowToDuckDBMapVerify(Vector &vector, idx_t count) { +static void ArrowToDuckDBMapVerify(const Vector &vector, idx_t count) { auto valid_check = MapVector::CheckMapValidity(vector, count); switch (valid_check) { case MapInvalidReason::VALID: diff --git a/src/duckdb/src/function/table/checkpoint.cpp b/src/duckdb/src/function/table/checkpoint.cpp index 9fdd19a79..8ba250d3d 100644 --- a/src/duckdb/src/function/table/checkpoint.cpp +++ b/src/duckdb/src/function/table/checkpoint.cpp @@ -36,7 +36,7 @@ static unique_ptr CheckpointBind(ClientContext &context, TableFunc throw BinderException("Database cannot be NULL"); } auto &db_name = StringValue::Get(input.inputs[0]); - db = db_manager.GetDatabase(context, db_name); + db = db_manager.GetDatabase(context, Identifier(db_name)); if (!db) { throw BinderException("Database \"%s\" not found", db_name); } diff --git a/src/duckdb/src/function/table/copy_csv.cpp b/src/duckdb/src/function/table/copy_csv.cpp index c630c52f1..ac8c70f8d 100644 --- a/src/duckdb/src/function/table/copy_csv.cpp +++ b/src/duckdb/src/function/table/copy_csv.cpp @@ -117,7 +117,7 @@ void BaseCSVData::Finalize() { } static vector> CreateCastExpressions(WriteCSVData &bind_data, ClientContext &context, - const vector &names, + const vector &names, const vector &sql_types) { auto &options = bind_data.options; auto &formats = options.write_date_format; @@ -173,7 +173,7 @@ static vector> CreateCastExpressions(WriteCSVData &bind_d } static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctionBindInput &input, - const vector &names, const vector &sql_types) { + const vector &names, const vector &sql_types) { auto bind_data = make_uniq(names); // check all the options in the copy info @@ -337,7 +337,6 @@ static void WriteCSVChunkInternal(CSVWriter &writer, CSVWriterState &writer_loca DataChunk &input, ExpressionExecutor &executor) { // first cast the columns of the chunk to varchar cast_chunk.Reset(); - cast_chunk.SetCardinality(input); executor.Execute(input, cast_chunk); @@ -414,7 +413,8 @@ unique_ptr WriteCSVPrepareBatch(ClientContext &context, Funct cast_chunk.Initialize(Allocator::Get(context), types); auto &original_types = collection->Types(); - auto expressions = CreateCastExpressions(csv_data, context, csv_data.options.name_list, original_types); + auto expressions = + CreateCastExpressions(csv_data, context, StringsToIdentifiers(csv_data.options.name_list), original_types); ExpressionExecutor executor(context, expressions); auto &global_state = gstate.Cast(); diff --git a/src/duckdb/src/function/table/direct_file_reader.cpp b/src/duckdb/src/function/table/direct_file_reader.cpp index 9b7436678..b03ef12da 100644 --- a/src/duckdb/src/function/table/direct_file_reader.cpp +++ b/src/duckdb/src/function/table/direct_file_reader.cpp @@ -20,7 +20,7 @@ DirectFileReader::DirectFileReader(OpenFileInfo file_p, const LogicalType &type) DirectFileReader::~DirectFileReader() { } -unique_ptr DirectFileReader::GetStatistics(ClientContext &context, const string &name) { +unique_ptr DirectFileReader::GetStatistics(ClientContext &context, const Identifier &name) { return nullptr; } @@ -76,7 +76,6 @@ AsyncResult DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionSt // At least verify that the file exist // The globbing behavior in remote filesystems can lead to files being listed that do not actually exist if (FileSystem::IsRemoteFile(file.path) && !fs.FileExists(file.path)) { - output.SetCardinality(0); done = true; return SourceResultType::FINISHED; } @@ -177,7 +176,7 @@ AsyncResult DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionSt } } } - output.SetCardinality(1); + output.SetChildCardinality(1); done = true; return AsyncResult(SourceResultType::HAVE_MORE_OUTPUT); } diff --git a/src/duckdb/src/function/table/glob.cpp b/src/duckdb/src/function/table/glob.cpp index 9ab60e8cb..6aaeac9a4 100644 --- a/src/duckdb/src/function/table/glob.cpp +++ b/src/duckdb/src/function/table/glob.cpp @@ -53,7 +53,6 @@ static void GlobFunction(ClientContext &context, TableFunctionInput &data_p, Dat count++; state.file_list_scan.scan_type = MultiFileListScanType::FETCH_IF_AVAILABLE; } - output.SetCardinality(count); } void GlobTableFunction::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/range.cpp b/src/duckdb/src/function/table/range.cpp index 39dfccc25..ecca52ff1 100644 --- a/src/duckdb/src/function/table/range.cpp +++ b/src/duckdb/src/function/table/range.cpp @@ -148,7 +148,6 @@ static OperatorResultType RangeFunction(ExecutionContext &context, TableFunction } if (state.empty_range) { // empty range - output.SetCardinality(0); state.current_input_row++; state.initialized_row = false; return OperatorResultType::HAVE_MORE_OUTPUT; @@ -170,7 +169,6 @@ static OperatorResultType RangeFunction(ExecutionContext &context, TableFunction output.data[0].Sequence(current_value_i64, Hugeint::Cast(increment), remaining); // increment the index pointer by the remaining count state.current_idx += remaining; - output.SetCardinality(remaining); if (remaining == 0) { // move to next row state.current_input_row++; @@ -348,7 +346,6 @@ static OperatorResultType RangeDateTimeFunction(ExecutionContext &context, Table } if (state.empty_range) { // empty range - output.SetCardinality(0); state.current_input_row++; state.initialized_row = false; return OperatorResultType::HAVE_MORE_OUTPUT; @@ -406,6 +403,7 @@ void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { TableFunctionSet generate_series("generate_series"); range_function.bind = RangeFunctionBind; range_function.in_out_function = RangeFunction; + range_function.return_type = TableFunctionReturnType::SET_RETURNING_FUNCTION; range_function.GetArguments() = {LogicalType::BIGINT}; generate_series.AddFunction(range_function); range_function.GetArguments() = {LogicalType::BIGINT, LogicalType::BIGINT}; @@ -416,6 +414,7 @@ void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { nullptr, RangeDateTimeBind, nullptr, RangeDateTimeLocalInit); generate_series_in_out.in_out_function = RangeDateTimeFunction; generate_series_in_out.parallelism = TableFunctionParallelism::FORCE_SINGLE_THREADED; + generate_series_in_out.return_type = TableFunctionReturnType::SET_RETURNING_FUNCTION; generate_series.AddFunction(generate_series_in_out); set.AddFunction(generate_series); } diff --git a/src/duckdb/src/function/table/read_csv.cpp b/src/duckdb/src/function/table/read_csv.cpp index 28a58afa0..ecd0e886e 100644 --- a/src/duckdb/src/function/table/read_csv.cpp +++ b/src/duckdb/src/function/table/read_csv.cpp @@ -79,6 +79,7 @@ void ReadCSVTableFunction::ReadCSVAddNamedParameters(TableFunction &table_functi table_function.named_parameters["rejects_table"] = LogicalType::VARCHAR; table_function.named_parameters["rejects_scan"] = LogicalType::VARCHAR; table_function.named_parameters["rejects_limit"] = LogicalType::BIGINT; + table_function.named_parameters["rejects_line_size_limit"] = LogicalType::BIGINT; table_function.named_parameters["force_not_null"] = LogicalType::LIST(LogicalType::VARCHAR); table_function.named_parameters["buffer_size"] = LogicalType::UBIGINT; table_function.named_parameters["decimal_separator"] = LogicalType::VARCHAR; @@ -151,7 +152,7 @@ TableFunction ReadCSVTableFunction::GetFunction() { TableFunction ReadCSVTableFunction::GetAutoFunction() { auto read_csv_auto = ReadCSVTableFunction::GetFunction(); - read_csv_auto.name = "read_csv_auto"; + read_csv_auto.SetName("read_csv_auto"); return read_csv_auto; } @@ -184,7 +185,7 @@ unique_ptr ReadCSVReplacement(ClientContext &context, ReplacementScanI if (!FileSystem::HasGlob(table_name)) { auto &fs = FileSystem::GetFileSystem(context); - table_function->alias = fs.ExtractBaseName(table_name); + table_function->alias = Identifier(fs.ExtractBaseName(table_name)); } return std::move(table_function); diff --git a/src/duckdb/src/function/table/read_duckdb.cpp b/src/duckdb/src/function/table/read_duckdb.cpp index 9f5f83760..d96565d20 100644 --- a/src/duckdb/src/function/table/read_duckdb.cpp +++ b/src/duckdb/src/function/table/read_duckdb.cpp @@ -27,7 +27,7 @@ struct DuckDBMultiFileInfo : MultiFileReaderInterface { unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, unique_ptr options) override; - void BindReader(ClientContext &context, vector &return_types, vector &names, + void BindReader(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) override; unique_ptr InitializeGlobalState(ClientContext &context, MultiFileBindData &bind_data, MultiFileGlobalState &global_state) override; @@ -50,8 +50,8 @@ struct DuckDBMultiFileInfo : MultiFileReaderInterface { class DuckDBFileReaderOptions : public BaseFileReaderOptions { public: - string schema_name; - string table_name; + Identifier schema_name; + Identifier table_name; bool Matches(TableCatalogEntry &table) const; bool HasSelection() const; @@ -93,7 +93,7 @@ class DuckDBReader : public BaseFileReader { shared_ptr GetUnionData(idx_t file_idx) override; void FinishFile(ClientContext &context, GlobalTableFunctionState &gstate) override; double GetProgressInFile(ClientContext &context) override; - unique_ptr GetStatistics(ClientContext &context, const string &name) override; + unique_ptr GetStatistics(ClientContext &context, const Identifier &name) override; void AddVirtualColumn(column_t virtual_column_id) override; string GetReaderType() const override { return "duckdb"; @@ -110,8 +110,8 @@ class DuckDBReader : public BaseFileReader { unique_ptr global_state; atomic finished; idx_t column_count; - string schema_name; - string table_name; + Identifier schema_name; + Identifier table_name; }; struct DuckDBReadGlobalState : GlobalTableFunctionState {}; @@ -125,7 +125,7 @@ string DuckDBFileReaderOptions::GetCandidates(const vector table_names; + identifier_map_t table_names; for (auto &table : tables) { table_names[table.get().name]++; } @@ -137,10 +137,10 @@ string DuckDBFileReaderOptions::GetCandidates(const vector attach_kv; @@ -350,7 +350,7 @@ optional_idx DuckDBReader::NumRows() { return table_entry.GetStorage().GetTotalRows(); } -unique_ptr DuckDBReader::GetStatistics(ClientContext &context, const string &name) { +unique_ptr DuckDBReader::GetStatistics(ClientContext &context, const Identifier &name) { if (!scan_function.statistics) { return BaseFileReader::GetStatistics(context, name); } @@ -393,11 +393,11 @@ bool DuckDBMultiFileInfo::ParseOption(ClientContext &context, const string &key, MultiFileOptions &file_options, BaseFileReaderOptions &options_p) { auto &options = options_p.Cast(); if (key == "schema_name") { - options.schema_name = StringValue::Get(val); + options.schema_name = Identifier(StringValue::Get(val)); return true; } if (key == "table_name") { - options.table_name = StringValue::Get(val); + options.table_name = Identifier(StringValue::Get(val)); return true; } return false; @@ -424,8 +424,8 @@ unique_ptr DuckDBMultiFileInfo::InitializeBindData(MultiFileB return std::move(result); } -void DuckDBMultiFileInfo::BindReader(ClientContext &context, vector &return_types, vector &names, - MultiFileBindData &bind_data) { +void DuckDBMultiFileInfo::BindReader(ClientContext &context, vector &return_types, + vector &names, MultiFileBindData &bind_data) { auto &duckdb_bind_data = bind_data.bind_data->Cast(); bind_data.reader_bind = bind_data.multi_file_reader->BindReader(context, return_types, names, *bind_data.file_list, bind_data, @@ -472,7 +472,7 @@ shared_ptr DuckDBMultiFileInfo::CreateReader(ClientContext &cont shared_ptr DuckDBReader::GetUnionData(idx_t file_idx) { auto result = make_uniq(file); for (auto &column : columns) { - result->names.push_back(column.name); + result->names.push_back(column.name.GetIdentifierName()); result->types.push_back(column.type); } result->reader = shared_from_this(); @@ -553,7 +553,7 @@ unique_ptr ReadDuckDBTableFunction::ReplacementScan(ClientContext &con if (!FileSystem::HasGlob(table_name)) { auto &fs = FileSystem::GetFileSystem(context); - table_function->alias = fs.ExtractBaseName(table_name); + table_function->alias = Identifier(fs.ExtractBaseName(table_name)); } return std::move(table_function); } diff --git a/src/duckdb/src/function/table/read_file.cpp b/src/duckdb/src/function/table/read_file.cpp index f0f15a539..bf413abcd 100644 --- a/src/duckdb/src/function/table/read_file.cpp +++ b/src/duckdb/src/function/table/read_file.cpp @@ -28,7 +28,7 @@ struct DirectMultiFileInfo : MultiFileReaderInterface { BaseFileReaderOptions &options) override; unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, unique_ptr options) override; - void BindReader(ClientContext &context, vector &return_types, vector &names, + void BindReader(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) override; optional_idx MaxThreads(const MultiFileBindData &bind_data_p, const MultiFileGlobalState &global_state, FileExpandResult expand_result) override; @@ -81,7 +81,7 @@ unique_ptr DirectMultiFileInfo::InitializeBindData(MultiF template void DirectMultiFileInfo::BindReader(ClientContext &context, vector &return_types, - vector &names, MultiFileBindData &bind_data) { + vector &names, MultiFileBindData &bind_data) { auto &read_bind = bind_data.bind_data->Cast(); bind_data.reader_bind = bind_data.multi_file_reader->BindReader( context, return_types, names, *bind_data.file_list, bind_data, *read_bind.options, bind_data.file_options); diff --git a/src/duckdb/src/function/table/repeat.cpp b/src/duckdb/src/function/table/repeat.cpp index 1e05aca9a..7643ffedf 100644 --- a/src/duckdb/src/function/table/repeat.cpp +++ b/src/duckdb/src/function/table/repeat.cpp @@ -43,7 +43,7 @@ static void RepeatFunction(ClientContext &context, TableFunctionInput &data_p, D idx_t remaining = MinValue(bind_data.target_count - state.current_count, STANDARD_VECTOR_SIZE); output.data[0].Reference(bind_data.value, count_t(remaining)); - output.SetCardinality(remaining); + output.SetChildCardinality(remaining); state.current_count += remaining; } diff --git a/src/duckdb/src/function/table/repeat_row.cpp b/src/duckdb/src/function/table/repeat_row.cpp index 31f75c88e..82ae22675 100644 --- a/src/duckdb/src/function/table/repeat_row.cpp +++ b/src/duckdb/src/function/table/repeat_row.cpp @@ -47,7 +47,7 @@ static void RepeatRowFunction(ClientContext &context, TableFunctionInput &data_p for (idx_t val_idx = 0; val_idx < bind_data.values.size(); val_idx++) { output.data[val_idx].Reference(bind_data.values[val_idx], count_t(remaining)); } - output.SetCardinality(remaining); + output.SetChildCardinality(remaining); state.current_count += remaining; } diff --git a/src/duckdb/src/function/table/sniff_csv.cpp b/src/duckdb/src/function/table/sniff_csv.cpp index efb53ee0a..ce4ac7e44 100644 --- a/src/duckdb/src/function/table/sniff_csv.cpp +++ b/src/duckdb/src/function/table/sniff_csv.cpp @@ -176,8 +176,7 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, } string str_opt; string separator = ", "; - // Set output - output.SetCardinality(1); + // Set output (called after all Appends below) // 1. Delimiter str_opt = sniffer_options.dialect_options.state_machine_options.delimiter.FormatValue(); @@ -207,7 +206,8 @@ static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, child_list_t struct_children {{"name", sniffer_result.names[i]}, {"type", {sniffer_result.return_types[i].ToString()}}}; values.emplace_back(Value::STRUCT(struct_children)); - columns << "'" << sniffer_result.names[i] << "': '" << sniffer_result.return_types[i].ToString() << "'"; + columns << "'" << sniffer_result.names[i].GetIdentifierName() << "': '" + << sniffer_result.return_types[i].ToString() << "'"; if (i != sniffer_result.return_types.size() - 1) { columns << separator; } diff --git a/src/duckdb/src/function/table/summary.cpp b/src/duckdb/src/function/table/summary.cpp index f21963e45..4c8e3a223 100644 --- a/src/duckdb/src/function/table/summary.cpp +++ b/src/duckdb/src/function/table/summary.cpp @@ -22,8 +22,6 @@ static unique_ptr SummaryFunctionBind(ClientContext &context, Tabl static OperatorResultType SummaryFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, DataChunk &output) { - output.SetCardinality(input.size()); - auto &summary_col = output.data[0]; for (idx_t row_idx = 0; row_idx < input.size(); row_idx++) { string summary_val = "["; diff --git a/src/duckdb/src/function/table/system/duckdb_approx_database_count.cpp b/src/duckdb/src/function/table/system/duckdb_approx_database_count.cpp index 5e40586ef..e62a6a03e 100644 --- a/src/duckdb/src/function/table/system/duckdb_approx_database_count.cpp +++ b/src/duckdb/src/function/table/system/duckdb_approx_database_count.cpp @@ -31,7 +31,6 @@ void DuckDBApproxDatabaseCountFunction(ClientContext &context, TableFunctionInpu return; } output.data[0].Append(Value::UBIGINT(data.count)); - output.SetCardinality(1); data.finished = true; } diff --git a/src/duckdb/src/function/table/system/duckdb_columns.cpp b/src/duckdb/src/function/table/system/duckdb_columns.cpp index f6f680d65..9d71f3932 100644 --- a/src/duckdb/src/function/table/system/duckdb_columns.cpp +++ b/src/duckdb/src/function/table/system/duckdb_columns.cpp @@ -82,6 +82,12 @@ static unique_ptr DuckDBColumnsBind(ClientContext &context, TableF names.emplace_back("tags"); return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); + names.emplace_back("is_generated"); + return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("generation_expression"); + return_types.emplace_back(LogicalType::VARCHAR); + return nullptr; } @@ -112,6 +118,8 @@ class ColumnHelper { virtual bool IsNullable(idx_t col) = 0; virtual const Value ColumnComment(idx_t col) = 0; virtual const Value ColumnTags(idx_t col) = 0; + virtual bool IsGenerated(idx_t col) = 0; + virtual const Value GenerationExpression(idx_t col) = 0; void WriteColumns(idx_t start_col, idx_t end_col, DataChunk &output); }; @@ -157,6 +165,16 @@ class TableColumnHelper : public ColumnHelper { const Value ColumnTags(idx_t col) override { return Value::MAP(entry.GetColumn(LogicalIndex(col)).Tags()); } + bool IsGenerated(idx_t col) override { + return entry.GetColumn(LogicalIndex(col)).Generated(); + } + const Value GenerationExpression(idx_t col) override { + auto &column = entry.GetColumn(LogicalIndex(col)); + if (!column.Generated()) { + return Value(); + } + return Value(column.GeneratedExpression().ToString()); + } private: TableCatalogEntry &entry; @@ -211,10 +229,16 @@ class ViewColumnHelper : public ColumnHelper { InsertionOrderPreservingMap empty; return Value::MAP(empty); } + bool IsGenerated(idx_t col) override { + return false; + } + const Value GenerationExpression(idx_t col) override { + return Value(); + } private: ViewCatalogEntry &entry; - vector column_names; + vector column_names; vector types; bool bound_view = false; }; @@ -270,6 +294,10 @@ void ColumnHelper::WriteColumns(idx_t start_col, idx_t end_col, DataChunk &outpu auto &numeric_scale_col = output.data[17]; // tags, MAP(VARCHAR, VARCHAR) auto &tags = output.data[18]; + // is_generated, BOOLEAN + auto &is_generated = output.data[19]; + // generation_expression, VARCHAR + auto &generation_expression = output.data[20]; for (idx_t i = start_col; i < end_col; i++) { auto &entry = Entry(); @@ -345,6 +373,8 @@ void ColumnHelper::WriteColumns(idx_t start_col, idx_t end_col, DataChunk &outpu numeric_precision_radix_col.Append(numeric_precision_radix); numeric_scale_col.Append(numeric_scale); tags.Append(ColumnTags(i)); + is_generated.Append(Value::BOOLEAN(IsGenerated(i))); + generation_expression.Append(GenerationExpression(i)); } } @@ -369,7 +399,6 @@ static void DuckDBColumnsFunction(ClientContext &context, TableFunctionInput &da // Check to see if we are going to exceed the maximum index for a DataChunk if (index + (columns - column_offset) > STANDARD_VECTOR_SIZE) { idx_t column_limit = column_offset + (STANDARD_VECTOR_SIZE - index); - output.SetCardinality(STANDARD_VECTOR_SIZE); column_helper->WriteColumns(column_offset, column_limit, output); // Make the current column limit the column offset when we process the next chunk @@ -378,13 +407,13 @@ static void DuckDBColumnsFunction(ClientContext &context, TableFunctionInput &da } else { // Otherwise, write all of the columns from the current relation and // then move on to the next one. - output.SetCardinality(index + (columns - column_offset)); column_helper->WriteColumns(column_offset, columns, output); index += columns - column_offset; next++; column_offset = 0; } } + // WriteColumns appends to the child vectors - record the resulting cardinality on the chunk data.offset = next; data.column_offset = column_offset; } diff --git a/src/duckdb/src/function/table/system/duckdb_connection_count.cpp b/src/duckdb/src/function/table/system/duckdb_connection_count.cpp index 5f329936b..33f49a187 100644 --- a/src/duckdb/src/function/table/system/duckdb_connection_count.cpp +++ b/src/duckdb/src/function/table/system/duckdb_connection_count.cpp @@ -33,7 +33,6 @@ void DuckDBConnectionCountFunction(ClientContext &context, TableFunctionInput &d return; } output.data[0].Append(Value::UBIGINT(data.count)); - output.SetCardinality(1); data.finished = true; } diff --git a/src/duckdb/src/function/table/system/duckdb_constraints.cpp b/src/duckdb/src/function/table/system/duckdb_constraints.cpp index 247258ef6..a5c5acb16 100644 --- a/src/duckdb/src/function/table/system/duckdb_constraints.cpp +++ b/src/duckdb/src/function/table/system/duckdb_constraints.cpp @@ -38,7 +38,7 @@ struct DuckDBConstraintsData : public GlobalTableFunctionState { idx_t offset; idx_t constraint_offset; idx_t unique_constraint_offset; - case_insensitive_set_t constraint_names; + identifier_set_t constraint_names; }; static unique_ptr DuckDBConstraintsBind(ClientContext &context, TableFunctionBindInput &input, @@ -119,14 +119,14 @@ unique_ptr DuckDBConstraintsInit(ClientContext &contex struct ExtraConstraintInfo { vector column_indexes; - vector column_names; - string referenced_table; - vector referenced_columns; + vector column_names; + Identifier referenced_table; + vector referenced_columns; }; -void ExtractReferencedColumns(const ParsedExpression &root_expr, vector &result) { +void ExtractReferencedColumns(const ParsedExpression &root_expr, vector &result) { ParsedExpressionIterator::VisitExpression( - root_expr, [&](const ColumnRefExpression &colref) { result.push_back(colref.GetColumnName()); }); + root_expr, [&](const ColumnRefExpression &colref) { result.emplace_back(colref.GetColumnName()); }); } ExtraConstraintInfo GetExtraConstraintInfo(const TableCatalogEntry &table, const Constraint &constraint) { @@ -154,7 +154,7 @@ ExtraConstraintInfo GetExtraConstraintInfo(const TableCatalogEntry &table, const case ConstraintType::FOREIGN_KEY: { auto &fk = constraint.Cast(); result.referenced_columns = fk.pk_columns; - result.referenced_table = fk.info.table; + result.referenced_table = Identifier(fk.info.table.GetIdentifierName()); result.column_names = fk.fk_columns; break; } @@ -164,12 +164,13 @@ ExtraConstraintInfo GetExtraConstraintInfo(const TableCatalogEntry &table, const if (result.column_indexes.empty()) { // generate column indexes from names for (auto &name : result.column_names) { - result.column_indexes.push_back(table.GetColumnIndex(name)); + auto col_name = name; + result.column_indexes.push_back(table.GetColumnIndex(col_name)); } } else { // generate names from column indexes for (auto &index : result.column_indexes) { - result.column_names.push_back(table.GetColumn(index).GetName()); + result.column_names.emplace_back(table.GetColumn(index).GetName()); } } return result; @@ -178,10 +179,10 @@ ExtraConstraintInfo GetExtraConstraintInfo(const TableCatalogEntry &table, const string GetConstraintName(const TableCatalogEntry &table, Constraint &constraint, const ExtraConstraintInfo &info) { string result = table.name + "_"; for (auto &col : info.column_names) { - result += StringUtil::Lower(col) + "_"; + result += StringUtil::Lower(col.GetIdentifierName()) + "_"; } for (auto &col : info.referenced_columns) { - result += StringUtil::Lower(col) + "_"; + result += StringUtil::Lower(col.GetIdentifierName()) + "_"; } switch (constraint.type) { case ConstraintType::CHECK: @@ -288,10 +289,10 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ auto info = GetExtraConstraintInfo(table, *constraint); auto constraint_name = GetConstraintName(table, *constraint, info); - if (data.constraint_names.find(constraint_name) != data.constraint_names.end()) { + if (data.constraint_names.find(Identifier(constraint_name)) != data.constraint_names.end()) { // duplicate constraint name idx_t index = 2; - while (data.constraint_names.find(constraint_name + "_" + to_string(index)) != + while (data.constraint_names.find(Identifier(constraint_name + "_" + to_string(index))) != data.constraint_names.end()) { index++; } @@ -333,7 +334,6 @@ void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_ data.offset++; } } - output.SetCardinality(count); } void DuckDBConstraintsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_coordinate_systems.cpp b/src/duckdb/src/function/table/system/duckdb_coordinate_systems.cpp index f09ca08c8..46a52d4af 100644 --- a/src/duckdb/src/function/table/system/duckdb_coordinate_systems.cpp +++ b/src/duckdb/src/function/table/system/duckdb_coordinate_systems.cpp @@ -122,7 +122,6 @@ static void DuckDBCoordinateSystemsFunction(ClientContext &context, TableFunctio count++; } - output.SetCardinality(count); } void DuckDBCoordinateSystemsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_databases.cpp b/src/duckdb/src/function/table/system/duckdb_databases.cpp index c62398515..4b2abbbb0 100644 --- a/src/duckdb/src/function/table/system/duckdb_databases.cpp +++ b/src/duckdb/src/function/table/system/duckdb_databases.cpp @@ -131,7 +131,6 @@ void DuckDBDatabasesFunction(ClientContext &context, TableFunctionInput &data_p, count++; } - output.SetCardinality(count); } void DuckDBDatabasesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_dependencies.cpp b/src/duckdb/src/function/table/system/duckdb_dependencies.cpp index 24b5bc4d4..40bc783c1 100644 --- a/src/duckdb/src/function/table/system/duckdb_dependencies.cpp +++ b/src/duckdb/src/function/table/system/duckdb_dependencies.cpp @@ -56,7 +56,7 @@ unique_ptr DuckDBDependenciesInit(ClientContext &conte auto result = make_uniq(); // scan all the schemas and collect them - auto &catalog = Catalog::GetCatalog(context, INVALID_CATALOG); + auto &catalog = Catalog::GetCatalog(context, Identifier::InvalidCatalog()); auto dependency_manager = catalog.GetDependencyManager(); if (dependency_manager) { dependency_manager->Scan( @@ -113,7 +113,6 @@ void DuckDBDependenciesFunction(ClientContext &context, TableFunctionInput &data data.offset++; count++; } - output.SetCardinality(count); } void DuckDBDependenciesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_eviction_queues.cpp b/src/duckdb/src/function/table/system/duckdb_eviction_queues.cpp index d2ce0df3c..0e1c29288 100644 --- a/src/duckdb/src/function/table/system/duckdb_eviction_queues.cpp +++ b/src/duckdb/src/function/table/system/duckdb_eviction_queues.cpp @@ -65,7 +65,6 @@ void DuckDBEvictionQueuesFunction(ClientContext &context, TableFunctionInput &da total_insertions.Append(Value::BIGINT(UnsafeNumericCast(entry.total_insertions))); count++; } - output.SetCardinality(count); } void DuckDBEvictionQueuesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_extensions.cpp b/src/duckdb/src/function/table/system/duckdb_extensions.cpp index 101678c88..83529d8b0 100644 --- a/src/duckdb/src/function/table/system/duckdb_extensions.cpp +++ b/src/duckdb/src/function/table/system/duckdb_extensions.cpp @@ -87,7 +87,7 @@ unique_ptr DuckDBExtensionsInit(ClientContext &context extension.statically_loaded ? ExtensionInstallMode::STATICALLY_LINKED : ExtensionInstallMode::NOT_INSTALLED; info.description = extension.description; for (idx_t k = 0; k < alias_count; k++) { - auto alias = ExtensionHelper::GetExtensionAlias(k); + auto alias = ExtensionHelper::GetInternalExtensionAlias(k); if (info.name == alias.extension) { info.aliases.emplace_back(alias.alias); } @@ -230,7 +230,6 @@ void DuckDBExtensionsFunction(ClientContext &context, TableFunctionInput &data_p data.offset++; count++; } - output.SetCardinality(count); } void DuckDBExtensionsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_external_file_cache.cpp b/src/duckdb/src/function/table/system/duckdb_external_file_cache.cpp index c0a76b544..3a87e393c 100644 --- a/src/duckdb/src/function/table/system/duckdb_external_file_cache.cpp +++ b/src/duckdb/src/function/table/system/duckdb_external_file_cache.cpp @@ -62,7 +62,6 @@ void DuckDBExternalFileCacheFunction(ClientContext &context, TableFunctionInput loaded.Append(Value::BOOLEAN(entry.loaded)); count++; } - output.SetCardinality(count); } void DuckDBExternalFileCacheFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_functions.cpp b/src/duckdb/src/function/table/system/duckdb_functions.cpp index 51e14c5ff..8f19feefa 100644 --- a/src/duckdb/src/function/table/system/duckdb_functions.cpp +++ b/src/duckdb/src/function/table/system/duckdb_functions.cpp @@ -779,7 +779,6 @@ void DuckDBFunctionsFunction(ClientContext &context, TableFunctionInput &data_p, } count++; } - output.SetCardinality(count); } void DuckDBFunctionsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_indexes.cpp b/src/duckdb/src/function/table/system/duckdb_indexes.cpp index 126ec96a1..71a63742b 100644 --- a/src/duckdb/src/function/table/system/duckdb_indexes.cpp +++ b/src/duckdb/src/function/table/system/duckdb_indexes.cpp @@ -154,7 +154,6 @@ void DuckDBIndexesFunction(ClientContext &context, TableFunctionInput &data_p, D count++; } - output.SetCardinality(count); } void DuckDBIndexesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_keywords.cpp b/src/duckdb/src/function/table/system/duckdb_keywords.cpp index e390a1346..3b0d48b29 100644 --- a/src/duckdb/src/function/table/system/duckdb_keywords.cpp +++ b/src/duckdb/src/function/table/system/duckdb_keywords.cpp @@ -71,7 +71,6 @@ void DuckDBKeywordsFunction(ClientContext &context, TableFunctionInput &data_p, count++; } - output.SetCardinality(count); } void DuckDBKeywordsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_memory.cpp b/src/duckdb/src/function/table/system/duckdb_memory.cpp index 8225a2021..f2cce56cd 100644 --- a/src/duckdb/src/function/table/system/duckdb_memory.cpp +++ b/src/duckdb/src/function/table/system/duckdb_memory.cpp @@ -62,7 +62,6 @@ void DuckDBMemoryFunction(ClientContext &context, TableFunctionInput &data_p, Da temporary_storage_bytes.Append(Value::BIGINT(ClampReportedMemory(entry.evicted_data))); count++; } - output.SetCardinality(count); } void DuckDBMemoryFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_metrics.cpp b/src/duckdb/src/function/table/system/duckdb_metrics.cpp new file mode 100644 index 000000000..b94648ec7 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_metrics.cpp @@ -0,0 +1,66 @@ +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/metrics_manager.hpp" + +namespace duckdb { + +struct DuckDBMetricsData : public GlobalTableFunctionState { + DuckDBMetricsData() : offset(0) { + } + + vector metrics; + idx_t offset; +}; + +static unique_ptr DuckDBMetricsBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("metric_name"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("metric_type"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("description"); + return_types.emplace_back(LogicalType::VARCHAR); + + names.emplace_back("unit"); + return_types.emplace_back(LogicalType::VARCHAR); + + return nullptr; +} + +unique_ptr DuckDBMetricsInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + result->metrics = MetricsManager::Get(context).GetAllMetrics(); + return std::move(result); +} + +void DuckDBMetricsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.offset >= data.metrics.size()) { + return; + } + + idx_t count = 0; + auto &metric_name = output.data[0]; + auto &metric_type = output.data[1]; + auto &description = output.data[2]; + auto &unit = output.data[3]; + + while (data.offset < data.metrics.size() && count < STANDARD_VECTOR_SIZE) { + auto &entry = data.metrics[data.offset++]; + + metric_name.Append(Value(entry.name)); + metric_type.Append(Value(entry.metric_type)); + description.Append(Value(entry.description)); + unit.Append(Value(entry.unit)); + count++; + } +} + +void DuckDBMetricsFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction( + TableFunction("duckdb_available_metrics", {}, DuckDBMetricsFunction, DuckDBMetricsBind, DuckDBMetricsInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_optimizers.cpp b/src/duckdb/src/function/table/system/duckdb_optimizers.cpp index e49695158..3dd0f899c 100644 --- a/src/duckdb/src/function/table/system/duckdb_optimizers.cpp +++ b/src/duckdb/src/function/table/system/duckdb_optimizers.cpp @@ -48,7 +48,6 @@ void DuckDBOptimizersFunction(ClientContext &context, TableFunctionInput &data_p name.Append(Value(entry)); count++; } - output.SetCardinality(count); } void DuckDBOptimizersFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_prepared_statements.cpp b/src/duckdb/src/function/table/system/duckdb_prepared_statements.cpp index a9651a73b..3df7fd683 100644 --- a/src/duckdb/src/function/table/system/duckdb_prepared_statements.cpp +++ b/src/duckdb/src/function/table/system/duckdb_prepared_statements.cpp @@ -104,7 +104,6 @@ void DuckDBPreparedStatementsFunction(ClientContext &context, TableFunctionInput } count++; } - output.SetCardinality(count); } void DuckDBPreparedStatementsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_schemas.cpp b/src/duckdb/src/function/table/system/duckdb_schemas.cpp index bc0fa7668..b7e20c3c5 100644 --- a/src/duckdb/src/function/table/system/duckdb_schemas.cpp +++ b/src/duckdb/src/function/table/system/duckdb_schemas.cpp @@ -96,7 +96,6 @@ void DuckDBSchemasFunction(ClientContext &context, TableFunctionInput &data_p, D data.offset++; count++; } - output.SetCardinality(count); } void DuckDBSchemasFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_secret_types.cpp b/src/duckdb/src/function/table/system/duckdb_secret_types.cpp index 053910750..fe60d8365 100644 --- a/src/duckdb/src/function/table/system/duckdb_secret_types.cpp +++ b/src/duckdb/src/function/table/system/duckdb_secret_types.cpp @@ -64,7 +64,6 @@ void DuckDBSecretTypesFunction(ClientContext &context, TableFunctionInput &data_ count++; } - output.SetCardinality(count); } void DuckDBSecretTypesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_secrets.cpp b/src/duckdb/src/function/table/system/duckdb_secrets.cpp index 8847edaab..b698e9f8d 100644 --- a/src/duckdb/src/function/table/system/duckdb_secrets.cpp +++ b/src/duckdb/src/function/table/system/duckdb_secrets.cpp @@ -132,7 +132,6 @@ void DuckDBSecretsFunction(ClientContext &context, TableFunctionInput &data_p, D data.offset++; count++; } - output.SetCardinality(count); } void DuckDBSecretsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_sequences.cpp b/src/duckdb/src/function/table/system/duckdb_sequences.cpp index 6eae85bb5..441576a32 100644 --- a/src/duckdb/src/function/table/system/duckdb_sequences.cpp +++ b/src/duckdb/src/function/table/system/duckdb_sequences.cpp @@ -153,7 +153,6 @@ void DuckDBSequencesFunction(ClientContext &context, TableFunctionInput &data_p, count++; } - output.SetCardinality(count); } void DuckDBSequencesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_settings.cpp b/src/duckdb/src/function/table/system/duckdb_settings.cpp index 5a18f6e31..9184ae860 100644 --- a/src/duckdb/src/function/table/system/duckdb_settings.cpp +++ b/src/duckdb/src/function/table/system/duckdb_settings.cpp @@ -145,7 +145,6 @@ void DuckDBSettingsFunction(ClientContext &context, TableFunctionInput &data_p, aliases.Append(Value::LIST(LogicalType::VARCHAR, std::move(entry.aliases))); count++; } - output.SetCardinality(count); } void DuckDBSettingsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_tables.cpp b/src/duckdb/src/function/table/system/duckdb_tables.cpp index c1b585c09..2b31e44c4 100644 --- a/src/duckdb/src/function/table/system/duckdb_tables.cpp +++ b/src/duckdb/src/function/table/system/duckdb_tables.cpp @@ -172,7 +172,6 @@ void DuckDBTablesFunction(ClientContext &context, TableFunctionInput &data_p, Da sql.Append(Value(table_info->ToString())); count++; } - output.SetCardinality(count); } void DuckDBTablesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp b/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp index 57390dfe8..690cf4da0 100644 --- a/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp +++ b/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp @@ -50,7 +50,6 @@ void DuckDBTemporaryFilesFunction(ClientContext &context, TableFunctionInput &da size.Append(Value::BIGINT(NumericCast(entry.size))); count++; } - output.SetCardinality(count); } void DuckDBTemporaryFilesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_triggers.cpp b/src/duckdb/src/function/table/system/duckdb_triggers.cpp index aacf15da0..ceab6f2b0 100644 --- a/src/duckdb/src/function/table/system/duckdb_triggers.cpp +++ b/src/duckdb/src/function/table/system/duckdb_triggers.cpp @@ -146,7 +146,6 @@ void DuckDBTriggersFunction(ClientContext &context, TableFunctionInput &data_p, count++; } - output.SetCardinality(count); } void DuckDBTriggersFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_types.cpp b/src/duckdb/src/function/table/system/duckdb_types.cpp index b918bda07..aaa1f77c5 100644 --- a/src/duckdb/src/function/table/system/duckdb_types.cpp +++ b/src/duckdb/src/function/table/system/duckdb_types.cpp @@ -215,7 +215,6 @@ void DuckDBTypesFunction(ClientContext &context, TableFunctionInput &data_p, Dat count++; } - output.SetCardinality(count); } void DuckDBTypesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_variables.cpp b/src/duckdb/src/function/table/system/duckdb_variables.cpp index 7b376f18d..4694de73c 100644 --- a/src/duckdb/src/function/table/system/duckdb_variables.cpp +++ b/src/duckdb/src/function/table/system/duckdb_variables.cpp @@ -76,7 +76,6 @@ void DuckDBVariablesFunction(ClientContext &context, TableFunctionInput &data_p, type.Append(Value(variable_entry.value.type().ToString())); count++; } - output.SetCardinality(count); } void DuckDBVariablesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_views.cpp b/src/duckdb/src/function/table/system/duckdb_views.cpp index 0a6cf7d93..8cf4c5a77 100644 --- a/src/duckdb/src/function/table/system/duckdb_views.cpp +++ b/src/duckdb/src/function/table/system/duckdb_views.cpp @@ -163,7 +163,6 @@ void DuckDBViewsFunction(ClientContext &context, TableFunctionInput &data_p, Dat } count++; } - output.SetCardinality(count); } void DuckDBViewsFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/duckdb_which_secret.cpp b/src/duckdb/src/function/table/system/duckdb_which_secret.cpp index bd6cd8673..55ee22467 100644 --- a/src/duckdb/src/function/table/system/duckdb_which_secret.cpp +++ b/src/duckdb/src/function/table/system/duckdb_which_secret.cpp @@ -59,7 +59,6 @@ void DuckDBWhichSecretFunction(ClientContext &context, TableFunctionInput &data_ auto secret_match = secret_manager.LookupSecret(transaction, path, type); if (secret_match.HasMatch()) { auto &secret_entry = *secret_match.secret_entry; - output.SetCardinality(1); output.data[0].Append(Value(secret_entry.secret->GetName())); output.data[1].Append(Value(EnumUtil::ToString(secret_entry.persist_type))); output.data[2].Append(Value(secret_entry.storage_mode)); diff --git a/src/duckdb/src/function/table/system/enable_profiling.cpp b/src/duckdb/src/function/table/system/enable_profiling.cpp index 3ad530148..78506dc36 100644 --- a/src/duckdb/src/function/table/system/enable_profiling.cpp +++ b/src/duckdb/src/function/table/system/enable_profiling.cpp @@ -52,7 +52,11 @@ static void EnableProfiling(ClientContext &context, TableFunctionInput &data, Da } if (!bind_data.metrics.IsNull()) { - ConfigureProfilingSetting::SetLocal(context, bind_data.metrics); + Value metrics_value = bind_data.metrics; + if (metrics_value.type().id() == LogicalTypeId::VARCHAR) { + metrics_value = Value::LIST(LogicalType::VARCHAR, {metrics_value}); + } + TrackedMetricsSetting::SetLocal(context, metrics_value); } } @@ -67,7 +71,7 @@ static unique_ptr BindEnableProfiling(ClientContext &context, Tabl bool metrics_set = false; for (const auto &named_param : input.named_parameters) { - const auto key = EnumUtil::FromString(named_param.first); + const auto key = EnumUtil::FromString(named_param.first.GetIdentifierName()); switch (key) { case ProfilingParameterNames::FORMAT: bind_data->format = StringUtil::Lower(named_param.second.ToString()); @@ -83,9 +87,8 @@ static unique_ptr BindEnableProfiling(ClientContext &context, Tabl break; case ProfilingParameterNames::METRICS: { if (named_param.second.type() != LogicalType::LIST(LogicalType::VARCHAR) && - named_param.second.type().id() != LogicalTypeId::STRUCT && named_param.second.type() != LogicalType::VARCHAR) { - throw InvalidInputException("EnableProfiling: metrics must be a list of strings or a JSON string"); + throw InvalidInputException("EnableProfiling: metrics must be a list of strings or a VARCHAR pattern"); } bind_data->metrics = named_param.second; @@ -100,8 +103,8 @@ static unique_ptr BindEnableProfiling(ClientContext &context, Tabl throw InvalidInputException("EnableProfiling: cannot specify both metrics and positional parameters"); } if (input.inputs[0].type() != LogicalType::LIST(LogicalType::VARCHAR) && - input.inputs[0].type().id() != LogicalTypeId::STRUCT && input.inputs[0].type() != LogicalType::VARCHAR) { - throw InvalidInputException("EnableProfiling: metrics must be a list of strings or a JSON string"); + input.inputs[0].type() != LogicalType::VARCHAR) { + throw InvalidInputException("EnableProfiling: metrics must be a list of strings or a VARCHAR pattern"); } bind_data->metrics = input.inputs[0]; diff --git a/src/duckdb/src/function/table/system/logging_utils.cpp b/src/duckdb/src/function/table/system/logging_utils.cpp index 160816858..aa1c34e73 100644 --- a/src/duckdb/src/function/table/system/logging_utils.cpp +++ b/src/duckdb/src/function/table/system/logging_utils.cpp @@ -53,10 +53,9 @@ static unique_ptr BindEnableLogging(ClientContext &context, TableF auto result = make_uniq(); bool storage_isset = false; - bool storage_path_isset = false; for (const auto ¶m : input.named_parameters) { - auto key = StringUtil::Lower(param.first); + auto &key = param.first; if (key == "level") { result->config.level = EnumUtil::FromString(param.second.ToString()); } else if (key == "storage") { @@ -68,10 +67,10 @@ static unique_ptr BindEnableLogging(ClientContext &context, TableF } auto &children = StructValue::GetChildren(param.second); for (idx_t i = 0; i < children.size(); i++) { - result->storage_config[StructType::GetChildName(param.second.type(), i)] = children[i]; + result->storage_config[StructType::GetChildName(param.second.type(), i).GetIdentifierName()] = + children[i]; } } else if (key == "storage_path") { - storage_path_isset = true; result->storage_config["path"] = param.second; } else if (key == "storage_normalize") { result->storage_config["normalize"] = param.second; @@ -82,9 +81,15 @@ static unique_ptr BindEnableLogging(ClientContext &context, TableF } } - // This will implicitly set the log storage if the storage_path param is set and the storage is omitted - if (!storage_isset && storage_path_isset) { - result->config.storage = LogConfig::FILE_STORAGE_NAME; + // If the user didn't explicitly set storage=, infer it. A 'path' in storage_config (provided + // either via storage_path= or storage_config={path:...}) implies file storage. Otherwise + // preserve the currently configured storage so logging_storage isn't silently reset. + if (!storage_isset) { + if (result->storage_config.find("path") != result->storage_config.end()) { + result->config.storage = LogConfig::FILE_STORAGE_NAME; + } else { + result->config.storage = context.db->GetLogManager().GetConfig().storage; + } } // Process positional params diff --git a/src/duckdb/src/function/table/system/pragma_collations.cpp b/src/duckdb/src/function/table/system/pragma_collations.cpp index 5d1f4f584..de16983cc 100644 --- a/src/duckdb/src/function/table/system/pragma_collations.cpp +++ b/src/duckdb/src/function/table/system/pragma_collations.cpp @@ -29,7 +29,7 @@ unique_ptr PragmaCollateInit(ClientContext &context, T auto schemas = Catalog::GetAllSchemas(context); for (auto schema : schemas) { schema.get().Scan(context, CatalogType::COLLATION_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry.name); }); + [&](CatalogEntry &entry) { result->entries.push_back(entry.name.GetIdentifierName()); }); } return std::move(result); } @@ -41,7 +41,6 @@ static void PragmaCollateFunction(ClientContext &context, TableFunctionInput &da return; } idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, data.entries.size()); - output.SetCardinality(next - data.offset); // collname, VARCHAR auto &collname = output.data[0]; @@ -49,7 +48,6 @@ static void PragmaCollateFunction(ClientContext &context, TableFunctionInput &da for (idx_t i = data.offset; i < next; i++) { collname.Append(Value(data.entries[i])); } - data.offset = next; } diff --git a/src/duckdb/src/function/table/system/pragma_database_size.cpp b/src/duckdb/src/function/table/system/pragma_database_size.cpp index 257fbd808..0bd4a3f69 100644 --- a/src/duckdb/src/function/table/system/pragma_database_size.cpp +++ b/src/duckdb/src/function/table/system/pragma_database_size.cpp @@ -97,7 +97,6 @@ void PragmaDatabaseSizeFunction(ClientContext &context, TableFunctionInput &data memory_limit.Append(data.memory_limit); row++; } - output.SetCardinality(row); } void PragmaDatabaseSize::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/pragma_metadata_info.cpp b/src/duckdb/src/function/table/system/pragma_metadata_info.cpp index 526003ea0..2daa75c69 100644 --- a/src/duckdb/src/function/table/system/pragma_metadata_info.cpp +++ b/src/duckdb/src/function/table/system/pragma_metadata_info.cpp @@ -34,14 +34,14 @@ static unique_ptr PragmaMetadataInfoBind(ClientContext &context, T names.emplace_back("free_list"); return_types.emplace_back(LogicalType::LIST(LogicalType::BIGINT)); - string db_name; + Identifier db_name; if (input.inputs.empty()) { db_name = DatabaseManager::GetDefaultDatabase(context); } else { if (input.inputs[0].IsNull()) { throw BinderException("Database argument for pragma_metadata_info cannot be NULL"); } - db_name = StringValue::Get(input.inputs[0]); + db_name = Identifier(StringValue::Get(input.inputs[0])); } auto &catalog = Catalog::GetCatalog(context, db_name); auto result = make_uniq(); @@ -80,7 +80,6 @@ static void PragmaMetadataInfoFunction(ClientContext &context, TableFunctionInpu free_list.Append(Value::LIST(LogicalType::BIGINT, std::move(list_values))); count++; } - output.SetCardinality(count); } void PragmaMetadataInfo::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/pragma_storage_info.cpp b/src/duckdb/src/function/table/system/pragma_storage_info.cpp index e4ab711bd..eb54a712a 100644 --- a/src/duckdb/src/function/table/system/pragma_storage_info.cpp +++ b/src/duckdb/src/function/table/system/pragma_storage_info.cpp @@ -2,35 +2,47 @@ #include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" #include "duckdb/parser/qualified_name.hpp" -#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" -#include "duckdb/planner/constraints/bound_unique_constraint.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/execution/partition_info.hpp" +#include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/table/row_group_collection.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/storage/table/column_data.hpp" - -#include +#include "duckdb/common/enums/column_segment_info_scan_type.hpp" namespace duckdb { struct PragmaStorageFunctionData : public TableFunctionData { - explicit PragmaStorageFunctionData(TableCatalogEntry &table_entry) : table_entry(table_entry) { + PragmaStorageFunctionData(TableCatalogEntry &table_entry, ColumnSegmentInfoScanOptions options) + : table_entry(table_entry), options(options) { } TableCatalogEntry &table_entry; - vector column_segments_info; + ColumnSegmentInfoScanOptions options; }; -struct PragmaStorageOperatorData : public GlobalTableFunctionState { - PragmaStorageOperatorData() : offset(0) { +struct PragmaStorageGlobalState : public GlobalTableFunctionState { + explicit PragmaStorageGlobalState(idx_t max_threads_p) : max_threads(max_threads_p) { + } + + mutex lock; + idx_t max_threads; + ColumnSegmentInfoScanState scan_state; + + idx_t MaxThreads() const override { + return max_threads; } +}; - idx_t offset; +struct PragmaStorageLocalState : public LocalTableFunctionState { + //! Buffered column segment info for the row group this thread is currently emitting. + vector buffer; + idx_t buffer_offset = 0; + //! Index of the row group currently being emitted; used as the batch index for partitioning. + idx_t batch_index = 0; }; static unique_ptr PragmaStorageInfoBind(ClientContext &context, TableFunctionBindInput &input, @@ -77,8 +89,21 @@ static unique_ptr PragmaStorageInfoBind(ClientContext &context, Ta names.emplace_back("block_offset"); return_types.emplace_back(LogicalType::BIGINT); - names.emplace_back("segment_info"); - return_types.emplace_back(LogicalType::VARCHAR); + ColumnSegmentInfoScanOptions options; + auto include_entry = input.named_parameters.find("include_segment_info"); + if (include_entry != input.named_parameters.end()) { + options.include_segment_info = include_entry->second.GetValue(); + } + auto loaded_only_entry = input.named_parameters.find("loaded_segments_only"); + if (loaded_only_entry != input.named_parameters.end()) { + options.loaded_segments_only = loaded_only_entry->second.GetValue(); + } + + // segment_info is only emitted when the (expensive) per-segment info was actually collected. + if (options.include_segment_info) { + names.emplace_back("segment_info"); + return_types.emplace_back(LogicalType::VARCHAR); + } names.emplace_back("additional_block_ids"); return_types.emplace_back(LogicalType::LIST(LogicalTypeId::BIGINT)); @@ -88,13 +113,22 @@ static unique_ptr PragmaStorageInfoBind(ClientContext &context, Ta // look up the table name in the catalog Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); auto &table_entry = Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); - auto result = make_uniq(table_entry); - result->column_segments_info = table_entry.GetColumnSegmentInfo(context); - return std::move(result); + return make_uniq(table_entry, options); } -unique_ptr PragmaStorageInfoInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); +unique_ptr PragmaStorageInfoInitGlobal(ClientContext &context, + TableFunctionInitInput &input) { + auto &bind_data = input.bind_data->Cast(); + auto max_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); + auto gstate = make_uniq(max_threads); + gstate->scan_state.options = bind_data.options; + bind_data.table_entry.InitializeColumnSegmentInfoScan(gstate->scan_state); + return std::move(gstate); +} + +unique_ptr PragmaStorageInfoInitLocal(ExecutionContext &context, TableFunctionInitInput &input, + GlobalTableFunctionState *global_state) { + return make_uniq(); } static Value ValueFromBlockIdList(const vector &block_ids) { @@ -107,45 +141,57 @@ static Value ValueFromBlockIdList(const vector &block_ids) { static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { auto &bind_data = data_p.bind_data->Cast(); - auto &data = data_p.global_state->Cast(); - idx_t count = 0; + auto &gstate = data_p.global_state->Cast(); + auto &lstate = data_p.local_state->Cast(); auto &columns = bind_data.table_entry.GetColumns(); + QueryContext query_context(context); + + idx_t count = 0; + + idx_t col_idx = 0; + auto &row_group_id = output.data[col_idx++]; + auto &column_name = output.data[col_idx++]; + auto &column_id = output.data[col_idx++]; + auto &column_path = output.data[col_idx++]; + auto &segment_id = output.data[col_idx++]; + auto &segment_type = output.data[col_idx++]; + auto &start = output.data[col_idx++]; + auto &count_col = output.data[col_idx++]; + auto &compression = output.data[col_idx++]; + auto &stats = output.data[col_idx++]; + auto &has_updates = output.data[col_idx++]; + auto &persistent = output.data[col_idx++]; + auto &block_id = output.data[col_idx++]; + auto &block_offset = output.data[col_idx++]; + optional_ptr segment_info = bind_data.options.include_segment_info ? &output.data[col_idx++] : nullptr; + auto &additional_block_ids = output.data[col_idx++]; + + while (count < STANDARD_VECTOR_SIZE) { + if (lstate.buffer_offset >= lstate.buffer.size()) { + // drained the current buffer; pull the next row group under the global lock + lstate.buffer.clear(); + lstate.buffer_offset = 0; + bool has_more; + { + lock_guard guard(gstate.lock); + has_more = bind_data.table_entry.ScanColumnSegmentInfo(query_context, gstate.scan_state, lstate.buffer); + } + if (!has_more) { + break; + } + if (lstate.buffer.empty()) { + continue; + } + } - // row_group_id - auto &row_group_id = output.data[0]; - // column_name - auto &column_name = output.data[1]; - // column_id - auto &column_id = output.data[2]; - // column_path - auto &column_path = output.data[3]; - // segment_id - auto &segment_id = output.data[4]; - // segment_type - auto &segment_type = output.data[5]; - // start - auto &start = output.data[6]; - // count - auto &count_col = output.data[7]; - // compression - auto &compression = output.data[8]; - // stats - auto &stats = output.data[9]; - // has_updates - auto &has_updates = output.data[10]; - // persistent - auto &persistent = output.data[11]; - // block_id - auto &block_id = output.data[12]; - // block_offset - auto &block_offset = output.data[13]; - // segment_info - auto &segment_info = output.data[14]; - // additional_block_ids - auto &additional_block_ids = output.data[15]; - - while (data.offset < bind_data.column_segments_info.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = bind_data.column_segments_info[data.offset++]; + auto &entry = lstate.buffer[lstate.buffer_offset]; + // keep each output chunk to a single row group so batch_index identifies its contents. + // merging is allowed when the next entry happens to belong to the same row group. + if (count > 0 && entry.row_group_index != lstate.batch_index) { + break; + } + lstate.buffer_offset++; + lstate.batch_index = entry.row_group_index; row_group_id.Append(Value::BIGINT(NumericCast(entry.row_group_index))); auto &col = columns.GetColumn(PhysicalIndex(entry.column_id)); @@ -167,7 +213,9 @@ static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput block_id.Append(Value()); block_offset.Append(Value()); } - segment_info.Append(Value(entry.segment_info)); + if (segment_info) { + segment_info->Append(Value(entry.segment_info)); + } if (entry.persistent) { additional_block_ids.Append(ValueFromBlockIdList(entry.additional_blocks)); } else { @@ -176,12 +224,21 @@ static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput count++; } - output.SetCardinality(count); +} + +static OperatorPartitionData PragmaStorageInfoGetPartitionData(ClientContext &context, + TableFunctionGetPartitionInput &input) { + auto &lstate = input.local_state->Cast(); + return OperatorPartitionData(lstate.batch_index); } void PragmaStorageInfo::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_storage_info", {LogicalType::VARCHAR}, PragmaStorageInfoFunction, - PragmaStorageInfoBind, PragmaStorageInfoInit)); + TableFunction storage_info("pragma_storage_info", {LogicalType::VARCHAR}, PragmaStorageInfoFunction, + PragmaStorageInfoBind, PragmaStorageInfoInitGlobal, PragmaStorageInfoInitLocal); + storage_info.get_partition_data = PragmaStorageInfoGetPartitionData; + storage_info.named_parameters["include_segment_info"] = LogicalType::BOOLEAN; + storage_info.named_parameters["loaded_segments_only"] = LogicalType::BOOLEAN; + set.AddFunction(std::move(storage_info)); } } // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_table_info.cpp b/src/duckdb/src/function/table/system/pragma_table_info.cpp index 93312624b..0d8bf9eb6 100644 --- a/src/duckdb/src/function/table/system/pragma_table_info.cpp +++ b/src/duckdb/src/function/table/system/pragma_table_info.cpp @@ -85,7 +85,7 @@ struct PragmaTableInfoHelper { output.data[5].Append(Value::BOOLEAN(constraint_info.pk)); } - static void GetViewColumns(idx_t i, const string &name, const LogicalType &type, DataChunk &output) { + static void GetViewColumns(idx_t i, const Identifier &name, const LogicalType &type, DataChunk &output) { // "cid", INTEGER output.data[0].Append(Value::INTEGER((int32_t)i)); // "name", VARCHAR @@ -142,7 +142,7 @@ struct PragmaShowHelper { output.data[5].Append(Value()); } - static void GetViewColumns(idx_t i, const string &name, const LogicalType &type, DataChunk &output) { + static void GetViewColumns(idx_t i, const Identifier &name, const LogicalType &type, DataChunk &output) { // "column_name", VARCHAR output.data[0].Append(Value(name)); // "column_type", VARCHAR @@ -227,7 +227,6 @@ static void PragmaTableInfoTable(PragmaTableOperatorData &data, TableCatalogEntr // start returning values // either fill up the chunk or return all the remaining columns idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, table.GetColumns().LogicalColumnCount()); - output.SetCardinality(next - data.offset); for (idx_t i = data.offset; i < next; i++) { auto &column = table.GetColumn(LogicalIndex(i)); @@ -257,11 +256,10 @@ static void PragmaTableInfoView(ClientContext &context, PragmaTableOperatorData // start returning values // either fill up the chunk or return all the remaining columns idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, view_types.size()); - output.SetCardinality(next - data.offset); for (idx_t i = data.offset; i < next; i++) { auto type = view_types[i]; - auto &name = i < view.aliases.size() ? view.aliases[i] : view_names[i]; + auto name = i < view.aliases.size() ? view.aliases[i] : view_names[i]; if (is_table_info) { PragmaTableInfoHelper::GetViewColumns(i, name, type, output); diff --git a/src/duckdb/src/function/table/system/pragma_table_sample.cpp b/src/duckdb/src/function/table/system/pragma_table_sample.cpp index cf5a9ccfb..b8b04ad15 100644 --- a/src/duckdb/src/function/table/system/pragma_table_sample.cpp +++ b/src/duckdb/src/function/table/system/pragma_table_sample.cpp @@ -48,7 +48,7 @@ static unique_ptr DuckDBTableSampleBind(ClientContext &context, Ta for (idx_t i = 0; i < types.size(); i++) { auto logical_index = LogicalIndex(i); auto &col = table_entry.GetColumn(logical_index); - names.push_back(col.GetName()); + names.emplace_back(col.GetName().GetIdentifierName()); } return make_uniq(entry); diff --git a/src/duckdb/src/function/table/system/pragma_user_agent.cpp b/src/duckdb/src/function/table/system/pragma_user_agent.cpp index f40db7b67..8b6003906 100644 --- a/src/duckdb/src/function/table/system/pragma_user_agent.cpp +++ b/src/duckdb/src/function/table/system/pragma_user_agent.cpp @@ -35,7 +35,6 @@ void PragmaUserAgentFunction(ClientContext &context, TableFunctionInput &data_p, return; } - output.SetCardinality(1); output.data[0].Append(Value(data.user_agent)); data.finished = true; diff --git a/src/duckdb/src/function/table/system/test_all_types.cpp b/src/duckdb/src/function/table/system/test_all_types.cpp index e09d8a1b0..9357a9c6d 100644 --- a/src/duckdb/src/function/table/system/test_all_types.cpp +++ b/src/duckdb/src/function/table/system/test_all_types.cpp @@ -185,36 +185,36 @@ vector TestAllTypesFun::GetTestTypes(const bool use_large_enum, const // structs child_list_t struct_type_list; - struct_type_list.push_back(make_pair("a", LogicalType::INTEGER)); - struct_type_list.push_back(make_pair("b", LogicalType::VARCHAR)); + struct_type_list.emplace_back(make_pair("a", LogicalType::INTEGER)); + struct_type_list.emplace_back(make_pair("b", LogicalType::VARCHAR)); auto struct_type = LogicalType::STRUCT(struct_type_list); child_list_t min_struct_list; - min_struct_list.push_back(make_pair("a", Value(LogicalType::INTEGER))); - min_struct_list.push_back(make_pair("b", Value(LogicalType::VARCHAR))); + min_struct_list.emplace_back(make_pair("a", Value(LogicalType::INTEGER))); + min_struct_list.emplace_back(make_pair("b", Value(LogicalType::VARCHAR))); auto min_struct_val = Value::STRUCT(std::move(min_struct_list)); child_list_t max_struct_list; - max_struct_list.push_back(make_pair("a", Value::INTEGER(42))); - max_struct_list.push_back(make_pair("b", Value("🦆🦆🦆🦆🦆🦆"))); + max_struct_list.emplace_back(make_pair("a", Value::INTEGER(42))); + max_struct_list.emplace_back(make_pair("b", Value("🦆🦆🦆🦆🦆🦆"))); auto max_struct_val = Value::STRUCT(std::move(max_struct_list)); result.emplace_back(struct_type, "struct", min_struct_val, max_struct_val); // structs with lists child_list_t struct_list_type_list; - struct_list_type_list.push_back(make_pair("a", int_list_type)); - struct_list_type_list.push_back(make_pair("b", varchar_list_type)); + struct_list_type_list.emplace_back(make_pair("a", int_list_type)); + struct_list_type_list.emplace_back(make_pair("b", varchar_list_type)); auto struct_list_type = LogicalType::STRUCT(struct_list_type_list); child_list_t min_struct_vl_list; - min_struct_vl_list.push_back(make_pair("a", Value(int_list_type))); - min_struct_vl_list.push_back(make_pair("b", Value(varchar_list_type))); + min_struct_vl_list.emplace_back(make_pair("a", Value(int_list_type))); + min_struct_vl_list.emplace_back(make_pair("b", Value(varchar_list_type))); auto min_struct_val_list = Value::STRUCT(std::move(min_struct_vl_list)); child_list_t max_struct_vl_list; - max_struct_vl_list.push_back(make_pair("a", int_list)); - max_struct_vl_list.push_back(make_pair("b", varchar_list)); + max_struct_vl_list.emplace_back(make_pair("a", int_list)); + max_struct_vl_list.emplace_back(make_pair("b", varchar_list)); auto max_struct_val_list = Value::STRUCT(std::move(max_struct_vl_list)); result.emplace_back(struct_list_type, "struct_of_arrays", std::move(min_struct_val_list), @@ -232,11 +232,11 @@ vector TestAllTypesFun::GetTestTypes(const bool use_large_enum, const auto min_map_value = Value::MAP(ListType::GetChildType(map_type), vector()); child_list_t map_struct1; - map_struct1.push_back(make_pair("key", Value("key1"))); - map_struct1.push_back(make_pair("value", Value("🦆🦆🦆🦆🦆🦆"))); + map_struct1.emplace_back(make_pair("key", Value("key1"))); + map_struct1.emplace_back(make_pair("value", Value("🦆🦆🦆🦆🦆🦆"))); child_list_t map_struct2; - map_struct2.push_back(make_pair("key", Value("key2"))); - map_struct2.push_back(make_pair("value", Value("goose"))); + map_struct2.emplace_back(make_pair("key", Value("key2"))); + map_struct2.emplace_back(make_pair("value", Value("goose"))); vector map_values; map_values.push_back(Value::STRUCT(map_struct1)); @@ -428,7 +428,6 @@ static void TestAllTypesFunction(ClientContext &context, TableFunctionInput &dat } count++; } - output.SetCardinality(count); } void TestAllTypesFun::RegisterFunction(BuiltinFunctions &set) { diff --git a/src/duckdb/src/function/table/system/test_vector_types.cpp b/src/duckdb/src/function/table/system/test_vector_types.cpp index 2d0623fb0..a0c736dc1 100644 --- a/src/duckdb/src/function/table/system/test_vector_types.cpp +++ b/src/duckdb/src/function/table/system/test_vector_types.cpp @@ -133,7 +133,6 @@ struct TestVectorFlat { result->data[c].Append(result_values.GetValue(cur_row + i, c)); } } - result->SetChildCardinality(cardinality); info.entries.push_back(std::move(result)); } } @@ -149,7 +148,6 @@ struct TestVectorConstant { for (idx_t c = 0; c < info.types.size(); c++) { result->data[c].Reference(values.GetValue(0, c), count_t(cardinality)); } - result->SetCardinality(cardinality); info.entries.push_back(std::move(result)); } @@ -298,7 +296,8 @@ unique_ptr TestVectorTypesInit(ClientContext &context, map test_type_map; for (auto &test_type : test_types) { - test_type_map.insert(make_pair(test_type.type.id(), std::move(test_type))); + auto type_id = test_type.type.id(); + test_type_map.insert(make_pair(type_id, std::move(test_type))); } TestVectorInfo info(bind_data.types, test_type_map, result->entries); diff --git a/src/duckdb/src/function/table/system_functions.cpp b/src/duckdb/src/function/table/system_functions.cpp index 056b3aa80..6eebbfa68 100644 --- a/src/duckdb/src/function/table/system_functions.cpp +++ b/src/duckdb/src/function/table/system_functions.cpp @@ -36,6 +36,7 @@ void BuiltinFunctions::RegisterSQLiteFunctions() { DuckDBMemoryFun::RegisterFunction(*this); DuckDBEvictionQueuesFun::RegisterFunction(*this); DuckDBExternalFileCacheFun::RegisterFunction(*this); + DuckDBMetricsFun::RegisterFunction(*this); DuckDBOptimizersFun::RegisterFunction(*this); DuckDBSecretsFun::RegisterFunction(*this); DuckDBWhichSecretFun::RegisterFunction(*this); diff --git a/src/duckdb/src/function/table/table_scan.cpp b/src/duckdb/src/function/table/table_scan.cpp index 4c5a039e8..b480e98dc 100644 --- a/src/duckdb/src/function/table/table_scan.cpp +++ b/src/duckdb/src/function/table/table_scan.cpp @@ -26,9 +26,8 @@ #include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/transaction/local_storage.hpp" #include "duckdb/storage/data_table.hpp" @@ -47,8 +46,8 @@ struct TableScanLocalState : public LocalTableFunctionState { //! This includes filter columns, which are immediately removed. DataChunk all_columns; - idx_t rows_scanned = 0; idx_t rows_in_current_row_group = 0; + idx_t row_groups_scanned = 0; }; struct IndexScanLocalState : public LocalTableFunctionState { @@ -73,11 +72,15 @@ class TableScanGlobalState : public GlobalTableFunctionState { D_ASSERT(bind_data_p); auto &bind_data = bind_data_p->Cast(); auto &duck_table = bind_data.table.Cast(); - max_threads = duck_table.GetStorage().MaxThreads(context); + auto &storage = duck_table.GetStorage(); + max_threads = storage.MaxThreads(context); + total_row_groups_to_scan = storage.GetRowGroupCountWithLocalStorage(context); } //! The maximum number of threads for this table scan. idx_t max_threads; + //! The total number of row groups available to this table scan. + idx_t total_row_groups_to_scan; //! The projected columns of this table scan. vector projection_ids; //! The types of all scanned columns. @@ -91,6 +94,7 @@ class TableScanGlobalState : public GlobalTableFunctionState { virtual OperatorPartitionData TableScanGetPartitionData(ClientContext &context, TableFunctionGetPartitionInput &input) = 0; virtual idx_t TableScanRowsScanned(LocalTableFunctionState &state) = 0; + virtual idx_t TableScanRowGroupsScanned(LocalTableFunctionState &state) = 0; idx_t MaxThreads() const override { return max_threads; @@ -259,6 +263,10 @@ class DuckIndexScanState : public TableScanGlobalState { auto &l_state = state.Cast(); return l_state.rows_scanned; } + + idx_t TableScanRowGroupsScanned(LocalTableFunctionState &) override { + return 0; + } }; class DuckTableScanState : public TableScanGlobalState { @@ -290,13 +298,18 @@ class DuckTableScanState : public TableScanGlobalState { } if (bind_data.order_options) { - l_state->scan_state.table_state.reorderer = make_uniq(*bind_data.order_options); - l_state->scan_state.local_state.reorderer = make_uniq(*bind_data.order_options); + l_state->scan_state.table_state.reorderer = + make_uniq(*bind_data.order_options, TransactionData(tx)); + l_state->scan_state.local_state.reorderer = + make_uniq(*bind_data.order_options, TransactionData(tx)); } l_state->scan_state.Initialize(std::move(storage_ids), context.client, input.filters, input.sample_options); l_state->rows_in_current_row_group = storage.NextParallelScan(context.client, state, l_state->scan_state); + if (l_state->rows_in_current_row_group > 0) { + l_state->row_groups_scanned++; + } if (input.CanRemoveFilterColumns()) { l_state->all_columns.Initialize(context.client, scanned_types); } @@ -323,9 +336,10 @@ class DuckTableScanState : public TableScanGlobalState { return; } - // We have fully processed a row group. Add to scanned_rows - l_state.rows_scanned += l_state.rows_in_current_row_group; l_state.rows_in_current_row_group = storage.NextParallelScan(context, state, l_state.scan_state); + if (l_state.rows_in_current_row_group > 0) { + l_state.row_groups_scanned++; + } if (data_p.results_execution_mode == AsyncResultsExecutionMode::TASK_EXECUTOR) { // We can avoid looping, and just return as appropriate @@ -376,8 +390,13 @@ class DuckTableScanState : public TableScanGlobalState { } idx_t TableScanRowsScanned(LocalTableFunctionState &state) override { + const auto &l_state = state.Cast(); + return l_state.scan_state.table_state.rows_scanned + l_state.scan_state.local_state.rows_scanned; + } + + idx_t TableScanRowGroupsScanned(LocalTableFunctionState &state) override { auto &l_state = state.Cast(); - return l_state.rows_scanned; + return l_state.row_groups_scanned; } }; @@ -391,8 +410,12 @@ unique_ptr DuckTableScanInitGlobal(ClientContext &cont DataTable &storage, const TableScanBindData &bind_data) { auto g_state = make_uniq(context, input.bind_data.get()); if (bind_data.order_options) { - g_state->state.scan_state.reorderer = make_uniq(*bind_data.order_options); - g_state->state.local_state.reorderer = make_uniq(*bind_data.order_options); + auto transaction = TransactionData(DuckTransaction::Get(context, storage.GetAttached())); + g_state->state.scan_state.reorderer = make_uniq(*bind_data.order_options, transaction); + g_state->state.local_state.reorderer = make_uniq(*bind_data.order_options, transaction); + } + if (bind_data.partitions_to_scan) { + g_state->state.scan_state.partitions_to_scan = bind_data.partitions_to_scan.get(); } // Check if row_number column is requested and initialize row_number_base @@ -476,14 +499,14 @@ static bool CollectValuesAndComparisonsFromExpression(const Expression &expr, va if (expr.GetExpressionClass() == ExpressionClass::BOUND_OPERATOR && expr.GetExpressionType() == ExpressionType::COMPARE_IN) { auto &op = expr.Cast(); - if (op.children.empty() || op.children[0]->GetExpressionClass() != ExpressionClass::BOUND_REF) { + if (op.GetChildren().empty() || op.GetChildren()[0]->GetExpressionClass() != ExpressionClass::BOUND_REF) { return false; } - for (idx_t i = 1; i < op.children.size(); i++) { - if (op.children[i]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { + for (idx_t i = 1; i < op.GetChildren().size(); i++) { + if (op.GetChildren()[i]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { return false; } - auto &value = op.children[i]->Cast().value; + auto &value = op.GetChildren()[i]->Cast().GetValue(); if (!value.IsNull()) { in_values.insert(value); } @@ -498,9 +521,9 @@ static bool CollectValuesAndComparisonsFromExpression(const Expression &expr, va bool left_is_ref = left.GetExpressionClass() == ExpressionClass::BOUND_REF; bool right_is_ref = right.GetExpressionClass() == ExpressionClass::BOUND_REF; if (right.GetExpressionType() == ExpressionType::VALUE_CONSTANT && left_is_ref) { - val = right.Cast().value; + val = right.Cast().GetValue(); } else if (left.GetExpressionType() == ExpressionType::VALUE_CONSTANT && right_is_ref) { - val = left.Cast().value; + val = left.Cast().GetValue(); } else { return false; } @@ -516,13 +539,35 @@ static bool CollectValuesAndComparisonsFromExpression(const Expression &expr, va if (expr.GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION && expr.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { auto &conj = expr.Cast(); - for (auto &child : conj.children) { + for (auto &child : conj.GetChildren()) { if (!CollectValuesAndComparisonsFromExpression(*child, in_values, comparisons)) { return false; } } return true; } + if (expr.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &func = expr.Cast(); + if (func.Function().GetName() == OptionalFilterScalarFun::NAME) { + if (!func.BindInfo()) { + return true; + } + auto &data = func.BindInfo()->Cast(); + return !data.child_filter_expr || + CollectValuesAndComparisonsFromExpression(*data.child_filter_expr, in_values, comparisons); + } + if (func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME) { + if (!func.BindInfo()) { + return true; + } + auto &data = func.BindInfo()->Cast(); + return !data.child_filter_expr || + CollectValuesAndComparisonsFromExpression(*data.child_filter_expr, in_values, comparisons); + } + if (TableFilterFunctions::IsTableFilterFunction(func.Function())) { + return true; + } + } return false; } @@ -559,40 +604,10 @@ static bool ValueQualifies(const Value &value, const vector return true; } -static bool CollectValuesAndComparisonsFromTableFilter(const TableFilter &filter, value_set_t &in_values, - vector &comparisons) { - switch (filter.filter_type) { - case TableFilterType::EXPRESSION_FILTER: - return CollectValuesAndComparisonsFromExpression(*filter.Cast().expr, in_values, comparisons); - case TableFilterType::OPTIONAL_FILTER: { - auto &optional_filter = filter.Cast(); - if (!optional_filter.child_filter) { - return true; // No child filters, always OK - } - return CollectValuesAndComparisonsFromTableFilter(*optional_filter.child_filter, in_values, comparisons); - } - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and = filter.Cast(); - for (auto &child_filter : conjunction_and.child_filters) { - if (!CollectValuesAndComparisonsFromTableFilter(*child_filter, in_values, comparisons)) { - return false; - } - } - return true; - } - case TableFilterType::BLOOM_FILTER: - case TableFilterType::PERFECT_HASH_JOIN_FILTER: - case TableFilterType::PREFIX_RANGE_FILTER: - return true; - default: - return false; - } -} - -static bool ExtractValuesFromTableFilter(const TableFilter &filter, value_set_t &values) { +static bool ExtractValuesFromExpression(const Expression &expr, value_set_t &values) { value_set_t in_values; vector comparisons; - if (!CollectValuesAndComparisonsFromTableFilter(filter, in_values, comparisons) || in_values.empty()) { + if (!CollectValuesAndComparisonsFromExpression(expr, in_values, comparisons) || in_values.empty()) { return false; } for (auto &value : in_values) { @@ -615,21 +630,20 @@ void ExtractExpressionsFromValues(const value_set_t &unique_values, BoundColumnR vector> ExtractFilterExpressions(const ColumnDefinition &col, const TableFilter &filter, idx_t storage_idx) { + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "ExtractFilterExpressions"); ColumnBinding binding(TableIndex(0), ProjectionIndex(storage_idx)); auto bound_ref = make_uniq(col.Name(), col.Type(), binding); // Extract all exact values we can derive from the filter tree. vector> expressions; value_set_t values; - if (ExtractValuesFromTableFilter(filter, values)) { + if (ExtractValuesFromExpression(*expr_filter.expr, values)) { ExtractExpressionsFromValues(values, *bound_ref, expressions); } // Attempt matching the top-level filter to the index expression. if (expressions.empty()) { - auto filter_expr = filter.filter_type == TableFilterType::EXPRESSION_FILTER - ? filter.Cast().ToExpression(*bound_ref) - : filter.ToExpression(*bound_ref); + auto filter_expr = expr_filter.ToExpression(*bound_ref); expressions.push_back(std::move(filter_expr)); } @@ -675,8 +689,8 @@ bool TryScanIndex(ART &art, IndexEntry &entry, const ColumnList &column_list, Ta auto &bound_column_ref_expr = expr.Cast(); // If the bound column references the index column, use updated_index_column - if (bound_column_ref_expr.binding.column_index == indexed_columns[0]) { - bound_column_ref_expr.binding.column_index = updated_index_column; + if (bound_column_ref_expr.Binding().column_index == indexed_columns[0]) { + bound_column_ref_expr.BindingMutable().column_index = updated_index_column; } }); } @@ -751,6 +765,12 @@ unique_ptr TableScanInitGlobal(ClientContext &context, if (!input.filters) { return DuckTableScanInitGlobal(context, input, storage, bind_data); } + + // Only scan specific partitions + if (bind_data.partitions_to_scan) { + return DuckTableScanInitGlobal(context, input, storage, bind_data); + } + auto &filter_set = *input.filters; // FIXME: We currently only support scanning one ART with one filter. @@ -785,8 +805,8 @@ unique_ptr TableScanInitGlobal(ClientContext &context, // row groups while we hold row IDs from the ART, ensuring we always see a consistent // pairing. unique_ptr vacuum_lock; - auto &db = DatabaseInstance::GetDatabase(context); - if (Settings::Get(db) > 0) { + const auto &attached = storage.GetAttached(); + if (attached.GetVacuumRebuildIndexThreshold() > 0) { auto &transaction_manager = DuckTransactionManager::Get(storage.GetAttached()); vacuum_lock = transaction_manager.SharedVacuumLock(); } @@ -882,9 +902,12 @@ unique_ptr TableScanCardinality(ClientContext &context, const Fu return make_uniq(estimated_cardinality, estimated_cardinality); } -idx_t TableScanRowsScanned(GlobalTableFunctionState &gstate_p, LocalTableFunctionState &local_state) { - auto &gstate = gstate_p.Cast(); - return gstate.TableScanRowsScanned(local_state); +void TableScanGetMetrics(TableFunctionGetMetricsInput &input) { + auto &gstate = input.global_state->Cast(); + auto &local_state = *input.local_state; + input.operator_metrics.rows_scanned = gstate.TableScanRowsScanned(local_state); + input.operator_metrics.row_groups_scanned = gstate.TableScanRowGroupsScanned(local_state); + input.operator_metrics.total_row_groups_to_scan = gstate.total_row_groups_to_scan; } InsertionOrderPreservingMap TableScanToString(TableFunctionToStringInput &input) { @@ -908,9 +931,9 @@ static void TableScanSerialize(Serializer &serializer, const optional_ptr TableScanDeserialize(Deserializer &deserializer, TableFunction &function) { - auto catalog = deserializer.ReadProperty(100, "catalog"); - auto schema = deserializer.ReadProperty(101, "schema"); - auto table = deserializer.ReadProperty(102, "table"); + auto catalog = deserializer.ReadProperty(100, "catalog"); + auto schema = deserializer.ReadProperty(101, "schema"); + auto table = deserializer.ReadProperty(102, "table"); auto &catalog_entry = Catalog::GetEntry(deserializer.Get(), catalog, schema, table); if (catalog_entry.type != CatalogType::TABLE_ENTRY) { @@ -956,6 +979,11 @@ void SetScanOrder(unique_ptr order_options, optional_ptr partition_indices, optional_ptr bind_data_p) { + auto &bind_data = bind_data_p->Cast(); + bind_data.partitions_to_scan = make_uniq>(partition_indices.begin(), partition_indices.end()); +} + TableFunction TableScanFunction::GetFunction() { TableFunction scan_function("seq_scan", {}, TableScanFunc); scan_function.init_local = TableScanInitLocal; @@ -963,7 +991,7 @@ TableFunction TableScanFunction::GetFunction() { scan_function.statistics_extended = TableScanStatistics; scan_function.dependency = TableScanDependency; scan_function.cardinality = TableScanCardinality; - scan_function.rows_scanned = TableScanRowsScanned; + scan_function.get_metrics = TableScanGetMetrics; scan_function.pushdown_complex_filter = nullptr; scan_function.to_string = TableScanToString; scan_function.table_scan_progress = TableScanProgress; @@ -981,6 +1009,7 @@ TableFunction TableScanFunction::GetFunction() { scan_function.get_virtual_columns = TableScanGetVirtualColumns; scan_function.get_row_id_columns = TableScanGetRowIdColumns; scan_function.set_scan_order = SetScanOrder; + scan_function.set_partitions_to_scan = SetPartitionsToScan; scan_function.supports_pushdown_extract = TableSupportsPushdownExtract; return scan_function; } diff --git a/src/duckdb/src/function/table/unnest.cpp b/src/duckdb/src/function/table/unnest.cpp index 8b09dbf3d..450adf39b 100644 --- a/src/duckdb/src/function/table/unnest.cpp +++ b/src/duckdb/src/function/table/unnest.cpp @@ -77,7 +77,7 @@ static unique_ptr UnnestInit(ClientContext &context, T } auto bound_unnest = make_uniq(ListType::GetChildType(child->GetReturnType())); - bound_unnest->child = std::move(child); + bound_unnest->ChildMutable() = std::move(child); result->select_list.push_back(std::move(bound_unnest)); return std::move(result); } @@ -92,6 +92,7 @@ static OperatorResultType UnnestFunction(ExecutionContext &context, TableFunctio void UnnestTableFunction::RegisterFunction(BuiltinFunctions &set) { TableFunction unnest_function("unnest", {LogicalType::ANY}, nullptr, UnnestBind, UnnestInit, UnnestLocalInit); unnest_function.in_out_function = UnnestFunction; + unnest_function.return_type = TableFunctionReturnType::SET_RETURNING_FUNCTION; set.AddFunction(unnest_function); } diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index d6c32b2c2..38725657a 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev5627" +#define DUCKDB_PATCH_VERSION "0-dev8488" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 6 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.6.0-dev5627" +#define DUCKDB_VERSION "v1.6.0-dev8488" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "6320f3003c" +#define DUCKDB_SOURCE_ID "ff62f2bc89" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" @@ -50,7 +50,6 @@ static void PragmaVersionFunction(ClientContext &context, TableFunctionInput &da // finished returning values return; } - output.SetCardinality(1); output.data[0].Append(Value(DuckDB::LibraryVersion())); output.data[1].Append(Value(DuckDB::SourceID())); output.data[2].Append(Value(DuckDB::ReleaseCodename())); @@ -128,7 +127,6 @@ static void PragmaPlatformFunction(ClientContext &context, TableFunctionInput &d // finished returning values return; } - output.SetCardinality(1); output.data[0].Append(Value(DuckDB::Platform())); data.finished = true; } diff --git a/src/duckdb/src/function/table_function.cpp b/src/duckdb/src/function/table_function.cpp index 155c6c912..18f170a62 100644 --- a/src/duckdb/src/function/table_function.cpp +++ b/src/duckdb/src/function/table_function.cpp @@ -14,34 +14,36 @@ PartitionStatistics::PartitionStatistics() : row_start(0), count(0), count_type( TableFunctionInfo::~TableFunctionInfo() { } -TableFunction::TableFunction(string name, const vector &arguments, table_function_t function_, +TableFunction::TableFunction(Identifier name, const vector &arguments, table_function_t function_, table_function_bind_t bind, table_function_init_global_t init_global, table_function_init_local_t init_local) : SimpleNamedParameterFunction(std::move(name), arguments), bind(bind), bind_replace(nullptr), bind_operator(nullptr), init_global(init_global), init_local(init_local), function(function_), in_out_function(nullptr), in_out_function_final(nullptr), statistics(nullptr), statistics_extended(nullptr), - dependency(nullptr), cardinality(nullptr), rows_scanned(nullptr), pushdown_complex_filter(nullptr), - pushdown_expression(nullptr), to_string(nullptr), dynamic_to_string(nullptr), table_scan_progress(nullptr), - get_partition_data(nullptr), get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), - supports_pushdown_type(nullptr), supports_pushdown_extract(nullptr), get_partition_info(nullptr), - get_partition_stats(nullptr), get_virtual_columns(nullptr), get_row_id_columns(nullptr), set_scan_order(nullptr), - serialize(nullptr), deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), filter_prune(false), - sampling_pushdown(false), late_materialization(false) { + dependency(nullptr), cardinality(nullptr), get_metrics(nullptr), pushdown_complex_filter(nullptr), + pushdown_expression(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_partition_data(nullptr), + get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), supports_pushdown_type(nullptr), + supports_pushdown_extract(nullptr), get_partition_info(nullptr), get_partition_stats(nullptr), + get_virtual_columns(nullptr), get_row_id_columns(nullptr), set_scan_order(nullptr), serialize(nullptr), + deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), filter_prune(false), + sampling_pushdown(false), late_materialization(false), + return_type(TableFunctionReturnType::TABLE_RETURNING_FUNCTION) { } -TableFunction::TableFunction(string name, const vector &arguments, std::nullptr_t function_, +TableFunction::TableFunction(Identifier name, const vector &arguments, std::nullptr_t function_, table_function_bind_t bind, table_function_init_global_t init_global, table_function_init_local_t init_local) : SimpleNamedParameterFunction(std::move(name), arguments), bind(bind), bind_replace(nullptr), bind_operator(nullptr), init_global(init_global), init_local(init_local), function(nullptr), in_out_function(nullptr), in_out_function_final(nullptr), statistics(nullptr), statistics_extended(nullptr), - dependency(nullptr), cardinality(nullptr), rows_scanned(nullptr), pushdown_complex_filter(nullptr), - pushdown_expression(nullptr), to_string(nullptr), dynamic_to_string(nullptr), table_scan_progress(nullptr), - get_partition_data(nullptr), get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), - supports_pushdown_type(nullptr), supports_pushdown_extract(nullptr), get_partition_info(nullptr), - get_partition_stats(nullptr), get_virtual_columns(nullptr), get_row_id_columns(nullptr), set_scan_order(nullptr), - serialize(nullptr), deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), filter_prune(false), - sampling_pushdown(false), late_materialization(false) { + dependency(nullptr), cardinality(nullptr), get_metrics(nullptr), pushdown_complex_filter(nullptr), + pushdown_expression(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_partition_data(nullptr), + get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), supports_pushdown_type(nullptr), + supports_pushdown_extract(nullptr), get_partition_info(nullptr), get_partition_stats(nullptr), + get_virtual_columns(nullptr), get_row_id_columns(nullptr), set_scan_order(nullptr), serialize(nullptr), + deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), filter_prune(false), + sampling_pushdown(false), late_materialization(false), + return_type(TableFunctionReturnType::TABLE_RETURNING_FUNCTION) { } TableFunction::TableFunction(const vector &arguments, table_function_t function_, @@ -65,16 +67,16 @@ bool TableFunction::operator==(const TableFunction &rhs) const { in_out_function_final == rhs.in_out_function_final && statistics == rhs.statistics && dependency == rhs.dependency && cardinality == rhs.cardinality && pushdown_complex_filter == rhs.pushdown_complex_filter && pushdown_expression == rhs.pushdown_expression && - to_string == rhs.to_string && dynamic_to_string == rhs.dynamic_to_string && - table_scan_progress == rhs.table_scan_progress && get_partition_data == rhs.get_partition_data && - get_bind_info == rhs.get_bind_info && type_pushdown == rhs.type_pushdown && - get_multi_file_reader == rhs.get_multi_file_reader && supports_pushdown_type == rhs.supports_pushdown_type && - get_partition_info == rhs.get_partition_info && get_partition_stats == rhs.get_partition_stats && - get_virtual_columns == rhs.get_virtual_columns && get_row_id_columns == rhs.get_row_id_columns && - serialize == rhs.serialize && deserialize == rhs.deserialize && - verify_serialization == rhs.verify_serialization && projection_pushdown == rhs.projection_pushdown && - filter_pushdown == rhs.filter_pushdown && filter_prune == rhs.filter_prune && - sampling_pushdown == rhs.sampling_pushdown && late_materialization == rhs.late_materialization && + to_string == rhs.to_string && table_scan_progress == rhs.table_scan_progress && + get_partition_data == rhs.get_partition_data && get_bind_info == rhs.get_bind_info && + type_pushdown == rhs.type_pushdown && get_multi_file_reader == rhs.get_multi_file_reader && + supports_pushdown_type == rhs.supports_pushdown_type && get_partition_info == rhs.get_partition_info && + get_partition_stats == rhs.get_partition_stats && get_virtual_columns == rhs.get_virtual_columns && + get_row_id_columns == rhs.get_row_id_columns && serialize == rhs.serialize && + deserialize == rhs.deserialize && verify_serialization == rhs.verify_serialization && + projection_pushdown == rhs.projection_pushdown && filter_pushdown == rhs.filter_pushdown && + filter_prune == rhs.filter_prune && sampling_pushdown == rhs.sampling_pushdown && + late_materialization == rhs.late_materialization && return_type == rhs.return_type && global_initialization == rhs.global_initialization; } diff --git a/src/duckdb/src/function/variant/variant_shredding.cpp b/src/duckdb/src/function/variant/variant_shredding.cpp index 84b6b3ffd..6d9ef7d1a 100644 --- a/src/duckdb/src/function/variant/variant_shredding.cpp +++ b/src/duckdb/src/function/variant/variant_shredding.cpp @@ -175,7 +175,9 @@ void VariantShredding::WriteTypedObjectValues(UnifiedVariantVectorData &variant, (void)validity; //! Collect the nested data for the objects - auto nested_data = make_unsafe_uniq_array_uninitialized(count); + const auto owned_nested_data = make_unsafe_uniq_array_uninitialized(count); + array_ptr nested_data(owned_nested_data.get(), count); + for (idx_t i = 0; i < count; i++) { auto row = sel[i]; //! When we're shredding an object, the top-level struct of it should always be valid @@ -204,12 +206,12 @@ void VariantShredding::WriteTypedObjectValues(UnifiedVariantVectorData &variant, auto &key = shredded_types[child_idx].first; VariantPathComponent path_component; path_component.lookup_mode = VariantChildLookupMode::BY_KEY; - path_component.key = key; + path_component.key = key.GetIdentifierName(); ValidityMask lookup_validity(count); ValidityMask all_valid_validity(count); - VariantUtils::FindChildValues(variant, path_component, sel, child_values_indexes, lookup_validity, - nested_data.get(), all_valid_validity, count); + VariantUtils::FindChildValues(variant, path_component, sel, child_values_indexes, lookup_validity, nested_data, + all_valid_validity, count); if (lookup_validity.CanHaveNull()) { optional_ptr typed_value_vector; diff --git a/src/duckdb/src/function/window/window_aggregate_function.cpp b/src/duckdb/src/function/window/window_aggregate_function.cpp index f5ad4b5c5..735c342e9 100644 --- a/src/duckdb/src/function/window/window_aggregate_function.cpp +++ b/src/duckdb/src/function/window/window_aggregate_function.cpp @@ -31,18 +31,18 @@ class WindowAggregateExecutorGlobalState : public WindowExecutorGlobalState { static BoundWindowExpression &SimplifyWindowedAggregate(BoundWindowExpression &wexpr, ClientContext &context) { // Remove redundant/irrelevant modifiers (they can be serious performance cliffs) - if (wexpr.aggregate && ClientConfig::GetConfig(context).enable_optimizer) { - const auto &aggr = wexpr.aggregate; - auto &arg_orders = wexpr.arg_orders; + if (wexpr.AggregateFunction() && Settings::Get(context)) { + const auto &aggr = wexpr.AggregateFunction(); + auto &arg_orders = wexpr.ArgOrdersMutable(); if (aggr->GetDistinctDependent() != AggregateDistinctDependent::DISTINCT_DEPENDENT) { - wexpr.distinct = false; + wexpr.DistinctMutable() = false; } if (aggr->GetOrderDependent() != AggregateOrderDependent::ORDER_DEPENDENT) { arg_orders.clear(); } else { // If the argument order is prefix of the partition ordering, // then we can just use the partition ordering. - if (BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) == arg_orders.size()) { + if (BoundWindowExpression::GetSharedOrders(wexpr.OrderBy(), arg_orders) == arg_orders.size()) { arg_orders.clear(); } } @@ -56,7 +56,7 @@ WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, C : WindowExecutor(SimplifyWindowedAggregate(wexpr, client), shared), mode(Settings::Get(client)) { // Force naive for SEPARATE mode or for (currently!) unsupported functionality - if (!ClientConfig::GetConfig(client).enable_optimizer || mode == WindowAggregationMode::SEPARATE) { + if (!Settings::Get(client) || mode == WindowAggregationMode::SEPARATE) { if (!WindowNaiveAggregator::CanAggregate(wexpr)) { throw InvalidInputException("Cannot use non-aggregate window function with naive window executor!"); } @@ -85,9 +85,9 @@ WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, C // Compute the FILTER with the other eval columns. // Anyone who needs it can then convert it to the form they need. - if (wexpr.filter_expr) { - const auto filter_idx = shared.RegisterSink(wexpr.filter_expr); - filter_ref = make_uniq(wexpr.filter_expr->GetReturnType(), filter_idx); + if (wexpr.Filter()) { + const auto filter_idx = shared.RegisterSink(wexpr.Filter()); + filter_ref = make_uniq(wexpr.Filter()->GetReturnType(), filter_idx); } } @@ -237,13 +237,13 @@ void WindowAggregateExecutor::Finalize(ExecutionContext &context, CollectionPtr // First entry is the frame start stats[0] = FrameDelta(-count, count); - auto base = wexpr.expr_stats.empty() ? nullptr : wexpr.expr_stats[0].get(); - ApplyWindowStats(wexpr.start, stats[0], base, true); + auto base = wexpr.ExprStats().empty() ? nullptr : wexpr.ExprStats()[0].get(); + ApplyWindowStats(wexpr.WindowStart(), stats[0], base, true); // Second entry is the frame end stats[1] = FrameDelta(-count, count); - base = wexpr.expr_stats.empty() ? nullptr : wexpr.expr_stats[1].get(); - ApplyWindowStats(wexpr.end, stats[1], base, false); + base = wexpr.ExprStats().empty() ? nullptr : wexpr.ExprStats()[1].get(); + ApplyWindowStats(wexpr.WindowEnd(), stats[1], base, false); auto &lastate = sink.local_state.Cast(); OperatorSinkInput asink {*gsink, *lastate.aggregator_state, sink.interrupt_state}; @@ -258,7 +258,7 @@ void WindowAggregateExecutor::EvaluateInternal(ExecutionContext &context, DataCh D_ASSERT(aggregator); OperatorSinkInput asink {*gsink, *lastate.aggregator_state, sink.interrupt_state}; - aggregator->Evaluate(context, bounds, result, eval_chunk.size(), row_idx, asink); + aggregator->Evaluate(context, bounds, result, bounds.size(), row_idx, asink); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_aggregate_states.cpp b/src/duckdb/src/function/window/window_aggregate_states.cpp index 0b583a7a3..b18523b30 100644 --- a/src/duckdb/src/function/window/window_aggregate_states.cpp +++ b/src/duckdb/src/function/window/window_aggregate_states.cpp @@ -26,12 +26,12 @@ void WindowAggregateStates::Initialize(idx_t count) { } void WindowAggregateStates::Combine(WindowAggregateStates &target) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); + AggregateInputData aggr_input_data(aggr, allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); aggr.function.GetStateCombineCallback()(*statef, *target.statef, aggr_input_data, GetCount()); } void WindowAggregateStates::Finalize(Vector &result) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); aggr.function.GetStateFinalizeCallback()(*statef, aggr_input_data, result, GetCount(), 0); } @@ -40,7 +40,7 @@ void WindowAggregateStates::Destroy() { return; } - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); if (aggr.function.HasStateDestructorCallback()) { aggr.function.GetStateDestructorCallback()(*statef, aggr_input_data, GetCount()); } diff --git a/src/duckdb/src/function/window/window_aggregator.cpp b/src/duckdb/src/function/window/window_aggregator.cpp index 02cd55dc6..83092c21e 100644 --- a/src/duckdb/src/function/window/window_aggregator.cpp +++ b/src/duckdb/src/function/window/window_aggregator.cpp @@ -11,15 +11,15 @@ namespace duckdb { //===--------------------------------------------------------------------===// WindowAggregator::WindowAggregator(const BoundWindowExpression &wexpr) : wexpr(wexpr), aggr(wexpr), result_type(wexpr.GetReturnType()), - state_size(aggr.function.GetStateSizeCallback()(aggr.function)), exclude_mode(wexpr.exclude_clause) { - for (auto &child : wexpr.children) { + state_size(aggr.function.GetStateSizeCallback()(aggr.function)), exclude_mode(wexpr.WindowExclude()) { + for (auto &child : wexpr.GetChildren()) { arg_types.emplace_back(child->GetReturnType()); } } WindowAggregator::WindowAggregator(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : WindowAggregator(wexpr) { - for (auto &child : wexpr.children) { + for (auto &child : wexpr.GetChildren()) { child_idx.emplace_back(shared.RegisterCollection(child, false)); } } diff --git a/src/duckdb/src/function/window/window_boundaries_state.cpp b/src/duckdb/src/function/window/window_boundaries_state.cpp index 94e2967f4..efb6c3dba 100644 --- a/src/duckdb/src/function/window/window_boundaries_state.cpp +++ b/src/duckdb/src/function/window/window_boundaries_state.cpp @@ -309,11 +309,13 @@ static idx_t FindOrderedRangeBound(WindowCursor &range_lo, WindowCursor &range_h } bool WindowBoundariesState::HasPrecedingRange(const BoundWindowExpression &wexpr) { - return (wexpr.start == WindowBoundary::EXPR_PRECEDING_RANGE || wexpr.end == WindowBoundary::EXPR_PRECEDING_RANGE); + return (wexpr.WindowStart() == WindowBoundary::EXPR_PRECEDING_RANGE || + wexpr.WindowEnd() == WindowBoundary::EXPR_PRECEDING_RANGE); } bool WindowBoundariesState::HasFollowingRange(const BoundWindowExpression &wexpr) { - return (wexpr.start == WindowBoundary::EXPR_FOLLOWING_RANGE || wexpr.end == WindowBoundary::EXPR_FOLLOWING_RANGE); + return (wexpr.WindowStart() == WindowBoundary::EXPR_FOLLOWING_RANGE || + wexpr.WindowEnd() == WindowBoundary::EXPR_FOLLOWING_RANGE); } void WindowBoundariesState::AddImpliedBounds(WindowBoundsSet &result, const BoundWindowExpression &wexpr) { @@ -323,7 +325,7 @@ void WindowBoundariesState::AddImpliedBounds(WindowBoundsSet &result, const Boun result.insert(PARTITION_END); // if we have EXCLUDE GROUP / TIES, we also need peer boundaries - if (wexpr.exclude_clause != WindowExcludeMode::NO_OTHER) { + if (wexpr.WindowExclude() != WindowExcludeMode::NO_OTHER) { result.insert(PEER_BEGIN); result.insert(PEER_END); } @@ -331,7 +333,7 @@ void WindowBoundariesState::AddImpliedBounds(WindowBoundsSet &result, const Boun // If the frames are RANGE or GROUPS, then we need peer boundaries // If they are preceding or following, RANGE also needs to know // where the valid values begin or end. - switch (wexpr.start) { + switch (wexpr.WindowStart()) { case WindowBoundary::CURRENT_ROW_RANGE: case WindowBoundary::CURRENT_ROW_GROUPS: result.insert(PEER_BEGIN); @@ -360,7 +362,7 @@ void WindowBoundariesState::AddImpliedBounds(WindowBoundsSet &result, const Boun break; } - switch (wexpr.end) { + switch (wexpr.WindowEnd()) { case WindowBoundary::CURRENT_ROW_RANGE: case WindowBoundary::CURRENT_ROW_GROUPS: result.insert(PEER_END); @@ -390,8 +392,8 @@ void WindowBoundariesState::AddImpliedBounds(WindowBoundsSet &result, const Boun } } - const auto partition_count = wexpr.partitions.size(); - const auto order_count = wexpr.orders.size(); + const auto partition_count = wexpr.Partitions().size(); + const auto order_count = wexpr.OrderBy().size(); if (result.count(VALID_END)) { result.insert(PARTITION_END); @@ -417,14 +419,15 @@ void WindowBoundariesState::AddImpliedBounds(WindowBoundsSet &result, const Boun WindowBoundariesState::WindowBoundariesState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) : partition_mask(gstate.partition_mask), order_mask(gstate.order_mask), type(gstate.executor.wexpr.GetExpressionType()), input_size(gstate.payload_count), - start_boundary(gstate.executor.wexpr.start), end_boundary(gstate.executor.wexpr.end), + start_boundary(gstate.executor.wexpr.WindowStart()), end_boundary(gstate.executor.wexpr.WindowEnd()), boundary_start_idx(gstate.executor.boundary_start_idx), boundary_end_idx(gstate.executor.boundary_end_idx), - partition_count(gstate.executor.wexpr.partitions.size()), order_count(gstate.executor.wexpr.orders.size()), - range_sense(gstate.executor.wexpr.orders.empty() ? OrderType::INVALID : gstate.executor.wexpr.orders[0].type), + partition_count(gstate.executor.wexpr.Partitions().size()), order_count(gstate.executor.wexpr.OrderBy().size()), + range_sense(gstate.executor.wexpr.OrderBy().empty() ? OrderType::INVALID + : gstate.executor.wexpr.OrderBy()[0].type), has_preceding_range(HasPrecedingRange(gstate.executor.wexpr)), has_following_range(HasFollowingRange(gstate.executor.wexpr)), range_idx(gstate.executor.range_idx) { - if (gstate.executor.wexpr.window) { - const auto &wfunc = *gstate.executor.wexpr.window; + if (gstate.executor.wexpr.WindowFunction()) { + const auto &wfunc = *gstate.executor.wexpr.WindowFunction(); if (wfunc.HasBoundsCallback()) { wfunc.GetBounds(required, gstate.executor.wexpr); AddImpliedBounds(required, gstate.executor.wexpr); @@ -445,13 +448,11 @@ void WindowBoundariesState::Finalize(CollectionPtr collection) { } } -void WindowBoundariesState::UpdateBounds(idx_t row_idx, DataChunk &eval_chunk) { +void WindowBoundariesState::UpdateBounds(idx_t row_idx, DataChunk &eval_chunk, idx_t count) { // Evaluate the row-level arguments WindowInputExpression boundary_start(eval_chunk, boundary_start_idx); WindowInputExpression boundary_end(eval_chunk, boundary_end_idx); - const auto count = eval_chunk.size(); - bounds.Reset(); D_ASSERT(bounds.ColumnCount() == 8); @@ -487,7 +488,7 @@ void WindowBoundariesState::UpdateBounds(idx_t row_idx, DataChunk &eval_chunk) { } next_pos += count; - bounds.SetCardinality(count); + bounds.SetChildCardinality(count); } void WindowBoundariesState::PartitionBegin(idx_t row_idx, const idx_t count, bool is_jump) { diff --git a/src/duckdb/src/function/window/window_collection.cpp b/src/duckdb/src/function/window/window_collection.cpp index 7055fa789..370065e92 100644 --- a/src/duckdb/src/function/window/window_collection.cpp +++ b/src/duckdb/src/function/window/window_collection.cpp @@ -129,7 +129,7 @@ WindowCursor::WindowCursor(const WindowCollection &paged, vector colum state.current_row_index = 0; state.next_row_index = paged.size(); state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; - chunk.SetCardinality(state.next_row_index); + chunk.SetChildCardinality(state.next_row_index); return; } else if (chunk.data.empty()) { auto &inputs = paged.inputs; @@ -148,7 +148,7 @@ LogicalType WindowCollectionChunkScanner::PrefixStructType(column_t end, column_ for (auto c = begin; c < end; ++c) { auto name = std::to_string(c); auto type = chunk.data[c].GetType(); - std::pair child {name, type}; + std::pair child {Identifier(name), type}; partition_children.emplace_back(child); } // For single children, don;t build a struct - compare will be slow diff --git a/src/duckdb/src/function/window/window_constant_aggregator.cpp b/src/duckdb/src/function/window/window_constant_aggregator.cpp index 5c6cec98d..a2364af94 100644 --- a/src/duckdb/src/function/window/window_constant_aggregator.cpp +++ b/src/duckdb/src/function/window/window_constant_aggregator.cpp @@ -1,5 +1,6 @@ #include "duckdb/function/window/window_constant_aggregator.hpp" +#include "duckdb/common/clustered_aggregate.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/function/window/window_aggregate_states.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" @@ -123,27 +124,27 @@ WindowConstantAggregatorLocalState::WindowConstantAggregatorLocalState( // WindowConstantAggregator //===--------------------------------------------------------------------===// bool WindowConstantAggregator::CanAggregate(const BoundWindowExpression &wexpr) { - if (!wexpr.aggregate) { + if (!wexpr.AggregateFunction()) { return false; } // The function must be able to be used as an aggregate - if (!wexpr.aggregate->CanAggregate()) { + if (!wexpr.AggregateFunction()->CanAggregate()) { return false; } // window exclusion cannot be handled by constant aggregates - if (wexpr.exclude_clause != WindowExcludeMode::NO_OTHER) { + if (wexpr.WindowExclude() != WindowExcludeMode::NO_OTHER) { return false; } // DISTINCT aggregation cannot be handled by constant aggregation - if (wexpr.distinct) { + if (wexpr.Distinct()) { return false; } // COUNT(*) is already handled efficiently by segment trees. - if (wexpr.children.empty()) { + if (wexpr.GetChildren().empty()) { return false; } @@ -164,11 +165,11 @@ bool WindowConstantAggregator::CanAggregate(const BoundWindowExpression &wexpr) offset PRECEDING and offset FOLLOWING options vary in meaning depending on the frame mode. */ - switch (wexpr.start) { + switch (wexpr.WindowStart()) { case WindowBoundary::UNBOUNDED_PRECEDING: break; case WindowBoundary::CURRENT_ROW_RANGE: - if (!wexpr.orders.empty()) { + if (!wexpr.OrderBy().empty()) { return false; } break; @@ -176,11 +177,11 @@ bool WindowConstantAggregator::CanAggregate(const BoundWindowExpression &wexpr) return false; } - switch (wexpr.end) { + switch (wexpr.WindowEnd()) { case WindowBoundary::UNBOUNDED_FOLLOWING: break; case WindowBoundary::CURRENT_ROW_RANGE: - if (!wexpr.orders.empty()) { + if (!wexpr.OrderBy().empty()) { return false; } break; @@ -201,7 +202,7 @@ WindowConstantAggregator::WindowConstantAggregator(BoundWindowExpression &wexpr, ClientContext &context) : WindowAggregator(RebindAggregate(context, wexpr)) { // We only need these values for Sink - for (auto &child : wexpr.children) { + for (auto &child : wexpr.GetChildren()) { child_idx.emplace_back(shared.RegisterSink(child)); } } @@ -237,7 +238,7 @@ void WindowConstantAggregatorLocalState::Sink(ExecutionContext &context, DataChu payload_chunk.data[c].Reference(sink_chunk.data[child_idx[c]]); } - AggregateInputData aggr_input_data(aggr.GetFunctionData(), statef.allocator); + AggregateInputData aggr_input_data(aggr, statef.allocator); idx_t begin = 0; idx_t filter_idx = 0; auto partition_end = partition_offsets[partition + 1]; @@ -277,22 +278,27 @@ void WindowConstantAggregatorLocalState::Sink(ExecutionContext &context, DataChu } } else { // Slice to [begin, end) - if (begin) { + if (begin == 0 && end == sink_chunk.size()) { + inputs.Reference(payload_chunk); + } else { + // we cannot resize non-flat vectors (e.g. dictionary vectors), so we have to slice for (idx_t c = 0; c < payload_chunk.ColumnCount(); ++c) { inputs.data[c].Slice(payload_chunk.data[c], begin, end); } - } else { - inputs.Reference(payload_chunk); } - inputs.SetCardinality(end - begin); + inputs.SetChildCardinality(end - begin); } // Aggregate the filtered rows into a single state const auto count = inputs.size(); auto state = state_f_data[partition]; - if (aggr.function.HasStateSimpleUpdateCallback()) { - aggr.function.GetStateSimpleUpdateCallback()(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), - state, count); + auto cluster_update = aggr.function.GetStateClusterUpdateCallback(); + if (cluster_update) { + ClusteredAggr clustered; + clustered.SetSingleRun(state, count); + aggr_input_data.clustered = &clustered; + cluster_update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), clustered, count); + aggr_input_data.clustered = nullptr; } else { state_p_data[0] = state_f_data[partition]; aggr.function.GetStateUpdateCallback()(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), statep, diff --git a/src/duckdb/src/function/window/window_custom_aggregator.cpp b/src/duckdb/src/function/window/window_custom_aggregator.cpp index a7639d693..39b187d6c 100644 --- a/src/duckdb/src/function/window/window_custom_aggregator.cpp +++ b/src/duckdb/src/function/window/window_custom_aggregator.cpp @@ -8,16 +8,16 @@ namespace duckdb { // WindowCustomAggregator //===--------------------------------------------------------------------===// bool WindowCustomAggregator::CanAggregate(const BoundWindowExpression &wexpr, WindowAggregationMode mode) { - if (!wexpr.aggregate) { + if (!wexpr.AggregateFunction()) { return false; } - if (!wexpr.aggregate->CanWindow()) { + if (!wexpr.AggregateFunction()->CanWindow()) { return false; } // ORDER BY arguments are not currently supported - if (!wexpr.arg_orders.empty()) { + if (!wexpr.ArgOrders().empty()) { return false; } @@ -82,7 +82,7 @@ WindowCustomAggregatorLocalState::WindowCustomAggregatorLocalState(ExecutionCont WindowCustomAggregatorLocalState::~WindowCustomAggregatorLocalState() { if (aggr.function.HasStateDestructorCallback()) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); aggr.function.GetStateDestructorCallback()(statef, aggr_input_data, 1); } } @@ -120,7 +120,7 @@ void WindowCustomAggregator::Finalize(ExecutionContext &context, CollectionPtr c WindowPartitionInput partition(context, inputs, count, child_idx, all_valids, filter_packed, stats, sink.interrupt_state); - AggregateInputData aggr_input_data(aggr.GetFunctionData(), gcstate.allocator); + AggregateInputData aggr_input_data(aggr, gcstate.allocator); aggr.function.GetWindowInitCallback()(aggr_input_data, partition, gcstate.state.data()); } @@ -160,7 +160,7 @@ void WindowCustomAggregator::Evaluate(ExecutionContext &context, const DataChunk vector all_frames; all_frames.reserve(count); EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { all_frames.push_back(frames); }); - AggregateInputData aggr_input_data(aggr.GetFunctionData(), lcstate.allocator); + AggregateInputData aggr_input_data(aggr, lcstate.allocator); aggr.function.GetWindowBatchCallback()(aggr_input_data, partition, gstate_p, lcstate.state.data(), all_frames.data(), count, result, row_idx); return; @@ -168,7 +168,7 @@ void WindowCustomAggregator::Evaluate(ExecutionContext &context, const DataChunk EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { // Extract the range - AggregateInputData aggr_input_data(aggr.GetFunctionData(), lcstate.allocator); + AggregateInputData aggr_input_data(aggr, lcstate.allocator); aggr.function.GetWindowCallback()(aggr_input_data, partition, gstate_p, lcstate.state.data(), frames, result, i); }); diff --git a/src/duckdb/src/function/window/window_distinct_aggregator.cpp b/src/duckdb/src/function/window/window_distinct_aggregator.cpp index 989715562..1060e8d98 100644 --- a/src/duckdb/src/function/window/window_distinct_aggregator.cpp +++ b/src/duckdb/src/function/window/window_distinct_aggregator.cpp @@ -18,15 +18,15 @@ enum class WindowDistinctSortStage : uint8_t { INIT, COMBINE, FINALIZE, SORTED, // WindowDistinctAggregator //===--------------------------------------------------------------------===// bool WindowDistinctAggregator::CanAggregate(const BoundWindowExpression &wexpr) { - if (!wexpr.aggregate) { + if (!wexpr.AggregateFunction()) { return false; } - if (!wexpr.aggregate->CanAggregate()) { + if (!wexpr.AggregateFunction()->CanAggregate()) { return false; } - return wexpr.distinct && wexpr.exclude_clause == WindowExcludeMode::NO_OTHER && wexpr.arg_orders.empty(); + return wexpr.Distinct() && wexpr.WindowExclude() == WindowExcludeMode::NO_OTHER && wexpr.ArgOrders().empty(); } WindowDistinctAggregator::WindowDistinctAggregator(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared, @@ -269,7 +269,6 @@ void WindowDistinctAggregatorLocalState::Sink(ExecutionContext &context, DataChu for (column_t c = 0; c < child_idx.size(); ++c) { sort_chunk.data[c].Reference(coll_chunk.data[child_idx[c]]); } - sort_chunk.SetCardinality(sink_chunk); // Apply FILTER clause, if any if (filter_sel) { @@ -537,7 +536,7 @@ void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDisti auto &leaves = ldastate.leaves; auto &sel = ldastate.sel; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), ldastate.allocator); + AggregateInputData aggr_input_data(aggr, ldastate.allocator); //! The states to update auto &update_v = ldastate.update_v; @@ -637,7 +636,7 @@ void WindowDistinctAggregatorLocalState::FlushStates() { } const auto &aggr = gdstate.aggr; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); statel.Verify(); aggr.function.GetStateCombineCallback()(statel, statep, aggr_input_data, flush_count); diff --git a/src/duckdb/src/function/window/window_executor.cpp b/src/duckdb/src/function/window/window_executor.cpp index 8063f8c7b..a2c64d30f 100644 --- a/src/duckdb/src/function/window/window_executor.cpp +++ b/src/duckdb/src/function/window/window_executor.cpp @@ -12,21 +12,21 @@ namespace duckdb { WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) : wexpr(wexpr), range_expr((WindowBoundariesState::HasPrecedingRange(wexpr) || WindowBoundariesState::HasFollowingRange(wexpr)) - ? wexpr.orders[0].expression.get() + ? wexpr.OrderBy()[0].expression.get() : nullptr) { if (range_expr) { - range_idx = shared.RegisterCollection(wexpr.orders[0].expression, false); + range_idx = shared.RegisterCollection(wexpr.OrderByMutable()[0].expression, false); } - boundary_start_idx = shared.RegisterEvaluate(wexpr.start_expr); - boundary_end_idx = shared.RegisterEvaluate(wexpr.end_expr); + boundary_start_idx = shared.RegisterEvaluate(wexpr.StartExprMutable()); + boundary_end_idx = shared.RegisterEvaluate(wexpr.EndExprMutable()); - if (wexpr.window) { - if (wexpr.window->HasSharingCallback()) { - wexpr.window->GetSharing(*this, shared); + if (wexpr.WindowFunction()) { + if (wexpr.WindowFunction()->HasSharingCallback()) { + wexpr.WindowFunction()->GetSharing(*this, shared); } else { // If no one overrides, assume the arguments are only needed at evaluate time - for (auto &child : wexpr.children) { + for (auto &child : wexpr.GetChildrenMutable()) { child_idx.emplace_back(shared.RegisterEvaluate(child)); } } @@ -34,13 +34,13 @@ WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, WindowSharedExpress } void WindowExecutor::Evaluate(ExecutionContext &context, idx_t row_idx, DataChunk &eval_chunk, Vector &result, - OperatorSinkInput &sink) const { + OperatorSinkInput &sink, idx_t count) const { auto &lbstate = sink.local_state.Cast(); - lbstate.state.UpdateBounds(row_idx, eval_chunk); + lbstate.state.UpdateBounds(row_idx, eval_chunk, count); EvaluateInternal(context, eval_chunk, lbstate.state.bounds, result, row_idx, sink); - FlatVector::SetSize(result, count_t(eval_chunk.size())); + FlatVector::SetSize(result, count_t(count)); result.Verify(); } @@ -49,7 +49,7 @@ WindowExecutorGlobalState::WindowExecutorGlobalState(ClientContext &client, cons const ValidityMask &order_mask) : client(client), executor(executor), payload_count(payload_count), partition_mask(partition_mask), order_mask(order_mask) { - for (const auto &child : executor.wexpr.children) { + for (const auto &child : executor.wexpr.GetChildren()) { arg_types.emplace_back(child->GetReturnType()); } } @@ -68,24 +68,24 @@ void WindowExecutorLocalState::Finalize(ExecutionContext &context, CollectionPtr unique_ptr WindowExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const { - if (wexpr.window && wexpr.window->HasGlobalCallback()) { - return wexpr.window->GetGlobalState(client, *this, payload_count, partition_mask, order_mask); + if (wexpr.WindowFunction() && wexpr.WindowFunction()->HasGlobalCallback()) { + return wexpr.WindowFunction()->GetGlobalState(client, *this, payload_count, partition_mask, order_mask); } return make_uniq(client, *this, payload_count, partition_mask, order_mask); } unique_ptr WindowExecutor::GetLocalState(ExecutionContext &context, const GlobalSinkState &gstate) const { - if (wexpr.window && wexpr.window->HasLocalCallback()) { - return wexpr.window->GetLocalState(context, gstate); + if (wexpr.WindowFunction() && wexpr.WindowFunction()->HasLocalCallback()) { + return wexpr.WindowFunction()->GetLocalState(context, gstate); } return make_uniq(context, gstate.Cast()); } void WindowExecutor::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, OperatorSinkInput &sink) const { - if (wexpr.window && wexpr.window->HasSinkCallback()) { - wexpr.window->Sink(context, sink_chunk, coll_chunk, input_idx, sink); + if (wexpr.WindowFunction() && wexpr.WindowFunction()->HasSinkCallback()) { + wexpr.WindowFunction()->Sink(context, sink_chunk, coll_chunk, input_idx, sink); } else { auto &lbstate = sink.local_state.Cast(); lbstate.Sink(context, sink_chunk, coll_chunk, input_idx, sink); @@ -96,8 +96,8 @@ void WindowExecutor::Finalize(ExecutionContext &context, CollectionPtr collectio auto &lbstate = sink.local_state.Cast(); lbstate.state.Finalize(collection); - if (wexpr.window && wexpr.window->HasFinalizeCallback()) { - wexpr.window->Finalize(context, collection, sink); + if (wexpr.WindowFunction() && wexpr.WindowFunction()->HasFinalizeCallback()) { + wexpr.WindowFunction()->Finalize(context, collection, sink); } else { lbstate.Finalize(context, collection, sink); } @@ -105,8 +105,8 @@ void WindowExecutor::Finalize(ExecutionContext &context, CollectionPtr collectio void WindowExecutor::EvaluateInternal(ExecutionContext &context, DataChunk &eval_chunk, DataChunk &bounds, Vector &result, idx_t row_idx, OperatorSinkInput &sink) const { - if (wexpr.window && wexpr.window->HasEvaluateCallback()) { - wexpr.window->Evaluate(context, eval_chunk, bounds, result, row_idx, sink); + if (wexpr.WindowFunction() && wexpr.WindowFunction()->HasEvaluateCallback()) { + wexpr.WindowFunction()->Evaluate(context, eval_chunk, bounds, result, row_idx, sink); } } diff --git a/src/duckdb/src/function/window/window_index_tree.cpp b/src/duckdb/src/function/window/window_index_tree.cpp index 1d8352788..c12f4a361 100644 --- a/src/duckdb/src/function/window/window_index_tree.cpp +++ b/src/duckdb/src/function/window/window_index_tree.cpp @@ -47,7 +47,7 @@ void WindowIndexTreeLocalState::BuildLeaves() { if (count == 0) { break; } - auto &indices = payload_chunk.data[0]; + const auto &indices = payload_chunk.data[0]; if (window_tree.mst32) { auto &sorted = window_tree.mst32->LowestLevel(); auto data = FlatVector::GetData(indices); diff --git a/src/duckdb/src/function/window/window_merge_sort_tree.cpp b/src/duckdb/src/function/window/window_merge_sort_tree.cpp index 172d5583a..035dde505 100644 --- a/src/duckdb/src/function/window/window_merge_sort_tree.cpp +++ b/src/duckdb/src/function/window/window_merge_sort_tree.cpp @@ -79,7 +79,6 @@ void WindowMergeSortTreeLocalState::Sink(ExecutionContext &context, DataChunk &c for (column_t c = 0; c < order_idx.size(); ++c) { sort_chunk.data[c].Reference(chunk.data[order_idx[c]]); } - sort_chunk.SetCardinality(chunk); // Apply FILTER clause, if any if (filter_sel) { diff --git a/src/duckdb/src/function/window/window_naive_aggregator.cpp b/src/duckdb/src/function/window/window_naive_aggregator.cpp index 7322c6b0e..dd0a47c9b 100644 --- a/src/duckdb/src/function/window/window_naive_aggregator.cpp +++ b/src/duckdb/src/function/window/window_naive_aggregator.cpp @@ -13,7 +13,7 @@ namespace duckdb { //===--------------------------------------------------------------------===// WindowNaiveAggregator::WindowNaiveAggregator(const WindowAggregateExecutor &executor, WindowSharedExpressions &shared) : WindowAggregator(executor.wexpr, shared), executor(executor) { - for (const auto &order : wexpr.arg_orders) { + for (const auto &order : wexpr.ArgOrders()) { arg_order_idx.emplace_back(shared.RegisterCollection(order.expression, false)); } } @@ -133,7 +133,7 @@ void WindowNaiveLocalState::Finalize(ExecutionContext &context, WindowAggregator // The sort expressions have already been computed, so we just need to reference them vector orders; - for (const auto &order_by : aggregator.wexpr.arg_orders) { + for (const auto &order_by : aggregator.wexpr.ArgOrders()) { auto order = order_by.Copy(); const auto &type = order.expression->GetReturnType(); order.expression = make_uniq(type, orders.size()); @@ -164,7 +164,7 @@ void WindowNaiveLocalState::FlushStates(const WindowAggregatorGlobalState &gsink leaves.Slice(scanned, update_sel, flush_count); const auto &aggr = gsink.aggr; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); aggr.function.GetStateUpdateCallback()(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), statep, flush_count); @@ -351,7 +351,7 @@ void WindowNaiveLocalState::Evaluate(ExecutionContext &context, const WindowAggr FlushStates(gsink); // Finalise the result aggregates and write to the result - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); aggr.function.GetStateFinalizeCallback()(statef, aggr_input_data, result, count, 0); // Destruct the result aggregates @@ -373,7 +373,7 @@ void WindowNaiveAggregator::Evaluate(ExecutionContext &context, const DataChunk } bool WindowNaiveAggregator::CanAggregate(const BoundWindowExpression &wexpr) { - if (!wexpr.aggregate || !wexpr.aggregate->CanAggregate()) { + if (!wexpr.AggregateFunction() || !wexpr.AggregateFunction()->CanAggregate()) { return false; } return true; diff --git a/src/duckdb/src/function/window/window_rank_function.cpp b/src/duckdb/src/function/window/window_rank_function.cpp index 2528db396..3cbdd8dca 100644 --- a/src/duckdb/src/function/window/window_rank_function.cpp +++ b/src/duckdb/src/function/window/window_rank_function.cpp @@ -5,6 +5,7 @@ #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/function/window/ranking_functions.hpp" #include "duckdb/function/window_function.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -22,9 +23,9 @@ class WindowPeerGlobalState : public WindowExecutorGlobalState { // If the argument order is a prefix of the partition ordering // (and the optimizer is enabled), then we can just use the partition ordering. auto &wexpr = executor.wexpr; - auto &arg_orders = executor.wexpr.arg_orders; - const auto optimize = ClientConfig::GetConfig(client).enable_optimizer; - if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) { + auto &arg_orders = executor.wexpr.ArgOrders(); + const auto optimize = Settings::Get(client); + if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.OrderBy(), arg_orders) != arg_orders.size()) { token_tree = make_uniq(client, arg_orders, executor.arg_order_idx, payload_count); } } @@ -139,7 +140,7 @@ struct WindowPeerExecutor : public WindowExecutor { void WindowPeerExecutor::GetSharing(WindowExecutor &executor, WindowSharedExpressions &shared) { const auto &wexpr = executor.wexpr; auto &arg_order_idx = executor.arg_order_idx; - for (const auto &order : wexpr.arg_orders) { + for (const auto &order : wexpr.ArgOrders()) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } } @@ -176,7 +177,7 @@ class WindowRankLocalState : public WindowPeerLocalState { }; void WindowRankExecutor::GetBounds(WindowBoundsSet &required, const BoundWindowExpression &wexpr) { - if (wexpr.arg_orders.empty()) { + if (wexpr.ArgOrders().empty()) { required.insert(PARTITION_BEGIN); required.insert(PEER_BEGIN); } else { @@ -206,7 +207,7 @@ void WindowRankExecutor::GetData(ExecutionContext &context, DataChunk &eval_chun idx_t row_idx, OperatorSinkInput &sink) { auto &gpeer = sink.global_state.Cast(); auto &lpeer = sink.local_state.Cast(); - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto rdata = FlatVector::Writer(result, count); if (gpeer.use_framing) { @@ -290,7 +291,7 @@ void WindowDenseRankExecutor::GetData(ExecutionContext &context, DataChunk &eval Vector &result, idx_t row_idx, OperatorSinkInput &sink) { auto &gpeer = sink.global_state.Cast(); auto &lpeer = sink.local_state.Cast(); - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto &order_mask = gpeer.order_mask; auto partition_begin = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); @@ -378,7 +379,7 @@ class WindowPercentRankLocalState : public WindowPeerLocalState { }; void WindowPercentRankExecutor::GetBounds(WindowBoundsSet &required, const BoundWindowExpression &wexpr) { - if (wexpr.arg_orders.empty()) { + if (wexpr.ArgOrders().empty()) { required.insert(PARTITION_BEGIN); required.insert(PARTITION_END); required.insert(PEER_BEGIN); @@ -415,7 +416,7 @@ void WindowPercentRankExecutor::GetData(ExecutionContext &context, DataChunk &ev Vector &result, idx_t row_idx, OperatorSinkInput &sink) { auto &gpeer = sink.global_state.Cast(); auto &lpeer = sink.local_state.Cast(); - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto rdata = FlatVector::Writer(result, count); if (gpeer.use_framing) { @@ -471,7 +472,7 @@ class WindowCumeDistLocalState : public WindowPeerLocalState { }; void WindowCumeDistExecutor::GetBounds(WindowBoundsSet &required, const BoundWindowExpression &wexpr) { - if (wexpr.arg_orders.empty()) { + if (wexpr.ArgOrders().empty()) { required.insert(PARTITION_BEGIN); required.insert(PARTITION_END); required.insert(PEER_END); @@ -504,7 +505,7 @@ static inline double CumeDist(const idx_t begin, const idx_t end, const idx_t pe void WindowCumeDistExecutor::GetData(ExecutionContext &context, DataChunk &eval_chunk, DataChunk &bounds, Vector &result, idx_t row_idx, OperatorSinkInput &sink) { auto &gpeer = sink.global_state.Cast(); - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto rdata = FlatVector::Writer(result, count); if (gpeer.use_framing) { diff --git a/src/duckdb/src/function/window/window_rownumber_function.cpp b/src/duckdb/src/function/window/window_rownumber_function.cpp index 7ce07adae..5c92c07c6 100644 --- a/src/duckdb/src/function/window/window_rownumber_function.cpp +++ b/src/duckdb/src/function/window/window_rownumber_function.cpp @@ -5,6 +5,7 @@ #include "duckdb/function/window_function.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -22,12 +23,12 @@ class WindowRowNumberGlobalState : public WindowExecutorGlobalState { // If the argument order is prefix of the partition ordering, // then we can just use the partition ordering. auto &wexpr = executor.wexpr; - auto &arg_orders = executor.wexpr.arg_orders; - const auto optimize = ClientConfig::GetConfig(client).enable_optimizer; - if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) { + auto &arg_orders = executor.wexpr.ArgOrders(); + const auto optimize = Settings::Get(client); + if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.OrderBy(), arg_orders) != arg_orders.size()) { // "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their // position in the input data, such that two elements never compare as equal." - token_tree = make_uniq(client, executor.wexpr.arg_orders, executor.arg_order_idx, + token_tree = make_uniq(client, executor.wexpr.ArgOrders(), executor.arg_order_idx, payload_count, true); } } @@ -142,7 +143,7 @@ WindowFunction RowNumberFun::GetFunction() { } void WindowRowNumberExecutor::GetBounds(WindowBoundsSet &required, const BoundWindowExpression &wexpr) { - if (wexpr.arg_orders.empty()) { + if (wexpr.ArgOrders().empty()) { required.insert(PARTITION_BEGIN); } else { // Secondary orders need to know where the frame is @@ -155,12 +156,12 @@ void WindowRowNumberExecutor::GetSharing(WindowExecutor &executor, WindowSharedE const auto &wexpr = executor.wexpr; auto &child_idx = executor.child_idx; - for (auto &child : wexpr.children) { + for (auto &child : wexpr.GetChildren()) { child_idx.emplace_back(shared.RegisterEvaluate(child)); } auto &arg_order_idx = executor.arg_order_idx; - for (const auto &order : wexpr.arg_orders) { + for (const auto &order : wexpr.ArgOrders()) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } } @@ -179,7 +180,7 @@ unique_ptr WindowRowNumberExecutor::GetLocal(ExecutionContext &c void WindowRowNumberExecutor::GetData(ExecutionContext &context, DataChunk &eval_chunk, DataChunk &bounds, Vector &result, idx_t row_idx, OperatorSinkInput &sink) { auto &grstate = sink.global_state.Cast(); - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto rdata = FlatVector::Writer(result, count); if (grstate.use_framing) { @@ -234,7 +235,7 @@ WindowFunction NtileFun::GetFunction() { } void WindowNtileExecutor::GetBounds(WindowBoundsSet &required, const BoundWindowExpression &wexpr) { - if (wexpr.arg_orders.empty()) { + if (wexpr.ArgOrders().empty()) { required.insert(PARTITION_BEGIN); required.insert(PARTITION_END); } else { @@ -251,7 +252,7 @@ unique_ptr WindowNtileExecutor::GetLocal(ExecutionContext &conte void WindowNtileExecutor::GetData(ExecutionContext &context, DataChunk &eval_chunk, DataChunk &bounds, Vector &result, idx_t row_idx, OperatorSinkInput &sink) { auto &grstate = sink.global_state.Cast(); - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto partition_begin = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); auto partition_end = FlatVector::GetData(bounds.data[PARTITION_END]); diff --git a/src/duckdb/src/function/window/window_segment_tree.cpp b/src/duckdb/src/function/window/window_segment_tree.cpp index fd3b8f555..af07fcfaa 100644 --- a/src/duckdb/src/function/window/window_segment_tree.cpp +++ b/src/duckdb/src/function/window/window_segment_tree.cpp @@ -11,16 +11,16 @@ namespace duckdb { // WindowSegmentTree //===--------------------------------------------------------------------===// bool WindowSegmentTree::CanAggregate(const BoundWindowExpression &wexpr) { - if (!wexpr.aggregate) { + if (!wexpr.AggregateFunction()) { return false; } - if (!wexpr.aggregate->CanAggregate()) { + if (!wexpr.AggregateFunction()->CanAggregate()) { return false; } // We can't handle DISTINCT, ORDER BY args or () args (COUNT(*)) - return !wexpr.distinct && wexpr.arg_orders.empty() && !wexpr.children.empty(); + return !wexpr.Distinct() && wexpr.ArgOrders().empty() && !wexpr.GetChildren().empty(); } WindowSegmentTree::WindowSegmentTree(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared) @@ -195,7 +195,7 @@ void WindowSegmentTreePart::FlushStates(bool combining) { return; } - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); if (combining) { statel.Verify(); aggr.function.GetStateCombineCallback()(statel, statep, aggr_input_data, flush_count); @@ -210,7 +210,7 @@ void WindowSegmentTreePart::FlushStates(bool combining) { } void WindowSegmentTreePart::Combine(WindowSegmentTreePart &other, idx_t count) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); aggr.function.GetStateCombineCallback()(other.statef, statef, aggr_input_data, count); } @@ -281,7 +281,7 @@ void WindowSegmentTreePart::WindowSegmentValue(const WindowSegmentTreeGlobalStat } void WindowSegmentTreePart::Finalize(Vector &result, idx_t count) { // Finalise the result aggregates and write to result if write_result is set - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr, allocator); aggr.function.GetStateFinalizeCallback()(statef, aggr_input_data, result, count, 0); // Destruct the result aggregates @@ -293,7 +293,7 @@ void WindowSegmentTreePart::Finalize(Vector &result, idx_t count) { WindowSegmentTreeGlobalState::WindowSegmentTreeGlobalState(ClientContext &client, const WindowSegmentTree &aggregator, idx_t group_count) : WindowAggregatorGlobalState(client, aggregator, group_count), tree(aggregator), levels_flat_native(client, aggr) { - D_ASSERT(!aggregator.wexpr.children.empty()); + D_ASSERT(!aggregator.wexpr.GetChildren().empty()); // compute space required to store internal nodes of segment tree levels_flat_start.push_back(0); diff --git a/src/duckdb/src/function/window/window_value_function.cpp b/src/duckdb/src/function/window/window_value_function.cpp index 794aa12af..bf0c431b8 100644 --- a/src/duckdb/src/function/window/window_value_function.cpp +++ b/src/duckdb/src/function/window/window_value_function.cpp @@ -12,6 +12,7 @@ #include "duckdb/function/function_set.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/parser/expression/bound_expression.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -27,7 +28,7 @@ class WindowValueGlobalState : public WindowExecutorGlobalState { value_idx(executor.child_idx[0]) { if (!executor.arg_order_idx.empty()) { value_tree = - make_uniq(client, executor.wexpr.arg_orders, executor.arg_order_idx, payload_count); + make_uniq(client, executor.wexpr.ArgOrders(), executor.arg_order_idx, payload_count); } } @@ -47,11 +48,11 @@ class WindowValueLocalState : public WindowExecutorLocalState { public: WindowValueLocalState(ExecutionContext &context, const WindowValueGlobalState &gvstate) : WindowExecutorLocalState(context, gvstate), gvstate(gvstate), ignore_nulls(&all_valid) { - WindowAggregatorLocalState::InitSubFrames(frames, gvstate.executor.wexpr.exclude_clause); + WindowAggregatorLocalState::InitSubFrames(frames, gvstate.executor.wexpr.WindowExclude()); if (gvstate.value_tree) { local_value = gvstate.value_tree->GetLocalState(context); - if (gvstate.executor.wexpr.ignore_nulls) { + if (gvstate.executor.wexpr.IgnoreNulls()) { sort_nulls.Initialize(); } } @@ -92,10 +93,10 @@ void WindowValueLocalState::Sinker(ExecutionContext &context, DataChunk &sink_ch // If we need to IGNORE NULLS for the child, and there are NULLs, // then build an SV to hold them const auto coll_count = coll_chunk.size(); - auto &values = coll_chunk.data[gvstate.value_idx]; + const auto &values = coll_chunk.data[gvstate.value_idx]; auto validity = values.Validity(); auto &sort_nulls = lvstate.sort_nulls; - if (gvstate.executor.wexpr.ignore_nulls && validity.CanHaveNull()) { + if (gvstate.executor.wexpr.IgnoreNulls() && validity.CanHaveNull()) { for (sel_t i = 0; i < coll_count; ++i) { if (validity.IsValid(i)) { sort_nulls[filtered++] = i; @@ -115,7 +116,7 @@ void WindowValueLocalState::Finalizer(ExecutionContext &context, CollectionPtr c const auto &value_idx = gvstate.value_idx; auto &lvstate = sink.local_state.Cast(); - if (value_idx != DConstants::INVALID_INDEX && executor.wexpr.ignore_nulls) { + if (value_idx != DConstants::INVALID_INDEX && executor.wexpr.IgnoreNulls()) { lvstate.ignore_nulls = &collection->validities[value_idx]; } @@ -140,9 +141,9 @@ class WindowValueStreamingState : public WindowExecutorStreamingState { static Value GetFirstValue(ClientContext &client, DataChunk &input, const BoundWindowExpression &wexpr) { // Just execute the expression once ExpressionExecutor executor(client); - executor.AddExpression(*wexpr.children[0]); + executor.AddExpression(*wexpr.GetChildren()[0]); DataChunk result; - result.Initialize(client, {wexpr.children[0]->GetReturnType()}); + result.Initialize(client, {wexpr.GetChildren()[0]->GetReturnType()}); executor.Execute(input, result); return result.GetValue(0, 0); @@ -150,8 +151,8 @@ class WindowValueStreamingState : public WindowExecutorStreamingState { WindowValueStreamingState(ClientContext &client, DataChunk &input, const BoundWindowExpression &wexpr) : wexpr(wexpr), vec(GetFirstValue(client, input, wexpr), count_t(STANDARD_VECTOR_SIZE)), - sel(STANDARD_VECTOR_SIZE), eval(client), arg(wexpr.children[0]->GetReturnType()) { - eval.AddExpression(*wexpr.children[0]); + sel(STANDARD_VECTOR_SIZE), eval(client), arg(wexpr.GetChildren()[0]->GetReturnType()) { + eval.AddExpression(*wexpr.GetChildren()[0]); } const BoundWindowExpression &wexpr; @@ -203,19 +204,19 @@ void WindowValueExecutor::GetBounds(WindowBoundsSet &required, const BoundWindow void WindowValueExecutor::GetSharing(WindowExecutor &executor, WindowSharedExpressions &shared) { // The children have to be handled separately because only the first one is global const auto &wexpr = executor.wexpr; - D_ASSERT(!wexpr.children.empty()); + D_ASSERT(!wexpr.GetChildren().empty()); auto &child_idx = executor.child_idx; - child_idx.emplace_back(shared.RegisterCollection(wexpr.children[0], wexpr.ignore_nulls)); + child_idx.emplace_back(shared.RegisterCollection(wexpr.GetChildren()[0], wexpr.IgnoreNulls())); - if (wexpr.children.size() > 1) { - child_idx.emplace_back(shared.RegisterEvaluate(wexpr.children[1])); + if (wexpr.GetChildren().size() > 1) { + child_idx.emplace_back(shared.RegisterEvaluate(wexpr.GetChildren()[1])); } - if (wexpr.children.size() > 2) { - child_idx.emplace_back(shared.RegisterEvaluate(wexpr.children[2])); + if (wexpr.GetChildren().size() > 2) { + child_idx.emplace_back(shared.RegisterEvaluate(wexpr.GetChildren()[2])); } auto &arg_order_idx = executor.arg_order_idx; - for (const auto &order : wexpr.arg_orders) { + for (const auto &order : wexpr.ArgOrders()) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } } @@ -261,9 +262,9 @@ class WindowLeadLagGlobalState : public WindowValueGlobalState { // If the argument order is prefix of the partition ordering, // then we can just use the partition ordering. auto &wexpr = executor.wexpr; - auto &arg_orders = executor.wexpr.arg_orders; - const auto optimize = ClientConfig::GetConfig(client).enable_optimizer; - if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) { + auto &arg_orders = executor.wexpr.ArgOrders(); + const auto optimize = Settings::Get(client); + if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.OrderBy(), arg_orders) != arg_orders.size()) { // "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their // position in the input data, such that two elements never compare as equal." // Note: If the user specifies an partial secondary sort, the disambiguation will use the @@ -296,7 +297,7 @@ class WindowLeadLagLocalState : public WindowValueLocalState { } static void GetBounds(WindowBoundsSet &required, const BoundWindowExpression &wexpr) { - if (wexpr.arg_orders.empty()) { + if (wexpr.ArgOrders().empty()) { required.insert(PARTITION_BEGIN); required.insert(PARTITION_END); } else { @@ -348,8 +349,8 @@ class WindowLeadLagStreamingState : public WindowExecutorStreamingState { public: static bool ComputeOffset(ClientContext &client, const BoundWindowExpression &wexpr, int64_t &offset) { offset = 1; - if (wexpr.children.size() > 1) { - auto &offset_expr = wexpr.children[1]; + if (wexpr.GetChildren().size() > 1) { + auto &offset_expr = wexpr.GetChildren()[1]; if (offset_expr->HasParameter() || !offset_expr->IsFoldable()) { return false; } @@ -372,12 +373,12 @@ class WindowLeadLagStreamingState : public WindowExecutorStreamingState { } static bool ComputeDefault(ClientContext &client, const BoundWindowExpression &wexpr, Value &result) { - if (wexpr.children.size() < 3) { + if (wexpr.GetChildren().size() < 3) { result = Value(wexpr.GetReturnType()); return true; } - auto &default_expr = wexpr.children[2]; + auto &default_expr = wexpr.GetChildren()[2]; if (default_expr && (default_expr->HasParameter() || !default_expr->IsFoldable())) { return false; } @@ -386,8 +387,8 @@ class WindowLeadLagStreamingState : public WindowExecutorStreamingState { } WindowLeadLagStreamingState(ClientContext &context, const BoundWindowExpression &wexpr) - : wexpr(wexpr), executor(context, *wexpr.children[0]), prev(wexpr.GetReturnType()), temp(wexpr.GetReturnType()), - sel(STANDARD_VECTOR_SIZE) { + : wexpr(wexpr), executor(context, *wexpr.GetChildren()[0]), prev(wexpr.GetReturnType()), + temp(wexpr.GetReturnType()), sel(STANDARD_VECTOR_SIZE) { ComputeOffset(context, wexpr, offset); ComputeDefault(context, wexpr, dflt); @@ -413,7 +414,7 @@ class WindowLeadLagStreamingState : public WindowExecutorStreamingState { void ExecuteLag(ExecutionContext &context, DataChunk &input, Vector &result) { D_ASSERT(offset >= 0); - auto &curr = curr_chunk.data[0]; + const auto &curr = curr_chunk.data[0]; curr_chunk.Reset(); executor.Execute(input, curr_chunk); const idx_t count = input.size(); @@ -449,7 +450,7 @@ class WindowLeadLagStreamingState : public WindowExecutorStreamingState { D_ASSERT(offset < 0); // Input has been set up with the number of rows we CAN produce. const idx_t count = input.size(); - auto &curr = curr_chunk.data[0]; + const auto &curr = curr_chunk.data[0]; // Copy unified[buffered:count] => result[pos:] idx_t pos = 0; idx_t unified_offset = buffered; @@ -515,7 +516,7 @@ struct WindowLeadLagExecutor : public WindowValueExecutor { //! Streaming APIs static bool CanStream(ClientContext &client, const BoundWindowExpression &wexpr, idx_t max_delta) { // We can stream LEAD/LAG if the arguments are constant and the delta is less than a block behind - if (!wexpr.ignore_nulls) { + if (!wexpr.IgnoreNulls()) { Value dflt; if (!WindowLeadLagStreamingState::ComputeDefault(client, wexpr, dflt)) { return false; @@ -552,58 +553,48 @@ unique_ptr WindowLeadLagExecutor::Bind(BindWindowFunctionInput &in return nullptr; } -static WindowFunctionSet GetLeadLagFunctionSet(const char *name, const ExpressionType &type) { - WindowFunctionSet funcs(name); - - auto bind = WindowLeadLagExecutor::Bind; - auto bounds = WindowLeadLagLocalState::GetBounds; - auto sharing = WindowLeadLagExecutor::GetSharing; - auto global = WindowLeadLagExecutor::GetGlobal; - auto local = WindowLeadLagExecutor::GetLocal; - auto sink = WindowLeadLagLocalState::Sinker; - auto finalize = WindowLeadLagLocalState::Finalizer; - auto evaluate = WindowLeadLagExecutor::GetData; - - funcs.AddFunction(WindowFunction({LogicalTypeId::ANY, LogicalType::BIGINT, LogicalTypeId::ANY}, LogicalType::ANY, - type, bind, bounds, sharing, global, local, sink, finalize, evaluate)); - funcs.AddFunction(WindowFunction({LogicalTypeId::ANY, LogicalType::BIGINT}, LogicalType::ANY, type, bind, bounds, - sharing, global, local, sink, finalize, evaluate)); - funcs.AddFunction(WindowFunction({LogicalTypeId::ANY}, LogicalType::ANY, type, bind, bounds, sharing, global, local, - sink, finalize, evaluate)); - - for (auto &f : funcs.functions) { - f.SetCanStreamCallback(WindowLeadLagExecutor::CanStream); - f.SetStreamingStateCallback(WindowLeadLagExecutor::GetStreamingState); - f.SetStreamingDataCallback(WindowLeadLagExecutor::StreamData); - } - - return funcs; +static WindowFunction GetLeadLagFunction(const char *name, const ExpressionType &type) { + const auto bind = WindowLeadLagExecutor::Bind; + const auto bounds = WindowLeadLagLocalState::GetBounds; + const auto sharing = WindowLeadLagExecutor::GetSharing; + const auto global = WindowLeadLagExecutor::GetGlobal; + const auto local = WindowLeadLagExecutor::GetLocal; + const auto sink = WindowLeadLagLocalState::Sinker; + const auto finalize = WindowLeadLagLocalState::Finalizer; + const auto evaluate = WindowLeadLagExecutor::GetData; + + WindowFunction func(name, {}, LogicalType::ANY, type, bind, bounds, sharing, global, local, sink, finalize, + evaluate); + auto &sig = func.GetSignature(); + + // Type of the default value is not actually T, as it is force-cast to the return type - not unified. + sig.AddParameter("col", LogicalType::TEMPLATE("T")); + sig.AddParameter("offset", LogicalType::BIGINT, Value::BIGINT(1)); + sig.AddParameter("default", LogicalTypeId::ANY, Value(LogicalTypeId::SQLNULL)); + sig.SetReturnType(LogicalType::TEMPLATE("T")); + + func.SetCanStreamCallback(WindowLeadLagExecutor::CanStream); + func.SetStreamingStateCallback(WindowLeadLagExecutor::GetStreamingState); + func.SetStreamingDataCallback(WindowLeadLagExecutor::StreamData); + + return func; } -WindowFunctionSet LeadFun::GetFunctions() { - return GetLeadLagFunctionSet(Name, ExpressionType::WINDOW_LEAD); +WindowFunction LeadFun::GetFunction() { + return GetLeadLagFunction(Name, ExpressionType::WINDOW_LEAD); } WindowFunction LeadFun::GetTypedFunction(const LogicalType &type, idx_t nargs) { - auto funcs = GetLeadLagFunctionSet(Name, ExpressionType::WINDOW_LEAD); - - for (auto &func : funcs.functions) { - if (func.GetSignature().GetParameterCount() != nargs) { - continue; - } - - func.GetSignature().GetParameter(0).SetType(type); - if (nargs > 2) { - func.GetSignature().GetParameter(2).SetType(type); - } - return func; + auto func = GetLeadLagFunction(Name, ExpressionType::WINDOW_LEAD); + func.GetSignature().GetParameter(0).SetType(type); + if (nargs > 2) { + func.GetSignature().GetParameter(2).SetType(type); } - - throw InternalException("Invalid number of arguments requested for LEAD: %lld", nargs); + return func; } -WindowFunctionSet LagFun::GetFunctions() { - return GetLeadLagFunctionSet(Name, ExpressionType::WINDOW_LAG); +WindowFunction LagFun::GetFunction() { + return GetLeadLagFunction(Name, ExpressionType::WINDOW_LAG); } unique_ptr WindowLeadLagExecutor::GetGlobal(ClientContext &client, const WindowExecutor &executor, @@ -622,7 +613,7 @@ void WindowLeadLagExecutor::GetData(ExecutionContext &context, DataChunk &eval_c idx_t row_idx, OperatorSinkInput &sink) { auto &glstate = sink.global_state.Cast(); auto &llstate = sink.local_state.Cast(); - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto &cursor = *llstate.cursor; auto &wexpr = glstate.executor.wexpr; @@ -699,10 +690,10 @@ void WindowLeadLagExecutor::GetData(ExecutionContext &context, DataChunk &eval_c auto &ignore_nulls = llstate.ignore_nulls; bool can_shift = ignore_nulls->CannotHaveNull() && !glstate.use_framing; if (has_offset) { - can_shift = can_shift && wexpr.children[1]->IsFoldable(); + can_shift = can_shift && wexpr.GetChildren()[1]->IsFoldable(); } if (has_default) { - can_shift = can_shift && wexpr.children[2]->IsFoldable(); + can_shift = can_shift && wexpr.GetChildren()[2]->IsFoldable(); } const auto row_end = row_idx + count; @@ -747,7 +738,7 @@ void WindowLeadLagExecutor::GetData(ExecutionContext &context, DataChunk &eval_c const idx_t col_idx = 0; while (width) { const auto source_offset = cursor.Seek(index); - auto &source = cursor.chunk.data[col_idx]; + const auto &source = cursor.chunk.data[col_idx]; const auto copied = MinValue(cursor.chunk.size() - source_offset, width); VectorOperations::Copy(source, result, source_offset + copied, source_offset, i); i += copied; @@ -789,9 +780,10 @@ struct WindowFirstValueExecutor : public WindowValueExecutor { //! Streaming APIs static bool CanStream(ClientContext &client, const BoundWindowExpression &wexpr, idx_t max_delta) { - if (wexpr.ignore_nulls) { + if (wexpr.IgnoreNulls()) { // We can stream first values ignoring NULLs if they are "running totals" - return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS; + return wexpr.WindowStart() == WindowBoundary::UNBOUNDED_PRECEDING && + wexpr.WindowEnd() == WindowBoundary::CURRENT_ROW_ROWS; } return true; } @@ -807,7 +799,7 @@ void WindowFirstValueExecutor::StreamData(ExecutionContext &context, DataChunk & // If we are ignoring NULLs and we started with a NULL, // then look for a non-NULL value and update it - if (wexpr.ignore_nulls && ConstantVector::IsNull(sstate.vec)) { + if (wexpr.IgnoreNulls() && ConstantVector::IsNull(sstate.vec)) { // Find the first non-NULL value auto &executor = sstate.eval; auto &arg = sstate.arg; @@ -821,7 +813,7 @@ void WindowFirstValueExecutor::StreamData(ExecutionContext &context, DataChunk & result.Reference(prev); } else { auto &sel = sstate.sel; - Vector split(wexpr.children[0]->GetReturnType()); + Vector split(wexpr.GetChildren()[0]->GetReturnType()); split.SetValue(0, prev.GetValue(0)); sel_t s = 0; for (sel_t i = 0; i < count; ++i) { @@ -858,10 +850,10 @@ void WindowFirstValueExecutor::GetData(ExecutionContext &context, DataChunk &eva auto &gvstate = sink.global_state.Cast(); auto &lvstate = sink.local_state.Cast(); auto &cursor = *lvstate.cursor; - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto &frames = lvstate.frames; auto &ignore_nulls = *lvstate.ignore_nulls; - auto exclude_mode = gvstate.executor.wexpr.exclude_clause; + auto exclude_mode = gvstate.executor.wexpr.WindowExclude(); WindowAggregator::EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { if (gvstate.value_tree) { idx_t frame_width = 0; @@ -913,7 +905,8 @@ struct WindowLastValueExecutor : public WindowValueExecutor { //! Streaming APIs static bool CanStream(ClientContext &client, const BoundWindowExpression &wexpr, idx_t max_delta) { // We can stream last values if they are "running totals" - return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS; + return wexpr.WindowStart() == WindowBoundary::UNBOUNDED_PRECEDING && + wexpr.WindowEnd() == WindowBoundary::CURRENT_ROW_ROWS; } static void StreamData(ExecutionContext &context, DataChunk &input, DataChunk &delayed, idx_t delayed_capacity, Vector &result, LocalSourceState &state); @@ -926,7 +919,7 @@ void WindowLastValueExecutor::StreamData(ExecutionContext &context, DataChunk &i auto &wexpr = sstate.wexpr; const auto count = input.size(); auto &executor = sstate.eval; - if (wexpr.ignore_nulls) { + if (wexpr.IgnoreNulls()) { auto &prev = sstate.vec; auto &arg = sstate.arg; executor.ExecuteExpression(input, arg); @@ -937,7 +930,7 @@ void WindowLastValueExecutor::StreamData(ExecutionContext &context, DataChunk &i VectorOperations::Copy(arg, result, count, 0, 0); } else { // Copy the data as it may be a reference to the argument - Vector copy(wexpr.children[0]->GetReturnType()); + Vector copy(wexpr.GetChildren()[0]->GetReturnType()); VectorOperations::Copy(arg, copy, count, 0, 0); // Overwrite the previous non-NULL value if the first one is NULL if (!validity.RowIsValidUnsafe(0)) { @@ -978,10 +971,10 @@ void WindowLastValueExecutor::GetData(ExecutionContext &context, DataChunk &eval auto &gvstate = sink.global_state.Cast(); auto &lvstate = sink.local_state.Cast(); auto &cursor = *lvstate.cursor; - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto &frames = lvstate.frames; auto &ignore_nulls = *lvstate.ignore_nulls; - auto exclude_mode = gvstate.executor.wexpr.exclude_clause; + auto exclude_mode = gvstate.executor.wexpr.WindowExclude(); WindowAggregator::EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { if (gvstate.value_tree) { idx_t frame_width = 0; @@ -1034,7 +1027,7 @@ void WindowLastValueExecutor::GetData(ExecutionContext &context, DataChunk &eval class WindowNthValueStreamingState : public WindowValueStreamingState { public: static bool ComputeNthIndex(ClientContext &client, const BoundWindowExpression &wexpr, idx_t &nth_index) { - auto &nth_expr = wexpr.children[1]; + auto &nth_expr = wexpr.GetChildren()[1]; if (nth_expr && (nth_expr->HasParameter() || !nth_expr->IsFoldable())) { return false; } @@ -1080,7 +1073,8 @@ struct WindowNthValueExecutor : public WindowValueExecutor { } // We can stream Nth values if they are "running totals" - return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS; + return wexpr.WindowStart() == WindowBoundary::UNBOUNDED_PRECEDING && + wexpr.WindowEnd() == WindowBoundary::CURRENT_ROW_ROWS; } static unique_ptr GetStreamingState(ClientContext &client, DataChunk &input, const BoundWindowExpression &wexpr) { @@ -1095,7 +1089,7 @@ struct WindowNthValueExecutor : public WindowValueExecutor { void WindowNthValueStreamingState::StreamData(ExecutionContext &context, DataChunk &input, Vector &result) { // If we haven't reached the chunk with the value yet, reference the current value const auto count = input.size(); - if (!wexpr.ignore_nulls && nth_count + count < nth_index) { + if (!wexpr.IgnoreNulls() && nth_count + count < nth_index) { result.Reference(vec); nth_count += count; return; @@ -1113,12 +1107,12 @@ void WindowNthValueStreamingState::StreamData(ExecutionContext &context, DataChu const auto &validity = unified.validity; // Split the result between NULLs and the Nth Value - Vector split(wexpr.children[0]->GetReturnType(), 2); + Vector split(wexpr.GetChildren()[0]->GetReturnType(), 2); sel_t s = 0; split.SetValue(s, vec.GetValue(0)); // If we are ignoring NULLS and there are NULLS, search for a non-NULL value. - const bool any_value = wexpr.ignore_nulls && !validity.CannotHaveNull(); + const bool any_value = wexpr.IgnoreNulls() && !validity.CannotHaveNull(); for (idx_t i = 0; i < count; ++i) { if (!any_value || validity.RowIsValidUnsafe(unified.sel->get_index(i))) { ++nth_count; @@ -1152,10 +1146,10 @@ void WindowNthValueExecutor::GetData(ExecutionContext &context, DataChunk &eval_ auto &gvstate = sink.global_state.Cast(); auto &lvstate = sink.local_state.Cast(); auto &cursor = *lvstate.cursor; - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto &frames = lvstate.frames; auto &ignore_nulls = *lvstate.ignore_nulls; - auto exclude_mode = gvstate.executor.wexpr.exclude_clause; + auto exclude_mode = gvstate.executor.wexpr.WindowExclude(); D_ASSERT(cursor.chunk.ColumnCount() == 1); const auto &child_idx = gvstate.executor.child_idx; const auto nth_idx = child_idx[1]; @@ -1492,18 +1486,18 @@ unique_ptr WindowFillExecutor::Bind(BindWindowFunctionInput &input void WindowFillExecutor::GetSharing(WindowExecutor &executor, WindowSharedExpressions &shared) { const auto &wexpr = executor.wexpr; - D_ASSERT(!wexpr.children.empty()); + D_ASSERT(!wexpr.GetChildren().empty()); //! Never ignore nulls (that's the point!) auto &child_idx = executor.child_idx; - child_idx.emplace_back(shared.RegisterCollection(wexpr.children[0], false)); + child_idx.emplace_back(shared.RegisterCollection(wexpr.GetChildren()[0], false)); // If the argument order is prefix of the partition ordering, // then we can just use the partition ordering. - auto &arg_orders = wexpr.arg_orders; + auto &arg_orders = wexpr.ArgOrders(); auto &arg_order_idx = executor.arg_order_idx; - if (BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) { - for (const auto &order : wexpr.arg_orders) { + if (BoundWindowExpression::GetSharedOrders(wexpr.OrderBy(), arg_orders) != arg_orders.size()) { + for (const auto &order : wexpr.ArgOrders()) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } } @@ -1512,22 +1506,22 @@ void WindowFillExecutor::GetSharing(WindowExecutor &executor, WindowSharedExpres if (arg_order_idx.empty()) { // We use the range ordering, even if it has not been defined if (executor.range_idx == DConstants::INVALID_INDEX) { - D_ASSERT(wexpr.orders.size() == 1); + D_ASSERT(wexpr.OrderBy().size() == 1); // We don't need the validity mask because we have also requested the valid range for the ordering. - executor.range_idx = shared.RegisterCollection(wexpr.orders[0].expression, false); + executor.range_idx = shared.RegisterCollection(wexpr.OrderBy()[0].expression, false); } executor.aux_idx.emplace_back(executor.range_idx); } else { // For secondary sorts, we need the entire collection so we can interpolate using the values D_ASSERT(arg_order_idx.size() == 1); - executor.aux_idx.emplace_back(shared.RegisterCollection(wexpr.arg_orders[0].expression, false)); + executor.aux_idx.emplace_back(shared.RegisterCollection(wexpr.ArgOrders()[0].expression, false)); } } static void WindowFillCopy(WindowCursor &cursor, Vector &result, idx_t count, idx_t row_idx, column_t col_idx = 0) { for (idx_t target_offset = 0; target_offset < count;) { const auto source_offset = cursor.Seek(row_idx); - auto &source = cursor.chunk.data[col_idx]; + const auto &source = cursor.chunk.data[col_idx]; const auto copied = MinValue(cursor.chunk.size() - source_offset, count - target_offset); VectorOperations::Copy(source, result, source_offset + copied, source_offset, target_offset); target_offset += copied; @@ -1557,9 +1551,9 @@ class WindowFillLocalState : public WindowLeadLagLocalState { required.insert(FRAME_BEGIN); required.insert(FRAME_END); - auto &arg_orders = wexpr.arg_orders; - const auto shared = BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) == arg_orders.size(); - if (wexpr.arg_orders.empty() || shared) { + auto &arg_orders = wexpr.ArgOrders(); + const auto shared = BoundWindowExpression::GetSharedOrders(wexpr.OrderBy(), arg_orders) == arg_orders.size(); + if (wexpr.ArgOrders().empty() || shared) { // FILL uses the validity ranges to quickly eliminate indexes that can't be interpolated. // This only works for non-secondary orderings required.insert(VALID_BEGIN); @@ -1611,7 +1605,7 @@ unique_ptr WindowFillExecutor::GetLocal(ExecutionContext &contex void WindowFillExecutor::GetData(ExecutionContext &context, DataChunk &eval_chunk, DataChunk &bounds, Vector &result, idx_t row_idx, OperatorSinkInput &sink) { auto &lfstate = sink.local_state.Cast(); - const auto count = eval_chunk.size(); + const auto count = bounds.size(); auto &cursor = *lfstate.cursor; auto &order_cursor = *lfstate.order_cursor; @@ -1633,15 +1627,15 @@ void WindowFillExecutor::GetData(ExecutionContext &context, DataChunk &eval_chun idx_t next_valid = DConstants::INVALID_INDEX; const auto &wexpr = gfstate.executor.wexpr; - auto interpolate_func = GetFillInterpolateFunction(wexpr.children[0]->GetReturnType()); - auto value_func = GetFillValueFunction(wexpr.children[0]->GetReturnType()); + auto interpolate_func = GetFillInterpolateFunction(wexpr.GetChildren()[0]->GetReturnType()); + auto value_func = GetFillValueFunction(wexpr.GetChildren()[0]->GetReturnType()); // Secondary sort - use the MSTs if (gfstate.value_tree) { // Roughly what we need to do is find the previous and next non-null values // with non-null ordering values. This is essentially LEAD/LAG(IGNORE NULLS) - auto slope_func = GetFillSlopeFunction(wexpr.arg_orders[0].expression->GetReturnType()); - auto order_value_func = GetFillValueFunction(wexpr.arg_orders[0].expression->GetReturnType()); + auto slope_func = GetFillSlopeFunction(wexpr.ArgOrders()[0].expression->GetReturnType()); + auto order_value_func = GetFillValueFunction(wexpr.ArgOrders()[0].expression->GetReturnType()); auto &frames = lfstate.frames; frames.resize(1); auto &frame = frames[0]; @@ -1752,8 +1746,8 @@ void WindowFillExecutor::GetData(ExecutionContext &context, DataChunk &eval_chun auto valid_begin = FlatVector::GetData(bounds.data[VALID_BEGIN]); auto valid_end = FlatVector::GetData(bounds.data[VALID_END]); idx_t prev_partition = DConstants::INVALID_INDEX; - auto slope_func = GetFillSlopeFunction(wexpr.orders[0].expression->GetReturnType()); - auto order_value_func = GetFillValueFunction(wexpr.orders[0].expression->GetReturnType()); + auto slope_func = GetFillSlopeFunction(wexpr.OrderBy()[0].expression->GetReturnType()); + auto order_value_func = GetFillValueFunction(wexpr.OrderBy()[0].expression->GetReturnType()); for (idx_t i = 0; i < count; ++i, ++row_idx) { // Did we change partitions? diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h index b6a1d5deb..fa3cf5688 100644 --- a/src/duckdb/src/include/duckdb.h +++ b/src/duckdb/src/include/duckdb.h @@ -934,7 +934,7 @@ process. Must be destroyed with 'duckdb_destroy_instance_cache'. * @return The database instance cache. */ -DUCKDB_C_API duckdb_instance_cache duckdb_create_instance_cache(); +DUCKDB_C_API duckdb_instance_cache duckdb_create_instance_cache(void); /*! Creates a new database instance in the instance cache, or retrieves an existing database instance. @@ -1071,7 +1071,7 @@ Returns the version of the linked DuckDB, with a version postfix for dev version Usually used for developing C extensions that must return this for a compatibility check. */ -DUCKDB_C_API const char *duckdb_library_version(); +DUCKDB_C_API const char *duckdb_library_version(void); /*! Get the list of (fully qualified) table names of the query. @@ -1114,7 +1114,7 @@ This should not be called in a loop as it internally loops over all the options. * @return The amount of config options available. */ -DUCKDB_C_API size_t duckdb_config_count(); +DUCKDB_C_API size_t duckdb_config_count(void); /*! Obtains a human-readable name and description of a specific configuration option. This can be used to e.g. @@ -1653,7 +1653,7 @@ This is the amount of tuples that will fit into a data chunk created by `duckdb_ * @return The vector size. */ -DUCKDB_C_API idx_t duckdb_vector_size(); +DUCKDB_C_API idx_t duckdb_vector_size(void); /*! Whether or not the duckdb_string_t value is inlined. @@ -2518,8 +2518,11 @@ DUCKDB_C_API duckdb_value duckdb_create_bignum(duckdb_bignum input); /*! Creates a DECIMAL value from a duckdb_decimal +The width must be between 1 and 38, and the scale must not exceed the width. + * @param input The duckdb_decimal value -* @return The value. This must be destroyed with `duckdb_destroy_value`. +* @return The value, or `nullptr` if the width or scale are out of range. This must be destroyed with +`duckdb_destroy_value`. */ DUCKDB_C_API duckdb_value duckdb_create_decimal(duckdb_decimal input); @@ -3000,7 +3003,7 @@ Creates a value of type SQLNULL. * @return The duckdb_value representing SQLNULL. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_C_API duckdb_value duckdb_create_null_value(); +DUCKDB_C_API duckdb_value duckdb_create_null_value(void); /*! Returns the number of elements in a LIST value. @@ -3156,9 +3159,9 @@ DUCKDB_C_API duckdb_logical_type duckdb_create_enum_type(const char **member_nam Creates a DECIMAL type with the specified width and scale. The resulting type should be destroyed with `duckdb_destroy_logical_type`. -* @param width The width of the decimal type -* @param scale The scale of the decimal type -* @return The logical type. +* @param width The width of the decimal type. Must be between 1 and 38. +* @param scale The scale of the decimal type. Must not exceed the width. +* @return The logical type, or `nullptr` if the width or scale are out of range. */ DUCKDB_C_API duckdb_logical_type duckdb_create_decimal_type(uint8_t width, uint8_t scale); @@ -3706,7 +3709,7 @@ The return value must be destroyed with `duckdb_destroy_scalar_function`. * @return The scalar function object. */ -DUCKDB_C_API duckdb_scalar_function duckdb_create_scalar_function(); +DUCKDB_C_API duckdb_scalar_function duckdb_create_scalar_function(void); /*! Destroys the given scalar function object. @@ -4042,7 +4045,7 @@ The return value should be destroyed with `duckdb_destroy_aggregate_function`. * @return The aggregate function object. */ -DUCKDB_C_API duckdb_aggregate_function duckdb_create_aggregate_function(); +DUCKDB_C_API duckdb_aggregate_function duckdb_create_aggregate_function(void); /*! Destroys the given aggregate function object. @@ -4203,7 +4206,7 @@ The return value should be destroyed with `duckdb_destroy_table_function`. * @return The table function object. */ -DUCKDB_C_API duckdb_table_function duckdb_create_table_function(); +DUCKDB_C_API duckdb_table_function duckdb_create_table_function(void); /*! Destroys the given table function object. @@ -5417,7 +5420,7 @@ Creates a new cast function object. * @return The cast function object. */ -DUCKDB_C_API duckdb_cast_function duckdb_create_cast_function(); +DUCKDB_C_API duckdb_cast_function duckdb_create_cast_function(void); /*! Sets the source type of the cast function. @@ -5605,7 +5608,7 @@ Creates a new file open options instance with blank settings. * @return The new file open options instance. Must be destroyed with `duckdb_destroy_file_open_options`. */ -DUCKDB_C_API duckdb_file_open_options duckdb_create_file_open_options(); +DUCKDB_C_API duckdb_file_open_options duckdb_create_file_open_options(void); /*! Sets a specific flag in the file open options. @@ -5716,7 +5719,7 @@ Creates a configuration option instance. * @return The resulting configuration option instance. Must be destroyed with `duckdb_destroy_config_option`. */ -DUCKDB_C_API duckdb_config_option duckdb_create_config_option(); +DUCKDB_C_API duckdb_config_option duckdb_create_config_option(void); /*! Destroys the given configuration option instance. @@ -5804,7 +5807,7 @@ The return value must be destroyed with `duckdb_destroy_copy_function`. * @return The copy function object. */ -DUCKDB_C_API duckdb_copy_function duckdb_create_copy_function(); +DUCKDB_C_API duckdb_copy_function duckdb_create_copy_function(void); /*! Sets the name of the copy function. @@ -6231,7 +6234,7 @@ Creates a new log storage object. * @return A log storage object. Must be destroyed with `duckdb_destroy_log_storage`. */ -DUCKDB_C_API duckdb_log_storage duckdb_create_log_storage(); +DUCKDB_C_API duckdb_log_storage duckdb_create_log_storage(void); /*! Destroys a log storage object. diff --git a/src/duckdb/src/include/duckdb/catalog/catalog.hpp b/src/duckdb/src/include/duckdb/catalog/catalog.hpp index 7a1e69536..48b158af3 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog.hpp @@ -79,6 +79,22 @@ class LogicalDelete; class LogicalUpdate; class CreateStatement; class CatalogEntryRetriever; +class QueryNode; + +//! Per-capability opt-in for remote catalogs. Each value gates a specific dispatch path or +//! engine-wide accommodation: +//! - IS_REMOTE: this catalog represents data hosted on a remote server. Drives generic +//! accommodations (e.g. CheckAmbiguousCatalogOrSchema skips schema lookups, database_manager +//! counts attached remote catalogs). +//! - EXECUTE_QUERY_NODE: `RemoteExecute(QueryNode)` is implemented; the RemotePushdownOptimizer +//! may push down structured queries to this catalog. +//! - CONNECT: `RemoteExecute(string)` is implemented; the CONNECT chokepoint may route raw SQL +//! to this catalog. +enum class RemoteCapability : uint8_t { + IS_REMOTE, + EXECUTE_QUERY_NODE, + CONNECT, +}; //! The Catalog object represents the catalog of the database. class Catalog { @@ -92,16 +108,16 @@ class Catalog { //! Get the SystemCatalog from the DatabaseInstance DUCKDB_API static Catalog &GetSystemCatalog(DatabaseInstance &db); //! Get the specified Catalog from the ClientContext - DUCKDB_API static Catalog &GetCatalog(ClientContext &context, const string &catalog_name); + DUCKDB_API static Catalog &GetCatalog(ClientContext &context, const Identifier &catalog_name); //! Get the specified Catalog from the ClientContext - DUCKDB_API static Catalog &GetCatalog(CatalogEntryRetriever &retriever, const string &catalog_name); + DUCKDB_API static Catalog &GetCatalog(CatalogEntryRetriever &retriever, const Identifier &catalog_name); //! Get the specified Catalog from the DatabaseInstance - DUCKDB_API static Catalog &GetCatalog(DatabaseInstance &db, const string &catalog_name); + DUCKDB_API static Catalog &GetCatalog(DatabaseInstance &db, const Identifier &catalog_name); //! Gets the specified Catalog from the database if it exists - DUCKDB_API static optional_ptr GetCatalogEntry(ClientContext &context, const string &catalog_name); + DUCKDB_API static optional_ptr GetCatalogEntry(ClientContext &context, const Identifier &catalog_name); //! Gets the specified Catalog from the database if it exists DUCKDB_API static optional_ptr GetCatalogEntry(CatalogEntryRetriever &retriever, - const string &catalog_name); + const Identifier &catalog_name); //! Get the specific Catalog from the AttachedDatabase DUCKDB_API static Catalog &GetCatalog(AttachedDatabase &db); @@ -129,7 +145,7 @@ class Catalog { } //! Returns the catalog name - based on how the catalog was attached - DUCKDB_API const string &GetName() const; + DUCKDB_API const Identifier &GetName() const; DUCKDB_API idx_t GetOid(); DUCKDB_API virtual string GetCatalogType() = 0; @@ -227,53 +243,55 @@ class Catalog { DUCKDB_API optional_ptr GetSchema(ClientContext &context, const EntryLookupInfo &schema_lookup, OnEntryNotFound if_not_found); //! Overloadable method for giving warnings on ambiguous naming id.tab due to a database and schema with name id - DUCKDB_API virtual bool CheckAmbiguousCatalogOrSchema(ClientContext &context, const string &schema); + DUCKDB_API virtual bool CheckAmbiguousCatalogOrSchema(ClientContext &context, const Identifier &schema); - DUCKDB_API SchemaCatalogEntry &GetSchema(ClientContext &context, const string &schema); - DUCKDB_API SchemaCatalogEntry &GetSchema(CatalogTransaction transaction, const string &schema); + DUCKDB_API SchemaCatalogEntry &GetSchema(ClientContext &context, const Identifier &schema); + DUCKDB_API SchemaCatalogEntry &GetSchema(CatalogTransaction transaction, const Identifier &schema); DUCKDB_API SchemaCatalogEntry &GetSchema(CatalogTransaction transaction, const EntryLookupInfo &schema_lookup); - DUCKDB_API static SchemaCatalogEntry &GetSchema(ClientContext &context, const string &catalog_name, + DUCKDB_API static SchemaCatalogEntry &GetSchema(ClientContext &context, const Identifier &catalog_name, const EntryLookupInfo &schema_lookup); - DUCKDB_API optional_ptr GetSchema(ClientContext &context, const string &schema, + DUCKDB_API optional_ptr GetSchema(ClientContext &context, const Identifier &schema, OnEntryNotFound if_not_found); - DUCKDB_API optional_ptr GetSchema(CatalogTransaction transaction, const string &schema, + DUCKDB_API optional_ptr GetSchema(CatalogTransaction transaction, const Identifier &schema, OnEntryNotFound if_not_found); - DUCKDB_API static optional_ptr GetSchema(ClientContext &context, const string &catalog_name, + DUCKDB_API static optional_ptr GetSchema(ClientContext &context, const Identifier &catalog_name, const EntryLookupInfo &schema_lookup, OnEntryNotFound if_not_found); - DUCKDB_API static SchemaCatalogEntry &GetSchema(ClientContext &context, const string &catalog_name, - const string &schema); - DUCKDB_API static optional_ptr GetSchema(ClientContext &context, const string &catalog_name, - const string &schema, OnEntryNotFound if_not_found); + DUCKDB_API static SchemaCatalogEntry &GetSchema(ClientContext &context, const Identifier &catalog_name, + const Identifier &schema); + DUCKDB_API static optional_ptr GetSchema(ClientContext &context, const Identifier &catalog_name, + const Identifier &schema, + OnEntryNotFound if_not_found); DUCKDB_API static optional_ptr GetSchema(CatalogEntryRetriever &retriever, - const string &catalog_name, + const Identifier &catalog_name, const EntryLookupInfo &schema_lookup, OnEntryNotFound if_not_found); //! Scans all the schemas in the system one-by-one, invoking the callback for each entry DUCKDB_API virtual void ScanSchemas(ClientContext &context, std::function callback) = 0; //! Gets the "schema.name" entry of the specified type, if entry does not exist behavior depends on OnEntryNotFound - DUCKDB_API optional_ptr GetEntry(ClientContext &context, const string &schema, + DUCKDB_API optional_ptr GetEntry(ClientContext &context, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found); DUCKDB_API optional_ptr GetEntry(ClientContext &context, CatalogType catalog_type, - const string &schema, const string &name, + const Identifier &schema, const Identifier &name, OnEntryNotFound if_not_found); - DUCKDB_API optional_ptr GetEntry(CatalogEntryRetriever &retriever, const string &schema, + DUCKDB_API optional_ptr GetEntry(CatalogEntryRetriever &retriever, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found); - DUCKDB_API CatalogEntry &GetEntry(ClientContext &context, const string &schema, const EntryLookupInfo &lookup_info); + DUCKDB_API CatalogEntry &GetEntry(ClientContext &context, const Identifier &schema, + const EntryLookupInfo &lookup_info); //! Gets the "catalog.schema.name" entry of the specified type, if entry does not exist behavior depends on //! OnEntryNotFound - DUCKDB_API static optional_ptr GetEntry(ClientContext &context, const string &catalog, - const string &schema, const EntryLookupInfo &lookup_info, + DUCKDB_API static optional_ptr GetEntry(ClientContext &context, const Identifier &catalog, + const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found); - DUCKDB_API static optional_ptr GetEntry(CatalogEntryRetriever &retriever, const string &catalog, - const string &schema, const EntryLookupInfo &lookup_info, + DUCKDB_API static optional_ptr GetEntry(CatalogEntryRetriever &retriever, const Identifier &catalog, + const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found); - DUCKDB_API static CatalogEntry &GetEntry(ClientContext &context, const string &catalog, const string &schema, - const EntryLookupInfo &lookup_info); + DUCKDB_API static CatalogEntry &GetEntry(ClientContext &context, const Identifier &catalog, + const Identifier &schema, const EntryLookupInfo &lookup_info); template - optional_ptr GetEntry(ClientContext &context, const string &schema_name, const string &name, + optional_ptr GetEntry(ClientContext &context, const Identifier &schema_name, const Identifier &name, OnEntryNotFound if_not_found, QueryErrorContext error_context = QueryErrorContext()) { EntryLookupInfo lookup_info(T::Type, name, error_context); auto entry = GetEntry(context, schema_name, lookup_info, if_not_found); @@ -287,16 +305,16 @@ class Catalog { } template - T &GetEntry(ClientContext &context, const string &schema_name, const string &name, + T &GetEntry(ClientContext &context, const Identifier &schema_name, const Identifier &name, QueryErrorContext error_context = QueryErrorContext()) { auto entry = GetEntry(context, schema_name, name, OnEntryNotFound::THROW_EXCEPTION, error_context); return *entry; } - static CatalogEntry &GetEntry(ClientContext &context, CatalogType catalog_type, const string &catalog_name, - const string &schema_name, const string &name); - CatalogEntry &GetEntry(ClientContext &context, CatalogType catalog_type, const string &schema_name, - const string &name); + static CatalogEntry &GetEntry(ClientContext &context, CatalogType catalog_type, const Identifier &catalog_name, + const Identifier &schema_name, const Identifier &name); + CatalogEntry &GetEntry(ClientContext &context, CatalogType catalog_type, const Identifier &schema_name, + const Identifier &name); //! Append a scalar or aggregate function to the catalog DUCKDB_API optional_ptr AddFunction(ClientContext &context, CreateFunctionInfo &info); @@ -343,6 +361,19 @@ class Catalog { } virtual ErrorData SupportsCreateTable(BoundCreateTableInfo &info); + virtual bool Supports(RemoteCapability capability) const { + return false; + } + virtual unique_ptr RemoteExecute(ClientContext &context, unique_ptr node); + virtual unique_ptr RemoteExecute(ClientContext &context, const string &sql); + virtual bool SupportsPushdown(const ParsedExpression &expression); + virtual bool SupportsPushdown(const TableRef &ref); + virtual bool SupportsPushdown(const QueryNode &node); + //! User-facing short identifier for this catalog (e.g. shown in the CLI prompt when CONNECT-ed). + //! Defaults to the AttachedDatabase name (the AS alias). Remote catalogs override to expose + //! backend-specific information — the URI for quack, host:port/dbname for postgres, etc. + DUCKDB_API virtual string GetConnectDisplay(); + //! Whether or not this catalog should search a specific type with the standard priority DUCKDB_API virtual CatalogLookupBehavior CatalogTypeLookupRule(CatalogType type) const { return CatalogLookupBehavior::STANDARD; @@ -354,7 +385,7 @@ class Catalog { //! The default table is used for `SELECT * FROM ;` //! FIXME: these should be virtual methods DUCKDB_API bool HasDefaultTable() const; - DUCKDB_API void SetDefaultTable(const string &schema, const string &name); + DUCKDB_API void SetDefaultTable(const Identifier &schema, const Identifier &name); DUCKDB_API string GetDefaultTable() const; DUCKDB_API string GetDefaultTableSchema() const; @@ -366,8 +397,8 @@ class Catalog { public: template - static optional_ptr GetEntry(ClientContext &context, const string &catalog_name, const string &schema_name, - const string &name, OnEntryNotFound if_not_found, + static optional_ptr GetEntry(ClientContext &context, const Identifier &catalog_name, + const Identifier &schema_name, const Identifier &name, OnEntryNotFound if_not_found, QueryErrorContext error_context = QueryErrorContext()) { EntryLookupInfo lookup_info(T::Type, name, error_context); auto entry = GetEntry(context, catalog_name, schema_name, lookup_info, if_not_found); @@ -380,8 +411,8 @@ class Catalog { return &entry->template Cast(); } template - static T &GetEntry(ClientContext &context, const string &catalog_name, const string &schema_name, - const string &name, QueryErrorContext error_context = QueryErrorContext()) { + static T &GetEntry(ClientContext &context, const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &name, QueryErrorContext error_context = QueryErrorContext()) { auto entry = GetEntry(context, catalog_name, schema_name, name, OnEntryNotFound::THROW_EXCEPTION, error_context); return *entry; @@ -433,8 +464,8 @@ class Catalog { static CatalogEntryLookup TryLookupEntry(CatalogEntryRetriever &retriever, const vector &lookups, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found, bool allow_default_table_lookup); - static CatalogEntryLookup TryLookupEntry(CatalogEntryRetriever &retriever, const string &catalog, - const string &schema, const EntryLookupInfo &lookup_info, + static CatalogEntryLookup TryLookupEntry(CatalogEntryRetriever &retriever, const Identifier &catalog, + const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound if_not_found); //! Looks for a Catalog with a DefaultTable that matches the lookup diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry.hpp index 02ed1484b..c931a960a 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/enums/catalog_type.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/atomic.hpp" @@ -34,8 +35,8 @@ struct CreateInfo; //! Abstract base class of an entry in the catalog class CatalogEntry { public: - CatalogEntry(CatalogType type, Catalog &catalog, string name); - CatalogEntry(CatalogType type, string name, idx_t oid); + CatalogEntry(CatalogType type, Catalog &catalog, Identifier name); + CatalogEntry(CatalogType type, Identifier name, idx_t oid); virtual ~CatalogEntry(); //! The oid of the entry @@ -45,7 +46,7 @@ class CatalogEntry { //! Reference to the catalog set this entry is stored in optional_ptr set; //! The name of the entry - string name; + Identifier name; //! Whether or not the object is deleted bool deleted; //! Whether or not the object is temporary and should not be added to the WAL @@ -53,7 +54,7 @@ class CatalogEntry { //! Whether or not the entry is an internal entry (cannot be deleted, not dumped, etc) bool internal; //! The name of the extension that registered this entry (empty for core entries) - string extension_name; + Identifier extension_name; //! Timestamp at which the catalog entry was created atomic timestamp; //! (optional) comment on this entry @@ -65,7 +66,7 @@ class CatalogEntry { //! Child entry unique_ptr child; //! Parent entry (the node that dependents_map this node) - optional_ptr parent; + atomic parent; public: virtual unique_ptr AlterEntry(ClientContext &context, AlterInfo &info); @@ -119,7 +120,7 @@ class CatalogEntry { class InCatalogEntry : public CatalogEntry { public: - InCatalogEntry(CatalogType type, Catalog &catalog, string name); + InCatalogEntry(CatalogType type, Catalog &catalog, Identifier name); ~InCatalogEntry() override; //! The catalog the entry belongs to diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_index_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_index_entry.hpp index d02d52f00..3063335b6 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_index_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_index_entry.hpp @@ -16,12 +16,12 @@ class TableCatalogEntry; //! Wrapper class to allow copying a DuckIndexEntry (for altering the DuckIndexEntry metadata such as comments) struct IndexDataTableInfo { - IndexDataTableInfo(shared_ptr info_p, const string &index_name_p); + IndexDataTableInfo(shared_ptr info_p, const Identifier &index_name_p); //! Pointer to the DataTableInfo shared_ptr info; //! The index to be removed on destruction - string index_name; + Identifier index_name; }; //! A duck index entry @@ -43,8 +43,8 @@ class DuckIndexEntry : public IndexCatalogEntry { idx_t initial_index_size; public: - string GetSchemaName() const override; - string GetTableName() const override; + Identifier GetSchemaName() const override; + Identifier GetTableName() const override; DataTableInfo &GetDataTableInfo() const; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp index b8943b8c8..0ebc3f355 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp @@ -53,7 +53,12 @@ class DuckTableEntry : public TableCatalogEntry { TableFunction GetScanFunction(ClientContext &context, unique_ptr &bind_data) override; - vector GetColumnSegmentInfo(const QueryContext &context) override; + vector + GetColumnSegmentInfo(const QueryContext &context, + const ColumnSegmentInfoScanOptions &options = ColumnSegmentInfoScanOptions {}) override; + void InitializeColumnSegmentInfoScan(ColumnSegmentInfoScanState &state) override; + bool ScanColumnSegmentInfo(const QueryContext &context, ColumnSegmentInfoScanState &state, + vector &result) override; TableStorageInfo GetStorageInfo(ClientContext &context) override; @@ -70,7 +75,7 @@ class DuckTableEntry : public TableCatalogEntry { //! Scan all triggers without a transaction (used by checkpoint writer) void ScanTriggersNonTransactional(const std::function &callback); //! Drop a trigger by name - bool DropTrigger(CatalogTransaction transaction, const string &name, bool cascade); + bool DropTrigger(CatalogTransaction transaction, const Identifier &name, bool cascade); private: unique_ptr RenameColumn(ClientContext &context, RenameColumnInfo &info); diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/function_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/function_entry.hpp index d744364f0..53c9728a0 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/function_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/function_entry.hpp @@ -25,7 +25,7 @@ class FunctionEntry : public StandardEntry { this->extension_name = info.extension_name; } - string alias_of; + Identifier alias_of; vector descriptions; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/index_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/index_catalog_entry.hpp index 2b1ea037f..4298da932 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/index_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/index_catalog_entry.hpp @@ -46,8 +46,8 @@ class IndexCatalogEntry : public StandardEntry { //! Returns the original CREATE INDEX SQL string ToSQL() const override; - virtual string GetSchemaName() const = 0; - virtual string GetTableName() const = 0; + virtual Identifier GetSchemaName() const = 0; + virtual Identifier GetTableName() const = 0; //! Returns true, if this index is UNIQUE bool IsUnique() const; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/schema_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/schema_catalog_entry.hpp index f3832b920..e33f173cb 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/schema_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/schema_catalog_entry.hpp @@ -99,7 +99,7 @@ class SchemaCatalogEntry : public InCatalogEntry { const EntryLookupInfo &lookup_info); DUCKDB_API optional_ptr GetEntry(CatalogTransaction transaction, CatalogType type, - const string &name); + const Identifier &name); //! Drops an entry from the schema virtual void DropEntry(ClientContext &context, DropInfo &info) = 0; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index c13d95f50..cbad23af6 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -11,6 +11,7 @@ #include "duckdb/catalog/catalog_entry/trigger_catalog_entry.hpp" #include "duckdb/catalog/catalog_transaction.hpp" #include "duckdb/catalog/standard_entry.hpp" +#include "duckdb/common/enums/column_segment_info_scan_type.hpp" #include "duckdb/common/enums/trigger_type.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/parser/column_list.hpp" @@ -49,6 +50,7 @@ struct EntryLookupInfo; class Binder; struct ColumnSegmentInfo; +struct ColumnSegmentInfoScanState; class TableStorageInfo; class LogicalGet; @@ -71,10 +73,10 @@ class TableCatalogEntry : public StandardEntry { DUCKDB_API bool HasGeneratedColumns() const; //! Returns whether or not a column with the given name exists - DUCKDB_API bool ColumnExists(const string &name) const; + DUCKDB_API bool ColumnExists(const Identifier &name) const; //! Returns a reference to the column of the specified name. Throws an //! exception if the column does not exist. - DUCKDB_API const ColumnDefinition &GetColumn(const string &name) const; + DUCKDB_API const ColumnDefinition &GetColumn(const Identifier &name) const; //! Returns a reference to the column of the specified logical index. Throws an //! exception if the column does not exist. DUCKDB_API const ColumnDefinition &GetColumn(LogicalIndex idx) const; @@ -98,7 +100,7 @@ class TableCatalogEntry : public StandardEntry { //! If the column does not exist: //! If if_column_exists is true, returns DConstants::INVALID_INDEX //! If if_column_exists is false, throws an exception - DUCKDB_API LogicalIndex GetColumnIndex(string &name, bool if_exists = false) const; + DUCKDB_API LogicalIndex GetColumnIndex(Identifier &name, bool if_exists = false) const; DUCKDB_API StorageIndex GetStorageIndex(const ColumnIndex &column_index) const; //! Returns the scan function that can be used to scan the given table @@ -116,7 +118,14 @@ class TableCatalogEntry : public StandardEntry { static string ColumnNamesToSQL(const ColumnList &columns); //! Returns a list of segment information for this table, if exists - virtual vector GetColumnSegmentInfo(const QueryContext &context); + virtual vector + GetColumnSegmentInfo(const QueryContext &context, + const ColumnSegmentInfoScanOptions &options = ColumnSegmentInfoScanOptions {}); + //! Initialize an incremental scan over the table's column segment info. + virtual void InitializeColumnSegmentInfoScan(ColumnSegmentInfoScanState &state); + //! Append the next row group's column segment info to result. Returns false when no row groups remain. + virtual bool ScanColumnSegmentInfo(const QueryContext &context, ColumnSegmentInfoScanState &state, + vector &result); //! Returns the storage info of this table virtual TableStorageInfo GetStorageInfo(ClientContext &context) = 0; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/trigger_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/trigger_catalog_entry.hpp index fd2bede2a..c98166786 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/trigger_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/trigger_catalog_entry.hpp @@ -23,9 +23,13 @@ class TriggerCatalogEntry : public StandardEntry { //! The event that fires the trigger (INSERT/DELETE/UPDATE) TriggerEventType event_type; //! Columns for UPDATE OF - vector columns; + vector columns; //! Whether this fires FOR EACH ROW or FOR EACH STATEMENT TriggerForEach for_each; + //! Alias for the NEW TABLE transition table (empty if not specified) + Identifier referencing_new_table; + //! Alias for the OLD TABLE transition table (empty if not specified) + Identifier referencing_old_table; //! The trigger action (INSERT/UPDATE/DELETE as QueryNode) unique_ptr trigger_action; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/view_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/view_catalog_entry.hpp index b860ec855..07df92eb9 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/view_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/view_catalog_entry.hpp @@ -25,7 +25,7 @@ enum class BindViewAction { BIND_IF_UNBOUND, FORCE_REBIND }; struct ViewColumnInfo { vector types; - vector names; + vector names; }; //! A view catalog entry @@ -43,14 +43,14 @@ class ViewCatalogEntry : public StandardEntry { //! The SQL query (if any) string sql; //! The set of aliases associated with the view - vector aliases; + vector aliases; //! Returns the view column info, if the view is bound. Otherwise returns `nullptr` virtual shared_ptr GetColumnInfo() const; //! Bind a view so we know the types / names returned by it virtual void BindView(ClientContext &context, BindViewAction action = BindViewAction::BIND_IF_UNBOUND); //! Update the view with a new set of types / names - virtual void UpdateBinding(const vector &types, const vector &names); + virtual void UpdateBinding(const vector &types, const vector &names); Value GetColumnComment(idx_t column_index); public: @@ -73,7 +73,7 @@ class ViewCatalogEntry : public StandardEntry { //! Current binding thread atomic bind_thread; //! The comments on the columns of the view: can be empty if there are no comments - unordered_map column_comments; + identifier_map_t column_comments; private: void Initialize(CreateViewInfo &info); diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry_retriever.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry_retriever.hpp index 6dda5cb6b..051722aec 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry_retriever.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry_retriever.hpp @@ -39,18 +39,19 @@ class CatalogEntryRetriever { return context; } - optional_ptr GetEntry(const string &catalog, const string &schema, const EntryLookupInfo &lookup_info, + optional_ptr GetEntry(const Identifier &catalog, const Identifier &schema, + const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found = OnEntryNotFound::THROW_EXCEPTION); - optional_ptr GetEntry(Catalog &catalog, const string &schema, const EntryLookupInfo &lookup_info, + optional_ptr GetEntry(Catalog &catalog, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found = OnEntryNotFound::THROW_EXCEPTION); - LogicalType GetType(const string &catalog, const string &schema, const string &name, + LogicalType GetType(const Identifier &catalog, const Identifier &schema, const Identifier &name, OnEntryNotFound on_entry_not_found = OnEntryNotFound::RETURN_NULL); - LogicalType GetType(Catalog &catalog, const string &schema, const string &name, + LogicalType GetType(Catalog &catalog, const Identifier &schema, const Identifier &name, OnEntryNotFound on_entry_not_found = OnEntryNotFound::RETURN_NULL); - optional_ptr GetSchema(const string &catalog, const EntryLookupInfo &schema_lookup, + optional_ptr GetSchema(const Identifier &catalog, const EntryLookupInfo &schema_lookup, OnEntryNotFound on_entry_not_found = OnEntryNotFound::THROW_EXCEPTION); const CatalogSearchPath &GetSearchPath() const; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_search_path.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_search_path.hpp index 5b2ab62c5..8cc281198 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_search_path.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_search_path.hpp @@ -19,10 +19,10 @@ namespace duckdb { class ClientContext; struct CatalogSearchEntry { - CatalogSearchEntry(string catalog, string schema); + CatalogSearchEntry(Identifier catalog, Identifier schema); - string catalog; - string schema; + Identifier catalog; + Identifier schema; public: string ToString() const; @@ -32,7 +32,7 @@ struct CatalogSearchEntry { private: static CatalogSearchEntry ParseInternal(const string &input, idx_t &pos); - static string WriteOptionallyQuoted(const string &input); + static string WriteOptionallyQuoted(const Identifier &input); }; enum class CatalogSetPathType { SET_SCHEMA, SET_SCHEMAS, SET_DIRECTLY }; @@ -47,6 +47,7 @@ class CatalogSearchPath { DUCKDB_API void Set(CatalogSearchEntry new_value, CatalogSetPathType set_type); DUCKDB_API void Set(vector new_paths, CatalogSetPathType set_type); DUCKDB_API void Reset(); + DUCKDB_API void RefreshSetPaths(); DUCKDB_API vector Get() const; const vector &GetSetPaths() const { @@ -54,15 +55,15 @@ class CatalogSearchPath { } DUCKDB_API const CatalogSearchEntry &GetDefault() const; //! FIXME: this method is deprecated - DUCKDB_API string GetDefaultSchema(const string &catalog) const; - DUCKDB_API string GetDefaultSchema(ClientContext &context, const string &catalog) const; - DUCKDB_API string GetDefaultCatalog(const string &schema) const; + DUCKDB_API Identifier GetDefaultSchema(const Identifier &catalog) const; + DUCKDB_API Identifier GetDefaultSchema(ClientContext &context, const Identifier &catalog) const; + DUCKDB_API Identifier GetDefaultCatalog(const Identifier &schema) const; - DUCKDB_API vector GetSchemasForCatalog(const string &catalog) const; - DUCKDB_API vector GetCatalogsForSchema(const string &schema) const; + DUCKDB_API vector GetSchemasForCatalog(const Identifier &catalog) const; + DUCKDB_API vector GetCatalogsForSchema(const Identifier &schema) const; - DUCKDB_API bool SchemaInSearchPath(ClientContext &context, const string &catalog_name, - const string &schema_name) const; + DUCKDB_API bool SchemaInSearchPath(ClientContext &context, const Identifier &catalog_name, + const Identifier &schema_name) const; private: //! Set paths without checking if they exist diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp index 8cce35abf..ca941ac9e 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_set.hpp @@ -42,12 +42,12 @@ class CatalogEntryMap { void AddEntry(unique_ptr entry); void UpdateEntry(unique_ptr entry); void DropEntry(CatalogEntry &entry); - case_insensitive_tree_t> &Entries(); - optional_ptr GetEntry(const string &name); + identifier_tree_t> &Entries(); + optional_ptr GetEntry(const Identifier &name); private: - //! Mapping of string to catalog entry - case_insensitive_tree_t> entries; + //! Mapping of identifier to catalog entry + identifier_tree_t> entries; }; //! The Catalog Set stores (key, value) map of a set of CatalogEntries @@ -65,16 +65,16 @@ class CatalogSet { //! Create an entry in the catalog set. Returns whether or not it was //! successful. - DUCKDB_API bool CreateEntry(CatalogTransaction transaction, const string &name, unique_ptr value, + DUCKDB_API bool CreateEntry(CatalogTransaction transaction, const Identifier &name, unique_ptr value, const LogicalDependencyList &dependencies); - DUCKDB_API bool CreateEntry(ClientContext &context, const string &name, unique_ptr value, + DUCKDB_API bool CreateEntry(ClientContext &context, const Identifier &name, unique_ptr value, const LogicalDependencyList &dependencies); - DUCKDB_API bool AlterEntry(CatalogTransaction transaction, const string &name, AlterInfo &alter_info); + DUCKDB_API bool AlterEntry(CatalogTransaction transaction, const Identifier &name, AlterInfo &alter_info); - DUCKDB_API bool DropEntry(CatalogTransaction transaction, const string &name, bool cascade, + DUCKDB_API bool DropEntry(CatalogTransaction transaction, const Identifier &name, bool cascade, bool allow_drop_internal = false); - DUCKDB_API bool DropEntry(ClientContext &context, const string &name, bool cascade, + DUCKDB_API bool DropEntry(ClientContext &context, const Identifier &name, bool cascade, bool allow_drop_internal = false); //! Verify that the entry referenced by the dependency is still alive DUCKDB_API void VerifyExistenceOfDependency(transaction_t commit_id, CatalogEntry &entry); @@ -88,13 +88,13 @@ class CatalogSet { void CleanupEntry(CatalogEntry &catalog_entry); //! Returns the entry with the specified name - DUCKDB_API EntryLookup GetEntryDetailed(CatalogTransaction transaction, const string &name); - DUCKDB_API optional_ptr GetEntry(CatalogTransaction transaction, const string &name); - DUCKDB_API optional_ptr GetEntry(ClientContext &context, const string &name); + DUCKDB_API EntryLookup GetEntryDetailed(CatalogTransaction transaction, const Identifier &name); + DUCKDB_API optional_ptr GetEntry(CatalogTransaction transaction, const Identifier &name); + DUCKDB_API optional_ptr GetEntry(ClientContext &context, const Identifier &name); //! Gets the entry that is most similar to the given name (i.e. smallest levenshtein distance), or empty string if //! none is found. The returned pair consists of the entry name and the distance (smaller means closer). - SimilarCatalogEntry SimilarEntry(CatalogTransaction transaction, const string &name); + SimilarCatalogEntry SimilarEntry(CatalogTransaction transaction, const Identifier &name); //! Rollback to be the currently valid entry for a certain catalog //! entry @@ -104,7 +104,7 @@ class CatalogSet { DUCKDB_API void Scan(const std::function &callback); //! Scan the catalog set, invoking the callback method for every entry DUCKDB_API void ScanWithPrefix(CatalogTransaction transaction, const std::function &callback, - const string &prefix); + const Identifier &prefix); DUCKDB_API void Scan(CatalogTransaction transaction, const std::function &callback); DUCKDB_API void ScanWithReturn(CatalogTransaction transaction, const std::function &callback); DUCKDB_API void Scan(ClientContext &context, const std::function &callback); @@ -135,7 +135,7 @@ class CatalogSet { void SetDefaultGenerator(unique_ptr defaults); private: - bool DropDependencies(CatalogTransaction transaction, const string &name, bool cascade, + bool DropDependencies(CatalogTransaction transaction, const Identifier &name, bool cascade, bool allow_drop_internal = false); //! Given a root entry, gets the entry valid for this transaction, 'visible' is used to indicate whether the entry //! is actually visible to the transaction @@ -143,25 +143,25 @@ class CatalogSet { //! Given a root entry, gets the entry valid for this transaction CatalogEntry &GetEntryForTransaction(CatalogTransaction transaction, CatalogEntry ¤t); CatalogEntry &GetCommittedEntry(CatalogEntry ¤t); - optional_ptr GetEntryInternal(CatalogTransaction transaction, const string &name); + optional_ptr GetEntryInternal(CatalogTransaction transaction, const Identifier &name); optional_ptr CreateCommittedEntry(unique_ptr entry); //! Create all default entries void CreateDefaultEntries(CatalogTransaction transaction, unique_lock &lock); //! Attempt to create a default entry with the specified name. Returns the entry if successful, nullptr otherwise. - optional_ptr CreateDefaultEntry(CatalogTransaction transaction, const string &name, + optional_ptr CreateDefaultEntry(CatalogTransaction transaction, const Identifier &name, unique_lock &lock); - bool DropEntryInternal(CatalogTransaction transaction, const string &name, bool allow_drop_internal = false); + bool DropEntryInternal(CatalogTransaction transaction, const Identifier &name, bool allow_drop_internal = false); - bool CreateEntryInternal(CatalogTransaction transaction, const string &name, unique_ptr value, + bool CreateEntryInternal(CatalogTransaction transaction, const Identifier &name, unique_ptr value, unique_lock &read_lock, bool should_be_empty = true); - void CheckCatalogEntryInvariants(CatalogEntry &value, const string &name); + void CheckCatalogEntryInvariants(CatalogEntry &value, const Identifier &name); //! Verify that the previous entry in the chain is dropped. bool VerifyVacancy(CatalogTransaction transaction, CatalogEntry &entry); //! Start the catalog entry chain with a dummy node - bool StartChain(CatalogTransaction transaction, const string &name, unique_lock &read_lock); - bool RenameEntryInternal(CatalogTransaction transaction, CatalogEntry &old, const string &new_name, + bool StartChain(CatalogTransaction transaction, const Identifier &name, unique_lock &read_lock); + bool RenameEntryInternal(CatalogTransaction transaction, CatalogEntry &old, const Identifier &new_name, AlterInfo &alter_info, unique_lock &read_lock); private: diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_coordinate_systems.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_coordinate_systems.hpp index f4125736b..7d210eb39 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/default_coordinate_systems.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/default_coordinate_systems.hpp @@ -23,8 +23,8 @@ class DefaultCoordinateSystemGenerator : public DefaultGenerator { SchemaCatalogEntry &schema; public: - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; + unique_ptr CreateDefaultEntry(ClientContext &context, const Identifier &entry_name) override; + vector GetDefaultEntries() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_functions.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_functions.hpp index c9a5b8997..6aa04a380 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/default_functions.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/default_functions.hpp @@ -29,8 +29,8 @@ class DefaultFunctionGenerator : public DefaultGenerator { DUCKDB_API static unique_ptr CreateInternalMacroInfo(const DefaultMacro &default_macro); public: - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; + unique_ptr CreateDefaultEntry(ClientContext &context, const Identifier &entry_name) override; + vector GetDefaultEntries() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_generator.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_generator.hpp index 88161f7c4..5424a8c90 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/default_generator.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/default_generator.hpp @@ -24,10 +24,10 @@ class DefaultGenerator { public: //! Creates a default entry with the specified name, or returns nullptr if no such entry can be generated - virtual unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name); - virtual unique_ptr CreateDefaultEntry(CatalogTransaction transaction, const string &entry_name); + virtual unique_ptr CreateDefaultEntry(ClientContext &context, const Identifier &entry_name); + virtual unique_ptr CreateDefaultEntry(CatalogTransaction transaction, const Identifier &entry_name); //! Get a list of all default entries in the generator - virtual vector GetDefaultEntries() = 0; + virtual vector GetDefaultEntries() = 0; //! Whether or not we should keep the lock while calling CreateDefaultEntry //! If this is set to false, CreateDefaultEntry might be called multiple times in parallel also for the same entry //! Otherwise it will be called exactly once per entry diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_schemas.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_schemas.hpp index 5b48165f7..1b1913577 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/default_schemas.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/default_schemas.hpp @@ -17,9 +17,9 @@ class DefaultSchemaGenerator : public DefaultGenerator { explicit DefaultSchemaGenerator(Catalog &catalog); public: - unique_ptr CreateDefaultEntry(CatalogTransaction transaction, const string &entry_name) override; - vector GetDefaultEntries() override; - static bool IsDefaultSchema(const string &input_schema); + unique_ptr CreateDefaultEntry(CatalogTransaction transaction, const Identifier &entry_name) override; + vector GetDefaultEntries() override; + static bool IsDefaultSchema(const Identifier &input_schema); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_table_functions.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_table_functions.hpp index c0eee2855..2901a4684 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/default_table_functions.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/default_table_functions.hpp @@ -34,8 +34,8 @@ class DefaultTableFunctionGenerator : public DefaultGenerator { SchemaCatalogEntry &schema; public: - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; + unique_ptr CreateDefaultEntry(ClientContext &context, const Identifier &entry_name) override; + vector GetDefaultEntries() override; static unique_ptr CreateTableMacroInfo(const DefaultTableMacro &default_macro); diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_types.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_types.hpp index 5d58a8299..e37cbd51f 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/default_types.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/default_types.hpp @@ -21,11 +21,11 @@ class DefaultTypeGenerator : public DefaultGenerator { SchemaCatalogEntry &schema; public: - DUCKDB_API static LogicalTypeId GetDefaultType(const string &name); + DUCKDB_API static LogicalTypeId GetDefaultType(const Identifier &name); DUCKDB_API static LogicalType TryDefaultBind(const string &name, const vector> ¶ms); - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; + unique_ptr CreateDefaultEntry(ClientContext &context, const Identifier &entry_name) override; + vector GetDefaultEntries() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/default/default_views.hpp b/src/duckdb/src/include/duckdb/catalog/default/default_views.hpp index 6ebd29eb1..2920448fe 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/default_views.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/default_views.hpp @@ -21,8 +21,8 @@ class DefaultViewGenerator : public DefaultGenerator { SchemaCatalogEntry &schema; public: - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; + unique_ptr CreateDefaultEntry(ClientContext &context, const Identifier &entry_name) override; + vector GetDefaultEntries() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/catalog/dependency.hpp b/src/duckdb/src/include/duckdb/catalog/dependency.hpp index 4ac3656e7..b24fe6c28 100644 --- a/src/duckdb/src/include/duckdb/catalog/dependency.hpp +++ b/src/duckdb/src/include/duckdb/catalog/dependency.hpp @@ -8,7 +8,9 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/common.hpp" +#include "duckdb/common/enums/catalog_type.hpp" #include "duckdb/common/unordered_set.hpp" #include "duckdb/common/string_util.hpp" @@ -137,18 +139,18 @@ struct DependencyDependentFlags : public DependencyFlags { struct CatalogEntryInfo { public: CatalogType type; - string schema; - string name; + Identifier schema; + Identifier name; public: bool operator==(const CatalogEntryInfo &other) const { if (other.type != type) { return false; } - if (!StringUtil::CIEquals(other.schema, schema)) { + if (other.schema != schema) { return false; } - if (!StringUtil::CIEquals(other.name, name)) { + if (other.name != name) { return false; } return true; diff --git a/src/duckdb/src/include/duckdb/catalog/dependency_list.hpp b/src/duckdb/src/include/duckdb/catalog/dependency_list.hpp index f6169f306..3c7202a2d 100644 --- a/src/duckdb/src/include/duckdb/catalog/dependency_list.hpp +++ b/src/duckdb/src/include/duckdb/catalog/dependency_list.hpp @@ -26,12 +26,12 @@ class LogicalDependencyList; struct LogicalDependency { public: CatalogEntryInfo entry; - string catalog; + Identifier catalog; public: explicit LogicalDependency(CatalogEntry &entry); LogicalDependency(); - LogicalDependency(optional_ptr catalog, CatalogEntryInfo entry, string catalog_str); + LogicalDependency(optional_ptr catalog, CatalogEntryInfo entry, Identifier catalog_str); bool operator==(const LogicalDependency &other) const; public: @@ -58,7 +58,7 @@ class LogicalDependencyList { DUCKDB_API bool Contains(CatalogEntry &entry); public: - DUCKDB_API void VerifyDependencies(Catalog &catalog, const string &name); + DUCKDB_API void VerifyDependencies(Catalog &catalog, const Identifier &name); void Serialize(Serializer &serializer) const; static LogicalDependencyList Deserialize(Deserializer &deserializer); bool operator==(const LogicalDependencyList &other) const; diff --git a/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp b/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp index 804d2e7cf..560d369dd 100644 --- a/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp +++ b/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp @@ -57,11 +57,11 @@ struct MangledEntryName { public: //! Format: Type\0Schema\0Name - string name; + Identifier name; public: bool operator==(const MangledEntryName &other) const { - return StringUtil::CIEquals(other.name, name); + return other.name == name; } bool operator!=(const MangledEntryName &other) const { return !(*this == other); @@ -75,7 +75,7 @@ struct MangledDependencyName { public: //! Format: MangledEntryName\0MangledEntryName - string name; + Identifier name; }; //! The DependencyManager is in charge of managing dependencies between catalog entries @@ -108,7 +108,7 @@ class DependencyManager { void CleanupDependencies(CatalogTransaction transaction, CatalogEntry &entry); public: - static string GetSchema(const CatalogEntry &entry); + static Identifier GetSchema(const CatalogEntry &entry); static MangledEntryName MangleName(const CatalogEntryInfo &info); static MangledEntryName MangleName(const CatalogEntry &entry); static CatalogEntryInfo GetLookupProperties(const CatalogEntry &entry); diff --git a/src/duckdb/src/include/duckdb/catalog/entry_lookup_info.hpp b/src/duckdb/src/include/duckdb/catalog/entry_lookup_info.hpp index 893307190..90c2d6e6b 100644 --- a/src/duckdb/src/include/duckdb/catalog/entry_lookup_info.hpp +++ b/src/duckdb/src/include/duckdb/catalog/entry_lookup_info.hpp @@ -17,24 +17,26 @@ class BoundAtClause; struct EntryLookupInfo { public: - EntryLookupInfo(CatalogType catalog_type, const string &name, - QueryErrorContext error_context = QueryErrorContext()); - EntryLookupInfo(CatalogType catalog_type, const string &name, optional_ptr at_clause, + EntryLookupInfo(CatalogType catalog_type, Identifier name, QueryErrorContext error_context = QueryErrorContext()); + EntryLookupInfo(CatalogType catalog_type, Identifier name, optional_ptr at_clause, QueryErrorContext error_context); - EntryLookupInfo(const EntryLookupInfo &parent, const string &name); + EntryLookupInfo(const EntryLookupInfo &parent, Identifier name); EntryLookupInfo(const EntryLookupInfo &parent, optional_ptr at_clause); public: CatalogType GetCatalogType() const; + //! The identifier being looked up (the catalog stores/compares names case-insensitively). + const Identifier &GetEntryIdentifier() const; + //! The raw name of the identifier being looked up. const string &GetEntryName() const; const QueryErrorContext &GetErrorContext() const; const optional_ptr GetAtClause() const; - static EntryLookupInfo SchemaLookup(const EntryLookupInfo &parent, const string &schema_name); + static EntryLookupInfo SchemaLookup(const EntryLookupInfo &parent, Identifier schema_name); private: CatalogType catalog_type; - const string &name; + Identifier name; optional_ptr at_clause; QueryErrorContext error_context; }; diff --git a/src/duckdb/src/include/duckdb/catalog/similar_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/similar_catalog_entry.hpp index 2031e7980..a01fb8c00 100644 --- a/src/duckdb/src/include/duckdb/catalog/similar_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/similar_catalog_entry.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { class SchemaCatalogEntry; @@ -17,7 +18,7 @@ class SchemaCatalogEntry; //! Return value of SimilarEntryInSchemas struct SimilarCatalogEntry { //! The entry name. Empty if absent - string name; + Identifier name; //! The similarity score of the given name (between 0.0 and 1.0, higher is better) double score = 0.0; //! The schema of the entry. diff --git a/src/duckdb/src/include/duckdb/catalog/standard_entry.hpp b/src/duckdb/src/include/duckdb/catalog/standard_entry.hpp index a3fa83d88..022b59422 100644 --- a/src/duckdb/src/include/duckdb/catalog/standard_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/standard_entry.hpp @@ -17,7 +17,7 @@ class SchemaCatalogEntry; //! A StandardEntry is a catalog entry that is a member of a schema class StandardEntry : public InCatalogEntry { public: - StandardEntry(CatalogType type, SchemaCatalogEntry &schema, Catalog &catalog, string name) + StandardEntry(CatalogType type, SchemaCatalogEntry &schema, Catalog &catalog, Identifier name) : InCatalogEntry(type, catalog, std::move(name)), schema(schema) { } ~StandardEntry() override { diff --git a/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp b/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp index 235a02e2a..3e92a0b26 100644 --- a/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp +++ b/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp @@ -214,6 +214,9 @@ AdbcStatusCode StatementBind(struct AdbcStatement *statement, struct ArrowArray AdbcStatusCode StatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, struct AdbcError *error); +AdbcStatusCode StatementExecuteSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, + struct AdbcError *error); + AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, struct AdbcError *error); diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp index 4a1e594d0..c0b60abf9 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp @@ -29,7 +29,8 @@ typedef void (*initialize_t)(ArrowAppendData &result, const LogicalType &type, i // from: The offset into the input we're scanning // to: The last index of the input we're scanning // input_size: The total size of the 'input' Vector. -typedef void (*append_vector_t)(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); +typedef void (*append_vector_t)(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, + idx_t input_size); typedef void (*finalize_t)(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); // This struct is used to save state for appending a column diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/bool_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/bool_data.hpp index 3b8a4f142..dbb3f5bc6 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/bool_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/bool_data.hpp @@ -16,7 +16,7 @@ namespace duckdb { struct ArrowBoolData { public: static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size); static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); }; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/fixed_size_list_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/fixed_size_list_data.hpp index 775fb2e17..11cc1acc6 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/fixed_size_list_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/fixed_size_list_data.hpp @@ -15,7 +15,7 @@ namespace duckdb { struct ArrowFixedSizeListData { public: static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size); static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); }; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp index 849da9a65..612937c70 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp @@ -24,7 +24,7 @@ struct ArrowListData { result.child_data.push_back(std::move(child_buffer)); } - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { UnifiedVectorFormat format; input.ToUnifiedFormat(format); idx_t size = to - from; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp index 79dfb11e2..cdd5d7d48 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp @@ -26,7 +26,7 @@ struct ArrowListViewData { result.child_data.push_back(std::move(child_buffer)); } - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { UnifiedVectorFormat format; input.ToUnifiedFormat(format); idx_t size = to - from; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp index e7569f91c..560e4c3e4 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp @@ -36,7 +36,7 @@ struct ArrowMapData { result.child_data.push_back(std::move(internal_struct)); } - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { UnifiedVectorFormat format; input.ToUnifiedFormat(format); idx_t size = to - from; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/null_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/null_data.hpp index d39b5d8f3..17c2dc32a 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/null_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/null_data.hpp @@ -16,7 +16,7 @@ namespace duckdb { struct ArrowNullData { public: static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size); static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); }; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp index 091fb52a1..e4b63e6d2 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp @@ -89,7 +89,7 @@ struct ArrowUUIDBlobConverter { template struct ArrowScalarBaseData { - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { D_ASSERT(to >= from); idx_t size = to - from; D_ASSERT(size <= input_size); diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/struct_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/struct_data.hpp index 299a05bab..e4db6ae03 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/struct_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/struct_data.hpp @@ -19,7 +19,7 @@ namespace duckdb { struct ArrowStructData { public: static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size); static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); }; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/union_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/union_data.hpp index df8fc848f..bb441c389 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/union_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/union_data.hpp @@ -22,7 +22,7 @@ namespace duckdb { struct ArrowUnionData { public: static void Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity); - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size); + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size); static void Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result); }; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp index 2d73d20a8..24675b06c 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp @@ -50,7 +50,8 @@ struct ArrowVarcharData { } template - static void AppendTemplated(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + static void AppendTemplated(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, + idx_t input_size) { idx_t size = to - from; UnifiedVectorFormat format; input.ToUnifiedFormat(format); @@ -110,7 +111,7 @@ struct ArrowVarcharData { append_data.row_count += size; } - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { if (append_data.options.arrow_offset_size == ArrowOffsetSize::REGULAR) { // Check if the offset exceeds the max supported value AppendTemplated(append_data, input, from, to, input_size); @@ -133,7 +134,7 @@ struct ArrowVarcharToStringViewData { result.GetBufferSizeBuffer().reserve(sizeof(int64_t)); } - static void Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { + static void Append(ArrowAppendData &append_data, const Vector &input, idx_t from, idx_t to, idx_t input_size) { idx_t size = to - from; UnifiedVectorFormat format; input.ToUnifiedFormat(format); diff --git a/src/duckdb/src/include/duckdb/common/bind_helpers.hpp b/src/duckdb/src/include/duckdb/common/bind_helpers.hpp index 992047628..172918e6f 100644 --- a/src/duckdb/src/include/duckdb/common/bind_helpers.hpp +++ b/src/duckdb/src/include/duckdb/common/bind_helpers.hpp @@ -13,11 +13,16 @@ namespace duckdb { class Value; +struct BoundOrderByNode; +struct BoundStatement; +class Binder; Value ConvertVectorToValue(vector set); vector ParseColumnList(const vector &set, vector &names, const string &option_name); vector ParseColumnList(const Value &value, vector &names, const string &option_name); -vector ParseColumnsOrdered(const vector &set, vector &names, const string &loption); -vector ParseColumnsOrdered(const Value &value, vector &names, const string &loption); +vector ParseColumnsOrdered(const vector &set, const vector &names, const string &loption); +vector ParseColumnsOrdered(const Value &value, const vector &names, const string &loption); +vector ParseOrderByColumns(Binder &binder, const vector &set, + const BoundStatement &bound_statement, const string &loption); } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/box_renderer_context.hpp b/src/duckdb/src/include/duckdb/common/box_renderer_context.hpp index b22eb297a..f716304b0 100644 --- a/src/duckdb/src/include/duckdb/common/box_renderer_context.hpp +++ b/src/duckdb/src/include/duckdb/common/box_renderer_context.hpp @@ -29,7 +29,7 @@ class BoxRendererContext { virtual Allocator &GetAllocator() = 0; //! Cast a single source vector to VARCHAR. Delegates to the chunk-level virtual. - void CastToVarchar(Vector &source, Vector &result, idx_t count); + void CastToVarchar(const Vector &source, Vector &result, idx_t count); protected: //! Cast source chunk columns to VARCHAR into result chunk. Derived classes implement this. diff --git a/src/duckdb/src/include/duckdb/common/checked_integer.hpp b/src/duckdb/src/include/duckdb/common/checked_integer.hpp index 5d209a599..e60674f74 100644 --- a/src/duckdb/src/include/duckdb/common/checked_integer.hpp +++ b/src/duckdb/src/include/duckdb/common/checked_integer.hpp @@ -13,6 +13,7 @@ #include "duckdb/common/operator/add.hpp" #include "duckdb/common/operator/multiply.hpp" #include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/windows_undefs.hpp" #include #include diff --git a/src/duckdb/src/include/duckdb/common/clustered_aggregate.hpp b/src/duckdb/src/include/duckdb/common/clustered_aggregate.hpp new file mode 100644 index 000000000..bf0d9d206 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/clustered_aggregate.hpp @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/clustered_aggregate.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/vector_size.hpp" + +namespace duckdb { + +class Vector; +struct ClusteredAggrState; + +using DictProps = unsafe_unique_array; + +static constexpr uint64_t SUM_OVERFLOW_MASK = ~((uint64_t(1) << 53) - 1); +static inline bool I64VectorSumSafe(int64_t v) { + return ((static_cast(v) ^ static_cast(v >> 63)) & SUM_OVERFLOW_MASK) == 0; +} + +//! Per-chunk tuple clustering by group. Built once and passed to aggregate kernels. +//! Clustered-aware kernels can use per-run accumulation; everyone else keeps the +//! regular scatter path over input-order addresses. +struct ClusteredAggr { + static constexpr idx_t RUNLENGTH_THRESHOLD = 6; + static constexpr idx_t MAX_RUNS = STANDARD_VECTOR_SIZE; + static constexpr idx_t HOTKEYS_LOG2 = 5; + static constexpr idx_t MAX_HOTKEYS = idx_t(1) << HOTKEYS_LOG2; + static constexpr idx_t SAMPLE_SIZE = 128; + static constexpr idx_t HASHTAB_LOG2 = 11; + static constexpr idx_t HASHTAB_SZ = idx_t(1) << HASHTAB_LOG2; + static constexpr idx_t SLOT_GRP_BITS = 13; + static constexpr idx_t SLOT_CURSOR_BITS = (HOTKEYS_LOG2 + 1) + SLOT_GRP_BITS; + static constexpr idx_t SLOT_GID_BITS = 64 - (SLOT_GRP_BITS + SLOT_CURSOR_BITS); + static constexpr idx_t MAX_GID_COUNT = idx_t(1) << SLOT_GID_BITS; + static constexpr uint64_t FREE_SLOT = ~uint64_t(0); + + struct GroupRun { + data_ptr_t state; //! caller fills this after TryClustered; advanced between aggregates + const sel_t *sel; //! points to the tuple positions for this run + uint64_t gid; //! raw group id for this run + idx_t count; //! number of tuples in this group + }; + + idx_t n_group_runs = 0; + GroupRun group_runs[MAX_RUNS]; + + const ClusteredAggrState *state = nullptr; + + //! Build a clustered permutation of 0..count-1 from raw integer group ids. + //! On success fills group_runs[].sel/gid/count. + bool TryClustered(const uint64_t *group_ids, sel_t count, sel_t *arena, uint64_t *slots); + + //! Initialize a single run covering 0..count-1 for one aggregate state. + void SetSingleRun(data_ptr_t state, idx_t count); + + //! Advance all run state pointers by payload_size. + void AdvanceStates(idx_t payload_size); + + template + void InitializeStates(GET_STATE &&get_state) { + for (idx_t r = 0; r < n_group_runs; r++) { + group_runs[r].state = get_state(group_runs[r].gid); + } + } + + //! Returns a composed dict sel for simple dictionary input, or nullptr. + const sel_t *ClusterIter(const Vector &input, idx_t count) const; + +private: + mutable sel_t composed_sel_data[STANDARD_VECTOR_SIZE]; + mutable const sel_t *cached_dict_sel = nullptr; +}; + +//! Scratch state shared by GroupedAggregateHashTable and PerfectAggregateHashTable. +struct ClusteredAggrState { + unsafe_unique_array arena; + unsafe_unique_array slots; + bool all_clustered = false; + idx_t n_clustered = 0; + idx_t skipped_opportunities = 0; + idx_t retry_backoff = 1; + + mutable unordered_map dict_props; + + void Initialize(); + bool TryBuild(ClusteredAggr &clustered, const uint64_t *group_ids, idx_t count); + optional_ptr GetDictProps(const Vector &input) const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index d04815c64..33021ac26 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -52,6 +52,8 @@ enum class AggregateHandling : uint8_t; enum class AggregateOrderDependent : uint8_t; +enum class AggregateStateExportMode : uint8_t; + enum class AggregateType : uint8_t; enum class AllowParserOverride : uint8_t; @@ -136,6 +138,8 @@ enum class ColumnSegmentType : uint8_t; enum class CompressedMaterializationDirection : uint8_t; +enum class CompressedMaterializationType : uint8_t; + enum class CompressionType : uint8_t; enum class CompressionValidity : uint8_t; @@ -180,6 +184,8 @@ enum class DeprecatedUsingKeySyntax : uint8_t; enum class DestroyBufferUpon : uint8_t; +enum class DistinctCountSource : uint8_t; + enum class DistinctType : uint8_t; enum class ErrorType : uint16_t; @@ -222,6 +228,8 @@ enum class FileExpandResult : uint8_t; enum class FileGlobOptions : uint8_t; +enum class FileIOMode : uint8_t; + enum class FileLockType : uint8_t; enum class FileNameSegmentType : uint8_t; @@ -258,6 +266,8 @@ enum class InsertColumnOrder : uint8_t; enum class InterruptMode : uint8_t; +enum class JoinFilterPushdownMode : uint8_t; + enum class JoinRefType : uint8_t; enum class JoinType : uint8_t; @@ -272,6 +282,8 @@ enum class LambdaSyntaxType : uint8_t; enum class LimitNodeType : uint8_t; +enum class LimitValueType : uint8_t; + enum class LoadType : uint8_t; enum class LogContextScope : uint8_t; @@ -300,10 +312,6 @@ enum class MergeActionType : uint8_t; enum class MetaPipelineType : uint8_t; -enum class MetricGroup : uint8_t; - -enum class MetricType : uint8_t; - enum class Monotonicity : uint8_t; enum class MultiFileColumnMappingMode : uint8_t; @@ -392,8 +400,14 @@ enum class QueryResultType : uint8_t; enum class RecoveryMode : uint8_t; +enum class RecursiveCTEInlineStageType : uint8_t; + +enum class RecursiveProbeSidePreference : uint8_t; + enum class RelationType : uint8_t; +enum class RemoteCapability : uint8_t; + enum class RenderMode : uint8_t; enum class RequestType : uint8_t; @@ -416,10 +430,14 @@ enum class SecretPersistType : uint8_t; enum class SecretSerializationType : uint8_t; +enum class SegmentTreeVerifyMode : uint8_t; + enum class SelectivityOptionalFilterType : uint8_t; enum class SequenceInfo : uint8_t; +enum class SerializationVersionDeprecated : uint64_t; + enum class SetOperationType : uint8_t; enum class SetScope : uint8_t; @@ -458,6 +476,8 @@ enum class StorageBlockPrefetch : uint8_t; enum class StorageIndexType : uint8_t; +enum class StorageVersion : uint64_t; + enum class StrTimeSpecifier : uint8_t; enum class StreamExecutionResult : uint8_t; @@ -480,6 +500,8 @@ enum class TaskExecutionMode : uint8_t; enum class TaskExecutionResult : uint8_t; +enum class TaskSchedulerType : uint8_t; + enum class TemporaryBufferSize : uint64_t; enum class TemporaryCompressionLevel : int; @@ -567,6 +589,9 @@ const char* EnumUtil::ToChars(AggregateHandling value); template<> const char* EnumUtil::ToChars(AggregateOrderDependent value); +template<> +const char* EnumUtil::ToChars(AggregateStateExportMode value); + template<> const char* EnumUtil::ToChars(AggregateType value); @@ -693,6 +718,9 @@ const char* EnumUtil::ToChars(ColumnSegmentType value); template<> const char* EnumUtil::ToChars(CompressedMaterializationDirection value); +template<> +const char* EnumUtil::ToChars(CompressedMaterializationType value); + template<> const char* EnumUtil::ToChars(CompressionType value); @@ -759,6 +787,9 @@ const char* EnumUtil::ToChars(DeprecatedUsingKeySyntax template<> const char* EnumUtil::ToChars(DestroyBufferUpon value); +template<> +const char* EnumUtil::ToChars(DistinctCountSource value); + template<> const char* EnumUtil::ToChars(DistinctType value); @@ -822,6 +853,9 @@ const char* EnumUtil::ToChars(FileExpandResult value); template<> const char* EnumUtil::ToChars(FileGlobOptions value); +template<> +const char* EnumUtil::ToChars(FileIOMode value); + template<> const char* EnumUtil::ToChars(FileLockType value); @@ -876,6 +910,9 @@ const char* EnumUtil::ToChars(InsertColumnOrder value); template<> const char* EnumUtil::ToChars(InterruptMode value); +template<> +const char* EnumUtil::ToChars(JoinFilterPushdownMode value); + template<> const char* EnumUtil::ToChars(JoinRefType value); @@ -897,6 +934,9 @@ const char* EnumUtil::ToChars(LambdaSyntaxType value); template<> const char* EnumUtil::ToChars(LimitNodeType value); +template<> +const char* EnumUtil::ToChars(LimitValueType value); + template<> const char* EnumUtil::ToChars(LoadType value); @@ -939,12 +979,6 @@ const char* EnumUtil::ToChars(MergeActionType value); template<> const char* EnumUtil::ToChars(MetaPipelineType value); -template<> -const char* EnumUtil::ToChars(MetricGroup value); - -template<> -const char* EnumUtil::ToChars(MetricType value); - template<> const char* EnumUtil::ToChars(Monotonicity value); @@ -1077,9 +1111,18 @@ const char* EnumUtil::ToChars(QueryResultType value); template<> const char* EnumUtil::ToChars(RecoveryMode value); +template<> +const char* EnumUtil::ToChars(RecursiveCTEInlineStageType value); + +template<> +const char* EnumUtil::ToChars(RecursiveProbeSidePreference value); + template<> const char* EnumUtil::ToChars(RelationType value); +template<> +const char* EnumUtil::ToChars(RemoteCapability value); + template<> const char* EnumUtil::ToChars(RenderMode value); @@ -1113,12 +1156,18 @@ const char* EnumUtil::ToChars(SecretPersistType value); template<> const char* EnumUtil::ToChars(SecretSerializationType value); +template<> +const char* EnumUtil::ToChars(SegmentTreeVerifyMode value); + template<> const char* EnumUtil::ToChars(SelectivityOptionalFilterType value); template<> const char* EnumUtil::ToChars(SequenceInfo value); +template<> +const char* EnumUtil::ToChars(SerializationVersionDeprecated value); + template<> const char* EnumUtil::ToChars(SetOperationType value); @@ -1176,6 +1225,9 @@ const char* EnumUtil::ToChars(StorageBlockPrefetch value); template<> const char* EnumUtil::ToChars(StorageIndexType value); +template<> +const char* EnumUtil::ToChars(StorageVersion value); + template<> const char* EnumUtil::ToChars(StrTimeSpecifier value); @@ -1209,6 +1261,9 @@ const char* EnumUtil::ToChars(TaskExecutionMode value); template<> const char* EnumUtil::ToChars(TaskExecutionResult value); +template<> +const char* EnumUtil::ToChars(TaskSchedulerType value); + template<> const char* EnumUtil::ToChars(TemporaryBufferSize value); @@ -1324,6 +1379,9 @@ AggregateHandling EnumUtil::FromString(const char *value); template<> AggregateOrderDependent EnumUtil::FromString(const char *value); +template<> +AggregateStateExportMode EnumUtil::FromString(const char *value); + template<> AggregateType EnumUtil::FromString(const char *value); @@ -1450,6 +1508,9 @@ ColumnSegmentType EnumUtil::FromString(const char *value); template<> CompressedMaterializationDirection EnumUtil::FromString(const char *value); +template<> +CompressedMaterializationType EnumUtil::FromString(const char *value); + template<> CompressionType EnumUtil::FromString(const char *value); @@ -1516,6 +1577,9 @@ DeprecatedUsingKeySyntax EnumUtil::FromString(const ch template<> DestroyBufferUpon EnumUtil::FromString(const char *value); +template<> +DistinctCountSource EnumUtil::FromString(const char *value); + template<> DistinctType EnumUtil::FromString(const char *value); @@ -1579,6 +1643,9 @@ FileExpandResult EnumUtil::FromString(const char *value); template<> FileGlobOptions EnumUtil::FromString(const char *value); +template<> +FileIOMode EnumUtil::FromString(const char *value); + template<> FileLockType EnumUtil::FromString(const char *value); @@ -1633,6 +1700,9 @@ InsertColumnOrder EnumUtil::FromString(const char *value); template<> InterruptMode EnumUtil::FromString(const char *value); +template<> +JoinFilterPushdownMode EnumUtil::FromString(const char *value); + template<> JoinRefType EnumUtil::FromString(const char *value); @@ -1654,6 +1724,9 @@ LambdaSyntaxType EnumUtil::FromString(const char *value); template<> LimitNodeType EnumUtil::FromString(const char *value); +template<> +LimitValueType EnumUtil::FromString(const char *value); + template<> LoadType EnumUtil::FromString(const char *value); @@ -1696,12 +1769,6 @@ MergeActionType EnumUtil::FromString(const char *value); template<> MetaPipelineType EnumUtil::FromString(const char *value); -template<> -MetricGroup EnumUtil::FromString(const char *value); - -template<> -MetricType EnumUtil::FromString(const char *value); - template<> Monotonicity EnumUtil::FromString(const char *value); @@ -1834,9 +1901,18 @@ QueryResultType EnumUtil::FromString(const char *value); template<> RecoveryMode EnumUtil::FromString(const char *value); +template<> +RecursiveCTEInlineStageType EnumUtil::FromString(const char *value); + +template<> +RecursiveProbeSidePreference EnumUtil::FromString(const char *value); + template<> RelationType EnumUtil::FromString(const char *value); +template<> +RemoteCapability EnumUtil::FromString(const char *value); + template<> RenderMode EnumUtil::FromString(const char *value); @@ -1870,12 +1946,18 @@ SecretPersistType EnumUtil::FromString(const char *value); template<> SecretSerializationType EnumUtil::FromString(const char *value); +template<> +SegmentTreeVerifyMode EnumUtil::FromString(const char *value); + template<> SelectivityOptionalFilterType EnumUtil::FromString(const char *value); template<> SequenceInfo EnumUtil::FromString(const char *value); +template<> +SerializationVersionDeprecated EnumUtil::FromString(const char *value); + template<> SetOperationType EnumUtil::FromString(const char *value); @@ -1933,6 +2015,9 @@ StorageBlockPrefetch EnumUtil::FromString(const char *valu template<> StorageIndexType EnumUtil::FromString(const char *value); +template<> +StorageVersion EnumUtil::FromString(const char *value); + template<> StrTimeSpecifier EnumUtil::FromString(const char *value); @@ -1966,6 +2051,9 @@ TaskExecutionMode EnumUtil::FromString(const char *value); template<> TaskExecutionResult EnumUtil::FromString(const char *value); +template<> +TaskSchedulerType EnumUtil::FromString(const char *value); + template<> TemporaryBufferSize EnumUtil::FromString(const char *value); diff --git a/src/duckdb/src/include/duckdb/common/enums/column_segment_info_scan_type.hpp b/src/duckdb/src/include/duckdb/common/enums/column_segment_info_scan_type.hpp new file mode 100644 index 000000000..91353c865 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/column_segment_info_scan_type.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/column_segment_info_scan_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +//! Options controlling how pragma_storage_info gathers column segment info. +//! Both axes are independent: +//! * include_segment_info: when true, populate the per-segment `segment_info` column +//! (this requires reading per-segment data — expensive for bitpacked segments). +//! * loaded_segments_only: when true, skip columns whose metadata has not been +//! lazy-loaded; used to introspect what is resident without triggering loads. +struct ColumnSegmentInfoScanOptions { + bool include_segment_info = false; + bool loaded_segments_only = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/debug_verification_mode.hpp b/src/duckdb/src/include/duckdb/common/enums/debug_verification_mode.hpp index 50eb54613..234692f56 100644 --- a/src/duckdb/src/include/duckdb/common/enums/debug_verification_mode.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/debug_verification_mode.hpp @@ -12,6 +12,6 @@ namespace duckdb { -enum class DebugVerificationMode : uint8_t { DEFAULT, NONE, VERIFY_VECTORS, VERIFY_SERIALIZATION }; +enum class DebugVerificationMode : uint8_t { DEFAULT, NONE, VERIFY_VECTORS, VERIFY_SERIALIZATION, VERIFY_FUNCTIONS }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/file_glob_options.hpp b/src/duckdb/src/include/duckdb/common/enums/file_glob_options.hpp index bd6e58f22..092225758 100644 --- a/src/duckdb/src/include/duckdb/common/enums/file_glob_options.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/file_glob_options.hpp @@ -16,12 +16,18 @@ enum class FileGlobOptions : uint8_t { DISALLOW_EMPTY = 0, ALLOW_EMPTY = 1, FALL struct FileGlobInput { FileGlobInput(FileGlobOptions options) // NOLINT: allow implicit conversion from FileGlobOptions - : behavior(options) { + : behavior(options), allow_empty(options == FileGlobOptions::ALLOW_EMPTY) { } - FileGlobInput(FileGlobOptions options, string extension_p) : behavior(options), extension(std::move(extension_p)) { + FileGlobInput(FileGlobOptions options, string extension_p) + : behavior(options), allow_empty(options == FileGlobOptions::ALLOW_EMPTY), extension(std::move(extension_p)) { + } + + bool AllowsEmpty() const { + return allow_empty || behavior == FileGlobOptions::ALLOW_EMPTY; } FileGlobOptions behavior; + bool allow_empty; string extension; }; diff --git a/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp b/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp index 9b228d468..6dd62ce67 100644 --- a/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp @@ -103,6 +103,8 @@ enum class LogicalOperatorType : uint8_t { LOGICAL_LOAD = 180, LOGICAL_RESET = 181, LOGICAL_UPDATE_EXTENSIONS = 182, + LOGICAL_CONNECT = 183, + LOGICAL_DISCONNECT = 184, // ----------------------------- // Secrets diff --git a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp index bef89e7c0..61a413b5c 100644 --- a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp @@ -12,156 +12,26 @@ #include "duckdb/common/types/value.hpp" #include "duckdb/common/unordered_set.hpp" -#include "duckdb/common/constants.hpp" -#include "duckdb/common/enums/optimizer_type.hpp" namespace duckdb { -enum class MetricGroup : uint8_t { - ALL, - CORE, - DEFAULT, - EXECUTION, - FILE, - OPERATOR, - OPTIMIZER, - PHASE_TIMING, - INVALID, -}; - -enum class MetricType : uint8_t { - // Core metrics - CPU_TIME = 2, - CUMULATIVE_CARDINALITY = 4, - CUMULATIVE_ROWS_SCANNED = 7, - EXTRA_INFO = 3, - LATENCY = 11, - QUERY_NAME = 0, - RESULT_SET_SIZE = 10, - ROWS_RETURNED = 12, - // Execution metrics - BLOCKED_THREAD_TIME = 1, - SYSTEM_PEAK_BUFFER_MEMORY = 14, - SYSTEM_PEAK_TEMP_DIR_SIZE = 15, - TOTAL_MEMORY_ALLOCATED = 91, - // File metrics - ATTACH_LOAD_STORAGE_LATENCY = 92, - ATTACH_REPLAY_WAL_LATENCY = 93, - CHECKPOINT_LATENCY = 94, - COMMIT_LOCAL_STORAGE_LATENCY = 95, - TOTAL_BYTES_READ = 16, - TOTAL_BYTES_WRITTEN = 17, - WAITING_TO_ATTACH_LATENCY = 96, - WAL_REPLAY_ENTRY_COUNT = 97, - WRITE_TO_WAL_LATENCY = 98, - // Operator metrics - OPERATOR_CARDINALITY = 6, - OPERATOR_NAME = 13, - OPERATOR_ROWS_SCANNED = 8, - OPERATOR_TIMING = 9, - OPERATOR_TYPE = 5, - // Optimizer metrics - OPTIMIZER_EXPRESSION_REWRITER = 26, - OPTIMIZER_FILTER_PULLUP = 27, - OPTIMIZER_FILTER_PUSHDOWN = 28, - OPTIMIZER_EMPTY_RESULT_PULLUP = 29, - OPTIMIZER_CTE_FILTER_PUSHER = 30, - OPTIMIZER_REGEX_RANGE = 31, - OPTIMIZER_IN_CLAUSE = 32, - OPTIMIZER_JOIN_ORDER = 33, - OPTIMIZER_DELIMINATOR = 34, - OPTIMIZER_UNNEST_REWRITER = 35, - OPTIMIZER_UNUSED_COLUMNS = 36, - OPTIMIZER_STATISTICS_PROPAGATION = 37, - OPTIMIZER_COMMON_SUBEXPRESSIONS = 38, - OPTIMIZER_COMMON_AGGREGATE = 39, - OPTIMIZER_COLUMN_LIFETIME = 40, - OPTIMIZER_BUILD_SIDE_PROBE_SIDE = 41, - OPTIMIZER_LIMIT_PUSHDOWN = 42, - OPTIMIZER_TOP_N = 43, - OPTIMIZER_COMPRESSED_MATERIALIZATION = 44, - OPTIMIZER_DUPLICATE_GROUPS = 45, - OPTIMIZER_REORDER_FILTER = 46, - OPTIMIZER_SAMPLING_PUSHDOWN = 47, - OPTIMIZER_JOIN_FILTER_PUSHDOWN = 48, - OPTIMIZER_EXTENSION = 49, - OPTIMIZER_MATERIALIZED_CTE = 50, - OPTIMIZER_AGGREGATE_FUNCTION_REWRITER = 51, - OPTIMIZER_LATE_MATERIALIZATION = 52, - OPTIMIZER_CTE_INLINING = 53, - OPTIMIZER_ROW_GROUP_PRUNER = 54, - OPTIMIZER_TOP_N_WINDOW_ELIMINATION = 55, - OPTIMIZER_COMMON_SUBPLAN = 56, - OPTIMIZER_JOIN_ELIMINATION = 57, - OPTIMIZER_WINDOW_SELF_JOIN = 58, - OPTIMIZER_PROJECTION_PULLUP = 59, - OPTIMIZER_OUTER_JOIN_SIMPLIFICATION = 60, - OPTIMIZER_ROW_NUMBER_REWRITER = 61, - OPTIMIZER_PARTITIONED_EXECUTION = 62, - // PhaseTiming metrics - ALL_OPTIMIZERS = 18, - CUMULATIVE_OPTIMIZER_TIMING = 19, - PARSER = 99, - PHYSICAL_PLANNER = 22, - PHYSICAL_PLANNER_COLUMN_BINDING = 23, - PHYSICAL_PLANNER_CREATE_PLAN = 25, - PHYSICAL_PLANNER_RESOLVE_TYPES = 24, - PLANNER = 20, - PLANNER_BINDING = 21, -}; - -struct MetricTypeHashFunction { - uint64_t operator()(const MetricType &index) const { - return std::hash()(static_cast(index)); - } -}; - -typedef unordered_set profiler_settings_t; -typedef unordered_map profiler_metrics_t; +typedef unordered_map profiler_metrics_t; class MetricsUtils { public: - static constexpr uint8_t START_OPTIMIZER = static_cast(MetricType::OPTIMIZER_EXPRESSION_REWRITER); - static constexpr uint8_t END_OPTIMIZER = static_cast(MetricType::OPTIMIZER_PARTITIONED_EXECUTION); - -public: - - // All metrics - static profiler_settings_t GetAllMetrics(); - static profiler_settings_t GetMetricsByGroupType(MetricGroup type); - - // Core metrics - static profiler_settings_t GetCoreMetrics(); - static bool IsCoreMetric(MetricType type); - - // Default metrics - static profiler_settings_t GetDefaultMetrics(); - static bool IsDefaultMetric(MetricType type); - - // Execution metrics - static profiler_settings_t GetExecutionMetrics(); - static bool IsExecutionMetric(MetricType type); - - // File metrics - static profiler_settings_t GetFileMetrics(); - static bool IsFileMetric(MetricType type); - - // Operator metrics - static profiler_settings_t GetOperatorMetrics(); - static bool IsOperatorMetric(MetricType type); - - // Optimizer metrics - static profiler_settings_t GetOptimizerMetrics(); - static bool IsOptimizerMetric(MetricType type); - static MetricType GetOptimizerMetricByType(OptimizerType type); - static OptimizerType GetOptimizerTypeByMetric(MetricType type); - - // PhaseTiming metrics - static profiler_settings_t GetPhaseTimingMetrics(); - static bool IsPhaseTimingMetric(MetricType type); - - // RootScope metrics - static profiler_settings_t GetRootScopeMetrics(); - static bool IsRootScopeMetric(MetricType type); + template + static bool IsMetric(const string &metric_name) { + return StringUtil::Equals(metric_name, T::Name); + } + + static bool MetricInGroup(const string &metric_name, const string &group_name) { + if (metric_name.size() < group_name.size() + 1) { + return false; + } + if (!StringUtil::StartsWith(metric_name, group_name)) { + return false; + } + return metric_name[group_name.size()] == '.'; + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp index 8c8358e44..1470a430e 100644 --- a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp @@ -52,6 +52,8 @@ enum class OptimizerType : uint32_t { OUTER_JOIN_SIMPLIFICATION = 35, ROW_NUMBER_REWRITER = 36, PARTITIONED_EXECUTION = 37, + PARTIAL_AGGREGATE_PUSHDOWN = 38, + REMOTE_PUSHDOWN = 39, }; string OptimizerTypeToString(OptimizerType type); diff --git a/src/duckdb/src/include/duckdb/common/enums/physical_operator_type.hpp b/src/duckdb/src/include/duckdb/common/enums/physical_operator_type.hpp index b31a3bfab..1937dbae3 100644 --- a/src/duckdb/src/include/duckdb/common/enums/physical_operator_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/physical_operator_type.hpp @@ -121,6 +121,8 @@ enum class PhysicalOperatorType : uint8_t { EXTENSION, VERIFY_VECTOR, UPDATE_EXTENSIONS, + CONNECT, + DISCONNECT, // ----------------------------- // Secret diff --git a/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp b/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp index fe7d6d960..35d2d83b5 100644 --- a/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/statement_type.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/constants.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/unordered_set.hpp" #include "duckdb/main/query_parameters.hpp" @@ -50,7 +51,9 @@ enum class StatementType : uint8_t { MULTI_STATEMENT, COPY_DATABASE_STATEMENT, UPDATE_EXTENSIONS_STATEMENT, - MERGE_INTO_STATEMENT + MERGE_INTO_STATEMENT, + CONNECT_STATEMENT, + DISCONNECT_STATEMENT }; DUCKDB_API string StatementTypeToString(StatementType type); @@ -93,9 +96,9 @@ struct StatementProperties { }; //! The set of databases this statement will read from - unordered_map read_databases; + identifier_map_t read_databases; //! The set of databases this statement will modify - unordered_map modified_databases; + identifier_map_t modified_databases; //! Whether or not the statement requires a valid transaction. Almost all statements require this, with the //! exception of ROLLBACK bool requires_valid_transaction; diff --git a/src/duckdb/src/include/duckdb/common/enums/task_scheduler_type.hpp b/src/duckdb/src/include/duckdb/common/enums/task_scheduler_type.hpp new file mode 100644 index 000000000..513ce9f6a --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/task_scheduler_type.hpp @@ -0,0 +1,19 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/task_scheduler_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class TaskSchedulerType : uint8_t { REGULAR, ASYNC, ENUM_SIZE }; + +static constexpr size_t TASK_SCHEDULER_TYPE_COUNT = static_cast(TaskSchedulerType::ENUM_SIZE); + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/exception.hpp b/src/duckdb/src/include/duckdb/common/exception.hpp index b74629d30..546dfbc56 100644 --- a/src/duckdb/src/include/duckdb/common/exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception.hpp @@ -135,7 +135,7 @@ class Exception : public std::runtime_error { template static string ConstructMessageRecursive(const string &msg, std::vector &values, const T ¶m, ARGS &&...params) { - values.push_back(ExceptionFormatValue::CreateFormatValue(param)); + values.push_back(ExceptionFormatValue::CreateFormatValue(param)); return ConstructMessageRecursive(msg, values, params...); } @@ -303,6 +303,7 @@ class InterruptException : public Exception { public: static constexpr const char *INTERRUPT_MESSAGE = "Interrupted!"; DUCKDB_API InterruptException(); + DUCKDB_API explicit InterruptException(const string &message); }; class FatalException : public Exception { diff --git a/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp index 5aec7e296..59d1274ee 100644 --- a/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/binder_exception.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/pair.hpp" #include "duckdb/parser/query_error_context.hpp" namespace duckdb { @@ -51,10 +52,12 @@ class BinderException : public Exception { ConstructMessage(msg, std::forward(params)...)) { } - static BinderException ColumnNotFound(const string &name, const vector &similar_bindings, + static BinderException ColumnNotFound(const Identifier &name, const vector &similar_bindings, QueryErrorContext context = QueryErrorContext()); - static BinderException NoMatchingFunction(const string &catalog_name, const string &schema_name, const string &name, - const vector &arguments, const vector &candidates); + static BinderException NoMatchingFunction(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &name, const vector &arguments, + const vector> &named_arguments, + const vector &candidates); static BinderException Unsupported(ParsedExpression &expr, const string &message); }; diff --git a/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp b/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp index aadbc9f83..850d6b34b 100644 --- a/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp +++ b/src/duckdb/src/include/duckdb/common/exception/catalog_exception.hpp @@ -36,11 +36,11 @@ class CatalogException : public Exception { } static CatalogException MissingEntry(const EntryLookupInfo &lookup_info, const string &suggestion); - static CatalogException MissingEntry(CatalogType type, const string &name, const string &suggestion, + static CatalogException MissingEntry(CatalogType type, const Identifier &name, const string &suggestion, QueryErrorContext context = QueryErrorContext()); - static CatalogException MissingEntry(const string &type, const string &name, const vector &suggestions, + static CatalogException MissingEntry(const string &type, const Identifier &name, const vector &suggestions, QueryErrorContext context = QueryErrorContext()); - static CatalogException EntryAlreadyExists(CatalogType type, const string &name, + static CatalogException EntryAlreadyExists(CatalogType type, const Identifier &name, QueryErrorContext context = QueryErrorContext()); }; diff --git a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp index 57e968138..204acce33 100644 --- a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp +++ b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp @@ -18,6 +18,7 @@ namespace duckdb { class String; +class Identifier; enum class PhysicalType : uint8_t; struct LogicalType; @@ -48,6 +49,11 @@ struct ExceptionFormatValue { static ExceptionFormatValue CreateFormatValue(const T &value) { return int64_t(value); } + template + static ExceptionFormatValue CreateFormatValue(const char (&value)[N]) { + const char *ptr = value; + return CreateFormatValue(ptr); + } static string Format(const string &msg, std::vector &values); }; @@ -66,6 +72,8 @@ DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const do template <> DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const string &value); template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const Identifier &value); +template <> DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const String &value); template <> DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *const &value); diff --git a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp index d6a7c5554..81a20bc22 100644 --- a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp +++ b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp @@ -32,8 +32,7 @@ enum class ExtraTypeInfoType : uint8_t { ANY_TYPE_INFO = 10, INTEGER_LITERAL_TYPE_INFO = 11, TEMPLATE_TYPE_INFO = 12, - GEO_TYPE_INFO = 13, - AGGREGATE_STATE_TYPE_INFO = 14 + GEO_TYPE_INFO = 13 }; struct ExtraTypeInfo { @@ -146,45 +145,25 @@ struct StructTypeInfo : public ExtraTypeInfo { }; struct LegacyAggregateStateTypeInfo : public ExtraTypeInfo { - explicit LegacyAggregateStateTypeInfo(aggregate_state_t state_type_p); - - aggregate_state_t state_type; - public: void Serialize(Serializer &serializer) const override; + // Legacy deserialize method kept only for compatibility with old database files static shared_ptr Deserialize(Deserializer &source); - shared_ptr Copy() const override; - -protected: - bool EqualsInternal(ExtraTypeInfo *other_p) const override; - -private: - LegacyAggregateStateTypeInfo(); -}; - -struct AggregateStateTypeInfo : public StructTypeInfo { - explicit AggregateStateTypeInfo(aggregate_state_t state_type_p, child_list_t child_types_p); - aggregate_state_t state_type; - -public: - void Serialize(Serializer &serializer) const override; - static shared_ptr Deserialize(Deserializer &source); - shared_ptr Copy() const override; - shared_ptr DeepCopy() const override; + static shared_ptr LegacyDeserialize(); protected: bool EqualsInternal(ExtraTypeInfo *other_p) const override; private: - AggregateStateTypeInfo(); + LegacyAggregateStateTypeInfo(); }; // If this type is primarily stored in the catalog or not. Enums from Pandas/Factors are not in the catalog. enum EnumDictType : uint8_t { INVALID = 0, VECTOR_DICT = 1 }; struct EnumTypeInfo : public ExtraTypeInfo { - explicit EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p); + explicit EnumTypeInfo(const Vector &values_insert_order_p, idx_t dict_size_p); EnumTypeInfo(const EnumTypeInfo &) = delete; EnumTypeInfo &operator=(const EnumTypeInfo &) = delete; @@ -194,7 +173,7 @@ struct EnumTypeInfo : public ExtraTypeInfo { const idx_t &GetDictSize() const; static PhysicalType DictType(idx_t size); - static LogicalType CreateType(Vector &ordered_data, idx_t size); + static LogicalType CreateType(const Vector &ordered_data, idx_t size); void Serialize(Serializer &serializer) const override; static shared_ptr Deserialize(Deserializer &source); diff --git a/src/duckdb/src/include/duckdb/common/extra_type_info/enum_type_info.hpp b/src/duckdb/src/include/duckdb/common/extra_type_info/enum_type_info.hpp index b907e700b..226113ea4 100644 --- a/src/duckdb/src/include/duckdb/common/extra_type_info/enum_type_info.hpp +++ b/src/duckdb/src/include/duckdb/common/extra_type_info/enum_type_info.hpp @@ -9,7 +9,7 @@ namespace duckdb { template struct EnumTypeInfoTemplated : public EnumTypeInfo { - explicit EnumTypeInfoTemplated(Vector &values_insert_order_p, idx_t size_p) + explicit EnumTypeInfoTemplated(const Vector &values_insert_order_p, idx_t size_p) : EnumTypeInfo(values_insert_order_p, size_p) { D_ASSERT(values_insert_order_p.GetType().InternalType() == PhysicalType::VARCHAR); diff --git a/src/duckdb/src/include/duckdb/common/file_buffer.hpp b/src/duckdb/src/include/duckdb/common/file_buffer.hpp index 141a25c66..3cc67c02a 100644 --- a/src/duckdb/src/include/duckdb/common/file_buffer.hpp +++ b/src/duckdb/src/include/duckdb/common/file_buffer.hpp @@ -15,6 +15,7 @@ namespace duckdb { class BlockAllocator; class BlockManager; +class MemoryMappedFile; class QueryContext; struct FileHandle; @@ -39,8 +40,12 @@ class FileBuffer { public: //! Read into the FileBuffer from the location. void Read(QueryContext context, FileHandle &handle, uint64_t location); + //! Read into the FileBuffer from a memory-mapped file at the location. + void Read(QueryContext context, MemoryMappedFile &handle, uint64_t location); //! Write the FileBuffer to the location. void Write(QueryContext context, FileHandle &handle, const uint64_t location); + //! Write the FileBuffer to the location in a memory-mapped file. + void Write(QueryContext context, MemoryMappedFile &handle, uint64_t location); void Clear(); @@ -69,6 +74,11 @@ class FileBuffer { uint64_t Size() const { return size; } + //! Whether this FileBuffer owns its underlying memory; false if it adopted a pointer + //! into a memory-mapped region. The buffer manager's reuse path requires owned memory. + bool OwnsInternalBuffer() const { + return owns_internal_buffer; + } data_ptr_t InternalBuffer() { return internal_buffer; } @@ -96,6 +106,8 @@ class FileBuffer { //! The aligned size as passed to the constructor. //! This is the size that is read from or written to disk. uint64_t internal_size; + //! False when internal_buffer was adopted from a MemoryMappedFile (don't free in dtor). + bool owns_internal_buffer = true; void ReallocBuffer(idx_t new_size); void Init(); diff --git a/src/duckdb/src/include/duckdb/common/file_system.hpp b/src/duckdb/src/include/duckdb/common/file_system.hpp index 8dadbd59d..9d1df1ebe 100644 --- a/src/duckdb/src/include/duckdb/common/file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/file_system.hpp @@ -39,6 +39,8 @@ class FileSystem; class Logger; class ClientContext; class QueryContext; +class MemoryMappedFile; +struct MMapOptions; class MultiFileList; enum class FileType { @@ -71,6 +73,14 @@ struct FileMetadata { unordered_map extended_file_info; }; +//! Measured network throughput for a (remote) file handle. Used to size prefetch coalescing gaps. +struct NetworkThroughputEstimate { + //! Round-trip latency + request setup, in seconds + double latency_seconds = 0; + //! Single-stream throughput, in bytes per second + double bandwidth_bytes_per_s = 0; +}; + struct FileHandle { public: DUCKDB_API FileHandle(FileSystem &file_system, string path, FileOpenFlags flags); @@ -102,6 +112,8 @@ struct FileHandle { DUCKDB_API bool CanSeek(); DUCKDB_API bool IsPipe(); DUCKDB_API bool OnDiskFile(); + //! Try to obtain a network throughput estimate (Local files return false). + DUCKDB_API bool TryGetNetworkThroughput(NetworkThroughputEstimate &result); DUCKDB_API idx_t GetFileSize(); DUCKDB_API FileType GetType(); DUCKDB_API FileMetadata Stats(); @@ -145,6 +157,7 @@ class FileSystem { public: DUCKDB_API static FileSystem &GetFileSystem(ClientContext &context); DUCKDB_API static FileSystem &GetFileSystem(DatabaseInstance &db); + DUCKDB_API static FileSystem &GetLocal(DatabaseInstance &db); DUCKDB_API static FileSystem &Get(AttachedDatabase &db); DUCKDB_API virtual unique_ptr OpenFile(const string &path, FileOpenFlags flags, @@ -152,6 +165,11 @@ class FileSystem { DUCKDB_API unique_ptr OpenFile(const OpenFileInfo &path, FileOpenFlags flags, optional_ptr opener = nullptr); + //! Open a memory-mapped view of [path]. Throws if not supported by this filesystem. + DUCKDB_API virtual unique_ptr MemoryMapFile(const OpenFileInfo &path, FileOpenFlags flags, + const MMapOptions &options, + optional_ptr opener = nullptr); + //! Read exactly nr_bytes from the specified location in the file. Fails if nr_bytes could not be read. This is //! equivalent to calling SetFilePointer(location) followed by calling Read(). DUCKDB_API virtual void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location); @@ -294,6 +312,9 @@ class FileSystem { //! Whether or not the FS handles plain files on disk. This is relevant for certain optimizations, as random reads //! in a file on-disk are much cheaper than e.g. random reads in a file over the network DUCKDB_API virtual bool OnDiskFile(FileHandle &handle); + //! Try to obtain a measured network throughput estimate. Default: not supported (returns false). + //! Used for file systems + DUCKDB_API virtual bool TryGetNetworkThroughput(FileHandle &handle, NetworkThroughputEstimate &result); DUCKDB_API virtual unique_ptr OpenCompressedFile(QueryContext context, unique_ptr handle, bool write); @@ -301,6 +322,9 @@ class FileSystem { //! Create a LocalFileSystem. DUCKDB_API static unique_ptr CreateLocal(); + //! Whether this is a LocalFileSystem instance. + DUCKDB_API virtual bool IsLocalFileSystem() const; + //! Return the name of the filesystem. Used for forming diagnosis messages. DUCKDB_API virtual std::string GetName() const = 0; diff --git a/src/duckdb/src/include/duckdb/common/fsst.hpp b/src/duckdb/src/include/duckdb/common/fsst.hpp index 88f4414f6..2a62cda08 100644 --- a/src/duckdb/src/include/duckdb/common/fsst.hpp +++ b/src/duckdb/src/include/duckdb/common/fsst.hpp @@ -22,6 +22,13 @@ class Value; class Vector; struct string_t; +typedef struct { + uint32_t dict_size; + uint32_t dict_end; + uint32_t bitpacking_width; + uint32_t fsst_symbol_table_offset; +} fsst_compression_header_t; + class FSSTPrimitives { private: // This allows us to decode FSST strings efficiently directly into a string_t (if inlined) diff --git a/src/duckdb/src/include/duckdb/common/http_util.hpp b/src/duckdb/src/include/duckdb/common/http_util.hpp index f36dbd2bf..baecd2172 100644 --- a/src/duckdb/src/include/duckdb/common/http_util.hpp +++ b/src/duckdb/src/include/duckdb/common/http_util.hpp @@ -141,7 +141,7 @@ struct BaseRequest { BaseRequest(RequestType type, const string &url, const HTTPHeaders &headers, HTTPParams ¶ms); RequestType type; - const string &url; + string url; string path; string proto_host_port; HTTPHeaders headers; @@ -149,11 +149,16 @@ struct BaseRequest { //! Whether or not to return failed requests (instead of throwing) bool try_request = false; - // Requests will optionally contain their timings + //! Requests will optionally contain their timings bool have_request_timing = false; timestamp_t request_start; timestamp_t request_end; + //! Optional per-request network measurements, populated by clients that measure them. + bool have_time_to_fst_byte = false; + double time_to_fst_byte_sec = 0; + idx_t bytes_received = 0; + template TARGET &Cast() { return reinterpret_cast(*this); diff --git a/src/duckdb/src/include/duckdb/common/identifier.hpp b/src/duckdb/src/include/duckdb/common/identifier.hpp new file mode 100644 index 000000000..7978c89fc --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/identifier.hpp @@ -0,0 +1,200 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/identifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/vector.hpp" + +#include + +namespace duckdb { + +//! An Identifier represents a SQL identifier (e.g. a column name, table name, alias, ...). +//! Unlike a regular string, identifiers compare case-insensitively (using StringUtil::CIEquals). +//! Internally the identifier is stored as-is (preserving the original casing), but all comparisons, +//! hashing and ordering are case-insensitive. +class Identifier { +public: + Identifier() = default; + //! Construction from a string literal is implicit: literals in the source are identifiers by intent. + Identifier(const char *str) : value(str) { // NOLINT: implicit conversion from literals is intentional + } + //! Construction from a runtime string is explicit: an Identifier carries case-insensitive semantics, so + //! promoting a runtime string must be a deliberate choice at the call site. + explicit Identifier(const string &str) : value(str) { + } + explicit Identifier(string &&str) : value(std::move(str)) { + } + + //! Named constructors for well-known identifiers + static Identifier DefaultSchema() { + return Identifier(DEFAULT_SCHEMA); + } + static Identifier InvalidSchema() { + return Identifier(INVALID_SCHEMA); + } + static Identifier InvalidCatalog() { + return Identifier(INVALID_CATALOG); + } + static Identifier SystemCatalog() { + return Identifier(SYSTEM_CATALOG); + } + static Identifier TempCatalog() { + return Identifier(TEMP_CATALOG); + } + + //! Conversion back to a string is explicit: it discards the case-insensitive semantics, so callers must opt in + //! (use GetIdentifierName() for the raw value). Keeping this explicit is what makes the Identifier type safe. + explicit operator const string &() const { + return value; + } + + //! The raw underlying string (preserving original casing) + const string &GetIdentifierName() const { + return value; + } + + bool empty() const { // NOLINT: match std::string interface + return value.empty(); + } + void clear() { // NOLINT: match std::string interface + value.clear(); + } + idx_t size() const { // NOLINT: match std::string interface + return value.size(); + } + const char *c_str() const { // NOLINT: match std::string interface + return value.c_str(); + } + + //! Case-insensitive hash of the identifier + DUCKDB_API hash_t Hash() const; + +private: + string value; +}; + +//! Equality (case-insensitive) +DUCKDB_API bool operator==(const Identifier &a, const Identifier &b); +DUCKDB_API bool operator==(const Identifier &a, const string &b); +DUCKDB_API bool operator==(const string &a, const Identifier &b); +DUCKDB_API bool operator==(const Identifier &a, const char *b); +DUCKDB_API bool operator==(const char *a, const Identifier &b); + +inline bool operator!=(const Identifier &a, const Identifier &b) { + return !(a == b); +} +inline bool operator!=(const Identifier &a, const string &b) { + return !(a == b); +} +inline bool operator!=(const string &a, const Identifier &b) { + return !(a == b); +} +inline bool operator!=(const Identifier &a, const char *b) { + return !(a == b); +} +inline bool operator!=(const char *a, const Identifier &b) { + return !(a == b); +} + +//! Ordering (case-insensitive) +DUCKDB_API bool operator<(const Identifier &a, const Identifier &b); + +//! Streaming an identifier writes its raw name (without quotes) - this mirrors writing the underlying string +DUCKDB_API std::ostream &operator<<(std::ostream &os, const Identifier &id); + +//! String concatenation (std::operator+ is a template and cannot use the implicit conversion, so we provide our own) +inline string operator+(const Identifier &a, const string &b) { + return a.GetIdentifierName() + b; +} +inline string operator+(const string &a, const Identifier &b) { + return a + b.GetIdentifierName(); +} +inline string operator+(const Identifier &a, const char *b) { + return a.GetIdentifierName() + b; +} +inline string operator+(const char *a, const Identifier &b) { + return a + b.GetIdentifierName(); +} +inline string operator+(const Identifier &a, const Identifier &b) { + return a.GetIdentifierName() + b.GetIdentifierName(); +} +inline string operator+(const Identifier &a, char b) { + return a.GetIdentifierName() + b; +} +inline string operator+(char a, const Identifier &b) { + return a + b.GetIdentifierName(); +} + +//! Appending an identifier to a string appends the raw name +inline string &operator+=(string &a, const Identifier &b) { + a += b.GetIdentifierName(); + return a; +} + +struct IdentifierHashFunction { + uint64_t operator()(const Identifier &id) const { + return id.Hash(); + } +}; + +struct IdentifierEquality { + bool operator()(const Identifier &a, const Identifier &b) const { + return a == b; + } +}; + +struct IdentifierCompare { + bool operator()(const Identifier &a, const Identifier &b) const { + return a < b; + } +}; + +template +using identifier_map_t = unordered_map; + +using identifier_set_t = unordered_set; + +template +using identifier_tree_t = map; + +//! Helper to convert a vector of identifiers to a vector of (raw) strings (for interop with string-based APIs) +inline vector IdentifiersToStrings(const vector &identifiers) { + vector result; + result.reserve(identifiers.size()); + for (auto &identifier : identifiers) { + result.push_back(identifier.GetIdentifierName()); + } + return result; +} + +//! Helper to convert a vector of (raw) strings to a vector of identifiers (to be removed at the end of the rework) +inline vector StringsToIdentifiers(const vector &strings) { + vector result; + result.reserve(strings.size()); + for (auto &str : strings) { + result.emplace_back(str); + } + return result; +} + +//! Identifier-aware overloads of the invalid-catalog/schema checks. These live here (rather than next to the +//! string versions in constants.hpp) because constants.hpp is a dependency of this header. +inline bool IsInvalidCatalog(const Identifier &catalog) { + return catalog.empty(); +} +inline bool IsInvalidSchema(const Identifier &schema) { + return schema.empty(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp b/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp index 5b6f11730..b4e6f9f38 100644 --- a/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp +++ b/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp @@ -18,11 +18,11 @@ namespace duckdb { -template +template > class InsertionOrderPreservingMap { public: - typedef vector> VECTOR_TYPE; // NOLINT: matching name of std - typedef string key_type; // NOLINT: matching name of std + typedef vector> VECTOR_TYPE; // NOLINT: matching name of std + typedef KEY key_type; // NOLINT: matching name of std public: InsertionOrderPreservingMap() { @@ -30,11 +30,11 @@ class InsertionOrderPreservingMap { private: VECTOR_TYPE map; - case_insensitive_map_t map_idx; + INDEX_MAP map_idx; public: - vector Keys() const { - vector keys; + vector Keys() const { + vector keys; keys.resize(this->size()); for (auto &kv : map_idx) { keys[kv.second] = kv.first; @@ -67,7 +67,7 @@ class InsertionOrderPreservingMap { return map.rend(); } - typename VECTOR_TYPE::iterator find(const string &key) { // NOLINT: match stl API + typename VECTOR_TYPE::iterator find(const KEY &key) { // NOLINT: match stl API auto entry = map_idx.find(key); if (entry == map_idx.end()) { return map.end(); @@ -75,7 +75,7 @@ class InsertionOrderPreservingMap { return map.begin() + static_cast(entry->second); } - typename VECTOR_TYPE::const_iterator find(const string &key) const { // NOLINT: match stl API + typename VECTOR_TYPE::const_iterator find(const KEY &key) const { // NOLINT: match stl API auto entry = map_idx.find(key); if (entry == map_idx.end()) { return map.end(); @@ -97,9 +97,10 @@ class InsertionOrderPreservingMap { void clear() { // NOLINT: match stl API map.clear(); + map_idx.clear(); } - void insert(const string &key, V &&value) { // NOLINT: match stl API + void insert(const KEY &key, V &&value) { // NOLINT: match stl API if (contains(key)) { return; } @@ -107,7 +108,7 @@ class InsertionOrderPreservingMap { map_idx[key] = map.size() - 1; } - void insert(const string &key, const V &value) { // NOLINT: match stl API + void insert(const KEY &key, const V &value) { // NOLINT: match stl API if (contains(key)) { return; } @@ -115,7 +116,7 @@ class InsertionOrderPreservingMap { map_idx[key] = map.size() - 1; } - void insert(pair &&value) { // NOLINT: match stl API + void insert(pair &&value) { // NOLINT: match stl API auto &key = value.first; if (contains(key)) { return; @@ -136,15 +137,15 @@ class InsertionOrderPreservingMap { } } - bool contains(const string &key) const { // NOLINT: match stl API + bool contains(const KEY &key) const { // NOLINT: match stl API return map_idx.find(key) != map_idx.end(); } - const V &at(const string &key) const { // NOLINT: match stl API + const V &at(const KEY &key) const { // NOLINT: match stl API return map[map_idx.at(key)].second; } - V &operator[](const string &key) { + V &operator[](const KEY &key) { if (!contains(key)) { auto v = V(); insert(key, std::move(v)); diff --git a/src/duckdb/src/include/duckdb/common/local_file_system.hpp b/src/duckdb/src/include/duckdb/common/local_file_system.hpp index 5bac3da77..70543ae53 100644 --- a/src/duckdb/src/include/duckdb/common/local_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/local_file_system.hpp @@ -18,6 +18,10 @@ class LocalFileSystem : public FileSystem { unique_ptr OpenFile(const string &path, FileOpenFlags flags, optional_ptr opener = nullptr) override; + unique_ptr MemoryMapFile(const OpenFileInfo &path, FileOpenFlags flags, + const MMapOptions &options, + optional_ptr opener = nullptr) override; + //! Read exactly nr_bytes from the specified location in the file. Fails if nr_bytes could not be read. This is //! equivalent to calling SetFilePointer(location) followed by calling Read(). void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; @@ -88,6 +92,10 @@ class LocalFileSystem : public FileSystem { //! in a file on-disk are much cheaper than e.g. random reads in a file over the network bool OnDiskFile(FileHandle &handle) override; + bool IsLocalFileSystem() const override { + return true; + } + std::string GetName() const override { return "LocalFileSystem"; } diff --git a/src/duckdb/src/include/duckdb/common/memory_mapped_file.hpp b/src/duckdb/src/include/duckdb/common/memory_mapped_file.hpp new file mode 100644 index 000000000..c0c1e8249 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/memory_mapped_file.hpp @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/memory_mapped_file.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_open_flags.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/common/typedefs.hpp" + +namespace duckdb { + +struct MMapOptions { + //! Size of the virtual mapping. Writable mappings sparsely extend the file up to this + //! size at open time so the region never has to be remapped. Ignored for read-only. + idx_t reserve_size = 0; +}; + +//! A fixed-size memory-mapped view of a file. The mapping pointer is stable for the +//! lifetime of the object, so callers may share it across threads without synchronization. +class MemoryMappedFile { +public: + DUCKDB_API MemoryMappedFile(string path, FileOpenFlags flags, data_ptr_t data, idx_t size); + MemoryMappedFile(const MemoryMappedFile &) = delete; + DUCKDB_API virtual ~MemoryMappedFile(); + + //! Bounds-checked pointer to [location, location + nr_bytes) within the mapping. + DUCKDB_API const_data_ptr_t GetData(idx_t location, idx_t nr_bytes) const; + //! Bounds-checked mutable pointer; throws if the mapping was opened read-only. + DUCKDB_API data_ptr_t GetDataMutable(idx_t location, idx_t nr_bytes); + + DUCKDB_API virtual void Sync() = 0; + //! Punch a hole at [offset, offset + length), releasing backing disk space. Returns + //! false if the platform does not support hole punching. + DUCKDB_API virtual bool Trim(idx_t offset, idx_t length) = 0; + DUCKDB_API virtual bool OnDiskFile() const { + return true; + } + DUCKDB_API virtual void Close() = 0; + + DUCKDB_API idx_t Size() const { + return size; + } + const string &GetPath() const { + return path; + } + FileOpenFlags GetFlags() const { + return flags; + } + +public: + string path; + FileOpenFlags flags; + +protected: + data_ptr_t data = nullptr; + idx_t size = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp b/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp index 4fc05100e..a03e7d9b3 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/base_file_reader.hpp @@ -74,7 +74,7 @@ class BaseFileReader : public enable_shared_from_this { public: DUCKDB_API virtual shared_ptr GetUnionData(idx_t file_idx); //! Get statistics for a specific column - DUCKDB_API virtual unique_ptr GetStatistics(ClientContext &context, const string &name); + DUCKDB_API virtual unique_ptr GetStatistics(ClientContext &context, const Identifier &name); //! Prepare reader for scanning DUCKDB_API virtual void PrepareReader(ClientContext &context, GlobalTableFunctionState &); @@ -144,7 +144,7 @@ class BaseUnionData { virtual optional_idx TryGetCardinalityEstimate() const { return optional_idx(); } - virtual unique_ptr GetStatistics(ClientContext &context, const string &name); + virtual unique_ptr GetStatistics(ClientContext &context, const Identifier &name); public: template diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_adaptive_filter_cache.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_adaptive_filter_cache.hpp index bc08721dd..505ee98f7 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_adaptive_filter_cache.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_adaptive_filter_cache.hpp @@ -24,13 +24,7 @@ class MultiFileAdaptiveFilterCache { const vector &filter_global_indices, shared_ptr logger, const string &file_path); - AdaptiveFilter &GetAdaptiveFilter() const { - if (!filter) { - throw InternalException( - "Filter from MultiFileAdaptiveFilterCache must be initialized by 'InitializeAdaptiveFilter' first."); - } - return *filter; - } + AdaptiveFilter &GetAdaptiveFilter() const; private: unique_ptr filter; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp index 6bb61a87a..ae7eb40bb 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_data.hpp @@ -39,6 +39,10 @@ struct HivePartitioningIndex { struct MultiFileColumnDefinition { public: + MultiFileColumnDefinition(const Identifier &name, const LogicalType &type) : name(name), type(type) { + } + MultiFileColumnDefinition(const char *name, const LogicalType &type) : name(name), type(type) { + } MultiFileColumnDefinition(const string &name, const LogicalType &type) : name(name), type(type) { } @@ -60,7 +64,7 @@ struct MultiFileColumnDefinition { } public: - static MultiFileColumnDefinition CreateFromNameAndType(const string &name, const LogicalType &type) { + static MultiFileColumnDefinition CreateFromNameAndType(const Identifier &name, const LogicalType &type) { MultiFileColumnDefinition result(name, type); if (type.id() == LogicalTypeId::STRUCT) { // recursively create for children @@ -71,7 +75,7 @@ struct MultiFileColumnDefinition { return result; } - static vector ColumnsFromNamesAndTypes(const vector &names, + static vector ColumnsFromNamesAndTypes(const vector &names, const vector &types) { vector columns; D_ASSERT(names.size() == types.size()); @@ -88,7 +92,7 @@ struct MultiFileColumnDefinition { D_ASSERT(names.empty()); D_ASSERT(types.empty()); for (auto &column : columns) { - names.push_back(column.name); + names.push_back(column.name.GetIdentifierName()); types.push_back(column.type); } } @@ -102,7 +106,7 @@ struct MultiFileColumnDefinition { string GetIdentifierName() const { if (identifier.IsNull()) { // No identifier was provided, assume the name as the identifier - return name; + return name.GetIdentifierName(); } D_ASSERT(identifier.type().id() == LogicalTypeId::VARCHAR); return identifier.GetValue(); @@ -118,7 +122,7 @@ struct MultiFileColumnDefinition { } public: - string name; + Identifier name; LogicalType type; vector children; unique_ptr default_expression; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp index feafdc903..3f0668c5f 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp @@ -33,7 +33,7 @@ struct MultiFileReaderInterface { const vector &expected_names, const vector &expected_types); virtual unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, unique_ptr options) = 0; - virtual void BindReader(ClientContext &context, vector &return_types, vector &names, + virtual void BindReader(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) = 0; virtual void FinalizeBindData(MultiFileBindData &multi_file_data); virtual void GetBindInfo(const TableFunctionData &bind_data, BindInfo &info); @@ -69,7 +69,7 @@ struct MultiFileReaderInterface { template class MultiFileFunction : public TableFunction { public: - explicit MultiFileFunction(string name_p) + explicit MultiFileFunction(Identifier name_p) : TableFunction(std::move(name_p), {LogicalType::VARCHAR}, MultiFileScan, MultiFileBind, MultiFileInitGlobal, MultiFileInitLocal) { cardinality = MultiFileCardinality; @@ -80,10 +80,14 @@ class MultiFileFunction : public TableFunction { pushdown_complex_filter = MultiFileComplexFilterPushdown; get_partition_info = MultiFileGetPartitionInfo; get_virtual_columns = MultiFileGetVirtualColumns; - dynamic_to_string = MultiFileDynamicToString; + get_metrics = MultiFileGetMetrics; MultiFileReader::AddParameters(*this); } + static bool IsEmptyResult(const MultiFileBindData &bind_data) { + return bind_data.file_options.allow_empty && bind_data.file_list->IsEmpty(); + } + static unique_ptr MultiFileBindInternal(ClientContext &context, unique_ptr multi_file_reader_p, shared_ptr multi_file_list_p, @@ -101,6 +105,15 @@ class MultiFileFunction : public TableFunction { result->bind_data = interface.InitializeBindData(*result, std::move(options_p)); result->interface = std::move(interface_p); + if (IsEmptyResult(*result)) { + result->types.emplace_back(LogicalType::BOOLEAN); + result->names.emplace_back("empty"); + result->columns = MultiFileColumnDefinition::ColumnsFromNamesAndTypes(result->names, result->types); + return_types = result->types; + names = IdentifiersToStrings(result->names); + return std::move(result); + } + // now bind the readers // there are two ways of binding the readers // (1) MultiFileReader::Bind -> custom bind, used only for certain lakehouse extensions @@ -120,7 +133,7 @@ class MultiFileFunction : public TableFunction { if (return_types.empty()) { // no expected types - just copy the types return_types = result->types; - names = result->names; + names = IdentifiersToStrings(result->names); } else { // We're deserializing from a previously successful bind call // verify that the amount of columns still matches @@ -172,19 +185,30 @@ class MultiFileFunction : public TableFunction { auto multi_file_reader = MultiFileReader::Create(input.table_function); auto glob_input = multi_file_reader->GetGlobInput(*interface); + + MultiFileOptions file_options; + for (auto &kv : input.named_parameters) { + auto loption = StringUtil::Lower(kv.first.GetIdentifierName()); + if (loption == "allow_empty") { + multi_file_reader->ParseOption(loption, kv.second, file_options, context); + if (file_options.allow_empty) { + glob_input.allow_empty = true; + } + break; + } + } + auto file_list = multi_file_reader->CreateFileList(context, input.inputs[0], glob_input); interface->InitializeInterface(context, *multi_file_reader, *file_list); - MultiFileOptions file_options; - auto options = interface->InitializeOptions(context, input.info); for (auto &kv : input.named_parameters) { - auto loption = StringUtil::Lower(kv.first); + auto loption = StringUtil::Lower(kv.first.GetIdentifierName()); if (multi_file_reader->ParseOption(loption, kv.second, file_options, context)) { continue; } - if (interface->ParseOption(context, kv.first, kv.second, file_options, *options)) { + if (interface->ParseOption(context, kv.first.GetIdentifierName(), kv.second, file_options, *options)) { continue; } throw NotImplementedException("Unimplemented option %s", kv.first); @@ -332,6 +356,7 @@ class MultiFileFunction : public TableFunction { // Now re-lock the state and add the reader parallel_lock.lock(); + global_state.files_opened++; if (can_skip_file) { current_reader_data.file_state = MultiFileFileState::SKIPPED; // release the reader so its file handle is closed; skipped files are @@ -491,6 +516,10 @@ class MultiFileFunction : public TableFunction { auto &bind_data = input.bind_data->Cast(); auto &gstate = gstate_p->Cast(); + if (IsEmptyResult(bind_data)) { + return nullptr; + } + auto result = make_uniq(context.client); result->is_parallel = true; result->batch_index = 0; @@ -507,6 +536,17 @@ class MultiFileFunction : public TableFunction { auto &bind_data = input.bind_data->CastNoConst(); unique_ptr result; + if (IsEmptyResult(bind_data)) { + result = make_uniq(*bind_data.file_list); + result->file_list.InitializeScan(result->file_list_scan); + result->file_index = 0; + result->column_indexes = input.column_indexes; + result->filters = input.filters.get(); + result->op = input.op; + result->max_threads = 1; + return std::move(result); + } + // before instantiating a scan trigger a dynamic filter pushdown if possible auto new_list = MultiFileFilterPushdown(context, bind_data, input.column_ids, input.filters); if (new_list) { @@ -560,6 +600,7 @@ class MultiFileFunction : public TableFunction { } auto init_result = InitializeReader(*reader_data, bind_data, input.column_indexes, input.filters, context, file_idx, *result); + result->files_opened++; if (init_result == ReaderInitializeType::SKIP_READING_FILE) { //! File can be skipped entirely, close it and move on reader_data->file_state = MultiFileFileState::SKIPPED; @@ -658,9 +699,10 @@ class MultiFileFunction : public TableFunction { } } - output.SetCardinality(scan_chunk.size()); + output.SetChildCardinality(scan_chunk.size()); if (scan_chunk.size() > 0) { + data.rows_scanned += scan_chunk.size(); bind_data.multi_file_reader->FinalizeChunk(context, bind_data, *data.reader, *data.reader_data, scan_chunk, output, data.executor, gstate.multi_file_reader_state); @@ -805,6 +847,9 @@ class MultiFileFunction : public TableFunction { static unique_ptr MultiFileCardinality(ClientContext &context, const FunctionData *bind_data) { auto &data = bind_data->Cast(); + if (IsEmptyResult(data)) { + return make_uniq(0); + } auto file_list_cardinality_estimate = data.file_list->GetCardinality(context); if (file_list_cardinality_estimate) { return file_list_cardinality_estimate; @@ -868,6 +913,9 @@ class MultiFileFunction : public TableFunction { static TablePartitionInfo MultiFileGetPartitionInfo(ClientContext &context, TableFunctionPartitionInput &input) { auto &bind_data = input.bind_data->Cast(); + if (IsEmptyResult(bind_data)) { + return TablePartitionInfo::NOT_PARTITIONED; + } return bind_data.multi_file_reader->GetPartitionInfo(context, bind_data.reader_bind, input); } @@ -875,6 +923,9 @@ class MultiFileFunction : public TableFunction { optional_ptr bind_data_p) { auto &bind_data = bind_data_p->Cast(); virtual_column_map_t result; + if (IsEmptyResult(bind_data)) { + return result; + } MultiFileReader::GetVirtualColumns(context, bind_data.reader_bind, result); bind_data.interface->GetVirtualColumns(context, bind_data, result); @@ -883,11 +934,14 @@ class MultiFileFunction : public TableFunction { return result; } - static InsertionOrderPreservingMap MultiFileDynamicToString(TableFunctionDynamicToStringInput &input) { + static void MultiFileGetMetrics(TableFunctionGetMetricsInput &input) { auto &gstate = input.global_state->Cast(); - InsertionOrderPreservingMap result; - auto files_loaded = gstate.file_index.load(); - result.insert(make_pair("Total Files Read", std::to_string(files_loaded))); + if (input.local_state) { + auto &local = input.local_state->Cast(); + input.operator_metrics.rows_scanned = local.rows_scanned; + } + auto files_loaded = gstate.files_opened.load(); + input.operator_metrics.AddExtraInfo("Total Files Read", std::to_string(files_loaded)); constexpr size_t FILE_NAME_LIST_LIMIT = 5; auto file_paths = gstate.file_list.GetDisplayFileList(FILE_NAME_LIST_LIMIT + 1); @@ -900,9 +954,8 @@ class MultiFileFunction : public TableFunction { file_path_names.push_back("..."); } auto list_of_types = StringUtil::Join(file_path_names, ", "); - result.insert(make_pair("Filename(s)", list_of_types)); + input.operator_metrics.AddExtraInfo("Filename(s)", list_of_types); } - return result; } static void PushdownType(ClientContext &context, optional_ptr bind_data_p, diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_list.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_list.hpp index 0fc1655ed..b51dda837 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_list.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_list.hpp @@ -73,11 +73,11 @@ class MultiFileListIterationHelper { struct MultiFilePushdownInfo { explicit MultiFilePushdownInfo(LogicalGet &get); - MultiFilePushdownInfo(TableIndex table_index, const vector &column_names, + MultiFilePushdownInfo(TableIndex table_index, const vector &column_names, const vector &column_ids, ExtraOperatorInfo &extra_info); TableIndex table_index; - const vector &column_names; + const vector &column_names; vector column_ids; vector column_indexes; ExtraOperatorInfo &extra_info; @@ -109,7 +109,7 @@ class MultiFileList { MultiFilePushdownInfo &info, vector> &filters) const; virtual unique_ptr DynamicFilterPushdown(ClientContext &context, const MultiFileOptions &options, - const vector &names, + const vector &names, const vector &types, const vector &column_ids, TableFilterSet &filters) const; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_options.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_options.hpp index 78c16b6b2..282f87f58 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_options.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_options.hpp @@ -24,6 +24,7 @@ struct MultiFileOptions { bool auto_detect_hive_partitioning = true; bool union_by_name = false; bool hive_types_autocast = true; + bool allow_empty = false; MultiFileColumnMappingMode mapping = MultiFileColumnMappingMode::BY_NAME; case_insensitive_map_t hive_types_schema; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_reader.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_reader.hpp index d25a02d3b..b159073eb 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_reader.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_reader.hpp @@ -83,16 +83,16 @@ struct MultiFileReader { vector> &filters); DUCKDB_API virtual unique_ptr DynamicFilterPushdown(ClientContext &context, const MultiFileList &files, const MultiFileOptions &options, - const vector &names, const vector &types, + const vector &names, const vector &types, const vector &column_ids, TableFilterSet &filters); //! Try to use the MultiFileReader for binding. Returns true if a bind could be made, returns false if the //! MultiFileReader can not perform the bind and binding should be performed on 1 or more files in the MultiFileList //! directly. DUCKDB_API virtual bool Bind(MultiFileOptions &options, MultiFileList &files, vector &return_types, - vector &names, MultiFileReaderBindData &bind_data); + vector &names, MultiFileReaderBindData &bind_data); //! Bind the options of the multi-file reader, potentially emitting any extra columns that are required DUCKDB_API virtual void BindOptions(MultiFileOptions &options, MultiFileList &files, - vector &return_types, vector &names, + vector &return_types, vector &names, MultiFileReaderBindData &bind_data); //! Initialize global state used by the MultiFileReader @@ -141,12 +141,12 @@ struct MultiFileReader { virtual_column_map_t &result); MultiFileReaderBindData BindUnionReader(ClientContext &context, vector &return_types, - vector &names, MultiFileList &files, MultiFileBindData &result, + vector &names, MultiFileList &files, MultiFileBindData &result, BaseFileReaderOptions &options, MultiFileOptions &file_options); - MultiFileReaderBindData BindReader(ClientContext &context, vector &return_types, vector &names, - MultiFileList &files, MultiFileBindData &result, BaseFileReaderOptions &options, - MultiFileOptions &file_options); + MultiFileReaderBindData BindReader(ClientContext &context, vector &return_types, + vector &names, MultiFileList &files, MultiFileBindData &result, + BaseFileReaderOptions &options, MultiFileOptions &file_options); DUCKDB_API virtual ReaderInitializeType InitializeReader(MultiFileReaderData &reader_data, const MultiFileBindData &bind_data, @@ -198,7 +198,7 @@ struct MultiFileReader { protected: //! Used in errors to report which function is using this MultiFileReader - string function_name; + Identifier function_name; public: template diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp index bf48b73b0..fb8d0d995 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_states.hpp @@ -71,7 +71,7 @@ struct MultiFileBindData : public TableFunctionData { MultiFileReaderBindData reader_bind; MultiFileOptions file_options; vector types; - vector names; + vector names; virtual_column_map_t virtual_columns; //! Table column names - set when using COPY tbl FROM file.parquet vector table_columns; @@ -155,6 +155,8 @@ struct MultiFileGlobalState : public GlobalTableFunctionState { //! Index of file currently up for scanning atomic file_index; + //! Number of files that were actually opened by the scan + atomic files_opened = 0; //! Index of the lowest file we know we have completely read mutable idx_t completed_file_index = 0; //! The current set of readers @@ -198,6 +200,8 @@ struct MultiFileLocalState : public LocalTableFunctionState { DataChunk scan_chunk; //! The executor to transform scan_chunk into the final result with FinalizeChunk ExpressionExecutor executor; + //! Number of rows scanned by this thread (for profiling) + idx_t rows_scanned = 0; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/multi_file/union_by_name.hpp b/src/duckdb/src/include/duckdb/common/multi_file/union_by_name.hpp index 3e5b47e1f..6db0a2c02 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/union_by_name.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/union_by_name.hpp @@ -23,13 +23,13 @@ struct MultiFileReaderInterface; class UnionByName { public: static void CombineUnionTypes(const vector &new_names, const vector &new_types, - vector &union_col_types, vector &union_col_names, - case_insensitive_map_t &union_names_map); + vector &union_col_types, vector &union_col_names, + identifier_map_t &union_names_map); //! Union all files(readers) by their col names static vector> UnionCols(ClientContext &context, const vector &files, vector &union_col_types, - vector &union_col_names, BaseFileReaderOptions &options, MultiFileOptions &file_options, + vector &union_col_names, BaseFileReaderOptions &options, MultiFileOptions &file_options, MultiFileReader &multi_file_reader, MultiFileReaderInterface &interface); }; diff --git a/src/duckdb/src/include/duckdb/common/named_parameter_map.hpp b/src/duckdb/src/include/duckdb/common/named_parameter_map.hpp index b39e81e11..690eb6249 100644 --- a/src/duckdb/src/include/duckdb/common/named_parameter_map.hpp +++ b/src/duckdb/src/include/duckdb/common/named_parameter_map.hpp @@ -12,7 +12,7 @@ #include "duckdb/common/types.hpp" namespace duckdb { -using named_parameter_type_map_t = case_insensitive_map_t; -using named_parameter_map_t = case_insensitive_map_t; +using named_parameter_type_map_t = identifier_map_t; +using named_parameter_map_t = identifier_map_t; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/opener_file_system.hpp b/src/duckdb/src/include/duckdb/common/opener_file_system.hpp index b0a341aa6..9b06e92c5 100644 --- a/src/duckdb/src/include/duckdb/common/opener_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/opener_file_system.hpp @@ -106,6 +106,34 @@ class OpenerFileSystem : public FileSystem { GetFileSystem().MoveFile(source, target, GetOpener()); } + bool DirectoryExists(const string &directory) { + return DirectoryExists(directory, nullptr); + } + void CreateDirectory(const string &directory) { + CreateDirectory(directory, nullptr); + } + void RemoveDirectory(const string &directory) { + RemoveDirectory(directory, nullptr); + } + void MoveFile(const string &source, const string &target) { + MoveFile(source, target, nullptr); + } + bool FileExists(const string &filename) { + return FileExists(filename, nullptr); + } + bool IsPipe(const string &filename) { + return IsPipe(filename, nullptr); + } + void RemoveFile(const string &filename) { + RemoveFile(filename, nullptr); + } + bool TryRemoveFile(const string &filename) { + return TryRemoveFile(filename, nullptr); + } + void RemoveFiles(const vector &filenames) { + RemoveFiles(filenames, nullptr); + } + string GetHomeDirectory() override { return FileSystem::GetHomeDirectory(GetOpener()); } @@ -205,6 +233,12 @@ class OpenerFileSystem : public FileSystem { return true; } +public: + unique_ptr MemoryMapFile(const OpenFileInfo &path, FileOpenFlags flags, + const MMapOptions &options, + optional_ptr opener = nullptr) override; + +protected: bool ListFilesExtended(const string &directory, const std::function &callback, optional_ptr opener) override { VerifyNoOpener(opener); diff --git a/src/duckdb/src/include/duckdb/common/operator/aggregate_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/aggregate_operators.hpp index 583d862d0..e0da13d4f 100644 --- a/src/duckdb/src/include/duckdb/common/operator/aggregate_operators.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/aggregate_operators.hpp @@ -29,4 +29,39 @@ struct Max { } }; +struct LogicalAnd { + template + static inline T Operation(T left, T right) { + return left && right; + } +}; + +struct LogicalOr { + template + static inline T Operation(T left, T right) { + return left || right; + } +}; + +struct BitAnd { + template + static inline T Operation(T left, T right) { + return left & right; + } +}; + +struct BitOr { + template + static inline T Operation(T left, T right) { + return left | right; + } +}; + +struct BitXor { + template + static inline T Operation(T left, T right) { + return left ^ right; + } +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp index 6ae40db60..14397098a 100644 --- a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp @@ -636,12 +636,16 @@ DUCKDB_API bool TryCast::Operation(timestamp_t input, timestamp_ns_t &result, bo template <> DUCKDB_API bool TryCast::Operation(timestamp_tz_t input, timestamp_tz_t &result, bool strict); template <> +DUCKDB_API bool TryCast::Operation(timestamp_tz_t input, timestamp_t &result, bool strict); +template <> DUCKDB_API bool TryCast::Operation(timestamp_t input, timestamp_tz_t &result, bool strict); template <> DUCKDB_API bool TryCast::Operation(timestamp_tz_ns_t input, timestamp_tz_ns_t &result, bool strict); template <> DUCKDB_API bool TryCast::Operation(timestamp_ns_t input, timestamp_tz_ns_t &result, bool strict); template <> +DUCKDB_API bool TryCast::Operation(timestamp_tz_ns_t input, timestamp_ns_t &result, bool strict); +template <> DUCKDB_API bool TryCast::Operation(timestamp_ms_t input, timestamp_sec_t &result, bool strict); template <> DUCKDB_API bool TryCast::Operation(timestamp_ns_t input, timestamp_ms_t &result, bool strict); @@ -728,40 +732,20 @@ DUCKDB_API bool TryCastErrorMessage::Operation(string_t input, interval_t &resul //===--------------------------------------------------------------------===// // string -> Non-Standard Timestamps //===--------------------------------------------------------------------===// -struct TryCastToTimestampNS { - template - static inline bool Operation(SRC input, DST &result, bool strict = false) { - throw InternalException("Unsupported type for try cast to timestamp (ns)"); - } -}; - -struct TryCastToTimestampMS { - template - static inline bool Operation(SRC input, DST &result, bool strict = false) { - throw InternalException("Unsupported type for try cast to timestamp (ms)"); - } -}; - -struct TryCastToTimestampSec { - template - static inline bool Operation(SRC input, DST &result, bool strict = false) { - throw InternalException("Unsupported type for try cast to timestamp (s)"); - } -}; template <> -DUCKDB_API bool TryCastToTimestampNS::Operation(string_t input, timestamp_ns_t &result, bool strict); +DUCKDB_API bool TryCast::Operation(string_t input, timestamp_ns_t &result, bool strict); template <> -DUCKDB_API bool TryCastToTimestampMS::Operation(string_t input, timestamp_ms_t &result, bool strict); +DUCKDB_API bool TryCast::Operation(string_t input, timestamp_ms_t &result, bool strict); template <> -DUCKDB_API bool TryCastToTimestampSec::Operation(string_t input, timestamp_sec_t &result, bool strict); +DUCKDB_API bool TryCast::Operation(string_t input, timestamp_sec_t &result, bool strict); template <> -DUCKDB_API bool TryCastToTimestampNS::Operation(date_t input, timestamp_ns_t &result, bool strict); +DUCKDB_API bool TryCast::Operation(date_t input, timestamp_ns_t &result, bool strict); template <> -DUCKDB_API bool TryCastToTimestampMS::Operation(date_t input, timestamp_ms_t &result, bool strict); +DUCKDB_API bool TryCast::Operation(date_t input, timestamp_ms_t &result, bool strict); template <> -DUCKDB_API bool TryCastToTimestampSec::Operation(date_t input, timestamp_sec_t &result, bool strict); +DUCKDB_API bool TryCast::Operation(date_t input, timestamp_sec_t &result, bool strict); //===--------------------------------------------------------------------===// // string -> Non-Standard Time types @@ -815,145 +799,38 @@ struct CastFromTimestampSec { } }; -struct CastTimestampUsToMs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); - } -}; - -struct CastTimestampUsToNs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); - } -}; - -struct CastTimestampUsToSec { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); - } -}; - -struct CastTimestampMsToDate { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to DATE could not be performed!"); - } -}; - -struct CastTimestampMsToTime { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to TIME could not be performed!"); - } -}; - -struct CastTimestampMsToUs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); - } -}; - -struct CastTimestampMsToNs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to TIMESTAMP_NS could not be performed!"); - } -}; - -struct CastTimestampNsToDate { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to DATE could not be performed!"); - } -}; -struct CastTimestampNsToTime { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to TIME could not be performed!"); - } -}; -struct CastTimestampNsToTimeNs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to TIME_NS could not be performed!"); - } -}; -struct CastTimestampNsToUs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); - } -}; - -struct CastTimestampSecToDate { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to DATE could not be performed!"); - } -}; -struct CastTimestampSecToTime { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to TIME could not be performed!"); - } -}; -struct CastTimestampSecToMs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to TIMESTAMP_MS could not be performed!"); - } -}; - -struct CastTimestampSecToUs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to timestamp could not be performed!"); - } -}; - -struct CastTimestampSecToNs { - template - static inline DST Operation(SRC input) { - throw duckdb::NotImplementedException("Cast to TIMESTAMP_NS could not be performed!"); - } -}; - template <> -duckdb::timestamp_sec_t CastTimestampUsToSec::Operation(duckdb::timestamp_t input); +duckdb::timestamp_sec_t Cast::Operation(duckdb::timestamp_t input); template <> -duckdb::timestamp_ms_t CastTimestampUsToMs::Operation(duckdb::timestamp_t input); +duckdb::timestamp_ms_t Cast::Operation(duckdb::timestamp_t input); template <> -duckdb::timestamp_ns_t CastTimestampUsToNs::Operation(duckdb::timestamp_t input); +duckdb::timestamp_ns_t Cast::Operation(duckdb::timestamp_t input); template <> -duckdb::date_t CastTimestampMsToDate::Operation(duckdb::timestamp_ms_t input); +duckdb::date_t Cast::Operation(duckdb::timestamp_ms_t input); template <> -duckdb::dtime_t CastTimestampMsToTime::Operation(duckdb::timestamp_ms_t input); +duckdb::dtime_t Cast::Operation(duckdb::timestamp_ms_t input); template <> -duckdb::timestamp_t CastTimestampMsToUs::Operation(duckdb::timestamp_ms_t input); +duckdb::timestamp_t Cast::Operation(duckdb::timestamp_ms_t input); template <> -duckdb::timestamp_ns_t CastTimestampMsToNs::Operation(duckdb::timestamp_ms_t input); +duckdb::timestamp_ns_t Cast::Operation(duckdb::timestamp_ms_t input); template <> -duckdb::date_t CastTimestampNsToDate::Operation(duckdb::timestamp_ns_t input); +duckdb::date_t Cast::Operation(duckdb::timestamp_ns_t input); template <> -duckdb::dtime_t CastTimestampNsToTime::Operation(duckdb::timestamp_ns_t input); +duckdb::dtime_t Cast::Operation(duckdb::timestamp_ns_t input); template <> -duckdb::dtime_ns_t CastTimestampNsToTimeNs::Operation(duckdb::timestamp_ns_t input); +duckdb::dtime_ns_t Cast::Operation(duckdb::timestamp_ns_t input); template <> -duckdb::timestamp_t CastTimestampNsToUs::Operation(duckdb::timestamp_ns_t input); +duckdb::timestamp_t Cast::Operation(duckdb::timestamp_ns_t input); template <> -duckdb::date_t CastTimestampSecToDate::Operation(duckdb::timestamp_sec_t input); +duckdb::date_t Cast::Operation(duckdb::timestamp_sec_t input); template <> -duckdb::dtime_t CastTimestampSecToTime::Operation(duckdb::timestamp_sec_t input); +duckdb::dtime_t Cast::Operation(duckdb::timestamp_sec_t input); template <> -duckdb::timestamp_ms_t CastTimestampSecToMs::Operation(duckdb::timestamp_sec_t input); +duckdb::timestamp_ms_t Cast::Operation(duckdb::timestamp_sec_t input); template <> -duckdb::timestamp_t CastTimestampSecToUs::Operation(duckdb::timestamp_sec_t input); +duckdb::timestamp_t Cast::Operation(duckdb::timestamp_sec_t input); template <> -duckdb::timestamp_ns_t CastTimestampSecToNs::Operation(duckdb::timestamp_sec_t input); +duckdb::timestamp_ns_t Cast::Operation(duckdb::timestamp_sec_t input); template <> duckdb::string_t CastFromTimestampNS::Operation(duckdb::timestamp_ns_t input, StringHeap &heap); diff --git a/src/duckdb/src/include/duckdb/common/operator/double_cast_operator.hpp b/src/duckdb/src/include/duckdb/common/operator/double_cast_operator.hpp index 733a95d44..38160a744 100644 --- a/src/duckdb/src/include/duckdb/common/operator/double_cast_operator.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/double_cast_operator.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb.h" #include "fast_float/fast_float.h" #include "duckdb/common/string_util.hpp" diff --git a/src/duckdb/src/include/duckdb/common/operator/string_cast.hpp b/src/duckdb/src/include/duckdb/common/operator/string_cast.hpp index 9c410e044..9113b346e 100644 --- a/src/duckdb/src/include/duckdb/common/operator/string_cast.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/string_cast.hpp @@ -70,21 +70,13 @@ DUCKDB_API duckdb::string_t StringCast::Operation(timestamp_t input, StringHeap template <> DUCKDB_API duckdb::string_t StringCast::Operation(timestamp_ns_t input, StringHeap &heap); -//! Temporary casting for Time Zone types. TODO: turn casting into functions. -struct StringCastTZ { - template - static inline string_t Operation(SRC input, StringHeap &heap) { - return StringCast::Operation(input, heap); - } -}; - template <> -duckdb::string_t StringCastTZ::Operation(date_t input, StringHeap &heap); +duckdb::string_t StringCast::Operation(date_t input, StringHeap &heap); template <> -duckdb::string_t StringCastTZ::Operation(dtime_tz_t input, StringHeap &heap); +duckdb::string_t StringCast::Operation(dtime_tz_t input, StringHeap &heap); template <> -duckdb::string_t StringCastTZ::Operation(timestamp_t input, StringHeap &heap); +duckdb::string_t StringCast::Operation(timestamp_tz_t input, StringHeap &heap); template <> -duckdb::string_t StringCastTZ::Operation(timestamp_ns_t input, StringHeap &heap); +duckdb::string_t StringCast::Operation(timestamp_tz_ns_t input, StringHeap &heap); } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar.hpp b/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar.hpp index ca24836e9..4a1beadda 100644 --- a/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar.hpp +++ b/src/duckdb/src/include/duckdb/common/progress_bar/progress_bar.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb.h" #include "duckdb/execution/executor.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/profiler.hpp" diff --git a/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp b/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp index 5740aae64..dde8215e9 100644 --- a/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp +++ b/src/duckdb/src/include/duckdb/common/radix_partitioning.hpp @@ -52,7 +52,7 @@ struct RadixPartitioning { } //! Select using a cutoff on the radix bits of the hash - static idx_t Select(Vector &hashes, const SelectionVector *sel, idx_t count, idx_t radix_bits, + static idx_t Select(const Vector &hashes, const SelectionVector *sel, idx_t count, idx_t radix_bits, const ValidityMask &partition_mask, SelectionVector *true_sel, SelectionVector *false_sel); }; @@ -116,6 +116,9 @@ class RadixPartitionedTupleData : public PartitionedTupleData { return radix_bits; } + void ResetAppendState(PartitionedTupleDataAppendState &state, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const override; + private: void Initialize(); diff --git a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp index 382b60dd2..943c5cadc 100644 --- a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp +++ b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp @@ -14,9 +14,12 @@ namespace duckdb { +struct ClusteredAggr; + class ArenaAllocator; struct AggregateObject; struct AggregateFilterData; +struct AggregateFilterDataSet; class DataChunk; class TupleDataLayout; struct SelectionVector; @@ -39,19 +42,52 @@ struct RowOperations { //! initialize - unaligned addresses static void InitializeStates(TupleDataLayout &layout, Vector &addresses, const SelectionVector &sel, idx_t count); //! destructor - unaligned addresses, updated - static void DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, idx_t count); + static void DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses); //! update - aligned addresses static void UpdateStates(RowOperationsState &state, AggregateObject &aggr, Vector &addresses, DataChunk &payload, - idx_t arg_idx, idx_t count); + idx_t arg_idx, optional_ptr clustered = nullptr); //! filtered update - aligned addresses static void UpdateFilteredStates(RowOperationsState &state, AggregateFilterData &filter_data, AggregateObject &aggr, Vector &addresses, DataChunk &payload, idx_t arg_idx); + //! clustered update loop shared by grouped and perfect aggregate hash tables + static void UpdateStatesClustered(RowOperationsState &state, vector &aggregates, + AggregateFilterDataSet *filter_set, const unsafe_vector *filter, + Vector &addresses, DataChunk &payload, ClusteredAggr &clustered, + bool skip_addresses); //! combine - unaligned addresses, updated - static void CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, Vector &targets, - idx_t count); + static void CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, Vector &targets); //! finalize - unaligned addresses, updated static void FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, DataChunk &result, idx_t aggr_idx); + + //===--------------------------------------------------------------------===// + // Deprecated overloads (count parameter removed - use count-free versions) + //===--------------------------------------------------------------------===// + [[deprecated("count parameter is deprecated; call DestroyStates without count instead")]] static void + DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, idx_t count) { + if (count != addresses.size()) { + throw InternalException("DestroyStates: count (%llu) does not match vector size (%llu)", count, + addresses.size()); + } + DestroyStates(state, layout, addresses); + } + [[deprecated("count parameter is deprecated; call UpdateStates without count instead")]] static void + UpdateStates(RowOperationsState &state, AggregateObject &aggr, Vector &addresses, DataChunk &payload, idx_t arg_idx, + idx_t count) { + if (count != addresses.size()) { + throw InternalException("UpdateStates: count (%llu) does not match vector size (%llu)", count, + addresses.size()); + } + UpdateStates(state, aggr, addresses, payload, arg_idx); + } + [[deprecated("count parameter is deprecated; call CombineStates without count instead")]] static void + CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, Vector &targets, idx_t count) { + if (count != sources.size()) { + throw InternalException("CombineStates: count (%llu) does not match vector size (%llu)", count, + sources.size()); + } + CombineStates(state, layout, sources, targets); + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serialization_compatibility.hpp b/src/duckdb/src/include/duckdb/common/serialization_compatibility.hpp deleted file mode 100644 index 3c444f983..000000000 --- a/src/duckdb/src/include/duckdb/common/serialization_compatibility.hpp +++ /dev/null @@ -1,39 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/serialization_compatibility.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" - -namespace duckdb { -class AttachedDatabase; - -class SerializationCompatibility { -public: - static SerializationCompatibility FromDatabase(AttachedDatabase &db); - static SerializationCompatibility FromIndex(idx_t serialization_version); - static SerializationCompatibility FromString(const string &input); - static SerializationCompatibility Default(); - static SerializationCompatibility Latest(); - -public: - bool Compare(idx_t property_version) const; - -public: - //! The user provided version - string duckdb_version; - //! The max version that should be serialized - idx_t serialization_version; - //! Whether this was set by a manual SET/PRAGMA or default - bool manually_set; - -protected: - SerializationCompatibility() = default; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp b/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp index 98a99c893..47d39ca5f 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/deserializer.hpp @@ -11,12 +11,15 @@ #include "duckdb/common/enum_util.hpp" #include "duckdb/common/serializer/serialization_data.hpp" #include "duckdb/common/serializer/serialization_traits.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/unordered_set.hpp" #include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/execution/operator/csv_scanner/csv_option.hpp" +#include "duckdb/storage/table/per_column_metadata_blocks.hpp" +#include "duckdb/storage/table/per_column_metadata_blocks.hpp" namespace duckdb { @@ -137,6 +140,16 @@ class Deserializer { OnPropertyEnd(); } + inline bool ReadOptionalProperty(const field_id_t field_id, const char *tag, data_ptr_t ret, idx_t count) { + if (!OnOptionalPropertyBegin(field_id, tag)) { + OnOptionalPropertyEnd(false); + return false; + } + ReadDataPtr(ret, count); + OnOptionalPropertyEnd(true); + return true; + } + // Try to read a property, if it is not present, continue, otherwise read and discard the value template inline void ReadDeletedProperty(const field_id_t field_id, const char *tag) { @@ -354,12 +367,13 @@ class Deserializer { template inline typename std::enable_if::value, T>::type Read() { using VALUE_TYPE = typename is_insertion_preserving_map::VALUE_TYPE; + using KEY_TYPE = typename is_insertion_preserving_map::KEY_TYPE; T map; auto size = OnListBegin(); for (idx_t i = 0; i < size; i++) { OnObjectBegin(); - auto key = ReadProperty(0, "key"); + auto key = ReadProperty(0, "key"); auto value = ReadProperty(1, "value"); OnObjectEnd(); map[key] = std::move(value); @@ -498,6 +512,12 @@ class Deserializer { return ReadString(); } + // Deserialize an Identifier (stored identically to a plain string) + template + inline typename std::enable_if::value, T>::type Read() { + return Identifier(ReadString()); + } + // Deserialize a Enum template inline typename std::enable_if::value, T>::type Read() { @@ -552,6 +572,12 @@ class Deserializer { return idx == DConstants::INVALID_INDEX ? optional_idx() : optional_idx(idx); } + // Deserialize a ProjectionIndex + template + inline typename std::enable_if::value, T>::type Read() { + return PerColumnMetadataBlock::Unpack(ReadUnsignedInt64()); + } + protected: // Hooks for subclasses to override to implement custom behavior virtual void OnPropertyBegin(const field_id_t field_id, const char *tag) = 0; diff --git a/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp b/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp index 9f81b0fea..6f1d76d3b 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/serialization_traits.hpp @@ -3,6 +3,7 @@ #include #include +#include "duckdb/common/identifier.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/map.hpp" #include "duckdb/common/unordered_map.hpp" @@ -111,6 +112,7 @@ struct is_insertion_preserving_map : std::false_type {}; template struct is_insertion_preserving_map> : std::true_type { typedef typename std::tuple_element<0, std::tuple>::type VALUE_TYPE; + typedef typename duckdb::InsertionOrderPreservingMap::key_type KEY_TYPE; }; template @@ -339,6 +341,16 @@ struct SerializationDefaultValue { return value.empty(); } + template + static inline typename std::enable_if::value, T>::type GetDefault() { + return T(); + } + + template + static inline bool IsDefault(const typename std::enable_if::value, T>::type &value) { + return value.empty(); + } + template static inline typename std::enable_if::value, T>::type GetDefault() { return optional_idx(); diff --git a/src/duckdb/src/include/duckdb/common/serializer/serializer.hpp b/src/duckdb/src/include/duckdb/common/serializer/serializer.hpp index fe53f91a9..6885aeb2a 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/serializer.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/serializer.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/enum_util.hpp" #include "duckdb/common/serializer/serialization_traits.hpp" #include "duckdb/common/serializer/serialization_data.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/types/interval.hpp" #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/uhugeint.hpp" @@ -22,10 +23,10 @@ #include "duckdb/common/value_operations/value_operations.hpp" #include "duckdb/execution/operator/csv_scanner/csv_option.hpp" #include "duckdb/common/insertion_order_preserving_map.hpp" -#include "duckdb/common/serialization_compatibility.hpp" +#include "duckdb/common/storage_compatibility.hpp" +#include "duckdb/storage/table/per_column_metadata_blocks.hpp" namespace duckdb { - class SerializationOptions { public: SerializationOptions() = default; @@ -33,7 +34,7 @@ class SerializationOptions { bool serialize_enum_as_string = false; bool serialize_default_values = false; - SerializationCompatibility serialization_compatibility = SerializationCompatibility::Default(); + StorageCompatibility storage_compatibility = StorageCompatibility::Default(); }; class Serializer { @@ -45,8 +46,12 @@ class Serializer { virtual ~Serializer() { } - bool ShouldSerialize(idx_t version_added) { - return options.serialization_compatibility.Compare(version_added); + bool ShouldSerializeInternal(StorageVersion version_added) const { + return options.storage_compatibility.Compare(version_added); + } + + bool ShouldSerialize(StorageVersion version_added) const { + return ShouldSerializeInternal(version_added); } class List { @@ -310,8 +315,8 @@ class Serializer { // Insertion Order Preserving Map // serialized as a list of pairs - template - void WriteValue(const duckdb::InsertionOrderPreservingMap &map) { + template + void WriteValue(const duckdb::InsertionOrderPreservingMap &map) { auto count = map.size(); OnListBegin(count); for (auto &entry : map) { @@ -378,6 +383,10 @@ class Serializer { virtual void WriteValue(const string &value) = 0; virtual void WriteValue(const char *str) = 0; virtual void WriteDataPtr(const_data_ptr_t ptr, idx_t count) = 0; + //! Identifiers are serialized identically to a plain string (preserving the original casing) + void WriteValue(const Identifier &value) { + WriteValue(value.GetIdentifierName()); + } void WriteValue(LogicalIndex value) { WriteValue(value.index); } @@ -393,6 +402,9 @@ class Serializer { void WriteValue(optional_idx value) { WriteValue(value.IsValid() ? value.GetIndex() : DConstants::INVALID_INDEX); } + void WriteValue(PerColumnMetadataBlock value) { + WriteValue(value.GetPacked()); + } }; // We need to special case vector because elements of vector cannot be referenced diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp index 79d71acfa..fedb2dba4 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp @@ -25,13 +25,12 @@ class SortedRunScanState { SortedRunScanState(ClientContext &context, const Sort &sort); public: - void Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, DataChunk &chunk); + void Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, DataChunk &chunk); void Clear(); private: template - void TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, - DataChunk &chunk); + void TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, DataChunk &chunk); private: const Sort &sort; diff --git a/src/duckdb/src/include/duckdb/common/sql_identifier.hpp b/src/duckdb/src/include/duckdb/common/sql_identifier.hpp index dee7f429c..d7dbb6e4b 100644 --- a/src/duckdb/src/include/duckdb/common/sql_identifier.hpp +++ b/src/duckdb/src/include/duckdb/common/sql_identifier.hpp @@ -15,6 +15,7 @@ namespace duckdb { class String; +class Identifier; // Helper class to support custom overloading // Optionally escaping " and quoting the value with " if required @@ -22,6 +23,9 @@ class SQLIdentifier { public: explicit SQLIdentifier(const string &raw_string) : raw_string(raw_string) { } + explicit SQLIdentifier(const char *raw_string) : raw_string(raw_string) { + } + explicit SQLIdentifier(const Identifier &id); //! Emits an optionally quoted identifier including required escapes (i.e. ident -> ident, table -> "table") static string ToString(const string &identifier); diff --git a/src/duckdb/src/include/duckdb/common/storage_compatibility.hpp b/src/duckdb/src/include/duckdb/common/storage_compatibility.hpp new file mode 100644 index 000000000..b72be82fd --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/storage_compatibility.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/storage_compatibility.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/storage/storage_info.hpp" + +namespace duckdb { +class AttachedDatabase; + +class StorageCompatibility { +public: + static StorageCompatibility FromDatabase(AttachedDatabase &db); + static StorageCompatibility FromIndex(StorageVersion storage_version_p); + static StorageCompatibility FromString(const string &input); + static StorageCompatibility Default(); + static StorageCompatibility Latest(); + +public: + bool Compare(StorageVersion property_version) const; + bool CompareVersionString(const string &property_version) const; + StorageVersion GetStorageVersionCompatibility() const; + +public: + //! The user provided version + string duckdb_version; + //! The max storage version that should be serialized + StorageVersion storage_version; + //! Whether this was set by a manual SET/PRAGMA or default + bool manually_set; + +protected: + StorageCompatibility() = default; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/string_util.hpp b/src/duckdb/src/include/duckdb/common/string_util.hpp index 71cbad735..71560fc5d 100644 --- a/src/duckdb/src/include/duckdb/common/string_util.hpp +++ b/src/duckdb/src/include/duckdb/common/string_util.hpp @@ -143,6 +143,7 @@ class StringUtil { //! Join multiple strings into one string. Components are concatenated by the given separator DUCKDB_API static string Join(const vector &input, const string &separator); + DUCKDB_API static string Join(const vector &input, const string &separator); DUCKDB_API static string Join(const set &input, const string &separator); //! Encode special URL characters in a string @@ -229,6 +230,7 @@ class StringUtil { //! Case insensitive find, returns DConstants::INVALID_INDEX if not found DUCKDB_API static idx_t CIFind(vector &vec, const string &str); + DUCKDB_API static idx_t CIFind(const vector &vec, const Identifier &str); //! Format a string using printf semantics template @@ -262,6 +264,7 @@ class StringUtil { DUCKDB_API static idx_t SimilarityScore(const string &s1, const string &s2); //! Returns a normalized similarity rating between 0.0 - 1.0 (higher is more similar) DUCKDB_API static double SimilarityRating(const string &s1, const string &s2); + DUCKDB_API static double SimilarityRating(const Identifier &s1, const Identifier &s2); //! Get the top-n strings (sorted by the given score distance) from a set of scores. //! The scores should be normalized between 0.0 and 1.0, where 1.0 is the highest score //! At least one entry is returned (if there is one). @@ -278,6 +281,8 @@ class StringUtil { idx_t threshold = 5); //! Computes the jaro winkler distance of each string in strings, and compares it to target, then returns //! TopNStrings with the given params. + DUCKDB_API static vector TopNJaroWinkler(const vector &strings, const Identifier &target, + idx_t n = 5, double threshold = 0.5); DUCKDB_API static vector TopNJaroWinkler(const vector &strings, const string &target, idx_t n = 5, double threshold = 0.5); DUCKDB_API static string CandidatesMessage(const vector &candidates, diff --git a/src/duckdb/src/include/duckdb/common/table_column.hpp b/src/duckdb/src/include/duckdb/common/table_column.hpp index ba5f674ef..84961c61f 100644 --- a/src/duckdb/src/include/duckdb/common/table_column.hpp +++ b/src/duckdb/src/include/duckdb/common/table_column.hpp @@ -15,10 +15,10 @@ namespace duckdb { struct TableColumn { TableColumn() = default; - TableColumn(string name_p, LogicalType type_p) : name(std::move(name_p)), type(std::move(type_p)) { + TableColumn(Identifier name_p, LogicalType type_p) : name(std::move(name_p)), type(std::move(type_p)) { } - string name; + Identifier name; LogicalType type; void Serialize(Serializer &serializer) const; diff --git a/src/duckdb/src/include/duckdb/common/thread.hpp b/src/duckdb/src/include/duckdb/common/thread.hpp index 277e34b39..ccf773c5c 100644 --- a/src/duckdb/src/include/duckdb/common/thread.hpp +++ b/src/duckdb/src/include/duckdb/common/thread.hpp @@ -23,10 +23,14 @@ using thread_id = std::thread::id; using thread_id = uint64_t; #endif +#include "duckdb/common/optional_ptr.hpp" + namespace duckdb { +class ClientContext; + struct ThreadUtil { - static void SleepMs(idx_t ms); + static void SleepMs(idx_t ms, optional_ptr context = nullptr); static void SleepMicroSeconds(idx_t micros); static thread_id GetThreadId(); static string GetThreadIdString(); diff --git a/src/duckdb/src/include/duckdb/common/type_util.hpp b/src/duckdb/src/include/duckdb/common/type_util.hpp index 72b7a78a4..aebaae653 100644 --- a/src/duckdb/src/include/duckdb/common/type_util.hpp +++ b/src/duckdb/src/include/duckdb/common/type_util.hpp @@ -135,4 +135,64 @@ bool IsIntegerType() { return TypeIsIntegral(GetTypeId()); } +//! Returns the LogicalType corresponding to a C++ primitive type +template +LogicalType PrimitiveToLogicalType() { + using TYPE = typename std::remove_cv::type; + if (std::is_same()) { + return LogicalType::BOOLEAN; + } else if (std::is_same()) { + return LogicalType::TINYINT; + } else if (std::is_same()) { + return LogicalType::SMALLINT; + } else if (std::is_same()) { + return LogicalType::INTEGER; + } else if (std::is_same()) { + return LogicalType::BIGINT; + } else if (std::is_same()) { + return LogicalType::UTINYINT; + } else if (std::is_same()) { + return LogicalType::USMALLINT; + } else if (std::is_same()) { + return LogicalType::UINTEGER; + } else if (std::is_same() || std::is_same()) { + return LogicalType::UBIGINT; + } else if (std::is_same()) { + return LogicalType::HUGEINT; + } else if (std::is_same()) { + return LogicalType::UHUGEINT; + } else if (std::is_same()) { + return LogicalType::FLOAT; + } else if (std::is_same()) { + return LogicalType::DOUBLE; + } else if (std::is_same()) { + return LogicalType::DATE; + } else if (std::is_same()) { + return LogicalType::TIME; + } else if (std::is_same()) { + return LogicalType::TIME_TZ; + } else if (std::is_same()) { + return LogicalType::TIME_NS; + } else if (std::is_same()) { + return LogicalType::TIMESTAMP; + } else if (std::is_same()) { + return LogicalType::TIMESTAMP_S; + } else if (std::is_same()) { + return LogicalType::TIMESTAMP_MS; + } else if (std::is_same()) { + return LogicalType::TIMESTAMP_NS; + } else if (std::is_same()) { + return LogicalType::TIMESTAMP_TZ; + } else if (std::is_same()) { + return LogicalType::TIMESTAMP_TZ_NS; + } else if (std::is_same()) { + return LogicalType::INTERVAL; + } else if (std::is_same() || std::is_same() || std::is_same() || + std::is_same()) { + return LogicalType::VARCHAR; + } else { + throw InternalException("Unsupported type in PrimitiveToLogicalType"); + } +} + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types.hpp b/src/duckdb/src/include/duckdb/common/types.hpp index a6753880d..ebdca4d00 100644 --- a/src/duckdb/src/include/duckdb/common/types.hpp +++ b/src/duckdb/src/include/duckdb/common/types.hpp @@ -10,10 +10,10 @@ #include "duckdb/common/constants.hpp" #include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/helper.hpp" - namespace duckdb { class Serializer; @@ -28,7 +28,7 @@ class CoordinateReferenceSystem; struct string_t; // NOLINT: mimic std casing template -using child_list_t = vector>; +using child_list_t = vector>; template using buffer_ptr = shared_ptr; @@ -176,7 +176,7 @@ enum class PhysicalType : uint8_t { /// DuckDB Extensions VARCHAR = 200, // our own string representation, different from STRING and LARGE_STRING above UINT128 = 203, // 128-bit unsigned integers - INT128 = 204, // 128-bit integers + INT128 = 204, // 128-bit integers UNKNOWN = 205, // Unknown physical type of user defined types /// Boolean as 1 bit, LSB bit-packed ordering BIT = 206, @@ -192,7 +192,7 @@ enum class LogicalTypeId : uint8_t { SQLNULL = 1, /* NULL type, used for constant NULL */ UNKNOWN = 2, /* unknown type, used for parameter expressions */ ANY = 3, /* ANY type, used for functions that accept any type as parameter */ - UNBOUND = 4, /* A parsed but unbound type, used during query planning */ + UNBOUND = 4, /* A parsed but unbound type, used during query planning */ // A "template" type functions as a "placeholder" type for function arguments and return types. // Templates only exist during the binding phase, in the scope of a function, and are replaced with concrete types @@ -201,7 +201,7 @@ enum class LogicalTypeId : uint8_t { // name are always resolved to the same concrete type. TEMPLATE = 5, - TYPE = 6, /* Type type, used for type parameters */ + TYPE = 6, /* Type type, used for type parameters */ BOOLEAN = 10, TINYINT = 11, @@ -230,8 +230,8 @@ enum class LogicalTypeId : uint8_t { TIME_TZ = 34, TIME_NS = 35, BIT = 36, - STRING_LITERAL = 37, /* string literals, used for constant strings - only exists while binding */ - INTEGER_LITERAL = 38,/* integer literals, used for constant integers - only exists while binding */ + STRING_LITERAL = 37, /* string literals, used for constant strings - only exists while binding */ + INTEGER_LITERAL = 38, /* integer literals, used for constant integers - only exists while binding */ BIGNUM = 39, UHUGEINT = 49, HUGEINT = 50, @@ -250,15 +250,12 @@ enum class LogicalTypeId : uint8_t { LAMBDA = 106, UNION = 107, ARRAY = 108, - VARIANT = 109, - AGGREGATE_STATE = 110, // struct-based aggregate state + VARIANT = 109 }; struct ExtraTypeInfo; struct ExtensionTypeInfo; -struct aggregate_state_t; // NOLINT: mimic std casing - struct LogicalType { DUCKDB_API LogicalType(); DUCKDB_API LogicalType(LogicalTypeId id); // NOLINT: Allow implicit conversion from `LogicalTypeId` @@ -361,16 +358,28 @@ struct LogicalType { DUCKDB_API optional_ptr GetExtensionInfo(); DUCKDB_API void SetExtensionInfo(unique_ptr info); - //! Returns the maximum logical type when combining the two types - or throws an exception if combining is not possible - DUCKDB_API static LogicalType MaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right); - DUCKDB_API static bool TryGetMaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right, LogicalType &result); - DUCKDB_API static bool TryGetMaxLogicalTypeUnchecked(const LogicalType &left, const LogicalType &right, LogicalType &result); - //! Forcibly returns a maximum logical type - similar to MaxLogicalType but never throws. As a fallback either left or right are returned. - DUCKDB_API static LogicalType ForceMaxLogicalType(const LogicalType &left, const LogicalType &right); + //! Returns the maximum logical type when combining the two types - or throws an exception if combining is not + //! possible + DUCKDB_API static LogicalType MaxLogicalType(ClientContext &context, const LogicalType &left, + const LogicalType &right); + DUCKDB_API static bool TryGetMaxLogicalType(ClientContext &context, const LogicalType &left, + const LogicalType &right, LogicalType &result); + DUCKDB_API static bool TryGetMaxLogicalTypeUnchecked(ClientContext &context, const LogicalType &left, + const LogicalType &right, LogicalType &result); + //! Variant of TryGetMaxLogicalTypeUnchecked for call-sites that have no ClientContext available; uses only the + //! built-in CastRules (no extension-registered casts). + DUCKDB_API static bool DefaultTryGetMaxLogicalTypeUnchecked(const LogicalType &left, const LogicalType &right, + LogicalType &result); + //! Forcibly returns a maximum logical type - similar to MaxLogicalType but never throws. As a fallback either left + //! or right are returned. + DUCKDB_API static LogicalType ForceMaxLogicalType(ClientContext &context, const LogicalType &left, + const LogicalType &right); + //! Variant of ForceMaxLogicalType for call-sites that have no ClientContext available; uses only the built-in + //! CastRules (no extension-registered casts). + DUCKDB_API static LogicalType DefaultForceMaxLogicalType(const LogicalType &left, const LogicalType &right); //! Normalize a type - removing literals DUCKDB_API static LogicalType NormalizeType(const LogicalType &type); - //! Gets the decimal properties of a numeric type. Fails if the type is not numeric. DUCKDB_API bool GetDecimalProperties(uint8_t &width, uint8_t &scale) const; @@ -386,10 +395,9 @@ struct LogicalType { //! True, if this type supports in-place updates. bool SupportsRegularUpdate() const; - private: - LogicalTypeId id_; // NOLINT: allow this naming for legacy reasons - PhysicalType physical_type_; // NOLINT: allow this naming for legacy reasons + LogicalTypeId id_; // NOLINT: allow this naming for legacy reasons + PhysicalType physical_type_; // NOLINT: allow this naming for legacy reasons shared_ptr type_info_; // NOLINT: allow this naming for legacy reasons private: @@ -438,30 +446,27 @@ struct LogicalType { // explicitly allowing these functions to be capitalized to be in-line with the remaining functions DUCKDB_API static LogicalType DECIMAL(uint8_t width, uint8_t scale); // NOLINT - DUCKDB_API static LogicalType VARCHAR_COLLATION(string collation); // NOLINT - DUCKDB_API static LogicalType LIST(const LogicalType &child); // NOLINT - DUCKDB_API static LogicalType STRUCT(child_list_t children); // NOLINT - DUCKDB_API static LogicalType LEGACY_AGGREGATE_STATE(aggregate_state_t state_type); // NOLINT - DUCKDB_API static LogicalType AGGREGATE_STATE(aggregate_state_t state_type, // NOLINT - child_list_t struct_child_types); // NOLINT - DUCKDB_API static LogicalType MAP(const LogicalType &child); // NOLINT - DUCKDB_API static LogicalType MAP(LogicalType key, LogicalType value); // NOLINT - DUCKDB_API static LogicalType UNION(child_list_t members); // NOLINT + DUCKDB_API static LogicalType VARCHAR_COLLATION(string collation); // NOLINT + DUCKDB_API static LogicalType LIST(const LogicalType &child); // NOLINT + DUCKDB_API static LogicalType STRUCT(child_list_t children); // NOLINT + DUCKDB_API static LogicalType MAP(const LogicalType &child); // NOLINT + DUCKDB_API static LogicalType MAP(LogicalType key, LogicalType value); // NOLINT + DUCKDB_API static LogicalType UNION(child_list_t members); // NOLINT DUCKDB_API static LogicalType ARRAY(const LogicalType &child, optional_idx index); // NOLINT - DUCKDB_API static LogicalType ENUM(Vector &ordered_data, idx_t size); // NOLINT - DUCKDB_API static LogicalType GEOMETRY(); // NOLINT + DUCKDB_API static LogicalType ENUM(const Vector &ordered_data, idx_t size); // NOLINT + DUCKDB_API static LogicalType GEOMETRY(); // NOLINT DUCKDB_API static LogicalType GEOMETRY(const string &crs); DUCKDB_API static LogicalType GEOMETRY(const CoordinateReferenceSystem &crs); // ANY but with special rules (default is LogicalType::ANY, 5) DUCKDB_API static LogicalType ANY_PARAMS(LogicalType target, idx_t cast_score = 5); // NOLINT - DUCKDB_API static LogicalType TEMPLATE(const string &name); // NOLINT - DUCKDB_API static LogicalType VARIANT(); // NOLINT + DUCKDB_API static LogicalType TEMPLATE(const string &name); // NOLINT + DUCKDB_API static LogicalType VARIANT(); // NOLINT //! Integer literal of the specified value - DUCKDB_API static LogicalType INTEGER_LITERAL(const Value &constant); // NOLINT + DUCKDB_API static LogicalType INTEGER_LITERAL(const Value &constant); // NOLINT // DEPRECATED - provided for backwards compatibility - DUCKDB_API static LogicalType ENUM(const string &enum_name, Vector &ordered_data, idx_t size); // NOLINT - DUCKDB_API static LogicalType UNBOUND(unique_ptr expr); // NOLINT - DUCKDB_API static LogicalType TYPE(); // NOLINT + DUCKDB_API static LogicalType ENUM(const string &enum_name, const Vector &ordered_data, idx_t size); // NOLINT + DUCKDB_API static LogicalType UNBOUND(unique_ptr expr); // NOLINT + DUCKDB_API static LogicalType TYPE(); // NOLINT //! A list of all NUMERIC types (integral and floating point types) DUCKDB_API static const vector Numeric(); //! A list of all INTEGRAL types @@ -476,7 +481,7 @@ struct LogicalType { static constexpr auto JSON_TYPE_NAME = "JSON"; DUCKDB_API static LogicalType JSON(); // NOLINT DUCKDB_API bool IsJSONType() const; - DUCKDB_API bool IsAggregateStateStructType() const; + DUCKDB_API bool IsAggregateState() const; }; struct DecimalType { @@ -493,7 +498,6 @@ struct ListType { DUCKDB_API static const LogicalType &GetChildType(const LogicalType &type); }; - struct UnboundType { // Try to bind the unbound type into a concrete type, using just the built in types DUCKDB_API static LogicalType TryParseAndDefaultBind(const string &type_str); @@ -513,7 +517,7 @@ struct EnumType { struct StructType { DUCKDB_API static const child_list_t &GetChildTypes(const LogicalType &type); DUCKDB_API static const LogicalType &GetChildType(const LogicalType &type, idx_t index); - DUCKDB_API static const string &GetChildName(const LogicalType &type, idx_t index); + DUCKDB_API static const Identifier &GetChildName(const LogicalType &type, idx_t index); DUCKDB_API static idx_t GetChildIndexUnsafe(const LogicalType &type, const string &name); DUCKDB_API static idx_t GetChildCount(const LogicalType &type); DUCKDB_API static bool IsUnnamed(const LogicalType &type); @@ -528,7 +532,7 @@ struct UnionType { DUCKDB_API static const idx_t MAX_UNION_MEMBERS = 256; DUCKDB_API static idx_t GetMemberCount(const LogicalType &type); DUCKDB_API static const LogicalType &GetMemberType(const LogicalType &type, idx_t index); - DUCKDB_API static const string &GetMemberName(const LogicalType &type, idx_t index); + DUCKDB_API static const Identifier &GetMemberName(const LogicalType &type, idx_t index); DUCKDB_API static const child_list_t CopyMemberTypes(const LogicalType &type); }; @@ -541,16 +545,6 @@ struct ArrayType { DUCKDB_API static LogicalType ConvertToList(const LogicalType &type); }; -struct LegacyAggregateStateType { - DUCKDB_API static const string GetTypeName(const LogicalType &type); - DUCKDB_API static const aggregate_state_t &GetStateType(const LogicalType &type); -}; - -struct AggregateStateType : public StructType { - DUCKDB_API static const string GetTypeName(const LogicalType &type); - DUCKDB_API static const aggregate_state_t &GetStateType(const LogicalType &type); -}; - struct AnyType { DUCKDB_API static LogicalType GetTargetType(const LogicalType &type); DUCKDB_API static idx_t GetCastScore(const LogicalType &type); @@ -593,20 +587,4 @@ DUCKDB_API bool TypeIsInteger(PhysicalType type); bool ApproxEqual(float l, float r); bool ApproxEqual(double l, double r); -struct aggregate_state_t { - aggregate_state_t() { - } - // NOLINTNEXTLINE: work around bug in clang-tidy - aggregate_state_t(string function_name_p, LogicalType return_type_p, vector bound_argument_types_p) - : function_name(std::move(function_name_p)), return_type(std::move(return_type_p)), - bound_argument_types(std::move(bound_argument_types_p)) { - } - - string function_name; - LogicalType return_type; - vector bound_argument_types; -}; - - - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp index 6f2c8a1c1..8ae321eb6 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_allocator.hpp @@ -83,6 +83,8 @@ class ColumnDataAllocator { public: void AllocateData(idx_t size, uint32_t &block_id, uint32_t &offset, ChunkManagementState *chunk_state); + void Reset(); + void ResetPreserveLastBlock(); void Initialize(ColumnDataAllocator &other); void InitializeChunkState(ChunkManagementState &state, ChunkMetaData &meta_data); diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp index 6cb1d7bdd..c9f7770db 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection.hpp @@ -111,6 +111,8 @@ class ColumnDataCollection { //! Appends the other ColumnDataCollection to this, destroying the other data collection DUCKDB_API void Combine(ColumnDataCollection &other); + //! Swaps the contents of two compatible ColumnDataCollections + DUCKDB_API void Swap(ColumnDataCollection &other); DUCKDB_API void Verify(); @@ -118,6 +120,7 @@ class ColumnDataCollection { DUCKDB_API void Print() const; DUCKDB_API void Reset(); + DUCKDB_API void ResetForReuse(); //! Returns the number of data chunks present in the ColumnDataCollection DUCKDB_API idx_t ChunkCount() const; diff --git a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_segment.hpp b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_segment.hpp index 910c71dae..ee000c023 100644 --- a/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_segment.hpp +++ b/src/duckdb/src/include/duckdb/common/types/column/column_data_collection_segment.hpp @@ -106,6 +106,8 @@ class ColumnDataCollectionSegment { shared_ptr heap; public: + //! Reset this segment for reuse, clearing all metadata while keeping allocated buffers + void Reset(); void AllocateNewChunk(); //! Allocate space for a vector of a specific type in the segment VectorDataIndex AllocateVector(const LogicalType &type, ChunkMetaData &chunk_data, diff --git a/src/duckdb/src/include/duckdb/common/types/conflict_manager.hpp b/src/duckdb/src/include/duckdb/common/types/conflict_manager.hpp index fd75d6e57..1f9ff0b12 100644 --- a/src/duckdb/src/include/duckdb/common/types/conflict_manager.hpp +++ b/src/duckdb/src/include/duckdb/common/types/conflict_manager.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/execution/index/art/art.hpp" @@ -137,7 +138,7 @@ class ConflictManager { //! Delete indexes matching the conflict target. vector> matching_delete_indexes; //! All matching indexes by their name (unique identifier). - case_insensitive_set_t index_names; + identifier_set_t index_names; //! Registers all conflicting rows in a data chunk. unordered_set conflict_rows; diff --git a/src/duckdb/src/include/duckdb/common/types/data_chunk.hpp b/src/duckdb/src/include/duckdb/common/types/data_chunk.hpp index cc706c757..0143c0d20 100644 --- a/src/duckdb/src/include/duckdb/common/types/data_chunk.hpp +++ b/src/duckdb/src/include/duckdb/common/types/data_chunk.hpp @@ -52,18 +52,42 @@ class DataChunk { vector data; public: - inline idx_t size() const { // NOLINT - return count; + inline idx_t size() const { + if (count.IsValid()) { + return count.GetIndex(); + } + for (const auto &v : data) { + if (v.GetBufferRef()) { + return v.size(); + } + } + if (data.empty()) { + // a column-less chunk has nothing to derive a cardinality from; without an explicit count it is empty + return 0; + } + throw InternalException( + "DataChunk::size() called but neither count was set, nor any vectors with valid counts were set"); } inline idx_t ColumnCount() const { return data.size(); } - void SetCardinality(idx_t count_p); - inline void SetCardinality(const DataChunk &other) { - SetCardinality(other.size()); - } + //! Verify all child vectors have the expected cardinality + void CheckCardinality(idx_t count_p); //! Sets the cardinality of all child vectors of this chunk void SetChildCardinality(idx_t count_p); + //! Deprecated: use SetChildCardinality instead. + //! NOTE: this only sets the chunk's cardinality, it does NOT resize the child vectors (matching the historical + //! behavior on main). Callers that mutate the child vectors directly (e.g. Vector::Append/SetValue) and then call + //! SetCardinality rely on this - forwarding to SetChildCardinality would resize/overwrite their data. + [[deprecated("Use CheckCardinality (preferred) or SetChildCardinality instead")]] DUCKDB_API void + SetCardinality(idx_t count_p) { + this->count = count_p; + } + //! Deprecated: use SetChildCardinality instead + [[deprecated("Use CheckCardinality (preferred) or SetChildCardinality instead")]] DUCKDB_API void + SetCardinality(const DataChunk &chunk) { + this->count = chunk.size(); + } DUCKDB_API Value GetValue(idx_t col_idx, idx_t index) const; [[deprecated("Use Vector::Append on data[col_idx] instead (or Vector::SetValue for write-at-index " @@ -169,8 +193,9 @@ class DataChunk { DUCKDB_API void Verify(); private: - //! The amount of tuples stored in the data chunk - idx_t count; + optional_idx count; + +private: //! Vector caches, used to store data when ::Initialize is called vector vector_caches; diff --git a/src/duckdb/src/include/duckdb/common/types/date.hpp b/src/duckdb/src/include/duckdb/common/types/date.hpp index 8fd556ef3..92eee3cc4 100644 --- a/src/duckdb/src/include/duckdb/common/types/date.hpp +++ b/src/duckdb/src/include/duckdb/common/types/date.hpp @@ -128,6 +128,18 @@ class Date { DUCKDB_API static date_t FromString(const string &str, bool strict = false); //! Convert a string in the format "YYYY-MM-DD" to a date object DUCKDB_API static date_t FromCString(const char *str, idx_t len, bool strict = false); + //! Convert an infinite temporal object to a string + template + static inline string ToInfinity(const T &t) { + if (t == T::infinity()) { + return PINF.str; + } else if (t == T::ninfinity()) { + return NINF.str; + } else { + D_ASSERT(!t.IsFinite()); + return string(); + } + } //! Convert a date object to a string in the format "YYYY-MM-DD" DUCKDB_API static string ToString(date_t date); //! Try to convert the string as a give "special" date (e.g, PINF, ...) diff --git a/src/duckdb/src/include/duckdb/common/types/decimal.hpp b/src/duckdb/src/include/duckdb/common/types/decimal.hpp index e08e544a1..7c4fe7d27 100644 --- a/src/duckdb/src/include/duckdb/common/types/decimal.hpp +++ b/src/duckdb/src/include/duckdb/common/types/decimal.hpp @@ -45,6 +45,10 @@ class Decimal { static constexpr uint8_t MAX_WIDTH_DECIMAL = MAX_WIDTH_INT128; public: + //! Whether width/scale form a valid DECIMAL type: width in [1, MAX_WIDTH_DECIMAL] and scale not exceeding width. + static bool IsValidWidthScale(uint8_t width, uint8_t scale) { + return width >= 1 && width <= MAX_WIDTH_DECIMAL && scale <= width; + } static string ToString(int16_t value, uint8_t width, uint8_t scale); static string ToString(int32_t value, uint8_t width, uint8_t scale); static string ToString(int64_t value, uint8_t width, uint8_t scale); diff --git a/src/duckdb/src/include/duckdb/common/types/geometry.hpp b/src/duckdb/src/include/duckdb/common/types/geometry.hpp index 50fe55137..75a39f2b7 100644 --- a/src/duckdb/src/include/duckdb/common/types/geometry.hpp +++ b/src/duckdb/src/include/duckdb/common/types/geometry.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/pair.hpp" +#include "duckdb/storage/storage_info.hpp" #include #include @@ -247,7 +248,7 @@ enum class GeometryStorageType : uint8_t { class Geometry { public: static constexpr idx_t MAX_RECURSION_DEPTH = 16; - static constexpr idx_t VERSION_ADDED = 7; // Added to core in DuckDB v1.5.0 + static constexpr StorageVersion VERSION_ADDED = StorageVersion::V1_5_0; // Added to core in DuckDB v1.5.0 //! Check for legayc geometry type (pre v1.5) static bool IsSpatialGeometryType(const LogicalType &type); @@ -264,10 +265,10 @@ class Geometry { //! Convert from WKB DUCKDB_API static bool FromBinary(const string_t &wkb, string_t &result, StringHeap &heap, bool strict); - DUCKDB_API static bool FromBinary(Vector &source, Vector &result, idx_t count, bool strict); + DUCKDB_API static bool FromBinary(const Vector &source, Vector &result, idx_t count, bool strict); //! Convert to WKB - DUCKDB_API static void ToBinary(Vector &source, Vector &result, idx_t count); + DUCKDB_API static void ToBinary(const Vector &source, Vector &result); //! Get the geometry type and vertex type from the WKB DUCKDB_API static pair GetType(const string_t &wkb); @@ -277,14 +278,15 @@ class Geometry { DUCKDB_API static uint32_t GetExtent(const string_t &wkb, GeometryExtent &extent, bool &has_any_empty); //! Convert to vectorized format - DUCKDB_API static void ToVectorizedFormat(Vector &source, Vector &target, idx_t count, GeometryType geom_type, + DUCKDB_API static void ToVectorizedFormat(const Vector &source, Vector &target, idx_t count, GeometryType geom_type, VertexType vert_type); - DUCKDB_API static void ToVectorizedFormat(Vector &source, Vector &target, idx_t count, GeometryStorageType type); + DUCKDB_API static void ToVectorizedFormat(const Vector &source, Vector &target, idx_t count, + GeometryStorageType type); //! Convert from vectorized format - DUCKDB_API static void FromVectorizedFormat(Vector &source, Vector &target, idx_t count, GeometryType geom_type, - VertexType vert_type, idx_t result_offset); - DUCKDB_API static void FromVectorizedFormat(Vector &source, Vector &target, idx_t count, GeometryStorageType type, - idx_t result_offset); + DUCKDB_API static void FromVectorizedFormat(const Vector &source, Vector &target, idx_t count, + GeometryType geom_type, VertexType vert_type, idx_t result_offset); + DUCKDB_API static void FromVectorizedFormat(const Vector &source, Vector &target, idx_t count, + GeometryStorageType type, idx_t result_offset); //! Get the vectorized logical type for a given geometry and vertex type DUCKDB_API static LogicalType GetVectorizedType(GeometryStorageType type); @@ -293,11 +295,11 @@ class Geometry { DUCKDB_API static pair GetSpecializedType(GeometryStorageType type); DUCKDB_API static void FromSpatialGeometry(const string_t &source, string_t &target, Vector &vector); - DUCKDB_API static void FromSpatialGeometry(Vector &source, Vector &target, idx_t count, idx_t result_offset); + DUCKDB_API static void FromSpatialGeometry(const Vector &source, Vector &target, idx_t count, idx_t result_offset); DUCKDB_API static void FromSpatialGeometry(const string_t &source, string &target); DUCKDB_API static void ToSpatialGeometry(const string_t &source, string_t &target, Vector &vector); - DUCKDB_API static void ToSpatialGeometry(Vector &source, Vector &target, idx_t count); + DUCKDB_API static void ToSpatialGeometry(const Vector &source, Vector &target, idx_t count); DUCKDB_API static void ToSpatialGeometry(const string_t &source, string &target); }; diff --git a/src/duckdb/src/include/duckdb/common/types/hyperloglog.hpp b/src/duckdb/src/include/duckdb/common/types/hyperloglog.hpp index d6222da5e..5a18c4379 100644 --- a/src/duckdb/src/include/duckdb/common/types/hyperloglog.hpp +++ b/src/duckdb/src/include/duckdb/common/types/hyperloglog.hpp @@ -79,7 +79,7 @@ class HyperLogLogP : public HyperLogLogBase { public: //! Add data to this HLL - void Update(Vector &input, Vector &hash_vec, const idx_t count) { + void Update(const Vector &input, const Vector &hash_vec) { auto idata = input.Validity(); const auto hashes = hash_vec.Values(); @@ -89,6 +89,7 @@ class HyperLogLogP : public HyperLogLogBase { } } else { D_ASSERT(hash_vec.GetVectorType() == VectorType::FLAT_VECTOR); + const idx_t count = hash_vec.size(); if (idata.CannotHaveNull()) { for (idx_t i = 0; i < count; ++i) { const auto hash = hashes[i].GetValue(); diff --git a/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp b/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp index 999d1bf75..6e376fc2c 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/partitioned_tuple_data.hpp @@ -92,6 +92,10 @@ class PartitionedTupleData { //! Initializes a local state for parallel partitioning that can be merged into this PartitionedTupleData void InitializeAppendState(PartitionedTupleDataAppendState &state, TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; + //! Reuses an existing append state for a new iteration - faster than InitializeAppendState when the + //! append state already has pre-allocated buffers (selection vectors, partition pin states, chunk state) + virtual void ResetAppendState(PartitionedTupleDataAppendState &state, + TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; //! Appends a DataChunk to this PartitionedTupleData void Append(PartitionedTupleDataAppendState &state, DataChunk &input, const SelectionVector &append_sel = *FlatVector::IncrementalSelectionVector(), diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp index b7a3b8dd4..7d9b3d60e 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_allocator.hpp @@ -101,6 +101,8 @@ class TupleDataAllocator { void ReleaseOrStoreHandles(TupleDataPinState &state, TupleDataSegment &segment); //! Sets 'can_destroy' to true for all blocks so they aren't added to the eviction queue void SetDestroyBufferUponUnpin(); + //! Release all row/heap blocks and reset the allocator for reuse without a heap allocation + void Reset(); //! Destroy the blocks between the given indices void DestroyRowBlocks(idx_t row_block_begin, idx_t row_block_end); void DestroyHeapBlocks(idx_t heap_block_begin, idx_t heap_block_end); diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp index 93558050a..49c78d23f 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_segment.hpp @@ -161,6 +161,8 @@ struct TupleDataSegment { idx_t SizeInBytes() const; //! Unpins all held pins void Unpin(); + //! Reset this segment for reuse, clearing chunk metadata while keeping the segment object alive + void Reset(); //! Verify counts of the chunks in this segment void Verify() const; diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp index bbe487d7c..7e21e505e 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_states.hpp @@ -89,6 +89,12 @@ struct TupleDataPinState { buffer_handle_map_t row_handles; buffer_handle_map_t heap_handles; TupleDataPinProperties properties = TupleDataPinProperties::INVALID; + + void Reset() { + row_handles.clear(); + heap_handles.clear(); + properties = TupleDataPinProperties::INVALID; + } }; struct CombinedListData { @@ -130,6 +136,15 @@ struct TupleDataChunkState { //! Re-usable arrays used while building buffer space unsafe_vector> chunk_parts; unsafe_vector> chunk_part_indices; + + void ResetForScan() { + column_ids.clear(); + chunk_lock = nullptr; + cached_cast_vectors.clear(); + cached_cast_vector_cache.clear(); + chunk_parts.clear(); + chunk_part_indices.clear(); + } }; struct SortKeyPayloadState { @@ -147,6 +162,13 @@ struct TupleDataScanState { TupleDataChunkState chunk_state; idx_t segment_index = DConstants::INVALID_INDEX; idx_t chunk_index = DConstants::INVALID_INDEX; + + void Reset() { + pin_state.Reset(); + chunk_state.ResetForScan(); + segment_index = DConstants::INVALID_INDEX; + chunk_index = DConstants::INVALID_INDEX; + } }; struct TupleDataParallelScanState { diff --git a/src/duckdb/src/include/duckdb/common/types/value.hpp b/src/duckdb/src/include/duckdb/common/types/value.hpp index ffe6114c8..4ecce53b6 100644 --- a/src/duckdb/src/include/duckdb/common/types/value.hpp +++ b/src/duckdb/src/include/duckdb/common/types/value.hpp @@ -52,7 +52,8 @@ class Value { //! Create a DOUBLE value DUCKDB_API Value(double val); // NOLINT: Allow implicit conversion from `double` //! Create a VARCHAR value - DUCKDB_API Value(const char *val); // NOLINT: Allow implicit conversion from `const char *` + DUCKDB_API Value(const char *val); // NOLINT: Allow implicit conversion from `const char *` + DUCKDB_API Value(const Identifier &val); // NOLINT: Allow implicit conversion from `Identifier` //! Create a NULL value DUCKDB_API Value(std::nullptr_t val); // NOLINT: Allow implicit conversion from `nullptr_t` //! Create a VARCHAR value diff --git a/src/duckdb/src/include/duckdb/common/types/variant.hpp b/src/duckdb/src/include/duckdb/common/types/variant.hpp index bc3b33d7c..2b7219de4 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/string.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/types/string_type.hpp" +#include "duckdb/storage/storage_info.hpp" namespace duckdb_yyjson { struct yyjson_mut_doc; @@ -29,7 +30,7 @@ enum class VariantChildLookupMode : uint8_t { INVALID, BY_KEY, BY_INDEX }; struct Variant { public: - static constexpr idx_t VERSION_ADDED = 7; // Added to core in DuckDB v1.5.0 + static constexpr StorageVersion VERSION_ADDED = StorageVersion::V1_5_0; // Added to core in DuckDB v1.5.0 }; struct VariantPathComponent { @@ -41,6 +42,23 @@ struct VariantPathComponent { explicit VariantPathComponent(uint32_t index) : lookup_mode(VariantChildLookupMode::BY_INDEX), index(index) { } + bool operator==(const VariantPathComponent &other) const { + if (lookup_mode != other.lookup_mode) { + return false; + } + switch (lookup_mode) { + case VariantChildLookupMode::BY_KEY: + return key == other.key; + case VariantChildLookupMode::BY_INDEX: + return index == other.index; + default: + return false; + } + } + bool operator!=(const VariantPathComponent &other) const { + return !(*this == other); + } + public: VariantChildLookupMode lookup_mode; string key; diff --git a/src/duckdb/src/include/duckdb/common/types/vector.hpp b/src/duckdb/src/include/duckdb/common/types/vector.hpp index 064847b42..e478a8610 100644 --- a/src/duckdb/src/include/duckdb/common/types/vector.hpp +++ b/src/duckdb/src/include/duckdb/common/types/vector.hpp @@ -166,7 +166,7 @@ class Vector { [[deprecated("Resize has been replaced by Reserve - use Reserve(to_reserve) instead")]] DUCKDB_API void Resize(idx_t current_size, idx_t to_reserve); - DUCKDB_API void Serialize(Serializer &serializer, idx_t count, bool compressed_serialization = true); + DUCKDB_API void Serialize(Serializer &serializer, bool compressed_serialization = true); DUCKDB_API void Deserialize(Deserializer &deserializer, idx_t count); //! Returns the uncompressed size of the data stored within this vector @@ -204,11 +204,12 @@ class Vector { // Setters DUCKDB_API void SetVectorType(VectorType vector_type); + DUCKDB_API void FlattenAndSetConstant(); // Transform vector to an equivalent dictionary vector - static void DebugTransformToDictionary(Vector &vector, idx_t count); + static void DebugTransformToDictionary(Vector &vector); // Transform vector to an equivalent nested vector - static void DebugShuffleNestedVector(Vector &vector, idx_t count); + static void DebugShuffleNestedVector(Vector &vector); template VectorIterator Values() const; diff --git a/src/duckdb/src/include/duckdb/common/vector.hpp b/src/duckdb/src/include/duckdb/common/vector.hpp index 3035edcd5..10ccb02be 100644 --- a/src/duckdb/src/include/duckdb/common/vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector.hpp @@ -110,6 +110,13 @@ class vector : public std::vector { // NOLINT: matching na return get(original::size() - 1); } + void pop_back() { // NOLINT: hiding on purpose + if (MemorySafety::ENABLED && original::empty()) { + throw InternalException("'pop_back' called on an empty vector!"); + } + original::pop_back(); + } + void unsafe_erase_at(idx_t idx) { // NOLINT: not using camelcase on purpose here original::erase(original::begin() + static_cast(idx)); } diff --git a/src/duckdb/src/include/duckdb/common/vector/dictionary_vector.hpp b/src/duckdb/src/include/duckdb/common/vector/dictionary_vector.hpp index 0a5bff99b..047687a1c 100644 --- a/src/duckdb/src/include/duckdb/common/vector/dictionary_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector/dictionary_vector.hpp @@ -22,9 +22,9 @@ class DictionaryEntry { Vector data; //! Optional id to uniquely identify re-occurring dictionaries string id; - //! For caching the hashes of a child buffer - mutex cached_hashes_lock; - unique_ptr cached_hashes; + //! For caching the hashes of a child buffer (mutable: cache is logically const) + mutable mutex cached_hashes_lock; + mutable unique_ptr cached_hashes; }; //! The DictionaryBuffer holds a selection vector and a reference to a DictionaryEntry @@ -151,7 +151,7 @@ struct DictionaryVector { return DictionarySize(vector).IsValid() && !DictionaryId(vector).empty() && CanCacheHashes(vector.GetType()); } static buffer_ptr CreateReusableDictionary(const LogicalType &type, const idx_t &size); - static const Vector &GetCachedHashes(Vector &input); + static const Vector &GetCachedHashes(const Vector &input); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector/flat_vector.hpp b/src/duckdb/src/include/duckdb/common/vector/flat_vector.hpp index dcd299140..4fb5c0b9d 100644 --- a/src/duckdb/src/include/duckdb/common/vector/flat_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector/flat_vector.hpp @@ -182,6 +182,7 @@ struct FlatVector { template static VectorWriter Writer(Vector &vector, idx_t count, idx_t offset) { + SetSize(vector, offset + count); return VectorWriter(vector, count, offset); } template diff --git a/src/duckdb/src/include/duckdb/common/vector/list_vector.hpp b/src/duckdb/src/include/duckdb/common/vector/list_vector.hpp index d1993c42f..e99e4c5af 100644 --- a/src/duckdb/src/include/duckdb/common/vector/list_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector/list_vector.hpp @@ -109,13 +109,15 @@ struct ListVector { DUCKDB_API static void Append(Vector &target, const Vector &source, const SelectionVector &sel, idx_t source_size); DUCKDB_API static void PushBack(Vector &target, const Value &insert); //! Returns the child_vector of list starting at offset until offset + count, and its length - DUCKDB_API static idx_t GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count); + DUCKDB_API static idx_t GetConsecutiveChildList(const Vector &list, Vector &result, idx_t offset, idx_t count); //! Returns information to only copy a section of a list child vector - DUCKDB_API static ConsecutiveChildListInfo GetConsecutiveChildListInfo(Vector &list, idx_t offset, idx_t count); + DUCKDB_API static ConsecutiveChildListInfo GetConsecutiveChildListInfo(const Vector &list, idx_t offset, + idx_t count); //! Slice and flatten a child vector to only contain a consecutive subsection of the child entries - DUCKDB_API static void GetConsecutiveChildSelVector(Vector &list, SelectionVector &sel, idx_t offset, idx_t count); + DUCKDB_API static void GetConsecutiveChildSelVector(const Vector &list, SelectionVector &sel, idx_t offset, + idx_t count); //! Returns the total number of entries in the list - DUCKDB_API static idx_t GetTotalEntryCount(Vector &list, idx_t count); + DUCKDB_API static idx_t GetTotalEntryCount(const Vector &list); private: template diff --git a/src/duckdb/src/include/duckdb/common/vector/map_vector.hpp b/src/duckdb/src/include/duckdb/common/vector/map_vector.hpp index 7354970ef..128da35cd 100644 --- a/src/duckdb/src/include/duckdb/common/vector/map_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector/map_vector.hpp @@ -21,9 +21,10 @@ struct MapVector { DUCKDB_API static Vector &GetKeys(Vector &vector); DUCKDB_API static Vector &GetValues(Vector &vector); DUCKDB_API static MapInvalidReason - CheckMapValidity(Vector &map, idx_t count, const SelectionVector &sel = *FlatVector::IncrementalSelectionVector()); + CheckMapValidity(const Vector &map, idx_t count, + const SelectionVector &sel = *FlatVector::IncrementalSelectionVector()); DUCKDB_API static void EvalMapInvalidReason(MapInvalidReason reason); - DUCKDB_API static void MapConversionVerify(Vector &vector, idx_t count); + DUCKDB_API static void MapConversionVerify(const Vector &vector, idx_t count); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector/shredded_vector.hpp b/src/duckdb/src/include/duckdb/common/vector/shredded_vector.hpp index b67e66458..03ef7b140 100644 --- a/src/duckdb/src/include/duckdb/common/vector/shredded_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector/shredded_vector.hpp @@ -64,7 +64,7 @@ struct ShreddedVector { DUCKDB_API static void Unshred(const Vector &vec, const SelectionVector &sel, idx_t count); //! Returns whether or not the vector is fully shredded - DUCKDB_API static bool IsFullyShredded(Vector &vec); + DUCKDB_API static bool IsFullyShredded(const Vector &vec); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector/struct_vector.hpp b/src/duckdb/src/include/duckdb/common/vector/struct_vector.hpp index 54eb3430b..8d9f2e389 100644 --- a/src/duckdb/src/include/duckdb/common/vector/struct_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector/struct_vector.hpp @@ -42,6 +42,7 @@ class VectorStructBuffer : public VectorBuffer { return children; } void SetVectorType(VectorType vector_type) override; + void PrepareChildrenForSetConstant(); public: idx_t GetDataSize(const LogicalType &type, idx_t count) const override; diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp index 2615feb42..cd036822c 100644 --- a/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp +++ b/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp @@ -11,9 +11,12 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/vector/constant_vector.hpp" +#include "duckdb/common/vector/dictionary_vector.hpp" #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/function/aggregate_state.hpp" +#include +#include namespace duckdb { @@ -108,7 +111,7 @@ class AggregateExecutor { #endif static inline void UnaryScatterLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, STATE_TYPE **__restrict states, const SelectionVector &isel, - const SelectionVector &ssel, ValidityMask &mask, idx_t count) { + const SelectionVector &ssel, const ValidityMask &mask, idx_t count) { #ifdef DUCKDB_SMALLER_BINARY const auto HAS_ISEL = isel.IsSet(); const auto HAS_SSEL = ssel.IsSet(); @@ -169,7 +172,7 @@ class AggregateExecutor { template static inline void UnaryUpdateLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, - STATE_TYPE *__restrict state, idx_t count, ValidityMask &mask, + STATE_TYPE *__restrict state, idx_t count, const ValidityMask &mask, const SelectionVector &__restrict sel_vector) { AggregateUnaryInput input(aggr_input_data, mask); if (OP::IgnoreNull() && mask.CanHaveNull()) { @@ -189,12 +192,49 @@ class AggregateExecutor { } } + template + static inline void UnarySelectedUpdateLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, + STATE_TYPE *__restrict state, idx_t count, const ValidityMask &mask, + const sel_t *__restrict sel) { + AggregateUnaryInput input(aggr_input_data, mask); + for (idx_t i = 0; i < count; i++) { + input.input_idx = sel[i]; + if (!CHECK_VALIDITY || mask.RowIsValidUnsafe(input.input_idx)) { + OP::template Operation(*state, idata[input.input_idx], input); + } + } + } + + template + static inline void UnarySelectedUpdateLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, + STATE_TYPE *__restrict state, idx_t count, const ValidityMask &mask, + const SelectionVector &isel, const sel_t *__restrict sel) { + AggregateUnaryInput input(aggr_input_data, mask); + for (idx_t i = 0; i < count; i++) { + input.input_idx = isel.get_index(sel[i]); + if (!CHECK_VALIDITY || mask.RowIsValidUnsafe(input.input_idx)) { + OP::template Operation(*state, idata[input.input_idx], input); + } + } + } + + template + static inline void UnaryUpdateLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, + STATE_TYPE *__restrict state, idx_t count, const ValidityMask &mask, + const sel_t *__restrict sel) { + if (OP::IgnoreNull() && mask.CanHaveNull()) { + UnarySelectedUpdateLoop(idata, aggr_input_data, state, count, mask, sel); + } else { + UnarySelectedUpdateLoop(idata, aggr_input_data, state, count, mask, sel); + } + } + template static inline void BinaryScatterLoop(const A_TYPE *__restrict adata, AggregateInputData &aggr_input_data, const B_TYPE *__restrict bdata, STATE_TYPE **__restrict states, idx_t count, const SelectionVector &asel, const SelectionVector &bsel, - const SelectionVector &ssel, ValidityMask &avalidity, - ValidityMask &bvalidity) { + const SelectionVector &ssel, const ValidityMask &avalidity, + const ValidityMask &bvalidity) { AggregateBinaryInput input(aggr_input_data, avalidity, bvalidity); if (OP::IgnoreNull() && (avalidity.CanHaveNull() || bvalidity.CanHaveNull())) { // potential NULL values and NULL values are ignored @@ -223,7 +263,7 @@ class AggregateExecutor { static inline void BinaryUpdateLoop(const A_TYPE *__restrict adata, AggregateInputData &aggr_input_data, const B_TYPE *__restrict bdata, STATE_TYPE *__restrict state, idx_t count, const SelectionVector &asel, const SelectionVector &bsel, - ValidityMask &avalidity, ValidityMask &bvalidity) { + const ValidityMask &avalidity, const ValidityMask &bvalidity) { AggregateBinaryInput input(aggr_input_data, avalidity, bvalidity); if (OP::IgnoreNull() && (avalidity.CanHaveNull() || bvalidity.CanHaveNull())) { // potential NULL values and NULL values are ignored @@ -246,9 +286,36 @@ class AggregateExecutor { } } + template + static inline void BinarySelectedUpdateLoop(const A_TYPE *__restrict adata, AggregateInputData &aggr_input_data, + const B_TYPE *__restrict bdata, STATE_TYPE *__restrict state, + idx_t count, const SelectionVector &asel, const SelectionVector &bsel, + ValidityMask &avalidity, ValidityMask &bvalidity, + const sel_t *__restrict sel) { + AggregateBinaryInput input(aggr_input_data, avalidity, bvalidity); + for (idx_t i = 0; i < count; i++) { + const auto idx = sel[i]; + input.lidx = asel.get_index(idx); + input.ridx = bsel.get_index(idx); + if (!CHECK_VALIDITY || (avalidity.RowIsValidUnsafe(input.lidx) && bvalidity.RowIsValidUnsafe(input.ridx))) { + OP::template Operation(*state, adata[input.lidx], bdata[input.ridx], + input); + } + } + } + public: template static void NullaryScatter(Vector &states, AggregateInputData &aggr_input_data, idx_t count) { + // COUNT(*) can add run lengths directly. + if (aggr_input_data.clustered) { + auto &cs = *aggr_input_data.clustered; + for (idx_t r = 0; r < cs.n_group_runs; r++) { + OP::template ConstantOperation(*reinterpret_cast(cs.group_runs[r].state), + aggr_input_data, cs.group_runs[r].count); + } + return; + } if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { auto sdata = ConstantVector::GetData(states); OP::template ConstantOperation(**sdata, aggr_input_data, count); @@ -264,13 +331,338 @@ class AggregateExecutor { } } + template + static void NullaryClustUpdate(AggregateInputData &aggr_input_data, const ClusteredAggr &clustered, idx_t count) { + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + OP::template ConstantOperation( + *reinterpret_cast(clustered.group_runs[r].state), aggr_input_data, + clustered.group_runs[r].count); + } + } + template static void NullaryUpdate(data_ptr_t state, AggregateInputData &aggr_input_data, idx_t count) { OP::template ConstantOperation(*reinterpret_cast(state), aggr_input_data, count); } + template + using void_t_helper = void; + + template + struct HasClusteredLocalState : std::false_type {}; + template + struct HasClusteredLocalState::Type>> + : std::true_type {}; + template + using clustered_local_state_t = typename OP::template ClusteredLocalState::Type; + + template + struct HasClusteredLocalRepeatCount : std::false_type {}; + template + struct HasClusteredLocalRepeatCount< + OP, INPUT, LOCAL, + void_t_helper( + std::declval(), std::declval(), std::declval()))>> : std::true_type {}; + + template + static inline bool UpdateUnaryClusteredOpt(const INPUT_TYPE *__restrict vals, LOCAL_TYPE &local, idx_t count, + const ValidityMask &mask, const sel_t *sel = nullptr, + const SelectionVector *isel = nullptr) { + bool saw_value = !CHECK_VALIDITY && count != 0; + if (isel) { + if (sel) { + for (idx_t i = 0; i < count; i++) { + auto idx = isel->get_index(sel[i]); + if constexpr (CHECK_VALIDITY) { + if (!mask.RowIsValidUnsafe(idx)) { + continue; + } + saw_value = true; + } + OP::template UpdateClusteredLocal(local, vals[idx]); + } + } else { + for (idx_t i = 0; i < count; i++) { + auto idx = isel->get_index(i); + if constexpr (CHECK_VALIDITY) { + if (!mask.RowIsValidUnsafe(idx)) { + continue; + } + saw_value = true; + } + OP::template UpdateClusteredLocal(local, vals[idx]); + } + } + } else if (sel) { + for (idx_t i = 0; i < count; i++) { + auto idx = sel[i]; + if constexpr (CHECK_VALIDITY) { + if (!mask.RowIsValidUnsafe(idx)) { + continue; + } + saw_value = true; + } + OP::template UpdateClusteredLocal(local, vals[idx]); + } + } else { + for (idx_t i = 0; i < count; i++) { + if constexpr (CHECK_VALIDITY) { + if (!mask.RowIsValidUnsafe(i)) { + continue; + } + saw_value = true; + } + OP::template UpdateClusteredLocal(local, vals[i]); + } + } + return saw_value; + } + + template + static inline bool UpdateUnaryClusteredDispatch(const INPUT_TYPE *__restrict vals, LOCAL_TYPE &local, idx_t count, + const ValidityMask &mask, const sel_t *sel = nullptr, + const SelectionVector *isel = nullptr) { + if (OP::IgnoreNull() && mask.CanHaveNull()) { + return UpdateUnaryClusteredOpt(vals, local, count, mask, sel, isel); + } + return UpdateUnaryClusteredOpt(vals, local, count, mask, sel, isel); + } + + template + static inline void UpdateUnaryClusteredLocalRepeat(LOCAL_TYPE &local, const INPUT_TYPE &input, idx_t count, + std::true_type) { + OP::template UpdateClusteredLocal(local, input, count); + } + + template + static inline void UpdateUnaryClusteredLocalRepeat(LOCAL_TYPE &local, const INPUT_TYPE &input, idx_t count, + std::false_type) { + for (idx_t i = 0; i < count; i++) { + OP::template UpdateClusteredLocal(local, input); + } + } + + template + static void ExecuteUnaryClusteredOpt(const INPUT_TYPE *vals, const ClusteredAggr &clustered, + const ValidityMask &validity, const SelectionVector *isel = nullptr, + const sel_t *cluster_iter = nullptr) { + idx_t pos = 0; + using local_type = clustered_local_state_t; + if (cluster_iter) { + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + local_type local; + OP::template InitializeClusteredLocal(local, state); + auto run_count = clustered.group_runs[r].count; + auto saw_value = UpdateUnaryClusteredOpt( + vals, local, run_count, validity, cluster_iter + pos); + OP::template FlushClusteredLocal(state, local, saw_value); + pos += run_count; + } + } else if (isel) { + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + local_type local; + OP::template InitializeClusteredLocal(local, state); + auto run_count = clustered.group_runs[r].count; + auto saw_value = UpdateUnaryClusteredOpt( + vals, local, run_count, validity, clustered.group_runs[r].sel, isel); + OP::template FlushClusteredLocal(state, local, saw_value); + } + } else { + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + local_type local; + OP::template InitializeClusteredLocal(local, state); + auto run_count = clustered.group_runs[r].count; + auto saw_value = UpdateUnaryClusteredOpt( + vals, local, run_count, validity, clustered.group_runs[r].sel); + OP::template FlushClusteredLocal(state, local, saw_value); + } + } + } + + template + static inline void + ExecuteUnaryClusteredDispatch(const INPUT_TYPE *vals, const ClusteredAggr &clustered, const ValidityMask &validity, + const SelectionVector *isel = nullptr, const sel_t *cluster_iter = nullptr) { + if constexpr (HasI64HugeintSumFastPath::value) { + if (!validity.CanHaveNull()) { + OP::template ExecuteFlatI64HugeintSum(vals, clustered, isel, cluster_iter); + return; + } + } + if (OP::IgnoreNull() && validity.CanHaveNull()) { + ExecuteUnaryClusteredOpt(vals, clustered, validity, isel, cluster_iter); + } else { + ExecuteUnaryClusteredOpt(vals, clustered, validity, isel, cluster_iter); + } + } + + static inline bool IsDenseSingleRun(const ClusteredAggr &clustered, idx_t count) { + return clustered.n_group_runs == 1 && clustered.group_runs[0].count == count && + clustered.group_runs[0].sel == nullptr; + } + + template + static void ExecuteUnaryClusteredDictOpt(Vector &input, const ClusteredAggr &clustered, idx_t count, + const sel_t *cluster_iter = nullptr) { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(idata); + auto vals = UnifiedVectorFormat::GetData(idata); + if constexpr (SIMPLE_DICT) { + ExecuteUnaryClusteredDispatch(vals, clustered, idata.validity, nullptr, + cluster_iter); + } else { + ExecuteUnaryClusteredDispatch(vals, clustered, idata.validity, idata.sel); + } + } + + template + static void ExecuteUnaryClusteredOpt(Vector &input, const ClusteredAggr &clustered, idx_t count) { + auto vals = FlatVector::GetData(input); + auto &validity = FlatVector::Validity(input); + if (IsDenseSingleRun(clustered, count)) { + auto &state = *reinterpret_cast(clustered.group_runs[0].state); + using local_type = clustered_local_state_t; + local_type local; + OP::template InitializeClusteredLocal(local, state); + auto saw_value = UpdateUnaryClusteredDispatch(vals, local, count, validity); + OP::template FlushClusteredLocal(state, local, saw_value); + return; + } + ExecuteUnaryClusteredDispatch(vals, clustered, validity); + } + + template + static void ExecuteUnaryClusteredConstantOpt(Vector &input, const ClusteredAggr &clustered) { + using local_type = clustered_local_state_t; + if (OP::IgnoreNull() && ConstantVector::IsNull(input)) { + return; + } + const auto &constant_input = *ConstantVector::GetData(input); + using has_repeat_count = HasClusteredLocalRepeatCount; + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + local_type local; + OP::template InitializeClusteredLocal(local, state); + auto run_count = clustered.group_runs[r].count; + if (run_count != 0) { + UpdateUnaryClusteredLocalRepeat(local, constant_input, run_count, + has_repeat_count {}); + OP::template FlushClusteredLocal(state, local, true); + } else { + OP::template FlushClusteredLocal(state, local, false); + } + } + } + + template + static void ExecuteUnaryClusteredUnifiedOpt(Vector &input, const ClusteredAggr &clustered, idx_t count) { + UnifiedVectorFormat idata; + input.ToUnifiedFormat(idata); + auto vals = UnifiedVectorFormat::GetData(idata); + ExecuteUnaryClusteredDispatch(vals, clustered, idata.validity, idata.sel); + } + + // true if OP defines ClusteredOp (early-exit optimization for first/last/any_value). + template + struct HasClusteredOperation : std::false_type {}; + template + struct HasClusteredOperation)>> + : std::true_type {}; + + template + struct HasI64HugeintSumFastPath : std::false_type {}; + template + struct HasI64HugeintSumFastPath> + : std::integral_constant::value && OP::ClusteredI64HugeintSum> {}; + + template + static void ExecuteUnaryClustConstantCustomState(Vector &input, AggregateInputData &aggr_input_data, + const ClusteredAggr &clustered) { + if (OP::IgnoreNull() && ConstantVector::IsNull(input)) { + return; + } + auto idata = ConstantVector::GetData(input); + AggregateUnaryInput unary_input(aggr_input_data, ConstantVector::Validity(input)); + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + OP::template ConstantOperation(state, *idata, unary_input, + clustered.group_runs[r].count); + } + } + + template + static void ExecuteUnaryNoClusteredOpt(Vector &input, AggregateInputData &aggr_input_data, + const ClusteredAggr &clustered, idx_t count) { + static_assert(HasClusteredOperation::value, "Expected custom clustered operation"); + if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { + ExecuteUnaryClustConstantCustomState(input, aggr_input_data, clustered); + return; + } + UnifiedVectorFormat idata; + input.ToUnifiedFormat(idata); + auto vals = UnifiedVectorFormat::GetData(idata); + AggregateUnaryInput unary_input(aggr_input_data, idata.validity); + for (idx_t r = 0; r < clustered.n_group_runs; r++) { + auto &state = *reinterpret_cast(clustered.group_runs[r].state); + OP::template ClusteredOp(state, vals, unary_input, clustered.group_runs[r].sel, + *idata.sel, idata.validity, 0, + clustered.group_runs[r].count); + } + } + + template + static void ExecuteUnaryClustered(Vector &input, AggregateInputData &aggr_input_data, + const ClusteredAggr &clustered, idx_t count) { + if constexpr (HasClusteredLocalState::value) { + switch (input.GetVectorType()) { + case VectorType::FLAT_VECTOR: + ExecuteUnaryClusteredOpt(input, clustered, count); + return; + case VectorType::CONSTANT_VECTOR: + ExecuteUnaryClusteredConstantOpt(input, clustered); + return; + case VectorType::DICTIONARY_VECTOR: { + if constexpr (HasI64HugeintSumFastPath::value) { + if (clustered.state) { + auto props = clustered.state->GetDictProps(input); + if (props && *props) { + OP::template ExecuteDictI64HugeintSum(input, clustered, props->get()); + return; + } + } + } + auto *cluster_iter = clustered.ClusterIter(input, count); + if (cluster_iter) { + ExecuteUnaryClusteredDictOpt(input, clustered, count, + cluster_iter); + } else { + ExecuteUnaryClusteredDictOpt(input, clustered, count); + } + return; + } + default: + ExecuteUnaryClusteredUnifiedOpt(input, clustered, count); + return; + } + } + if constexpr (HasClusteredOperation::value) { + ExecuteUnaryNoClusteredOpt(input, aggr_input_data, clustered, count); + } else { + D_ASSERT(false); + } + } + template static void UnaryScatter(Vector &input, Vector &states, AggregateInputData &aggr_input_data, idx_t count) { + if (aggr_input_data.clustered) { + ExecuteUnaryClustered(input, aggr_input_data, *aggr_input_data.clustered, + count); + return; + } if (input.GetVectorType() == VectorType::CONSTANT_VECTOR && states.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (OP::IgnoreNull() && ConstantVector::IsNull(input)) { diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/binary_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/binary_executor.hpp index 09a003650..f18a253bb 100644 --- a/src/duckdb/src/include/duckdb/common/vector_operations/binary_executor.hpp +++ b/src/duckdb/src/include/duckdb/common/vector_operations/binary_executor.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/optional.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/vector/constant_vector.hpp" @@ -19,6 +20,27 @@ namespace duckdb { +//! Complement-fold trait for comparison selection. On the NO_NULL path a comparison op's selection +//! is the exact complement of its complement op's selection (no NULL third bucket). +//! Note LessThanEquals is not handled since that's done upstream in the general case (see +//! VectorOperations::LessThan[Equals]) +template +struct ComparisonSelectComplement { + static constexpr bool FOLD = false; +}; +template <> +struct ComparisonSelectComplement { + static constexpr bool FOLD = true; + using COMPLEMENT = Equals; + static constexpr bool SWAP_OPERANDS = false; +}; +template <> +struct ComparisonSelectComplement { + static constexpr bool FOLD = true; + using COMPLEMENT = GreaterThan; + static constexpr bool SWAP_OPERANDS = true; // GreaterThanEquals(a,b) == !GreaterThan(b,a) +}; + struct DefaultNullCheckOperator { template static inline bool Operation(LEFT_TYPE left, RIGHT_TYPE right) { @@ -127,9 +149,11 @@ struct BinaryExecutor { #endif template - static void ExecuteConstant(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + static void ExecuteConstant(const Vector &left, const Vector &right, Vector &result, idx_t count, FUNC fun) { result.SetVectorType(VectorType::CONSTANT_VECTOR); - FlatVector::SetSize(result, count); + if (result.size() != count) { + FlatVector::SetSize(result, count); + } auto ldata = ConstantVector::GetData(left); auto rdata = ConstantVector::GetData(right); @@ -146,7 +170,7 @@ struct BinaryExecutor { #ifndef DUCKDB_SMALLER_BINARY template - static void ExecuteFlat(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + static void ExecuteFlat(const Vector &left, const Vector &right, Vector &result, idx_t count, FUNC fun) { auto ldata = LEFT_CONSTANT ? ConstantVector::GetData(left) : FlatVector::GetData(left); auto rdata = RIGHT_CONSTANT ? ConstantVector::GetData(right) : FlatVector::GetData(right); @@ -158,31 +182,34 @@ struct BinaryExecutor { } result.SetVectorType(VectorType::FLAT_VECTOR); + if (result.size() != count) { + FlatVector::SetSize(result, count); + } auto result_data = FlatVector::GetDataMutable(result); auto &result_validity = FlatVector::ValidityMutable(result); if (LEFT_CONSTANT) { if (OPWRAPPER::AddsNulls()) { - result_validity.Copy(FlatVector::ValidityMutable(right), count); + result_validity.Copy(FlatVector::Validity(right), count); } else { - FlatVector::SetValidity(result, FlatVector::ValidityMutable(right)); + FlatVector::SetValidity(result, FlatVector::Validity(right)); } } else if (RIGHT_CONSTANT) { if (OPWRAPPER::AddsNulls()) { - result_validity.Copy(FlatVector::ValidityMutable(left), count); + result_validity.Copy(FlatVector::Validity(left), count); } else { - FlatVector::SetValidity(result, FlatVector::ValidityMutable(left)); + FlatVector::SetValidity(result, FlatVector::Validity(left)); } } else { if (OPWRAPPER::AddsNulls()) { - result_validity.Copy(FlatVector::ValidityMutable(left), count); + result_validity.Copy(FlatVector::Validity(left), count); if (result_validity.CannotHaveNull()) { - result_validity.Copy(FlatVector::ValidityMutable(right), count); + result_validity.Copy(FlatVector::Validity(right), count); } else { - result_validity.Combine(FlatVector::ValidityMutable(right), count); + result_validity.Combine(FlatVector::Validity(right), count); } } else { - FlatVector::SetValidity(result, FlatVector::ValidityMutable(left)); - result_validity.Combine(FlatVector::ValidityMutable(right), count); + FlatVector::SetValidity(result, FlatVector::Validity(left)); + result_validity.Combine(FlatVector::Validity(right), count); } } ExecuteFlatLoop( @@ -219,13 +246,16 @@ struct BinaryExecutor { } template - static void ExecuteGeneric(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + static void ExecuteGeneric(const Vector &left, const Vector &right, Vector &result, idx_t count, FUNC fun) { UnifiedVectorFormat ldata, rdata; left.ToUnifiedFormat(ldata); right.ToUnifiedFormat(rdata); result.SetVectorType(VectorType::FLAT_VECTOR); + if (result.size() != count) { + FlatVector::SetSize(result, count); + } auto result_data = FlatVector::GetDataMutable(result); ExecuteGenericLoop( UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), @@ -234,7 +264,7 @@ struct BinaryExecutor { } template - static void ExecuteSwitch(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + static void ExecuteSwitch(const Vector &left, const Vector &right, Vector &result, idx_t count, FUNC fun) { auto left_vector_type = left.GetVectorType(); auto right_vector_type = right.GetVectorType(); if (left_vector_type == VectorType::CONSTANT_VECTOR && right_vector_type == VectorType::CONSTANT_VECTOR) { @@ -258,7 +288,7 @@ struct BinaryExecutor { public: template > - static void Execute(Vector &left, Vector &right, Vector &result, idx_t count, FUNC fun) { + static void Execute(const Vector &left, const Vector &right, Vector &result, idx_t count, FUNC fun) { constexpr bool adds_nulls = std::is_same(), std::declval())), optional>::value; ExecuteSwitch, bool, FUNC>( @@ -267,19 +297,48 @@ struct BinaryExecutor { template - static void Execute(Vector &left, Vector &right, Vector &result, idx_t count) { + static void Execute(const Vector &left, const Vector &right, Vector &result, idx_t count) { ExecuteSwitch(left, right, result, count, false); } template - static void ExecuteStandard(Vector &left, Vector &right, Vector &result, idx_t count) { + static void ExecuteStandard(const Vector &left, const Vector &right, Vector &result, idx_t count) { ExecuteSwitch(left, right, result, count, false); } +private: + static idx_t CheckExecuteCount(const Vector &left, const Vector &right) { + if (left.size() != right.size()) { + throw InternalException( + "Mismatch in input vector sizes for BinaryExecutor - left has %d rows but right has %d", left.size(), + right.size()); + } + return left.size(); + } + +public: + //! Convenience overloads without explicit count - count is derived from the input vectors. + template > + static void Execute(const Vector &left, const Vector &right, Vector &result, FUNC fun) { + Execute(left, right, result, CheckExecuteCount(left, right), fun); + } + + template + static void Execute(const Vector &left, const Vector &right, Vector &result) { + Execute(left, right, result, CheckExecuteCount(left, right)); + } + + template + static void ExecuteStandard(const Vector &left, const Vector &right, Vector &result) { + ExecuteStandard(left, right, result, CheckExecuteCount(left, right)); + } + public: template - static idx_t SelectConstant(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + static idx_t SelectConstant(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { auto ldata = ConstantVector::GetData(left); auto rdata = ConstantVector::GetData(right); @@ -303,6 +362,8 @@ struct BinaryExecutor { } } +// NOTE: the flat path is intentionally NOT covered by the ComparisonSelectComplement fold (unlike +// the generic and constant paths above). It is null-unified — there is no NO_NULL template split; #ifndef DUCKDB_SMALLER_BINARY template @@ -387,7 +448,7 @@ struct BinaryExecutor { } template - static idx_t SelectFlat(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + static idx_t SelectFlat(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { auto ldata = LEFT_CONSTANT ? ConstantVector::GetData(left) : FlatVector::GetData(left); auto rdata = @@ -412,13 +473,13 @@ struct BinaryExecutor { if (LEFT_CONSTANT) { return SelectFlatLoopSwitch( - ldata, rdata, sel, count, FlatVector::ValidityMutable(right), true_sel, false_sel); + ldata, rdata, sel, count, FlatVector::Validity(right), true_sel, false_sel); } else if (RIGHT_CONSTANT) { return SelectFlatLoopSwitch( - ldata, rdata, sel, count, FlatVector::ValidityMutable(left), true_sel, false_sel); + ldata, rdata, sel, count, FlatVector::Validity(left), true_sel, false_sel); } else { - ValidityMask combined_mask = FlatVector::ValidityMutable(left); - combined_mask.Combine(FlatVector::ValidityMutable(right), count); + ValidityMask combined_mask = FlatVector::Validity(left); + combined_mask.Combine(FlatVector::Validity(right), count); return SelectFlatLoopSwitch( ldata, rdata, sel, count, combined_mask, true_sel, false_sel); } @@ -461,6 +522,15 @@ struct BinaryExecutor { const SelectionVector &generic_sel, const ValidityMask &mask, const SelectionVector &result_sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { + // NO_NULL complement fold: NotEquals/GreaterThanEquals are exactly the complement of + // Equals/GreaterThan when there are no NULLs + if constexpr (!CAN_HAVE_NULL && ComparisonSelectComplement::FOLD) { + using FOLDED = ComparisonSelectComplement; + constexpr bool FOLDED_RIGHT_CONSTANT = FOLDED::SWAP_OPERANDS ? !RIGHT_CONSTANT : RIGHT_CONSTANT; + return count - SelectGenericConstant( + constant, data, generic_sel, mask, result_sel, count, false_sel, true_sel); + } if (true_sel && false_sel) { return SelectGenericConstant( constant, data, generic_sel, mask, result_sel, count, true_sel, false_sel); @@ -490,7 +560,7 @@ struct BinaryExecutor { } template - static idx_t SelectGenericConstant(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + static idx_t SelectGenericConstant(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { constexpr bool LEFT_CONSTANT = !RIGHT_CONSTANT; if (LEFT_CONSTANT && ConstantVector::IsNull(left)) { @@ -540,7 +610,7 @@ struct BinaryExecutor { static inline idx_t SelectGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, const SelectionVector &result_sel, - idx_t count, ValidityMask &lvalidity, ValidityMask &rvalidity, + idx_t count, const ValidityMask &lvalidity, const ValidityMask &rvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { idx_t true_count = 0, false_count = 0; #ifdef DUCKDB_SMALLER_BINARY @@ -552,15 +622,16 @@ struct BinaryExecutor { auto result_idx = result_sel.get_index(i); auto lindex = lsel->get_index(i); auto rindex = rsel->get_index(i); - if ((NO_NULL || (lvalidity.RowIsValid(lindex) && rvalidity.RowIsValid(rindex))) && - OP::Operation(ldata[lindex], rdata[rindex])) { - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count++, result_idx); - } - } else { - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count++, result_idx); - } + const bool comparison_result = + (NO_NULL || (lvalidity.RowIsValid(lindex) && rvalidity.RowIsValid(rindex))) && + OP::Operation(ldata[lindex], rdata[rindex]); + if (HAS_TRUE_SEL) { + true_sel->set_index(true_count, result_idx); + true_count += comparison_result; + } + if (HAS_FALSE_SEL) { + false_sel->set_index(false_count, result_idx); + false_count += !comparison_result; } } if (HAS_TRUE_SEL) { @@ -575,8 +646,22 @@ struct BinaryExecutor { static inline idx_t SelectGenericLoopSelSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, - const SelectionVector &result_sel, idx_t count, ValidityMask &lvalidity, - ValidityMask &rvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { + const SelectionVector &result_sel, idx_t count, const ValidityMask &lvalidity, + const ValidityMask &rvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { + // NO_NULL complement fold: NotEquals/GreaterThanEquals are exactly the complement of + // Equals/GreaterThan when there are no NULLs + if constexpr (NO_NULL && ComparisonSelectComplement::FOLD) { + using FOLDED = ComparisonSelectComplement; + if constexpr (FOLDED::SWAP_OPERANDS) { + return count - + SelectGenericLoopSelSwitch( + rdata, ldata, rsel, lsel, result_sel, count, rvalidity, lvalidity, false_sel, true_sel); + } else { + return count - + SelectGenericLoopSelSwitch( + ldata, rdata, lsel, rsel, result_sel, count, lvalidity, rvalidity, false_sel, true_sel); + } + } if (true_sel && false_sel) { return SelectGenericLoop( ldata, rdata, lsel, rsel, result_sel, count, lvalidity, rvalidity, true_sel, false_sel); @@ -595,8 +680,8 @@ struct BinaryExecutor { static inline idx_t SelectGenericLoopSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, - const SelectionVector &result_sel, idx_t count, ValidityMask &lvalidity, - ValidityMask &rvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { + const SelectionVector &result_sel, idx_t count, const ValidityMask &lvalidity, + const ValidityMask &rvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { #ifndef DUCKDB_SMALLER_BINARY if (lvalidity.CanHaveNull() || rvalidity.CanHaveNull()) { return SelectGenericLoopSelSwitch( @@ -612,7 +697,7 @@ struct BinaryExecutor { } template - static idx_t SelectGeneric(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, + static idx_t SelectGeneric(const Vector &left, const Vector &right, const SelectionVector &sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { UnifiedVectorFormat ldata, rdata; @@ -625,8 +710,8 @@ struct BinaryExecutor { } template - static idx_t Select(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { + static idx_t Select(const Vector &left, const Vector &right, const SelectionVector *sel, idx_t count, + SelectionVector *true_sel, SelectionVector *false_sel) { if (!sel) { sel = FlatVector::IncrementalSelectionVector(); } diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/generic_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/generic_executor.hpp index 842e1e8c1..b6b63602a 100644 --- a/src/duckdb/src/include/duckdb/common/vector_operations/generic_executor.hpp +++ b/src/duckdb/src/include/duckdb/common/vector_operations/generic_executor.hpp @@ -23,7 +23,7 @@ namespace duckdb { struct PrimitiveTypeState { UnifiedVectorFormat main_data; - void PrepareVector(Vector &input, idx_t count) { + void PrepareVector(const Vector &input) { input.ToUnifiedFormat(main_data); } }; @@ -57,8 +57,8 @@ struct StructTypeState { UnifiedVectorFormat main_data; UnifiedVectorFormat child_data[CHILD_COUNT]; - void PrepareVector(Vector &input, idx_t count) { - auto &entries = StructVector::GetEntries(input); + void PrepareVector(const Vector &input) { + const auto &entries = StructVector::GetEntries(input); input.ToUnifiedFormat(main_data); @@ -250,11 +250,11 @@ struct GenericListType { struct GenericExecutor { private: template - static void ExecuteUnaryInternal(Vector &input, Vector &result, idx_t count, FUNC &fun) { + static void ExecuteUnaryInternal(const Vector &input, Vector &result, idx_t count, FUNC &fun) { auto constant = input.GetVectorType() == VectorType::CONSTANT_VECTOR; typename A_TYPE::STRUCT_STATE state; - state.PrepareVector(input, count); + state.PrepareVector(input); for (idx_t i = 0; i < (constant ? 1 : count); i++) { auto idx = state.main_data.sel->get_index(i); @@ -275,14 +275,14 @@ struct GenericExecutor { } template - static void ExecuteBinaryInternal(Vector &a, Vector &b, Vector &result, idx_t count, FUNC &fun) { + static void ExecuteBinaryInternal(const Vector &a, const Vector &b, Vector &result, idx_t count, FUNC &fun) { auto constant = a.GetVectorType() == VectorType::CONSTANT_VECTOR && b.GetVectorType() == VectorType::CONSTANT_VECTOR; typename A_TYPE::STRUCT_STATE a_state; typename B_TYPE::STRUCT_STATE b_state; - a_state.PrepareVector(a, count); - b_state.PrepareVector(b, count); + a_state.PrepareVector(a); + b_state.PrepareVector(b); for (idx_t i = 0; i < (constant ? 1 : count); i++) { auto a_idx = a_state.main_data.sel->get_index(i); @@ -305,7 +305,8 @@ struct GenericExecutor { } template - static void ExecuteTernaryInternal(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUNC &fun) { + static void ExecuteTernaryInternal(const Vector &a, const Vector &b, const Vector &c, Vector &result, idx_t count, + FUNC &fun) { auto constant = a.GetVectorType() == VectorType::CONSTANT_VECTOR && b.GetVectorType() == VectorType::CONSTANT_VECTOR && c.GetVectorType() == VectorType::CONSTANT_VECTOR; @@ -314,9 +315,9 @@ struct GenericExecutor { typename B_TYPE::STRUCT_STATE b_state; typename C_TYPE::STRUCT_STATE c_state; - a_state.PrepareVector(a, count); - b_state.PrepareVector(b, count); - c_state.PrepareVector(c, count); + a_state.PrepareVector(a); + b_state.PrepareVector(b); + c_state.PrepareVector(c); for (idx_t i = 0; i < (constant ? 1 : count); i++) { auto a_idx = a_state.main_data.sel->get_index(i); @@ -343,8 +344,8 @@ struct GenericExecutor { } template - static void ExecuteQuaternaryInternal(Vector &a, Vector &b, Vector &c, Vector &d, Vector &result, idx_t count, - FUNC &fun) { + static void ExecuteQuaternaryInternal(const Vector &a, const Vector &b, const Vector &c, const Vector &d, + Vector &result, idx_t count, FUNC &fun) { auto constant = a.GetVectorType() == VectorType::CONSTANT_VECTOR && b.GetVectorType() == VectorType::CONSTANT_VECTOR && c.GetVectorType() == VectorType::CONSTANT_VECTOR && d.GetVectorType() == VectorType::CONSTANT_VECTOR; @@ -354,10 +355,10 @@ struct GenericExecutor { typename C_TYPE::STRUCT_STATE c_state; typename D_TYPE::STRUCT_STATE d_state; - a_state.PrepareVector(a, count); - b_state.PrepareVector(b, count); - c_state.PrepareVector(c, count); - d_state.PrepareVector(d, count); + a_state.PrepareVector(a); + b_state.PrepareVector(b); + c_state.PrepareVector(c); + d_state.PrepareVector(d); for (idx_t i = 0; i < (constant ? 1 : count); i++) { auto a_idx = a_state.main_data.sel->get_index(i); @@ -385,25 +386,60 @@ struct GenericExecutor { } } + static idx_t CheckExecuteCount(std::initializer_list inputs) { + idx_t count = (*inputs.begin())->size(); + for (auto *v : inputs) { + if (v->size() != count) { + throw InternalException("Mismatch in input vector sizes for GenericExecutor - " + "expected %d rows but got %d", + count, v->size()); + } + } + return count; + } + public: template > - static void ExecuteUnary(Vector &input, Vector &result, idx_t count, FUNC fun) { + static void ExecuteUnary(const Vector &input, Vector &result, idx_t count, FUNC fun) { ExecuteUnaryInternal(input, result, count, fun); } + template > + static void ExecuteUnary(const Vector &input, Vector &result, FUNC fun) { + ExecuteUnaryInternal(input, result, input.size(), fun); + } template > - static void ExecuteBinary(Vector &a, Vector &b, Vector &result, idx_t count, FUNC fun) { + static void ExecuteBinary(const Vector &a, const Vector &b, Vector &result, idx_t count, FUNC fun) { ExecuteBinaryInternal(a, b, result, count, fun); } + template > + static void ExecuteBinary(const Vector &a, const Vector &b, Vector &result, FUNC fun) { + ExecuteBinaryInternal(a, b, result, CheckExecuteCount({&a, &b}), fun); + } template > - static void ExecuteTernary(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUNC fun) { + static void ExecuteTernary(const Vector &a, const Vector &b, const Vector &c, Vector &result, idx_t count, + FUNC fun) { ExecuteTernaryInternal(a, b, c, result, count, fun); } + template > + static void ExecuteTernary(const Vector &a, const Vector &b, const Vector &c, Vector &result, FUNC fun) { + ExecuteTernaryInternal(a, b, c, result, + CheckExecuteCount({&a, &b, &c}), fun); + } template > - static void ExecuteQuaternary(Vector &a, Vector &b, Vector &c, Vector &d, Vector &result, idx_t count, FUNC fun) { + static void ExecuteQuaternary(const Vector &a, const Vector &b, const Vector &c, const Vector &d, Vector &result, + idx_t count, FUNC fun) { ExecuteQuaternaryInternal(a, b, c, d, result, count, fun); } + template > + static void ExecuteQuaternary(const Vector &a, const Vector &b, const Vector &c, const Vector &d, Vector &result, + FUNC fun) { + ExecuteQuaternaryInternal( + a, b, c, d, result, CheckExecuteCount({&a, &b, &c, &d}), fun); + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/ternary_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/ternary_executor.hpp index 2313a76f3..2a0bd95f0 100644 --- a/src/duckdb/src/include/duckdb/common/vector_operations/ternary_executor.hpp +++ b/src/duckdb/src/include/duckdb/common/vector_operations/ternary_executor.hpp @@ -17,24 +17,48 @@ namespace duckdb { struct TernaryExecutor { template > - static void Execute(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUN fun) { + static void Execute(const Vector &a, const Vector &b, const Vector &c, Vector &result, FUN fun) { std::array inputs = {{a, b, c}}; - VariadicExecutor::Execute(inputs, result, count, fun); + VariadicExecutor::Execute(inputs, result, fun); } template - static void ExecuteStandard(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count) { + static void ExecuteStandard(const Vector &a, const Vector &b, const Vector &c, Vector &result) { std::array inputs = {{a, b, c}}; - VariadicExecutor::ExecuteStandard(inputs, result, count); + VariadicExecutor::ExecuteStandard(inputs, result); } template - static idx_t Select(Vector &a, Vector &b, Vector &c, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel) { + static idx_t Select(const Vector &a, const Vector &b, const Vector &c, optional_ptr sel, + idx_t count, optional_ptr true_sel, optional_ptr false_sel) { std::array inputs = {{a, b, c}}; return VariadicExecutor::Select(inputs, sel.get(), count, true_sel.get(), false_sel.get()); } + + //===--------------------------------------------------------------------===// + // Deprecated overloads (count parameter removed - use count-free versions) + //===--------------------------------------------------------------------===// + template > + [[deprecated("count parameter is deprecated; call Execute without count instead")]] static void + Execute(const Vector &a, const Vector &b, const Vector &c, Vector &result, idx_t count, FUN fun) { + if (count != a.size()) { + throw InternalException("TernaryExecutor::Execute: count (%llu) does not match vector size (%llu)", count, + a.size()); + } + Execute(a, b, c, result, fun); + } + + template + [[deprecated("count parameter is deprecated; call ExecuteStandard without count instead")]] static void + ExecuteStandard(const Vector &a, const Vector &b, const Vector &c, Vector &result, idx_t count) { + if (count != a.size()) { + throw InternalException("TernaryExecutor::ExecuteStandard: count (%llu) does not match vector size (%llu)", + count, a.size()); + } + ExecuteStandard(a, b, c, result); + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/unary_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/unary_executor.hpp index 2179a6582..ce135d340 100644 --- a/src/duckdb/src/include/duckdb/common/vector_operations/unary_executor.hpp +++ b/src/duckdb/src/include/duckdb/common/vector_operations/unary_executor.hpp @@ -145,12 +145,15 @@ struct UnaryExecutor { #endif template - static inline void ExecuteStandard(Vector &input, Vector &result, idx_t count, DATA_TYPE &data, bool adds_nulls, + static inline void ExecuteStandard(const Vector &input, Vector &result, idx_t count, DATA_TYPE &data, + bool adds_nulls, FunctionErrors errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR) { switch (input.GetVectorType()) { case VectorType::CONSTANT_VECTOR: { result.SetVectorType(VectorType::CONSTANT_VECTOR); - FlatVector::SetSize(result, count); + if (result.size() != count) { + FlatVector::SetSize(result, count); + } auto result_data = ConstantVector::GetData(result); auto ldata = ConstantVector::GetData(input); @@ -166,6 +169,9 @@ struct UnaryExecutor { #ifndef DUCKDB_SMALLER_BINARY case VectorType::FLAT_VECTOR: { result.SetVectorType(VectorType::FLAT_VECTOR); + if (result.size() != count) { + FlatVector::SetSize(result, count); + } auto result_data = FlatVector::GetDataMutable(result); auto ldata = FlatVector::GetData(input); @@ -208,6 +214,9 @@ struct UnaryExecutor { input.ToUnifiedFormat(vdata); result.SetVectorType(VectorType::FLAT_VECTOR); + if (result.size() != count) { + FlatVector::SetSize(result, count); + } auto result_data = FlatVector::GetDataMutable(result); auto ldata = UnifiedVectorFormat::GetData(vdata); @@ -220,13 +229,13 @@ struct UnaryExecutor { public: template - static void Execute(Vector &input, Vector &result, idx_t count) { + static void Execute(const Vector &input, Vector &result, idx_t count) { std::nullptr_t no_data = nullptr; ExecuteStandard(input, result, count, no_data, false); } template > - static void Execute(Vector &input, Vector &result, idx_t count, FUNC fun, + static void Execute(const Vector &input, Vector &result, idx_t count, FUNC fun, FunctionErrors errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR) { constexpr bool adds_nulls = std::is_same())), optional>::value; @@ -235,23 +244,46 @@ struct UnaryExecutor { } template - static void GenericExecute(Vector &input, Vector &result, idx_t count, DATA_TYPE &data, bool adds_nulls = false) { + static void GenericExecute(const Vector &input, Vector &result, idx_t count, DATA_TYPE &data, + bool adds_nulls = false) { ExecuteStandard(input, result, count, data, adds_nulls); } template - static void ExecuteString(Vector &input, Vector &result, idx_t count) { + static void ExecuteString(const Vector &input, Vector &result, idx_t count) { auto &heap = StringVector::GetStringHeap(result); UnaryExecutor::GenericExecute>(input, result, count, heap); } + //! Convenience overloads without explicit count - count is derived from input.size(). + template + static void Execute(const Vector &input, Vector &result) { + Execute(input, result, input.size()); + } + + template > + static void Execute(const Vector &input, Vector &result, FUNC fun, + FunctionErrors errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR) { + Execute(input, result, input.size(), fun, errors); + } + + template + static void GenericExecute(const Vector &input, Vector &result, DATA_TYPE &data, bool adds_nulls = false) { + GenericExecute(input, result, input.size(), data, adds_nulls); + } + + template + static void ExecuteString(const Vector &input, Vector &result) { + ExecuteString(input, result, input.size()); + } + private: // Select logic copied from TernaryExecutor, but with a lambda instead of a static functor template , bool NO_NULL, bool HAS_TRUE_SEL, bool HAS_FALSE_SEL> static inline idx_t SelectLoop(const INPUT_TYPE *__restrict input_data, const SelectionVector *result_sel, const idx_t count, FUNC fun, const SelectionVector &input_sel, - ValidityMask &input_validity, SelectionVector *true_sel, + const ValidityMask &input_validity, SelectionVector *true_sel, SelectionVector *false_sel) { idx_t true_count = 0, false_count = 0; for (idx_t i = 0; i < count; i++) { @@ -306,7 +338,7 @@ struct UnaryExecutor { public: template > - static idx_t Select(Vector &input, const SelectionVector *sel, const idx_t count, FUNC fun, + static idx_t Select(const Vector &input, const SelectionVector *sel, const idx_t count, FUNC fun, SelectionVector *true_sel, SelectionVector *false_sel) { if (!sel) { sel = FlatVector::IncrementalSelectionVector(); diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/variadic_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/variadic_executor.hpp index 742912aa5..7fb9d0e04 100644 --- a/src/duckdb/src/include/duckdb/common/vector_operations/variadic_executor.hpp +++ b/src/duckdb/src/include/duckdb/common/vector_operations/variadic_executor.hpp @@ -56,12 +56,12 @@ struct VariadicStandardOperatorWrapper { //! a parameter pack must be last (or followed only by deducible params). //! The named executors (TernaryExecutor, etc.) provide backward-compatible APIs. struct VariadicExecutor { - using VectorRef = std::reference_wrapper; + using VectorRef = std::reference_wrapper; private: template static std::array MakeInputArrayImpl(DataChunk &input, std::index_sequence) { - return {{std::ref(input.data[Is])...}}; + return {{std::cref(input.data[Is])...}}; } template @@ -100,6 +100,7 @@ struct VariadicExecutor { if (AllConstant(inputs)) { result.SetVectorType(VectorType::CONSTANT_VECTOR); + FlatVector::SetSize(result, count); if (AnyConstantNull(inputs)) { ConstantVector::SetNull(result, true); } else { @@ -210,12 +211,32 @@ struct VariadicExecutor { } } +private: + //------------------------------------------------------------------- + // Verify all inputs have the same size and return that size. + //------------------------------------------------------------------- + template + static idx_t CheckExecuteCount(const std::array &inputs) { + static_assert(N > 0, "VariadicExecutor requires at least one input"); + idx_t count = inputs[0].get().size(); + for (size_t i = 1; i < N; i++) { + if (inputs[i].get().size() != count) { + throw InternalException( + "Mismatch in input vector sizes for VariadicExecutor - expected %d rows but got %d", count, + inputs[i].get().size()); + } + } + return count; + } + public: //------------------------------------------------------------------- // Execute: lambda-based, with Vector array //------------------------------------------------------------------- template - static void Execute(std::array inputs, Vector &result, idx_t count, FUN fun) { + static void Execute(std::array inputs, Vector &result, FUN fun) { + constexpr size_t N = sizeof...(ARGS); + const idx_t count = CheckExecuteCount(inputs); ExecuteImplWithIndices(inputs, result, count, fun, std::index_sequence_for {}); } @@ -234,7 +255,9 @@ struct VariadicExecutor { // ExecuteStandard: static OP::Operation, with Vector array //------------------------------------------------------------------- template - static void ExecuteStandard(std::array inputs, Vector &result, idx_t count) { + static void ExecuteStandard(std::array inputs, Vector &result) { + constexpr size_t N = sizeof...(ARGS); + const idx_t count = CheckExecuteCount(inputs); bool dummy = false; ExecuteImplWithIndices, bool, ARGS...>( inputs, result, count, dummy, std::index_sequence_for {}); diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/vector_operations.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/vector_operations.hpp index 50a807707..ff2bb4004 100644 --- a/src/duckdb/src/include/duckdb/common/vector_operations/vector_operations.hpp +++ b/src/duckdb/src/include/duckdb/common/vector_operations/vector_operations.hpp @@ -11,8 +11,6 @@ #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/common/types/vector.hpp" -#include - namespace duckdb { class CastFunctionSet; struct GetCastFunctionInput; @@ -29,121 +27,127 @@ struct VectorOperations { // In-Place Operators //===--------------------------------------------------------------------===// //! left += delta - static void AddInPlace(Vector &left, int64_t delta, idx_t count); + static void AddInPlace(Vector &left, int64_t delta); //===--------------------------------------------------------------------===// // NULL Operators //===--------------------------------------------------------------------===// //! result = IS NOT NULL(input) - static void IsNotNull(Vector &arg, Vector &result, idx_t count); + static void IsNotNull(const Vector &arg, Vector &result); //! result = IS NULL (input) - static void IsNull(Vector &input, Vector &result, idx_t count); + static void IsNull(const Vector &input, Vector &result); // Returns whether or not arg vector has a NULL value - static bool HasNull(Vector &input, idx_t count); - static bool HasNotNull(Vector &input, idx_t count); + static bool HasNull(const Vector &input); + static bool HasNotNull(const Vector &input); //! Count the number of not-NULL values. - static idx_t CountNotNull(Vector &input, const idx_t count); + static idx_t CountNotNull(const Vector &input); //===--------------------------------------------------------------------===// // Boolean Operations //===--------------------------------------------------------------------===// // result = left && right - static void And(Vector &left, Vector &right, Vector &result, idx_t count); + static void And(const Vector &left, const Vector &right, Vector &result); // result = left || right - static void Or(Vector &left, Vector &right, Vector &result, idx_t count); - // result = NOT(left) - static void Not(Vector &left, Vector &result, idx_t count); + static void Or(const Vector &left, const Vector &right, Vector &result); + // result = NOT(input) + static void Not(const Vector &input, Vector &result); //===--------------------------------------------------------------------===// // Comparison Operations //===--------------------------------------------------------------------===// // result = left == right - static void Equals(Vector &left, Vector &right, Vector &result, idx_t count); + static void Equals(const Vector &left, const Vector &right, Vector &result); // result = left != right - static void NotEquals(Vector &left, Vector &right, Vector &result, idx_t count); + static void NotEquals(const Vector &left, const Vector &right, Vector &result); // result = left > right - static void GreaterThan(Vector &left, Vector &right, Vector &result, idx_t count); + static void GreaterThan(const Vector &left, const Vector &right, Vector &result); // result = left >= right - static void GreaterThanEquals(Vector &left, Vector &right, Vector &result, idx_t count); + static void GreaterThanEquals(const Vector &left, const Vector &right, Vector &result); // result = left < right - static void LessThan(Vector &left, Vector &right, Vector &result, idx_t count); + static void LessThan(const Vector &left, const Vector &right, Vector &result); // result = left <= right - static void LessThanEquals(Vector &left, Vector &right, Vector &result, idx_t count); + static void LessThanEquals(const Vector &left, const Vector &right, Vector &result); // result = -1 if left < right, 0 if left == right, 1 if left > right (stored in int8_t TINYINT result vector) - static void Comparator(Vector &left, Vector &right, Vector &result, idx_t count); + static void Comparator(const Vector &left, const Vector &right, Vector &result); + // result = -1 if left < right, 0 if left == right, 1 if left > right; fills exactly count rows (for select paths) + static void ComparatorFill(const Vector &left, const Vector &right, Vector &result, idx_t count); // result = A != B with nulls being equal - static void DistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count); + static void DistinctFrom(const Vector &left, const Vector &right, Vector &result); // result := A == B with nulls being equal - static void NotDistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count); - // result := A > B with nulls being maximal - static void DistinctGreaterThan(Vector &left, Vector &right, Vector &result, idx_t count); - // result := A >= B with nulls being maximal - static void DistinctGreaterThanEquals(Vector &left, Vector &right, Vector &result, idx_t count); - // result := A < B with nulls being maximal - static void DistinctLessThan(Vector &left, Vector &right, Vector &result, idx_t count); - // result := A <= B with nulls being maximal - static void DistinctLessThanEquals(Vector &left, Vector &right, Vector &result, idx_t count); + static void NotDistinctFrom(const Vector &left, const Vector &right, Vector &result); // result := comparator(A, B) with nulls being maximal (NULLS LAST), returns -1/0/1 as int8_t - static void DistinctComparator(Vector &left, Vector &right, Vector &result, idx_t count); + static void DistinctComparator(const Vector &left, const Vector &right, Vector &result); // result := comparator(A, B) with nulls being minimal (NULLS FIRST), returns -1/0/1 as int8_t - static void DistinctComparatorNullsFirst(Vector &left, Vector &right, Vector &result, idx_t count); + static void DistinctComparatorNullsFirst(const Vector &left, const Vector &right, Vector &result); + // DistinctComparator/NullsFirst variants that fill exactly count rows (for select paths with SelectionVector) + static void DistinctComparatorFill(const Vector &left, const Vector &right, Vector &result, idx_t count); + static void DistinctComparatorNullsFirstFill(const Vector &left, const Vector &right, Vector &result, idx_t count); //===--------------------------------------------------------------------===// // Select Comparisons //===--------------------------------------------------------------------===// - static idx_t Equals(Vector &left, Vector &right, optional_ptr sel, idx_t count, + static idx_t Equals(const Vector &left, const Vector &right, optional_ptr sel, idx_t count, optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask = nullptr); - static idx_t NotEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, + static idx_t NotEquals(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask = nullptr); - static idx_t GreaterThan(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask = nullptr); - static idx_t GreaterThanEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, + static idx_t GreaterThan(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask = nullptr); + static idx_t GreaterThanEquals(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask = nullptr); - static idx_t LessThan(Vector &left, Vector &right, optional_ptr sel, idx_t count, + static idx_t LessThan(const Vector &left, const Vector &right, optional_ptr sel, idx_t count, optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask = nullptr); - static idx_t LessThanEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, + static idx_t LessThanEquals(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask = nullptr); // true := A != B with nulls being equal - static idx_t DistinctFrom(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel); + static idx_t DistinctFrom(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel); // true := A == B with nulls being equal - static idx_t NotDistinctFrom(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel); + static idx_t NotDistinctFrom(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel); // true := A > B with nulls being maximal - static idx_t DistinctGreaterThan(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, + static idx_t DistinctGreaterThan(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask = nullptr); // true := A >= B with nulls being maximal - static idx_t DistinctGreaterThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, + static idx_t DistinctGreaterThanEquals(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask = nullptr); // true := A < B with nulls being maximal - static idx_t DistinctLessThan(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, + static idx_t DistinctLessThan(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask = nullptr); // true := A <= B with nulls being maximal - static idx_t DistinctLessThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, + static idx_t DistinctLessThanEquals(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask = nullptr); // true := A > B with nulls being minimal - static idx_t DistinctGreaterThanNullsFirst(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, + static idx_t DistinctGreaterThanNullsFirst(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask = nullptr); // true := A < B with nulls being minimal - static idx_t DistinctLessThanNullsFirst(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, + static idx_t DistinctLessThanNullsFirst(const Vector &left, const Vector &right, + optional_ptr sel, idx_t count, + optional_ptr true_sel, optional_ptr false_sel, optional_ptr null_mask = nullptr); @@ -151,23 +155,30 @@ struct VectorOperations { // Nested Comparisons //===--------------------------------------------------------------------===// // true := A != B with nulls being equal - static idx_t NestedNotEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, + static idx_t NestedNotEquals(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask = nullptr); // true := A == B with nulls being equal - static idx_t NestedEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask = nullptr); + static idx_t NestedEquals(const Vector &left, const Vector &right, optional_ptr sel, + idx_t count, optional_ptr true_sel, + optional_ptr false_sel, optional_ptr null_mask = nullptr); //===--------------------------------------------------------------------===// // Hash functions //===--------------------------------------------------------------------===// // hashes = HASH(input) - static void Hash(Vector &input, Vector &hashes, idx_t count); - static void Hash(Vector &input, Vector &hashes, const SelectionVector &rsel, idx_t count); + // FIXME: this "count" needs to go + static void Hash(const Vector &input, Vector &hashes, idx_t count); + static void Hash(const Vector &input, Vector &hashes, const SelectionVector &rsel, idx_t count); + //! Convenience overload without explicit count - count is derived from input.size() + static void Hash(const Vector &input, Vector &hashes); // hashes ^= HASH(input) - static void CombineHash(Vector &hashes, Vector &input, idx_t count); - static void CombineHash(Vector &hashes, Vector &input, const SelectionVector &rsel, idx_t count); + // FIXME: this "count" needs to go + static void CombineHash(Vector &hashes, const Vector &input, idx_t count); + static void CombineHash(Vector &hashes, const Vector &input, const SelectionVector &rsel, idx_t count); + //! Convenience overload without explicit count - count is derived from the input vectors (with mismatch check) + static void CombineHash(Vector &hashes, const Vector &input); //===--------------------------------------------------------------------===// // Generate functions @@ -203,9 +214,167 @@ struct VectorOperations { // Copy the data of to the target location, setting null values to // NullValue. Used to store data without separate NULL mask. - static void WriteToStorage(Vector &source, idx_t count, data_ptr_t target); + static void WriteToStorage(const Vector &source, data_ptr_t target); // Reads the data of to the target vector, setting the nullmask // for any NullValue of source. Used to go back from storage to a proper vector static void ReadFromStorage(data_ptr_t source, idx_t count, Vector &result); + + //===--------------------------------------------------------------------===// + // Deprecated overloads (count parameter removed - use count-free versions) + //===--------------------------------------------------------------------===// + [[deprecated("count parameter is deprecated; call AddInPlace without count instead")]] static void + AddInPlace(Vector &left, int64_t delta, idx_t count) { + if (count != left.size()) { + throw InternalException("AddInPlace: count (%llu) does not match vector size (%llu)", count, left.size()); + } + AddInPlace(left, delta); + } + [[deprecated("count parameter is deprecated; call IsNotNull without count instead")]] static void + IsNotNull(const Vector &arg, Vector &result, idx_t count) { + if (count != arg.size()) { + throw InternalException("IsNotNull: count (%llu) does not match vector size (%llu)", count, arg.size()); + } + IsNotNull(arg, result); + } + [[deprecated("count parameter is deprecated; call IsNull without count instead")]] static void + IsNull(const Vector &input, Vector &result, idx_t count) { + if (count != input.size()) { + throw InternalException("IsNull: count (%llu) does not match vector size (%llu)", count, input.size()); + } + IsNull(input, result); + } + [[deprecated("count parameter is deprecated; call HasNull without count instead")]] static bool + HasNull(const Vector &input, idx_t count) { + if (count != input.size()) { + throw InternalException("HasNull: count (%llu) does not match vector size (%llu)", count, input.size()); + } + return HasNull(input); + } + [[deprecated("count parameter is deprecated; call HasNotNull without count instead")]] static bool + HasNotNull(const Vector &input, idx_t count) { + if (count != input.size()) { + throw InternalException("HasNotNull: count (%llu) does not match vector size (%llu)", count, input.size()); + } + return HasNotNull(input); + } + [[deprecated("count parameter is deprecated; call CountNotNull without count instead")]] static idx_t + CountNotNull(const Vector &input, const idx_t count) { + if (count != input.size()) { + throw InternalException("CountNotNull: count (%llu) does not match vector size (%llu)", count, + input.size()); + } + return CountNotNull(input); + } + [[deprecated("count parameter is deprecated; call And without count instead")]] static void + And(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("And: count (%llu) does not match vector size (%llu)", count, left.size()); + } + And(left, right, result); + } + [[deprecated("count parameter is deprecated; call Or without count instead")]] static void + Or(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("Or: count (%llu) does not match vector size (%llu)", count, left.size()); + } + Or(left, right, result); + } + [[deprecated("count parameter is deprecated; call Not without count instead")]] static void + Not(const Vector &input, Vector &result, idx_t count) { + if (count != input.size()) { + throw InternalException("Not: count (%llu) does not match vector size (%llu)", count, input.size()); + } + Not(input, result); + } + [[deprecated("count parameter is deprecated; call Equals without count instead")]] static void + Equals(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("Equals: count (%llu) does not match vector size (%llu)", count, left.size()); + } + Equals(left, right, result); + } + [[deprecated("count parameter is deprecated; call NotEquals without count instead")]] static void + NotEquals(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("NotEquals: count (%llu) does not match vector size (%llu)", count, left.size()); + } + NotEquals(left, right, result); + } + [[deprecated("count parameter is deprecated; call GreaterThan without count instead")]] static void + GreaterThan(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("GreaterThan: count (%llu) does not match vector size (%llu)", count, left.size()); + } + GreaterThan(left, right, result); + } + [[deprecated("count parameter is deprecated; call GreaterThanEquals without count instead")]] static void + GreaterThanEquals(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("GreaterThanEquals: count (%llu) does not match vector size (%llu)", count, + left.size()); + } + GreaterThanEquals(left, right, result); + } + [[deprecated("count parameter is deprecated; call LessThan without count instead")]] static void + LessThan(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("LessThan: count (%llu) does not match vector size (%llu)", count, left.size()); + } + LessThan(left, right, result); + } + [[deprecated("count parameter is deprecated; call LessThanEquals without count instead")]] static void + LessThanEquals(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("LessThanEquals: count (%llu) does not match vector size (%llu)", count, + left.size()); + } + LessThanEquals(left, right, result); + } + [[deprecated("count parameter is deprecated; call Comparator without count instead")]] static void + Comparator(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("Comparator: count (%llu) does not match vector size (%llu)", count, left.size()); + } + Comparator(left, right, result); + } + [[deprecated("count parameter is deprecated; call DistinctFrom without count instead")]] static void + DistinctFrom(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("DistinctFrom: count (%llu) does not match vector size (%llu)", count, left.size()); + } + DistinctFrom(left, right, result); + } + [[deprecated("count parameter is deprecated; call NotDistinctFrom without count instead")]] static void + NotDistinctFrom(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("NotDistinctFrom: count (%llu) does not match vector size (%llu)", count, + left.size()); + } + NotDistinctFrom(left, right, result); + } + [[deprecated("count parameter is deprecated; call DistinctComparator without count instead")]] static void + DistinctComparator(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("DistinctComparator: count (%llu) does not match vector size (%llu)", count, + left.size()); + } + DistinctComparator(left, right, result); + } + [[deprecated("count parameter is deprecated; call DistinctComparatorNullsFirst without count instead")]] static void + DistinctComparatorNullsFirst(const Vector &left, const Vector &right, Vector &result, idx_t count) { + if (count != left.size()) { + throw InternalException("DistinctComparatorNullsFirst: count (%llu) does not match vector size (%llu)", + count, left.size()); + } + DistinctComparatorNullsFirst(left, right, result); + } + [[deprecated("count parameter is deprecated; call WriteToStorage without count instead")]] static void + WriteToStorage(const Vector &source, idx_t count, data_ptr_t target) { + if (count != source.size()) { + throw InternalException("WriteToStorage: count (%llu) does not match vector size (%llu)", count, + source.size()); + } + WriteToStorage(source, target); + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp index 17b59af59..54ad61779 100644 --- a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp @@ -28,6 +28,10 @@ class VirtualFileSystem : public FileSystem { int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + unique_ptr MemoryMapFile(const OpenFileInfo &path, FileOpenFlags flags, + const MMapOptions &options, + optional_ptr opener = nullptr) override; + int64_t GetFileSize(FileHandle &handle) override; timestamp_t GetLastModifiedTime(FileHandle &handle) override; string GetVersionTag(FileHandle &handle) override; @@ -60,6 +64,8 @@ class VirtualFileSystem : public FileSystem { vector ListSubSystems() override; + FileSystem &GetDefaultFileSystem(); + std::string GetName() const override; void SetDisabledFileSystems(const vector &names) override; diff --git a/src/duckdb/src/include/duckdb/execution/aggregate_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/aggregate_hashtable.hpp index 21a72c747..3dd4bee7d 100644 --- a/src/duckdb/src/include/duckdb/execution/aggregate_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/aggregate_hashtable.hpp @@ -15,6 +15,7 @@ #include "duckdb/storage/arena_allocator.hpp" #include "duckdb/common/row_operations/row_operations.hpp" #include "duckdb/common/types/hyperloglog.hpp" +#include "duckdb/common/clustered_aggregate.hpp" namespace duckdb { @@ -125,6 +126,8 @@ class GroupedAggregateHashTable : public BaseAggregateHashTable { //! Executes the filter(if any) and update the aggregates void Combine(GroupedAggregateHashTable &other); void Combine(TupleDataCollection &other_data, optional_ptr> progress = nullptr); + //! Reset the HT for a new execution while reusing internal allocations where possible + void ResetForNewIteration(idx_t initial_capacity, idx_t radix_bits); private: ClientContext &context; @@ -143,6 +146,17 @@ class GroupedAggregateHashTable : public BaseAggregateHashTable { unique_ptr dictionary_addresses; unsafe_unique_array found_entry; idx_t capacity = 0; + //! Cumulative count of dict slots added to found_entry under the current id; the + //! unique-entries walk is skipped once it reaches dict_size. + idx_t resolved_count = 0; + //! For the clustered-aggregate fast path: track whether every state pointer added to this + //! dictionary shares the same high (64 - SLOT_GID_BITS) bits. The clustered path stores only + //! the low SLOT_GID_BITS of each pointer, so if any two pointers differ in their high bits + //! we cannot reuse the truncated form to disambiguate them and must fall back to scatter. + //! address_high_bits starts at the sentinel ~0 ("no address seen yet"); a real high-bits + //! value always has its low SLOT_GID_BITS cleared so ~0 cannot collide with a valid value. + uint64_t address_high_bits = ~uint64_t(0); + bool address_high_bits_uniform = true; }; //! If we have this many or more radix bits, we use the unpartitioned data collection too @@ -202,6 +216,8 @@ class GroupedAggregateHashTable : public BaseAggregateHashTable { RowOperationsState row_state; } state; + ClusteredAggrState clustered_state; + private: //! Disabled the copy constructor GroupedAggregateHashTable(const GroupedAggregateHashTable &) = delete; @@ -218,7 +234,10 @@ class GroupedAggregateHashTable : public BaseAggregateHashTable { //! Reinserts tuples (triggered by Resize) void ReinsertTuples(PartitionedTupleData &data); - void UpdateAggregates(DataChunk &payload, const unsafe_vector &filter); + void UpdateAggregates(DataChunk &payload, const unsafe_vector &filter, idx_t count, + bool ht_offsets_valid = true); + bool UpdateAggregatesClustered(DataChunk &payload, const unsafe_vector &filter, idx_t count, + bool ht_offsets_valid); //! Does the actual group matching / creation idx_t FindOrCreateGroupsInternal(DataChunk &groups, Vector &group_hashes, Vector &addresses, diff --git a/src/duckdb/src/include/duckdb/execution/base_aggregate_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/base_aggregate_hashtable.hpp index ce59051fa..dd0102a21 100644 --- a/src/duckdb/src/include/duckdb/execution/base_aggregate_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/base_aggregate_hashtable.hpp @@ -24,6 +24,9 @@ class BaseAggregateHashTable { } protected: + static bool AllAggregatesClustered(const vector &aggregates); + static idx_t CountAggregatesClustered(const vector &aggregates); + Allocator &allocator; BufferManager &buffer_manager; //! A helper for managing offsets into the data buffers diff --git a/src/duckdb/src/include/duckdb/execution/executor.hpp b/src/duckdb/src/include/duckdb/execution/executor.hpp index 46315a989..b91005e8d 100644 --- a/src/duckdb/src/include/duckdb/execution/executor.hpp +++ b/src/duckdb/src/include/duckdb/execution/executor.hpp @@ -75,7 +75,7 @@ class Executor { void ThrowException(); //! Work on tasks for this specific executor, until there are no tasks remaining - void WorkOnTasks(); + bool WorkOnTasks(); //! Flush a thread context into the client context void Flush(ThreadContext &context); diff --git a/src/duckdb/src/include/duckdb/execution/expression_executor.hpp b/src/duckdb/src/include/duckdb/execution/expression_executor.hpp index 52b7ed4d7..bbd13de56 100644 --- a/src/duckdb/src/include/duckdb/execution/expression_executor.hpp +++ b/src/duckdb/src/include/duckdb/execution/expression_executor.hpp @@ -146,7 +146,7 @@ class ExpressionExecutor { //! Verify that the output of a step in the ExpressionExecutor is correct void Verify(const Expression &expr, Vector &result, idx_t count); - void FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count); + void FillSwitch(const Vector &vector, Vector &result, const SelectionVector &sel, sel_t count); private: //! Client context diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp index 8fbed71bc..981e715f2 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp @@ -47,7 +47,7 @@ class ART : public BoundIndex { static constexpr uint8_t DEPRECATED_ALLOCATOR_COUNT = ALLOCATOR_COUNT - 3; public: - ART(const string &name, const IndexConstraintType index_constraint_type, const vector &column_ids, + ART(const Identifier &name, const IndexConstraintType index_constraint_type, const vector &column_ids, TableIOManager &table_io_manager, const vector> &unbound_expressions, AttachedDatabase &db, const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr = nullptr, @@ -69,7 +69,7 @@ class ART : public BoundIndex { //! True, if the ART owns its data. bool owns_data; //! Storage version that the ART was created in, used for backwards compatible key generation - optional_idx storage_version; + StorageVersion storage_version; public: //! Try to initialize a scan on the ART with the given expression and filter. @@ -146,8 +146,8 @@ class ART : public BoundIndex { //! ART key generation. template void GenerateKeys(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys); - void GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, Vector &row_ids, unsafe_vector &keys, - unsafe_vector &row_id_keys); + void GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, const Vector &row_ids, + unsafe_vector &keys, unsafe_vector &row_id_keys); //! Verifies the nodes. void Verify(IndexLock &l) override; diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art_key.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art_key.hpp index d6190d737..87bf60fa3 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art_key.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art_key.hpp @@ -14,6 +14,7 @@ #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/storage/arena_allocator.hpp" +#include "duckdb/storage/storage_info.hpp" namespace duckdb { @@ -56,7 +57,7 @@ class ARTKey { return ARTKey(new_data, len); } - static ARTKey CreateKey(ArenaAllocator &allocator, Value &value, optional_idx storage_version); + static ARTKey CreateKey(ArenaAllocator &allocator, Value &value, StorageVersion storage_version); public: data_t &operator[](idx_t i) { diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp index 7d66df318..d882e3ba7 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp @@ -321,8 +321,9 @@ class ARTOperator { const auto byte = current_key.get()[depth]; if (current.get().HasByte(art, byte)) { Node::DeleteChild(art, current, parent, byte, status, row_id); + return true; } - return true; + return false; } } } diff --git a/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp b/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp index dba80ce62..4cc25c0c2 100644 --- a/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp @@ -59,7 +59,7 @@ enum class DeltaIndexType { //! The index is an abstract base class that serves as the basis for indexes class BoundIndex : public Index { public: - BoundIndex(const string &name, const string &index_type, IndexConstraintType index_constraint_type, + BoundIndex(const Identifier &name, const string &index_type, IndexConstraintType index_constraint_type, const vector &column_ids, TableIOManager &table_io_manager, const vector> &unbound_expressions, AttachedDatabase &db); @@ -69,7 +69,7 @@ class BoundIndex : public Index { vector logical_types; //! The name of the index - string name; + Identifier name; //! The index type (ART, B+-tree, Skip-List, ...) string index_type; //! The index constraint type @@ -97,7 +97,7 @@ class BoundIndex : public Index { const string &GetIndexType() const override { return index_type; } - const string &GetIndexName() const override { + const Identifier &GetIndexName() const override { return name; } IndexConstraintType GetConstraintType() const override { diff --git a/src/duckdb/src/include/duckdb/execution/index/index_type.hpp b/src/duckdb/src/include/duckdb/execution/index/index_type.hpp index 8d0768758..fd7c81edf 100644 --- a/src/duckdb/src/include/duckdb/execution/index/index_type.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/index_type.hpp @@ -33,14 +33,14 @@ struct CreateIndexInput { TableIOManager &table_io_manager; AttachedDatabase &db; IndexConstraintType constraint_type; - const string &name; + const Identifier &name; const vector &column_ids; const vector> &unbound_expressions; const IndexStorageInfo &storage_info; const case_insensitive_map_t &options; CreateIndexInput(ClientContext &context, TableIOManager &table_io_manager, AttachedDatabase &db, - IndexConstraintType constraint_type, const string &name, const vector &column_ids, + IndexConstraintType constraint_type, const Identifier &name, const vector &column_ids, const vector> &unbound_expressions, const IndexStorageInfo &storage_info, const case_insensitive_map_t &options) : context(context), table_io_manager(table_io_manager), db(db), constraint_type(constraint_type), name(name), diff --git a/src/duckdb/src/include/duckdb/execution/index/index_type_set.hpp b/src/duckdb/src/include/duckdb/execution/index/index_type_set.hpp index 201e3b90c..a535f4a2d 100644 --- a/src/duckdb/src/include/duckdb/execution/index/index_type_set.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/index_type_set.hpp @@ -18,7 +18,7 @@ namespace duckdb { class IndexTypeSet { mutex lock; - case_insensitive_tree_t functions; + identifier_tree_t functions; public: IndexTypeSet(); diff --git a/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp b/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp index 0173b4e3f..09928e893 100644 --- a/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp @@ -88,7 +88,7 @@ class UnboundIndex final : public Index { const string &GetIndexType() const override { return GetCreateInfo().index_type; } - const string &GetIndexName() const override { + const Identifier &GetIndexName() const override { return GetCreateInfo().index_name; } IndexConstraintType GetConstraintType() const override { @@ -103,7 +103,7 @@ class UnboundIndex final : public Index { const vector> &GetParsedExpressions() const { return GetCreateInfo().parsed_expressions; } - const string &GetTableName() const { + const Identifier &GetTableName() const { return GetCreateInfo().table; } diff --git a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp index 75ea9055d..1f2ba57fb 100644 --- a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/helper.hpp" #include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/types/column/column_data_consumer.hpp" #include "duckdb/common/types/column/partitioned_column_data.hpp" @@ -19,8 +20,7 @@ #include "duckdb/common/unique_ptr.hpp" #include "duckdb/execution/aggregate_hashtable.hpp" #include "duckdb/execution/ht_entry.hpp" -#include "duckdb/planner/filter/bloom_filter.hpp" -#include "duckdb/planner/filter/prefix_range_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" namespace duckdb { @@ -117,6 +117,7 @@ class JoinHashTable { SelectionVector last_sel_vector; explicit ScanStructure(JoinHashTable &ht, TupleDataChunkState &key_state); + void Reset(); //! Get the next batch of data from the scan structure void Next(DataChunk &keys, DataChunk &probe_data, DataChunk &result); //! Are pointer chains all pointing to NULL? @@ -148,7 +149,7 @@ class JoinHashTable { //! Apply residual predicate filtering idx_t ApplyResidualPredicate(DataChunk &probe_data, SelectionVector &match_sel, idx_t match_count, - SelectionVector *no_match_sel, idx_t no_match_offset = 0); + optional_ptr no_match_sel, idx_t no_match_offset = 0); private: unique_ptr residual_executor; @@ -162,7 +163,7 @@ class JoinHashTable { void GatherResult(Vector &result, const SelectionVector &sel_vector, const idx_t count, const idx_t col_idx); void GatherResult(Vector &result, const idx_t count, const idx_t col_idx); idx_t ResolvePredicates(DataChunk &keys, DataChunk &probe_data, SelectionVector &match_sel, - SelectionVector *no_match_sel); + optional_ptr no_match_sel); }; public: @@ -175,12 +176,34 @@ class JoinHashTable { SelectionVector keys_no_match_sel; }; + //! Mirrors GroupedAggregateHashTable::AggregateDictionaryState for the join probe path + struct ProbeDictionaryState { + ProbeDictionaryState(); + + //! The current dictionary vector id (if any) + string dictionary_id; + DataChunk unique_values; + TupleDataChunkState unique_key_state; + Vector hashes; + Vector new_dictionary_pointers; + SelectionVector unique_entries; + //! Per-slot head-of-chain pointer cache; nullptr marks a miss + unique_ptr dictionary_pointers; + unsafe_unique_array found_entry; + idx_t capacity = 0; + SelectionVector match_sel; + //! Number of dict slots already resolved; the unique-entries walk is skipped once it reaches dict_size + idx_t resolved_count = 0; + }; + struct ProbeState : SharedState { ProbeState(); Vector ht_offsets_and_salts_v; Vector hashes_dense_v; SelectionVector non_empty_sel; + //! Allocated only when the operator gates the compressed-probe paths on; null otherwise + unique_ptr dict_state; }; struct InsertState : SharedState { @@ -376,24 +399,39 @@ class JoinHashTable { private: void InitializeScanStructure(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, - const SelectionVector *¤t_sel); + optional_ptr ¤t_sel); void Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes); + //! Dictionary-aware variant of Probe. Returns false if the LHS keys are not dictionary-eligible. + bool TryProbeDictionary(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, + ProbeState &probe_state); + //! Constant-vector variant of Probe. Returns false if the LHS keys are not a constant vector. + bool TryProbeConstant(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, + ProbeState &probe_state); + bool UseSalt() const; //! Gets a pointer to the entry in the HT for each of the hashes_v using linear probing. Will update the //! key_match_sel vector and the count argument to the number and position of the matches void GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, - const SelectionVector *sel, idx_t &count, Vector &pointers_result_v, SelectionVector &match_sel, - bool has_sel); + optional_ptr sel, idx_t &count, Vector &pointers_result_v, + SelectionVector &match_sel, bool has_sel); private: //! Insert the given set of locations into the HT with the given set of hashes_v - void InsertHashes(Vector &hashes_v, idx_t count, TupleDataChunkState &chunk_state, InsertState &insert_statebool, - bool parallel); + void InsertHashes(Vector &hashes_v, TupleDataChunkState &chunk_state, InsertState &insert_state, bool parallel); //! Prepares keys by filtering NULLs - idx_t PrepareKeys(DataChunk &keys, vector &vector_data, const SelectionVector *¤t_sel, - SelectionVector &sel, bool build_side); + idx_t PrepareKeys(DataChunk &keys, vector &vector_data, + optional_ptr ¤t_sel, SelectionVector &sel, bool build_side); + + unsafe_optional_ptr GetEntries() { + D_ASSERT(hash_map.get()); + return reinterpret_cast(hash_map.get()); + } + unsafe_optional_ptr> GetAtomicEntries() { + D_ASSERT(hash_map.get()); + return reinterpret_cast *>(hash_map.get()); + } //! Lock for combining data_collection when merging HTs mutex data_lock; @@ -404,7 +442,6 @@ class JoinHashTable { //! The hash map of the HT, created after finalization AllocatedData hash_map; - ht_entry_t *entries = nullptr; //! Whether or not NULL values are considered equal in each of the comparisons vector null_values_are_equal; //! An empty tuple that's a "dead end", can be used to stop chains early @@ -413,6 +450,7 @@ class JoinHashTable { //! Whether or not to use a bloom filter will be determined by the operator BloomFilter bloom_filter; bool should_build_bloom_filter = false; + idx_t bloom_filter_init_count = 0; unique_ptr prefix_range_filter; bool should_build_prefix_range_filter = false; @@ -497,6 +535,8 @@ class JoinHashTable { void SetBuildBloomFilter(const bool should_build) { this->should_build_bloom_filter = should_build; } + void PrepareBuildBloomFilter(idx_t estimated_row_count); + void PrepareBloomFilterForFinalize(); BloomFilter &GetBloomFilter() { return bloom_filter; @@ -523,7 +563,7 @@ class JoinHashTable { void MergePrefixRangeBuildState(PrefixRangeFilter::BuildState &state); //! Get total size of HT if all partitions would be built - idx_t GetTotalSize(const vector> &local_hts, idx_t &max_partition_size, + idx_t GetTotalSize(const vector> &local_hts, idx_t &max_partition_size, idx_t &max_partition_count) const; idx_t GetTotalSize(const vector &partition_sizes, const vector &partition_counts, idx_t &max_partition_size, idx_t &max_partition_count) const; @@ -545,6 +585,10 @@ class JoinHashTable { //! Delete blocks that belong to the current partitioned HT void Reset(); + //! Collapses the sink collection to a single partition. + //! Used by recursive CTEs to avoid per-iteration overhead of managing many radix partitions + //! when only one thread is building the hash table. + void ResetForNewIterationSinglePartition(); //! Build HT for the next partitioned probe round bool PrepareExternalFinalize(const idx_t max_ht_size); //! Probe whatever we can, sink the rest into a thread-local HT diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/aggregate_object.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/aggregate_object.hpp index b729df36e..3fd137b63 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/aggregate_object.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/aggregate_object.hpp @@ -25,7 +25,8 @@ struct FunctionDataWrapper { struct AggregateObject { // NOLINT: work-around bug in clang-tidy AggregateObject(BoundAggregateFunction function, FunctionData *bind_data, idx_t child_count, idx_t payload_size, - AggregateType aggr_type, PhysicalType return_type, Expression *filter = nullptr); + AggregateType aggr_type, PhysicalType return_type, optional_ptr filter = nullptr); + explicit AggregateObject(BoundAggregateExpression &aggr); explicit AggregateObject(BoundAggregateExpression *aggr); explicit AggregateObject(const BoundWindowExpression &window); @@ -39,7 +40,7 @@ struct AggregateObject { // NOLINT: work-around bug in clang-tidy idx_t payload_size; AggregateType aggr_type; PhysicalType return_type; - Expression *filter = nullptr; + optional_ptr filter = nullptr; public: bool IsDistinct() const { @@ -49,7 +50,8 @@ struct AggregateObject { // NOLINT: work-around bug in clang-tidy }; struct AggregateFilterData { - AggregateFilterData(ClientContext &context, Expression &filter_expr, const vector &payload_types); + AggregateFilterData(ClientContext &context, const Expression &filter_expr, + const vector &payload_types); idx_t ApplyFilter(DataChunk &payload); diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp index f5bc4128e..4ef6809e3 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp @@ -88,7 +88,7 @@ class PhysicalHashAggregate : public PhysicalOperator { unsafe_vector non_distinct_filter; unsafe_vector distinct_filter; - unordered_map filter_indexes; + reference_map_t filter_indexes; public: // Source interface diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp index 939f743b3..1783ed277 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp @@ -81,7 +81,7 @@ class PhysicalPerfectHashAggregate : public PhysicalOperator { //! The number of bits we need to completely cover each of the groups vector required_bits; - unordered_map filter_indexes; + reference_map_t filter_indexes; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_streaming_window.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_streaming_window.hpp index dfca3fa5e..74c3c8a81 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_streaming_window.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_streaming_window.hpp @@ -18,7 +18,7 @@ class PhysicalStreamingWindow : public PhysicalOperator { public: static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::STREAMING_WINDOW; - static bool IsStreamingFunction(ClientContext &context, unique_ptr &expr); + static bool IsStreamingFunction(ClientContext &context, BoundWindowExpression &wexpr); public: PhysicalStreamingWindow(PhysicalPlan &physical_plan, vector types, diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp index c35fe02c0..850ce1ea3 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/physical_window.hpp @@ -21,7 +21,9 @@ class PhysicalWindow : public PhysicalOperator { public: PhysicalWindow(PhysicalPlan &physical_plan, vector types, vector> select_list, - idx_t estimated_cardinality, PhysicalOperatorType type = PhysicalOperatorType::WINDOW); + idx_t estimated_cardinality, vector partitions); + PhysicalWindow(PhysicalPlan &physical_plan, vector types, vector> select_list, + idx_t estimated_cardinality); //! The projection list of the WINDOW statement (may contain aggregates) vector> select_list; @@ -30,6 +32,8 @@ class PhysicalWindow : public PhysicalOperator { //! Whether or not the window is order dependent (only true if ANY window function contains neither an order nor a //! partition clause) bool is_order_dependent; + //! The partitions over which this is grouped (if any) + OperatorPartitionInfo partition_info; public: // Source interface @@ -56,6 +60,7 @@ class PhysicalWindow : public PhysicalOperator { public: // Sink interface + SinkNextBatchType NextBatch(ExecutionContext &context, OperatorSinkNextBatchInput &input) const override; SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, @@ -66,6 +71,8 @@ class PhysicalWindow : public PhysicalOperator { unique_ptr GetLocalSinkState(ExecutionContext &context) const override; unique_ptr GetGlobalSinkState(ClientContext &context) const override; + OperatorPartitionInfo RequiredPartitionInfo() const override; + bool IsSink() const override { return true; } diff --git a/src/duckdb/src/include/duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp b/src/duckdb/src/include/duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp index cf1b13c55..8349882cc 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp @@ -19,21 +19,25 @@ struct LocalUngroupedAggregateState; struct UngroupedAggregateState { explicit UngroupedAggregateState(const vector> &aggregate_expressions); + UngroupedAggregateState(const UngroupedAggregateState &global_state); ~UngroupedAggregateState(); void Move(UngroupedAggregateState &other); public: - //! Aggregates - const vector> &aggregate_expressions; //! The aggregate values vector> aggregate_data; //! The bind data - vector> bind_data; - //! The destructors - vector destructors; + vector> bind_data; + //! Copies of aggregate functions (owns the callbacks and state size) + vector functions; + //! Aggregate types (distinct vs non-distinct) + vector aggregate_types; + //! Number of arguments per aggregate (children count) + vector argument_counts; //! Counts (used for verification) - unique_array> counts; + //! Note: these are either thread-local, or only modified while holding the global state's lock + unique_array counts; }; struct GlobalUngroupedAggregateState { @@ -68,9 +72,11 @@ struct LocalUngroupedAggregateState { ArenaAllocator &allocator; //! The local aggregate state UngroupedAggregateState state; + //! Reusable flat state-pointer vector for generic update callbacks + Vector repeated_state_vector; public: - void Sink(DataChunk &payload_chunk, idx_t payload_idx, idx_t aggr_idx); + void Sink(DataChunk &payload_chunk, idx_t payload_idx, idx_t aggr_idx, idx_t count); }; struct UngroupedAggregateExecuteState { diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp index aad90df94..597bcb169 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp @@ -55,7 +55,7 @@ class CSVError { const string &fixes, const String ¤t_path); CSVError(string error_message, CSVErrorType type, LinesPerBoundary error_info); //! Produces error messages for column name -> type mismatch. - static CSVError ColumnTypesError(case_insensitive_map_t sql_types_per_column, const vector &names); + static CSVError ColumnTypesError(identifier_map_t sql_types_per_column, const vector &names); //! Produces error messages for casting errors static CSVError CastError(const CSVReaderOptions &options, const string &column_name, string &cast_error, idx_t column_idx, string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp index 5446739ad..801c8ba08 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp @@ -34,7 +34,7 @@ class CSVFileScan : public BaseFileReader { //! Constructor for new CSV Files, we must initialize the buffer manager and the state machine //! Path to this file CSVFileScan(ClientContext &context, const OpenFileInfo &file, CSVReaderOptions options, - const MultiFileOptions &file_options, const vector &names, const vector &types, + const MultiFileOptions &file_options, const vector &names, const vector &types, CSVSchema &file_schema, bool per_file_single_threaded, shared_ptr buffer_manager = nullptr, bool fixed_schema = false); @@ -43,7 +43,7 @@ class CSVFileScan : public BaseFileReader { public: void SetStart(); - void SetNamesAndTypes(const vector &names, const vector &types); + void SetNamesAndTypes(const vector &names, const vector &types); public: string GetReaderType() const override { diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_multi_file_info.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_multi_file_info.hpp index 50d046fb0..1a5d86a67 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_multi_file_info.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_multi_file_info.hpp @@ -27,7 +27,7 @@ class CSVFileReaderOptions : public BaseFileReaderOptions { struct CSVSchemaDiscovery { static CSVSchema SchemaDiscovery(ClientContext &context, shared_ptr &buffer_manager, CSVReaderOptions &options, const MultiFileOptions &file_options, - vector &return_types, vector &names, + vector &return_types, vector &names, MultiFileList &multi_file_list); }; @@ -45,7 +45,7 @@ struct CSVMultiFileInfo : MultiFileReaderInterface { const vector &expected_types) override; unique_ptr InitializeBindData(MultiFileBindData &multi_file_data, unique_ptr options) override; - void BindReader(ClientContext &context, vector &return_types, vector &names, + void BindReader(ClientContext &context, vector &return_types, vector &names, MultiFileBindData &bind_data) override; void FinalizeBindData(MultiFileBindData &multi_file_data) override; optional_idx MaxThreads(const MultiFileBindData &bind_data_p, const MultiFileGlobalState &global_state, diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp index 744710bef..b53b14ebb 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_reader_options.hpp @@ -54,6 +54,8 @@ struct CSVReaderOptions { CSVOption rejects_scan_name = {"reject_scans"}; //! Rejects table entry limit (0 = no limit) idx_t rejects_limit = 0; + //! Rejects line size limit (0 = no limit) + idx_t rejects_line_size_limit = 10000; //! Number of samples to buffer idx_t buffer_sample_size = static_cast(STANDARD_VECTOR_SIZE * 50); //! Specifies the strings that represents a null value @@ -72,7 +74,7 @@ struct CSVReaderOptions { // CSVAutoOptions //===--------------------------------------------------------------------===// //! SQL Type list mapping of name to SQL type index in sql_type_list - case_insensitive_map_t sql_types_per_column; + identifier_map_t sql_types_per_column; //! User-defined SQL type list vector sql_type_list; //! User-defined name list @@ -80,10 +82,10 @@ struct CSVReaderOptions { //! If the names and types were set by the columns parameter bool columns_set = false; //! Types considered as candidates for auto-detection ordered by ascending specificity (~ from low to high) - vector auto_type_candidates = { - LogicalType::VARCHAR, LogicalType::DOUBLE, LogicalType::BIGINT, - LogicalType::TIMESTAMP_TZ, LogicalType::TIMESTAMP, LogicalType::DATE, - LogicalType::TIME, LogicalType::BOOLEAN, LogicalType::SQLNULL}; + vector auto_type_candidates = {LogicalType::VARCHAR, LogicalType::DOUBLE, LogicalType::BIGNUM, + LogicalType::HUGEINT, LogicalType::BIGINT, LogicalType::TIMESTAMP_TZ, + LogicalType::TIMESTAMP, LogicalType::DATE, LogicalType::TIME, + LogicalType::BOOLEAN, LogicalType::SQLNULL}; //! In case the sniffer found a mismatch error from user defined types or dialect string sniffer_user_mismatch_error; //! In case the sniffer found a mismatch error from user defined types or dialect diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_schema.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_schema.hpp index 7cb93f485..a7c40001e 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_schema.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_schema.hpp @@ -14,9 +14,9 @@ namespace duckdb { //! Basic CSV Column Info struct CSVColumnInfo { - CSVColumnInfo(const string &name_p, const LogicalType &type_p) : name(name_p), type(type_p) { + CSVColumnInfo(const Identifier &name_p, const LogicalType &type_p) : name(name_p), type(type_p) { } - string name; + Identifier name; LogicalType type; }; @@ -25,11 +25,11 @@ struct CSVSchema { explicit CSVSchema(const bool empty = false) : empty(empty) { } - CSVSchema(const vector &names, const vector &types, const string &file_path, idx_t rows_read, - const bool empty = false); + CSVSchema(const vector &names, const vector &types, const string &file_path, + idx_t rows_read, const bool empty = false); //! Initializes the schema based on names and types - void Initialize(const vector &names, const vector &types, const string &file_path); + void Initialize(const vector &names, const vector &types, const string &file_path); //! If the schema is empty bool Empty() const; @@ -66,7 +66,7 @@ struct CSVSchema { //! If a type can be cast to another static bool CanWeCastIt(LogicalTypeId source, LogicalTypeId destination); vector columns; - unordered_map name_idx_map; + identifier_map_t name_idx_map; string file_path; idx_t rows_read = 0; bool empty = false; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp index 8ab081ae3..09b44ee66 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp @@ -102,6 +102,7 @@ class CSVSniffer { static bool CanYouCastIt(ClientContext &context, const string_t value, const LogicalType &type, const DialectOptions &dialect_options, const bool is_null, const char decimal_separator, const char thousands_separator); + static bool CanYouCastBignum(const char *value_ptr, idx_t value_size); idx_t LinesSniffed() const; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/sniff_result.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/sniff_result.hpp index 2a8e0de19..1c41b05b4 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/sniff_result.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/sniffer/sniff_result.hpp @@ -15,17 +15,17 @@ namespace duckdb { //! Struct to store the result of the Sniffer struct SnifferResult { - SnifferResult(vector return_types_p, vector names_p) + SnifferResult(vector return_types_p, vector names_p) : return_types(std::move(return_types_p)), names(std::move(names_p)) { } //! Return Types that were detected vector return_types; //! Column Names that were detected - vector names; + vector names; }; struct AdaptiveSnifferResult : SnifferResult { - AdaptiveSnifferResult(vector return_types_p, vector names_p, bool more_than_one_row_p) + AdaptiveSnifferResult(vector return_types_p, vector names_p, bool more_than_one_row_p) : SnifferResult(std::move(return_types_p), std::move(names_p)), more_than_one_row(more_than_one_row_p) { } bool more_than_one_row; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp index 32ce543e7..ab4216dc7 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/string_value_scanner.hpp @@ -180,8 +180,8 @@ class StringValueResult : public ScannerResult { StringValueResult(CSVStates &states, CSVStateMachine &state_machine, const shared_ptr &buffer_handle, Allocator &buffer_allocator, idx_t result_size_p, idx_t buffer_position, CSVErrorHandler &error_handler, CSVIterator &iterator, - bool store_line_size, shared_ptr csv_file_scan, idx_t &lines_read, bool sniffing, - const string &path, idx_t scan_id, bool &used_unstrictness); + shared_ptr csv_file_scan, idx_t &lines_read, bool sniffing, const string &path, + idx_t scan_id, bool &used_unstrictness); ~StringValueResult(); @@ -218,7 +218,6 @@ class StringValueResult : public ScannerResult { FullLinePosition current_line_position; //! Used for CSV line reconstruction on flushed errors unordered_map line_positions_per_row; - bool store_line_size = false; bool added_last_line = false; bool quoted_new_line = false; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_connect.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_connect.hpp new file mode 100644 index 000000000..6d0a68a9e --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_connect.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_connect.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/connect_info.hpp" + +namespace duckdb { + +//! PhysicalConnect routes subsequent SQL on this ClientContext to the named AttachedDatabase via +//! Catalog::ExecuteSQL. The case-insensitive name "LOCAL" unbinds the routing. +class PhysicalConnect : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CONNECT; + +public: + PhysicalConnect(PhysicalPlan &physical_plan, unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(physical_plan, PhysicalOperatorType::CONNECT, {LogicalType::BOOLEAN}, estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_disconnect.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_disconnect.hpp new file mode 100644 index 000000000..33af06117 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_disconnect.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/operator/helper/physical_disconnect.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/execution/physical_operator.hpp" +#include "duckdb/parser/parsed_data/disconnect_info.hpp" + +namespace duckdb { + +//! PhysicalDisconnect clears any active CONNECT binding on the ClientContext. +class PhysicalDisconnect : public PhysicalOperator { +public: + static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::DISCONNECT; + +public: + PhysicalDisconnect(PhysicalPlan &physical_plan, unique_ptr info, idx_t estimated_cardinality) + : PhysicalOperator(physical_plan, PhysicalOperatorType::DISCONNECT, {LogicalType::BOOLEAN}, + estimated_cardinality), + info(std::move(info)) { + } + + unique_ptr info; + +public: + SourceResultType GetDataInternal(ExecutionContext &context, DataChunk &chunk, + OperatorSourceInput &input) const override; + + bool IsSource() const override { + return true; + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp index 5f3e6cb23..d55260c43 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_prepare.hpp @@ -20,10 +20,10 @@ class PhysicalPrepare : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::PREPARE; public: - PhysicalPrepare(PhysicalPlan &physical_plan, const std::string &name_p, shared_ptr prepared, + PhysicalPrepare(PhysicalPlan &physical_plan, const Identifier &name_p, shared_ptr prepared, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::PREPARE, {LogicalType::BOOLEAN}, estimated_cardinality), - name(physical_plan.ArenaRef().MakeString(name_p)), prepared(std::move(prepared)) { + name(physical_plan.ArenaRef().MakeString(name_p.GetIdentifierName())), prepared(std::move(prepared)) { } String name; diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp index aaf593db8..dbcfbf63c 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_reset.hpp @@ -24,9 +24,9 @@ class PhysicalReset : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::RESET; public: - PhysicalReset(PhysicalPlan &physical_plan, const std::string &name_p, SetScope scope_p, idx_t estimated_cardinality) + PhysicalReset(PhysicalPlan &physical_plan, const Identifier &name_p, SetScope scope_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::RESET, {LogicalType::BOOLEAN}, estimated_cardinality), - name(physical_plan.ArenaRef().MakeString(name_p)), scope(scope_p) { + name(physical_plan.ArenaRef().MakeString(name_p.GetIdentifierName())), scope(scope_p) { } public: diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp index c1114b091..35cc31c8e 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_result_collector.hpp @@ -29,7 +29,7 @@ class PhysicalResultCollector : public PhysicalOperator { StatementProperties properties; QueryResultMemoryType memory_type; PhysicalOperator &plan; - vector names; + vector names; public: static unique_ptr GetResultCollector(ClientContext &context, PreparedStatementData &data); diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp index 56da7a90e..3ade08da6 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp @@ -25,10 +25,11 @@ class PhysicalSet : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::SET; public: - PhysicalSet(PhysicalPlan &physical_plan, const string &name_p, Value value_p, SetScope scope_p, + PhysicalSet(PhysicalPlan &physical_plan, const Identifier &name_p, Value value_p, SetScope scope_p, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::SET, {LogicalType::BOOLEAN}, estimated_cardinality), - name(physical_plan.ArenaRef().MakeString(name_p)), value(std::move(value_p)), scope(scope_p) { + name(physical_plan.ArenaRef().MakeString(name_p.GetIdentifierName())), value(std::move(value_p)), + scope(scope_p) { } public: diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp index 2e1a0252f..2859f6aa6 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set_variable.hpp @@ -18,7 +18,7 @@ class PhysicalSetVariable : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::SET_VARIABLE; public: - PhysicalSetVariable(PhysicalPlan &physical_plan, const string &name_p, idx_t estimated_cardinality); + PhysicalSetVariable(PhysicalPlan &physical_plan, const Identifier &name_p, idx_t estimated_cardinality); public: // Source interface diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp index 76723233a..dcb6b9dd3 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/projection_index.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/planner/column_binding.hpp" @@ -24,11 +25,23 @@ class PerfectHashJoinExecutor; struct GlobalUngroupedAggregateState; struct LocalUngroupedAggregateState; +enum class JoinFilterPushdownMode : uint8_t { + //! The pushed expression can be reconstructed on top of the raw scan value for BF/PRF/PHJ runtime filters + RECONSTRUCT_EXPRESSION, + //! Only storage-domain filters are safe; BF/PRF/PHJ reconstruction on raw scan values is not + STORAGE_ONLY +}; + struct JoinFilterPushdownColumn { //! The probe column index to which this filter should be applied ColumnBinding probe_column_index; //! The type of the value in storage (LogicalGet) LogicalType storage_type; + //! Whether runtime filters can reconstruct the pushed expression, or whether only storage-domain filters are safe + JoinFilterPushdownMode mode = JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION; + //! The original type of the pushed probe expression before rewriting to the LogicalGet storage column. Only used + //! when the mode allows reconstruction of the probe expression for BF/PRF/PHJ runtime filters. + LogicalType runtime_filter_type; }; struct JoinFilterGlobalState { @@ -61,6 +74,11 @@ struct PushdownFilterTarget { vector columns; }; +struct JoinFilterPushdownUtil { + static bool PushdownJoinFilterExpression(const Expression &expr, JoinFilterPushdownColumn &filter); + static bool JoinTypeIsSupported(JoinType join_type); +}; + struct JoinFilterPushdownInfo { //! The join condition indexes for which we compute the min/max aggregates vector join_condition; @@ -90,13 +108,14 @@ struct JoinFilterPushdownInfo { void PushInFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, const PhysicalOperator &op, idx_t filter_idx, ProjectionIndex filter_col_idx) const; - void PushBloomFilter(const PhysicalOperator &op, JoinHashTable &ht, const JoinFilterPushdownFilter &info, - ProjectionIndex filter_col_idx) const; - void PushPerfectHashJoinFilter(const PhysicalOperator &op, PerfectHashJoinExecutor &perfect_join_executor, - const JoinFilterPushdownFilter &info, ProjectionIndex filter_col_idx) const; + void PushBloomFilter(ClientContext &context, const PhysicalOperator &op, JoinHashTable &ht, + const JoinFilterPushdownFilter &info, idx_t filter_idx, ProjectionIndex filter_col_idx) const; + void PushPerfectHashJoinFilter(ClientContext &context, const PhysicalOperator &op, + PerfectHashJoinExecutor &perfect_join_executor, const JoinFilterPushdownFilter &info, + idx_t filter_idx, ProjectionIndex filter_col_idx) const; void RegisterPrefixRangeFilter(const JoinFilterPushdownFilter &info, ClientContext &context, JoinHashTable &ht, - const PhysicalOperator &op, ProjectionIndex filter_col_idx, const Value &min_val, - const Value &max_val) const; + const PhysicalOperator &op, idx_t filter_idx, ProjectionIndex filter_col_idx, + const Value &min_val, const Value &max_val) const; bool CanUseInFilter(const ClientContext &context, optional_ptr ht, const ExpressionType &cmp) const; bool CanUseBloomFilter(const ClientContext &context, const PhysicalComparisonJoin &op, const ExpressionType &cmp, diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp index 09815ce1e..744a8fb60 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/perfect_hash_join_executor.hpp @@ -44,21 +44,21 @@ class PerfectHashJoinExecutor { OperatorResultType ProbePerfectHashTable(ExecutionContext &context, DataChunk &input, DataChunk &lhs_output_columns, DataChunk &chunk, OperatorState &state); - void FillSelectionVectorSwitchProbe(Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, + void FillSelectionVectorSwitchProbe(const Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, idx_t &probe_sel_count, optional_ptr build_sel_vec) const; private: template - void FillSelectionVectorSwitchProbe(Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, + void FillSelectionVectorSwitchProbe(const Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, idx_t &probe_sel_count, SelectionVector *build_sel_vec) const; template - void TemplatedFillSelectionVectorProbe(Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, + void TemplatedFillSelectionVectorProbe(const Vector &source, const idx_t &count, SelectionVector &probe_sel_vec, idx_t &probe_sel_count, SelectionVector *build_sel_vec) const; - bool FillSelectionVectorSwitchBuild(Vector &source, SelectionVector &sel_vec, SelectionVector &seq_sel_vec, + bool FillSelectionVectorSwitchBuild(const Vector &source, SelectionVector &sel_vec, SelectionVector &seq_sel_vec, idx_t count); template - bool TemplatedFillSelectionVectorBuild(Vector &source, SelectionVector &sel_vec, SelectionVector &seq_sel_vec, + bool TemplatedFillSelectionVectorBuild(const Vector &source, SelectionVector &sel_vec, SelectionVector &seq_sel_vec, idx_t count); bool FullScanHashTable(); diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_cross_product.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_cross_product.hpp index 87a6b5198..8c78a0c6c 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_cross_product.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_cross_product.hpp @@ -40,6 +40,7 @@ class PhysicalCrossProduct : public CachingPhysicalOperator { public: // Sink Interface unique_ptr GetGlobalSinkState(ClientContext &context) const override; + unique_ptr GetLocalSinkState(ExecutionContext &context) const override; SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; bool IsSink() const override { @@ -62,6 +63,7 @@ class CrossProductExecutor { explicit CrossProductExecutor(ColumnDataCollection &rhs); OperatorResultType Execute(const DataChunk &input, DataChunk &output); + void Reset(); // returns if the left side is scanned as a constant vector bool ScanLHS() { diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp index 42989bcb5..445624c93 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_hash_join.hpp @@ -110,6 +110,8 @@ class PhysicalHashJoin : public PhysicalComparisonJoin { unique_ptr GetGlobalSinkState(ClientContext &context) const override; unique_ptr GetLocalSinkState(ExecutionContext &context) const override; + void SetPreserveBuildForRecursiveReuse(bool preserve) const; + bool CanPreserveBuildForRecursiveReuse() const; SinkResultType Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const override; SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; void PrepareFinalize(ClientContext &context, GlobalSinkState &global_state) const override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp index 394ef8479..9cee4889e 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp @@ -26,6 +26,7 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { LocalSortedTable(ExecutionContext &context, GlobalSortedTable &global_table, const idx_t child); void Sink(ExecutionContext &context, DataChunk &input); + void ResetForReuse(ExecutionContext &context); //! The global table we are connected to GlobalSortedTable &global_table; @@ -80,6 +81,7 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { //! Combine local states void Combine(ExecutionContext &context, LocalSortedTable <able); + void ResetForReuse(ClientContext &client); //! Prepare for sorting. void Finalize(ClientContext &client, InterruptState &interrupt); //! Schedules the materialisation process. @@ -160,7 +162,7 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { TupleDataChunkState &chunk_state, const idx_t chunk_idx, const unsafe_vector &result, SortedRunScanState &scan_state); // Apply a tail condition to the current selection - static idx_t SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, + static idx_t SelectJoinTail(const ExpressionType &condition, const Vector &left, const Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel); //! Utility to project full width internal chunks to projected results diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp index f4bcf13c7..d8d31d5d3 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp @@ -19,6 +19,7 @@ namespace duckdb { struct GlobalFileState; +struct BoundOrderByNode; struct CopyToFileInfo { explicit CopyToFileInfo(string file_path_p) : file_path(std::move(file_path_p)) { @@ -70,6 +71,9 @@ class PhysicalCopyToFile : public PhysicalOperator { SinkCombineResultType Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const override; SinkFinalizeType Finalize(Pipeline &pipeline, Event &event, ClientContext &context, OperatorSinkFinalizeInput &input) const override; + //! Synchronously complete any pending partitioned-copy flush work after Finalize has been called. + //! Used by callers (e.g. DuckLakeUpdate) that drive Sink/Finalize/GetData manually outside a pipeline. + void FinalizePartitionedSync(ExecutionContext &execution_context, InterruptState &interrupt_state) const; unique_ptr GetLocalSinkState(ExecutionContext &context) const override; unique_ptr GetGlobalSinkState(ClientContext &context) const override; @@ -103,7 +107,7 @@ class PhysicalCopyToFile : public PhysicalOperator { unique_ptr bind_data; //! Names and types going into the file(s) - vector names; + vector names; vector expected_types; //! Where to write the file @@ -138,6 +142,9 @@ class PhysicalCopyToFile : public PhysicalOperator { bool partition_output; bool write_partition_columns; bool hive_file_pattern; + + //! If the data should be sorted + vector order_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp index 87e1a7de7..f76265eb6 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/scan/physical_table_scan.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/enums/metric_type.hpp" #include "duckdb/common/optional_idx.hpp" #include "duckdb/execution/physical_operator.hpp" #include "duckdb/execution/physical_operator_states.hpp" @@ -87,9 +88,8 @@ class PhysicalTableScan : public PhysicalOperator { ProgressData GetProgress(ClientContext &context, GlobalSourceState &gstate) const override; - InsertionOrderPreservingMap ExtraSourceParams(GlobalSourceState &gstate, - LocalSourceState &lstate) const override; - optional_idx GetRowsScanned(GlobalSourceState &gstate_p, LocalSourceState &lstate) const; + void GetMetrics(ClientContext &context, GlobalSourceState &gstate_p, LocalSourceState &lstate, + OperatorMetrics &operator_metrics) const; private: string GetFilterInfo(const TableFilterSet &filter_set) const; diff --git a/src/duckdb/src/include/duckdb/execution/operator/set/physical_cte.hpp b/src/duckdb/src/include/duckdb/execution/operator/set/physical_cte.hpp index e54e2e3c2..45a0c4c55 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/set/physical_cte.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/set/physical_cte.hpp @@ -20,7 +20,7 @@ class PhysicalCTE : public PhysicalOperator { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::CTE; public: - PhysicalCTE(PhysicalPlan &physical_plan, string ctename, TableIndex table_index, vector types, + PhysicalCTE(PhysicalPlan &physical_plan, Identifier ctename, TableIndex table_index, vector types, PhysicalOperator &top, PhysicalOperator &bottom, idx_t estimated_cardinality); ~PhysicalCTE() override; @@ -29,7 +29,7 @@ class PhysicalCTE : public PhysicalOperator { shared_ptr working_table; TableIndex table_index; - string ctename; + Identifier ctename; bool cte_body_is_dml = false; public: @@ -55,6 +55,9 @@ class PhysicalCTE : public PhysicalOperator { InsertionOrderPreservingMap ParamsToString() const override; + ProgressData GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, + const ProgressData source_progress) const override; + public: void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; diff --git a/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp b/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp index a04d50a6b..5ccff6f6b 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/reference_map.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/execution/physical_operator.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" @@ -15,17 +16,24 @@ namespace duckdb { class RecursiveCTEState; +struct RecursiveExecutorPool; +class PhysicalColumnDataScan; +class Pipeline; +class PipelineExecutor; class PhysicalRecursiveCTE : public PhysicalOperator { public: static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::RECURSIVE_CTE; + using executor_cache_t = reference_map_t>>; + friend class RecursiveCTEState; public: - PhysicalRecursiveCTE(PhysicalPlan &physical_plan, string ctename, TableIndex table_index, vector types, - bool union_all, PhysicalOperator &top, PhysicalOperator &bottom, idx_t estimated_cardinality); + PhysicalRecursiveCTE(PhysicalPlan &physical_plan, Identifier ctename, TableIndex table_index, + vector types, bool union_all, PhysicalOperator &top, PhysicalOperator &bottom, + idx_t estimated_cardinality); ~PhysicalRecursiveCTE() override; - string ctename; + Identifier ctename; TableIndex table_index; // Flag if recurring table is referenced, if not we do not copy ht into ColumnDataCollection bool ref_recurring; @@ -45,6 +53,14 @@ class PhysicalRecursiveCTE : public PhysicalOperator { vector payload_idx, distinct_idx; // Contains the aggregates for the payload vector> payload_aggregates; + //! Number of recursive table scans inside the recursive member + idx_t recursive_reference_count = 0; + //! Number of recurring table scans inside the recursive member + idx_t recurring_reference_count = 0; + //! Recursive table scans rebound to the current iteration input buffer + vector> recursive_scans; + //! Recursive meta-pipelines that are independent of the active recursive scan graph and can be materialized once + reference_set_t invariant_meta_pipelines; public: // Source interface @@ -81,6 +97,9 @@ class PhysicalRecursiveCTE : public PhysicalOperator { idx_t ProbeHT(DataChunk &chunk, RecursiveCTEState &state) const; void ExecuteRecursivePipelines(ExecutionContext &context) const; + +private: + mutable shared_ptr shared_executor_pool; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte_state.hpp b/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte_state.hpp new file mode 100644 index 000000000..1613d0d75 --- /dev/null +++ b/src/duckdb/src/include/duckdb/execution/operator/set/physical_recursive_cte_state.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include "duckdb/common/types/column/column_data_collection.hpp" +#include "duckdb/execution/aggregate_hashtable.hpp" +#include "duckdb/execution/executor.hpp" +#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" + +namespace duckdb { + +struct RecursiveExecutorPool { + mutex lock; + PhysicalRecursiveCTE::executor_cache_t executors; +}; + +enum class RecursiveCTEInlineStageType : uint8_t { EXECUTE, PREPARE_FINISH, FINISH }; + +struct RecursiveCTEInlineStage { + RecursiveCTEInlineStage(RecursiveCTEInlineStageType type_p, Pipeline &pipeline_p) + : type(type_p), pipeline(pipeline_p), dependency_count(0) { + } + + RecursiveCTEInlineStageType type; + reference pipeline; + vector dependents; + idx_t dependency_count; +}; + +struct RecursiveCTEInlinePlan { + vector stages; + vector> initialize_on_schedule_pipelines; +}; + +class RecursiveCTEState : public GlobalSinkState { +public: + explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op); + ~RecursiveCTEState() override; + + void InitializeIntermediateAppend(); + void ResetRecurringTable(); + ColumnDataCollection &CurrentOutputTable(); + ColumnDataCollection &CurrentInputTable(); + const ColumnDataCollection &CurrentInputTable() const; + ColumnDataAppendState &CurrentOutputAppendState(); + void AdvanceIterationBuffers(); + void ResetCurrentOutputTableForReuse(); + void RebindRecursiveScans(); + vector> &GetCachedExecutors(Pipeline &pipeline, idx_t max_threads); + void ClearCachedExecutors(); + + unique_ptr ht; + const PhysicalRecursiveCTE &op; + ExpressionExecutor executor; + DataChunk payload_rows; + const bool allow_executor_reuse; + shared_ptr executor_pool; + + mutex intermediate_table_lock; + ColumnDataCollection intermediate_table; + ColumnDataAppendState intermediate_append_state; + ColumnDataAppendState working_append_state; + ColumnDataAppendState recurring_append_state; + ColumnDataScanState scan_state; + bool initialized = false; + bool finished_scan = false; + bool output_is_working = false; + SelectionVector new_groups; + //! Cached dummy address vector for ProbeHT (avoids per-chunk VectorBuffer allocation) + Vector dummy_addresses; + //! Cached chunk for distinct key extraction in the using_key Sink path + DataChunk distinct_rows; + //! Cached chunks for source-side hash table scans and recurring table copy paths + DataChunk source_result; + DataChunk source_payload_rows; + DataChunk source_distinct_rows; + AggregateHTScanState ht_scan_state; + + //! Cached PipelineExecutors per pipeline for reuse across recursive iterations + //! When both locks are needed, always acquire cached_executor_lock before executor_pool->lock. + mutex cached_executor_lock; + PhysicalRecursiveCTE::executor_cache_t cached_executors; + //! Cached dependency graph for the single-thread inline recursive fast path + unique_ptr inline_plan; + //! Cached dependency graph after invariant meta-pipelines have been materialized once + unique_ptr invariant_inline_plan; + //! Whether invariant recursive meta-pipelines have already been materialized for this state + bool invariant_meta_pipelines_materialized = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/perfect_aggregate_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/perfect_aggregate_hashtable.hpp index 211c27b36..17bec3b99 100644 --- a/src/duckdb/src/include/duckdb/execution/perfect_aggregate_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/perfect_aggregate_hashtable.hpp @@ -10,6 +10,7 @@ #include "duckdb/execution/base_aggregate_hashtable.hpp" #include "duckdb/storage/arena_allocator.hpp" +#include "duckdb/common/clustered_aggregate.hpp" namespace duckdb { @@ -61,7 +62,11 @@ class PerfectAggregateHashTable : public BaseAggregateHashTable { //! Owning arena allocators that this HT has data from vector> stored_allocators; + ClusteredAggrState clustered_state; + private: + //! Try adding a chunk using the clustered aggregation path. Returns false if not applicable. + bool AddChunkClustered(uintptr_t *address_data, DataChunk &payload); //! Destroy the perfect aggregate HT (called automatically by the destructor) void Destroy(); }; diff --git a/src/duckdb/src/include/duckdb/execution/physical_operator.hpp b/src/duckdb/src/include/duckdb/execution/physical_operator.hpp index 4efd2f6f2..ba1082140 100644 --- a/src/duckdb/src/include/duckdb/execution/physical_operator.hpp +++ b/src/duckdb/src/include/duckdb/execution/physical_operator.hpp @@ -97,6 +97,7 @@ class PhysicalOperator { // Operator interface virtual unique_ptr GetOperatorState(ExecutionContext &context) const; virtual unique_ptr GetGlobalOperatorState(ClientContext &context) const; + virtual bool ResetGlobalOperatorState(ClientContext &context, GlobalOperatorState &state) const; virtual OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, GlobalOperatorState &gstate, OperatorState &state) const; virtual OperatorFinalizeResultType FinalExecute(ExecutionContext &context, DataChunk &chunk, @@ -170,11 +171,6 @@ class PhysicalOperator { return source_progress; } - virtual InsertionOrderPreservingMap ExtraSourceParams(GlobalSourceState &gstate, - LocalSourceState &lstate) const { - return InsertionOrderPreservingMap(); - } - public: // Sink interface @@ -259,6 +255,14 @@ class CachingOperatorState : public OperatorState { void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { } + void ResetCachingState() { + cached_chunk.reset(); + initialized = false; + can_cache_chunk = OperatorCachingMode::NONE; + must_return_continuation_chunk = false; + cached_result = OperatorResultType::NEED_MORE_INPUT; + } + unique_ptr cached_chunk; bool initialized = false; //! Whether or not the chunk can be cached diff --git a/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp b/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp index a471c229f..73de239aa 100644 --- a/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp +++ b/src/duckdb/src/include/duckdb/execution/physical_operator_states.hpp @@ -36,6 +36,13 @@ class OperatorState { virtual void Finalize(const PhysicalOperator &op, ExecutionContext &context) { } + virtual bool SupportsReuse() const { + return false; + } + + virtual void Reset() { + } + template TARGET &Cast() { DynamicCastCheck(this); @@ -78,6 +85,16 @@ class GlobalSinkState : public StateWithBlockableTasks { SinkFinalizeType state; + virtual bool SupportsReuse() const { + return false; + } + + virtual void Reset(ClientContext &context) { + annotated_lock_guard guard(lock); + ResetBlocking(); + state = SinkFinalizeType::READY; + } + template TARGET &Cast() { DynamicCastCheck(this); @@ -102,6 +119,13 @@ class LocalSinkState { //! Source partition info SourcePartitionInfo partition_info; + virtual bool SupportsReuse() const { + return false; + } + + virtual void Reset(ExecutionContext &context, GlobalSinkState &gstate) { + } + template TARGET &Cast() { DynamicCastCheck(this); @@ -123,6 +147,15 @@ class GlobalSourceState : public StateWithBlockableTasks { return 1; } + virtual bool SupportsReuse() const { + return false; + } + + virtual void Reset(ClientContext &context) { + annotated_lock_guard guard(lock); + ResetBlocking(); + } + template TARGET &Cast() { DynamicCastCheck(this); @@ -140,6 +173,13 @@ class LocalSourceState { virtual ~LocalSourceState() { } + virtual bool SupportsReuse() const { + return false; + } + + virtual void Reset(ExecutionContext &context, GlobalSourceState &gstate) { + } + template TARGET &Cast() { DynamicCastCheck(this); diff --git a/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp b/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp index db998509a..4f46d320e 100644 --- a/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp +++ b/src/duckdb/src/include/duckdb/execution/physical_plan_generator.hpp @@ -93,7 +93,9 @@ class PhysicalPlanGenerator { static bool PreserveInsertionOrder(ClientContext &context, PhysicalOperator &plan); //! The order preservation type of the given operator decided by recursively looking at its children static OrderPreservationType OrderPreservationRecursive(PhysicalOperator &op); - + //! Determine whether a child has a single value partitioning for the given expressions. + static bool HasSingleValuePartitions(ClientContext &context, const vector> &partitions, + PhysicalOperator &child, vector &partition_columns); //! Make a physical operator in the physical plan. template PhysicalOperator &Make(ARGS &&... args) { diff --git a/src/duckdb/src/include/duckdb/execution/radix_partitioned_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/radix_partitioned_hashtable.hpp index cf294cfee..464130a15 100644 --- a/src/duckdb/src/include/duckdb/execution/radix_partitioned_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/radix_partitioned_hashtable.hpp @@ -40,6 +40,8 @@ class RadixPartitionedHashTable { //! Sink Interface unique_ptr GetGlobalSinkState(ClientContext &context) const; unique_ptr GetLocalSinkState(ExecutionContext &context) const; + void ResetGlobalSinkState(ClientContext &context, GlobalSinkState &gstate) const; + void ResetLocalSinkState(ExecutionContext &context, GlobalSinkState &gstate, LocalSinkState &lstate) const; void Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, DataChunk &aggregate_input_chunk, const unsafe_vector &filter) const; @@ -50,6 +52,8 @@ class RadixPartitionedHashTable { //! Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const; unique_ptr GetLocalSourceState(ExecutionContext &context) const; + void ResetGlobalSourceState(ClientContext &context, GlobalSourceState &gstate) const; + void ResetLocalSourceState(ExecutionContext &context, LocalSourceState &lstate) const; SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, GlobalSinkState &sink, OperatorSourceInput &input) const; diff --git a/src/duckdb/src/include/duckdb/function/aggregate/distributive_function_utils.hpp b/src/duckdb/src/include/duckdb/function/aggregate/distributive_function_utils.hpp index c42c34346..820bcfd1d 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate/distributive_function_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate/distributive_function_utils.hpp @@ -12,6 +12,87 @@ namespace duckdb { +struct ClusteredStateCopy { + template + struct ClusteredLocalState { + typedef STATE Type; + }; + + template + static void InitializeClusteredLocal(STATE &local, const STATE &state) { + local = state; + } + + template + static void FlushClusteredLocal(STATE &state, const STATE &local, bool) { + state = local; + } +}; + +template +struct ConstantInit { + template + static constexpr T Value() { + return T(INIT_VALUE); + } +}; + +template +struct EmptyValAggregate : public ClusteredStateCopy { + template + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input) { + local.empty = false; + using value_type = decltype(local.val); + local.val = REDUCE_OP::template Operation(local.val, value_type(input)); + } + + template + static void UpdateClusteredLocal(STATE &local, const INPUT_TYPE &input, idx_t count) { + if (count != 0) { + UpdateClusteredLocal(local, input); + } + } + + template + static void Initialize(STATE &state) { + state.val = INIT_OP::template Value(); + state.empty = true; + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &) { + using value_type = decltype(target.val); + target.val = REDUCE_OP::template Operation(target.val, source.val); + target.empty = target.empty && source.empty; + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (state.empty) { + finalize_data.ReturnNull(); + return; + } + target = state.val; + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + UpdateClusteredLocal(state, input); + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + static bool IgnoreNull() { + return true; + } +}; + struct CountFunctionBase { static AggregateFunction GetFunction(); }; diff --git a/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp b/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp index 87a45eb0e..d82e90c91 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp @@ -315,12 +315,11 @@ struct MinMaxFixedValue { } // Nothing to do here - static EXTRA_STATE CreateExtraState(Vector &input, idx_t count) { + static EXTRA_STATE CreateExtraState() { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format, - const bool nulls_last) { + static void PrepareData(const Vector &input, EXTRA_STATE &, UnifiedVectorFormat &format, const bool nulls_last) { input.ToUnifiedFormat(format); } }; @@ -338,12 +337,11 @@ struct MinMaxStringValue { } // Nothing to do here - static EXTRA_STATE CreateExtraState(Vector &input, idx_t count) { + static EXTRA_STATE CreateExtraState() { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &, UnifiedVectorFormat &format, - const bool nulls_last) { + static void PrepareData(const Vector &input, EXTRA_STATE &, UnifiedVectorFormat &format, const bool nulls_last) { input.ToUnifiedFormat(format); } }; @@ -363,15 +361,15 @@ struct MinMaxFallbackValue { CreateSortKeyHelpers::DecodeSortKey(value, vector, idx, modifiers); } - static EXTRA_STATE CreateExtraState(Vector &input, idx_t count) { + static EXTRA_STATE CreateExtraState() { return Vector(LogicalTypeId::BLOB); } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, + static void PrepareData(const Vector &input, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, const bool nulls_last) { auto order_by_null_type = nulls_last ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; const OrderModifiers modifiers(OrderType::ASCENDING, order_by_null_type); - CreateSortKeyHelpers::CreateSortKeyWithValidity(input, extra_state, modifiers, count); + CreateSortKeyHelpers::CreateSortKeyWithValidity(input, extra_state, modifiers); input.Flatten(); extra_state.ToUnifiedFormat(format); } @@ -412,11 +410,11 @@ struct MinMaxFixedValueOrNull { FlatVector::GetDataMutable(vector)[idx] = value.value; } - static EXTRA_STATE CreateExtraState(Vector &input, idx_t count) { + static EXTRA_STATE CreateExtraState() { return false; } - static void PrepareData(Vector &input, const idx_t count, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, + static void PrepareData(const Vector &input, EXTRA_STATE &extra_state, UnifiedVectorFormat &format, const bool nulls_last) { input.ToUnifiedFormat(format); } @@ -426,11 +424,6 @@ struct MinMaxFixedValueOrNull { // MinMaxN Operation (common for both ArgMinMaxN and MinMaxN) //------------------------------------------------------------------------------ struct MinMaxNOperation { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - template static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input) { if (!source.is_initialized) { @@ -502,7 +495,6 @@ struct MinMaxNOperation { D_ASSERT(current_offset == old_len + new_entries); ListVector::SetListSize(result, current_offset); - FlatVector::SetSize(result, count_t(offset + count)); result.Verify(); } diff --git a/src/duckdb/src/include/duckdb/function/aggregate/sort_key_helpers.hpp b/src/duckdb/src/include/duckdb/function/aggregate/sort_key_helpers.hpp index 3ed9eecb5..86ce51c69 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate/sort_key_helpers.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate/sort_key_helpers.hpp @@ -22,7 +22,7 @@ struct AggregateSortKeyHelpers { Vector sort_key(LogicalType::BLOB); auto modifiers = OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST); - CreateSortKeyHelpers::CreateSortKey(input, count, modifiers, sort_key); + CreateSortKeyHelpers::CreateSortKey(input, modifiers, sort_key); UnifiedVectorFormat idata; if (IGNORE_NULLS) { diff --git a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp index b238bf3fb..fc7b2e1db 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/array.hpp" #include "duckdb/common/vector_operations/aggregate_executor.hpp" #include "duckdb/function/aggregate_state.hpp" +#include "duckdb/function/aggregate_state_layout.hpp" #include "duckdb/planner/bound_result_modifier.hpp" #include "duckdb/planner/expression.hpp" @@ -96,9 +97,9 @@ typedef unique_ptr (*bind_aggregate_function_t)(BindAggregateFunct //! The type used for the aggregate destructor method. NOTE: this method is used in destructors and MAY NOT throw. typedef void (*aggregate_destructor_t)(Vector &state, AggregateInputData &aggr_input_data, idx_t count); -//! The type used for updating simple (non-grouped) aggregate functions -typedef void (*aggregate_simple_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, - data_ptr_t state, idx_t count); +//! The type used for updating a clustered set of aggregate states. +typedef void (*aggregate_cluster_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + const ClusteredAggr &clustered, idx_t count); //! The type used for computing complex/custom windowed aggregate functions (optional) typedef void (*aggregate_window_t)(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, @@ -125,7 +126,7 @@ typedef void (*aggregate_serialize_t)(Serializer &serializer, const optional_ptr typedef unique_ptr (*aggregate_deserialize_t)(Deserializer &deserializer, BoundAggregateFunction &function); -typedef LogicalType (*aggregate_get_state_type_t)(const BoundAggregateFunction &function); +typedef AggregateStateLayout (*aggregate_get_state_type_t)(const BoundAggregateFunction &function); struct AggregateFunctionInfo { DUCKDB_API virtual ~AggregateFunctionInfo(); @@ -172,9 +173,9 @@ class AggregateFunctionCallbacks { aggregate_update_t GetStateUpdateCallback() const { return update; } void SetStateUpdateCallback(aggregate_update_t callback) { update = callback; } - bool HasStateSimpleUpdateCallback() const { return simple_update != nullptr; } - aggregate_simple_update_t GetStateSimpleUpdateCallback() const { return simple_update; } - void SetStateSimpleUpdateCallback(aggregate_simple_update_t callback) { simple_update = callback; } + bool HasStateClusterUpdateCallback() const { return cluster_update != nullptr; } + aggregate_cluster_update_t GetStateClusterUpdateCallback() const { return cluster_update; } + void SetStateClusterUpdateCallback(aggregate_cluster_update_t callback) { cluster_update = callback; } void SetStateCombineCallback(aggregate_combine_t callback) { combine = callback; } aggregate_combine_t GetStateCombineCallback() const { return combine; } @@ -219,8 +220,8 @@ class AggregateFunctionCallbacks { aggregate_combine_t combine = nullptr; //! The hashed aggregate finalization function (may be null, if window is set) aggregate_finalize_t finalize = nullptr; - //! The simple aggregate update function (may be null) - aggregate_simple_update_t simple_update = nullptr; + //! The clustered aggregate update function (may be null) + aggregate_cluster_update_t cluster_update = nullptr; //! The windowed aggregate custom function (may be null) aggregate_window_t window = nullptr; //! The windowed aggregate custom initialization function (may be null) @@ -323,9 +324,9 @@ class BaseAggregateFunction { auto GetStateUpdateCallback() const -> aggregate_update_t { return callbacks.update; } auto SetStateUpdateCallback(aggregate_update_t callback) -> void { callbacks.update = callback; } - auto HasStateSimpleUpdateCallback() const -> bool { return callbacks.simple_update != nullptr; } - auto GetStateSimpleUpdateCallback() const -> aggregate_simple_update_t { return callbacks.simple_update; } - auto SetStateSimpleUpdateCallback(aggregate_simple_update_t callback) -> void { callbacks.simple_update = callback; } + auto HasStateClusterUpdateCallback() const -> bool { return callbacks.cluster_update != nullptr; } + auto GetStateClusterUpdateCallback() const -> aggregate_cluster_update_t { return callbacks.cluster_update; } + auto SetStateClusterUpdateCallback(aggregate_cluster_update_t callback) -> void { callbacks.cluster_update = callback; } auto HasStateCombineCallback() const -> bool { return callbacks.combine != nullptr; } auto GetStateCombineCallback() const -> aggregate_combine_t { return callbacks.combine; } @@ -388,11 +389,18 @@ class BaseAggregateFunction { class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { // NOLINT: work-around bug in clang-tidy public: - AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, + static constexpr aggregate_cluster_update_t NoClusterUpdate() { + return nullptr; + } + static constexpr bind_aggregate_function_t NoBind() { + return nullptr; + } + + AggregateFunction(const Identifier &name, const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, - aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, + aggregate_cluster_update_t cluster_update_p = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_window_batch_t window_batch = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) @@ -404,7 +412,7 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { callbacks.update = update; callbacks.combine = combine; callbacks.finalize = finalize; - callbacks.simple_update = simple_update; + callbacks.cluster_update = cluster_update_p; callbacks.window_batch = window_batch; callbacks.bind = bind; callbacks.destructor = destructor; @@ -413,10 +421,10 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { callbacks.deserialize = deserialize; } - AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, + AggregateFunction(const Identifier &name, const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, - aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, + aggregate_cluster_update_t cluster_update_p = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_window_batch_t window_batch = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) @@ -426,7 +434,7 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { callbacks.update = update; callbacks.combine = combine; callbacks.finalize = finalize; - callbacks.simple_update = simple_update; + callbacks.cluster_update = cluster_update_p; callbacks.bind = bind; callbacks.destructor = destructor; callbacks.statistics = statistics; @@ -439,23 +447,23 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, aggregate_finalize_t finalize, FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, - aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr, + aggregate_cluster_update_t cluster_update_p = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_window_batch_t window = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) - : AggregateFunction(string(), arguments, return_type, state_size, initialize, update, combine, finalize, - null_handling, simple_update, bind, destructor, statistics, window, serialize, + : AggregateFunction(Identifier(), arguments, return_type, state_size, initialize, update, combine, finalize, + null_handling, cluster_update_p, bind, destructor, statistics, window, serialize, deserialize) { } AggregateFunction(const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine, - aggregate_finalize_t finalize, aggregate_simple_update_t simple_update = nullptr, + aggregate_finalize_t finalize, aggregate_cluster_update_t cluster_update_p = nullptr, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_window_batch_t window = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) - : AggregateFunction(string(), arguments, return_type, state_size, initialize, update, combine, finalize, - FunctionNullHandling::DEFAULT_NULL_HANDLING, simple_update, bind, destructor, statistics, + : AggregateFunction(Identifier(), arguments, return_type, state_size, initialize, update, combine, finalize, + FunctionNullHandling::DEFAULT_NULL_HANDLING, cluster_update_p, bind, destructor, statistics, window, serialize, deserialize) { } @@ -465,7 +473,7 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { aggregate_window_batch_t window_batch, bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr, aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr) - : SimpleFunction(name, arguments, return_type) { + : SimpleFunction(Identifier(), arguments, return_type) { callbacks.state_size = state_size; callbacks.initialize = initialize; callbacks.window_batch = window_batch; @@ -492,13 +500,31 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { return *this; } + AggregateFunction &SetClusterCallback(aggregate_cluster_update_t cluster_update) { + callbacks.cluster_update = cluster_update; + return *this; + } + public: + template + static aggregate_cluster_update_t UnaryClusterUpdateCallback() { + if constexpr (AggregateExecutor::HasClusteredLocalState::value || + AggregateExecutor::HasClusteredOperation::value) { + return AggregateFunction::UnaryClusterUpdate; + } else { + return nullptr; + } + } + template static AggregateFunction NullaryAggregate(LogicalType return_type) { - return AggregateFunction( - {}, return_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - AggregateFunction::NullaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::NullaryUpdate); + AggregateFunction result( + Identifier(), {}, return_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, AggregateFunction::NullaryScatterUpdate, + AggregateFunction::StateCombine, AggregateFunction::StateFinalize, + FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NullaryClusterUpdate); + WireStructStateType(result); + return result; } template , + AggregateFunction result(Identifier(), {input_type}, return_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, AggregateFunction::StateFinalize, null_handling, - AggregateFunction::UnaryUpdate); + UnaryClusterUpdateCallback()); + WireStructStateType(result); + return result; } template static AggregateFunction BinaryAggregate(const LogicalType &a_type, const LogicalType &b_type, LogicalType return_type) { - return AggregateFunction({a_type, b_type}, return_type, AggregateFunction::StateSize, + AggregateFunction result({a_type, b_type}, return_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, AggregateFunction::BinaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, - AggregateFunction::BinaryUpdate); + AggregateFunction::StateFinalize, nullptr); + WireStructStateType(result); + return result; } public: + template + static void WireStructStateType(AggregateFunction &result) { + if constexpr (HasStructStateType::value) { + if constexpr (IsOptionalStateType::value) { + result.SetStructStateExport([](const BoundAggregateFunction &) { + using T = typename STATE::STATE_TYPE::value_type; + if constexpr (IsStructStateType::value) { + return AggregateStateLayout(T::GetLogicalType(STATE::STATE_NAMES), + AlignValue(sizeof(STATE)), true); + } else { + return AggregateStateLayout(PrimitiveToLogicalType(), AlignValue(sizeof(STATE)), + true); + } + }); + } else { + result.SetStructStateExport([](const BoundAggregateFunction &) { + AggregateStateLayout layout; + layout.type = STATE::STATE_TYPE::GetLogicalType(STATE::STATE_NAMES); + layout.total_state_size = AlignValue(sizeof(STATE)); + STATE::STATE_TYPE::PopulateField(layout.field); + return layout; + }); + } + } else if constexpr (HasPrimitiveLogicalType::value) { + result.SetStructStateExport([](const BoundAggregateFunction &) { + return AggregateStateLayout(PrimitiveToLogicalType(), AlignValue(sizeof(STATE))); + }); + } + } + template static idx_t StateSize(const BoundAggregateFunction &) { return sizeof(STATE); } + template + using initialize_void_t = void; + + //! Detects whether "OP" provides an "Initialize(STATE &)" method + template + struct OperationHasInitialize : std::false_type {}; + template + struct OperationHasInitialize()))>> + : std::true_type {}; + template static void StateInitialize(const BoundAggregateFunction &, data_ptr_t state) { // FIXME: we should remove the "destructor_type" option in the future @@ -548,7 +617,12 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { destructor_type == AggregateDestructorType::LEGACY, "Aggregate state must be trivially move constructible"); #endif - OP::Initialize(*reinterpret_cast(state)); + if constexpr (OperationHasInitialize::value) { + OP::Initialize(*reinterpret_cast(state)); + } else { + // if the operation does not define an Initialize method - initialize the state by zero-initializing it + memset(state, 0, sizeof(STATE)); + } } template @@ -565,6 +639,13 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { AggregateExecutor::NullaryUpdate(state, aggr_input_data, count); } + template + static void NullaryClusterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + const ClusteredAggr &clustered, idx_t count) { + D_ASSERT(input_count == 0); + AggregateExecutor::NullaryClustUpdate(aggr_input_data, clustered, count); + } + template static void UnaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, idx_t count) { @@ -579,6 +660,13 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { AggregateExecutor::UnaryUpdate(inputs[0], aggr_input_data, state, count); } + template + static void UnaryClusterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + const ClusteredAggr &clustered, idx_t count) { + D_ASSERT(input_count == 1); + AggregateExecutor::ExecuteUnaryClustered(inputs[0], aggr_input_data, clustered, count); + } + template static void BinaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, idx_t count) { @@ -626,12 +714,9 @@ class BoundAggregateFunction : public BaseAggregateFunction, public BoundSimpleF DUCKDB_API bool operator==(const BoundAggregateFunction &rhs) const; DUCKDB_API bool operator!=(const BoundAggregateFunction &rhs) const; - LogicalType GetStateType() const { + AggregateStateLayout GetStateType() const { D_ASSERT(callbacks.get_state_type); - const auto result = callbacks.get_state_type(*this); - // The underlying type of the AggregateState should be a struct - D_ASSERT(result.id() == LogicalTypeId::STRUCT); - return result; + return callbacks.get_state_type(*this); } }; diff --git a/src/duckdb/src/include/duckdb/function/aggregate_state.hpp b/src/duckdb/src/include/duckdb/function/aggregate_state.hpp index ff59695aa..ca8e9930f 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_state.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_state.hpp @@ -13,9 +13,12 @@ #include "duckdb/common/vector/string_vector.hpp" #include "duckdb/function/function.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/clustered_aggregate.hpp" #include "duckdb/storage/statistics/node_statistics.hpp" namespace duckdb { +class BoundAggregateFunction; +struct AggregateObject; enum class AggregateType : uint8_t { NON_DISTINCT = 1, DISTINCT = 2 }; //! Whether or not the input order influences the result of the aggregate @@ -24,17 +27,27 @@ enum class AggregateOrderDependent : uint8_t { ORDER_DEPENDENT = 1, NOT_ORDER_DE enum class AggregateDistinctDependent : uint8_t { DISTINCT_DEPENDENT = 1, NOT_DISTINCT_DEPENDENT = 2 }; //! Whether or not the combiner needs to preserve the source enum class AggregateCombineType : uint8_t { PRESERVE_INPUT = 1, ALLOW_DESTRUCTIVE = 2 }; +//! Whether or not we are exporting the state of the aggregate +enum class AggregateStateExportMode : uint8_t { NONE = 1, STATE_EXPORT = 2 }; class BoundAggregateExpression; struct AggregateInputData { - AggregateInputData(optional_ptr bind_data_p, ArenaAllocator &allocator_p, + AggregateInputData(const BoundAggregateFunction &function_p, optional_ptr bind_data_p, + ArenaAllocator &allocator_p, AggregateCombineType combine_type_p = AggregateCombineType::PRESERVE_INPUT) - : bind_data(bind_data_p), allocator(allocator_p), combine_type(combine_type_p) { + : function(function_p), bind_data(bind_data_p), allocator(allocator_p), combine_type(combine_type_p) { } + AggregateInputData(const BoundAggregateExpression &expr, ArenaAllocator &allocator_p, + AggregateCombineType combine_type_p = AggregateCombineType::PRESERVE_INPUT); + AggregateInputData(const AggregateObject &aggr, ArenaAllocator &allocator_p, + AggregateCombineType combine_type_p = AggregateCombineType::PRESERVE_INPUT); + + const BoundAggregateFunction &function; optional_ptr bind_data; ArenaAllocator &allocator; AggregateCombineType combine_type; + optional_ptr clustered; }; struct AggregateUnaryInput { @@ -52,13 +65,13 @@ struct AggregateUnaryInput { }; struct AggregateBinaryInput { - AggregateBinaryInput(AggregateInputData &input_p, ValidityMask &left_mask_p, ValidityMask &right_mask_p) + AggregateBinaryInput(AggregateInputData &input_p, const ValidityMask &left_mask_p, const ValidityMask &right_mask_p) : input(input_p), left_mask(left_mask_p), right_mask(right_mask_p) { } AggregateInputData &input; - ValidityMask &left_mask; - ValidityMask &right_mask; + const ValidityMask &left_mask; + const ValidityMask &right_mask; idx_t lidx; idx_t ridx; }; diff --git a/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp b/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp new file mode 100644 index 000000000..f98ff3799 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp @@ -0,0 +1,223 @@ +//===----------------------------------------------------------------------===// +// +// duckdb/function/aggregate_state_layout.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/helper.hpp" +#include "duckdb/common/type_util.hpp" + +namespace duckdb { + +//! Detection trait: true when STATE defines a nested STATE_TYPE (i.e. StructStateType<...> or OptionalStateType) +template +struct HasStructStateType : std::false_type {}; + +template +struct HasStructStateType> : std::true_type {}; + +//! Phantom marker type for use in StructStateType only. +//! Represents two consecutive flat fields at base+0 and base+sizeof(T): T value, bool is_set. +//! memset(0) initializes to "not set" (is_set = false, value = 0). +template +struct OptionalStateType { + using value_type = T; +}; + +//! Detection trait: true when T is OptionalStateType for some U. +template +struct IsOptionalStateType : std::false_type {}; +template +struct IsOptionalStateType> : std::true_type {}; + +//! Detection trait: true when STATE is itself a C++ primitive type mappable to a LogicalType via PrimitiveToLogicalType +template +struct HasPrimitiveLogicalType : std::false_type {}; + +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; +template <> +struct HasPrimitiveLogicalType : std::true_type {}; + +//! Maps a single C++ field type to a LogicalType. +//! OptionalStateType → PrimitiveToLogicalType() (the optional encoding is captured in +//! AggregateStateField::is_optional) T with STATE_TYPE → nested struct type otherwise → PrimitiveToLogicalType() +template +LogicalType FieldToLogicalType() { + if constexpr (IsOptionalStateType::value) { + return PrimitiveToLogicalType(); + } else if constexpr (HasStructStateType::value) { + return T::STATE_TYPE::GetLogicalType(T::STATE_NAMES); + } else { + return PrimitiveToLogicalType(); + } +} + +//! Per-field layout information within an aggregate state. +//! field_offset: byte offset of this field relative to the parent struct's base. +//! is_optional: true when this field is an OptionalStateType — physically T value + bool is_set at +//! field_offset+sizeof(T). children: non-empty only when the field is itself a STRUCT; each child's offset is relative +//! to this field's base. +struct AggregateStateField { + idx_t field_offset = 0; + bool is_optional = false; + vector children; + + static idx_t GetPhysicalSize(const LogicalType &type) { + if (type.id() != LogicalTypeId::STRUCT) { + return GetTypeIdSize(type.InternalType()); + } + idx_t size = 0; + for (const auto &child : StructType::GetChildTypes(type)) { + idx_t child_size = GetPhysicalSize(child.second); + size = AlignValue(size, MinValue(child_size, 8)); + size += child_size; + } + return size; + } + + static void PopulateChildren(const LogicalType &type, AggregateStateField &field) { + if (type.id() != LogicalTypeId::STRUCT) { + return; + } + D_ASSERT(field.children.empty()); + idx_t offset = 0; + for (auto &[name, child_type] : StructType::GetChildTypes(type)) { + idx_t child_size = GetPhysicalSize(child_type); + offset = AlignValue(offset, MinValue(child_size, 8)); + AggregateStateField child_field; + child_field.field_offset = offset; + PopulateChildren(child_type, child_field); + field.children.push_back(std::move(child_field)); + offset += child_size; + } + } +}; + +//! Populate one field's worth of AggregateStateField children from a C++ type T. +//! Advances `offset` by the physical size of T (including the is_set bool for OptionalStateType). +template +void PopulateFieldFromType(AggregateStateField &parent, idx_t &offset) { + AggregateStateField child; + if constexpr (IsOptionalStateType::value) { + using V = typename T::value_type; + constexpr idx_t alignment = MinValue(sizeof(V), 8); + offset = AlignValue(offset, alignment); + child.field_offset = offset; + child.is_optional = true; + parent.children.push_back(std::move(child)); + offset += sizeof(V) + sizeof(bool); + } else if constexpr (HasStructStateType::value) { + auto logical = FieldToLogicalType(); + idx_t sz = AggregateStateField::GetPhysicalSize(logical); + offset = AlignValue(offset, MinValue(sz, 8)); + child.field_offset = offset; + AggregateStateField::PopulateChildren(logical, child); + parent.children.push_back(std::move(child)); + offset += sz; + } else { + constexpr idx_t alignment = MinValue(sizeof(T), 8); + offset = AlignValue(offset, alignment); + child.field_offset = offset; + parent.children.push_back(std::move(child)); + offset += sizeof(T); + } +} + +//! Describes a struct-typed aggregate state layout. Intended for use as a nested type inside an aggregate STATE struct: +//! static constexpr const char *STATE_NAMES[] = {"field_a", "field_b"}; +//! using STATE_TYPE = StructStateType; +//! UnaryAggregate/BinaryAggregate/NullaryAggregate detect STATE_TYPE and wire up SetStructStateExport automatically. +//! Fields that themselves have STATE_TYPE are recursively expanded into nested struct types. +//! OptionalStateType fields are exported as nullable T (T value at base+0, bool is_set at base+sizeof(T)). +//! Names are passed at call time (STATE::STATE_NAMES) rather than as template arguments. +template +struct StructStateType { + static LogicalType GetLogicalType(const char *const *names) { + child_list_t children; + idx_t i = 0; + (children.emplace_back(names[i++], FieldToLogicalType()), ...); + return LogicalType::STRUCT(std::move(children)); + } + + //! Populate field children using compile-time type knowledge, correctly handling OptionalStateType. + static void PopulateField(AggregateStateField &field) { + idx_t offset = 0; + (PopulateFieldFromType(field, offset), ...); + } +}; + +//! Detection trait: true when T is StructStateType for some Us. +template +struct IsStructStateType : std::false_type {}; +template +struct IsStructStateType> : std::true_type {}; + +//! Top-level description of an aggregate state for export/import purposes. +//! Returned by the aggregate_get_state_type_t callback registered via SetStructStateExport. +//! +//! - Primitive state (e.g. int64_t for count): type=BIGINT, field.children empty, field.is_optional=false +//! - Optional state (e.g. OptionalStateType): type=BOOLEAN, field.is_optional=true, field.children empty +//! - Struct state: type=STRUCT(...), field.children fully populated via StructStateType::PopulateField +//! total_state_size is the aligned size of the full state (stride between consecutive states in a buffer). +struct AggregateStateLayout { + AggregateStateLayout() = default; + AggregateStateLayout(LogicalType type_p, idx_t total_state_size_p, bool is_optional = false) + : type(std::move(type_p)), total_state_size(total_state_size_p) { + field.is_optional = is_optional; + AggregateStateField::PopulateChildren(type, field); + } + + LogicalType type; + AggregateStateField field; + idx_t total_state_size = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/compression_function.hpp b/src/duckdb/src/include/duckdb/function/compression_function.hpp index 32ed345d8..c2bd0f3a4 100644 --- a/src/duckdb/src/include/duckdb/function/compression_function.hpp +++ b/src/duckdb/src/include/duckdb/function/compression_function.hpp @@ -63,7 +63,8 @@ class CompressionInfo { }; struct AnalyzeState { - explicit AnalyzeState(const CompressionInfo &info) : info(info) {}; + explicit AnalyzeState(BlockManager &block_manager_p) : block_manager(block_manager_p), info(block_manager) { + } virtual ~AnalyzeState() { } @@ -78,11 +79,12 @@ struct AnalyzeState { return reinterpret_cast(*this); } + BlockManager &block_manager; CompressionInfo info; }; struct CompressionState { - explicit CompressionState(const CompressionInfo &info) : info(info) {}; + explicit CompressionState(ColumnDataCheckpointData &checkpoint_data, CompressionType compression_type); virtual ~CompressionState() { } @@ -96,7 +98,14 @@ struct CompressionState { DynamicCastCheck(this); return reinterpret_cast(*this); } + const LogicalType &GetType(); + + unique_ptr CreateNewSegment(); +public: + ColumnDataCheckpointData &checkpoint_data; + const CompressionFunction &function; + BlockManager &block_manager; CompressionInfo info; }; @@ -154,7 +163,7 @@ struct CompressionAppendState { //! The system then decides which compression function to use based on the analyzed score (returned from final_analyze) typedef unique_ptr (*compression_init_analyze_t)(ColumnData &col_data, PhysicalType type); -typedef bool (*compression_analyze_t)(AnalyzeState &state, Vector &input, idx_t count); +typedef bool (*compression_analyze_t)(AnalyzeState &state, const Vector &input); typedef idx_t (*compression_final_analyze_t)(AnalyzeState &state); //===--------------------------------------------------------------------===// @@ -162,7 +171,7 @@ typedef idx_t (*compression_final_analyze_t)(AnalyzeState &state); //===--------------------------------------------------------------------===// typedef unique_ptr (*compression_init_compression_t)(ColumnDataCheckpointData &checkpoint_data, unique_ptr state); -typedef void (*compression_compress_data_t)(CompressionState &state, Vector &scan_vector, idx_t count); +typedef void (*compression_compress_data_t)(CompressionState &state, const Vector &scan_vector); typedef void (*compression_compress_finalize_t)(CompressionState &state); //===--------------------------------------------------------------------===// @@ -199,8 +208,8 @@ typedef unique_ptr (*compression_init_segment_t)( ColumnSegment &segment, block_id_t block_id, optional_ptr segment_state); typedef unique_ptr (*compression_init_append_t)(ColumnSegment &segment); typedef idx_t (*compression_append_t)(CompressionAppendState &append_state, ColumnSegment &segment, - SegmentStatistics &stats, UnifiedVectorFormat &data, idx_t offset, idx_t count); -typedef idx_t (*compression_finalize_append_t)(ColumnSegment &segment, SegmentStatistics &stats); + BaseStatistics &stats, UnifiedVectorFormat &data, idx_t offset, idx_t count); +typedef idx_t (*compression_finalize_append_t)(ColumnSegment &segment, BaseStatistics &stats); typedef void (*compression_revert_append_t)(ColumnSegment &segment, idx_t new_count); //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/function/copy_function.hpp b/src/duckdb/src/include/duckdb/function/copy_function.hpp index 4b80f42da..4dd20df2b 100644 --- a/src/duckdb/src/include/duckdb/function/copy_function.hpp +++ b/src/duckdb/src/include/duckdb/function/copy_function.hpp @@ -136,7 +136,8 @@ enum class CopyFunctionExecutionMode { REGULAR_COPY_TO_FILE, PARALLEL_COPY_TO_FI typedef BoundStatement (*copy_to_plan_t)(Binder &binder, CopyStatement &stmt); typedef void (*copy_options_t)(ClientContext &context, CopyOptionsInput &input); typedef unique_ptr (*copy_to_bind_t)(ClientContext &context, CopyFunctionBindInput &input, - const vector &names, const vector &sql_types); + const vector &names, + const vector &sql_types); typedef unique_ptr (*copy_to_initialize_local_t)(ExecutionContext &context, FunctionData &bind_data); typedef unique_ptr (*copy_to_initialize_global_t)(ClientContext &context, FunctionData &bind_data, const string &file_path); @@ -182,7 +183,7 @@ enum class CopyFunctionReturnType : uint8_t { CHANGED_ROWS_AND_FILE_LIST = 1, WRITTEN_FILE_STATISTICS = 2 }; -vector GetCopyFunctionReturnNames(CopyFunctionReturnType return_type); +vector GetCopyFunctionReturnNames(CopyFunctionReturnType return_type); vector GetCopyFunctionReturnLogicalTypes(CopyFunctionReturnType return_type); struct CopyFunctionFileStatistics { @@ -232,7 +233,7 @@ struct CopyFunctionBatchAnalyzer { class CopyFunction : public Function { // NOLINT: work-around bug in clang-tidy public: - explicit CopyFunction(const string &name); + explicit CopyFunction(const Identifier &name); //! Plan rewrite copy function copy_to_plan_t plan; @@ -264,9 +265,6 @@ class CopyFunction : public Function { // NOLINT: work-around bug in clang-tidy //! Additional function info, passed to the bind shared_ptr function_info; - - //! Whether this copy function supports writing SQLNULL (e.g. Parquet UNKNOWN/NullType) - bool supports_sql_null = false; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/create_sort_key.hpp b/src/duckdb/src/include/duckdb/function/create_sort_key.hpp index b2b5c08c3..e8d08aab3 100644 --- a/src/duckdb/src/include/duckdb/function/create_sort_key.hpp +++ b/src/duckdb/src/include/duckdb/function/create_sort_key.hpp @@ -47,12 +47,14 @@ struct OrderModifiers { struct CreateSortKeyHelpers { static void CreateSortKey(DataChunk &input, const vector &modifiers, Vector &result); - static void CreateSortKey(Vector &input, idx_t input_count, OrderModifiers modifiers, Vector &result); + static void CreateSortKey(const Vector &input, OrderModifiers modifiers, Vector &result); + static void CreateSortKey(const Vector &input, idx_t count, OrderModifiers modifiers, Vector &result); static idx_t DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, OrderModifiers modifiers); static void DecodeSortKey(string_t sort_key, DataChunk &result, idx_t result_idx, const vector &modifiers); - static void CreateSortKeyWithValidity(Vector &input, Vector &result, const OrderModifiers &modifiers, - const idx_t count); + static void CreateSortKeyWithValidity(const Vector &input, Vector &result, const OrderModifiers &modifiers); + static void CreateSortKeyWithValidity(const Vector &input, Vector &result, const OrderModifiers &modifiers, + idx_t count); }; //! We don't add this function to the catalog, for internal use only diff --git a/src/duckdb/src/include/duckdb/function/function.hpp b/src/duckdb/src/include/duckdb/function/function.hpp index f16f8d607..dd9ab0121 100644 --- a/src/duckdb/src/include/duckdb/function/function.hpp +++ b/src/duckdb/src/include/duckdb/function/function.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/named_parameter_map.hpp" #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/common/unordered_set.hpp" @@ -15,6 +16,7 @@ #include "duckdb/parser/column_definition.hpp" #include "duckdb/common/enums/function_errors.hpp" #include "duckdb/common/optional_idx.hpp" +#include "fmt/core.h" namespace duckdb { class CatalogEntry; @@ -101,7 +103,12 @@ struct FunctionParameters { class FunctionParameter { public: - FunctionParameter(string name, LogicalType type) : name(std::move(name)), type(std::move(type)) { + FunctionParameter(Identifier name, LogicalType type) + : name(std::move(name)), type(std::move(type)), default_value(nullptr) { + } + + FunctionParameter(Identifier name, LogicalType type, Value value) + : name(std::move(name)), type(std::move(type)), default_value(make_shared_ptr(std::move(value))) { } string ToString() const; @@ -109,10 +116,10 @@ class FunctionParameter { bool operator==(const FunctionParameter &other) const; bool operator!=(const FunctionParameter &other) const; - auto GetName() const -> const string & { + auto GetName() const -> const Identifier & { return name; } - auto SetName(string name_p) -> void { + auto SetName(Identifier name_p) -> void { name = std::move(name_p); } @@ -123,13 +130,26 @@ class FunctionParameter { type = std::move(type_p); } + auto GetDefaultValue() const -> optional_ptr { + return default_value.get(); + } + auto SetDefaultValue(Value value) -> void { + default_value = make_shared_ptr(std::move(value)); + } + auto HasDefaultValue() const -> bool { + return default_value != nullptr; + } + private: - string name; + Identifier name; LogicalType type; + shared_ptr default_value; }; class FunctionSignature { public: + FunctionSignature() = default; + FunctionSignature(vector arguments, LogicalType varargs, LogicalType return_type) : varargs(std::move(varargs)), return_type(std::move(return_type)) { for (auto &arg : arguments) { @@ -178,24 +198,64 @@ class FunctionSignature { varargs = std::move(varargs_p); } - auto AddParameter(string name, LogicalType type) -> void { + auto AddParameter(Identifier name, LogicalType type, Value default_value) -> FunctionSignature & { + parameters.emplace_back(std::move(name), std::move(type), std::move(default_value)); + return *this; + } + + auto AddParameter(Identifier name, LogicalType type) -> FunctionSignature & { parameters.emplace_back(std::move(name), std::move(type)); + return *this; } - auto AddParameter(LogicalType type) -> void { - auto name = parameters.empty() ? string("col") : string("col") + to_string(parameters.size() + 1); + auto AddParameter(LogicalType type) -> FunctionSignature & { + auto name = StringUtil::Format("col%d", parameters.size()); + parameters.emplace_back(Identifier(name), std::move(type)); + return *this; + } + + auto GetParameterIndexByName(const Identifier &name) const -> optional_idx { + // Parameter names are matched case-insensitively, consistent with SQL identifier semantics. + for (idx_t i = 0; i < parameters.size(); i++) { + if (parameters[i].GetName() == name) { + return i; + } + } + return optional_idx(); + } - parameters.emplace_back(name, std::move(type)); + auto GetRequiredParameterCount() const -> idx_t { + idx_t result = 0; + for (const auto ¶m : parameters) { + if (!param.HasDefaultValue()) { + result++; + } + } + return result; } void Verify() const { - case_insensitive_set_t seen_names; + // Check for duplicate parameter names + identifier_set_t seen_names; for (const auto ¶m : parameters) { if (seen_names.find(param.GetName()) != seen_names.end()) { throw InvalidInputException("Duplicate parameter name: %s", param.GetName()); } seen_names.insert(param.GetName()); } + + // Also check for default values that are not at the end of the parameter list + bool found_default_value = false; + for (const auto ¶m : parameters) { + if (param.HasDefaultValue()) { + found_default_value = true; + } else if (found_default_value) { + throw InvalidInputException( + "Parameters with default values must be at the end of the parameter list. Parameter '%s' does not " + "have a default value but follows a parameter with a default value.", + param.GetName()); + } + } } hash_t Hash() const; @@ -209,52 +269,53 @@ class FunctionSignature { //! Function is the base class used for any type of function (scalar, aggregate or simple function) class Function { public: - DUCKDB_API explicit Function(string name); + DUCKDB_API explicit Function(Identifier name); DUCKDB_API virtual ~Function(); //! The name of the function - string name; + Identifier name; //! Additional Information to specify function from it's name string extra_info; // Optional catalog name of the function - string catalog_name; + Identifier catalog_name; // Optional schema name of the function - string schema_name; + Identifier schema_name; public: - auto SetName(string name_p) -> void { + auto SetName(Identifier name_p) -> void { name = std::move(name_p); } - auto SetSchemaName(string schema_name_p) -> void { + auto SetSchemaName(Identifier schema_name_p) -> void { schema_name = std::move(schema_name_p); } - auto SetCatalogName(string catalog_name_p) -> void { + auto SetCatalogName(Identifier catalog_name_p) -> void { catalog_name = std::move(catalog_name_p); } - const string &GetName() const { + const Identifier &GetName() const { return name; } - const string &GetSchemaName() const { + const Identifier &GetSchemaName() const { return schema_name; } - const string &GetCatalogName() const { + const Identifier &GetCatalogName() const { return catalog_name; } //! Returns the formatted string name(arg1, arg2, ...) - DUCKDB_API static string CallToString(const string &catalog_name, const string &schema_name, const string &name, - const vector &arguments, + DUCKDB_API static string CallToString(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &name, const vector &arguments, + const vector> &named_arguments, const LogicalType &varargs = LogicalType::INVALID); //! Returns the formatted string name(arg1, arg2..) -> return_type - DUCKDB_API static string CallToString(const string &catalog_name, const string &schema_name, const string &name, - const vector &arguments, const LogicalType &varargs, - const LogicalType &return_type); + DUCKDB_API static string CallToString(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &name, const vector &arguments, + const LogicalType &varargs, const LogicalType &return_type); //! Returns the formatted string name(arg1, arg2.., np1=a, np2=b, ...) - DUCKDB_API static string CallToString(const string &catalog_name, const string &schema_name, const string &name, - const vector &arguments, + DUCKDB_API static string CallToString(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &name, const vector &arguments, const named_parameter_type_map_t &named_parameters); //! Used in the bind to erase an argument from a function DUCKDB_API static void EraseArgument(BoundSimpleFunction &bound_function, vector> &arguments, @@ -263,7 +324,8 @@ class Function { class SimpleFunction : public Function { public: - DUCKDB_API SimpleFunction(string name, vector arguments, LogicalType return_type, + DUCKDB_API SimpleFunction(Identifier name, FunctionSignature signature); + DUCKDB_API SimpleFunction(Identifier name, vector arguments, LogicalType return_type, LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); DUCKDB_API ~SimpleFunction() override; @@ -303,7 +365,7 @@ class SimpleFunction : public Function { class SimpleNamedParameterFunction : public Function { public: - DUCKDB_API SimpleNamedParameterFunction(string name, vector arguments, + DUCKDB_API SimpleNamedParameterFunction(Identifier name, vector arguments, LogicalType varargs = LogicalType(LogicalTypeId::INVALID)); DUCKDB_API ~SimpleNamedParameterFunction() override; @@ -382,6 +444,13 @@ class FunctionProperties { collation_handling = value; } + auto GetCaptureArgumentAliases() const -> bool { + return capture_argument_aliases; + } + auto SetCaptureArgumentAliases(bool value) -> void { + capture_argument_aliases = value; + } + // Helpers auto SetFallible() -> void { errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; @@ -401,13 +470,17 @@ class FunctionProperties { FunctionErrors errors = FunctionErrors::CANNOT_ERROR; //! Collation handling of the function FunctionCollationHandling collation_handling = FunctionCollationHandling::PROPAGATE_COLLATIONS; + //! Whether the binder should capture argument expression aliases as named-argument names when binding this + //! function. This preserves the legacy behavior of functions such as struct_pack/row, which derived their + //! (struct field) names from argument aliases and therefore allowed positional arguments after named ones. + bool capture_argument_aliases = false; }; class BoundSimpleFunction { protected: - string name; - string schema_name; - string catalog_name; + Identifier name; + Identifier schema_name; + Identifier catalog_name; string extra_info; //! The set of arguments of the function @@ -419,17 +492,17 @@ class BoundSimpleFunction { LogicalType return_type; public: - void SetName(string name_p) { + void SetName(Identifier name_p) { name = std::move(name_p); } - const string &GetName() const { + const Identifier &GetName() const { return name; } - const string &GetSchemaName() const { + const Identifier &GetSchemaName() const { return schema_name; } - const string &GetCatalogName() const { + const Identifier &GetCatalogName() const { return catalog_name; } diff --git a/src/duckdb/src/include/duckdb/function/function_binder.hpp b/src/duckdb/src/include/duckdb/function/function_binder.hpp index 1a8f07ef1..c8f11b3ff 100644 --- a/src/duckdb/src/include/duckdb/function/function_binder.hpp +++ b/src/duckdb/src/include/duckdb/function/function_binder.hpp @@ -17,6 +17,8 @@ namespace duckdb { +class WindowFunctionCatalogEntry; + //! The FunctionBinder class is responsible for binding functions class FunctionBinder { public: @@ -29,34 +31,70 @@ class FunctionBinder { public: //! Bind a scalar function from the set of functions and input arguments. Returns the index of the chosen function, //! returns optional_idx() and sets error if none could be found - DUCKDB_API optional_idx BindFunction(const string &name, ScalarFunctionSet &functions, - const vector &arguments, ErrorData &error); - DUCKDB_API optional_idx BindFunction(const string &name, ScalarFunctionSet &functions, - vector> &arguments, ErrorData &error); + DUCKDB_API optional_idx BindFunction(const Identifier &name, const ScalarFunctionSet &functions, + const vector ®ular_args, + const vector> &keyword_args, ErrorData &error); + DUCKDB_API optional_idx BindFunction(const Identifier &name, const ScalarFunctionSet &functions, + const vector ®ular_args, ErrorData &error) { + return BindFunctionFromArguments(name, functions, regular_args, {}, error); + } + + DUCKDB_API optional_idx BindFunction(const Identifier &name, const ScalarFunctionSet &functions, + const vector> ®ular_args, + const vector>> &keyword_args, + ErrorData &error); + //! Bind an aggregate function from the set of functions and input arguments. Returns the index of the chosen //! function, returns optional_idx() and sets error if none could be found - DUCKDB_API optional_idx BindFunction(const string &name, AggregateFunctionSet &functions, - const vector &arguments, ErrorData &error); - DUCKDB_API optional_idx BindFunction(const string &name, AggregateFunctionSet &functions, - vector> &arguments, ErrorData &error); + DUCKDB_API optional_idx BindFunction(const Identifier &name, const AggregateFunctionSet &functions, + const vector ®ular_args, + const vector> &keyword_args, ErrorData &error); + DUCKDB_API optional_idx BindFunction(const Identifier &name, const AggregateFunctionSet &functions, + const vector ®ular_args, ErrorData &error) { + return BindFunctionFromArguments(name, functions, regular_args, {}, error); + } + + DUCKDB_API optional_idx BindFunction(const Identifier &name, const AggregateFunctionSet &functions, + const vector> ®ular_args, + const vector>> &keyword_args, + ErrorData &error); + + //! Bind an aggregate function from the set of functions and input arguments. Returns the index of the chosen + //! function, returns optional_idx() and sets error if none could be found + DUCKDB_API optional_idx BindFunction(const Identifier &name, const WindowFunctionSet &functions, + const vector ®ular_args, + const vector> &keyword_args, ErrorData &error); + + DUCKDB_API optional_idx BindFunction(const Identifier &name, const WindowFunctionSet &functions, + const vector ®ular_args, ErrorData &error) { + return BindFunctionFromArguments(name, functions, regular_args, {}, error); + } + //! Bind a table function from the set of functions and input arguments. Returns the index of the chosen //! function, returns optional_idx() and sets error if none could be found - DUCKDB_API optional_idx BindFunction(const string &name, TableFunctionSet &functions, - const vector &arguments, ErrorData &error); - DUCKDB_API optional_idx BindFunction(const string &name, TableFunctionSet &functions, - vector> &arguments, ErrorData &error); - //! Bind a pragma function from the set of functions and input arguments - DUCKDB_API optional_idx BindFunction(const string &name, PragmaFunctionSet &functions, vector ¶meters, + DUCKDB_API optional_idx BindFunction(const Identifier &name, const TableFunctionSet &functions, + const vector ®ular_args, + const vector> &keyword_args, ErrorData &error); + DUCKDB_API optional_idx BindFunction(const Identifier &name, const TableFunctionSet &functions, + const vector ®ular_args, ErrorData &error) { + return BindFunctionFromArguments(name, functions, regular_args, {}, error); + } + + DUCKDB_API optional_idx BindFunction(const Identifier &name, const TableFunctionSet &functions, + const vector> ®ular_args, + const vector>> &keyword_args, ErrorData &error); - //! Bind a window function from the set of functions and input arguments - DUCKDB_API optional_idx BindFunction(const string &name, WindowFunctionSet &functions, - const vector &arguments, ErrorData &error); - DUCKDB_API unique_ptr BindScalarFunction(const string &schema, const string &name, + //! Bind a pragma function from the set of functions and input arguments + DUCKDB_API optional_idx BindFunction(const Identifier &name, const PragmaFunctionSet &functions, + vector ¶meters, ErrorData &error); + + DUCKDB_API unique_ptr BindScalarFunction(const Identifier &schema, const Identifier &name, vector> children, ErrorData &error, bool is_operator = false, optional_ptr binder = nullptr); - DUCKDB_API unique_ptr BindScalarFunction(ScalarFunctionCatalogEntry &function, + + DUCKDB_API unique_ptr BindScalarFunction(const ScalarFunctionCatalogEntry &function, vector> children, ErrorData &error, bool is_operator = false, optional_ptr binder = nullptr); @@ -66,60 +104,124 @@ class FunctionBinder { bool is_operator = false, optional_ptr binder = nullptr); + //! Bind a scalar function from a catalog entry given the full list of (maybe-named) bound arguments. The + //! positional/named split is resolved per candidate overload (overloads flagged to capture argument aliases + //! treat every argument as positional and keep its alias). + DUCKDB_API unique_ptr BindScalarFunction(const ScalarFunctionCatalogEntry &function, + vector>> arguments, + ErrorData &error, bool is_operator = false, + optional_ptr binder = nullptr); + + DUCKDB_API unique_ptr BindScalarFunction(const ScalarFunction &bound_function, + vector> children, + vector>> keyword_args, + bool is_operator = false, + optional_ptr binder = nullptr); + DUCKDB_API unique_ptr BindAggregateFunction(const AggregateFunction &bound_function, vector> children, unique_ptr filter = nullptr, AggregateType aggr_type = AggregateType::NON_DISTINCT); + DUCKDB_API unique_ptr + BindAggregateFunction(const AggregateFunction &function, vector> children, + vector>> keyword_args, unique_ptr filter, + AggregateType aggr_type); + + DUCKDB_API unique_ptr + BindAggregateFunction(const AggregateFunctionCatalogEntry &function, + vector>> arguments, ErrorData &error, + unique_ptr filter = nullptr, + AggregateType aggr_type = AggregateType::NON_DISTINCT); + DUCKDB_API static void BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, const vector> &groups, optional_ptr> grouping_sets); DUCKDB_API static void BindSortedAggregate(ClientContext &context, BoundWindowExpression &expr); + DUCKDB_API unique_ptr + BindWindowFunction(const WindowFunction &function, vector> children, + vector>> keyword_args, vector &orders, + vector &arg_orders); + DUCKDB_API unique_ptr BindWindowFunction(const WindowFunction &function, vector> children, vector &orders, vector &arg_orders); - //! Cast a set of expressions to the arguments of this function - void CastToFunctionArguments(BoundSimpleFunction &function, vector> &children); - - void ResolveTemplateTypes(BoundSimpleFunction &bound_function, const vector> &children); - void CheckTemplateTypesResolved(const BoundSimpleFunction &bound_function); + DUCKDB_API unique_ptr + BindWindowFunction(const WindowFunctionCatalogEntry &function, + vector>> arguments, ErrorData &error, + vector &orders, vector &arg_orders); pair> ResolveFunction(const ScalarFunction &function, - vector> &children); + vector> &children) { + vector>> empty_keyword_args; + return ResolveFunction(function, children, empty_keyword_args); + } + + pair> + ResolveFunction(const ScalarFunction &function, vector> &children, + vector>> &keyword_args); + + pair> + ResolveFunction(const AggregateFunction &function, vector> &children, + vector>> &keyword_args); pair> ResolveFunction(const AggregateFunction &function, - vector> &children); + vector> &children) { + vector>> empty_keyword_args; + return ResolveFunction(function, children, empty_keyword_args); + } pair> ResolveFunction(const WindowFunction &function, vector> &children, + vector>> &keyword_args, optional_ptr> orders = nullptr, optional_ptr> arg_orders = nullptr); + pair> ResolveFunction(const WindowFunction &function, + vector> &children) { + vector>> empty_keyword_args; + return ResolveFunction(function, children, empty_keyword_args); + } + private: - optional_idx BindVarArgsFunctionCost(const SimpleFunction &func, const vector &arguments); - optional_idx BindFunctionCost(const SimpleFunction &func, const vector &arguments); + //! Cast a set of expressions to the arguments of this function + void CastToFunctionArguments(BoundSimpleFunction &function, vector> &children); + + void ResolveTemplateTypes(BoundSimpleFunction &bound_function, const vector> &children); + void CheckTemplateTypesResolved(const BoundSimpleFunction &bound_function); + + optional_idx BindFunctionCost(const SimpleFunction &func, const vector &arguments, + const vector> &named_arguments); optional_idx BindVarArgsFunctionCost(const SimpleNamedParameterFunction &func, const vector &arguments); - optional_idx BindFunctionCost(const SimpleNamedParameterFunction &func, const vector &arguments); + optional_idx BindFunctionCost(const SimpleNamedParameterFunction &func, const vector &arguments, + const vector> &); template - vector BindFunctionsFromArguments(const string &name, FunctionSet &functions, - const vector &arguments, ErrorData &error); + vector BindFunctionsFromArguments(const Identifier &name, const FunctionSet &functions, + const vector &arguments, + const vector> &named_arguments, + ErrorData &error); template - optional_idx MultipleCandidateException(const string &catalog_name, const string &schema_name, const string &name, - FunctionSet &functions, vector &candidate_functions, - const vector &arguments, ErrorData &error); + optional_idx BindFunctionFromArguments(const Identifier &name, const FunctionSet &functions, + const vector &arguments, + const vector> &named_arguments, + ErrorData &error); + //! Select the best matching overload for the given full (maybe-named) argument list. template - optional_idx BindFunctionFromArguments(const string &name, FunctionSet &functions, - const vector &arguments, ErrorData &error); + optional_idx BindFunctionFromArguments(const Identifier &name, const FunctionSet &functions, + vector>> &arguments, + ErrorData &error); - vector GetLogicalTypesFromExpressions(vector> &arguments); + pair, vector>> + GetArgumentsFromExpressions(const vector> ®ular_arguments, + const vector>> &keyword_arguments); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/function_serialization.hpp b/src/duckdb/src/include/duckdb/function/function_serialization.hpp index df778a8f2..dc9b4bcc7 100644 --- a/src/duckdb/src/include/duckdb/function/function_serialization.hpp +++ b/src/duckdb/src/include/duckdb/function/function_serialization.hpp @@ -27,8 +27,8 @@ class FunctionSerializer { // These are optional fields that are written out of numeric order, older // databases won't contain the fields, so the defaults will be used, but if // the fields are present, they will be used. - serializer.WritePropertyWithDefault(505, "catalog_name", function.GetCatalogName(), ""); - serializer.WritePropertyWithDefault(506, "schema_name", function.GetSchemaName(), ""); + serializer.WritePropertyWithDefault(505, "catalog_name", function.GetCatalogName(), Identifier()); + serializer.WritePropertyWithDefault(506, "schema_name", function.GetSchemaName(), Identifier()); bool has_serialize = function.HasSerializationCallbacks(); serializer.WriteProperty(503, "has_serialize", has_serialize); @@ -40,16 +40,18 @@ class FunctionSerializer { } template - static FUNC DeserializeFunction(ClientContext &context, CatalogType catalog_type, const string &catalog_name, - const string &schema_name, const string &name, const vector &arguments, + static FUNC DeserializeFunction(ClientContext &context, CatalogType catalog_type, const Identifier &catalog_name, + const Identifier &schema_name, const Identifier &name, + const vector &arguments, const vector &original_arguments) { EntryLookupInfo lookup_info(catalog_type, name); auto &func_catalog = - Catalog::GetEntry(context, catalog_type, catalog_name.empty() ? SYSTEM_CATALOG : catalog_name, - schema_name.empty() ? DEFAULT_SCHEMA : schema_name, name); + Catalog::GetEntry(context, catalog_type, catalog_name.empty() ? Identifier::SystemCatalog() : catalog_name, + schema_name.empty() ? Identifier::DefaultSchema() : schema_name, name); if (func_catalog.type != catalog_type) { - throw InternalException("DeserializeFunction - cant find catalog entry for function %s", name); + throw InternalException("DeserializeFunction - cant find catalog entry for function %s", + name.GetIdentifierName()); } auto &functions = func_catalog.Cast(); auto function = functions.functions.GetFunctionByArguments( @@ -61,16 +63,16 @@ class FunctionSerializer { static pair DeserializeBase(Deserializer &deserializer, CatalogType catalog_type, optional_ptr>> children = nullptr) { auto &context = deserializer.Get(); - auto name = deserializer.ReadProperty(500, "name"); + auto name = deserializer.ReadProperty(500, "name"); auto arguments = deserializer.ReadProperty>(501, "arguments"); auto original_arguments = deserializer.ReadProperty>(502, "original_arguments"); - auto catalog_name = deserializer.ReadPropertyWithDefault(505, "catalog_name"); - auto schema_name = deserializer.ReadPropertyWithDefault(506, "schema_name"); + auto catalog_name = deserializer.ReadPropertyWithDefault(505, "catalog_name"); + auto schema_name = deserializer.ReadPropertyWithDefault(506, "schema_name"); if (catalog_name.empty()) { - catalog_name = SYSTEM_CATALOG; + catalog_name = Identifier::SystemCatalog(); } if (schema_name.empty()) { - schema_name = DEFAULT_SCHEMA; + schema_name = Identifier::DefaultSchema(); } if (arguments.empty() && original_arguments.empty() && children && !children->empty()) { @@ -148,18 +150,18 @@ class FunctionSerializer { LogicalType return_type) { // NOLINT: clang-tidy bug auto &context = deserializer.Get(); - auto name = deserializer.ReadProperty(500, "name"); + auto name = deserializer.ReadProperty(500, "name"); auto arguments = deserializer.ReadProperty>(501, "arguments"); auto original_arguments = deserializer.ReadProperty>(502, "original_arguments"); - auto catalog_name = deserializer.ReadPropertyWithDefault(505, "catalog_name"); - auto schema_name = deserializer.ReadPropertyWithDefault(506, "schema_name"); + auto catalog_name = deserializer.ReadPropertyWithDefault(505, "catalog_name"); + auto schema_name = deserializer.ReadPropertyWithDefault(506, "schema_name"); auto has_serialize = deserializer.ReadProperty(503, "has_serialize"); if (catalog_name.empty()) { - catalog_name = SYSTEM_CATALOG; + catalog_name = Identifier::SystemCatalog(); } if (schema_name.empty()) { - schema_name = DEFAULT_SCHEMA; + schema_name = Identifier::DefaultSchema(); } if (arguments.empty() && original_arguments.empty() && !children.empty()) { @@ -177,7 +179,8 @@ class FunctionSerializer { auto &func_catalog = Catalog::GetEntry(context, catalog_type, catalog_name, schema_name, name); if (func_catalog.type != catalog_type) { - throw InternalException("DeserializeFunction - cant find catalog entry for function %s", name); + throw InternalException("DeserializeFunction - cant find catalog entry for function %s", + name.GetIdentifierName()); } auto &functions = func_catalog.Cast(); const auto &function = functions.functions.GetFunctionByArguments( diff --git a/src/duckdb/src/include/duckdb/function/function_set.hpp b/src/duckdb/src/include/duckdb/function/function_set.hpp index dfeeb3175..97f41fdc0 100644 --- a/src/duckdb/src/include/duckdb/function/function_set.hpp +++ b/src/duckdb/src/include/duckdb/function/function_set.hpp @@ -19,11 +19,16 @@ namespace duckdb { template class FunctionSet { public: - explicit FunctionSet(string name) : name(std::move(name)) { // NOLINT + explicit FunctionSet(Identifier name) : name(std::move(name)) { // NOLINT } //! The name of the function set - string name; + Identifier name; + +public: + void SetName(Identifier name_p) { + name = std::move(name_p); + } //! The set of functions. vector functions; @@ -35,7 +40,7 @@ class FunctionSet { return functions.size(); } - const T &GetFunctionByOffset(idx_t offset) { + const T &GetFunctionByOffset(idx_t offset) const { D_ASSERT(offset < functions.size()); return functions[offset]; } @@ -69,7 +74,7 @@ class FunctionSet { class ScalarFunctionSet : public FunctionSet { public: DUCKDB_API explicit ScalarFunctionSet(); - DUCKDB_API explicit ScalarFunctionSet(string name); + DUCKDB_API explicit ScalarFunctionSet(Identifier name); DUCKDB_API explicit ScalarFunctionSet(ScalarFunction fun); DUCKDB_API const ScalarFunction &GetFunctionByArguments(ClientContext &context, @@ -96,7 +101,7 @@ class ScalarFunctionSet : public FunctionSet { class AggregateFunctionSet : public FunctionSet { public: DUCKDB_API explicit AggregateFunctionSet(); - DUCKDB_API explicit AggregateFunctionSet(string name); + DUCKDB_API explicit AggregateFunctionSet(Identifier name); DUCKDB_API explicit AggregateFunctionSet(AggregateFunction fun); DUCKDB_API const AggregateFunction &GetFunctionByArguments(ClientContext &context, @@ -106,7 +111,7 @@ class AggregateFunctionSet : public FunctionSet { class WindowFunctionSet : public FunctionSet { public: DUCKDB_API explicit WindowFunctionSet(); - DUCKDB_API explicit WindowFunctionSet(string name); + DUCKDB_API explicit WindowFunctionSet(Identifier name); DUCKDB_API explicit WindowFunctionSet(WindowFunction fun); DUCKDB_API const WindowFunction &GetFunctionByArguments(ClientContext &context, @@ -115,7 +120,7 @@ class WindowFunctionSet : public FunctionSet { class TableFunctionSet : public FunctionSet { public: - DUCKDB_API explicit TableFunctionSet(string name); + DUCKDB_API explicit TableFunctionSet(Identifier name); DUCKDB_API explicit TableFunctionSet(TableFunction fun); DUCKDB_API const TableFunction &GetFunctionByArguments(ClientContext &context, @@ -124,7 +129,7 @@ class TableFunctionSet : public FunctionSet { class PragmaFunctionSet : public FunctionSet { public: - DUCKDB_API explicit PragmaFunctionSet(string name); + DUCKDB_API explicit PragmaFunctionSet(Identifier name); DUCKDB_API explicit PragmaFunctionSet(PragmaFunction fun); }; diff --git a/src/duckdb/src/include/duckdb/function/lambda_functions.hpp b/src/duckdb/src/include/duckdb/function/lambda_functions.hpp index 53d02e3ac..7fc9753eb 100644 --- a/src/duckdb/src/include/duckdb/function/lambda_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/lambda_functions.hpp @@ -108,7 +108,7 @@ class LambdaFunctions { // get the lambda expression auto &func_expr = state.expr.Cast(); - auto &bind_info = func_expr.bind_info->Cast(); + auto &bind_info = func_expr.BindInfo()->Cast(); lambda_expr = bind_info.lambda_expr; is_volatile = lambda_expr->IsVolatile(); has_index = bind_info.has_index; diff --git a/src/duckdb/src/include/duckdb/function/macro_function.hpp b/src/duckdb/src/include/duckdb/function/macro_function.hpp index d01a1304d..d153e7d6f 100644 --- a/src/duckdb/src/include/duckdb/function/macro_function.hpp +++ b/src/duckdb/src/include/duckdb/function/macro_function.hpp @@ -40,7 +40,7 @@ class MacroFunction { //! The parameters (ColumnRefExpression) vector> parameters; //! The default values of the parameters - InsertionOrderPreservingMap> default_parameters; + InsertionOrderPreservingMap, Identifier, identifier_map_t> default_parameters; //! The types of the parameters vector types; @@ -55,15 +55,16 @@ class MacroFunction { vector> GetPositionalParametersForSerialization(Serializer &serializer) const; void FinalizeDeserialization(); - static MacroBindResult BindMacroFunction(Binder &binder, const vector> ¯o_functions, - const string &name, FunctionExpression &function_expr, - vector> &positional_arguments, - InsertionOrderPreservingMap> &named_arguments, - idx_t depth); + static MacroBindResult BindMacroFunction( + Binder &binder, const vector> ¯o_functions, const Identifier &name, + FunctionExpression &function_expr, vector> &positional_arguments, + InsertionOrderPreservingMap, Identifier, identifier_map_t> &named_arguments, + idx_t depth); static unique_ptr - CreateDummyBinding(const MacroFunction ¯o_def, const string &name, + CreateDummyBinding(const MacroFunction ¯o_def, const Identifier &name, vector> &positional_arguments, - InsertionOrderPreservingMap> &named_arguments); + InsertionOrderPreservingMap, Identifier, identifier_map_t> + &named_arguments); virtual string ToSQL() const; diff --git a/src/duckdb/src/include/duckdb/function/pragma_function.hpp b/src/duckdb/src/include/duckdb/function/pragma_function.hpp index 91f55f3af..f06783b57 100644 --- a/src/duckdb/src/include/duckdb/function/pragma_function.hpp +++ b/src/duckdb/src/include/duckdb/function/pragma_function.hpp @@ -31,14 +31,15 @@ typedef void (*pragma_function_t)(ClientContext &context, const FunctionParamete class PragmaFunction : public SimpleNamedParameterFunction { // NOLINT: work-around bug in clang-tidy public: // Call - DUCKDB_API static PragmaFunction PragmaCall(const string &name, pragma_query_t query, vector arguments, + DUCKDB_API static PragmaFunction PragmaCall(const Identifier &name, pragma_query_t query, + vector arguments, LogicalType varargs = LogicalType::INVALID); - DUCKDB_API static PragmaFunction PragmaCall(const string &name, pragma_function_t function, + DUCKDB_API static PragmaFunction PragmaCall(const Identifier &name, pragma_function_t function, vector arguments, LogicalType varargs = LogicalType::INVALID); // Statement - DUCKDB_API static PragmaFunction PragmaStatement(const string &name, pragma_query_t query); - DUCKDB_API static PragmaFunction PragmaStatement(const string &name, pragma_function_t function); + DUCKDB_API static PragmaFunction PragmaStatement(const Identifier &name, pragma_query_t query); + DUCKDB_API static PragmaFunction PragmaStatement(const Identifier &name, pragma_function_t function); DUCKDB_API string ToString() const override; @@ -50,7 +51,7 @@ class PragmaFunction : public SimpleNamedParameterFunction { // NOLINT: work-aro named_parameter_type_map_t named_parameters; private: - PragmaFunction(string name, PragmaType pragma_type, pragma_query_t query, pragma_function_t function, + PragmaFunction(Identifier name, PragmaType pragma_type, pragma_query_t query, pragma_function_t function, vector arguments, LogicalType varargs); }; diff --git a/src/duckdb/src/include/duckdb/function/register_function_list_helper.hpp b/src/duckdb/src/include/duckdb/function/register_function_list_helper.hpp index a0cf0acc7..30deed9e4 100644 --- a/src/duckdb/src/include/duckdb/function/register_function_list_helper.hpp +++ b/src/duckdb/src/include/duckdb/function/register_function_list_helper.hpp @@ -30,26 +30,28 @@ static void FillFunctionParameters(FunctionDescription &function_description, co } } -static vector GetExamplesForFunctionAlias(const string &function_name, const string &alias_of, +static vector GetExamplesForFunctionAlias(const Identifier &function_name, const Identifier &alias_of, vector &all_examples) { vector filtered_examples; - bool is_operator = (!function_name.empty() && !(function_name[0] >= 'a' && function_name[0] <= 'z') && - !(function_name[0] >= 'A' && function_name[0] <= 'Z')); - bool alias_of_is_operator = (!alias_of.empty() && !(alias_of[0] >= 'a' && alias_of[0] <= 'z') && - !(alias_of[0] >= 'A' && alias_of[0] <= 'Z')); + auto &function_name_str = function_name.GetIdentifierName(); + auto &alias_of_str = alias_of.GetIdentifierName(); + bool is_operator = (!function_name_str.empty() && !(function_name_str[0] >= 'a' && function_name_str[0] <= 'z') && + !(function_name_str[0] >= 'A' && function_name_str[0] <= 'Z')); + bool alias_of_is_operator = (!alias_of_str.empty() && !(alias_of_str[0] >= 'a' && alias_of_str[0] <= 'z') && + !(alias_of_str[0] >= 'A' && alias_of_str[0] <= 'Z')); // select examples with matching function name for (string &example : all_examples) { - if (example.compare(0, function_name.size(), function_name) == 0 || - (is_operator && example.find(function_name) != string::npos)) { + if (example.compare(0, function_name_str.size(), function_name_str) == 0 || + (is_operator && example.find(function_name_str) != string::npos)) { filtered_examples.emplace_back(std::move(example)); } } // fallback 1: create fitting examples by replacing canonical name by function_name if (filtered_examples.empty() && !alias_of.empty() && !alias_of_is_operator && !is_operator) { for (string &example : all_examples) { - if (example.compare(0, alias_of.size(), alias_of) == 0) { - filtered_examples.emplace_back(function_name + - example.substr(alias_of.size(), example.size() - alias_of.size())); + if (example.compare(0, alias_of_str.size(), alias_of_str) == 0) { + filtered_examples.emplace_back( + function_name + example.substr(alias_of_str.size(), example.size() - alias_of_str.size())); } } } diff --git a/src/duckdb/src/include/duckdb/function/scalar/comparison_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/comparison_functions.hpp index 2235f286a..2f370d99b 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/comparison_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/comparison_functions.hpp @@ -25,11 +25,81 @@ struct BetweenFun { static ScalarFunction GetFunction(); }; -struct ComparisonFun { - static constexpr const char *Name = "__comparison"; +struct OperatorEqualFun { + static constexpr const char *Name = "="; static constexpr const char *Parameters = "left,right"; - static constexpr const char *Description = "Performs a comparison between left / right."; - static constexpr const char *Example = "__comparison(3, 0"; + static constexpr const char *Description = "Returns whether left is equal to right."; + static constexpr const char *Example = "1 = 1"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct OperatorNotEqualFun { + static constexpr const char *Name = "!="; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Returns whether left is not equal to right."; + static constexpr const char *Example = "1 != 2"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct OperatorLessThanFun { + static constexpr const char *Name = "<"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Returns whether left is less than right."; + static constexpr const char *Example = "1 < 2"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct OperatorLessThanEqualsFun { + static constexpr const char *Name = "<="; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Returns whether left is less than or equal to right."; + static constexpr const char *Example = "1 <= 2"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct OperatorGreaterThanFun { + static constexpr const char *Name = ">"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Returns whether left is greater than right."; + static constexpr const char *Example = "2 > 1"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct OperatorGreaterThanEqualsFun { + static constexpr const char *Name = ">="; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Returns whether left is greater than or equal to right."; + static constexpr const char *Example = "2 >= 1"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct IsDistinctFromFun { + static constexpr const char *Name = "IS DISTINCT FROM"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Returns whether left is distinct from right, treating null as a comparable value."; + static constexpr const char *Example = "1 IS DISTINCT FROM NULL"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct IsNotDistinctFromFun { + static constexpr const char *Name = "IS NOT DISTINCT FROM"; + static constexpr const char *Parameters = "left,right"; + static constexpr const char *Description = "Returns whether left is not distinct from right, treating null as a comparable value."; + static constexpr const char *Example = "NULL IS NOT DISTINCT FROM NULL"; static constexpr const char *Categories = ""; static ScalarFunction GetFunction(); diff --git a/src/duckdb/src/include/duckdb/function/scalar/generic_common.hpp b/src/duckdb/src/include/duckdb/function/scalar/generic_common.hpp index f4a4089e6..72c497eca 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/generic_common.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/generic_common.hpp @@ -31,6 +31,7 @@ struct ExportAggregateFunctionBindData : public FunctionData { struct ExportAggregateFunction { static unique_ptr Bind(unique_ptr child_aggregate); + static void SetStateExport(BoundAggregateExpression &aggregate, LogicalType state_layout); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/list/contains_or_position.hpp b/src/duckdb/src/include/duckdb/function/scalar/list/contains_or_position.hpp index 5ade05404..482eb13cb 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/list/contains_or_position.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/list/contains_or_position.hpp @@ -7,7 +7,8 @@ namespace duckdb { template -idx_t ListSearchSimpleOp(Vector &input_list, Vector &list_child, Vector &target, Vector &result, const idx_t count) { +idx_t ListSearchSimpleOp(const Vector &input_list, const Vector &list_child, const Vector &target, Vector &result, + const idx_t count) { // If the return type is not a bool, return the position const auto return_pos = std::is_same::value; @@ -85,7 +86,7 @@ idx_t ListSearchSimpleOp(Vector &input_list, Vector &list_child, Vector &target, } template -idx_t ListSearchNestedOp(Vector &list_vec, Vector &source_vec, Vector &target_vec, Vector &result_vec, +idx_t ListSearchNestedOp(const Vector &list_vec, const Vector &source_vec, const Vector &target_vec, Vector &result_vec, const idx_t target_count) { // Set up sort keys for nested types. auto source_count = ListVector::GetListSize(list_vec); @@ -93,6 +94,9 @@ idx_t ListSearchNestedOp(Vector &list_vec, Vector &source_vec, Vector &target_ve Vector target_sort_key_vec(LogicalType::BLOB, target_count); const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); + // Pass explicit counts: source_vec and target_vec may have different sizes than source_count/target_count + // when called via dictionary expression optimization (TryExecuteDictionaryExpression uses the dictionary + // size as the chunk count, while the non-key arguments retain their original size). CreateSortKeyHelpers::CreateSortKeyWithValidity(source_vec, source_sort_key_vec, order_modifiers, source_count); CreateSortKeyHelpers::CreateSortKeyWithValidity(target_vec, target_sort_key_vec, order_modifiers, target_count); @@ -105,7 +109,8 @@ idx_t ListSearchNestedOp(Vector &list_vec, Vector &source_vec, Vector &target_ve //! usually the "source" vector is the list child vector, but it is passed separately to enable searching nested //! children, for example when searching the keys of a MAP vectors. template -idx_t ListSearchOp(Vector &list_v, Vector &source_v, Vector &target_v, Vector &result_v, idx_t target_count) { +idx_t ListSearchOp(const Vector &list_v, const Vector &source_v, const Vector &target_v, Vector &result_v, + idx_t target_count) { const auto type = target_v.GetType().InternalType(); switch (type) { case PhysicalType::BOOL: diff --git a/src/duckdb/src/include/duckdb/function/scalar/nested_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/nested_functions.hpp index ffbb747b5..e332c581d 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/nested_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/nested_functions.hpp @@ -50,7 +50,7 @@ struct PositionFunctor { }; struct MapUtil { - static void ReinterpretMap(Vector &target, Vector &other, idx_t count); + static void ReinterpretMap(Vector &target, const Vector &other); }; struct VariableReturnBindData : public FunctionData { diff --git a/src/duckdb/src/include/duckdb/function/scalar/operator_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/operator_functions.hpp index 4b124a44f..bd4473fae 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/operator_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/operator_functions.hpp @@ -105,4 +105,14 @@ struct ModFun { static constexpr const char *Name = "mod"; }; +struct DecimalDivisionFun { + static constexpr const char *Name = "decimal_division"; + static constexpr const char *Parameters = "x,y[,scale]"; + static constexpr const char *Description = "Divides two DECIMAL values using exact integer arithmetic, returning a DECIMAL result with scale determined by SQL Server semantics (result_scale = max(6, s1 + p2 + 1))."; + static constexpr const char *Example = "decimal_division(10.00::DECIMAL(10,2), 3.00::DECIMAL(10,2))"; + static constexpr const char *Categories = ""; + + static ScalarFunctionSet GetFunctions(); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/strftime_format.hpp b/src/duckdb/src/include/duckdb/function/scalar/strftime_format.hpp index 549c44393..9bf059d53 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/strftime_format.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/strftime_format.hpp @@ -105,9 +105,9 @@ struct StrfTimeFormat : public StrTimeFormat { // NOLINT: work-around bug in cla DUCKDB_API static string Format(timestamp_t timestamp, const string &format); - DUCKDB_API void ConvertDateVector(Vector &input, Vector &result, idx_t count); - DUCKDB_API void ConvertTimestampVector(Vector &input, Vector &result, idx_t count); - DUCKDB_API void ConvertTimestampNSVector(Vector &input, Vector &result, idx_t count); + DUCKDB_API void ConvertDateVector(const Vector &input, Vector &result); + DUCKDB_API void ConvertTimestampVector(const Vector &input, Vector &result); + DUCKDB_API void ConvertTimestampNSVector(const Vector &input, Vector &result); protected: //! The variable-length specifiers. To determine total string size, these need to be checked. diff --git a/src/duckdb/src/include/duckdb/function/scalar/string_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/string_functions.hpp index 2be2de359..ca30a91bb 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/string_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/string_functions.hpp @@ -239,6 +239,16 @@ struct ArrayLengthFun { static ScalarFunctionSet GetFunctions(); }; +struct OverlayFun { + static constexpr const char *Name = "overlay"; + static constexpr const char *Parameters = "string,replacement,start,count"; + static constexpr const char *Description = "Replaces a substring of `string` starting at character `start` with `replacement`. If `count` is not specified, it defaults to the length of `replacement`. Note that a `start` value of `1` refers to the first character of the `string`."; + static constexpr const char *Example = "overlay('hello' placing 'xyz' from 2 for 3)\002overlay('hello' placing 'xyz' from 2)"; + static constexpr const char *Categories = "string"; + + static ScalarFunctionSet GetFunctions(); +}; + struct SubstringFun { static constexpr const char *Name = "substring"; static constexpr const char *Parameters = "string,start,length"; diff --git a/src/duckdb/src/include/duckdb/function/scalar/struct_utils.hpp b/src/duckdb/src/include/duckdb/function/scalar/struct_utils.hpp index 443d4e55f..0461b51a9 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/struct_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/struct_utils.hpp @@ -33,13 +33,14 @@ struct StructExtractBindData : public FunctionData { }; inline bool TryGetStructExtractChildIndex(const BoundFunctionExpression &func, idx_t &child_idx) { - if (func.function.GetName() == "struct_extract_at") { - if (func.bind_info) { - child_idx = func.bind_info->Cast().index; + if (func.Function().GetName() == "struct_extract_at") { + if (func.BindInfo()) { + child_idx = func.BindInfo()->Cast().index; return true; } - if (func.children.size() > 1 && func.children[1]->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { - auto &field_value = func.children[1]->Cast().value; + if (func.GetChildren().size() > 1 && + func.GetChildren()[1]->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { + auto &field_value = func.GetChildren()[1]->Cast().GetValue(); if (field_value.IsNull()) { return false; } @@ -52,16 +53,16 @@ inline bool TryGetStructExtractChildIndex(const BoundFunctionExpression &func, i } return false; } - if (func.function.GetName() != "struct_extract" || func.children.size() <= 1 || - func.children[1]->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT || - func.children[0]->GetReturnType().id() != LogicalTypeId::STRUCT) { + if (func.Function().GetName() != "struct_extract" || func.GetChildren().size() <= 1 || + func.GetChildren()[1]->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT || + func.GetChildren()[0]->GetReturnType().id() != LogicalTypeId::STRUCT) { return false; } - auto &field_value = func.children[1]->Cast().value; + auto &field_value = func.GetChildren()[1]->Cast().GetValue(); if (field_value.IsNull() || field_value.type().id() != LogicalTypeId::VARCHAR) { return false; } - child_idx = StructType::GetChildIndexUnsafe(func.children[0]->GetReturnType(), field_value.GetValue()); + child_idx = StructType::GetChildIndexUnsafe(func.GetChildren()[0]->GetReturnType(), field_value.GetValue()); return true; } diff --git a/src/duckdb/src/include/duckdb/function/scalar/system_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/system_functions.hpp index 6db5eadf1..c4ae6b3a6 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/system_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/system_functions.hpp @@ -22,7 +22,7 @@ struct FinalizeFun { static constexpr const char *Example = ""; static constexpr const char *Categories = ""; - static ScalarFunctionSet GetFunctions(); + static ScalarFunction GetFunction(); }; struct CombineFun { @@ -32,7 +32,17 @@ struct CombineFun { static constexpr const char *Example = ""; static constexpr const char *Categories = ""; - static ScalarFunctionSet GetFunctions(); + static ScalarFunction GetFunction(); +}; + +struct ToAggregateStateFun { + static constexpr const char *Name = "to_aggregate_state"; + static constexpr const char *Parameters = "data,name,signature"; + static constexpr const char *Description = "Converts a value into the aggregate state of the aggregate function with the given name and signature. The type of the value must exactly match the state layout of the aggregate function."; + static constexpr const char *Example = "to_aggregate_state({'count': 1, 'value': 42.0}, 'avg', ['DOUBLE'])"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); }; struct WriteLogFun { diff --git a/src/duckdb/src/include/duckdb/function/scalar/tablefilter_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/tablefilter_functions.hpp new file mode 100644 index 000000000..65f83baa3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/function/scalar/tablefilter_functions.hpp @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// function/scalar/tablefilter_functions.hpp +// +// +//===----------------------------------------------------------------------===// +// This file is automatically generated by scripts/generate_functions.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/function/function_set.hpp" + +namespace duckdb { + +struct TableFilterBloomFilterFun { + static constexpr const char *Name = "__internal_tablefilter_bloom_filter"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct TableFilterDynamicFun { + static constexpr const char *Name = "__internal_tablefilter_dynamic"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct TableFilterOptionalFun { + static constexpr const char *Name = "__internal_tablefilter_optional"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct TableFilterPerfectHashJoinFun { + static constexpr const char *Name = "__internal_tablefilter_perfect_hash_join"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct TableFilterPrefixRangeFun { + static constexpr const char *Name = "__internal_tablefilter_prefix_range"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct TableFilterSelectivityOptionalFun { + static constexpr const char *Name = "__internal_tablefilter_selectivity_optional"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp index c318a9236..0d33059f9 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp @@ -15,6 +15,16 @@ namespace duckdb { +struct VariantExistsFun { + static constexpr const char *Name = "variant_exists"; + static constexpr const char *Parameters = "input_variant::VARIANT,path::VARCHAR\001input_variant::VARIANT,path::VARCHAR[]"; + static constexpr const char *Description = "Returns whether the specified path exists in the variant.\001Returns a boolean for each specified path, indicating whether that path exists in the variant."; + static constexpr const char *Example = "variant_exists({'a': { 'a': 1, 'b': 2}}::VARIANT, 'a')\001variant_exists({'a': 1, 'b': 2}::VARIANT, ['a', 'c'])"; + static constexpr const char *Categories = "variant\001variant"; + + static ScalarFunctionSet GetFunctions(); +}; + struct VariantExtractFun { static constexpr const char *Name = "variant_extract"; static constexpr const char *Parameters = "input_variant::VARIANT,field::VARCHAR\001input_variant::VARIANT,index::UINTEGER"; @@ -45,4 +55,14 @@ struct VariantTypeofFun { static ScalarFunction GetFunction(); }; +struct VariantKeysFun { + static constexpr const char *Name = "variant_keys"; + static constexpr const char *Parameters = "input_variant::VARIANT\001input_variant::VARIANT,path::VARCHAR\001input_variant::VARIANT,path::VARCHAR[]"; + static constexpr const char *Description = "Returns the keys present at the root level of the object.\001Returns the keys present at the level of the specified path.\001Returns a list of keys present at the level of each specified path."; + static constexpr const char *Example = "variant_keys({'a': { 'a': 1, 'b': 2}}::VARIANT)\001variant_keys({'a': { 'a': 1, 'b': 2}}::VARIANT, 'a')\001variant_keys({'a': { 'a': 1, 'b': 2}, 'b': {'c': 3}}::VARIANT, ['a', 'b'])"; + static constexpr const char *Categories = "variant\001variant\001variant"; + + static ScalarFunctionSet GetFunctions(); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp index 33eeb3270..78181dd6e 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp @@ -16,6 +16,21 @@ namespace duckdb { +struct VariantPathBindData : public FunctionData { +public: + explicit VariantPathBindData(); + explicit VariantPathBindData(const string &input_path); + explicit VariantPathBindData(const vector &input_paths); + VariantPathBindData(const VariantPathBindData &other) = default; + +public: + unique_ptr Copy() const override; + bool Equals(const FunctionData &other) const override; + +public: + vector> paths; +}; + struct VariantExtractBindData : public FunctionData { public: explicit VariantExtractBindData(const string &str); @@ -74,7 +89,38 @@ struct VariantChildDataCollectionResult { idx_t nested_data_index; }; +struct VariantPathSelection { + explicit VariantPathSelection(const idx_t count) { + value_index_sel.Initialize(count); + new_value_index_sel.Initialize(count); + + // We start at values[0] for every row. + for (idx_t i = 0; i < count; i++) { + value_index_sel[i] = 0; + } + } + + SelectionVector &Input(const idx_t depth) { + return depth % 2 == 0 ? value_index_sel : new_value_index_sel; + } + + SelectionVector &Output(const idx_t depth) { + return depth % 2 == 0 ? new_value_index_sel : value_index_sel; + } + + //! Input and output buffers used during the object walk, switched per iteration + SelectionVector value_index_sel, new_value_index_sel; +}; + struct VariantUtils { + using unary_path_function_t = void (*)(const Vector &, const vector &, Vector &, idx_t); + using many_path_function_t = void (*)(const Vector &, const vector> &, Vector &, + idx_t); + + //! Generic dispatcher for function execution, used by functions that take an (optional) path(s) + DUCKDB_API static void ExecutePathFunction(DataChunk &input, const ExpressionState &state, Vector &result, + const unary_path_function_t &unary_fn, + const many_path_function_t &many_fn); DUCKDB_API static bool IsNestedType(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index); DUCKDB_API static VariantDecimalData DecodeDecimalData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index); @@ -87,21 +133,26 @@ struct VariantUtils { DUCKDB_API static void FindChildValues(const UnifiedVariantVectorData &variant, const VariantPathComponent &component, optional_ptr sel, SelectionVector &res, - ValidityMask &res_validity, const VariantNestedData *nested_data, + ValidityMask &res_validity, const array_ptr &nested_data, const ValidityMask &validity, idx_t count); DUCKDB_API static VariantNestedDataCollectionResult CollectNestedData(const UnifiedVariantVectorData &variant, VariantLogicalType expected_type, const SelectionVector &value_index_sel, idx_t count, optional_idx row, idx_t offset, - VariantNestedData *child_data, ValidityMask &validity); + array_ptr child_data, ValidityMask &validity); + //! Generic unshredded traversal over a variant using a provided path + DUCKDB_API static void TraversePath(const UnifiedVariantVectorData &variant, + const vector &components, idx_t count, + array_ptr nested_data, ValidityMask &validity, + VariantPathSelection &path_selection); DUCKDB_API static vector ValueIsNull(const UnifiedVariantVectorData &variant, const SelectionVector &sel, idx_t count, optional_idx row); DUCKDB_API static Value ConvertVariantToValue(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_idx); - DUCKDB_API static bool Verify(Vector &variant, const SelectionVector &sel_p, idx_t count); + DUCKDB_API static bool Verify(const Vector &variant, const SelectionVector &sel_p, idx_t count); DUCKDB_API static void FinalizeVariantKeys(Vector &variant, OrderedOwningStringMap &dictionary, SelectionVector &sel, idx_t sel_size); - DUCKDB_API static void VariantExtract(Vector &input, const vector &components, Vector &result, - idx_t count); + DUCKDB_API static void VariantExtract(const Vector &input, const vector &components, + Vector &result, idx_t count); DUCKDB_API static void UnshredVariantData(Vector &input, Vector &output, idx_t count); //! Returns the type of a shredded vector that is shredded on "type" DUCKDB_API static LogicalType ShreddedType(const LogicalType &type); @@ -110,4 +161,9 @@ struct VariantUtils { DUCKDB_API static bool VariantSupportsType(const LogicalType &type); }; +struct VariantBindUtils { + DUCKDB_API static unique_ptr VariantPathBind(BindScalarFunctionInput &input); + DUCKDB_API static bool GetConstantArgument(ClientContext &context, const Expression &expr, Value &constant_arg); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar_function.hpp b/src/duckdb/src/include/duckdb/function/scalar_function.hpp index 23816cfd3..d40363fb2 100644 --- a/src/duckdb/src/include/duckdb/function/scalar_function.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar_function.hpp @@ -74,12 +74,12 @@ class ScalarFunctionCatalogEntry; struct StatementProperties; struct FunctionStatisticsPruneInput { - FunctionStatisticsPruneInput(optional_ptr bind_data_p, BaseStatistics &stats_p) + FunctionStatisticsPruneInput(optional_ptr bind_data_p, const BaseStatistics &stats_p) : bind_data(bind_data_p), stats(stats_p) { } optional_ptr bind_data; - BaseStatistics &stats; + const BaseStatistics &stats; }; struct FunctionStatisticsInput { @@ -238,6 +238,9 @@ class BaseScalarFunction { auto GetCollationHandling() const -> FunctionCollationHandling { return properties.collation_handling; } auto SetCollationHandling(FunctionCollationHandling value) -> void { properties.collation_handling = value; } + auto GetCaptureArgumentAliases() const -> bool { return properties.capture_argument_aliases; } + auto SetCaptureArgumentAliases(bool value) -> void { properties.capture_argument_aliases = value; } + //! Set this functions error-mode as fallible (can throw runtime errors) void SetFallible() { properties.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; } //! Set this functions stability as volatile (can not be cached per row) @@ -362,7 +365,8 @@ class BaseScalarFunction { class ScalarFunction : public BaseScalarFunction, public SimpleFunction { // NOLINT: work-around bug in clang-tidy public: - DUCKDB_API ScalarFunction(string name, vector arguments, LogicalType return_type, + DUCKDB_API ScalarFunction(Identifier name, FunctionSignature sig, scalar_function_t function); + DUCKDB_API ScalarFunction(Identifier name, vector arguments, LogicalType return_type, scalar_function_t function, bind_scalar_function_t bind = nullptr, function_statistics_t statistics = nullptr, init_local_state_t init_local_state = nullptr, LogicalType varargs = LogicalType(LogicalTypeId::INVALID), @@ -393,20 +397,19 @@ class ScalarFunction : public BaseScalarFunction, template static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() >= 1); - UnaryExecutor::Execute(input.data[0], result, input.size()); + UnaryExecutor::Execute(input.data[0], result); } template static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 2); - BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result, input.size()); + BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result); } template static void TernaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 3); - TernaryExecutor::ExecuteStandard(input.data[0], input.data[1], input.data[2], result, - input.size()); + TernaryExecutor::ExecuteStandard(input.data[0], input.data[1], input.data[2], result); } public: diff --git a/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp b/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp index e553a6d41..f853074d3 100644 --- a/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp +++ b/src/duckdb/src/include/duckdb/function/table/direct_file_reader.hpp @@ -19,7 +19,7 @@ class DirectFileReader : public BaseFileReader { ~DirectFileReader() override; public: - unique_ptr GetStatistics(ClientContext &context, const string &name) override; + unique_ptr GetStatistics(ClientContext &context, const Identifier &name) override; bool TryInitializeScan(ClientContext &context, GlobalTableFunctionState &gstate, LocalTableFunctionState &lstate) override; diff --git a/src/duckdb/src/include/duckdb/function/table/read_csv.hpp b/src/duckdb/src/include/duckdb/function/table/read_csv.hpp index 62a7f0af8..d64577f11 100644 --- a/src/duckdb/src/include/duckdb/function/table/read_csv.hpp +++ b/src/duckdb/src/include/duckdb/function/table/read_csv.hpp @@ -41,8 +41,8 @@ struct BaseCSVData : public TableFunctionData { }; struct WriteCSVData : public BaseCSVData { - explicit WriteCSVData(vector names) { - options.name_list = std::move(names); + explicit WriteCSVData(const vector &names) { + options.name_list = IdentifiersToStrings(names); if (options.dialect_options.state_machine_options.escape == '\0') { options.dialect_options.state_machine_options.escape = options.dialect_options.state_machine_options.quote; } diff --git a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp index cbe462bf6..6e529e589 100644 --- a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp @@ -123,6 +123,10 @@ struct DuckDBExternalFileCacheFun { static void RegisterFunction(BuiltinFunctions &set); }; +struct DuckDBMetricsFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + struct DuckDBOptimizersFun { static void RegisterFunction(BuiltinFunctions &set); }; diff --git a/src/duckdb/src/include/duckdb/function/table/table_scan.hpp b/src/duckdb/src/include/duckdb/function/table/table_scan.hpp index f7e684c7e..b7be2d10c 100644 --- a/src/duckdb/src/include/duckdb/function/table/table_scan.hpp +++ b/src/duckdb/src/include/duckdb/function/table/table_scan.hpp @@ -10,6 +10,7 @@ #include "duckdb/function/table_function.hpp" #include "duckdb/common/atomic.hpp" +#include "duckdb/common/unordered_set.hpp" #include "duckdb/function/built_in_functions.hpp" namespace duckdb { @@ -30,6 +31,8 @@ struct TableScanBindData : public TableFunctionData { bool is_create_index; //! In what order to scan the row groups unique_ptr order_options; + //! Subset of partition indices to scan, if null, scan all + unique_ptr> partitions_to_scan; public: bool Equals(const FunctionData &other_p) const override { @@ -42,6 +45,8 @@ struct TableScanBindData : public TableFunctionData { bind_data->is_create_index = is_create_index; bind_data->column_ids = column_ids; bind_data->order_options = order_options ? make_uniq(*order_options) : nullptr; + bind_data->partitions_to_scan = + partitions_to_scan ? make_uniq>(*partitions_to_scan) : nullptr; return std::move(bind_data); } }; diff --git a/src/duckdb/src/include/duckdb/function/table_function.hpp b/src/duckdb/src/include/duckdb/function/table_function.hpp index a8042f433..1ce75139a 100644 --- a/src/duckdb/src/include/duckdb/function/table_function.hpp +++ b/src/duckdb/src/include/duckdb/function/table_function.hpp @@ -17,6 +17,7 @@ #include "duckdb/storage/statistics/node_statistics.hpp" #include "duckdb/storage/table/row_group_reorderer.hpp" #include "duckdb/common/column_index.hpp" +#include "duckdb/common/enums/metric_type.hpp" #include "duckdb/common/table_column.hpp" #include "duckdb/parallel/async_result.hpp" #include "duckdb/function/partition_stats.hpp" @@ -43,6 +44,7 @@ class SampleOptions; struct MultiFileReader; struct OperatorPartitionData; struct OperatorPartitionInfo; +struct OperatorMetrics; enum class OrderByColumnType : uint8_t; enum class OrderType : uint8_t; enum class OrderByStatistics : uint8_t; @@ -103,7 +105,7 @@ struct LocalTableFunctionState { struct TableFunctionBindInput { TableFunctionBindInput(vector &inputs, named_parameter_map_t &named_parameters, - vector &input_table_types, vector &input_table_names, + vector &input_table_types, vector &input_table_names, optional_ptr info, optional_ptr binder, TableFunction &table_function, const TableFunctionRef &ref, optional_ptr> input_plan = nullptr) @@ -115,7 +117,7 @@ struct TableFunctionBindInput { vector &inputs; named_parameter_map_t &named_parameters; vector &input_table_types; - vector &input_table_names; + vector &input_table_names; optional_ptr info; optional_ptr binder; TableFunction &table_function; @@ -201,20 +203,6 @@ struct TableFunctionToStringInput { optional_ptr bind_data; }; -struct TableFunctionDynamicToStringInput { - TableFunctionDynamicToStringInput(const TableFunction &table_function_p, - optional_ptr bind_data_p, - optional_ptr local_state_p, - optional_ptr global_state_p) - : table_function(table_function_p), bind_data(bind_data_p), local_state(local_state_p), - global_state(global_state_p) { - } - const TableFunction &table_function; - optional_ptr bind_data; - optional_ptr local_state; - optional_ptr global_state; -}; - struct TableFunctionGetPartitionInput { public: TableFunctionGetPartitionInput(optional_ptr bind_data_p, @@ -252,6 +240,23 @@ struct GetPartitionStatsInput { optional_ptr bind_data; }; +struct TableFunctionGetMetricsInput { +public: + TableFunctionGetMetricsInput(ClientContext &context, optional_ptr bind_data_p, + optional_ptr local_state_p, + optional_ptr global_state_p, OperatorMetrics &metrics_p) + : context(context), bind_data(bind_data_p), local_state(local_state_p), global_state(global_state_p), + operator_metrics(metrics_p) { + } + +public: + ClientContext &context; + optional_ptr bind_data; + optional_ptr local_state; + optional_ptr global_state; + OperatorMetrics &operator_metrics; +}; + enum class ScanType : uint8_t { TABLE, PARQUET, EXTERNAL }; struct BindInfo { @@ -331,15 +336,12 @@ typedef double (*table_function_progress_t)(ClientContext &context, const Functi typedef void (*table_function_dependency_t)(LogicalDependencyList &dependencies, const FunctionData *bind_data); typedef unique_ptr (*table_function_cardinality_t)(ClientContext &context, const FunctionData *bind_data); -typedef idx_t (*table_function_rows_scanned_t)(GlobalTableFunctionState &global_state, - LocalTableFunctionState &local_state); +typedef void (*table_function_get_metrics_t)(TableFunctionGetMetricsInput &input); typedef void (*table_function_pushdown_complex_filter_t)(ClientContext &context, LogicalGet &get, FunctionData *bind_data, vector> &filters); typedef bool (*table_function_pushdown_expression_t)(ClientContext &context, const LogicalGet &get, Expression &expr); typedef InsertionOrderPreservingMap (*table_function_to_string_t)(TableFunctionToStringInput &input); -typedef InsertionOrderPreservingMap (*table_function_dynamic_to_string_t)( - TableFunctionDynamicToStringInput &input); typedef void (*table_function_serialize_t)(Serializer &serializer, const optional_ptr bind_data, const TableFunction &function); @@ -362,15 +364,20 @@ typedef vector (*table_function_get_row_id_columns)(ClientContext &con typedef void (*table_function_set_scan_order)(unique_ptr order_options, optional_ptr bind_data); +typedef void (*table_function_set_partitions_to_scan_t)(vector partition_indices, + optional_ptr bind_data); + //! When to call init_global to initialize the table function enum class TableFunctionInitialization { INITIALIZE_ON_EXECUTE, INITIALIZE_ON_SCHEDULE }; +enum class TableFunctionReturnType { TABLE_RETURNING_FUNCTION, SET_RETURNING_FUNCTION }; + class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-around bug in clang-tidy public: DUCKDB_API TableFunction(); // Overloads taking table_function_t DUCKDB_API - TableFunction(string name, const vector &arguments, table_function_t function, + TableFunction(Identifier name, const vector &arguments, table_function_t function, table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); DUCKDB_API @@ -378,7 +385,7 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); // Overloads taking std::nullptr DUCKDB_API - TableFunction(string name, const vector &arguments, std::nullptr_t function, + TableFunction(Identifier name, const vector &arguments, std::nullptr_t function, table_function_bind_t bind = nullptr, table_function_init_global_t init_global = nullptr, table_function_init_local_t init_local = nullptr); DUCKDB_API @@ -447,8 +454,8 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou //! (Optional) cardinality function //! Returns the expected cardinality of this scan table_function_cardinality_t cardinality; - //! (Optional) returns the number of rows that have been scanned - table_function_rows_scanned_t rows_scanned; + //! (Optional) returns profiling metrics for this table scan operator + table_function_get_metrics_t get_metrics; //! (Optional) pushdown a set of arbitrary filter expressions, rather than only simple comparisons with a constant //! Any functions remaining in the expression list will be pushed as a regular filter after the scan table_function_pushdown_complex_filter_t pushdown_complex_filter; @@ -456,8 +463,6 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou table_function_pushdown_expression_t pushdown_expression; //! (Optional) function for rendering the operator to a string in explain/profiling output (invoked pre-execution) table_function_to_string_t to_string; - //! (Optional) function for rendering the operator to a string in profiling output (invoked post-execution) - table_function_dynamic_to_string_t dynamic_to_string; //! (Optional) return how much of the table we have scanned up to this point (% of the data) table_function_progress_t table_scan_progress; //! (Optional) returns the partition info of the current scan operator @@ -482,6 +487,8 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou table_function_get_row_id_columns get_row_id_columns; //! (Optional) sets the order to scan the row groups in table_function_set_scan_order set_scan_order; + //! (Optional) restricts the scan to a specific subset of partitions (by index in get_partition_stats order) + table_function_set_partitions_to_scan_t set_partitions_to_scan = nullptr; table_function_serialize_t serialize; table_function_deserialize_t deserialize; @@ -501,6 +508,7 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou bool sampling_pushdown; //! Whether or not the table function supports late materialization bool late_materialization; + TableFunctionReturnType return_type; //! Additional function info, passed to the bind shared_ptr function_info; //! The order preservation type of the table function diff --git a/src/duckdb/src/include/duckdb/function/variant/variant_normalize.hpp b/src/duckdb/src/include/duckdb/function/variant/variant_normalize.hpp index 0e8b7d9bd..a70a7dc7b 100644 --- a/src/duckdb/src/include/duckdb/function/variant/variant_normalize.hpp +++ b/src/duckdb/src/include/duckdb/function/variant/variant_normalize.hpp @@ -83,7 +83,7 @@ struct VariantNormalizer { static void VisitDefault(VariantLogicalType type_id, const_data_ptr_t, VariantNormalizerState &state); public: - static void Normalize(Vector &input, Vector &output, idx_t count); + static void Normalize(const Vector &input, Vector &output); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp b/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp index aa31ea84c..74ec66fd8 100644 --- a/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp +++ b/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp @@ -47,7 +47,7 @@ struct VariantShreddingStats { const VariantColumnStatsData &GetColumnStats(idx_t index) const; public: - void Update(Vector &input, idx_t count); + void Update(const Vector &input, idx_t count); LogicalType GetShreddedType() const; private: diff --git a/src/duckdb/src/include/duckdb/function/window/value_functions.hpp b/src/duckdb/src/include/duckdb/function/window/value_functions.hpp index 2999b0d8b..121266c51 100644 --- a/src/duckdb/src/include/duckdb/function/window/value_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/window/value_functions.hpp @@ -52,7 +52,7 @@ struct LeadFun { static constexpr const char *Example = "LEAD(athlete, 1, NULL) OVER (PARTITION BY event ORDER BY points) AS victorious_over"; static constexpr const char *Categories = ""; - static WindowFunctionSet GetFunctions(); + static WindowFunction GetFunction(); static WindowFunction GetTypedFunction(const LogicalType &type, idx_t nargs); }; @@ -63,7 +63,7 @@ struct LagFun { static constexpr const char *Example = "LAG(athlete, 1, NULL) OVER (PARTITION BY event ORDER BY points) AS defeated_by"; static constexpr const char *Categories = ""; - static WindowFunctionSet GetFunctions(); + static WindowFunction GetFunction(); }; struct FillFun { diff --git a/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp b/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp index dc7350b3d..cbc4583ae 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_boundaries_state.hpp @@ -89,7 +89,7 @@ class WindowBoundariesState { WindowBoundariesState(ExecutionContext &context, const WindowExecutorGlobalState &gstate); void Finalize(CollectionPtr collection); - void UpdateBounds(idx_t row_idx, DataChunk &eval_chunk); + void UpdateBounds(idx_t row_idx, DataChunk &eval_chunk, idx_t count); //! The requested bounds. DataChunk bounds; diff --git a/src/duckdb/src/include/duckdb/function/window/window_executor.hpp b/src/duckdb/src/include/duckdb/function/window/window_executor.hpp index 3427d6cc5..05b81bc69 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_executor.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_executor.hpp @@ -74,7 +74,7 @@ class WindowExecutor { virtual void Finalize(ExecutionContext &context, CollectionPtr collection, OperatorSinkInput &sink) const; void Evaluate(ExecutionContext &context, idx_t row_idx, DataChunk &eval_chunk, Vector &result, - OperatorSinkInput &sink) const; + OperatorSinkInput &sink, idx_t count) const; // The function const BoundWindowExpression &wexpr; diff --git a/src/duckdb/src/include/duckdb/function/window_function.hpp b/src/duckdb/src/include/duckdb/function/window_function.hpp index 26e12102e..2ef66f079 100644 --- a/src/duckdb/src/include/duckdb/function/window_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window_function.hpp @@ -314,7 +314,7 @@ class BaseWindowFunction { class WindowFunction : public BaseWindowFunction, public SimpleFunction { // NOLINT: work-around bug in clang-tidy public: - WindowFunction(const string &name, const vector &arguments, const LogicalType &return_type, + WindowFunction(const Identifier &name, const vector &arguments, const LogicalType &return_type, ExpressionType window_enum, window_bind_function_t bind = nullptr, window_bounds_function_t bounds = nullptr, window_sharing_function_t sharing = nullptr, window_global_function_t global = nullptr, window_local_function_t local = nullptr, @@ -336,7 +336,7 @@ class WindowFunction : public BaseWindowFunction, public SimpleFunction { // NOL window_sharing_function_t sharing = nullptr, window_global_function_t global = nullptr, window_local_function_t local = nullptr, window_sink_function_t sink = nullptr, window_finalize_function_t finalize = nullptr, window_evaluate_function_t evaluate = nullptr) - : WindowFunction(string(), arguments, return_type, window_enum, bind, bounds, sharing, global, local, sink, + : WindowFunction(Identifier(), arguments, return_type, window_enum, bind, bounds, sharing, global, local, sink, finalize, evaluate) { } diff --git a/src/duckdb/src/include/duckdb/logging/log_type.hpp b/src/duckdb/src/include/duckdb/logging/log_type.hpp index dca29940b..e97f98cfa 100644 --- a/src/duckdb/src/include/duckdb/logging/log_type.hpp +++ b/src/duckdb/src/include/duckdb/logging/log_type.hpp @@ -20,8 +20,6 @@ class PhysicalOperator; class AttachedDatabase; class RowGroup; struct DataTableInfo; -enum class MetricType : uint8_t; - //! Log types provide some structure to the formats that the different log messages can have //! For now, this holds a type that the VARCHAR value will be auto-cast into. class LogType { @@ -120,7 +118,7 @@ class MetricsLogType : public LogType { static LogicalType GetLogType(); - static string ConstructLogMessage(const MetricType &type, const Value &value); + static string ConstructLogMessage(const string &metric, const Value &value); }; class CheckpointLogType : public LogType { @@ -172,4 +170,30 @@ class AdaptiveFilterLogType : public LogType { const vector> &info); }; +class ParquetPrefetchLogType : public LogType { +public: + static constexpr const char *NAME = "ParquetPrefetch"; + static constexpr LogLevel LEVEL = LogLevel::LOG_DEBUG; + + ParquetPrefetchLogType(); + + static LogicalType GetLogType(); + + static string ConstructLogMessage(const string &file_path, idx_t row_group_id, bool fully_filtered, + const char *strategy, const vector> &prefetch_groups, + const vector &minimal_filters, uint64_t accepted_column_gap); +}; + +class AsyncTaskScheduleLogType : public LogType { +public: + static constexpr const char *NAME = "AsyncTaskSchedule"; + static constexpr LogLevel LEVEL = LogLevel::LOG_DEBUG; + + AsyncTaskScheduleLogType(); + + static LogicalType GetLogType(); + + static string ConstructLogMessage(const string &pool, idx_t task_count); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/appender.hpp b/src/duckdb/src/include/duckdb/main/appender.hpp index 09c3256f6..532cb23cc 100644 --- a/src/duckdb/src/include/duckdb/main/appender.hpp +++ b/src/duckdb/src/include/duckdb/main/appender.hpp @@ -34,8 +34,8 @@ class BaseAppender { public: //! Returns a table reference to the appended data. - static unique_ptr GetColumnDataTableRef(ColumnDataCollection &collection, const string &table_name, - const vector &expected_names); + static unique_ptr GetColumnDataTableRef(ColumnDataCollection &collection, const Identifier &table_name, + const vector &expected_names); //! Parses the statement to append data. static unique_ptr ParseStatement(unique_ptr table_ref, const string &query, const string &table_name); @@ -107,7 +107,7 @@ class BaseAppender { virtual void AppendDefault(DataChunk &chunk, idx_t col, idx_t row); //! Appends a column to the active column list. //! Immediately flushes all previous data. - virtual void AddColumn(const string &name); + virtual void AddColumn(const Identifier &name); //! Removes all columns from the active column list. //! Immediately flushes all previous data. virtual void ClearColumns(); @@ -144,24 +144,24 @@ class BaseAppender { class Appender : public BaseAppender { public: - DUCKDB_API Appender(Connection &con, const string &database_name, const string &schema_name, - const string &table_name, const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); - DUCKDB_API Appender(Connection &con, const string &schema_name, const string &table_name, + DUCKDB_API Appender(Connection &con, const Identifier &database_name, const Identifier &schema_name, + const Identifier &table_name, const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); + DUCKDB_API Appender(Connection &con, const Identifier &schema_name, const Identifier &table_name, const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); - DUCKDB_API Appender(Connection &con, const string &table_name, + DUCKDB_API Appender(Connection &con, const Identifier &table_name, const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); DUCKDB_API ~Appender() override; public: void AppendDefault() override; void AppendDefault(DataChunk &chunk, idx_t col, idx_t row) override; - void AddColumn(const string &name) override; + void AddColumn(const Identifier &name) override; void ClearColumns() override; //! Get the expected names based on the active columns. - vector GetExpectedNames(); + vector GetExpectedNames(); //! Construct a query that appends data from, typically, a column data collection. - static string ConstructQuery(TableDescription &description_p, const string &table_name, - const vector &expected_names); + static string ConstructQuery(TableDescription &description_p, const Identifier &table_name, + const vector &expected_names); private: //! A shared pointer to the context of this appender. @@ -183,7 +183,7 @@ class Appender : public BaseAppender { class QueryAppender : public BaseAppender { public: DUCKDB_API QueryAppender(Connection &con, string query, vector types, - vector names = vector(), string table_name = string(), + vector names = vector(), Identifier table_name = Identifier(), const idx_t flush_memory_threshold = DConstants::INVALID_INDEX); DUCKDB_API ~QueryAppender() override; @@ -193,9 +193,9 @@ class QueryAppender : public BaseAppender { //! The query to run. string query; //! The column names of the to-be-appended data, or "col1, col2, ...", if empty. - vector names; + vector names; //! The table name that we can reference in the query, or "appended_data", if empty. - string table_name; + Identifier table_name; protected: void FlushInternal(ColumnDataCollection &collection) override; diff --git a/src/duckdb/src/include/duckdb/main/attached_database.hpp b/src/duckdb/src/include/duckdb/main/attached_database.hpp index 0d46cf88b..66d56c05e 100644 --- a/src/duckdb/src/include/duckdb/main/attached_database.hpp +++ b/src/duckdb/src/include/duckdb/main/attached_database.hpp @@ -45,7 +45,8 @@ enum class DatabaseCloseAction { CHECKPOINT, TRY_CHECKPOINT, SKIP_CHECKPOINT }; class DatabaseFilePathManager; struct StoredDatabasePath { - StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path, const string &name); + StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path, + const Identifier &name); ~StoredDatabasePath(); DatabaseManager &db_manager; @@ -79,6 +80,8 @@ struct AttachOptions { AttachVisibility visibility = AttachVisibility::SHOWN; //! The stored database path (in the path manager) unique_ptr stored_database_path; + //! Per-database override of vacuum_rebuild_indexes. If not set, the global setting value is used. + optional_idx vacuum_rebuild_indexes_threshold; }; //! The AttachedDatabase represents an attached database instance. @@ -87,10 +90,10 @@ class AttachedDatabase : public CatalogEntry, public enable_shared_from_this &GetAttachOptions() const { return attach_options; } string StoredPath() const; - static bool NameIsReserved(const string &name); - static string ExtractDatabaseName(const string &dbpath, FileSystem &fs); + static bool NameIsReserved(const Identifier &name); + static Identifier ExtractDatabaseName(const string &dbpath, FileSystem &fs); // Invoke Close() on an attached database, if its use count is 1. // Only call this in places where you know that the (last) shared pointer is about to go out of scope. static void InvokeCloseIfLastReference(shared_ptr &attached_database, ClientContext &context); @@ -157,6 +163,7 @@ class AttachedDatabase : public CatalogEntry, public enable_shared_from_this close_lock; + optional_idx vacuum_rebuild_threshold; unordered_map attach_options; private: diff --git a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp index 4c15ef88b..2434b9c41 100644 --- a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp @@ -49,7 +49,7 @@ struct CClientArrowOptionsWrapper { struct PreparedStatementWrapper { //! Map of name -> values - case_insensitive_map_t values; + identifier_map_t values; unique_ptr statement; bool success = true; ErrorData error_data; @@ -74,6 +74,7 @@ struct ArrowResultWrapper { struct AppenderWrapper { unique_ptr appender; ErrorData error_data; + bool flush_failed = false; }; struct TableDescriptionWrapper { diff --git a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp index 45e2db719..72210f42e 100644 --- a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp @@ -15,9 +15,9 @@ typedef struct { void (*duckdb_interrupt)(duckdb_connection connection); duckdb_query_progress_type (*duckdb_query_progress)(duckdb_connection connection); void (*duckdb_disconnect)(duckdb_connection *connection); - const char *(*duckdb_library_version)(); + const char *(*duckdb_library_version)(void); duckdb_state (*duckdb_create_config)(duckdb_config *out_config); - size_t (*duckdb_config_count)(); + size_t (*duckdb_config_count)(void); duckdb_state (*duckdb_get_config_flag)(size_t index, const char **out_name, const char **out_description); duckdb_state (*duckdb_set_config)(duckdb_config config, const char *name, const char *option); void (*duckdb_destroy_config)(duckdb_config *config); @@ -34,7 +34,7 @@ typedef struct { duckdb_result_type (*duckdb_result_return_type)(duckdb_result result); void *(*duckdb_malloc)(size_t size); void (*duckdb_free)(void *ptr); - idx_t (*duckdb_vector_size)(); + idx_t (*duckdb_vector_size)(void); bool (*duckdb_string_is_inlined)(duckdb_string_t string); uint32_t (*duckdb_string_t_length)(duckdb_string_t string); const char *(*duckdb_string_t_data)(duckdb_string_t *string); @@ -172,7 +172,7 @@ typedef struct { duckdb_value (*duckdb_get_map_key)(duckdb_value value, idx_t index); duckdb_value (*duckdb_get_map_value)(duckdb_value value, idx_t index); bool (*duckdb_is_null_value)(duckdb_value value); - duckdb_value (*duckdb_create_null_value)(); + duckdb_value (*duckdb_create_null_value)(void); idx_t (*duckdb_get_list_size)(duckdb_value value); duckdb_value (*duckdb_get_list_child)(duckdb_value value, idx_t index); duckdb_value (*duckdb_create_enum_value)(duckdb_logical_type type, uint64_t value); @@ -234,7 +234,7 @@ typedef struct { void (*duckdb_validity_set_row_validity)(uint64_t *validity, idx_t row, bool valid); void (*duckdb_validity_set_row_invalid)(uint64_t *validity, idx_t row); void (*duckdb_validity_set_row_valid)(uint64_t *validity, idx_t row); - duckdb_scalar_function (*duckdb_create_scalar_function)(); + duckdb_scalar_function (*duckdb_create_scalar_function)(void); void (*duckdb_destroy_scalar_function)(duckdb_scalar_function *scalar_function); void (*duckdb_scalar_function_set_name)(duckdb_scalar_function scalar_function, const char *name); void (*duckdb_scalar_function_set_varargs)(duckdb_scalar_function scalar_function, duckdb_logical_type type); @@ -253,7 +253,7 @@ typedef struct { void (*duckdb_destroy_scalar_function_set)(duckdb_scalar_function_set *scalar_function_set); duckdb_state (*duckdb_add_scalar_function_to_set)(duckdb_scalar_function_set set, duckdb_scalar_function function); duckdb_state (*duckdb_register_scalar_function_set)(duckdb_connection con, duckdb_scalar_function_set set); - duckdb_aggregate_function (*duckdb_create_aggregate_function)(); + duckdb_aggregate_function (*duckdb_create_aggregate_function)(void); void (*duckdb_destroy_aggregate_function)(duckdb_aggregate_function *aggregate_function); void (*duckdb_aggregate_function_set_name)(duckdb_aggregate_function aggregate_function, const char *name); void (*duckdb_aggregate_function_add_parameter)(duckdb_aggregate_function aggregate_function, @@ -280,7 +280,7 @@ typedef struct { duckdb_state (*duckdb_add_aggregate_function_to_set)(duckdb_aggregate_function_set set, duckdb_aggregate_function function); duckdb_state (*duckdb_register_aggregate_function_set)(duckdb_connection con, duckdb_aggregate_function_set set); - duckdb_table_function (*duckdb_create_table_function)(); + duckdb_table_function (*duckdb_create_table_function)(void); void (*duckdb_destroy_table_function)(duckdb_table_function *table_function); void (*duckdb_table_function_set_name)(duckdb_table_function table_function, const char *name); void (*duckdb_table_function_add_parameter)(duckdb_table_function table_function, duckdb_logical_type type); @@ -354,7 +354,7 @@ typedef struct { void (*duckdb_destroy_task_state)(duckdb_task_state state); bool (*duckdb_execution_is_finished)(duckdb_connection con); duckdb_data_chunk (*duckdb_fetch_chunk)(duckdb_result result); - duckdb_cast_function (*duckdb_create_cast_function)(); + duckdb_cast_function (*duckdb_create_cast_function)(void); void (*duckdb_cast_function_set_source_type)(duckdb_cast_function cast_function, duckdb_logical_type source_type); void (*duckdb_cast_function_set_target_type)(duckdb_cast_function cast_function, duckdb_logical_type target_type); void (*duckdb_cast_function_set_implicit_cast_cost)(duckdb_cast_function cast_function, int64_t cost); @@ -462,7 +462,7 @@ typedef struct { duckdb_data_chunk (*duckdb_stream_fetch_chunk)(duckdb_result result); // Exposing the instance cache - duckdb_instance_cache (*duckdb_create_instance_cache)(); + duckdb_instance_cache (*duckdb_create_instance_cache)(void); duckdb_state (*duckdb_get_or_create_from_cache)(duckdb_instance_cache instance_cache, const char *path, duckdb_database *out_database, duckdb_config config, char **out_error); @@ -501,7 +501,7 @@ typedef struct { void (*duckdb_destroy_catalog_entry)(duckdb_catalog_entry *entry); // New configuration options functions - duckdb_config_option (*duckdb_create_config_option)(); + duckdb_config_option (*duckdb_create_config_option)(void); void (*duckdb_destroy_config_option)(duckdb_config_option *option); void (*duckdb_config_option_set_name)(duckdb_config_option option, const char *name); void (*duckdb_config_option_set_type)(duckdb_config_option option, duckdb_logical_type type); @@ -514,7 +514,7 @@ typedef struct { duckdb_config_option_scope *out_scope); // API to define custom copy functions - duckdb_copy_function (*duckdb_create_copy_function)(); + duckdb_copy_function (*duckdb_create_copy_function)(void); void (*duckdb_copy_function_set_name)(duckdb_copy_function copy_function, const char *name); void (*duckdb_copy_function_set_extra_info)(duckdb_copy_function copy_function, void *extra_info, duckdb_delete_callback_t destructor); @@ -579,7 +579,7 @@ typedef struct { duckdb_state (*duckdb_file_system_open)(duckdb_file_system file_system, const char *path, duckdb_file_open_options options, duckdb_file_handle *out_file); duckdb_error_data (*duckdb_file_system_error_data)(duckdb_file_system file_system); - duckdb_file_open_options (*duckdb_create_file_open_options)(); + duckdb_file_open_options (*duckdb_create_file_open_options)(void); duckdb_state (*duckdb_file_open_options_set_flag)(duckdb_file_open_options options, duckdb_file_flag flag, bool value); void (*duckdb_destroy_file_open_options)(duckdb_file_open_options *options); @@ -597,7 +597,7 @@ typedef struct { char *(*duckdb_geometry_type_get_crs)(duckdb_logical_type type); // API to register a custom log storage. - duckdb_log_storage (*duckdb_create_log_storage)(); + duckdb_log_storage (*duckdb_create_log_storage)(void); void (*duckdb_destroy_log_storage)(duckdb_log_storage *log_storage); void (*duckdb_log_storage_set_write_log_entry)(duckdb_log_storage log_storage, duckdb_logger_write_log_entry_t function); @@ -684,7 +684,7 @@ typedef struct { //===--------------------------------------------------------------------===// // Struct Create Method //===--------------------------------------------------------------------===// -inline duckdb_ext_api_v1 CreateAPIv1() { +inline duckdb_ext_api_v1 CreateAPIv1(void) { duckdb_ext_api_v1 result; result.duckdb_open = duckdb_open; result.duckdb_open_ext = duckdb_open_ext; diff --git a/src/duckdb/src/include/duckdb/main/client_config.hpp b/src/duckdb/src/include/duckdb/main/client_config.hpp index ffe0ac720..6f5e0102d 100644 --- a/src/duckdb/src/include/duckdb/main/client_config.hpp +++ b/src/duckdb/src/include/duckdb/main/client_config.hpp @@ -15,7 +15,6 @@ #include "duckdb/common/enums/profiler_format.hpp" #include "duckdb/common/progress_bar/progress_bar.hpp" #include "duckdb/common/types/value.hpp" -#include "duckdb/main/profiling_info.hpp" #include "duckdb/parser/expression/lambda_expression.hpp" #include "duckdb/main/query_profiler.hpp" #include "duckdb/main/user_settings.hpp" @@ -39,11 +38,9 @@ struct ClientConfig { //! The file to save query profiling information to, instead of printing it to the console //! (empty = print to console) string profiler_save_location; - //! The custom settings for the profiler - //! (empty = use the default settings) - profiler_settings_t profiler_settings = MetricsUtils::GetDefaultMetrics(); - //! The input format type of the profiler settings - LogicalTypeId profiler_settings_type = LogicalTypeId::VARCHAR; + //! Glob patterns for tracked_metrics (controls which metrics are gathered and displayed). + //! Default "*" means all metrics are tracked. + vector tracked_metrics = {"*"}; //! Allows suppressing profiler output, even if enabled. We turn on the profiler on all test runs but don't want //! to output anything @@ -58,8 +55,6 @@ struct ClientConfig { //! The wait time before showing the progress bar int wait_time = 2000; - //! Enable the running of optimizers - bool enable_optimizer = true; //! Force parallelism of small tables, used for testing bool verify_parallelism = false; //! If this context should also try to use the available replacement scans @@ -89,13 +84,6 @@ struct ClientConfig { //! Function that is used to create the result collector for a materialized result. get_result_collector_t get_result_collector = nullptr; - //! If HTTP logging is enabled or not. - bool enable_http_logging = true; - - //! **DEPRECATED** The file to save query HTTP logging information to, instead of printing it to the console - //! (empty = output to the DuckDB logger) - string http_logging_output; - public: static ClientConfig &GetConfig(ClientContext &context); static const ClientConfig &GetConfig(const ClientContext &context); diff --git a/src/duckdb/src/include/duckdb/main/client_context.hpp b/src/duckdb/src/include/duckdb/main/client_context.hpp index 7454dd82c..a70e33da9 100644 --- a/src/duckdb/src/include/duckdb/main/client_context.hpp +++ b/src/duckdb/src/include/duckdb/main/client_context.hpp @@ -29,11 +29,13 @@ #include "duckdb/main/table_description.hpp" #include "duckdb/planner/expression/bound_parameter_data.hpp" #include "duckdb/transaction/transaction_context.hpp" +#include "duckdb/main/query_context.hpp" #include "duckdb/main/query_parameters.hpp" namespace duckdb { class Appender; +class AttachedDatabase; class Catalog; class CatalogSearchPath; class ColumnDataCollection; @@ -57,7 +59,7 @@ class RegisteredStateManager; struct PendingQueryParameters { //! Prepared statement parameters (if any) - optional_ptr> parameters; + optional_ptr> parameters; //! Whether a stream/buffer-managed result should be allowed QueryParameters query_parameters; }; @@ -97,6 +99,20 @@ class ClientContext : public enable_shared_from_this { TransactionContext transaction; public: + //! Connect this client to a remote-style AttachedDatabase. Subsequent non-control SQL routes via + //! Catalog::GetConnectFunctionName. Use DisconnectFromCatalog() to revert to LOCAL. + DUCKDB_API void ConnectToCatalog(const shared_ptr &target); + //! Clear any active CONNECT; subsequent SQL goes through the normal DuckDB pipeline. + DUCKDB_API void DisconnectFromCatalog(); + //! True iff a CONNECT is currently active (even if the target was detached out from under us). + DUCKDB_API bool IsConnected() const { + return is_connected; + } + //! Resolve the currently-connected AttachedDatabase. Returns nullptr if not connected or if the + //! target has been detached out from under us (in that case IsConnected() is still true — call it + //! directly to disambiguate "never connected" from "was connected, target was detached elsewhere"). + DUCKDB_API shared_ptr TryGetConnectedCatalog() const; + MetaTransaction &ActiveTransaction() { return transaction.ActiveTransaction(); } @@ -132,21 +148,20 @@ class ClientContext : public enable_shared_from_this { //! Create a pending query with a list of parameters DUCKDB_API unique_ptr PendingQuery(unique_ptr statement, - case_insensitive_map_t &values, - QueryParameters query_parameters); - DUCKDB_API unique_ptr PendingQuery(const string &query, - case_insensitive_map_t &values, + identifier_map_t &values, QueryParameters query_parameters); + DUCKDB_API unique_ptr + PendingQuery(const string &query, identifier_map_t &values, QueryParameters query_parameters); DUCKDB_API unique_ptr PendingQuery(const string &query, PendingQueryParameters parameters); //! Destroy the client context DUCKDB_API void Destroy(); //! Get the table info of a specific table, or nullptr if it cannot be found. - DUCKDB_API unique_ptr TableInfo(const string &database_name, const string &schema_name, - const string &table_name); + DUCKDB_API unique_ptr TableInfo(const Identifier &database_name, const Identifier &schema_name, + const Identifier &table_name); //! Get the table info of a specific table, or nullptr if it cannot be found. Uses INVALID_CATALOG. - DUCKDB_API unique_ptr TableInfo(const string &schema_name, const string &table_name); + DUCKDB_API unique_ptr TableInfo(const Identifier &schema_name, const Identifier &table_name); //! Executes a query with the given collection "attached" to the query using a CTE. DUCKDB_API void Append(unique_ptr stmt); //! Appends a ColumnDataCollection to the described table. @@ -181,7 +196,7 @@ class ClientContext : public enable_shared_from_this { //! modified in between the prepared statement being bound and the prepared statement being run. DUCKDB_API unique_ptr Execute(const string &query, shared_ptr &prepared, - case_insensitive_map_t &values, + identifier_map_t &values, QueryParameters query_parameters = QueryResultOutputType::ALLOW_STREAMING); DUCKDB_API unique_ptr Execute(const string &query, shared_ptr &prepared, const PendingQueryParameters ¶meters); @@ -331,6 +346,11 @@ class ClientContext : public enable_shared_from_this { QueryProgress query_progress; //! The connection corresponding to this client context connection_t connection_id; + //! Routing target for SQL execution while CONNECT-ed (CONNECT/DISCONNECT). When is_connected is + //! true and connected_to_database can be locked, the chokepoint dispatches non-control SQL via + //! `Catalog::RemoteExecute(string)` and wraps the returned TableRef into a SelectStatement. + weak_ptr connected_to_database; + bool is_connected = false; }; class ClientContextLock { @@ -345,27 +365,4 @@ class ClientContextLock { lock_guard client_guard; }; -//! The QueryContext wraps an optional client context. -//! It makes query-related information available to operations. -class QueryContext { -public: - QueryContext() : context(nullptr) { - } - QueryContext(optional_ptr context) : context(context) { // NOLINT: allow implicit construction - } - QueryContext(ClientContext &context) : context(&context) { // NOLINT: allow implicit construction - } - -public: - bool Valid() const { - return context != nullptr; - } - optional_ptr GetClientContext() const { - return context; - } - -private: - optional_ptr context; -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/client_context_state.hpp b/src/duckdb/src/include/duckdb/main/client_context_state.hpp index fcedae943..a5ce66e29 100644 --- a/src/duckdb/src/include/duckdb/main/client_context_state.hpp +++ b/src/duckdb/src/include/duckdb/main/client_context_state.hpp @@ -42,7 +42,7 @@ struct PreparedStatementCallbackInfo { struct BindPreparedStatementCallbackInfo { PreparedStatementData &prepared_statement; - optional_ptr> parameters; + optional_ptr> parameters; }; //! ClientContextState is virtual base class for ClientContext-local (or Query-Local, using QueryEnd callback) state diff --git a/src/duckdb/src/include/duckdb/main/client_data.hpp b/src/duckdb/src/include/duckdb/main/client_data.hpp index a2068b68e..44b4b71d0 100644 --- a/src/duckdb/src/include/duckdb/main/client_data.hpp +++ b/src/duckdb/src/include/duckdb/main/client_data.hpp @@ -36,7 +36,7 @@ struct ClientData { //! The set of temporary objects belonging to this client. shared_ptr temporary_objects; //! The set of bound prepared statements belonging to this client. - case_insensitive_map_t> prepared_statements; + identifier_map_t> prepared_statements; //! The random generator used by random(). //! Its seed value can be set by setseed(). @@ -52,11 +52,6 @@ struct ClientData { //! The clients' buffer manager wrapper. unique_ptr client_buffer_manager; - //! The Max Line Length Size of Last Query Executed on a CSV File. (Only used for testing) - //! FIXME: this should not be done like this - bool debug_set_max_line_length = false; - idx_t debug_max_line_length = 0; - //! The writer used to log queries, if logging is enabled. //! DEPRECATED: Now, queries are written to the generic log. unique_ptr log_query_writer; diff --git a/src/duckdb/src/include/duckdb/main/config.hpp b/src/duckdb/src/include/duckdb/main/config.hpp index fa38063e4..def558ad9 100644 --- a/src/duckdb/src/include/duckdb/main/config.hpp +++ b/src/duckdb/src/include/duckdb/main/config.hpp @@ -36,7 +36,7 @@ #include "duckdb/main/user_settings.hpp" #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/common/types/type_manager.hpp" -#include "duckdb/common/serialization_compatibility.hpp" +#include "duckdb/common/storage_compatibility.hpp" #include "duckdb/common/enums/debug_verification_mode.hpp" namespace duckdb { @@ -78,8 +78,6 @@ struct DBConfigOptions { AccessMode access_mode = AccessMode::AUTOMATIC; //! Checkpoint when WAL reaches this size (default: 16MiB) idx_t checkpoint_wal_size = 1 << 24; - //! Whether or not to use Direct IO, bypassing operating system buffers - bool use_direct_io = false; //! Whether extensions should be loaded on start-up bool load_extensions = true; //! The maximum memory used by the database system (in bytes). Default: 80% of System available memory @@ -88,6 +86,8 @@ struct DBConfigOptions { idx_t maximum_swap_space = DConstants::INVALID_INDEX; //! The maximum amount of CPU threads used by the database system. Default: all available. idx_t maximum_threads = DConstants::INVALID_INDEX; + //! The maximum amount of async threads used by the database system. Default: all available. + idx_t async_threads = DConstants::INVALID_INDEX; //! Whether or not to create and use a temporary directory to store intermediates that do not fit in memory bool use_temporary_directory = true; //! Directory to store temporary structures that do not fit in memory @@ -102,7 +102,7 @@ struct DBConfigOptions { //! Run a checkpoint on successful shutdown and delete the WAL, to leave only a single database file behind bool checkpoint_on_shutdown = true; //! Serialize the metadata on checkpoint with compatibility for a given DuckDB version. - SerializationCompatibility serialization_compatibility = SerializationCompatibility::Default(); + StorageCompatibility storage_compatibility = StorageCompatibility::Default(); //! Initialize the database with the standard set of DuckDB functions //! You should probably not touch this unless you know what you are doing bool initialize_default_database = true; @@ -122,14 +122,14 @@ struct DBConfigOptions { case_insensitive_map_t user_options; //! The set of unrecognized (other) options case_insensitive_map_t unrecognized_options; - //! The peak allocation threshold at which to flush the allocator after completing a task (1 << 27, ~128MB) - idx_t allocator_flush_threshold = 134217728ULL; //! If bulk deallocation larger than this occurs, flush outstanding allocations (1 << 30, ~1GB) idx_t allocator_bulk_deallocation_flush_threshold = 536870912ULL; //! Delta Only! - Fall back to recognizing Variant columns structurally bool variant_legacy_encoding = false; //! Metadata from DuckDB callers string custom_user_agent; + //! HTTP proxy host (defaults to the HTTP_PROXY environment variable when unset) + string http_proxy; //! The default block header size for new duckdb database files. idx_t default_block_header_size = DEFAULT_BLOCK_HEADER_STORAGE_SIZE; //! Whether or not to abort if a serialization exception is thrown during WAL playback (when reading truncated WAL) @@ -139,11 +139,13 @@ struct DBConfigOptions { //! Directories that are explicitly allowed, even if enable_external_access is false set allowed_directories; //! Additional configuration options that are allowed to be changed even when the configuration is locked - case_insensitive_set_t allowed_configs; + identifier_set_t allowed_configs; //! The log configuration LogConfig log_config = LogConfig(); //! Physical memory that the block allocator is allowed to use (this memory is never freed and cannot be reduced) idx_t block_allocator_size = 0; + //! Memory limit for the write buffer per row group (optional) + optional_idx write_buffer_row_group_memory_limit; //! Whether to print bindings when printing the plan (debug mode only) static bool debug_print_bindings; // NOLINT: debug setting //! The global verification mode @@ -269,6 +271,7 @@ struct DBConfig { DUCKDB_API CollationBinding &GetCollationBinding(); DUCKDB_API IndexTypeSet &GetIndexTypes(); static idx_t GetSystemMaxThreads(FileSystem &fs); + static idx_t GetSystemMaxAsyncThreads(FileSystem &fs); static idx_t GetSystemAvailableMemory(FileSystem &fs); static optional_idx ParseMemoryLimitSlurm(const string &arg); void SetDefaultMaxMemory(); diff --git a/src/duckdb/src/include/duckdb/main/connection.hpp b/src/duckdb/src/include/duckdb/main/connection.hpp index 91e55441d..32b3f6b92 100644 --- a/src/duckdb/src/include/duckdb/main/connection.hpp +++ b/src/duckdb/src/include/duckdb/main/connection.hpp @@ -19,7 +19,6 @@ #include "duckdb/main/stream_query_result.hpp" #include "duckdb/main/table_description.hpp" #include "duckdb/parser/sql_statement.hpp" -#include "duckdb/main/profiling_node.hpp" namespace duckdb { @@ -54,9 +53,6 @@ class Connection { //! Returns query profiling information for the current query DUCKDB_API string GetProfilingInformation(ProfilerPrintFormat format = ProfilerPrintFormat::QUERY_TREE); - //! Returns the first node of the query profiling tree - DUCKDB_API optional_ptr GetProfilingTree(); - //! Interrupt execution of the current query DUCKDB_API void Interrupt(); @@ -109,10 +105,10 @@ class Connection { PendingQuery(unique_ptr statement, QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); DUCKDB_API unique_ptr - PendingQuery(unique_ptr statement, case_insensitive_map_t &named_values, + PendingQuery(unique_ptr statement, identifier_map_t &named_values, QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); DUCKDB_API unique_ptr - PendingQuery(const string &query, case_insensitive_map_t &named_values, + PendingQuery(const string &query, identifier_map_t &named_values, QueryParameters query_parameters = QueryResultOutputType::FORCE_MATERIALIZED); DUCKDB_API unique_ptr PendingQuery(const string &query, vector &values, @@ -128,13 +124,13 @@ class Connection { DUCKDB_API unique_ptr Prepare(unique_ptr statement); //! Get the table info of a specific table, or nullptr if it cannot be found. - DUCKDB_API unique_ptr TableInfo(const string &database_name, const string &schema_name, - const string &table_name); + DUCKDB_API unique_ptr TableInfo(const Identifier &database_name, const Identifier &schema_name, + const Identifier &table_name); //! Get the table info of a specific table, or nullptr if it cannot be found. Uses INVALID_CATALOG. - DUCKDB_API unique_ptr TableInfo(const string &schema_name, const string &table_name); + DUCKDB_API unique_ptr TableInfo(const Identifier &schema_name, const Identifier &table_name); //! Get the table info of a specific table, or nullptr if it cannot be found. Uses INVALID_CATALOG and //! DEFAULT_SCHEMA. - DUCKDB_API unique_ptr TableInfo(const string &table_name); + DUCKDB_API unique_ptr TableInfo(const Identifier &table_name); //! Extract a set of SQL statements from a specific query DUCKDB_API vector> ExtractStatements(const string &query); @@ -145,13 +141,13 @@ class Connection { DUCKDB_API void Append(TableDescription &description, ColumnDataCollection &collection); //! Returns a relation that produces a table from this connection - DUCKDB_API shared_ptr Table(const string &tname); - DUCKDB_API shared_ptr Table(const string &schema_name, const string &table_name); - DUCKDB_API shared_ptr Table(const string &catalog_name, const string &schema_name, - const string &table_name); + DUCKDB_API shared_ptr Table(const Identifier &tname); + DUCKDB_API shared_ptr Table(const Identifier &schema_name, const Identifier &table_name); + DUCKDB_API shared_ptr Table(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &table_name); //! Returns a relation that produces a view from this connection - DUCKDB_API shared_ptr View(const string &tname); - DUCKDB_API shared_ptr View(const string &schema_name, const string &table_name); + DUCKDB_API shared_ptr View(const Identifier &tname); + DUCKDB_API shared_ptr View(const Identifier &schema_name, const Identifier &table_name); //! Returns a relation that calls a specified table function DUCKDB_API shared_ptr TableFunction(const string &tname); DUCKDB_API shared_ptr TableFunction(const string &tname, const vector &values, diff --git a/src/duckdb/src/include/duckdb/main/database.hpp b/src/duckdb/src/include/duckdb/main/database.hpp index 2fc67d0da..55b0eb5f0 100644 --- a/src/duckdb/src/include/duckdb/main/database.hpp +++ b/src/duckdb/src/include/duckdb/main/database.hpp @@ -17,6 +17,7 @@ #include "duckdb/main/extension_manager.hpp" namespace duckdb { +class LocalDatabaseFileSystem; class BufferManager; class DatabaseManager; @@ -33,6 +34,7 @@ struct AttachOptions; class DatabaseFileSystem; struct DatabaseCacheEntry; class LogManager; +class MetricsManager; class ExternalFileCache; class ResultSetManager; struct ParserCache; @@ -53,6 +55,7 @@ class DatabaseInstance : public enable_shared_from_this { DUCKDB_API const BufferManager &GetBufferManager() const; DUCKDB_API DatabaseManager &GetDatabaseManager(); DUCKDB_API FileSystem &GetFileSystem(); + DUCKDB_API FileSystem &GetLocalFileSystem(); DUCKDB_API ExternalFileCache &GetExternalFileCache(); DUCKDB_API ResultSetManager &GetResultSetManager(); DUCKDB_API TaskScheduler &GetScheduler(); @@ -61,6 +64,7 @@ class DatabaseInstance : public enable_shared_from_this { DUCKDB_API ExtensionManager &GetExtensionManager(); DUCKDB_API ValidChecker &GetValidChecker(); DUCKDB_API LogManager &GetLogManager() const; + DUCKDB_API MetricsManager &GetMetricsManager(); DUCKDB_API ParserCache &GetParserCache(); DUCKDB_API const duckdb_ext_api_v1 GetExtensionAPIV1(); @@ -96,7 +100,9 @@ class DatabaseInstance : public enable_shared_from_this { unique_ptr extension_manager; ValidChecker db_validity; unique_ptr db_file_system; + unique_ptr local_db_file_system; unique_ptr log_manager; + unique_ptr metrics_manager; unique_ptr external_file_cache; unique_ptr result_set_manager; unique_ptr parser_cache; @@ -123,7 +129,7 @@ class DuckDB { void LoadStaticExtension() { T extension; auto &manager = ExtensionManager::Get(*instance); - auto load_info = manager.BeginLoad(extension.Name()); + auto load_info = manager.BeginLoad({extension.Name()}); if (!load_info) { // already loaded - return return; diff --git a/src/duckdb/src/include/duckdb/main/database_file_opener.hpp b/src/duckdb/src/include/duckdb/main/database_file_opener.hpp index 2e6664b6c..2129afa3c 100644 --- a/src/duckdb/src/include/duckdb/main/database_file_opener.hpp +++ b/src/duckdb/src/include/duckdb/main/database_file_opener.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/file_opener.hpp" +#include "duckdb/common/local_file_system.hpp" #include "duckdb/common/opener_file_system.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/database.hpp" @@ -68,4 +69,20 @@ class DatabaseFileSystem : public OpenerFileSystem { mutable DatabaseFileOpener database_opener; }; +class LocalDatabaseFileSystem : public OpenerFileSystem { +public: + explicit LocalDatabaseFileSystem(DatabaseInstance &db_p); + + FileSystem &GetFileSystem() const override; + optional_ptr GetOpener() const override { + return &database_opener; + } + +private: + DatabaseInstance &db; + unique_ptr owned_file_system; + FileSystem &local_fs; + mutable DatabaseFileOpener database_opener; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp index ad637cad5..8d021eeb5 100644 --- a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp @@ -23,7 +23,7 @@ class DatabaseManager; enum class InsertDatabasePathResult { SUCCESS, ALREADY_EXISTS }; struct DatabasePathInfo { - DatabasePathInfo(DatabaseManager &manager, string name_p, AccessMode access_mode); + DatabasePathInfo(DatabaseManager &manager, const Identifier &name_p, AccessMode access_mode); string name; AccessMode access_mode; @@ -35,7 +35,7 @@ struct DatabasePathInfo { class DatabaseFilePathManager { public: idx_t ApproxDatabaseCount() const; - InsertDatabasePathResult InsertDatabasePath(DatabaseManager &manager, const string &path, const string &name, + InsertDatabasePathResult InsertDatabasePath(DatabaseManager &manager, const string &path, const Identifier &name, OnCreateConflict on_conflict, AttachOptions &options); //! Erase a database path - indicating we are done with using it void EraseDatabasePath(const string &path); diff --git a/src/duckdb/src/include/duckdb/main/database_manager.hpp b/src/duckdb/src/include/duckdb/main/database_manager.hpp index 4e436c0d8..5b0a6167a 100644 --- a/src/duckdb/src/include/duckdb/main/database_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/database_manager.hpp @@ -14,6 +14,7 @@ #include "duckdb/main/config.hpp" #include "duckdb/parser/parsed_data/attach_info.hpp" #include "duckdb/main/database_file_path_manager.hpp" +#include "duckdb/common/checked_integer.hpp" namespace duckdb { class AttachedDatabase; @@ -47,21 +48,21 @@ class DatabaseManager { //! Finalize starting up the system void FinalizeStartup(); //! Get an attached database by its name - optional_ptr GetDatabase(ClientContext &context, const string &name); - shared_ptr GetDatabase(const string &name); + optional_ptr GetDatabase(ClientContext &context, const Identifier &name); + shared_ptr GetDatabase(const Identifier &name); //! Attach a new database shared_ptr AttachDatabase(ClientContext &context, AttachInfo &info, AttachOptions &options); //! Detach an existing database - void DetachDatabase(ClientContext &context, const string &name, OnEntryNotFound if_not_found); + void DetachDatabase(ClientContext &context, const Identifier &name, OnEntryNotFound if_not_found); //! Alter operation dispatcher void Alter(ClientContext &context, AlterInfo &info); //! Rollback the attach of a database - shared_ptr DetachInternal(const string &name); + shared_ptr DetachInternal(const Identifier &name); //! Returns a reference to the system catalog Catalog &GetSystemCatalog(); - static const string &GetDefaultDatabase(ClientContext &context); + static Identifier GetDefaultDatabase(ClientContext &context); void SetDefaultDatabase(ClientContext &context, const string &new_value); //! Inserts a path to name mapping to the database paths map @@ -79,6 +80,10 @@ class DatabaseManager { vector> GetDatabases(); //! Returns the approximate count of attached databases. idx_t ApproxDatabaseCount(); + //! Returns the number of remote catalogs currently attached. + idx_t GetRemoteCatalogCount() const { + return remote_catalog_count.load(); + } //! Removes all databases from the catalog set. This is necessary for the database instance's destructor, //! as the database manager has to be alive when destroying the catalog set objects. void ResetDatabases(); @@ -104,7 +109,7 @@ class DatabaseManager { //! Gets a list of all attached database paths vector GetAttachedDatabasePaths(); - shared_ptr GetDatabaseInternal(const lock_guard &, const string &name); + shared_ptr GetDatabaseInternal(const lock_guard &, const Identifier &name); private: optional_ptr FinalizeAttach(ClientContext &context, AttachInfo &info, @@ -117,21 +122,23 @@ class DatabaseManager { //! Lock for databases mutex databases_lock; //! The set of attached databases - case_insensitive_map_t> databases; + identifier_map_t> databases; //! The next object id handed out by the NextOid method atomic next_oid; //! The current query number atomic current_query_number; //! The current transaction number atomic current_transaction_id; + //! Count of remote catalogs currently attached; used to skip the remote pushdown optimizer when zero + atomic> remote_catalog_count; //! The current default database - string default_database; + Identifier default_database; //! Manager for ensuring we never open the same database file twice in the same program shared_ptr path_manager; private: //! Rename an existing database - void RenameDatabase(ClientContext &context, const string &old_name, const string &new_name, + void RenameDatabase(ClientContext &context, const Identifier &old_name, const Identifier &new_name, OnEntryNotFound if_not_found); }; diff --git a/src/duckdb/src/include/duckdb/main/error_manager.hpp b/src/duckdb/src/include/duckdb/main/error_manager.hpp index 065f6399a..2f1542e0a 100644 --- a/src/duckdb/src/include/duckdb/main/error_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/error_manager.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/map.hpp" +#include "duckdb/common/types/string.hpp" namespace duckdb { class String; diff --git a/src/duckdb/src/include/duckdb/main/extension/extension_loader.hpp b/src/duckdb/src/include/duckdb/main/extension/extension_loader.hpp index f11e204ff..4c21b3c0a 100644 --- a/src/duckdb/src/include/duckdb/main/extension/extension_loader.hpp +++ b/src/duckdb/src/include/duckdb/main/extension/extension_loader.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/constants.hpp" #include "duckdb/function/cast/cast_function_set.hpp" #include "duckdb/function/function_set.hpp" +#include "duckdb/main/metric_info.hpp" #include "duckdb/main/secret/secret.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "duckdb/main/extension_install_info.hpp" @@ -26,20 +27,42 @@ struct CreateScalarFunctionInfo; struct CreateTableFunctionInfo; struct CreateWindowFunctionInfo; +struct ExtensionLoaderInfo { + Identifier extension_name; + Identifier extension_alias; + string extension_description; + Identifier extension_schema = Identifier::DefaultSchema(); +}; + class ExtensionLoader { friend class DuckDB; friend class ExtensionHelper; public: - explicit ExtensionLoader(ExtensionActiveLoad &load_info); + explicit ExtensionLoader(const ExtensionActiveLoad &load_info); ExtensionLoader(DatabaseInstance &db, const string &extension_name); //! Returns the DatabaseInstance associated with this extension loader - DUCKDB_API DatabaseInstance &GetDatabaseInstance(); + DUCKDB_API DatabaseInstance &GetDatabaseInstance() const; public: //! Set the description of the extension DUCKDB_API void SetDescription(const string &description); + //! Explicitly sets, creates and registers all functions in this dedicated extension schema + DUCKDB_API void UseDedicatedSchemaForExtension(const Identifier &extension_schema_name); + //! Explicitly sets, creates and registers all functions in the registered extension schema + DUCKDB_API void UseDedicatedSchemaForExtension(); + //! Creates a schema in the catalog with the extension name + DUCKDB_API void CreateSchema(const Identifier &extension_schema_name) const; + //! Adds the created extension schema to the search path + DUCKDB_API void AddSchemaToSearchPath(const Identifier &schema_name) const; + //! Sets the default extension schema for this extension + DUCKDB_API void UseDefaultSchema(const Identifier &name = DEFAULT_SCHEMA); + DUCKDB_API static void RefreshSearchPath(ClientContext &context); + //! Gets registered extension name (or alias) + DUCKDB_API const Identifier &GetRegisteredExtensionName() const { + return loader_info.extension_alias.empty() ? loader_info.extension_name : loader_info.extension_alias; + } public: //! Register a new scalar function - merge overloads if the function already exists @@ -83,10 +106,10 @@ class ExtensionLoader { DUCKDB_API void RegisterCoordinateSystem(CreateCoordinateSystemInfo &info); //! Returns a reference to the function in the catalog - throws an exception if it does not exist - DUCKDB_API ScalarFunctionCatalogEntry &GetFunction(const string &name); - DUCKDB_API TableFunctionCatalogEntry &GetTableFunction(const string &name); - DUCKDB_API optional_ptr TryGetFunction(const string &name); - DUCKDB_API optional_ptr TryGetTableFunction(const string &name); + DUCKDB_API ScalarFunctionCatalogEntry &GetFunction(const Identifier &name); + DUCKDB_API TableFunctionCatalogEntry &GetTableFunction(const Identifier &name); + DUCKDB_API optional_ptr TryGetFunction(const Identifier &name); + DUCKDB_API optional_ptr TryGetTableFunction(const Identifier &name); //! Add a function overload DUCKDB_API void AddFunctionOverload(ScalarFunction function); @@ -106,13 +129,15 @@ class ExtensionLoader { DUCKDB_API void RegisterCastFunction(const LogicalType &source, const LogicalType &target, BoundCastInfo function, int64_t implicit_cast_cost = -1); + //! Registers a custom metric so it appears in duckdb_available_metrics. + DUCKDB_API void RegisterMetric(MetricInfo info); + private: void FinalizeLoad(); private: DatabaseInstance &db; - string extension_name; - string extension_description; + ExtensionLoaderInfo loader_info; optional_ptr extension_info; }; diff --git a/src/duckdb/src/include/duckdb/main/extension_callback_manager.hpp b/src/duckdb/src/include/duckdb/main/extension_callback_manager.hpp index a44108f0f..a3502a4c1 100644 --- a/src/duckdb/src/include/duckdb/main/extension_callback_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_callback_manager.hpp @@ -33,6 +33,9 @@ class ExtensionCallbackManager { ExtensionCallbackManager(); ~ExtensionCallbackManager(); + void AddExtensionSchema(const Identifier &schema); + vector GetExtensionSchemas() const; + static ExtensionCallbackManager &Get(ClientContext &context); static ExtensionCallbackManager &Get(DatabaseInstance &db); static const ExtensionCallbackManager &Get(const ClientContext &context); @@ -55,6 +58,7 @@ class ExtensionCallbackManager { private: mutex registry_lock; shared_ptr callback_registry; + vector extension_schemas; }; template diff --git a/src/duckdb/src/include/duckdb/main/extension_entries.hpp b/src/duckdb/src/include/duckdb/main/extension_entries.hpp index e4e4bd5c2..51ec375f1 100644 --- a/src/duckdb/src/include/duckdb/main/extension_entries.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_entries.hpp @@ -116,6 +116,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"bit_xor", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"bitstring", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"bitstring_agg", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"bitstring_byte_comparable", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"bool_and", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"bool_or", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"broadcast", "inet", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -229,13 +230,18 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"host", "inet", CatalogType::SCALAR_FUNCTION_ENTRY}, {"html_escape", "inet", CatalogType::SCALAR_FUNCTION_ENTRY}, {"html_unescape", "inet", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"iceberg_bucket", "iceberg", CatalogType::SCALAR_FUNCTION_ENTRY}, {"iceberg_column_stats", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, + {"iceberg_load_table_response", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"iceberg_metadata", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"iceberg_partition_stats", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"iceberg_scan", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, + {"iceberg_schema_properties", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"iceberg_snapshots", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"iceberg_table_properties", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"iceberg_to_ducklake", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, + {"iceberg_truncate", "iceberg", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"iceberg_verify_equality_deletes", "iceberg", CatalogType::SCALAR_FUNCTION_ENTRY}, {"icu_calendar_names", "icu", CatalogType::TABLE_FUNCTION_ENTRY}, {"icu_collate_af", "icu", CatalogType::SCALAR_FUNCTION_ENTRY}, {"icu_collate_am", "icu", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -485,7 +491,9 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"mode", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"murmur3_32", "ducklake", CatalogType::SCALAR_FUNCTION_ENTRY}, {"mysql_clear_cache", "mysql_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, + {"mysql_debug_execution_plan", "mysql_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, {"mysql_execute", "mysql_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, + {"mysql_explain_federated", "mysql_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, {"mysql_query", "mysql_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, {"nanosecond", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"netmask", "inet", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -522,7 +530,10 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"pi", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"position", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"postgres_attach", "postgres_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, + {"postgres_configure_pool", "postgres_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, {"postgres_execute", "postgres_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, + {"postgres_hstore_get", "postgres_scanner", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"postgres_hstore_to_json", "postgres_scanner", CatalogType::SCALAR_FUNCTION_ENTRY}, {"postgres_query", "postgres_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, {"postgres_scan", "postgres_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, {"postgres_scan_pushdown", "postgres_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, @@ -532,6 +543,16 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"pragma_rtree_index_info", "spatial", CatalogType::TABLE_FUNCTION_ENTRY}, {"printf", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"product", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"quack_check_token", "quack", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"quack_clear_cache", "quack", CatalogType::TABLE_FUNCTION_ENTRY}, + {"quack_identify", "quack", CatalogType::TABLE_FUNCTION_ENTRY}, + {"quack_nop_authorization", "quack", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"quack_query", "quack", CatalogType::TABLE_FUNCTION_ENTRY}, + {"quack_query_by_name", "quack", CatalogType::TABLE_FUNCTION_ENTRY}, + {"quack_serve", "quack", CatalogType::TABLE_FUNCTION_ENTRY}, + {"quack_server_list", "quack", CatalogType::TABLE_FUNCTION_ENTRY}, + {"quack_stop", "quack", CatalogType::TABLE_FUNCTION_ENTRY}, + {"quack_uri_parser", "quack", CatalogType::SCALAR_FUNCTION_ENTRY}, {"quantile", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"quantile_cont", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"quantile_disc", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, @@ -546,6 +567,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"read_ndjson_auto", "json", CatalogType::TABLE_FUNCTION_ENTRY}, {"read_ndjson_objects", "json", CatalogType::TABLE_FUNCTION_ENTRY}, {"read_parquet", "parquet", CatalogType::TABLE_FUNCTION_ENTRY}, + {"read_postgres_binary", "postgres_scanner", CatalogType::TABLE_FUNCTION_ENTRY}, {"read_xlsx", "excel", CatalogType::TABLE_FUNCTION_ENTRY}, {"reduce", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"reduce_sql_statement", "sqlsmith", CatalogType::TABLE_FUNCTION_ENTRY}, @@ -558,6 +580,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"regr_sxx", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"regr_sxy", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"regr_syy", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"remove_iceberg_schema_properties", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"remove_iceberg_table_properties", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"repeat", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"replace", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -574,6 +597,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"rtrim", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"sem", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"set_bit", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, + {"set_iceberg_schema_properties", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"set_iceberg_table_properties", "iceberg", CatalogType::TABLE_FUNCTION_ENTRY}, {"setseed", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"shapefile_meta", "spatial", CatalogType::TABLE_FUNCTION_ENTRY}, @@ -823,6 +847,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"version", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"vss_join", "vss", CatalogType::TABLE_MACRO_ENTRY}, {"vss_match", "vss", CatalogType::TABLE_MACRO_ENTRY}, + {"whoami", "quack", CatalogType::TABLE_MACRO_ENTRY}, {"xor", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"|", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"~", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, @@ -1091,15 +1116,42 @@ static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { {"http_retry_wait_ms", "httpfs"}, {"http_timeout", "httpfs"}, {"httpfs_client_implementation", "httpfs"}, + {"httpfs_connection_caching", "httpfs"}, + {"iceberg_logging_post_body_truncate_limit", "iceberg"}, {"iceberg_test_force_token_expiry", "iceberg"}, + {"iceberg_use_metadata_log", "iceberg"}, {"iceberg_via_aws_sdk_for_catalog_interactions", "iceberg"}, + {"ignore_row_group_size_for_partitioned_tables", "iceberg"}, + {"ignore_target_file_size_for_partitioned_tables", "iceberg"}, {"merge_http_secret_into_s3_request", "httpfs"}, + {"mysql_adaptive_replan_enabled", "mysql_scanner"}, + {"mysql_aggregate_pushdown_enabled", "mysql_scanner"}, {"mysql_bit1_as_boolean", "mysql_scanner"}, + {"mysql_compression_aware_costs", "mysql_scanner"}, + {"mysql_compression_ratio", "mysql_scanner"}, {"mysql_debug_show_queries", "mysql_scanner"}, + {"mysql_enable_filter_pushdown", "mysql_scanner"}, + {"mysql_enable_predicate_analyzer", "mysql_scanner"}, {"mysql_enable_transactions", "mysql_scanner"}, - {"mysql_experimental_filter_pushdown", "mysql_scanner"}, + {"mysql_explain_validation_enabled", "mysql_scanner"}, + {"mysql_hint_injection_enabled", "mysql_scanner"}, + {"mysql_hint_staleness_threshold", "mysql_scanner"}, {"mysql_incomplete_dates_as_nulls", "mysql_scanner"}, + {"mysql_order_pushdown_enabled", "mysql_scanner"}, + {"mysql_pool_acquire_mode", "mysql_scanner"}, + {"mysql_pool_connection_idle_timeout_millis", "mysql_scanner"}, + {"mysql_pool_connection_max_lifetime_millis", "mysql_scanner"}, + {"mysql_pool_enable_reaper_thread", "mysql_scanner"}, + {"mysql_pool_enable_thread_local_cache", "mysql_scanner"}, + {"mysql_pool_size", "mysql_scanner"}, + {"mysql_pool_wait_timeout_millis", "mysql_scanner"}, + {"mysql_push_threshold_no_index", "mysql_scanner"}, + {"mysql_push_threshold_with_index", "mysql_scanner"}, + {"mysql_query_timeout_enabled", "mysql_scanner"}, + {"mysql_query_timeout_max_ms", "mysql_scanner"}, + {"mysql_query_timeout_min_ms", "mysql_scanner"}, {"mysql_session_time_zone", "mysql_scanner"}, + {"mysql_sql_buffer_result", "mysql_scanner"}, {"mysql_time_as_time", "mysql_scanner"}, {"mysql_tinyint1_as_boolean", "mysql_scanner"}, {"parquet_metadata_cache", "parquet"}, @@ -1108,12 +1160,28 @@ static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { {"pg_connection_limit", "postgres_scanner"}, {"pg_debug_show_queries", "postgres_scanner"}, {"pg_experimental_filter_pushdown", "postgres_scanner"}, + {"pg_idle_in_transaction_timeout_millis", "postgres_scanner"}, {"pg_null_byte_replacement", "postgres_scanner"}, + {"pg_oauth_token", "postgres_scanner"}, + {"pg_order_pushdown", "postgres_scanner"}, {"pg_pages_per_task", "postgres_scanner"}, + {"pg_pool_acquire_mode", "postgres_scanner"}, + {"pg_pool_enable_reaper_thread", "postgres_scanner"}, + {"pg_pool_enable_thread_local_cache", "postgres_scanner"}, + {"pg_pool_health_check_query", "postgres_scanner"}, + {"pg_pool_idle_timeout_millis", "postgres_scanner"}, + {"pg_pool_max_connections", "postgres_scanner"}, + {"pg_pool_max_lifetime_millis", "postgres_scanner"}, + {"pg_pool_wait_timeout_millis", "postgres_scanner"}, + {"pg_statement_timeout_millis", "postgres_scanner"}, {"pg_use_binary_copy", "postgres_scanner"}, {"pg_use_ctid_scan", "postgres_scanner"}, {"pg_use_text_protocol", "postgres_scanner"}, {"prefetch_all_parquet_files", "parquet"}, + {"quack_authentication_function", "quack"}, + {"quack_authorization_function", "quack"}, + {"quack_fetch_batch_chunks", "quack"}, + {"quack_loaded_at_us", "quack"}, {"s3_access_key_id", "httpfs"}, {"s3_allow_recursive_globbing", "httpfs"}, {"s3_endpoint", "httpfs"}, @@ -1138,12 +1206,23 @@ static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { {"ui_remote_url", "ui"}, {"unsafe_disable_etag_checks", "httpfs"}, {"unsafe_enable_version_guessing", "iceberg"}, + {"unsafe_iceberg_ignore_sort_order", "iceberg"}, + {"whoami_hostname", "quack"}, + {"whoami_meta", "quack"}, + {"whoami_name", "quack"}, + {"whoami_provider", "quack"}, + {"whoami_region", "quack"}, + {"whoami_started_at", "quack"}, }; // END_OF_EXTENSION_SETTINGS static constexpr ExtensionEntry EXTENSION_SECRET_TYPES[] = { - {"aws", "httpfs"}, {"azure", "azure"}, {"ducklake", "ducklake"}, {"gcs", "httpfs"}, - {"huggingface", "httpfs"}, {"iceberg", "iceberg"}, {"mysql", "mysql_scanner"}, {"postgres", "postgres_scanner"}, - {"r2", "httpfs"}, {"s3", "httpfs"}, + {"aws", "httpfs"}, {"azure", "azure"}, + {"ducklake", "ducklake"}, {"gcs", "httpfs"}, + {"huggingface", "httpfs"}, {"iceberg", "iceberg"}, + {"mysql", "mysql_scanner"}, {"postgres", "postgres_scanner"}, + {"quack", "quack"}, {"r2", "httpfs"}, + {"rds", "postgres_scanner"}, {"s3", "httpfs"}, + {"uc", "unity_catalog"}, {"unity_catalog", "unity_catalog"}, }; // END_OF_EXTENSION_SECRET_TYPES // Note: these are currently hardcoded in scripts/generate_extensions_function.py @@ -1218,6 +1297,7 @@ static constexpr ExtensionEntry EXTENSION_SECRET_PROVIDERS[] = { {"gcs/credential_chain", "aws"}, {"r2/credential_chain", "aws"}, {"aws/credential_chain", "aws"}, + {"rds/credential_chain", "aws"}, {"azure/access_token", "azure"}, {"azure/config", "azure"}, {"azure/credential_chain", "azure"}, @@ -1228,30 +1308,10 @@ static constexpr ExtensionEntry EXTENSION_SECRET_PROVIDERS[] = { {"mysql/config", "mysql_scanner"}, {"postgres/config", "postgres_scanner"}}; // EXTENSION_SECRET_PROVIDERS -static constexpr const char *AUTOLOADABLE_EXTENSIONS[] = {"avro", - "aws", - "azure", - "autocomplete", - "core_functions", - "delta", - "ducklake", - "encodings", - "excel", - "fts", - "httpfs", - "iceberg", - "inet", - "icu", - "json", - "motherduck", - "mysql_scanner", - "parquet", - "sqlite_scanner", - "sqlsmith", - "postgres_scanner", - "tpcds", - "tpch", - "unity_catalog", - "ui"}; // END_OF_AUTOLOADABLE_EXTENSIONS +static constexpr const char *AUTOLOADABLE_EXTENSIONS[] = { + "autocomplete", "avro", "aws", "azure", "core_functions", "delta", "ducklake", + "encodings", "excel", "fts", "httpfs", "iceberg", "icu", "inet", + "json", "motherduck", "mysql_scanner", "parquet", "postgres_scanner", "quack", "sqlite_scanner", + "sqlsmith", "tpcds", "tpch", "ui", "unity_catalog"}; // END_OF_AUTOLOADABLE_EXTENSIONS } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension_helper.hpp b/src/duckdb/src/include/duckdb/main/extension_helper.hpp index f268b9fdc..4abcbb800 100644 --- a/src/duckdb/src/include/duckdb/main/extension_helper.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_helper.hpp @@ -11,6 +11,7 @@ #include "duckdb.hpp" #include "duckdb/main/extension_entries.hpp" #include "duckdb/main/extension_install_info.hpp" +#include "duckdb/main/extension_load_options.hpp" #include "duckdb/main/settings.hpp" #include @@ -18,6 +19,7 @@ namespace duckdb { class DuckDB; +class ExtensionActiveLoad; enum class ExtensionLoadResult : uint8_t { LOADED_EXTENSION = 0, EXTENSION_UNKNOWN = 1, NOT_LOADED = 2 }; @@ -103,8 +105,8 @@ class ExtensionHelper { static unique_ptr InstallExtension(DatabaseInstance &db, FileSystem &fs, const string &extension, ExtensionInstallOptions &options); //! Load an extension - static void LoadExternalExtension(ClientContext &context, const string &extension); - static void LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension); + static void LoadExternalExtension(ClientContext &context, const ExtensionLoadOptions &options); + static void LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const ExtensionLoadOptions &options); //! Autoload an extension (depending on config, potentially a nop. Throws when installation fails) static void AutoLoadExtension(ClientContext &context, const string &extension_name); @@ -156,7 +158,7 @@ class ExtensionHelper { //! Extension can have aliases static idx_t ExtensionAliasCount(); - static ExtensionAlias GetExtensionAlias(idx_t index); + static ExtensionAlias GetInternalExtensionAlias(idx_t index); //! Get public signing keys for extension signing static const vector GetPublicKeys(bool allow_community_extension = false); diff --git a/src/duckdb/src/include/duckdb/main/extension_load_options.hpp b/src/duckdb/src/include/duckdb/main/extension_load_options.hpp new file mode 100644 index 000000000..efc87e8da --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/extension_load_options.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/extension_load_options.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/string.hpp" + +#include "duckdb/common/identifier.hpp" +namespace duckdb { + +struct ExtensionLoadOptions { + string extension_name; + Identifier alias; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension_manager.hpp b/src/duckdb/src/include/duckdb/main/extension_manager.hpp index 5f0ed3d3d..b7c0f731b 100644 --- a/src/duckdb/src/include/duckdb/main/extension_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_manager.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/main/extension_install_info.hpp" +#include "duckdb/main/extension_load_options.hpp" namespace duckdb { class ErrorData; @@ -18,6 +19,8 @@ class ExtensionInfo { public: ExtensionInfo(); + string orig_ext_name; + string alias; mutex lock; atomic is_loaded; unique_ptr install_info; @@ -26,12 +29,17 @@ class ExtensionInfo { class ExtensionActiveLoad { public: - ExtensionActiveLoad(DatabaseInstance &db, ExtensionInfo &info, string extension_name); + ExtensionActiveLoad(DatabaseInstance &db, ExtensionInfo &info, Identifier extension_name_p, Identifier alias_p) + : db(db), load_lock(info.lock), info(info), extension_name(std::move(extension_name_p)), + alias(std::move(alias_p)) {}; + + ~ExtensionActiveLoad() = default; DatabaseInstance &db; unique_lock load_lock; ExtensionInfo &info; - string extension_name; + Identifier extension_name; + Identifier alias; public: void FinishLoad(ExtensionInstallInfo &install_info); @@ -45,15 +53,18 @@ class ExtensionManager { DUCKDB_API bool ExtensionIsLoaded(const string &name); DUCKDB_API vector GetExtensions(); DUCKDB_API optional_ptr GetExtensionInfo(const string &name); - DUCKDB_API unique_ptr BeginLoad(const string &extension); + DUCKDB_API unique_ptr BeginLoad(const ExtensionLoadOptions &options); DUCKDB_API static ExtensionManager &Get(DatabaseInstance &db); DUCKDB_API static ExtensionManager &Get(ClientContext &context); + void RemoveExtensionInfo(const string &name); + private: DatabaseInstance &db; mutex lock; - unordered_map> loaded_extensions_info; + identifier_map_t> loaded_extensions_info; + string extension_load_prefix; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/gathered_metrics.hpp b/src/duckdb/src/include/duckdb/main/gathered_metrics.hpp new file mode 100644 index 000000000..5c70f3e22 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/gathered_metrics.hpp @@ -0,0 +1,96 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/gathered_metrics.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/output_type.hpp" +#include "duckdb/common/enums/profiler_format.hpp" +#include "duckdb/common/progress_bar/progress_bar.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/constants.hpp" +#include "duckdb/common/enums/metric_type.hpp" +#include "duckdb/common/set.hpp" +#include "duckdb/main/metrics.hpp" + +namespace duckdb_yyjson { +struct yyjson_mut_doc; +struct yyjson_mut_val; +} // namespace duckdb_yyjson + +namespace duckdb { + +struct QueryProfileResult; +enum class ProfilingParameterNames : uint8_t { FORMAT, COVERAGE, SAVE_LOCATION, MODE, METRICS }; + +class GatheredMetrics { +public: + GatheredMetrics() = default; + explicit GatheredMetrics(const vector &tracked_metrics); + GatheredMetrics(GatheredMetrics &) = default; + GatheredMetrics &operator=(GatheredMetrics const &) = default; + +public: + void ResetMetrics(); + //! Returns true if this metric is enabled (and should therefore be collected and output). + bool MetricIsTracked(const string &key) const; + void SetMetric(const string &key, Value new_value); + void SetMetric(const string &key, idx_t value); + void SetMetric(const string &key, double value); + void SetMetric(const string &key, const string &value); + + template + bool MetricIsTracked() const { + return MetricIsTracked(T::Name); + } + + bool AnyOperatorMetricTracked() const { + string operator_prefix = "operator."; + if (MetricIsTracked(operator_prefix)) { + return true; + } + for (const auto &key : tracked_exact) { + if (key.size() > operator_prefix.size() && key.compare(0, operator_prefix.size(), operator_prefix) == 0) { + return true; + } + } + return false; + } + + template + void SetMetric(const typename T::METRIC_TYPE &value) { + SetMetric(T::Name, value); + } + + const profiler_metrics_t &GetMetrics() const { + return metrics; + } + +public: + void WriteMetricsToLog(ClientContext &context) const; + //! Copy all enabled metrics into a QueryProfileResult node using lowercase string keys + void MetricsToProfileResult(QueryProfileResult &result) const; + +private: + void InitTrackedMetrics(const vector &patterns); + +private: + //! When tracked_metrics contains "*", all metrics are enabled and pattern matching is skipped. + bool track_all = false; + //! Exact metric names from tracked_metrics globs with no wildcard characters. + unordered_set tracked_exact; + //! Prefix strings from tracked_metrics patterns of the form "prefix*" (stored without the trailing '*'). + set tracked_prefixes; + //! Arbitrary glob patterns from tracked_metrics that don't fit the prefix category. + vector tracked_globs; + //! Contains all enabled metrics. + profiler_metrics_t metrics; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/metric_info.hpp b/src/duckdb/src/include/duckdb/main/metric_info.hpp new file mode 100644 index 000000000..1ce4d2e25 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/metric_info.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/metric_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/string.hpp" +#include "duckdb/common/winapi.hpp" + +namespace duckdb { + +struct MetricInfo { + string name; + string metric_type; // "double" | "uint64" | "string" | "map" + string description; + string unit; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/metrics.hpp b/src/duckdb/src/include/duckdb/main/metrics.hpp new file mode 100644 index 000000000..ff7e56928 --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/metrics.hpp @@ -0,0 +1,308 @@ +//===----------------------------------------------------------------------===// +// +// DuckDB +// +// duckdb/main/metrics.hpp +// +// This file is automatically generated by scripts/generate_metrics.py +// Do not edit this file manually, your changes will be overwritten +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/value.hpp" + +namespace duckdb { + +// Query metrics +struct MetricQueryCPUTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "query.cpu_time"; + static constexpr const char *Description = "CPU time spent on the query"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricQuerySQL { + using METRIC_TYPE = string; + static constexpr const char *Name = "query.sql"; + static constexpr const char *Description = "The SQL string of the query"; + static constexpr const char *Unit = ""; + static constexpr const char *TypeStr = "string"; +}; +struct MetricQueryTotalIntermediateRows { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "query.total_intermediate_rows"; + static constexpr const char *Description = "Cumulative number of intermediate rows produced by the query"; + static constexpr const char *Unit = "rows"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricQueryTotalIntermediateSizeBytes { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "query.total_intermediate_size_bytes"; + static constexpr const char *Description = "Cumulative size in bytes of all intermediate results produced by the query"; + static constexpr const char *Unit = "bytes"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricQueryTotalRowGroupsScanned { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "query.total_row_groups_scanned"; + static constexpr const char *Description = "Cumulative number of row groups scanned by the query"; + static constexpr const char *Unit = "row_groups"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricQueryTotalRowGroupsToScan { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "query.total_row_groups_to_scan"; + static constexpr const char *Description = "Cumulative total number of row groups available to scan"; + static constexpr const char *Unit = "row_groups"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricQueryTotalRowsScanned { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "query.total_rows_scanned"; + static constexpr const char *Description = "Cumulative number of rows scanned by the query"; + static constexpr const char *Unit = "rows"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricQueryTotalTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "query.total_time"; + static constexpr const char *Description = "Time spent executing the entire query"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; + +// System metrics +struct MetricSystemBlockedThreadTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "system.blocked_thread_time"; + static constexpr const char *Description = "Time spent waiting for a thread to become available"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricSystemPeakBufferMemory { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "system.peak_buffer_memory"; + static constexpr const char *Description = "Peak memory usage of the buffer manager"; + static constexpr const char *Unit = "bytes"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricSystemPeakTempDirSize { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "system.peak_temp_dir_size"; + static constexpr const char *Description = "Peak size of the temporary directory"; + static constexpr const char *Unit = "bytes"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricSystemTotalMemoryAllocated { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "system.total_memory_allocated"; + static constexpr const char *Description = "The total memory allocated by the buffer manager"; + static constexpr const char *Unit = "bytes"; + static constexpr const char *TypeStr = "uint64"; +}; + +// Io metrics +struct MetricIOTotalBytesRead { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "io.total_bytes_read"; + static constexpr const char *Description = "The total amount of bytes read from storage"; + static constexpr const char *Unit = "bytes"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricIOTotalBytesWritten { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "io.total_bytes_written"; + static constexpr const char *Description = "The total amount of bytes written to storage"; + static constexpr const char *Unit = "bytes"; + static constexpr const char *TypeStr = "uint64"; +}; + +// Storage metrics +struct MetricStorageAttachLoadStorageLatency { + using METRIC_TYPE = double; + static constexpr const char *Name = "storage.attach_load_storage_latency"; + static constexpr const char *Description = "Time spent loading from storage"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricStorageAttachReplayWALLatency { + using METRIC_TYPE = double; + static constexpr const char *Name = "storage.attach_replay_wal_latency"; + static constexpr const char *Description = "Time spent replaying the WAL file"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricStorageCheckpointLatency { + using METRIC_TYPE = double; + static constexpr const char *Name = "storage.checkpoint_latency"; + static constexpr const char *Description = "Time spent running checkpoints"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricStorageCommitLocalStorageLatency { + using METRIC_TYPE = double; + static constexpr const char *Name = "storage.commit_local_storage_latency"; + static constexpr const char *Description = "Time spent committing the transaction-local storage"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricStorageTotalVacuumTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "storage.total_vacuum_time"; + static constexpr const char *Description = "Total time spent running vacuum tasks during checkpointing"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricStorageWaitingToAttachLatency { + using METRIC_TYPE = double; + static constexpr const char *Name = "storage.waiting_to_attach_latency"; + static constexpr const char *Description = "Time spent waiting to ATTACH a file"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricStorageWALReplayEntryCount { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "storage.wal_replay_entry_count"; + static constexpr const char *Description = "The total number of entries to replay in the WAL"; + static constexpr const char *Unit = "rows"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricStorageWriteToWALLatency { + using METRIC_TYPE = double; + static constexpr const char *Name = "storage.write_to_wal_latency"; + static constexpr const char *Description = "Time spent writing to the WAL"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; + +// Operator metrics +struct MetricOperatorCPUTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "operator.cpu_time"; + static constexpr const char *Description = "CPU time spent in the operator"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricOperatorExtraInfo { + using METRIC_TYPE = Value; + static constexpr const char *Name = "operator.extra_info"; + static constexpr const char *Description = "Unique operator metrics"; + static constexpr const char *Unit = ""; + static constexpr const char *TypeStr = "map"; +}; +struct MetricOperatorIntermediateRows { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "operator.intermediate_rows"; + static constexpr const char *Description = "Number of intermediate rows produced by the operator"; + static constexpr const char *Unit = "rows"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricOperatorIntermediateSizeBytes { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "operator.intermediate_size_bytes"; + static constexpr const char *Description = "Intermediate size in bytes produced by the operator"; + static constexpr const char *Unit = "bytes"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricOperatorRowGroupsScanned { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "operator.row_groups_scanned"; + static constexpr const char *Description = "Number of row groups scanned by the operator"; + static constexpr const char *Unit = "row_groups"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricOperatorRowsScanned { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "operator.rows_scanned"; + static constexpr const char *Description = "Number of rows scanned by the operator"; + static constexpr const char *Unit = "rows"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricOperatorTiming { + using METRIC_TYPE = double; + static constexpr const char *Name = "operator.timing"; + static constexpr const char *Description = "Time spent in the operator"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricOperatorTotalRowGroupsToScan { + using METRIC_TYPE = uint64_t; + static constexpr const char *Name = "operator.total_row_groups_to_scan"; + static constexpr const char *Description = "Total number of row groups available for the operator to scan"; + static constexpr const char *Unit = "row_groups"; + static constexpr const char *TypeStr = "uint64"; +}; +struct MetricOperatorType { + using METRIC_TYPE = string; + static constexpr const char *Name = "operator.type"; + static constexpr const char *Description = "Type of the operator"; + static constexpr const char *Unit = ""; + static constexpr const char *TypeStr = "string"; +}; + +// Optimizer metrics +struct MetricOptimizerTotalTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "optimizer.total_time"; + static constexpr const char *Description = "Time spent in all optimizers combined"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; + +// Parser metrics +struct MetricParserTotalTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "parser.total_time"; + static constexpr const char *Description = "Time spent parsing the SQL query"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; + +// Planner metrics +struct MetricPlannerBindingTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "planner.binding_time"; + static constexpr const char *Description = "Time spent binding the logical plan"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricPlannerTotalTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "planner.total_time"; + static constexpr const char *Description = "Time spent generating the logical plan"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; + +// Physical Planner metrics +struct MetricPhysicalPlannerColumnBinding { + using METRIC_TYPE = double; + static constexpr const char *Name = "physical_planner.column_binding"; + static constexpr const char *Description = "Time spent binding the columns in the logical plan to physical columns"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricPhysicalPlannerCreatePlan { + using METRIC_TYPE = double; + static constexpr const char *Name = "physical_planner.create_plan"; + static constexpr const char *Description = "Time spent creating the physical plan"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricPhysicalPlannerResolveTypes { + using METRIC_TYPE = double; + static constexpr const char *Name = "physical_planner.resolve_types"; + static constexpr const char *Description = "Time spent resolving the types in the logical plan to physical types"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; +struct MetricPhysicalPlannerTotalTime { + using METRIC_TYPE = double; + static constexpr const char *Name = "physical_planner.total_time"; + static constexpr const char *Description = "Time spent generating the physical plan"; + static constexpr const char *Unit = "seconds"; + static constexpr const char *TypeStr = "double"; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/metrics_manager.hpp b/src/duckdb/src/include/duckdb/main/metrics_manager.hpp new file mode 100644 index 000000000..e6994e9cb --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/metrics_manager.hpp @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/metrics_manager.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/main/metric_info.hpp" + +namespace duckdb { + +class ClientContext; +class DatabaseInstance; + +//! MetricsManager holds descriptors for every metric known to the system: +//! the statically compiled set (from src/common/metrics.json), optimizer +//! metrics generated at runtime from the OptimizerType enum, and any +//! additional metrics registered by extensions at runtime. +class MetricsManager { +public: + MetricsManager() = default; + + DUCKDB_API static MetricsManager &Get(ClientContext &context); + DUCKDB_API static MetricsManager &Get(DatabaseInstance &db); + + //! Returns the total number of known metrics (static + optimizer-generated + registered). + DUCKDB_API idx_t GetMetricCount() const; + + //! Returns a snapshot of all known metrics. + DUCKDB_API vector GetAllMetrics() const; + + //! Register an extension-defined metric. Silently ignores duplicates (same name). + DUCKDB_API void RegisterMetric(MetricInfo info); + +private: + mutable mutex registered_metrics_lock; + vector registered_metrics; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/prepared_statement.hpp b/src/duckdb/src/include/duckdb/main/prepared_statement.hpp index 0f8b86a07..b2608f61c 100644 --- a/src/duckdb/src/include/duckdb/main/prepared_statement.hpp +++ b/src/duckdb/src/include/duckdb/main/prepared_statement.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/winapi.hpp" #include "duckdb/main/materialized_query_result.hpp" #include "duckdb/main/pending_query_result.hpp" @@ -24,7 +25,7 @@ class PreparedStatement { public: //! Create a successfully prepared prepared statement object with the given name DUCKDB_API PreparedStatement(shared_ptr context, shared_ptr data, - string query, case_insensitive_map_t named_param_map); + string query, identifier_map_t named_param_map); //! Create a prepared statement that was not successfully prepared DUCKDB_API explicit PreparedStatement(ErrorData error); @@ -42,7 +43,7 @@ class PreparedStatement { //! The error message (if success = false) ErrorData error; //! The parameter mapping - case_insensitive_map_t named_param_map; + identifier_map_t named_param_map; public: //! Returns the stored error message @@ -60,7 +61,7 @@ class PreparedStatement { //! Returns the result SQL types of the prepared statement DUCKDB_API const vector &GetTypes(); //! Returns the result names of the prepared statement - DUCKDB_API const vector &GetNames(); + DUCKDB_API const vector &GetNames(); //! Returns the map of parameter index to the expected type of parameter DUCKDB_API case_insensitive_map_t GetExpectedParameterTypes() const; @@ -75,14 +76,14 @@ class PreparedStatement { DUCKDB_API unique_ptr PendingQuery(vector &values, bool allow_stream_result = true); //! Create a pending query result of the prepared statement with the given set named arguments - DUCKDB_API unique_ptr PendingQuery(case_insensitive_map_t &named_values, + DUCKDB_API unique_ptr PendingQuery(identifier_map_t &named_values, bool allow_stream_result = true); //! Execute the prepared statement with the given set of values DUCKDB_API unique_ptr Execute(vector &values, bool allow_stream_result = true); //! Execute the prepared statement with the given set of named+unnamed values - DUCKDB_API unique_ptr Execute(case_insensitive_map_t &named_values, + DUCKDB_API unique_ptr Execute(identifier_map_t &named_values, bool allow_stream_result = true); //! Execute the prepared statement with the given set of arguments @@ -93,14 +94,14 @@ class PreparedStatement { } template - static string ExcessValuesException(const case_insensitive_map_t ¶meters, - const case_insensitive_map_t &values) { + static string ExcessValuesException(const identifier_map_t ¶meters, + const identifier_map_t &values) { // Too many values set excess_set; for (auto &pair : values) { auto &name = pair.first; if (!parameters.count(name)) { - excess_set.insert(name); + excess_set.insert(name.GetIdentifierName()); } } vector excess_values; @@ -112,17 +113,17 @@ class PreparedStatement { } template - static string MissingValuesException(const case_insensitive_map_t ¶meters, - const case_insensitive_map_t &values) { + static string MissingValuesException(const identifier_map_t ¶meters, + const identifier_map_t &values) { // Missing values - set missing_set; + identifier_set_t missing_set; for (auto &pair : parameters) { auto &name = pair.first; if (!values.count(name)) { missing_set.insert(name); } } - vector missing_values; + vector missing_values; for (auto &val : missing_set) { missing_values.push_back(val); } @@ -131,8 +132,7 @@ class PreparedStatement { } template - static void VerifyParameters(const case_insensitive_map_t &provided, - const case_insensitive_map_t &expected) { + static void VerifyParameters(const identifier_map_t &provided, const identifier_map_t &expected) { if (expected.size() == provided.size()) { // Same amount of identifiers, if for (auto &pair : expected) { diff --git a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp index 84f50d654..b09c183f3 100644 --- a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp +++ b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp @@ -35,7 +35,7 @@ class PreparedStatementData { unique_ptr physical_plan; //! The result names of the transaction - vector names; + vector names; //! The result types of the transaction vector types; @@ -52,13 +52,13 @@ class PreparedStatementData { public: void CheckParameterCount(idx_t parameter_count); //! Whether or not the prepared statement data requires the query to rebound for the given parameters - bool RequireRebind(ClientContext &context, optional_ptr> values); + bool RequireRebind(ClientContext &context, optional_ptr> values); //! Bind a set of values to the prepared statement data - DUCKDB_API void Bind(case_insensitive_map_t values); + DUCKDB_API void Bind(identifier_map_t values); //! Get the expected SQL Type of the bound parameter - DUCKDB_API LogicalType GetType(const string &identifier); + DUCKDB_API LogicalType GetType(const Identifier &identifier); //! Try to get the expected SQL Type of the bound parameter - DUCKDB_API bool TryGetType(const string &identifier, LogicalType &result); + DUCKDB_API bool TryGetType(const Identifier &identifier, LogicalType &result); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/profiling_info.hpp b/src/duckdb/src/include/duckdb/main/profiling_info.hpp deleted file mode 100644 index 458c7b766..000000000 --- a/src/duckdb/src/include/duckdb/main/profiling_info.hpp +++ /dev/null @@ -1,122 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/main/profiling_info.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" -#include "duckdb/common/enums/output_type.hpp" -#include "duckdb/common/enums/profiler_format.hpp" -#include "duckdb/common/progress_bar/progress_bar.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/unordered_set.hpp" -#include "duckdb/common/constants.hpp" -#include "duckdb/common/enums/metric_type.hpp" - -namespace duckdb_yyjson { -struct yyjson_mut_doc; -struct yyjson_mut_val; -} // namespace duckdb_yyjson - -namespace duckdb { -enum class ProfilingParameterNames : uint8_t { FORMAT, COVERAGE, SAVE_LOCATION, MODE, METRICS }; - -class ProfilingInfo { -public: - //! Enabling a metric adds it to this set. - profiler_settings_t settings; - //! This set contains the expanded to-be-collected metrics, which can differ from 'settings'. - profiler_settings_t expanded_settings; - //! Contains all enabled metrics. - profiler_metrics_t metrics; - -public: - ProfilingInfo() = default; - explicit ProfilingInfo(const profiler_settings_t &n_settings, const idx_t depth = 0); - ProfilingInfo(ProfilingInfo &) = default; - ProfilingInfo &operator=(ProfilingInfo const &) = default; - -public: - void ResetMetrics(); - //! Returns true, if the query profiler must collect this metric. - static bool Enabled(const profiler_settings_t &settings, const MetricType metric); - //! Expand metrics depending on the collection of other metrics. - static void Expand(profiler_settings_t &settings, const MetricType metric); - -public: - string GetMetricAsString(const MetricType metric) const; - void WriteMetricsToLog(ClientContext &context); - void WriteMetricsToJSON(duckdb_yyjson::yyjson_mut_doc *doc, duckdb_yyjson::yyjson_mut_val *destination); - -public: - template - METRIC_TYPE GetMetricValue(const MetricType type) const { - auto val = metrics.at(type); - return val.GetValue(); - } - - template - void MetricUpdate(const MetricType type, const Value &value, - const std::function &update_fun) { - if (metrics.find(type) == metrics.end()) { - metrics[type] = value; - return; - } - auto new_value = update_fun(metrics[type].GetValue(), value.GetValue()); - metrics[type] = Value::CreateValue(new_value); - } - - template - void MetricUpdate(const MetricType type, const METRIC_TYPE &value, - const std::function &update_fun) { - auto new_value = Value::CreateValue(value); - MetricUpdate(type, new_value, update_fun); - } - - template - void MetricSum(const MetricType type, const Value &value) { - MetricUpdate(type, value, [](const METRIC_TYPE &old_value, const METRIC_TYPE &new_value) { - return old_value + new_value; - }); - } - - template - void MetricSum(const MetricType type, const METRIC_TYPE &value) { - auto new_value = Value::CreateValue(value); - return MetricSum(type, new_value); - } - - template - void MetricMax(const MetricType type, const Value &value) { - MetricUpdate(type, value, [](const METRIC_TYPE &old_value, const METRIC_TYPE &new_value) { - return MaxValue(old_value, new_value); - }); - } - - template - void MetricMax(const MetricType type, const METRIC_TYPE &value) { - auto new_value = Value::CreateValue(value); - return MetricMax(type, new_value); - } -}; - -// Specialization for InsertionOrderPreservingMap -template <> -inline InsertionOrderPreservingMap -ProfilingInfo::GetMetricValue>(const MetricType type) const { - auto val = metrics.at(type); - InsertionOrderPreservingMap result; - auto children = MapValue::GetChildren(val); - for (auto &child : children) { - auto struct_children = StructValue::GetChildren(child); - auto key = struct_children[0].GetValue(); - auto value = struct_children[1].GetValue(); - result.insert(key, value); - } - return result; -} -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/profiling_node.hpp b/src/duckdb/src/include/duckdb/main/profiling_node.hpp index 04d49d108..872fe2bf1 100644 --- a/src/duckdb/src/include/duckdb/main/profiling_node.hpp +++ b/src/duckdb/src/include/duckdb/main/profiling_node.hpp @@ -9,21 +9,93 @@ #pragma once #include "duckdb/common/common.hpp" -#include "duckdb/common/deque.hpp" -#include "duckdb/common/enums/profiler_format.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/profiler.hpp" -#include "duckdb/common/reference_map.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/winapi.hpp" -#include "duckdb/execution/expression_executor_state.hpp" -#include "duckdb/execution/physical_operator.hpp" -#include "duckdb/main/profiling_info.hpp" +#include "duckdb/main/gathered_metrics.hpp" namespace duckdb { +struct OperatorMetrics { + explicit OperatorMetrics() { + ResetMetrics(); + } + + string name; + PhysicalOperatorType operator_type; + + double time; + idx_t elements_returned; + idx_t intermediate_size_bytes; + idx_t system_peak_buffer_manager_memory; + idx_t system_peak_temp_directory_size; + idx_t rows_scanned; + idx_t row_groups_scanned; + idx_t total_row_groups_to_scan; + + profiler_metrics_t GetMetrics(const GatheredMetrics &info) const; + void ResetMetrics() { + time = 0; + elements_returned = 0; + intermediate_size_bytes = 0; + system_peak_buffer_manager_memory = 0; + system_peak_temp_directory_size = 0; + rows_scanned = 0; + row_groups_scanned = 0; + total_row_groups_to_scan = 0; + operator_type = PhysicalOperatorType::INVALID; + extra_info.clear(); + } + void AddExtraInfo(string key, string value) { + extra_info.insert(make_pair(std::move(key), std::move(value))); + } + void SetExtraInfo(InsertionOrderPreservingMap info) { + extra_info = std::move(info); + } + const InsertionOrderPreservingMap &GetExtraInfo() const { + return extra_info; + } + void GatherMetrics(ClientContext &context, double elapsed_time, optional_ptr chunk); + void Merge(const OperatorMetrics &other); + void Accumulate(const OperatorMetrics &other); + +private: + InsertionOrderPreservingMap extra_info; + void MergeInternal(const OperatorMetrics &other); +}; + +//! The OperatorProfiler measures timings of individual operators +//! This class exists once for all operators and collects `OperatorInfo` for each operator +class OperatorProfiler { + friend class QueryProfiler; + +public: + DUCKDB_API explicit OperatorProfiler(ClientContext &context); + ~OperatorProfiler() { + } + +public: + DUCKDB_API void StartOperator(optional_ptr phys_op); + DUCKDB_API void EndOperator(optional_ptr chunk); + DUCKDB_API void FinishSource(GlobalSourceState &gstate, LocalSourceState &lstate); + DUCKDB_API void FinishSource(const PhysicalOperator &phys_op, GlobalSourceState &gstate, LocalSourceState &lstate); + + //! Adds the timings in the OperatorProfiler (tree) to the QueryProfiler (tree). + DUCKDB_API void Flush(const PhysicalOperator &phys_op); + DUCKDB_API OperatorMetrics &GetOperatorMetrics(const PhysicalOperator &phys_op); + DUCKDB_API bool OperatorMetricsIsInitialized(const PhysicalOperator &phys_op); + +public: + ClientContext &context; + +private: + //! Whether or not the profiler is enabled + bool enabled; + //! The timer used to time the execution time of the individual Physical Operators + Profiler op; + //! The stack of Physical Operators that are currently active + optional_ptr active_operator; + //! A mapping of physical operators to profiled operator information. + reference_map_t operator_metrics; +}; + //! Recursive tree mirroring the operator tree. class ProfilingNode { public: @@ -32,7 +104,7 @@ class ProfilingNode { virtual ~ProfilingNode() {}; private: - ProfilingInfo profiling_info; + OperatorMetrics operator_info; public: idx_t depth = 0; @@ -42,11 +114,11 @@ class ProfilingNode { idx_t GetChildCount() { return children.size(); } - ProfilingInfo &GetProfilingInfo() { - return profiling_info; + OperatorMetrics &GetOperatorMetrics() { + return operator_info; } - const ProfilingInfo &GetProfilingInfo() const { - return profiling_info; + const OperatorMetrics &GetOperatorMetrics() const { + return operator_info; } optional_ptr GetChild(idx_t idx) { return children[idx].get(); diff --git a/src/duckdb/src/include/duckdb/main/profiling_utils.hpp b/src/duckdb/src/include/duckdb/main/profiling_utils.hpp index d7ccb5989..6dd5c7fea 100644 --- a/src/duckdb/src/include/duckdb/main/profiling_utils.hpp +++ b/src/duckdb/src/include/duckdb/main/profiling_utils.hpp @@ -4,15 +4,14 @@ // // duckdb/main/profiling_utils.hpp // -// This file is automatically generated by scripts/generate_metric_enums.py -// Do not edit this file manually, your changes will be overwritten //===----------------------------------------------------------------------===// #pragma once #include "duckdb/common/enums/metric_type.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/main/gathered_metrics.hpp" #include "duckdb/main/profiling_node.hpp" -#include "duckdb/main/profiling_info.hpp" #include "duckdb/common/profiler.hpp" namespace duckdb_yyjson { @@ -22,113 +21,152 @@ struct yyjson_mut_val; namespace duckdb { -struct ActiveTimer; +struct MetricsTimer; // Top level query metrics struct QueryMetrics { public: - QueryMetrics() { - Reset(); - } + QueryMetrics(); + ~QueryMetrics(); - ProfilingInfo query_global_info; + idx_t system_peak_buffer_memory; + idx_t system_peak_temp_dir_size; + double blocked_thread_time; - std::string query_name; - unique_ptr latency_timer; + std::string query_sql; + + // Always-tracked byte counters (used by progress bar even when profiling is disabled) + atomic bytes_read; + atomic bytes_written; + // Thread-safe memory allocation counter (updated from allocator callbacks on any thread) + atomic total_memory_allocated; public: - void UpdateMetric(const MetricType metric, idx_t addition) { - active_metrics[GetMetricsIndex(metric)] += addition; - } - - void ResetMetric(const MetricType metric) { - active_metrics[GetMetricsIndex(metric)] = 0; - } - - idx_t GetMetricValue(const MetricType metric) const { - return active_metrics[GetMetricsIndex(metric)]; - } - - double GetMetricInSeconds(const MetricType metric) const { - return static_cast(active_metrics[GetMetricsIndex(metric)]) / 1e9; - } - - void Reset() { - for(idx_t i = 0; i < ACTIVELY_TRACKED_METRICS; i++) { - active_metrics[i] = 0; - } - - latency_timer.reset(); - query_name = ""; - } - - void Merge(const QueryMetrics &other) { - for(idx_t i = 0; i < ACTIVELY_TRACKED_METRICS; i++) { - active_metrics[i] += other.active_metrics[i]; - } - } - - static idx_t GetMetricsIndex(MetricType type) { - switch(type) { - case MetricType::ATTACH_LOAD_STORAGE_LATENCY: return 0; - case MetricType::ATTACH_REPLAY_WAL_LATENCY: return 1; - case MetricType::CHECKPOINT_LATENCY: return 2; - case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: return 3; - case MetricType::LATENCY: return 4; - case MetricType::WAITING_TO_ATTACH_LATENCY: return 5; - case MetricType::WRITE_TO_WAL_LATENCY: return 6; - case MetricType::TOTAL_BYTES_READ: return 7; - case MetricType::TOTAL_BYTES_WRITTEN: return 8; - case MetricType::TOTAL_MEMORY_ALLOCATED: return 9; - case MetricType::WAL_REPLAY_ENTRY_COUNT: return 10; - default: - throw InternalException("MetricType %s is not actively tracked.", EnumUtil::ToString(type)); + void UpdateMetric(const string &key, idx_t addition) { + string_timings[key] += addition; + } + + void UpdateMetricCounter(const string &key, idx_t addition) { + string_counters[key] += addition; + } + + void UpdateBytesRead(idx_t n) { + bytes_read += n; + } + + void UpdateBytesWritten(idx_t n) { + bytes_written += n; + } + + void UpdateTotalMemoryAllocated(idx_t n) { + total_memory_allocated += n; + } + + double GetStringMetricInSeconds(const string &key) const { + auto it = string_timings.find(key); + if (it == string_timings.end()) { + return 0.0; } + return static_cast(it->second) / 1e9; } -private: - static constexpr const idx_t ACTIVELY_TRACKED_METRICS = 11; + idx_t GetStringCounter(const string &key) const { + auto it = string_counters.find(key); + if (it == string_counters.end()) { + return 0; + } + return it->second; + } - atomic active_metrics[ACTIVELY_TRACKED_METRICS]; -}; + idx_t GetBytesRead() const { + return bytes_read.load(); + } + + idx_t GetBytesWritten() const { + return bytes_written.load(); + } + + idx_t GetTotalMemoryAllocated() const { + return total_memory_allocated.load(); + } + + const unordered_map &GetMetricTimings() const { + return string_timings; + } + + const unordered_map &GetMetricCounters() const { + return string_counters; + } + + void Reset() { + string_timings.clear(); + string_counters.clear(); + bytes_read = 0; + bytes_written = 0; + total_memory_allocated = 0; + + latency_timer.reset(); + query_sql = ""; + system_peak_buffer_memory = 0; + system_peak_temp_dir_size = 0; + blocked_thread_time = 0; + } + + //! Write all query-level metrics into the given GatheredMetrics. + void FinalizeMetrics(GatheredMetrics &info); + + void Merge(const QueryMetrics &other) { + for (const auto &entry : other.string_timings) { + string_timings[entry.first] += entry.second; + } + for (const auto &entry : other.string_counters) { + string_counters[entry.first] += entry.second; + } + bytes_read += other.bytes_read.load(); + bytes_written += other.bytes_written.load(); + total_memory_allocated += other.total_memory_allocated.load(); + } + +private: + // String-keyed timings for optimizer, storage and phase timing metrics + unordered_map string_timings; + // String-keyed counters for storage counter metrics (e.g. "storage.wal_replay_entry_count") + unordered_map string_counters; -class ProfilingUtils { public: - static void SetMetricToDefault(profiler_metrics_t &metrics, const MetricType &type); - static void MetricToJson(duckdb_yyjson::yyjson_mut_doc *doc, duckdb_yyjson::yyjson_mut_val *dest, const char *key_ptr, profiler_metrics_t &metrics, const MetricType &type); - static void CollectMetrics(const MetricType &type, QueryMetrics &query_metrics, Value &metric, ProfilingNode &node, ProfilingInfo &child_info); + // Declared after string_timings so it is destroyed first; its destructor writes to string_timings. + unique_ptr latency_timer; }; -struct ActiveTimer { +struct MetricsTimer { public: - ActiveTimer() : metric(MetricType::EXTRA_INFO), is_active(false) { + MetricsTimer() : metric_name(""), is_active(false) { } - ActiveTimer(QueryMetrics &query_metrics, const MetricType metric, const bool is_active = true) : query_metrics(query_metrics), metric(metric), is_active(is_active) { - // start on constructor + MetricsTimer(QueryMetrics &query_metrics, string key, const bool is_active = true) + : query_metrics(query_metrics), metric_name(std::move(key)), is_active(is_active) { if (!is_active) { return; } profiler.Start(); } - ~ActiveTimer() { - if (is_active) { - // automatically end in destructor + ~MetricsTimer() { + if (is_active && !Exception::UncaughtException()) { EndTimer(); } } // disable copy constructors - ActiveTimer(const ActiveTimer &other) = delete; - ActiveTimer &operator=(const ActiveTimer &) = delete; + MetricsTimer(const MetricsTimer &other) = delete; + MetricsTimer &operator=(const MetricsTimer &) = delete; //! enable move constructors - ActiveTimer(ActiveTimer &&other) noexcept : is_active(false) { + MetricsTimer(MetricsTimer &&other) noexcept : is_active(false) { std::swap(query_metrics, other.query_metrics); - std::swap(metric, other.metric); + std::swap(metric_name, other.metric_name); std::swap(profiler, other.profiler); std::swap(is_active, other.is_active); } - ActiveTimer &operator=(ActiveTimer &&other) noexcept { + MetricsTimer &operator=(MetricsTimer &&other) noexcept { std::swap(query_metrics, other.query_metrics); - std::swap(metric, other.metric); + std::swap(metric_name, other.metric_name); std::swap(profiler, other.profiler); std::swap(is_active, other.is_active); return *this; @@ -139,14 +177,13 @@ struct ActiveTimer { if (!is_active) { return; } - // stop profiling and report is_active = false; profiler.End(); - query_metrics->UpdateMetric(metric, profiler.ElapsedNanos()); + query_metrics->UpdateMetric(metric_name, profiler.ElapsedNanos()); } void Reset() { - if (!is_active) { + if (!is_active) { return; } profiler.Reset(); @@ -155,9 +192,9 @@ struct ActiveTimer { private: optional_ptr query_metrics; - MetricType metric; + string metric_name; Profiler profiler; bool is_active; }; -} +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/query_context.hpp b/src/duckdb/src/include/duckdb/main/query_context.hpp new file mode 100644 index 000000000..899815f5d --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/query_context.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/query_context.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/optional_ptr.hpp" + +namespace duckdb { + +class ClientContext; + +//! The QueryContext wraps an optional client context. +//! It makes query-related information available to operations. +class QueryContext { +public: + QueryContext() : context(nullptr) { + } + QueryContext(optional_ptr context) : context(context) { // NOLINT: allow implicit construction + } + QueryContext(ClientContext &context) : context(&context) { // NOLINT: allow implicit construction + } + +public: + bool Valid() const { + return context != nullptr; + } + optional_ptr GetClientContext() const { + return context; + } + +private: + optional_ptr context; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/query_profiler.hpp b/src/duckdb/src/include/duckdb/main/query_profiler.hpp index d3d1fa6a2..95fd94aab 100644 --- a/src/duckdb/src/include/duckdb/main/query_profiler.hpp +++ b/src/duckdb/src/include/duckdb/main/query_profiler.hpp @@ -34,93 +34,43 @@ class ExpressionExecutor; class ProfilingNode; class PhysicalOperator; class SQLStatement; -struct ActiveTimer; +struct MetricsTimer; +class OperatorProfiler; enum class ProfilingCoverage : uint8_t { SELECT = 0, ALL = 1 }; -struct OperatorInformation { - explicit OperatorInformation() { - } - - string name; - - double time = 0; - idx_t elements_returned = 0; - idx_t result_set_size = 0; - idx_t system_peak_buffer_manager_memory = 0; - idx_t system_peak_temp_directory_size = 0; - idx_t rows_scanned = 0; - - InsertionOrderPreservingMap extra_info; - - template - void AddMetric(MetricType type, T metric) { - switch (type) { - case MetricType::OPERATOR_TIMING: - time += static_cast(metric); - break; - case MetricType::OPERATOR_CARDINALITY: - elements_returned += LossyNumericCast(metric); - break; - case MetricType::RESULT_SET_SIZE: - result_set_size += LossyNumericCast(metric); - break; - case MetricType::SYSTEM_PEAK_BUFFER_MEMORY: { - if (metric > static_cast(system_peak_buffer_manager_memory)) { - system_peak_buffer_manager_memory = LossyNumericCast(metric); - } - break; - } - case MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE: { - if (metric > static_cast(system_peak_temp_directory_size)) { - system_peak_temp_directory_size = LossyNumericCast(metric); - } - break; - } - case MetricType::OPERATOR_ROWS_SCANNED: - rows_scanned = LossyNumericCast(metric); - break; - default: - throw InternalException("OperatorProfiler: Unknown metric type"); - } - } +//! A JSON-like recursive profiling value. +//! FIXME: this should at some point be replaced by a "Value" - but that's not easily possible until our VARIANT Value +//! is extended to be able to easily hold arbitrary values +enum class QueryProfileResultKind { + VALUE, + LIST, + OBJECT, }; -//! The OperatorProfiler measures timings of individual operators -//! This class exists once for all operators and collects `OperatorInfo` for each operator -class OperatorProfiler { - friend class QueryProfiler; - -public: - DUCKDB_API explicit OperatorProfiler(ClientContext &context); - ~OperatorProfiler() { +struct QueryProfileResult { + QueryProfileResultKind kind = QueryProfileResultKind::OBJECT; + //! Non-empty for children of an OBJECT node; empty for LIST items and the root + string key; + //! Valid when kind == VALUE + Value value; + //! Ordered children; valid when kind == LIST or OBJECT + vector> children; + + //! Add a named VALUE leaf to this OBJECT node + void AddValue(const string &k, Value val); + //! Add a named OBJECT child to this OBJECT node; returns the new child + QueryProfileResult &AddObject(const string &k); + //! Add a named LIST child to this OBJECT node; returns the new child + QueryProfileResult &AddList(const string &k); + //! Append an anonymous OBJECT item to this LIST node; returns the new item + QueryProfileResult &AppendObject(); + //! Append an anonymous LIST item to this node; returns the new item + QueryProfileResult &AppendList(); + //! Returns true if this node is a nested type (OBJECT or LIST) + bool IsNested() const { + return kind == QueryProfileResultKind::OBJECT || kind == QueryProfileResultKind::LIST; } - -public: - DUCKDB_API void StartOperator(optional_ptr phys_op); - DUCKDB_API void EndOperator(optional_ptr chunk); - DUCKDB_API void FinishSource(GlobalSourceState &gstate, LocalSourceState &lstate); - - //! Adds the timings in the OperatorProfiler (tree) to the QueryProfiler (tree). - DUCKDB_API void Flush(const PhysicalOperator &phys_op); - DUCKDB_API OperatorInformation &GetOperatorInfo(const PhysicalOperator &phys_op); - DUCKDB_API bool OperatorInfoIsInitialized(const PhysicalOperator &phys_op); - -public: - ClientContext &context; - -private: - //! Whether or not the profiler is enabled - bool enabled; - //! Sub-settings for the operator profiler - profiler_settings_t settings; - - //! The timer used to time the execution time of the individual Physical Operators - Profiler op; - //! The stack of Physical Operators that are currently active - optional_ptr active_operator; - //! A mapping of physical operators to profiled operator information. - reference_map_t operator_infos; }; //! QueryProfiler collects the profiling metrics of a query. @@ -147,11 +97,28 @@ class QueryProfiler { //! Finalize query metrics for output; safe to call multiple times. DUCKDB_API void FinalizeMetrics(); - //! Adds amount to a specific metric type. - DUCKDB_API void AddToCounter(MetricType type, const idx_t amount); - - //! Start/End a timer for a specific metric type. - DUCKDB_API ActiveTimer StartTimer(MetricType type); + //! Track bytes read (always tracked, even when profiling disabled). + DUCKDB_API void TrackBytesRead(idx_t amount); + //! Track bytes written (always tracked, even when profiling disabled). + DUCKDB_API void TrackBytesWritten(idx_t amount); + //! Track memory allocated (thread-safe; always tracked). + DUCKDB_API void TrackTotalMemoryAllocated(idx_t amount); + //! Add to a metric counter (profiling-only). + DUCKDB_API void AddToMetricCounter(const string &key, idx_t amount); + + //! Set an arbitrary metric value (profiling-only; no-op when profiling is disabled). + DUCKDB_API void SetMetric(const string &key, Value new_value); + //! Returns true if the given metric is currently being tracked. + //! Always returns false when profiling is disabled. + DUCKDB_API bool MetricIsTracked(const string &key) const; + + //! Start a timer for a metric identified by its struct type. + template + MetricsTimer StartTimer() { + return StartTimerInternal(METRIC::Name); + } + //! Start a timer for a string-keyed metric (use the template overload when possible). + DUCKDB_API MetricsTimer StartTimerInternal(const string &key); DUCKDB_API void StartExplainAnalyze(); @@ -160,9 +127,6 @@ class QueryProfiler { //! Adds the top level query information to the global profiler. DUCKDB_API void SetBlockedTime(const double &blocked_thread_time); - DUCKDB_API void StartPhase(MetricType phase_metric); - DUCKDB_API void EndPhase(); - DUCKDB_API void Initialize(const PhysicalOperator &root); DUCKDB_API string QueryTreeToString() const; @@ -184,27 +148,13 @@ class QueryProfiler { DUCKDB_API idx_t GetBytesRead() const; DUCKDB_API idx_t GetBytesWritten() const; - idx_t OperatorSize() { - return tree_map.size(); - } - - void Finalize(ProfilingNode &node); - - //! Return the root of the query tree. - optional_ptr GetRoot() { - return root.get(); - } - - //! Provides access to the root of the query tree, but ensures there are no concurrent modifications. - //! This can be useful when implementing continuous profiling or making customizations. - DUCKDB_API void GetRootUnderLock(const std::function)> &callback) { - lock_guard guard(lock); - callback(GetRoot()); - } + //! Return the result tree (generating it if it does not yet exist) + QueryProfileResult &GetResult(); + //! Returns true if the last query produced a profiling tree (i.e. profiling was enabled and the query succeeded) + bool HasRoot() const; private: - unique_ptr CreateTree(const PhysicalOperator &root, const profiler_settings_t &settings, - const idx_t depth = 0); + unique_ptr CreateTree(const PhysicalOperator &root, const idx_t depth = 0); void Render(const ProfilingNode &node, std::ostream &str) const; string RenderDisabledMessage(ProfilerPrintFormat format) const; @@ -222,9 +172,13 @@ class QueryProfiler { //! The root of the query tree unique_ptr root; + unique_ptr metrics; + //! Top level query information. QueryMetrics query_metrics; + unique_ptr result_tree; + //! A map of a Physical Operator pointer to a tree node TreeMap tree_map; //! Whether or not we are running as part of a explain_analyze query @@ -238,18 +192,12 @@ class QueryProfiler { } private: - //! The timer used to time the individual phases of the planning process - Profiler phase_profiler; - //! A mapping of the phase names to the timings - using PhaseTimingStorage = unordered_map; - PhaseTimingStorage phase_timings; - using PhaseTimingItem = PhaseTimingStorage::value_type; - //! The stack of currently active phases - vector phase_stack; - -private: - void MoveOptimizerPhasesToRoot(); void FinalizeMetricsInternal(); + //! Write metrics to log without acquiring the lock (must be called with lock held). + void ToLogInternal() const; + + unique_ptr ToResultTree() const; + unique_ptr ToLegacyResultTree() const; //! Check whether or not an operator type requires query profiling. If none of the ops in a query require profiling //! no profiling information is output. diff --git a/src/duckdb/src/include/duckdb/main/query_result.hpp b/src/duckdb/src/include/duckdb/main/query_result.hpp index 00937a214..60e9f2df2 100644 --- a/src/duckdb/src/include/duckdb/main/query_result.hpp +++ b/src/duckdb/src/include/duckdb/main/query_result.hpp @@ -92,6 +92,7 @@ class QueryResult : public BaseQueryResult { public: //! Deduplicate column names for interop with external libraries + static void DeduplicateColumns(vector &names); static void DeduplicateColumns(vector &names); public: diff --git a/src/duckdb/src/include/duckdb/main/relation.hpp b/src/duckdb/src/include/duckdb/main/relation.hpp index 94450b0be..6e7267dbc 100644 --- a/src/duckdb/src/include/duckdb/main/relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation.hpp @@ -81,7 +81,7 @@ class Relation : public enable_shared_from_this { DUCKDB_API virtual unique_ptr GetQueryNode() = 0; DUCKDB_API virtual string GetQuery(); DUCKDB_API virtual BoundStatement Bind(Binder &binder); - DUCKDB_API virtual string GetAlias(); + DUCKDB_API virtual Identifier GetAlias(); DUCKDB_API unique_ptr ExecuteOrThrow(); DUCKDB_API unique_ptr Execute(); @@ -91,11 +91,11 @@ class Relation : public enable_shared_from_this { DUCKDB_API void Print(); DUCKDB_API void Head(idx_t limit = 10); - DUCKDB_API shared_ptr CreateView(const string &name, bool replace = true, bool temporary = false); - DUCKDB_API shared_ptr CreateView(const string &schema_name, const string &name, bool replace = true, - bool temporary = false); + DUCKDB_API shared_ptr CreateView(const Identifier &name, bool replace = true, bool temporary = false); + DUCKDB_API shared_ptr CreateView(const Identifier &schema_name, const Identifier &name, + bool replace = true, bool temporary = false); DUCKDB_API unique_ptr Query(const string &sql) const; - DUCKDB_API unique_ptr Query(const string &name, const string &sql); + DUCKDB_API unique_ptr Query(const Identifier &name, const string &sql); //! Explain the query plan of this relation DUCKDB_API unique_ptr Explain(ExplainType type = ExplainType::EXPLAIN_STANDARD, @@ -161,27 +161,27 @@ class Relation : public enable_shared_from_this { DUCKDB_API shared_ptr Alias(const string &alias); //! Insert the data from this relation into a table - DUCKDB_API shared_ptr InsertRel(const string &schema_name, const string &table_name); - DUCKDB_API shared_ptr InsertRel(const string &catalog_name, const string &schema_name, - const string &table_name); - DUCKDB_API void Insert(const string &table_name); - DUCKDB_API void Insert(const string &schema_name, const string &table_name); - DUCKDB_API void Insert(const string &catalog_name, const string &schema_name, const string &table_name); + DUCKDB_API shared_ptr InsertRel(const Identifier &schema_name, const Identifier &table_name); + DUCKDB_API shared_ptr InsertRel(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &table_name); + DUCKDB_API void Insert(const Identifier &table_name); + DUCKDB_API void Insert(const Identifier &schema_name, const Identifier &table_name); + DUCKDB_API void Insert(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &table_name); //! Insert a row (i.e.,list of values) into a table DUCKDB_API virtual void Insert(const vector> &values); DUCKDB_API virtual void Insert(vector>> &&expressions); //! Create a table and insert the data from this relation into that table - DUCKDB_API shared_ptr CreateRel(const string &schema_name, const string &table_name, + DUCKDB_API shared_ptr CreateRel(const Identifier &schema_name, const Identifier &table_name, bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); - DUCKDB_API shared_ptr CreateRel(const string &catalog_name, const string &schema_name, - const string &table_name, bool temporary = false, + DUCKDB_API shared_ptr CreateRel(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &table_name, bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); - DUCKDB_API void Create(const string &table_name, bool temporary = false, + DUCKDB_API void Create(const Identifier &table_name, bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); - DUCKDB_API void Create(const string &schema_name, const string &table_name, bool temporary = false, + DUCKDB_API void Create(const Identifier &schema_name, const Identifier &table_name, bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); - DUCKDB_API void Create(const string &catalog_name, const string &schema_name, const string &table_name, + DUCKDB_API void Create(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &table_name, bool temporary = false, OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT); //! Write a relation to a CSV file diff --git a/src/duckdb/src/include/duckdb/main/relation/aggregate_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/aggregate_relation.hpp index 45daad4a0..215fe2535 100644 --- a/src/duckdb/src/include/duckdb/main/relation/aggregate_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/aggregate_relation.hpp @@ -32,7 +32,7 @@ class AggregateRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp index cfc0e243a..1f2b6fd97 100644 --- a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp @@ -14,15 +14,15 @@ namespace duckdb { class CreateTableRelation : public Relation { public: - CreateTableRelation(shared_ptr child, string schema_name, string table_name, bool temporary, + CreateTableRelation(shared_ptr child, Identifier schema_name, Identifier table_name, bool temporary, OnCreateConflict on_conflict); - CreateTableRelation(shared_ptr child, string catalog_name, string schema_name, string table_name, - bool temporary, OnCreateConflict on_conflict); + CreateTableRelation(shared_ptr child, Identifier catalog_name, Identifier schema_name, + Identifier table_name, bool temporary, OnCreateConflict on_conflict); shared_ptr child; - string catalog_name; - string schema_name; - string table_name; + Identifier catalog_name; + Identifier schema_name; + Identifier table_name; vector columns; bool temporary; OnCreateConflict on_conflict; diff --git a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp index aa09b0def..69baad0df 100644 --- a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp @@ -14,12 +14,13 @@ namespace duckdb { class CreateViewRelation : public Relation { public: - CreateViewRelation(shared_ptr child, string view_name, bool replace, bool temporary); - CreateViewRelation(shared_ptr child, string schema_name, string view_name, bool replace, bool temporary); + CreateViewRelation(shared_ptr child, Identifier view_name, bool replace, bool temporary); + CreateViewRelation(shared_ptr child, Identifier schema_name, Identifier view_name, bool replace, + bool temporary); shared_ptr child; - string schema_name; - string view_name; + Identifier schema_name; + Identifier view_name; bool replace; bool temporary; vector columns; diff --git a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp index 0c25c6576..851900c62 100644 --- a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp @@ -16,13 +16,13 @@ namespace duckdb { class DeleteRelation : public Relation { public: DeleteRelation(shared_ptr &context, unique_ptr condition, - string catalog_name, string schema_name, string table_name); + Identifier catalog_name, Identifier schema_name, Identifier table_name); vector columns; unique_ptr condition; - string catalog_name; - string schema_name; - string table_name; + Identifier catalog_name; + Identifier schema_name; + Identifier table_name; public: BoundStatement Bind(Binder &binder) override; diff --git a/src/duckdb/src/include/duckdb/main/relation/distinct_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/distinct_relation.hpp index 18209e1a5..98677c7ce 100644 --- a/src/duckdb/src/include/duckdb/main/relation/distinct_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/distinct_relation.hpp @@ -23,7 +23,7 @@ class DistinctRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; public: bool InheritsColumnBindings() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/filter_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/filter_relation.hpp index 2a87a02dc..a9ccc981f 100644 --- a/src/duckdb/src/include/duckdb/main/relation/filter_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/filter_relation.hpp @@ -25,7 +25,7 @@ class FilterRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; public: bool InheritsColumnBindings() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp index 41756488f..6e7782fae 100644 --- a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp @@ -14,13 +14,13 @@ namespace duckdb { class InsertRelation : public Relation { public: - InsertRelation(shared_ptr child, string schema_name, string table_name); - InsertRelation(shared_ptr child, string catalog_name, string schema_name, string table_name); + InsertRelation(shared_ptr child, Identifier schema_name, Identifier table_name); + InsertRelation(shared_ptr child, Identifier catalog_name, Identifier schema_name, Identifier table_name); shared_ptr child; - string catalog_name; - string schema_name; - string table_name; + Identifier catalog_name; + Identifier schema_name; + Identifier table_name; vector columns; public: diff --git a/src/duckdb/src/include/duckdb/main/relation/join_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/join_relation.hpp index f7d4671ee..ded9aeaeb 100644 --- a/src/duckdb/src/include/duckdb/main/relation/join_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/join_relation.hpp @@ -18,13 +18,13 @@ class JoinRelation : public Relation { DUCKDB_API JoinRelation(shared_ptr left, shared_ptr right, unique_ptr condition, JoinType type, JoinRefType join_ref_type = JoinRefType::REGULAR); - DUCKDB_API JoinRelation(shared_ptr left, shared_ptr right, vector using_columns, + DUCKDB_API JoinRelation(shared_ptr left, shared_ptr right, vector using_columns, JoinType type, JoinRefType join_ref_type = JoinRefType::REGULAR); shared_ptr left; shared_ptr right; unique_ptr condition; - vector using_columns; + vector using_columns; JoinType join_type; JoinRefType join_ref_type; vector columns; diff --git a/src/duckdb/src/include/duckdb/main/relation/limit_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/limit_relation.hpp index 4edc1ae4b..7536e25a0 100644 --- a/src/duckdb/src/include/duckdb/main/relation/limit_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/limit_relation.hpp @@ -25,7 +25,7 @@ class LimitRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; public: bool InheritsColumnBindings() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/materialized_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/materialized_relation.hpp index 2911fae4e..7c8dc2425 100644 --- a/src/duckdb/src/include/duckdb/main/relation/materialized_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/materialized_relation.hpp @@ -16,7 +16,7 @@ namespace duckdb { class MaterializedRelation : public Relation { public: MaterializedRelation(const shared_ptr &context, unique_ptr &&collection, - vector names, string alias = "materialized"); + vector names, const Identifier &alias = "materialized"); vector columns; string alias; shared_ptr collection; @@ -24,7 +24,7 @@ class MaterializedRelation : public Relation { public: const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; unique_ptr GetTableRef() override; unique_ptr GetQueryNode() override; }; diff --git a/src/duckdb/src/include/duckdb/main/relation/order_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/order_relation.hpp index 604ea6895..7d560d8d8 100644 --- a/src/duckdb/src/include/duckdb/main/relation/order_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/order_relation.hpp @@ -27,7 +27,7 @@ class OrderRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; public: bool InheritsColumnBindings() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/projection_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/projection_relation.hpp index 11110a7c4..648e2d8cd 100644 --- a/src/duckdb/src/include/duckdb/main/relation/projection_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/projection_relation.hpp @@ -16,7 +16,7 @@ namespace duckdb { class ProjectionRelation : public Relation { public: DUCKDB_API ProjectionRelation(shared_ptr child, vector> expressions, - vector aliases); + vector aliases); vector> expressions; vector columns; @@ -27,7 +27,7 @@ class ProjectionRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp index b1be001b9..4c6cd2a3d 100644 --- a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp @@ -34,7 +34,7 @@ class QueryRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; private: unique_ptr GetSelectStatement(); diff --git a/src/duckdb/src/include/duckdb/main/relation/read_csv_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/read_csv_relation.hpp index 0590cf6a9..561435b90 100644 --- a/src/duckdb/src/include/duckdb/main/relation/read_csv_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/read_csv_relation.hpp @@ -26,7 +26,7 @@ class ReadCSVRelation : public TableFunctionRelation { void InitializeAlias(const vector &input); public: - string GetAlias() override; + Identifier GetAlias() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/read_json_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/read_json_relation.hpp index aee96ac19..f292bedb8 100644 --- a/src/duckdb/src/include/duckdb/main/relation/read_json_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/read_json_relation.hpp @@ -20,7 +20,7 @@ class ReadJSONRelation : public TableFunctionRelation { string alias; public: - string GetAlias() override; + Identifier GetAlias() override; private: void InitializeAlias(const vector &input); diff --git a/src/duckdb/src/include/duckdb/main/relation/setop_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/setop_relation.hpp index 77920bf9e..66dbc2ac8 100644 --- a/src/duckdb/src/include/duckdb/main/relation/setop_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/setop_relation.hpp @@ -29,7 +29,7 @@ class SetOpRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/relation/table_function_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/table_function_relation.hpp index 8781d7125..efd84979d 100644 --- a/src/duckdb/src/include/duckdb/main/relation/table_function_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/table_function_relation.hpp @@ -25,7 +25,7 @@ class TableFunctionRelation : public Relation { ~TableFunctionRelation() override { } - string name; + Identifier name; vector parameters; named_parameter_map_t named_parameters; vector columns; @@ -37,7 +37,7 @@ class TableFunctionRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; void AddNamedParameter(const string &name, Value argument); void RemoveNamedParameterIfExists(const string &name); void SetNamedParameters(named_parameter_map_t &&named_parameters); diff --git a/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp index a3184dafa..d736d765c 100644 --- a/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/table_relation.hpp @@ -25,7 +25,7 @@ class TableRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; unique_ptr GetTableRef() override; diff --git a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp index 91eac246e..b51dcc1c1 100644 --- a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp @@ -16,15 +16,15 @@ namespace duckdb { class UpdateRelation : public Relation { public: UpdateRelation(shared_ptr &context, unique_ptr condition, - string catalog_name, string schema_name, string table_name, vector update_columns, - vector> expressions); + Identifier catalog_name, Identifier schema_name, Identifier table_name, + vector update_columns, vector> expressions); vector columns; unique_ptr condition; - string catalog_name; - string schema_name; - string table_name; - vector update_columns; + Identifier catalog_name; + Identifier schema_name; + Identifier table_name; + vector update_columns; vector> expressions; public: diff --git a/src/duckdb/src/include/duckdb/main/relation/value_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/value_relation.hpp index 9b5ae5257..7289bf3ac 100644 --- a/src/duckdb/src/include/duckdb/main/relation/value_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/value_relation.hpp @@ -37,7 +37,7 @@ class ValueRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; unique_ptr GetTableRef() override; }; diff --git a/src/duckdb/src/include/duckdb/main/relation/view_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/view_relation.hpp index 529817d60..c4a9163b2 100644 --- a/src/duckdb/src/include/duckdb/main/relation/view_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/view_relation.hpp @@ -14,12 +14,12 @@ namespace duckdb { class ViewRelation : public Relation { public: - ViewRelation(const shared_ptr &context, string schema_name, string view_name); - ViewRelation(const shared_ptr &context, string schema_name, string view_name); + ViewRelation(const shared_ptr &context, Identifier schema_name, Identifier view_name); + ViewRelation(const shared_ptr &context, Identifier schema_name, Identifier view_name); ViewRelation(const shared_ptr &context, unique_ptr ref, const string &view_name); - string schema_name; - string view_name; + Identifier schema_name; + Identifier view_name; vector columns; unique_ptr premade_tableref; @@ -29,7 +29,7 @@ class ViewRelation : public Relation { const vector &Columns() override; string ToString(idx_t depth) override; - string GetAlias() override; + Identifier GetAlias() override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/secret/secret.hpp b/src/duckdb/src/include/duckdb/main/secret/secret.hpp index 6b96b3e19..bb8bd3eb0 100644 --- a/src/duckdb/src/include/duckdb/main/secret/secret.hpp +++ b/src/duckdb/src/include/duckdb/main/secret/secret.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/named_parameter_map.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -27,13 +28,13 @@ enum class SecretPersistType : uint8_t { DEFAULT, TEMPORARY, PERSISTENT }; //! Input passed to a CreateSecretFunction struct CreateSecretInput { //! type - string type; + Identifier type; //! mode - string provider; + Identifier provider; //! should the secret be persisted? - string storage_type; + Identifier storage_type; //! (optional) alias provided by user - string name; + Identifier name; //! (optional) scope provided by user vector scope; //! (optional) named parameter map, each create secret function has defined it's own set of these @@ -51,7 +52,7 @@ typedef unique_ptr (*create_secret_function_t)(ClientContext &contex class CreateSecretFunction { public: string secret_type; - string provider; + Identifier provider; create_secret_function_t function; named_parameter_type_map_t named_parameters; }; @@ -63,15 +64,15 @@ class CreateSecretFunctionSet { explicit CreateSecretFunctionSet(string &name) : name(name) {}; public: - bool ProviderExists(const string &provider_name); + bool ProviderExists(const Identifier &provider_name); void AddFunction(CreateSecretFunction &function, OnCreateConflict on_conflict); CreateSecretFunction &GetFunction(const string &provider); protected: //! Create Secret Function type name - string name; + Identifier name; //! Maps of provider -> function - case_insensitive_map_t functions; + identifier_map_t functions; }; //! Determines whether the secrets are allowed to be shown @@ -80,7 +81,7 @@ enum class SecretDisplayType : uint8_t { REDACTED, UNREDACTED }; //! Secret types contain the base settings of a secret struct SecretType { //! Unique name identifying the secret type - string name; + Identifier name; //! The deserialization function for the type secret_deserializer_t deserializer; //! Provider to use when non is specified @@ -101,7 +102,7 @@ class BaseSecret { friend class SecretManager; public: - BaseSecret(vector prefix_paths_p, string type_p, string provider_p, string name_p) + BaseSecret(vector prefix_paths_p, Identifier type_p, Identifier provider_p, Identifier name_p) : prefix_paths(std::move(prefix_paths_p)), type(std::move(type_p)), provider(std::move(provider_p)), name(std::move(name_p)), serializable(false) { D_ASSERT(!type.empty()); @@ -130,13 +131,13 @@ class BaseSecret { const vector &GetScope() const { return prefix_paths; } - const string &GetType() const { + const Identifier &GetType() const { return type; } - const string &GetProvider() const { + const Identifier &GetProvider() const { return provider; } - const string &GetName() const { + const Identifier &GetName() const { return name; } bool IsSerializable() const { @@ -151,11 +152,11 @@ class BaseSecret { vector prefix_paths; //! Type of secret - string type; + Identifier type; //! Provider of the secret - string provider; + Identifier provider; //! Name of the secret - string name; + Identifier name; //! Whether the secret can be serialized/deserialized bool serializable; }; @@ -164,7 +165,8 @@ class BaseSecret { //! for most use-cases of secrets as secrets generally tend to fit in a key value map. class KeyValueSecret : public BaseSecret { public: - KeyValueSecret(const vector &prefix_paths, const string &type, const string &provider, const string &name) + KeyValueSecret(const vector &prefix_paths, const Identifier &type, const Identifier &provider, + const Identifier &name) : BaseSecret(prefix_paths, type, provider, name) { D_ASSERT(!type.empty()); serializable = true; @@ -192,7 +194,7 @@ class KeyValueSecret : public BaseSecret { void Serialize(Serializer &serializer) const override; //! Tries to get the value at key , depending on error_on_missing will throw or return Value() - Value TryGetValue(const string &key, bool error_on_missing = false) const; + Value TryGetValue(const Identifier &key, bool error_on_missing = false) const; // FIXME: use serialization scripts template @@ -203,13 +205,13 @@ class KeyValueSecret : public BaseSecret { for (const auto &entry : ListValue::GetChildren(secret_map_value)) { auto kv_struct = StructValue::GetChildren(entry); - result->secret_map[kv_struct[0].ToString()] = kv_struct[1]; + result->secret_map[Identifier(kv_struct[0].ToString())] = kv_struct[1]; } Value redact_set_value; deserializer.ReadProperty(202, "redact_keys", redact_set_value); for (const auto &entry : ListValue::GetChildren(redact_set_value)) { - result->redact_keys.insert(entry.ToString()); + result->redact_keys.insert(Identifier(entry.ToString())); } return duckdb::unique_ptr_cast(std::move(result)); @@ -220,7 +222,7 @@ class KeyValueSecret : public BaseSecret { } // Get a value from the secret - bool TryGetValue(const string &key, Value &result) const { + bool TryGetValue(const Identifier &key, Value &result) const { auto lookup = secret_map.find(key); if (lookup == secret_map.end()) { return false; @@ -229,8 +231,8 @@ class KeyValueSecret : public BaseSecret { return true; } - bool TrySetValue(const string &key, const CreateSecretInput &input) { - auto lookup = input.options.find(key); + bool TrySetValue(const Identifier &key, const CreateSecretInput &input) { + auto lookup = input.options.find(key.GetIdentifierName()); if (lookup != input.options.end()) { secret_map[key] = lookup->second; return true; @@ -239,9 +241,9 @@ class KeyValueSecret : public BaseSecret { } //! the map of key -> values that make up the secret - case_insensitive_tree_t secret_map; + identifier_tree_t secret_map; //! keys that are sensitive and should be redacted - case_insensitive_set_t redact_keys; + identifier_set_t redact_keys; }; // Helper class to fetch secret parameters in a cascading way. The idea being that in many cases there is a direct diff --git a/src/duckdb/src/include/duckdb/main/secret/secret_manager.hpp b/src/duckdb/src/include/duckdb/main/secret/secret_manager.hpp index 8ba703dcd..71c7fc13c 100644 --- a/src/duckdb/src/include/duckdb/main/secret/secret_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/secret/secret_manager.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/catalog/default/default_generator.hpp" #include "duckdb/common/common.hpp" #include "duckdb/main/secret/secret.hpp" @@ -130,10 +131,10 @@ class SecretManager { DUCKDB_API unique_ptr GetSecretByName(CatalogTransaction transaction, const string &name, const string &storage = ""); //! Delete a secret by name, optionally by providing the storage to drop from - DUCKDB_API void DropSecretByName(CatalogTransaction transaction, const string &name, + DUCKDB_API void DropSecretByName(CatalogTransaction transaction, const Identifier &name, OnEntryNotFound on_entry_not_found, SecretPersistType persist_type = SecretPersistType::DEFAULT, - const string &storage = ""); + const Identifier &storage = ""); //! List all secrets from all secret storages DUCKDB_API vector AllSecrets(CatalogTransaction transaction); @@ -154,25 +155,25 @@ class SecretManager { DUCKDB_API virtual string PersistentSecretPath(); //! Utility functions - DUCKDB_API void DropSecretByName(ClientContext &context, const string &name, OnEntryNotFound on_entry_not_found, + DUCKDB_API void DropSecretByName(ClientContext &context, const Identifier &name, OnEntryNotFound on_entry_not_found, SecretPersistType persist_type = SecretPersistType::DEFAULT, - const string &storage = ""); + const Identifier &storage = ""); private: //! Register a secret type void RegisterSecretTypeInternal(SecretType &type); //! Lookup a SecretType, throws if not found - SecretType LookupTypeInternal(const string &type); + SecretType LookupTypeInternal(const Identifier &type); //! Try to lookup a SecretType bool TryLookupTypeInternal(const string &type, SecretType &type_out); //! Register a secret provider void RegisterSecretFunctionInternal(CreateSecretFunction function, OnCreateConflict on_conflict); //! Lookup a CreateSecretFunction - optional_ptr LookupFunctionInternal(const string &type, const string &provider); + optional_ptr LookupFunctionInternal(const Identifier &type, const Identifier &provider); //! Register a new Secret unique_ptr RegisterSecretInternal(CatalogTransaction transaction, unique_ptr secret, OnCreateConflict on_conflict, SecretPersistType persist_type, - const string &storage = ""); + const Identifier &storage = Identifier()); //! Initialize the secret catalog_set and persistent secrets (lazily) void InitializeSecrets(CatalogTransaction transaction); //! Load a secret storage @@ -184,12 +185,12 @@ class SecretManager { void AutoloadExtensionForFunction(const string &type, const string &provider); //! Will throw appropriate error message when type not found - [[noreturn]] void ThrowTypeNotFoundError(const string &type, const string &secret_path = ""); + [[noreturn]] void ThrowTypeNotFoundError(const Identifier &type, const string &secret_path = ""); [[noreturn]] void ThrowProviderNotFoundError(const string &type, const string &provider, bool was_default = false); //! Thread-safe accessors for secret_storages vector> GetSecretStorages(); - optional_ptr GetSecretStorage(const string &name); + optional_ptr GetSecretStorage(const Identifier &name); //! Throw an exception if the secret manager is initialized void ThrowOnSettingChangeIfInitialized(); @@ -197,11 +198,11 @@ class SecretManager { //! Lock for types, functions, settings and storages mutex manager_lock; //! Secret functions; - case_insensitive_map_t secret_functions; + identifier_map_t secret_functions; //! Secret types; - case_insensitive_map_t secret_types; + identifier_map_t secret_types; //! Map of all registered SecretStorages - case_insensitive_map_t> secret_storages; + identifier_map_t> secret_storages; //! While false, secret manager settings can still be changed atomic initialized {false}; //! Configuration for secret manager @@ -213,22 +214,22 @@ class SecretManager { //! The DefaultGenerator for persistent secrets. This is used to store lazy loaded secrets in the catalog class DefaultSecretGenerator : public DefaultGenerator { public: - DefaultSecretGenerator(Catalog &catalog, SecretManager &secret_manager, case_insensitive_set_t &persistent_secrets); + DefaultSecretGenerator(Catalog &catalog, SecretManager &secret_manager, identifier_set_t &persistent_secrets); public: - unique_ptr CreateDefaultEntry(CatalogTransaction transaction, const string &entry_name) override; - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override; - vector GetDefaultEntries() override; + unique_ptr CreateDefaultEntry(CatalogTransaction transaction, const Identifier &entry_name) override; + unique_ptr CreateDefaultEntry(ClientContext &context, const Identifier &entry_name) override; + vector GetDefaultEntries() override; bool LockDuringCreate() const override { return true; } protected: - unique_ptr CreateDefaultEntryInternal(const string &entry_name); + unique_ptr CreateDefaultEntryInternal(const Identifier &entry_name); SecretManager &secret_manager; mutex lock; - case_insensitive_set_t persistent_secrets; + identifier_set_t persistent_secrets; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/secret/secret_storage.hpp b/src/duckdb/src/include/duckdb/main/secret/secret_storage.hpp index 5649411d8..a02c5a903 100644 --- a/src/duckdb/src/include/duckdb/main/secret/secret_storage.hpp +++ b/src/duckdb/src/include/duckdb/main/secret/secret_storage.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/catalog_set.hpp" #include "duckdb/common/common.hpp" @@ -54,7 +55,7 @@ class SecretStorage { //! Get all secrets virtual vector AllSecrets(optional_ptr transaction = nullptr) = 0; //! Drop secret by name - virtual void DropSecretByName(const string &name, OnEntryNotFound on_entry_not_found, + virtual void DropSecretByName(const Identifier &name, OnEntryNotFound on_entry_not_found, optional_ptr transaction = nullptr) = 0; //! Get best match virtual SecretMatch LookupSecret(const string &path, const string &type, @@ -111,7 +112,7 @@ class CatalogSetSecretStorage : public SecretStorage { unique_ptr StoreSecret(unique_ptr secret, OnCreateConflict on_conflict, optional_ptr transaction = nullptr) override; vector AllSecrets(optional_ptr transaction = nullptr) override; - void DropSecretByName(const string &name, OnEntryNotFound on_entry_not_found, + void DropSecretByName(const Identifier &name, OnEntryNotFound on_entry_not_found, optional_ptr transaction = nullptr) override; SecretMatch LookupSecret(const string &path, const string &type, optional_ptr transaction = nullptr) override; @@ -154,7 +155,7 @@ class LocalFileSecretStorage : public CatalogSetSecretStorage { void RemoveSecret(const string &secret, OnEntryNotFound on_entry_not_found) override; //! Set of persistent secrets that are lazily loaded - case_insensitive_set_t persistent_secrets; + identifier_set_t persistent_secrets; //! Path that is searched for secrets; string secret_path; }; diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index 27f1bb942..a29a7b470 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -74,10 +74,24 @@ struct Settings { return StringUtil::ToDouble(OP::DefaultValue); } + template + static void Set(SOURCE &source, idx_t setting_index, SetScope scope, Value target_value) { + SetSettingInternal(source, setting_index, scope, std::move(target_value)); + } + + template + static void Set(SOURCE &source, SetScope scope, Value target_value) { + Set(source, OP::SettingIndex, scope, std::move(target_value)); + } + private: static bool TryGetSettingInternal(const DatabaseInstance &db, idx_t setting_index, Value &result); static bool TryGetSettingInternal(const DBConfig &config, idx_t setting_index, Value &result); static bool TryGetSettingInternal(const ClientContext &context, idx_t setting_index, Value &result); + + static void SetSettingInternal(DatabaseInstance &db, idx_t setting_index, SetScope scope, Value target_value); + static void SetSettingInternal(DBConfig &config, idx_t setting_index, SetScope scope, Value target_value); + static void SetSettingInternal(ClientContext &context, idx_t setting_index, SetScope scope, Value target_value); }; static constexpr idx_t SETTING_INDEX_BASE = __COUNTER__ + 1; @@ -140,9 +154,10 @@ struct AllocatorFlushThresholdSetting { static constexpr const char *Description = "Peak allocation threshold at which to flush the allocator after completing a task."; static constexpr const char *InputType = "VARCHAR"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "134217728B"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct AllowCommunityExtensionsSetting { @@ -300,6 +315,17 @@ struct AsofLoopJoinThresholdSetting { static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; +struct AsyncThreadsSetting { + using RETURN_TYPE = int64_t; + static constexpr const char *Name = "async_threads"; + static constexpr const char *Description = + "The number of total async threads used by the system for tasks like I/O."; + static constexpr const char *InputType = "BIGINT"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct AutoCheckpointSkipWalThresholdSetting { using RETURN_TYPE = idx_t; static constexpr const char *Name = "auto_checkpoint_skip_wal_threshold"; @@ -569,6 +595,16 @@ struct DebugVerificationProjectionSetting { static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; +struct DebugVerifyAggregateStateExportSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "debug_verify_aggregate_state_export"; + static constexpr const char *Description = "DEBUG SETTING: enable verification of aggregate state export"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + struct DebugVerifyBlocksSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "debug_verify_blocks"; @@ -666,6 +702,18 @@ struct DefaultCollationSetting { static void OnSet(SettingCallbackInfo &info, Value &input); }; +struct DefaultIoModeSetting { + using RETURN_TYPE = FileIOMode; + static constexpr const char *Name = "default_io_mode"; + static constexpr const char *Description = + "The default IO_MODE for newly attached database files when no explicit IO_MODE is given (BUFFERED_IO or MMAP)"; + static constexpr const char *InputType = "VARCHAR"; + static constexpr const char *DefaultValue = "BUFFERED_IO"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); + static void OnSet(SettingCallbackInfo &info, Value &input); +}; + struct DefaultNullOrderSetting { using RETURN_TYPE = DefaultOrderByNullType; static constexpr const char *Name = "default_null_order"; @@ -698,6 +746,30 @@ struct DefaultSecretStorageSetting { static Value GetSetting(const ClientContext &context); }; +struct DefaultTransactionInvalidationPolicySetting { + using RETURN_TYPE = TransactionInvalidationPolicy; + static constexpr const char *Name = "default_transaction_invalidation_policy"; + static constexpr const char *Description = + "When to invalidate transactions when errors occur (SYNTACTIC_ERRORS_DO_NOT_INVALIDATE, i.e. parser and binder " + "exceptions do not invalidate, or ALL_ERRORS_INVALIDATE_TRANSACTION)"; + static constexpr const char *InputType = "VARCHAR"; + static constexpr const char *DefaultValue = "ALL_ERRORS_INVALIDATE_TRANSACTION"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); + static void OnSet(SettingCallbackInfo &info, Value &input); +}; + +struct DelimJoinAsCteSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "delim_join_as_cte"; + static constexpr const char *Description = + "Rewrite delim joins to materialized CTEs during dependent join flattening"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + struct DeprecatedUsingKeySyntaxSetting { using RETURN_TYPE = DeprecatedUsingKeySyntax; static constexpr const char *Name = "deprecated_using_key_syntax"; @@ -839,16 +911,6 @@ struct EnableFSSTVectorsSetting { static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; -struct EnableHTTPLoggingSetting { - using RETURN_TYPE = bool; - static constexpr const char *Name = "enable_http_logging"; - static constexpr const char *Description = "(deprecated) Enables HTTP logging"; - static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); -}; - struct EnableHTTPMetadataCacheSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "enable_http_metadata_cache"; @@ -890,6 +952,16 @@ struct EnableObjectCacheSetting { static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; +struct EnableOptimizerSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "enable_optimizer"; + static constexpr const char *Description = "Whether or not query optimization is enabled"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "true"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + struct EnableProfilingSetting { using RETURN_TYPE = string; static constexpr const char *Name = "enable_profiling"; @@ -1053,6 +1125,18 @@ struct ForceBitpackingModeSetting { static void OnSet(SettingCallbackInfo &info, Value &input); }; +struct ForceColumnMetadataReuseSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "force_column_metadata_reuse"; + static constexpr const char *Description = + "Force re-use of row group metadata on a column-level when checkpointing on older storage versions 6 and 7. " + "This breaks storage backward-compatibility with older DuckDB versions."; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_ONLY; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + struct ForceCompressionSetting { using RETURN_TYPE = CompressionType; static constexpr const char *Name = "force_compression"; @@ -1107,25 +1191,15 @@ struct HomeDirectorySetting { static void OnSet(SettingCallbackInfo &info, Value &input); }; -struct HTTPLoggingOutputSetting { - using RETURN_TYPE = string; - static constexpr const char *Name = "http_logging_output"; - static constexpr const char *Description = - "(deprecated) The file to which HTTP logging output should be saved, or empty to print to the terminal"; - static constexpr const char *InputType = "VARCHAR"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); -}; - struct HTTPProxySetting { using RETURN_TYPE = string; static constexpr const char *Name = "http_proxy"; - static constexpr const char *Description = "HTTP proxy host"; + static constexpr const char *Description = + "HTTP proxy host (defaults to the HTTP_PROXY environment variable when unset)"; static constexpr const char *InputType = "VARCHAR"; - static constexpr const char *DefaultValue = ""; - static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_ONLY; - static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); }; struct HTTPProxyPasswordSetting { @@ -1206,6 +1280,20 @@ struct IndexScanPercentageSetting { static void OnSet(SettingCallbackInfo &info, Value &input); }; +struct InitialColumnSegmentSizeSetting { + using RETURN_TYPE = idx_t; + static constexpr const char *Name = "initial_column_segment_size"; + static constexpr const char *Description = + "The initial memory (in bytes) reserved for the first transient column segment. Must be a power of two. " + "Internally, we subtract the block header size (typically 8 bytes) for segments with or exceeding 1024 bytes. " + "Subsequent segments double in size until reaching the block size."; + static constexpr const char *InputType = "UBIGINT"; + static constexpr const char *DefaultValue = "2048"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); + static void OnSet(SettingCallbackInfo &info, Value &input); +}; + struct IntegerDivisionSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "integer_division"; @@ -1240,6 +1328,28 @@ struct LateMaterializationMaxRowsSetting { static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; +struct LegacyDisableNullTypeSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "legacy_disable_null_type"; + static constexpr const char *Description = + "When enabled, prevent the NULL type from leaving the binder (< v2.0 default behavior)"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + +struct LegacyMetricsFormatSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "legacy_metrics_format"; + static constexpr const char *Description = + "When enabled, profiling output uses the legacy flat format instead of the current grouped format"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + struct LockConfigurationSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "lock_configuration"; @@ -1298,7 +1408,7 @@ struct MaxExecutionTimeSetting { static constexpr const char *Description = "The maximum execution time per query in milliseconds (0 = no limit)"; static constexpr const char *InputType = "BIGINT"; static constexpr const char *DefaultValue = "0"; - static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_ONLY; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_DEFAULT; static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; @@ -1639,6 +1749,16 @@ struct SecretDirectorySetting { static Value GetSetting(const ClientContext &context); }; +struct StandardVectorSizeSetting { + using RETURN_TYPE = idx_t; + static constexpr const char *Name = "standard_vector_size"; + static constexpr const char *Description = "The compiled-in STANDARD_VECTOR_SIZE (read-only)"; + static constexpr const char *InputType = "UBIGINT"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct StorageBlockPrefetchSetting { using RETURN_TYPE = StorageBlockPrefetch; static constexpr const char *Name = "storage_block_prefetch"; @@ -1702,6 +1822,17 @@ struct ThreadsSetting { static Value GetSetting(const ClientContext &context); }; +struct TrackedMetricsSetting { + using RETURN_TYPE = vector; + static constexpr const char *Name = "tracked_metrics"; + static constexpr const char *Description = + "A list of metric glob patterns to enable for collection (e.g. ['query.*', 'optimizer.*'])"; + static constexpr const char *InputType = "VARCHAR[]"; + static void SetLocal(ClientContext &context, const Value ¶meter); + static void ResetLocal(ClientContext &context); + static Value GetSetting(const ClientContext &context); +}; + struct UsernameSetting { using RETURN_TYPE = string; static constexpr const char *Name = "username"; @@ -1717,7 +1848,8 @@ struct VacuumRebuildIndexesSetting { static constexpr const char *Name = "vacuum_rebuild_indexes"; static constexpr const char *Description = "(Experimental) Allow vacuum to compact row groups on tables with bound ART indexes, rebuilding the indexes " - "afterward. Tables with a row count exceeding this threshold are skipped. 0 = disabled."; + "afterward. Tables with a row count exceeding this threshold are skipped. 0 = disabled. Can also be set " + "per-database via the 'vacuum_rebuild_indexes' ATTACH option, which overrides this default."; static constexpr const char *InputType = "UBIGINT"; static constexpr const char *DefaultValue = "0"; static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; @@ -1782,6 +1914,19 @@ struct WriteBufferRowGroupCountSetting { static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; +struct WriteBufferRowGroupMemoryLimitSetting { + using RETURN_TYPE = string; + static constexpr const char *Name = "write_buffer_row_group_memory_limit"; + static constexpr const char *Description = + "The maximum data to buffer in row groups (in bytes) to buffer prior to flushing them together. When either " + "this limit is reached, or write_buffer_row_group_count is reached, we flush the data to disk. Defaults to 20% " + "of memory limit divided by thread count."; + static constexpr const char *InputType = "VARCHAR"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct ZstdMinStringLengthSetting { using RETURN_TYPE = idx_t; static constexpr const char *Name = "zstd_min_string_length"; diff --git a/src/duckdb/src/include/duckdb/main/table_description.hpp b/src/duckdb/src/include/duckdb/main/table_description.hpp index d1b7f6561..e139019a3 100644 --- a/src/duckdb/src/include/duckdb/main/table_description.hpp +++ b/src/duckdb/src/include/duckdb/main/table_description.hpp @@ -9,23 +9,24 @@ #pragma once #include "duckdb/parser/column_definition.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { class TableDescription { public: - TableDescription(const string &database_name, const string &schema_name, const string &table_name) - : database(database_name), schema(schema_name), table(table_name) {}; + TableDescription(Identifier database_name, Identifier schema_name, Identifier table_name) + : database(std::move(database_name)), schema(std::move(schema_name)), table(std::move(table_name)) {}; TableDescription() = delete; public: //! The database of the table. - string database; + Identifier database; //! The schema of the table. - string schema; + Identifier schema; //! The name of the table. - string table; + Identifier table; //! True, if the catalog is readonly. bool readonly; //! The columns of the table. diff --git a/src/duckdb/src/include/duckdb/optimizer/build_probe_side_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/build_probe_side_optimizer.hpp index d013074f6..b59229bb1 100644 --- a/src/duckdb/src/include/duckdb/optimizer/build_probe_side_optimizer.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/build_probe_side_optimizer.hpp @@ -24,6 +24,8 @@ struct BuildSize { } }; +enum class RecursiveProbeSidePreference : uint8_t { NONE, KEEP, SWAP }; + class BuildProbeSideOptimizer : LogicalOperatorVisitor { private: static constexpr idx_t COLUMN_COUNT_PENALTY = 2; @@ -37,6 +39,9 @@ class BuildProbeSideOptimizer : LogicalOperatorVisitor { private: bool TryFlipJoinChildren(LogicalOperator &op) const; static idx_t ChildHasJoins(LogicalOperator &op); + bool ContainsActiveRecursiveReference(const LogicalOperator &op) const; + bool ContainsCorrelationSensitiveOperators(const LogicalOperator &op) const; + RecursiveProbeSidePreference GetRecursiveProbeSidePreference(const LogicalOperator &op) const; static BuildSize GetBuildSizes(const LogicalOperator &op, idx_t lhs_cardinality, idx_t rhs_cardinality); static double GetBuildSize(vector types, idx_t cardinality); @@ -44,6 +49,7 @@ class BuildProbeSideOptimizer : LogicalOperatorVisitor { private: ClientContext &context; vector preferred_on_probe_side; + vector active_recursive_cte_indexes; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp b/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp index b1b573c16..8757c895f 100644 --- a/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp @@ -18,6 +18,14 @@ class Optimizer; class ClientContext; class LogicalOperator; +enum class CompressedMaterializationType : uint8_t { + INVALID = 0, + //! Use the opaque internal (de-)compress functions + FUNCTION = 1, + //! Use down- and upcasts that are transparent to subsequent optimizers + CAST = 2 +}; + struct CMChildInfo { public: CMChildInfo(LogicalOperator &op, const column_binding_set_t &referenced_bindings); @@ -42,7 +50,7 @@ struct CMBindingInfo { //! Type before compressing LogicalType type; - bool needs_decompression; + CompressedMaterializationType materialization_type; unique_ptr stats; }; @@ -62,11 +70,13 @@ struct CompressedMaterializationInfo { struct CompressExpression { public: - CompressExpression(unique_ptr expression, unique_ptr stats); + CompressExpression(unique_ptr expression, unique_ptr stats, + CompressedMaterializationType materialization_type); public: unique_ptr expression; unique_ptr stats; + CompressedMaterializationType materialization_type; }; typedef column_binding_map_t> statistics_map_t; @@ -103,7 +113,8 @@ class CompressedMaterialization { //! Adds bindings referenced in expression to referenced_bindings static void GetReferencedBindings(const Expression &expression, column_binding_set_t &referenced_bindings); //! Updates CMBindingInfo in the binding_map in info - void UpdateBindingInfo(CompressedMaterializationInfo &info, const ColumnBinding &binding, bool needs_decompression); + void UpdateBindingInfo(CompressedMaterializationInfo &info, const ColumnBinding &binding, + CompressedMaterializationType materialization_type); //! Create (de)compress projections around the operator void CreateProjections(unique_ptr &op, CompressedMaterializationInfo &info); @@ -113,6 +124,8 @@ class CompressedMaterialization { vector> compress_exprs, CompressedMaterializationInfo &info, CMChildInfo &child_info); void CreateDecompressProjection(unique_ptr &op, CompressedMaterializationInfo &info); + unique_ptr CreateRestoreExpression(unique_ptr input, const CMBindingInfo &binding_info, + const BaseStatistics &stats); //! Create expressions that apply a scalar compression function unique_ptr GetCompressExpression(const ColumnBinding &binding, const LogicalType &type, diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp index a4b1e2bb3..cde41eca7 100644 --- a/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/filter_combiner.hpp @@ -12,7 +12,6 @@ #include "duckdb/common/unordered_map.hpp" #include "duckdb/parser/expression_map.hpp" #include "duckdb/planner/expression.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/storage/data_table.hpp" #include @@ -49,10 +48,11 @@ class FilterCombiner { //! If this returns true - this sorts "in_list" as a side-effect static bool IsDenseRange(vector &in_list); static bool ContainsNull(vector &in_list); - static bool FindNextLegalUTF8(string &prefix_string); void GenerateFilters(const std::function filter)> &callback); bool HasFilters(); + void GenerateEquivalentFilters(const Expression &filter, + const std::function filter)> &callback); TableFilterSet GenerateTableScanFilters(const vector &column_ids, vector &pushdown_results); diff --git a/src/duckdb/src/include/duckdb/optimizer/in_clause_rewriter.hpp b/src/duckdb/src/include/duckdb/optimizer/in_clause_rewriter.hpp index 51af6e127..0e5ff262b 100644 --- a/src/duckdb/src/include/duckdb/optimizer/in_clause_rewriter.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/in_clause_rewriter.hpp @@ -27,8 +27,9 @@ class InClauseRewriter : public LogicalOperatorVisitor { unique_ptr root; public: + // Minimum number of values to use a hash join + static constexpr idx_t IN_CLAUSE_REWRITE_THRESHOLD = 6; unique_ptr Rewrite(unique_ptr op); - unique_ptr VisitReplace(BoundOperatorExpression &expr, unique_ptr *expr_ptr) override; }; diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp index 42a337c0f..96b7a5cca 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/cardinality_estimator.hpp @@ -7,101 +7,32 @@ //===----------------------------------------------------------------------===// #pragma once -#include "duckdb/planner/column_binding_map.hpp" -#include "duckdb/optimizer/join_order/query_graph.hpp" - +#include "duckdb/common/reference_map.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" #include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" namespace duckdb { class FilterInfo; - -struct DenomInfo { - DenomInfo(JoinRelationSet &numerator_relations, double filter_strength, double denominator) - : numerator_relations(numerator_relations), filter_strength(filter_strength), denominator(denominator) { - } - - JoinRelationSet &numerator_relations; - double filter_strength; - double denominator; -}; - -struct RelationsSetToStats { - //! column binding sets that are equivalent in a join plan. - //! if you have A.x = B.y and B.y = C.z, then one set is {A.x, B.y, C.z}. - column_binding_set_t equivalent_relations; - //! the estimated total domains of the equivalent relations determined using HLL - idx_t distinct_count_hll; - //! the estimated total domains of each relation without using HLL - idx_t distinct_count_no_hll; - bool has_distinct_count_hll; - vector> filters; - vector column_names; - - explicit RelationsSetToStats(const column_binding_set_t &column_binding_set) - : equivalent_relations(column_binding_set), distinct_count_hll(0), - distinct_count_no_hll(NumericLimits::Maximum()), has_distinct_count_hll(false) {}; -}; - -// class to wrap a join Filter along with some statistical information about the joined columns -class FilterInfoWithTotalDomains { -public: - FilterInfoWithTotalDomains(optional_ptr filter_info, RelationsSetToStats &relation_set_to_stats) - : filter_info(filter_info), distinct_count_hll(relation_set_to_stats.distinct_count_hll), - distinct_count_no_hll(relation_set_to_stats.distinct_count_no_hll), - has_distinct_count_hll(relation_set_to_stats.has_distinct_count_hll) { - } - - optional_ptr filter_info; - //! the estimated distinct count the joined columns determined using HLL - idx_t distinct_count_hll; - //! the estimated total domains of each relation without using HLL - idx_t distinct_count_no_hll; - bool has_distinct_count_hll; -}; - -struct Subgraph2Denominator { - optional_ptr relations; - optional_ptr numerator_relations; - double denom; - - Subgraph2Denominator() : relations(nullptr), numerator_relations(nullptr), denom(1) {}; -}; - -class CardinalityHelper { -public: - CardinalityHelper() { - } - explicit CardinalityHelper(double cardinality_before_filters) - : cardinality_before_filters(cardinality_before_filters) {}; - -public: - // must be a double. Otherwise we can lose significance between different join orders. - // our cardinality estimator severely underestimates cardinalities for 3+ joins. However, - // if one join order has an estimate of 0.8, and another has an estimate of 0.6, rounding - // them means there is no estimated difference, when in reality there could be a very large - // difference. - double cardinality_before_filters; - - vector table_names_joined; - vector column_names; -}; +class JoinPredicateModel; +struct CardinalityEstimatorState; +struct CompositeJoinPairStats; +struct DenomInfo; +struct DenominatorState; +struct FilterInfoWithTotalDomains; +struct LeftJoinDenomInfo; +struct Subgraph2Denominator; class CardinalityEstimator { public: static constexpr double DEFAULT_SEMI_ANTI_SELECTIVITY = 5; - explicit CardinalityEstimator() {}; - -private: - vector relation_set_stats; - unordered_map relation_set_2_cardinality; - JoinRelationSetManager set_manager; - vector relation_stats; + CardinalityEstimator(JoinRelationSetManager &set_manager, const JoinPredicateModel &predicate_model); + ~CardinalityEstimator(); public: void RemoveEmptyTotalDomains(); void UpdateTotalDomains(optional_ptr set, RelationStats &stats); - void InitEquivalentRelations(const vector> &filter_infos); + void InitEquivalentRelations(); void InitCardinalityEstimatorProps(optional_ptr set, RelationStats &stats); @@ -117,25 +48,55 @@ class CardinalityEstimator { private: double GetNumerator(JoinRelationSet &set); DenomInfo GetDenominator(JoinRelationSet &set); + void ProcessDenominatorEdge(FilterInfoWithTotalDomains &edge, JoinRelationSet &requested_set, + DenominatorState &state); + void CreateDenominatorSubgraph(FilterInfoWithTotalDomains &edge, JoinRelationSet &edge_left_set, + JoinRelationSet &edge_right_set, DenominatorState &state); + void ExtendDenominatorSubgraph(idx_t subgraph_index, FilterInfoWithTotalDomains &edge, + JoinRelationSet &edge_left_set, JoinRelationSet &edge_right_set, + bool can_increment_existing_join, DenominatorState &state); + void MergeDenominatorSubgraphs(const vector &subgraph_connections, FilterInfoWithTotalDomains &edge, + DenominatorState &state); + void MergeDisconnectedDenominatorSubgraphs(DenominatorState &state); + void AddCrossProductRelations(JoinRelationSet &set, DenominatorState &state); + DenomInfo CreateDenominatorResult(JoinRelationSet &set, DenominatorState &state); //! Applied outside the cardinality cache so stored values stay pre-OR. double ApplyOrFilterSelectivities(JoinRelationSet &new_set, double cardinality) const; - vector> or_filters; - - bool SingleColumnFilter(FilterInfo &filter_info); - vector DetermineMatchingEquivalentSets(optional_ptr filter_info); - //! Given a filter, add the column bindings to the matching equivalent set at the index - //! given in matching equivalent sets. - //! If there are multiple equivalence sets, they are merged. - void AddToEquivalenceSets(optional_ptr filter_info, vector matching_equivalent_sets); + bool SingleColumnFilter(const FilterInfo &filter_info); + //! Denom calculation double CalculateUpdatedDenom(Subgraph2Denominator left, Subgraph2Denominator right, FilterInfoWithTotalDomains &filter); + double CalculateInnerJoinDenom(double base_denom, FilterInfoWithTotalDomains &filter); + LeftJoinDenomInfo CalculateLeftJoinDenomInfo(Subgraph2Denominator &left, Subgraph2Denominator &right, + FilterInfoWithTotalDomains &filter); + double CalculateLeftJoinDenom(Subgraph2Denominator &left, Subgraph2Denominator &right, + FilterInfoWithTotalDomains &filter); + double CalculateSemiAntiJoinDenom(double base_denom, Subgraph2Denominator &left, Subgraph2Denominator &right, + FilterInfoWithTotalDomains &filter); + bool ApplyJoinIncrement(double &target_denom, FilterInfoWithTotalDomains &edge, + reference_map_t &inner_join_pair_stats, + reference_set_t &capped_join_pairs, JoinRelationSet &scope, + optional_ptr join_pair = nullptr); + bool ApplyJoinPairCap(double &target_denom, JoinRelationSet &join_pair, + reference_map_t &inner_join_pair_stats, + reference_set_t &capped_join_pairs); + bool ApplyCompositeJoinPairCaps(double &target_denom, JoinRelationSet &scope, + reference_map_t &inner_join_pair_stats, + reference_set_t &capped_join_pairs); + double GetJoinPairCap(JoinRelationSet &join_pair); + JoinRelationSet &UpdateNumeratorRelations(Subgraph2Denominator left, Subgraph2Denominator right, FilterInfoWithTotalDomains &filter); - void AddRelationStats(FilterInfo &filter_info); - bool EmptyFilter(FilterInfo &filter_info); + void AddRelationStats(const FilterInfo &filter_info); + bool EmptyFilter(const FilterInfo &filter_info); + +private: + unique_ptr state; + JoinRelationSetManager &set_manager; + const JoinPredicateModel &predicate_model; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/cost_model.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/cost_model.hpp index d6a9245b8..4b8c83acf 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/cost_model.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/cost_model.hpp @@ -8,7 +8,6 @@ #pragma once #include "duckdb/optimizer/join_order/join_node.hpp" -#include "duckdb/common/enums/join_type.hpp" #include "duckdb/optimizer/join_order/cardinality_estimator.hpp" namespace duckdb { @@ -17,22 +16,18 @@ class QueryGraphManager; class CostModel { public: - explicit CostModel(QueryGraphManager &query_graph_manager); - -private: - //! query graph storing relation manager information - QueryGraphManager &query_graph_manager; + explicit CostModel(QueryGraphManager &query_graph_manager, CardinalityEstimator &cardinality_estimator); public: - void InitCostModel(); - //! Compute cost of a join relation set - double ComputeCost(DPJoinNode &left, DPJoinNode &right); - - //! Cardinality Estimator used to calculate cost - CardinalityEstimator cardinality_estimator; + double ComputeCost(DPJoinNode &left, DPJoinNode &right, JoinRelationSet &combination, + const vector> &possible_connections); + CardinalityEstimator &GetCardinalityEstimator(); private: + //! query graph storing relation manager information + QueryGraphManager &query_graph_manager; + CardinalityEstimator &cardinality_estimator; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/filter_info.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/filter_info.hpp new file mode 100644 index 000000000..04cf8c162 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/filter_info.hpp @@ -0,0 +1,48 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/filter_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/optional_idx.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" +#include "duckdb/planner/column_binding.hpp" + +namespace duckdb { + +//! Filter info struct that is used by the cardinality estimator to set the initial cardinality +//! but is also eventually transformed into a query edge. +class FilterInfo { +public: + FilterInfo(unique_ptr filter, JoinRelationSet &set, idx_t filter_index, + JoinType join_type = JoinType::INNER); + ~FilterInfo(); + +public: + void SetLeftSet(optional_ptr left_set_new); + void SetRightSet(optional_ptr right_set_new); + +public: + unique_ptr filter; + reference set; + idx_t filter_index; + JoinType join_type; + optional_ptr left_set; + optional_ptr right_set; + ColumnBinding left_binding; + ColumnBinding right_binding; + bool from_residual_predicate = false; + //! Index of the equivalence group for INNER equality/IS NOT DISTINCT FROM join filters. + //! All filters transitively connected by equality (a=b, b=c -> a=c all share the same index). + //! Used by cardinality estimation to skip redundant transitive conditions. + optional_idx edge_equivalence_index; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/join_node.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/join_node.hpp index b68f9a0a4..5ab505178 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/join_node.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/join_node.hpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" #include "duckdb/optimizer/join_order/query_graph.hpp" namespace duckdb { @@ -15,6 +15,15 @@ namespace duckdb { struct NeighborInfo; class DPJoinNode { +public: + //! Create an intermediate node in the join tree. base_cardinality = estimated_props.cardinality + DPJoinNode(JoinRelationSet &set, optional_ptr info, JoinRelationSet &left, JoinRelationSet &right, + double cost); + //! Create a leaf node in the join tree + //! set cost to 0 for leaf nodes + //! cost will be the cost to *produce* an intermediate table + explicit DPJoinNode(JoinRelationSet &set); + public: //! Represents a node in the join plan JoinRelationSet &set; @@ -31,16 +40,7 @@ class DPJoinNode { //! the whole Join Order Optimizer can start exhibiting undesired behavior. double cost; //! used only to populate logical operators with estimated cardinalities after the best join plan has been found. - idx_t cardinality; - - //! Create an intermediate node in the join tree. base_cardinality = estimated_props.cardinality - DPJoinNode(JoinRelationSet &set, optional_ptr info, JoinRelationSet &left, JoinRelationSet &right, - double cost); - - //! Create a leaf node in the join tree - //! set cost to 0 for leaf nodes - //! cost will be the cost to *produce* an intermediate table - explicit DPJoinNode(JoinRelationSet &set); + idx_t cardinality = DConstants::INVALID_INDEX; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/join_order_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/join_order_optimizer.hpp index 34e96deb2..0647b333d 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/join_order_optimizer.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/join_order_optimizer.hpp @@ -10,25 +10,18 @@ #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/unordered_set.hpp" -#include "duckdb/optimizer/join_order/cardinality_estimator.hpp" -#include "duckdb/optimizer/join_order/join_node.hpp" -#include "duckdb/optimizer/join_order/join_relation.hpp" -#include "duckdb/optimizer/join_order/query_graph.hpp" #include "duckdb/optimizer/join_order/query_graph_manager.hpp" #include "duckdb/parser/expression_map.hpp" #include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/logical_operator_visitor.hpp" - -#include namespace duckdb { class JoinOrderOptimizer { public: explicit JoinOrderOptimizer(ClientContext &context); - JoinOrderOptimizer CreateChildOptimizer(); public: + JoinOrderOptimizer CreateChildOptimizer(); //! Perform join reordering inside a plan unique_ptr Optimize(unique_ptr plan, optional_ptr stats = nullptr); //! Adds/gets materialized CTE stats @@ -52,8 +45,6 @@ class JoinOrderOptimizer { //! i.e. in the join A=B AND B=C, the equivalence set of {B} is {A, C}, thus we can add an implied join edge {A = C} expression_map_t> equivalence_sets; - CardinalityEstimator cardinality_estimator; - unordered_set join_nodes_in_full_plan; //! Mapping from materialized CTE index to stats diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/join_predicate.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/join_predicate.hpp new file mode 100644 index 000000000..7052f7cc2 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/join_predicate.hpp @@ -0,0 +1,133 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/join_order/join_predicate.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/enums/join_type.hpp" +#include "duckdb/common/optional_idx.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/optimizer/join_order/filter_info.hpp" +#include "duckdb/planner/column_binding_map.hpp" + +namespace duckdb { + +enum class JoinPredicateClass { INNER_EQUALITY, INNER_NON_EQUALITY, LEFT_JOIN, SEMI_ANTI_JOIN, OTHER }; + +struct JoinOrderUtil { + //! Return the comparison type when the predicate expression itself is a comparison. + static ExpressionType GetJoinPredicateComparisonType(const FilterInfo &filter); + //! Classify the predicate for equivalence closure, denominator assembly and cost-model use. + static JoinPredicateClass ClassifyJoinPredicate(const FilterInfo &filter); + static bool HasAnyValidJoinBinding(const FilterInfo &filter); + static bool HasValidJoinEndpoints(const FilterInfo &filter); + static bool HasDisjointJoinEndpoints(const FilterInfo &filter); + static RelationIndex GetBindingRelation(const ColumnBinding &binding); +}; + +class JoinPredicate { +public: + JoinPredicate(idx_t index, FilterInfo &filter, JoinPredicateClass predicate_class, + ColumnBinding left_equality_binding, ColumnBinding right_equality_binding); + +public: + idx_t GetIndex() const; + FilterInfo &GetFilter() const; + JoinPredicateClass GetPredicateClass() const; + JoinType GetJoinType() const; + ExpressionType GetComparisonType() const; + JoinRelationSet &GetSet() const; + JoinRelationSet &GetLeftSet() const; + JoinRelationSet &GetRightSet() const; + optional_ptr GetLeftSetOptional() const; + optional_ptr GetRightSetOptional() const; + const ColumnBinding &GetStatsBinding(bool left) const; + const ColumnBinding &GetEqualityBinding(bool left) const; + bool HasValidJoinEndpoints() const; + bool HasDisjointJoinEndpoints() const; + bool HasAnyValidStatsBinding() const; + bool HasValidEqualityBindings() const; + bool CanBuildEqualityClosure() const; + bool CanBuildSelectivityDomain() const; + bool IsEquivalencePredicate() const; + void SetEqualityClassIndex(optional_idx equality_class_index); + optional_idx GetEqualityClassIndex() const; + +private: + idx_t index; + reference filter; + JoinPredicateClass predicate_class; + ExpressionType comparison_type; + ColumnBinding left_equality_binding; + ColumnBinding right_equality_binding; + optional_idx equality_class_index; +}; + +struct JoinEqualityPredicateEdge { +public: + JoinEqualityPredicateEdge(JoinPredicate &predicate, RelationIndex left_relation, RelationIndex right_relation, + ColumnBinding left_binding, ColumnBinding right_binding); + +public: + reference predicate; + RelationIndex left_relation; + RelationIndex right_relation; + ColumnBinding left_binding; + ColumnBinding right_binding; +}; + +struct JoinEqualityClass { + idx_t index = DConstants::INVALID_INDEX; + column_binding_set_t columns; + unordered_set relations; + vector edges; +}; + +struct RelationPairEqualitySummary { + bool HasDirectCompositeEquality() const; + + vector direct_equality_class_indices; + column_binding_set_t first_relation_bindings; + column_binding_set_t second_relation_bindings; +}; + +class JoinPredicateModel { +public: + void Clear(); + JoinPredicate &RegisterPredicate(FilterInfo &filter, JoinPredicateClass predicate_class, + ColumnBinding left_equality_binding, ColumnBinding right_equality_binding); + void AddEqualityClass(JoinEqualityClass equality_class); + void AddDirectEqualityPairClass(JoinRelationSet &pair, idx_t equality_class_index, ColumnBinding first_binding, + ColumnBinding second_binding); + + const vector> &GetPredicates() const; + const vector> &GetEqualityJoinPredicates() const; + const vector> &GetSelectivityPredicates() const; + const vector> &GetGraphPredicates() const; + const vector &GetEqualityClasses() const; + bool HasLeftJoinPredicates() const; + + bool HasDirectCompositeEquality(JoinRelationSet &pair) const; + +private: + static bool ContainsClassIndex(const vector &class_indices, idx_t equality_class_index); + +private: + vector> predicates; + vector> all_predicates; + vector> equality_join_predicates; + vector> selectivity_predicates; + vector> graph_predicates; + bool has_left_join_predicates = false; + vector equality_classes; + reference_map_t equality_pairs; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/join_relation.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/join_relation_set.hpp similarity index 84% rename from src/duckdb/src/include/duckdb/optimizer/join_order/join_relation.hpp rename to src/duckdb/src/include/duckdb/optimizer/join_order/join_relation_set.hpp index 10e45af71..684a9d447 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/join_relation.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/join_relation_set.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/optimizer/join_order/join_relation.hpp +// duckdb/optimizer/join_order/join_relation_set.hpp // // //===----------------------------------------------------------------------===// @@ -9,24 +9,25 @@ #pragma once #include "duckdb/common/common.hpp" -#include "duckdb/common/unordered_map.hpp" #include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/optional_ptr.hpp" #include "duckdb/optimizer/join_order/relation_index.hpp" namespace duckdb { //! Set of relations, used in the join graph. struct JoinRelationSet { - JoinRelationSet(unsafe_unique_array relations, idx_t count) - : relations(std::move(relations)), count(count) { - } +public: + JoinRelationSet(unsafe_unique_array relations, idx_t count); +public: string ToString() const; + bool Empty() const; + static bool IsSubset(JoinRelationSet &super, JoinRelationSet &sub); +public: unsafe_unique_array relations; idx_t count; - - static bool IsSubset(JoinRelationSet &super, JoinRelationSet &sub); }; //! The JoinRelationTree is a structure holding all the created JoinRelationSet objects and allowing fast lookup on to @@ -41,6 +42,8 @@ class JoinRelationSetManager { }; public: + //! Get an empty JoinRelationSet + JoinRelationSet &GetEmptyJoinRelationSet(); //! Create or get a JoinRelationSet from a single node with the given index JoinRelationSet &GetJoinRelation(RelationIndex index); //! Create or get a JoinRelationSet from a set of relation bindings @@ -49,13 +52,12 @@ class JoinRelationSetManager { JoinRelationSet &GetJoinRelation(unsafe_unique_array relations, idx_t count); //! Union two sets of relations together and create a new relation set JoinRelationSet &Union(JoinRelationSet &left, JoinRelationSet &right); - // //! Create the set difference of left \ right (i.e. all elements in left that are not in right) - // JoinRelationSet *Difference(JoinRelationSet *left, JoinRelationSet *right); string ToString() const; void Print(); private: JoinRelationTreeNode root; + optional_ptr empty_relation_set; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/plan_enumerator.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/plan_enumerator.hpp index 7d5a7424b..59ac6255e 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/plan_enumerator.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/plan_enumerator.hpp @@ -10,7 +10,7 @@ #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/unordered_set.hpp" -#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" #include "duckdb/optimizer/join_order/cardinality_estimator.hpp" #include "duckdb/optimizer/join_order/query_graph.hpp" #include "duckdb/optimizer/join_order/join_node.hpp" @@ -28,13 +28,12 @@ class QueryGraphManager; class PlanEnumerator { public: - explicit PlanEnumerator(QueryGraphManager &query_graph_manager, CostModel &cost_model, - const QueryGraphEdges &query_graph) - : query_graph(query_graph), query_graph_manager(query_graph_manager), cost_model(cost_model) { - } - static constexpr idx_t THRESHOLD_TO_SWAP_TO_APPROXIMATE = 12; + explicit PlanEnumerator(QueryGraphManager &query_graph_manager, CostModel &cost_model, + const QueryGraphEdges &query_graph); + +public: //! Perform the join order solving void SolveJoinOrder(); void InitLeafPlans(); @@ -42,19 +41,6 @@ class PlanEnumerator { const reference_map_t> &GetPlans() const; private: - //! The set of edges used in the join optimizer - QueryGraphEdges const &query_graph; - //! The total amount of join pairs that have been considered - idx_t pairs = 0; - //! Grant access to the set manager and the relation manager - QueryGraphManager &query_graph_manager; - //! Cost model to evaluate cost of joins - CostModel &cost_model; - //! A map to store the optimal join plan found for a specific JoinRelationSet* - reference_map_t> plans; - - unordered_set join_nodes_in_full_plan; - unique_ptr CreateJoinTree(JoinRelationSet &set, const vector> &possible_connections, DPJoinNode &left, DPJoinNode &right); @@ -65,6 +51,8 @@ class PlanEnumerator { //! Tries to emit a potential join candidate pair. Returns false if too many pairs have already been emitted, //! cancelling the dynamic programming step. bool TryEmitPair(JoinRelationSet &left, JoinRelationSet &right, const vector> &info); + const vector> &GetConnections(JoinRelationSet &left, JoinRelationSet &right); + const vector> &GetAllNeighborRelationSets(vector neighbors); bool EnumerateCmpRecursive(JoinRelationSet &left, JoinRelationSet &right, unordered_set &exclusion_set); @@ -80,6 +68,23 @@ class PlanEnumerator { bool SolveJoinOrderExactly(); //! Solve the join order approximately using a greedy algorithm void SolveJoinOrderApproximately(); + +private: + //! The set of edges used in the join optimizer + QueryGraphEdges const &query_graph; + //! The total amount of join pairs that have been considered + idx_t pairs = 0; + //! Grant access to the set manager and the relation manager + QueryGraphManager &query_graph_manager; + //! Cost model to evaluate cost of joins + CostModel &cost_model; + //! A map to store the optimal join plan found for a specific JoinRelationSet* + reference_map_t> plans; + reference_map_t>>> + connection_cache; + unordered_map>> neighbor_set_cache; + + unordered_set join_nodes_in_full_plan; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph.hpp index e71fb9dba..4f714f6d1 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph.hpp @@ -10,7 +10,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/optional_ptr.hpp" -#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" #include "duckdb/optimizer/join_order/join_node.hpp" #include "duckdb/optimizer/join_order/relation_manager.hpp" #include "duckdb/common/pair.hpp" @@ -23,14 +23,15 @@ namespace duckdb { -class FilterInfo; +class JoinPredicate; struct NeighborInfo { - explicit NeighborInfo(optional_ptr neighbor) : neighbor(neighbor) { - } +public: + explicit NeighborInfo(optional_ptr neighbor); +public: optional_ptr neighbor; - vector> filters; + vector> predicates; }; //! The QueryGraph contains edges between relations and allows edges to be created/queried @@ -55,7 +56,7 @@ class QueryGraphEdges { //! Enumerate all neighbors of a given JoinRelationSet node void EnumerateNeighbors(JoinRelationSet &node, const std::function &callback) const; //! Create an edge in the edge_set - void CreateEdge(JoinRelationSet &left, JoinRelationSet &right, optional_ptr info); + void CreateEdge(JoinRelationSet &left, JoinRelationSet &right, optional_ptr predicate); private: //! Get the QueryEdge of a specific node @@ -64,6 +65,7 @@ class QueryGraphEdges { void EnumerateNeighborsDFS(JoinRelationSet &node, reference info, idx_t index, const std::function &callback) const; +private: QueryEdge root; }; diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp index f1215fc81..f8d0ea703 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp @@ -9,56 +9,26 @@ #pragma once #include "duckdb/common/common.hpp" -#include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/optional_ptr.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/unordered_set.hpp" +#include "duckdb/common/reference_map.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/optimizer/join_order/join_node.hpp" -#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/join_predicate.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" #include "duckdb/optimizer/join_order/query_graph.hpp" #include "duckdb/optimizer/join_order/relation_manager.hpp" -#include "duckdb/planner/column_binding.hpp" -#include "duckdb/planner/logical_operator.hpp" - -#include namespace duckdb { class QueryGraphEdges; struct GenerateJoinRelation { - GenerateJoinRelation(optional_ptr set, unique_ptr op_p) - : set(set), op(std::move(op_p)) { - } - - optional_ptr set; - unique_ptr op; -}; - -//! Filter info struct that is used by the cardinality estimator to set the initial cardinality -//! but is also eventually transformed into a query edge. -class FilterInfo { public: - FilterInfo(unique_ptr filter, JoinRelationSet &set, idx_t filter_index, - JoinType join_type = JoinType::INNER) - : filter(std::move(filter)), set(set), filter_index(filter_index), join_type(join_type) { - } + GenerateJoinRelation(optional_ptr set, unique_ptr op_p); public: - unique_ptr filter; - reference set; - idx_t filter_index; - JoinType join_type; - optional_ptr left_set; - optional_ptr right_set; - ColumnBinding left_binding; - ColumnBinding right_binding; - bool from_residual_predicate = false; - - void SetLeftSet(optional_ptr left_set_new); - void SetRightSet(optional_ptr right_set_new); + optional_ptr set; + unique_ptr op; }; //! The QueryGraphManager manages the process of extracting the reorderable and nonreorderable operations @@ -66,36 +36,39 @@ class FilterInfo { //! When the plan enumerator finishes, the Query Graph Manger can then recreate the logical plan. class QueryGraphManager { public: - explicit QueryGraphManager(ClientContext &context) : relation_manager(context), context(context) { - } - - //! manage relations and the logical operators they represent - RelationManager relation_manager; - - //! A structure holding all the created JoinRelationSet objects - JoinRelationSetManager set_manager; - - ClientContext &context; + explicit QueryGraphManager(ClientContext &context); +public: //! Extract the join relations, optimizing non-reoderable relations when encountered bool Build(JoinOrderOptimizer &optimizer, LogicalOperator &op); - //! Reconstruct the logical plan using the plan found by the plan enumerator unique_ptr Reconstruct(unique_ptr plan); + //! Plan enumerator may not find a full plan and therefore will need to create cross products to create edges. + void CreateQueryGraphCrossProduct(JoinRelationSet &left, JoinRelationSet &right); - //! Get a reference to the QueryGraphEdges structure that stores edges between - //! nodes and hypernodes. + //! Get a reference to the QueryGraphEdges structure that stores edges between nodes and hypernodes. const QueryGraphEdges &GetQueryGraphEdges() const; + const JoinPredicateModel &GetPredicateModel() const; - //! Get a list of the join filters in the join plan than eventually are - //! transformed into the query graph edges - const vector> &GetFilterBindings() const; +private: + void GetColumnBinding(const Expression &expression, ColumnBinding &binding); + void GetEquivalenceBinding(const Expression &expression, ColumnBinding &binding); - //! Plan enumerator may not find a full plan and therefore will need to create cross - //! products to create edges. - void CreateQueryGraphCrossProduct(JoinRelationSet &left, JoinRelationSet &right); + void BindFilterEndpoints(); + void CreateHyperGraphEdges(); + + //! Build the normalized predicate model after filter endpoints and stats bindings are populated. + void BuildPredicateModel(); - //! A map to store the optimal join plan found for a specific JoinRelationSet* + GenerateJoinRelation GenerateJoins(vector> &extracted_relations, JoinRelationSet &set); + +public: + ClientContext &context; + //! manage relations and the logical operators they represent + RelationManager relation_manager; + //! A structure holding all the created JoinRelationSet objects + JoinRelationSetManager set_manager; + //! A map to store the optimal join plan found for a specific JoinRelationSet optional_ptr>> plans; private: @@ -106,12 +79,7 @@ class QueryGraphManager { vector> filters_and_bindings; QueryGraphEdges query_graph; - - void GetColumnBinding(const Expression &expression, ColumnBinding &binding); - - void CreateHyperGraphEdges(); - - GenerateJoinRelation GenerateJoins(vector> &extracted_relations, JoinRelationSet &set); + JoinPredicateModel predicate_model; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_index.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_index.hpp index 714d16ae3..12fabe524 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_index.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_index.hpp @@ -19,12 +19,12 @@ struct RelationIndex { idx_t index; - inline bool operator==(const RelationIndex &rhs) const { + bool operator==(const RelationIndex &rhs) const { return index == rhs.index; - }; - inline bool operator<(const RelationIndex &rhs) const { + } + bool operator<(const RelationIndex &rhs) const { return index < rhs.index; - }; + } bool operator!=(const RelationIndex &other) const { return !(*this == other); } diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp index 3ab7da754..fe31dcb0d 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_manager.hpp @@ -11,10 +11,10 @@ #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/unordered_set.hpp" #include "duckdb/optimizer/join_order/cardinality_estimator.hpp" -#include "duckdb/optimizer/join_order/join_node.hpp" -#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" #include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" #include "duckdb/parser/expression_map.hpp" +#include "duckdb/planner/column_binding_map.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/planner/logical_operator_visitor.hpp" @@ -25,22 +25,21 @@ class FilterInfo; //! Represents a single relation and any metadata accompanying that relation struct SingleJoinRelation { +public: + SingleJoinRelation(LogicalOperator &op, optional_ptr parent); + SingleJoinRelation(LogicalOperator &op, optional_ptr parent, RelationStats stats); + +public: LogicalOperator &op; optional_ptr parent; RelationStats stats; - - SingleJoinRelation(LogicalOperator &op, optional_ptr parent) : op(op), parent(parent) { - } - SingleJoinRelation(LogicalOperator &op, optional_ptr parent, RelationStats stats) - : op(op), parent(parent), stats(std::move(stats)) { - } }; class RelationManager { public: - explicit RelationManager(ClientContext &context) : context(context) { - } + explicit RelationManager(ClientContext &context); +public: idx_t NumRelations(); bool ExtractJoinRelations(JoinOrderOptimizer &optimizer, LogicalOperator &input_op, @@ -73,6 +72,12 @@ class RelationManager { void PrintRelationStats(); +private: + optional_ptr GetJoinRelations(column_binding_set_t &column_bindings, + JoinRelationSetManager &set_manager); + void GetColumnBindingsFromExpression(const Expression &expression, column_binding_set_t &column_bindings); + void GetColumnBindingsFromOperator(LogicalOperator &op, column_binding_set_t &column_bindings); + private: ClientContext &context; //! Set of all relations considered in the join optimizer diff --git a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp index 82ac8332e..f32af4dff 100644 --- a/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/join_order/relation_statistics_helper.hpp @@ -7,29 +7,26 @@ //===----------------------------------------------------------------------===// #pragma once -#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/expression_filter.hpp" #include "duckdb/planner/logical_operator.hpp" namespace duckdb { class CardinalityEstimator; +enum class DistinctCountSource : uint8_t { CARDINALITY, MIN_MAX, HLL, EXACT }; + struct DistinctCount { + DistinctCount(idx_t distinct_count, DistinctCountSource source); + idx_t distinct_count; - bool from_hll; + DistinctCountSource source; }; struct ExpressionBinding { public: - bool FoundExpression() const { - return expression; - } - bool FoundColumnRef() const { - if (!FoundExpression()) { - return false; - } - return expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF; - } + bool FoundExpression() const; + bool FoundColumnRef() const; public: optional_ptr expression; @@ -38,18 +35,19 @@ struct ExpressionBinding { }; struct RelationStats { - // column_id -> estimated distinct count for column +public: + RelationStats(); + +public: + //! column_id -> estimated distinct count for column vector column_distinct_count; idx_t cardinality; double filter_strength = 1; bool stats_initialized = false; - // for debug, column names and tables - vector column_names; - string table_name; - - RelationStats() : cardinality(1), filter_strength(1), stats_initialized(false) { - } + //! for debug, column names and tables + vector column_names; + Identifier table_name; }; class RelationStatisticsHelper { @@ -58,8 +56,6 @@ class RelationStatisticsHelper { public: static idx_t InspectTableFilter(idx_t cardinality, const TableFilter &filter, BaseStatistics &base_stats); - // static idx_t InspectConjunctionOR(idx_t cardinality, idx_t column_index, ConjunctionOrFilter &filter, - // BaseStatistics &base_stats); //! Extract Statistics from a LogicalGet. static RelationStats ExtractGetStats(LogicalGet &get, ClientContext &context); static RelationStats ExtractDelimGetStats(LogicalDelimGet &delim_get, ClientContext &context); @@ -81,7 +77,10 @@ class RelationStatisticsHelper { static void CopyRelationStats(RelationStats &to, const RelationStats &from); private: - static idx_t GetDistinctCount(LogicalGet &get, ClientContext &context, const ColumnIndex &column_id); + static unique_ptr GetColumnStatistics(LogicalGet &get, ClientContext &context, + const ColumnIndex &column_id); + static DistinctCount GetDistinctCount(LogicalGet &get, ClientContext &context, const ColumnIndex &column_id, + idx_t base_table_cardinality); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/expression_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/expression_matcher.hpp index 3a96c4f73..55407358c 100644 --- a/src/duckdb/src/include/duckdb/optimizer/matcher/expression_matcher.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/expression_matcher.hpp @@ -99,6 +99,18 @@ class InClauseExpressionMatcher : public ExpressionMatcher { bool Match(Expression &expr, vector> &bindings) override; }; +class InUniformExpressionMatcher : public ExpressionMatcher { +public: + InUniformExpressionMatcher() : ExpressionMatcher(ExpressionClass::BOUND_OPERATOR) { + } + + //! The matchers for the probe and child expressions + unique_ptr probe_matcher; + unique_ptr child_matcher; + + bool Match(Expression &expr, vector> &bindings) override; +}; + class ConjunctionExpressionMatcher : public ExpressionMatcher { public: ConjunctionExpressionMatcher() diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/function_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/function_matcher.hpp index 162789c3e..e63b91ffd 100644 --- a/src/duckdb/src/include/duckdb/optimizer/matcher/function_matcher.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/function_matcher.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/unordered_set.hpp" #include @@ -20,9 +21,9 @@ class FunctionMatcher { virtual ~FunctionMatcher() { } - virtual bool Match(const string &name) = 0; + virtual bool Match(const Identifier &name) = 0; - static bool Match(unique_ptr &matcher, const string &name) { + static bool Match(unique_ptr &matcher, const Identifier &name) { if (!matcher) { return true; } @@ -33,29 +34,29 @@ class FunctionMatcher { //! The SpecificFunctionMatcher class matches a single specified function name class SpecificFunctionMatcher : public FunctionMatcher { public: - explicit SpecificFunctionMatcher(string name_p) : name(std::move(name_p)) { + explicit SpecificFunctionMatcher(Identifier name_p) : name(std::move(name_p)) { } - bool Match(const string &matched_name) override { + bool Match(const Identifier &matched_name) override { return matched_name == this->name; } private: - string name; + Identifier name; }; //! The ManyFunctionMatcher class matches a set of functions class ManyFunctionMatcher : public FunctionMatcher { public: - explicit ManyFunctionMatcher(unordered_set names_p) : names(std::move(names_p)) { + explicit ManyFunctionMatcher(identifier_set_t names_p) : names(std::move(names_p)) { } - bool Match(const string &name) override { + bool Match(const Identifier &name) override { return names.find(name) != names.end(); } private: - unordered_set names; + identifier_set_t names; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher.hpp b/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher.hpp index 892f35f30..6e15e9cb5 100644 --- a/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/matcher/type_matcher.hpp @@ -53,6 +53,20 @@ class SetTypesMatcher : public TypeMatcher { vector types; }; +//! The SpecificTypeIdMatcher class matches only a single specified LogicalTypeId +class SpecificTypeIdMatcher : public TypeMatcher { +public: + explicit SpecificTypeIdMatcher(LogicalTypeId id_p) : id(id_p) { + } + + bool Match(const LogicalType &type_p) override { + return type_p.id() == this->id; + } + +private: + LogicalTypeId id; +}; + //! The NumericTypeMatcher class matches any numeric type (DECIMAL, INTEGER, etc...) class NumericTypeMatcher : public TypeMatcher { public: diff --git a/src/duckdb/src/include/duckdb/optimizer/optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/optimizer.hpp index 8c7d086f2..26646b0dd 100644 --- a/src/duckdb/src/include/duckdb/optimizer/optimizer.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/optimizer.hpp @@ -17,6 +17,7 @@ namespace duckdb { class Binder; +class SQLStatement; class Optimizer { public: @@ -30,6 +31,9 @@ class Optimizer { bool OptimizerDisabled(OptimizerType type); static bool OptimizerDisabled(ClientContext &context, OptimizerType type); + //! Pre-binder statement-level optimization pass + void OptimizeStatement(unique_ptr &statement); + public: ClientContext &context; Binder &binder; @@ -42,16 +46,17 @@ class Optimizer { public: // helper functions - unique_ptr BindScalarFunction(const string &name, unique_ptr c1); - unique_ptr BindScalarFunction(const string &name, unique_ptr c1, unique_ptr c2); - unique_ptr BindScalarFunction(const string &name, unique_ptr c1, unique_ptr c2, - unique_ptr c3); + unique_ptr BindScalarFunction(const Identifier &name, unique_ptr c1); + unique_ptr BindScalarFunction(const Identifier &name, unique_ptr c1, + unique_ptr c2); + unique_ptr BindScalarFunction(const Identifier &name, unique_ptr c1, + unique_ptr c2, unique_ptr c3); private: unique_ptr plan; private: - unique_ptr BindScalarFunction(const string &name, vector> children); + unique_ptr BindScalarFunction(const Identifier &name, vector> children); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/partial_aggregate_pushdown.hpp b/src/duckdb/src/include/duckdb/optimizer/partial_aggregate_pushdown.hpp new file mode 100644 index 000000000..2295e6003 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/partial_aggregate_pushdown.hpp @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/partial_aggregate_pushdown.hpp +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" + +namespace duckdb { +class Optimizer; + +//! The PartialAggregatePushdown optimizer pushes SUM aggregates below joins when this can reduce join work +class PartialAggregatePushdown : public LogicalOperatorVisitor { +public: + explicit PartialAggregatePushdown(Optimizer &optimizer); + + void VisitOperator(unique_ptr &op) override; + +private: + bool TryPushdownAggregate(unique_ptr &op); + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + +private: + Optimizer &optimizer; + column_binding_map_t replacement_map; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/remote_pushdown_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/remote_pushdown_optimizer.hpp new file mode 100644 index 000000000..b3d5daa73 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/remote_pushdown_optimizer.hpp @@ -0,0 +1,164 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/remote_pushdown_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/parser/tokens.hpp" + +namespace duckdb { +class Binder; +class Catalog; +class CatalogEntry; +class ExpressionListRef; +class FunctionExpression; +class JoinRef; +class SubqueryRef; +class TableFunctionRef; +class TableRef; +class QueryNode; +class RecursiveCTENode; +class SetOperationNode; +class InsertQueryNode; +class DeleteQueryNode; +class UpdateQueryNode; + +enum class CatalogReferenceType { NO_CATALOG_REFERENCED, SINGLE_REMOTE_CATALOG, UNKNOWN_CATALOG_REFERENCE }; + +struct CatalogPushdownResult { + explicit CatalogPushdownResult( + CatalogReferenceType reference_type = CatalogReferenceType::UNKNOWN_CATALOG_REFERENCE); + static CatalogPushdownResult Unknown(); + static CatalogPushdownResult NoCatalogReference(); + static CatalogPushdownResult RemoteReference(Catalog &catalog); + + CatalogReferenceType reference_type = CatalogReferenceType::UNKNOWN_CATALOG_REFERENCE; + optional_ptr catalog; + vector> used_expressions; + vector> used_table_constructs; + vector> used_nodes; +}; + +//! Whether an expression (tree) can be constant-folded +enum class ExpressionFoldability { FOLDABLE, NOT_FOLDABLE }; + +//! Whether an expression itself may be constant-folded +enum class ExpressionFoldingMode { + //! The expression (or any subtree within) can be folded + FOLD_EXPRESSION, + //! Only subtrees within the expression can be folded, the expression itself must be kept - + //! used for positional contexts (ORDER BY / GROUP BY / DISTINCT ON, where a folded integer + //! literal would become a positional reference) and after a failed folding attempt + FOLD_CHILDREN_ONLY +}; + +//! The result of analyzing or rewriting a single expression +struct ExpressionPushdownResult { + //! The catalog analysis result + CatalogPushdownResult result {CatalogReferenceType::NO_CATALOG_REFERENCED}; + ExpressionFoldability foldability = ExpressionFoldability::NOT_FOLDABLE; +}; + +struct RemotePushdownState { + bool search_path_initialized = false; + vector> remote_catalogs_in_search_path; + vector local_catalogs_in_search_path; +}; + +class RemotePushdownOptimizer { +public: + explicit RemotePushdownOptimizer(Binder &binder); + explicit RemotePushdownOptimizer(optional_ptr parent); + + void Rewrite(unique_ptr &statement); + +private: + void FindRemoteCatalogsInSearchPath(); + CatalogPushdownResult Rewrite(QueryNode &node); + //! The per-type query node handlers are deliberately NOT overloads of Rewrite: calling + //! Rewrite with a statically-typed node (e.g. *InsertStatement::node, which is an + //! InsertQueryNode) must dispatch through Rewrite(QueryNode &) so the node-level checks + //! (CTE handling, catalog support verification) are applied + CatalogPushdownResult RewriteNode(SelectNode &node); + CatalogPushdownResult RewriteNode(SetOperationNode &node); + CatalogPushdownResult RewriteNode(InsertQueryNode &node); + CatalogPushdownResult RewriteNode(DeleteQueryNode &node); + CatalogPushdownResult RewriteNode(UpdateQueryNode &node); + CatalogPushdownResult Rewrite(unique_ptr &ref); + CatalogPushdownResult Rewrite(ExpressionListRef &ref); + CatalogPushdownResult RewriteNode(RecursiveCTENode &node); + CatalogPushdownResult Rewrite(JoinRef &ref); + CatalogPushdownResult Rewrite(SubqueryRef &ref); + CatalogPushdownResult Rewrite(TableFunctionRef &ref); + CatalogPushdownResult Rewrite(BaseTableRef &ref); + + enum class ConstantFoldResult { + //! The expression is not a foldable constant expression (contains columns, is volatile, ...) + NOT_FOLDABLE, + //! The expression was replaced with its locally-evaluated constant result + FOLDED, + //! The expression is constant but evaluating it raises an error - the query must be + //! executed locally so the user sees DuckDB's error message + FOLD_ERROR + }; + //! Rewrite an expression, constant-folding maximal foldable subtrees + CatalogPushdownResult Rewrite(unique_ptr &expr, + ExpressionFoldingMode mode = ExpressionFoldingMode::FOLD_EXPRESSION); + //! Rewrite an expression, deferring the folding of foldable subtrees to the parent (or the root) + ExpressionPushdownResult RewriteExpression(unique_ptr &expr, ExpressionFoldingMode mode); + //! Fold a maximal foldable subtree and record the resulting constant + CatalogPushdownResult FoldExpression(unique_ptr &expr); + //! Per-expression-class catalog analysis (catalog qualification, subqueries, local tables), + //! which also determines whether the expression can be constant-folded + ExpressionPushdownResult AnalyzeExpression(const ParsedExpression &expr); + ExpressionPushdownResult AnalyzeExpression(const SubqueryExpression &expr); + ExpressionPushdownResult AnalyzeExpression(const FunctionExpression &expr); + ExpressionPushdownResult AnalyzeExpression(const WindowExpression &expr); + ExpressionPushdownResult AnalyzeExpression(const TypeExpression &expr); + ExpressionPushdownResult AnalyzeExpression(const ColumnRefExpression &expr); + //! Bind and evaluate an expression locally, replacing it with the resulting constant + ConstantFoldResult TryConstantFold(unique_ptr &expr); + + CatalogPushdownResult CheckCatalogQualification(const ParsedExpression &expr, const Identifier &catalog_name, + const Identifier &schema_name); + CatalogPushdownResult RewriteTableFunctionOnly(TableFunctionRef &ref); + + //! Records a BaseTableRef's name, alias and columns as local for correlated subquery detection + void TrackLocalTable(const TableRef &ref); + void TrackLocalTable(const BaseTableRef &ref); + void TrackLocalTable(const TableFunctionRef &ref); + void TrackLocalTable(const SubqueryRef &ref); + + void FinishPushdown(unique_ptr &statement, CatalogPushdownResult result); + void FinishPushdown(unique_ptr &node, CatalogPushdownResult result); + + static CatalogPushdownResult Merge(CatalogPushdownResult a, CatalogPushdownResult b); + unique_ptr CreateRemoteFunctionRef(CatalogPushdownResult &result, unique_ptr node); + static void StripCatalogName(SQLStatement &statement, const Identifier &catalog_name); + static void StripCatalogName(QueryNode &node, const Identifier &catalog_name); + static void StripCatalogName(TableRef &ref, const Identifier &catalog_name); + //! Strip catalog prefix from expression column refs. When strip_subquery_bodies=false, leaves subquery + //! bodies untouched (used for partial pushdown where inner subqueries are not being pushed). + static void StripCatalogName(ParsedExpression &expr, const Identifier &catalog_name); + bool RefersToLocalTable(const ColumnRefExpression &col_ref) const; + + bool RefersToCTE(const Identifier &cte_name, CatalogPushdownResult &result) const; + +private: + Binder &binder; + optional_ptr parent; + unique_ptr owned_pushdown_state; + RemotePushdownState &pushdown_state; + //! Names/aliases of non-remote tables seen in the current FROM scope, used to detect correlated subqueries + identifier_set_t local_table_names; + //! CTE name to catalog pushdown result, populated as CTEs are analyzed (inner scopes restore on exit) + identifier_map_t cte_results; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp index 775ba01fb..1875af2c3 100644 --- a/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/rule/constant_order_normalization.hpp @@ -12,8 +12,9 @@ namespace duckdb { -// Move constant expression parameters to the left in expression(i.e. x + 2 + y + 2 => 2 + 2 + x + y) -// for convenience of other rules(i.e. ConstantFoldingRule). +// Group constant expression parameters together in commutative arithmetic expressions to expose +// constant folding opportunities. After folding collapses multiplication constants into a single +// value, keep that constant on the right to match the canonical binary operator shape. class ConstantOrderNormalizationRule : public Rule { public: explicit ConstantOrderNormalizationRule(ExpressionRewriter &rewriter); diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/contains_to_in_clause.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/contains_to_in_clause.hpp new file mode 100644 index 000000000..c792167e9 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/rule/contains_to_in_clause.hpp @@ -0,0 +1,13 @@ +#pragma once +#include "duckdb/optimizer/rule.hpp" + +namespace duckdb { + +class ContainsToInClauseRule : public Rule { +public: + explicit ContainsToInClauseRule(ExpressionRewriter &rewriter); + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/rule/in_clause_simplification.hpp b/src/duckdb/src/include/duckdb/optimizer/rule/in_clause_simplification.hpp index 0afabeae8..6f6e48e14 100644 --- a/src/duckdb/src/include/duckdb/optimizer/rule/in_clause_simplification.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/rule/in_clause_simplification.hpp @@ -21,4 +21,13 @@ class InClauseSimplificationRule : public Rule { bool is_root) override; }; +// The in enum simplification rule rewrites cases where a left enum::VARCHAR is compared with right string literals +class InEnumSimplificationRule : public Rule { +public: + explicit InEnumSimplificationRule(ExpressionRewriter &rewriter); + + unique_ptr Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, + bool is_root) override; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp b/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp index 461e1eccc..a528fa1cd 100644 --- a/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/unnest_rewriter.hpp @@ -30,7 +30,7 @@ struct LHSBinding { } ColumnBinding binding; LogicalType type; - string alias; + Identifier alias; }; //! The UnnestRewriterPlanUpdater updates column bindings after changing the operator plan diff --git a/src/duckdb/src/include/duckdb/optimizer/window_self_join.hpp b/src/duckdb/src/include/duckdb/optimizer/window_self_join.hpp index bac7e9976..3ed80d451 100644 --- a/src/duckdb/src/include/duckdb/optimizer/window_self_join.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/window_self_join.hpp @@ -16,6 +16,8 @@ namespace duckdb { class WindowSelfJoinOptimizer { public: + static bool CanOptimize(const LogicalOperator &op); + explicit WindowSelfJoinOptimizer(Optimizer &optimizer); unique_ptr Optimize(unique_ptr op); diff --git a/src/duckdb/src/include/duckdb/parallel/async_result.hpp b/src/duckdb/src/include/duckdb/parallel/async_result.hpp index fada0854d..9915a6f60 100644 --- a/src/duckdb/src/include/duckdb/parallel/async_result.hpp +++ b/src/duckdb/src/include/duckdb/parallel/async_result.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/unique_ptr.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/enums/operator_result_type.hpp" +#include "duckdb/common/enums/task_scheduler_type.hpp" namespace duckdb { @@ -20,8 +21,8 @@ class TaskExecutor; class Executor; enum class AsyncResultsExecutionMode : uint8_t { - SYNCHRONOUS, // BLOCKED should not bubble up, and they should be executed synchronously - TASK_EXECUTOR // BLOCKED is allowed + SYNCHRONOUS, //! BLOCKED should not bubble up, and they should be executed synchronously + TASK_EXECUTOR //! BLOCKED is allowed }; class AsyncTask { @@ -37,25 +38,26 @@ class AsyncResult { AsyncResult() = default; AsyncResult(AsyncResult &&) = default; AsyncResult(SourceResultType t); // NOLINT - explicit AsyncResult(vector> &&task); + explicit AsyncResult(vector> &&task, + TaskSchedulerType pool_type = TaskSchedulerType::REGULAR); AsyncResult &operator=(SourceResultType t); AsyncResult &operator=(AsyncResultType t); AsyncResult &operator=(AsyncResult &&) noexcept; - // Schedule held async_tasks into the Executor, eventually unblocking InterruptState - // needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform - // into INVALID + //! Schedule held async_tasks into the Executor, eventually unblocking InterruptState + //! needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform + //! into INVALID void ScheduleTasks(InterruptState &interrupt_state, Executor &executor); - // Execute tasks synchronously at callsite - // needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform - // into HAVE_MORE_OUTPUT + //! Execute tasks synchronously at callsite + //! needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform + //! into HAVE_MORE_OUTPUT void ExecuteTasksSynchronously(); static AsyncResultType GetAsyncResultType(SourceResultType s); - // Check whether there are tasks associated + //! Check whether there are tasks associated bool HasTasks() const; AsyncResultType GetResultType() const; - // Extract associated tasks, moving them away, will empty async_tasks and transform to INVALID + //! Extract associated tasks, moving them away, will empty async_tasks and transform to INVALID vector> &&ExtractAsyncTasks(); #ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE @@ -68,5 +70,7 @@ class AsyncResult { private: AsyncResultType result_type {AsyncResultType::INVALID}; vector> async_tasks {}; + //! The thread pool that the async_tasks are scheduled onto when BLOCKED + TaskSchedulerType pool_type {TaskSchedulerType::REGULAR}; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/concurrentqueue.hpp b/src/duckdb/src/include/duckdb/parallel/concurrentqueue.hpp index 4ab2a8165..7ccd1656e 100644 --- a/src/duckdb/src/include/duckdb/parallel/concurrentqueue.hpp +++ b/src/duckdb/src/include/duckdb/parallel/concurrentqueue.hpp @@ -25,12 +25,14 @@ class BlockingConcurrentQueue; struct ProducerToken { //! Constructor - template - explicit ProducerToken(ConcurrentQueue &); - //! Constructor - template - explicit ProducerToken(BlockingConcurrentQueue &); + template + explicit ProducerToken(ConcurrentQueue &) { + } //! Constructor + template + explicit ProducerToken(BlockingConcurrentQueue &) { + } + //! Move constructor ProducerToken(ProducerToken &&) { } //! Is valid token? @@ -39,6 +41,20 @@ struct ProducerToken { } }; +struct ConsumerToken { + //! Constructor + template + explicit ConsumerToken(ConcurrentQueue &) { + } + //! Constructor + template + explicit ConsumerToken(BlockingConcurrentQueue &) { + } + //! Move constructor + ConsumerToken(ConsumerToken &&) { + } +}; + template class ConcurrentQueue { private: @@ -84,6 +100,11 @@ class ConcurrentQueue { } return max; } + //! Dequeues several elements from the queue using a consumer token. + template + size_t try_dequeue_bulk(ConsumerToken &, It itemFirst, size_t max) { + return try_dequeue_bulk(itemFirst, max); + } template bool enqueue_bulk(It itemFirst, size_t count) { @@ -92,6 +113,11 @@ class ConcurrentQueue { } return true; } + //! Enqueues several elements using a producer token. + template + bool enqueue_bulk(ProducerToken const &, It itemFirst, size_t count) { + return enqueue_bulk(itemFirst, count); + } }; } // namespace duckdb_moodycamel diff --git a/src/duckdb/src/include/duckdb/parallel/interrupt.hpp b/src/duckdb/src/include/duckdb/parallel/interrupt.hpp index a2e3b10f4..4e691c714 100644 --- a/src/duckdb/src/include/duckdb/parallel/interrupt.hpp +++ b/src/duckdb/src/include/duckdb/parallel/interrupt.hpp @@ -68,6 +68,11 @@ class StateWithBlockableTasks { can_block = false; } + void ResetBlocking() DUCKDB_REQUIRES(lock) { + can_block = true; + blocked_tasks.clear(); + } + //! Add a task to 'blocked_tasks' before returning SourceResultType::BLOCKED (must hold the lock) bool BlockTask(const InterruptState &interrupt_state) DUCKDB_REQUIRES(lock) { if (can_block) { diff --git a/src/duckdb/src/include/duckdb/parallel/meta_pipeline.hpp b/src/duckdb/src/include/duckdb/parallel/meta_pipeline.hpp index f5f9b4ade..38f520c48 100644 --- a/src/duckdb/src/include/duckdb/parallel/meta_pipeline.hpp +++ b/src/duckdb/src/include/duckdb/parallel/meta_pipeline.hpp @@ -76,6 +76,8 @@ class MetaPipeline : public enable_shared_from_this { bool HasFinishEvent(Pipeline &pipeline) const; //! Whether this pipeline is part of a PipelineFinishEvent optional_ptr GetFinishGroup(Pipeline &pipeline) const; + //! Get the child MetaPipelines + const vector> &GetChildren() const; public: //! Build the MetaPipeline with 'op' as the first operator (excl. the shared sink) diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline.hpp index 0d8a665a5..ee1f3510a 100644 --- a/src/duckdb/src/include/duckdb/parallel/pipeline.hpp +++ b/src/duckdb/src/include/duckdb/parallel/pipeline.hpp @@ -91,11 +91,17 @@ class Pipeline : public enable_shared_from_this { void Ready(); void Reset(); void ResetSink(); + void ResetSinkForReschedule(); + void ResetForReschedule(bool reset_sink); void ResetSource(bool force); void ClearSource(); void Schedule(shared_ptr &event); void PrepareFinalize(); + //! Compute the maximum number of threads for parallel execution of this pipeline + //! Returns 1 if the pipeline cannot be parallelized + idx_t GetMaxThreads(); + string ToString() const; void Print() const; void PrintDependencies() const; @@ -159,6 +165,7 @@ class Pipeline : public enable_shared_from_this { void ScheduleSequentialTask(shared_ptr &event); bool LaunchScanTasks(shared_ptr &event, idx_t max_threads); + bool TryGetMaxThreads(idx_t &max_threads); bool ScheduleParallel(shared_ptr &event); }; diff --git a/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp b/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp index 73cd38895..1f871c5b2 100644 --- a/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp +++ b/src/duckdb/src/include/duckdb/parallel/pipeline_executor.hpp @@ -78,6 +78,15 @@ class PipelineExecutor { //! Registers the task in the interrupt_state to allow Source/Sink operators to block the task void SetTaskForInterrupts(weak_ptr current_task); + //! Replaces the interrupt state used by source/sink/finalize calls + void SetInterruptState(InterruptState interrupt_state_p); + + //! Resets the executor for re-execution while reusing allocated intermediate buffers. + //! Reuses local source/sink/operator states where operators provide explicit reset hooks, + //! and falls back to recreation otherwise. + void Reset(); + //! Prepare the executor for another execution, skipping Reset() on the very first run. + void PrepareForExecution(); private: //! The pipeline to process @@ -112,6 +121,8 @@ class PipelineExecutor { //! Partition info that is used by this executor OperatorPartitionInfo required_partition_info; + //! Source operator indicated that there is no more output possible + bool exhausted_source = false; //! Source or intermediate operator indicated that there is no more output possible bool exhausted_pipeline = false; //! Flushing of intermediate operators has started @@ -119,6 +130,9 @@ class PipelineExecutor { //! Flushing of caching operators is done bool done_flushing = false; + //! Whether FinishSource has already been called (so FinalizeSource is skipped in PushFinalize) + bool source_profiling_finalized = false; + //! This flag is set when the pipeline gets interrupted by the Sink -> the final_chunk should be re-sink-ed. bool remaining_sink_chunk = false; @@ -130,6 +144,8 @@ class PipelineExecutor { idx_t flushing_idx; //! Whether the current flushing_idx should be flushed: this needs to be stored to make flushing code re-entrant bool should_flush_current_idx = true; + //! Whether this executor has already run at least once + bool has_executed = false; private: void StartOperator(PhysicalOperator &op); diff --git a/src/duckdb/src/include/duckdb/parallel/task_executor.hpp b/src/duckdb/src/include/duckdb/parallel/task_executor.hpp index 912ad80b3..40d2f8dfe 100644 --- a/src/duckdb/src/include/duckdb/parallel/task_executor.hpp +++ b/src/duckdb/src/include/duckdb/parallel/task_executor.hpp @@ -13,6 +13,7 @@ #include "duckdb/common/optional_ptr.hpp" #include "duckdb/parallel/task.hpp" #include "duckdb/execution/task_error_manager.hpp" +#include "duckdb/common/enums/task_scheduler_type.hpp" namespace duckdb { class TaskScheduler; @@ -20,8 +21,8 @@ class TaskScheduler; //! The TaskExecutor is a helper class that enables parallel scheduling and execution of tasks class TaskExecutor { public: - explicit TaskExecutor(ClientContext &context); - explicit TaskExecutor(TaskScheduler &scheduler); + explicit TaskExecutor(ClientContext &context, TaskSchedulerType type = TaskSchedulerType::REGULAR); + explicit TaskExecutor(TaskScheduler &scheduler, TaskSchedulerType type = TaskSchedulerType::REGULAR); ~TaskExecutor(); //! Push an error into the TaskExecutor @@ -44,6 +45,7 @@ class TaskExecutor { private: TaskScheduler &scheduler; + const TaskSchedulerType type; TaskErrorManager error_manager; unique_ptr token; atomic completed_tasks; diff --git a/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp b/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp index 3ed5cb2fc..834ed8d2a 100644 --- a/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp +++ b/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp @@ -13,103 +13,110 @@ #include "duckdb/common/mutex.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/parallel/task.hpp" +#include "duckdb/common/array.hpp" +#include "duckdb/common/enums/task_scheduler_type.hpp" namespace duckdb { -struct ConcurrentQueue; struct QueueProducerToken; class ClientContext; +struct DBConfig; class DatabaseInstance; class TaskScheduler; - -struct SchedulerThread; +class TaskSchedulerPool; +class TaskSchedulerQueue; struct ProducerToken { - ProducerToken(TaskScheduler &scheduler, unique_ptr token); +public: + explicit ProducerToken(array, TASK_SCHEDULER_TYPE_COUNT> &queues); ~ProducerToken(); - TaskScheduler &scheduler; - unique_ptr token; +public: + QueueProducerToken &GetQueueProducerToken(TaskSchedulerType pool_type); + +public: mutex producer_lock; + +private: + array, TASK_SCHEDULER_TYPE_COUNT> tokens; }; //! The TaskScheduler is responsible for managing tasks and threads class TaskScheduler { - // timeout for semaphore wait, default 5ms + //! Timeout for semaphore wait, default 5ms constexpr static int64_t TASK_TIMEOUT_USECS = 5000; public: explicit TaskScheduler(DatabaseInstance &db); ~TaskScheduler(); +public: DUCKDB_API static TaskScheduler &GetScheduler(ClientContext &context); DUCKDB_API static TaskScheduler &GetScheduler(DatabaseInstance &db); unique_ptr CreateProducer(); - //! Schedule a task to be executed by the task scheduler + //! Returns the number of threads + DUCKDB_API int32_t NumberOfThreads(); + + idx_t GetNumberOfTasks() const; + idx_t GetProducerCount() const; + idx_t GetTaskCountForProducer(ProducerToken &token) const; + + //! Schedule a task to be executed by the task scheduler in the given pool + void ScheduleTask(ProducerToken &producer, shared_ptr task, TaskSchedulerType pool_type); + void ScheduleTasks(ProducerToken &producer, vector> &tasks, TaskSchedulerType pool_type); + //! Helpers for regular tasks void ScheduleTask(ProducerToken &producer, shared_ptr task); void ScheduleTasks(ProducerToken &producer, vector> &tasks); //! Fetches a task from a specific producer, returns true if successful or false if no tasks were available bool GetTaskFromProducer(ProducerToken &token, shared_ptr &task); //! Run tasks forever until "marker" is set to false, "marker" must remain valid until the thread is joined void ExecuteForever(atomic *marker); + void ExecuteForever(atomic *marker, TaskSchedulerType pool_type); //! Run tasks until `marker` is set to false, `max_tasks` have been completed, or until there are no more tasks //! available. Returns the number of tasks that were completed. idx_t ExecuteTasks(atomic *marker, idx_t max_tasks); //! Run tasks until `max_tasks` have been completed, or until there are no more tasks available void ExecuteTasks(idx_t max_tasks); + //! Send signals to n threads, signalling for them to wake up and attempt to execute a task + void Signal(idx_t n); //! Sets the amount of background threads to be used for execution, based on the number of total threads //! and the number of external threads. External threads, e.g. the main thread, will also be used for execution. //! Launches `total_threads - external_threads` background worker threads. void SetThreads(idx_t total_threads, idx_t external_threads); - + void SetAsyncThreads(idx_t n); void RelaunchThreads(); - //! Returns the number of threads - DUCKDB_API int32_t NumberOfThreads(); - - idx_t GetNumberOfTasks() const; - idx_t GetProducerCount() const; - idx_t GetTaskCountForProducer(ProducerToken &token) const; - - //! Send signals to n threads, signalling for them to wake up and attempt to execute a task - void Signal(idx_t n); - //! Yield to other threads static void YieldThread(); - - //! Set the allocator flush threshold - void SetAllocatorFlushThreshold(idx_t threshold); - //! Sets the allocator background thread - void SetAllocatorBackgroundThreads(bool enable); - //! Get the number of the CPU on which the calling thread is currently executing. //! Fallback to calling thread id if CPU number is not available. //! Result do not need to be exact 'return 0' is a valid fallback strategy static idx_t GetEstimatedCPUId(); private: - void RelaunchThreadsInternal(int32_t n, bool destroy); + TaskSchedulerPool &GetPool(TaskSchedulerType pool_type); + TaskSchedulerQueue &GetQueue(TaskSchedulerType pool_type) const; + + //! Fetches a task, returns true if successful or false if no tasks were available + bool GetTaskInternal(shared_ptr &task); + bool GetTaskInternal(shared_ptr &task, TaskSchedulerType pool_type); + bool TryDequeueAndProcessTask(const DBConfig &config, TaskSchedulerQueue &queue, shared_ptr &task); + + void SetThreadsInternal(TaskSchedulerType pool_type, idx_t n); + void Signal(TaskSchedulerType pool_type, idx_t n); + void SignalAllPools(idx_t n); + void SignalForTaskType(TaskSchedulerType task_type, idx_t n); private: DatabaseInstance &db; - //! The task queue - unique_ptr queue; //! Lock for modifying the thread count mutex thread_lock; - //! The active background threads of the task scheduler - vector> threads; - //! Markers used by the various threads, if the markers are set to "false" the thread execution is stopped - vector>> markers; - //! The threshold after which to flush the allocator after completing a task - atomic allocator_flush_threshold; - //! Whether allocator background threads are enabled - atomic allocator_background_threads; - //! Requested thread count (set by the 'threads' setting) - atomic requested_thread_count; - //! The amount of threads currently running - atomic current_thread_count; + //! The thread pools + array, TASK_SCHEDULER_TYPE_COUNT> pools; + //! The task queues + array, TASK_SCHEDULER_TYPE_COUNT> queues; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/task_scheduler_pool.hpp b/src/duckdb/src/include/duckdb/parallel/task_scheduler_pool.hpp new file mode 100644 index 000000000..d3d679761 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/task_scheduler_pool.hpp @@ -0,0 +1,57 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/task_scheduler_pool.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/enums/task_scheduler_type.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/atomic.hpp" + +namespace duckdb { + +class TaskScheduler; +class DatabaseInstance; +struct LightWeightSemaphoreWrapper; +struct QueueProducerToken; +struct TaskSchedulerThread; + +class TaskSchedulerPool { +public: + explicit TaskSchedulerPool(DatabaseInstance &db, TaskSchedulerType pool_type); + ~TaskSchedulerPool(); + +public: + void SetThreads(idx_t n); + int32_t NumberOfThreads(); + void RelaunchThreads(TaskScheduler &scheduler, bool destroy); + void Signal(idx_t n); +#ifndef DUCKDB_NO_THREADS + void Wait(); + bool Wait(int64_t timeout_usecs); +#endif + +private: + DatabaseInstance &db; + //! The type of this pool + const TaskSchedulerType pool_type; + //! The active background threads of the task scheduler + vector> threads; + //! Markers used by the various threads, if the markers are set to "false" the thread execution is stopped + vector>> markers; + //! Requested thread count (set by the 'threads' setting) + atomic requested_thread_count; + //! The amount of threads currently running + atomic current_thread_count; +#ifndef DUCKDB_NO_THREADS + //! Semaphore to signal threads in this pool to wake up and execute a task + unique_ptr semaphore; +#endif +}; + +}; // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/task_scheduler_queue.hpp b/src/duckdb/src/include/duckdb/parallel/task_scheduler_queue.hpp new file mode 100644 index 000000000..6fcda5558 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parallel/task_scheduler_queue.hpp @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parallel/task_scheduler_queue.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/atomic.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/array.hpp" +#include "duckdb/common/reference_map.hpp" +#include "duckdb/common/enums/task_scheduler_type.hpp" + +#ifdef DUCKDB_NO_THREADS +#include +#endif + +namespace duckdb { + +class Task; +class TaskSchedulerQueue; +struct ConcurrentQueueWrapper; +struct ProducerToken; +struct QueueProducerToken; + +class TaskSchedulerQueue { +public: + explicit TaskSchedulerQueue(TaskSchedulerType pool_type_p); + ~TaskSchedulerQueue(); + +public: + TaskSchedulerType GetPoolType(); + void Enqueue(ProducerToken &token, shared_ptr task); + void EnqueueBulk(ProducerToken &token, vector> &tasks); + bool DequeueFromProducer(ProducerToken &token, shared_ptr &task); + bool Dequeue(shared_ptr &task); + idx_t GetTasksInQueue() const; + idx_t GetApproxSize() const; + idx_t GetProducerCount() const; + idx_t GetTaskCountForProducer(ProducerToken &token) const; + +#ifndef DUCKDB_NO_THREADS + ConcurrentQueueWrapper &GetQueue(); +#else + void RemoveToken(QueueProducerToken &token); +#endif + +private: + const TaskSchedulerType pool_type; +#ifndef DUCKDB_NO_THREADS + unique_ptr queue; + atomic tasks_in_queue {0}; +#else + reference_map_t>> q; + mutable mutex qlock; +#endif +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/base_expression.hpp b/src/duckdb/src/include/duckdb/parser/base_expression.hpp index bef8287c4..de17e817b 100644 --- a/src/duckdb/src/include/duckdb/parser/base_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/base_expression.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/enums/expression_type.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/optional_idx.hpp" namespace duckdb { @@ -62,17 +63,12 @@ class BaseExpression { } //! Returns the alias of the expression - const string &GetAlias() const { + const Identifier &GetAlias() const { return alias; } //! Sets the alias of the expression - void SetAlias(const string &alias_p) { - alias = alias_p; - } - - //! Sets the alias of the expression - void SetAlias(string &&alias_p) { + void SetAlias(Identifier alias_p) { alias = std::move(alias_p); } @@ -89,7 +85,7 @@ class BaseExpression { ExpressionClass expression_class; //! The alias of the expression, - string alias; + Identifier alias; //! The location in the query (if any) optional_idx query_location; @@ -121,7 +117,7 @@ class BaseExpression { virtual bool HasParameter() const = 0; //! Get the name of the expression - virtual string GetName() const; + virtual Identifier GetName() const; //! Convert the Expression to a String virtual string ToString() const = 0; //! Print the expression to stdout diff --git a/src/duckdb/src/include/duckdb/parser/column_definition.hpp b/src/duckdb/src/include/duckdb/parser/column_definition.hpp index 5ae61ef06..aacd3d85d 100644 --- a/src/duckdb/src/include/duckdb/parser/column_definition.hpp +++ b/src/duckdb/src/include/duckdb/parser/column_definition.hpp @@ -24,8 +24,8 @@ class ColumnDefinition; //! A column of a table. class ColumnDefinition { public: - DUCKDB_API ColumnDefinition(string name, LogicalType type); - DUCKDB_API ColumnDefinition(string name, LogicalType type, unique_ptr expression, + DUCKDB_API ColumnDefinition(Identifier name, LogicalType type); + DUCKDB_API ColumnDefinition(Identifier name, LogicalType type, unique_ptr expression, TableColumnType category); public: @@ -40,8 +40,8 @@ class ColumnDefinition { void SetType(const LogicalType &type); //! name - DUCKDB_API const string &Name() const; - void SetName(const string &name); + DUCKDB_API const Identifier &Name() const; + void SetName(const Identifier &name); //! comment DUCKDB_API const Value &Comment() const; @@ -87,13 +87,13 @@ class ColumnDefinition { void ChangeGeneratedExpressionType(const LogicalType &type); void GetListOfDependencies(vector &dependencies) const; - string GetName() const; + Identifier GetName() const; LogicalType GetType() const; private: //! The name of the entry - string name; + Identifier name; //! The type of the column LogicalType type; //! Compression Type used for this column diff --git a/src/duckdb/src/include/duckdb/parser/column_list.hpp b/src/duckdb/src/include/duckdb/parser/column_list.hpp index 7bf9cc173..7a32ded3d 100644 --- a/src/duckdb/src/include/duckdb/parser/column_list.hpp +++ b/src/duckdb/src/include/duckdb/parser/column_list.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/parser/column_definition.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { @@ -26,16 +27,16 @@ class ColumnList { DUCKDB_API const ColumnDefinition &GetColumn(LogicalIndex index) const; DUCKDB_API const ColumnDefinition &GetColumn(PhysicalIndex index) const; - DUCKDB_API const ColumnDefinition &GetColumn(const string &name) const; + DUCKDB_API const ColumnDefinition &GetColumn(const Identifier &name) const; DUCKDB_API ColumnDefinition &GetColumnMutable(LogicalIndex index); DUCKDB_API ColumnDefinition &GetColumnMutable(PhysicalIndex index); - DUCKDB_API ColumnDefinition &GetColumnMutable(const string &name); + DUCKDB_API ColumnDefinition &GetColumnMutable(const Identifier &name); DUCKDB_API vector GetColumnNames() const; DUCKDB_API vector GetColumnTypes() const; - DUCKDB_API bool ColumnExists(const string &name) const; + DUCKDB_API bool ColumnExists(const Identifier &name) const; - DUCKDB_API LogicalIndex GetColumnIndex(string &column_name) const; + DUCKDB_API LogicalIndex GetColumnIndex(Identifier &column_name) const; DUCKDB_API PhysicalIndex LogicalToPhysical(LogicalIndex index) const; DUCKDB_API LogicalIndex PhysicalToLogical(PhysicalIndex index) const; @@ -63,7 +64,7 @@ class ColumnList { private: vector columns; //! A map of column name to column index - case_insensitive_map_t name_map; + identifier_map_t name_map; //! The set of physical columns vector physical_columns; //! Allow duplicate names or not diff --git a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp index b1a2e546a..ccb38c95c 100644 --- a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp @@ -25,7 +25,7 @@ struct CommonTableExpressionInfo { //! Used by deserialization: prefers query_node; if null, uses query->node. CommonTableExpressionInfo(unique_ptr query, unique_ptr query_node); - vector aliases; + vector aliases; vector> key_targets; vector> payload_aggregates; diff --git a/src/duckdb/src/include/duckdb/parser/constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraint.hpp index 90eacaf89..e9e3c01a2 100644 --- a/src/duckdb/src/include/duckdb/parser/constraint.hpp +++ b/src/duckdb/src/include/duckdb/parser/constraint.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/constants.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/exception.hpp" @@ -36,10 +37,10 @@ enum class ForeignKeyType : uint8_t { struct ForeignKeyInfo { ForeignKeyType type; - string schema; + Identifier schema; //! if type is FK_TYPE_FOREIGN_KEY_TABLE, means main key table, if type is FK_TYPE_PRIMARY_KEY_TABLE, means foreign //! key table - string table; + Identifier table; //! The set of main key table's column's index vector pk_keys; //! The set of foreign key table's column's index diff --git a/src/duckdb/src/include/duckdb/parser/constraints/foreign_key_constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraints/foreign_key_constraint.hpp index 7601a2322..372c4ef9b 100644 --- a/src/duckdb/src/include/duckdb/parser/constraints/foreign_key_constraint.hpp +++ b/src/duckdb/src/include/duckdb/parser/constraints/foreign_key_constraint.hpp @@ -18,12 +18,12 @@ class ForeignKeyConstraint : public Constraint { static constexpr const ConstraintType TYPE = ConstraintType::FOREIGN_KEY; public: - DUCKDB_API ForeignKeyConstraint(vector pk_columns, vector fk_columns, ForeignKeyInfo info); + DUCKDB_API ForeignKeyConstraint(vector pk_columns, vector fk_columns, ForeignKeyInfo info); //! The set of main key table's columns - vector pk_columns; + vector pk_columns; //! The set of foreign key table's columns - vector fk_columns; + vector fk_columns; ForeignKeyInfo info; public: diff --git a/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp b/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp index dd73a87a0..98a97a077 100644 --- a/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp +++ b/src/duckdb/src/include/duckdb/parser/constraints/unique_constraint.hpp @@ -20,8 +20,8 @@ class UniqueConstraint : public Constraint { public: DUCKDB_API UniqueConstraint(const LogicalIndex index, const bool is_primary_key); - DUCKDB_API UniqueConstraint(const LogicalIndex index, string column_name, const bool is_primary_key); - DUCKDB_API UniqueConstraint(vector columns, const bool is_primary_key); + DUCKDB_API UniqueConstraint(const LogicalIndex index, Identifier column_name, const bool is_primary_key); + DUCKDB_API UniqueConstraint(vector columns, const bool is_primary_key); public: DUCKDB_API string ToString() const override; @@ -38,13 +38,13 @@ class UniqueConstraint : public Constraint { //! Sets the column index of the constraint. void SetIndex(const LogicalIndex new_index); //! Returns a constant reference to the column names on which the constraint is defined. - const vector &GetColumnNames() const; + const vector &GetColumnNames() const; //! Returns a mutable reference to the column names on which the constraint is defined. - vector &GetColumnNamesMutable(); + vector &GetColumnNamesMutable(); //! Returns the column indexes on which the constraint is defined. vector GetLogicalIndexes(const ColumnList &columns) const; //! Get the name of the constraint. - string GetName(const string &table_name) const; + Identifier GetName(const Identifier &table_name) const; private: UniqueConstraint(); @@ -58,7 +58,7 @@ class UniqueConstraint : public Constraint { //! The indexed column of the constraint. Only used for single-column constraints, invalid otherwise. LogicalIndex index; //! The names of the columns on which this constraint is defined. Only set if the index field is not set. - vector columns; + vector columns; //! Whether this is a PRIMARY KEY constraint, or a UNIQUE constraint. bool is_primary_key; }; diff --git a/src/duckdb/src/include/duckdb/parser/expression/between_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/between_expression.hpp index 12132bc47..6b50619b9 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/between_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/between_expression.hpp @@ -20,14 +20,10 @@ class BetweenExpression : public ParsedExpression { DUCKDB_API BetweenExpression(unique_ptr input, unique_ptr lower, unique_ptr upper); - unique_ptr input; - unique_ptr lower; - unique_ptr upper; - public: string ToString() const override; - static bool Equal(const BetweenExpression &a, const BetweenExpression &b); + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; @@ -43,6 +39,15 @@ class BetweenExpression : public ParsedExpression { const ParsedExpression &UpperBound() const { return *upper; } + unique_ptr &InputMutable() { + return input; + } + unique_ptr &LowerBoundMutable() { + return lower; + } + unique_ptr &UpperBoundMutable() { + return upper; + } public: template @@ -52,5 +57,10 @@ class BetweenExpression : public ParsedExpression { private: BetweenExpression(); + +private: + unique_ptr input; + unique_ptr lower; + unique_ptr upper; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/case_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/case_expression.hpp index 8478f05b8..44403b6a8 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/case_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/case_expression.hpp @@ -29,13 +29,24 @@ class CaseExpression : public ParsedExpression { public: DUCKDB_API CaseExpression(); - vector case_checks; - unique_ptr else_expr; +public: + const vector &CaseChecks() const { + return case_checks; + } + const ParsedExpression &Else() const { + return *else_expr; + } + vector &CaseChecksMutable() { + return case_checks; + } + unique_ptr &ElseMutable() { + return else_expr; + } public: string ToString() const override; - static bool Equal(const CaseExpression &a, const CaseExpression &b); + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; @@ -46,13 +57,17 @@ class CaseExpression : public ParsedExpression { template static string ToString(const T &entry) { string case_str = "CASE "; - for (auto &check : entry.case_checks) { + for (auto &check : entry.CaseChecks()) { case_str += " WHEN (" + check.when_expr->ToString() + ")"; case_str += " THEN (" + check.then_expr->ToString() + ")"; } - case_str += " ELSE " + entry.else_expr->ToString(); + case_str += " ELSE " + entry.Else().ToString(); case_str += " END"; return case_str; } + +private: + vector case_checks; + unique_ptr else_expr; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/cast_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/cast_expression.hpp index 7c48e0dde..8f6713bcb 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/cast_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/cast_expression.hpp @@ -21,17 +21,27 @@ class CastExpression : public ParsedExpression { public: DUCKDB_API CastExpression(LogicalType target, unique_ptr child, bool try_cast = false); - //! The child of the cast expression - unique_ptr child; - //! The type to cast to - LogicalType cast_type; - //! Whether or not this is a try_cast expression - bool try_cast; - public: + const LogicalType &TargetType() const { + return cast_type; + } + LogicalType &TargetTypeMutable() { + return cast_type; + } + const ParsedExpression &Child() const { + return *child; + } + unique_ptr &ChildMutable() { + return child; + } + bool IsTryCast() const { + return try_cast; + } + string ToString() const override; - static bool Equal(const CastExpression &a, const CastExpression &b); + bool Equals(const ParsedExpression &other) const override; + hash_t Hash() const override; unique_ptr Copy() const override; @@ -41,10 +51,18 @@ class CastExpression : public ParsedExpression { public: template static string ToString(const T &entry) { - return (entry.try_cast ? "TRY_CAST(" : "CAST(") + entry.child->ToString() + " AS " + - entry.cast_type.ToString() + ")"; + return (entry.IsTryCast() ? "TRY_CAST(" : "CAST(") + entry.Child().ToString() + " AS " + + entry.TargetType().ToString() + ")"; } +private: + //! The child of the cast expression + unique_ptr child; + //! The type to cast to + LogicalType cast_type; + //! Whether or not this is a try_cast expression + bool try_cast; + private: CastExpression(); }; diff --git a/src/duckdb/src/include/duckdb/parser/expression/collate_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/collate_expression.hpp index 6f3bcac6e..0c29a65f0 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/collate_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/collate_expression.hpp @@ -20,21 +20,33 @@ class CollateExpression : public ParsedExpression { public: CollateExpression(string collation, unique_ptr child); - //! The child of the cast expression - unique_ptr child; - //! The collation clause - string collation; - public: + const ParsedExpression &Child() const { + return *child; + } + unique_ptr &ChildMutable() { + return child; + } + const string &Collation() const { + return collation; + } + string ToString() const override; - static bool Equal(const CollateExpression &a, const CollateExpression &b); + bool Equals(const ParsedExpression &other) const override; + hash_t Hash() const override; unique_ptr Copy() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); +private: + //! The child of the cast expression + unique_ptr child; + //! The collation clause + string collation; + private: CollateExpression(); }; diff --git a/src/duckdb/src/include/duckdb/parser/expression/columnref_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/columnref_expression.hpp index 408dcaab3..39046ee18 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/columnref_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/columnref_expression.hpp @@ -22,29 +22,33 @@ class ColumnRefExpression : public ParsedExpression { public: //! Specify both the column and table name - ColumnRefExpression(string column_name, string table_name); + ColumnRefExpression(Identifier column_name, Identifier table_name); //! Specify both the column and table alias - ColumnRefExpression(string column_name, const BindingAlias &alias); + ColumnRefExpression(Identifier column_name, const BindingAlias &alias); //! Only specify the column name, the table name will be derived later - explicit ColumnRefExpression(string column_name); + explicit ColumnRefExpression(Identifier column_name); //! Specify a set of names - explicit ColumnRefExpression(vector column_names); - - //! The stack of names in order of which they appear (column_names[0].column_names[1].column_names[2]....) - vector column_names; + explicit ColumnRefExpression(vector column_names); public: + const vector &ColumnNames() const { + return column_names; + } + vector &ColumnNamesMutable() { + return column_names; + } + bool IsQualified() const; - const string &GetColumnName() const; - const string &GetTableName() const; + const Identifier &GetColumnName() const; + const Identifier &GetTableName() const; bool IsScalar() const override { return false; } - string GetName() const override; + Identifier GetName() const override; string ToString() const override; - static bool Equal(const ColumnRefExpression &a, const ColumnRefExpression &b); + bool Equals(const ParsedExpression &other) const override; hash_t Hash() const override; unique_ptr Copy() const override; @@ -52,6 +56,10 @@ class ColumnRefExpression : public ParsedExpression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); +private: + //! The stack of names in order of which they appear (column_names[0].column_names[1].column_names[2]....) + vector column_names; + private: ColumnRefExpression(); }; diff --git a/src/duckdb/src/include/duckdb/parser/expression/comparison_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/comparison_expression.hpp index 52b36ce47..ba45b442c 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/comparison_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/comparison_expression.hpp @@ -21,13 +21,22 @@ class ComparisonExpression : public ParsedExpression { DUCKDB_API ComparisonExpression(ExpressionType type, unique_ptr left, unique_ptr right); - unique_ptr left; - unique_ptr right; - public: + const ParsedExpression &Left() const { + return *left; + } + unique_ptr &LeftMutable() { + return left; + } + const ParsedExpression &Right() const { + return *right; + } + unique_ptr &RightMutable() { + return right; + } string ToString() const override; - static bool Equal(const ComparisonExpression &a, const ComparisonExpression &b); + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; @@ -40,7 +49,12 @@ class ComparisonExpression : public ParsedExpression { return StringUtil::Format("(%s %s %s)", left.ToString(), ExpressionTypeToOperator(type), right.ToString()); } +private: + unique_ptr left; + unique_ptr right; + private: explicit ComparisonExpression(ExpressionType type); + ComparisonExpression(); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/conjunction_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/conjunction_expression.hpp index 3e343df9f..948400273 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/conjunction_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/conjunction_expression.hpp @@ -24,26 +24,37 @@ class ConjunctionExpression : public ParsedExpression { DUCKDB_API ConjunctionExpression(ExpressionType type, unique_ptr left, unique_ptr right); - vector> children; - public: + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } void AddExpression(unique_ptr expr); string ToString() const override; - static bool Equal(const ConjunctionExpression &a, const ConjunctionExpression &b); + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); +private: + ConjunctionExpression(); + +private: + vector> children; + public: template static string ToString(const T &entry) { - string result = "(" + entry.children[0]->ToString(); - for (idx_t i = 1; i < entry.children.size(); i++) { - result += " " + ExpressionTypeToOperator(entry.GetExpressionType()) + " " + entry.children[i]->ToString(); + auto &children = entry.GetChildren(); + string result = "(" + children[0]->ToString(); + for (idx_t i = 1; i < children.size(); i++) { + result += " " + ExpressionTypeToOperator(entry.GetExpressionType()) + " " + children[i]->ToString(); } return result + ")"; } diff --git a/src/duckdb/src/include/duckdb/parser/expression/constant_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/constant_expression.hpp index 549fb10e8..861501a6a 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/constant_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/constant_expression.hpp @@ -28,7 +28,7 @@ class ConstantExpression : public ParsedExpression { public: string ToString() const override; - static bool Equal(const ConstantExpression &a, const ConstantExpression &b); + bool Equals(const ParsedExpression &other) const override; hash_t Hash() const override; unique_ptr Copy() const override; diff --git a/src/duckdb/src/include/duckdb/parser/expression/default_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/default_expression.hpp index e7bc89b71..4790c36d5 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/default_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/default_expression.hpp @@ -25,6 +25,7 @@ class DefaultExpression : public ParsedExpression { } string ToString() const override; + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; diff --git a/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp index faab6a4c9..d511f123a 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/function_expression.hpp @@ -8,53 +8,154 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/result_modifier.hpp" #include "duckdb/parser/keyword_helper.hpp" namespace duckdb { + +// An argument passed to a function, consisting of an optional name and an expression. +// The name is used for named arguments, and the expression is the value of the argument. +class FunctionArgument { +public: + // NOLINTNEXTLINE (allow implicit conversions) + FunctionArgument(unique_ptr expression_p) : expression(std::move(expression_p)) { + } + + FunctionArgument(Identifier name_p, unique_ptr expression_p) + : name(std::move(name_p)), expression(std::move(expression_p)) { + } + + const Identifier &GetName() const { + return name; + } + + bool HasName() const { + return !name.empty(); + } + + unique_ptr &GetExpressionMutable() { + return expression; + } + const ParsedExpression &GetExpression() const { + return *expression; + } + + FunctionArgument Copy() const { + return FunctionArgument(name, expression ? expression->Copy() : nullptr); + } + + bool Equals(const FunctionArgument &other) const { + if (name != other.name) { + return false; + } + if (expression && other.expression) { + return expression->Equals(*other.expression); + } + return !expression && !other.expression; + } + + hash_t Hash() const; + + string ToString() const { + if (name.empty()) { + return expression->ToString(); + } + return StringUtil::Format("%s := %s", SQLIdentifier(name), expression->ToString()); + } + + void Serialize(Serializer &serializer) const; + static FunctionArgument Deserialize(Deserializer &deserializer); + +private: + Identifier name; + unique_ptr expression; +}; + //! Represents a function call class FunctionExpression : public ParsedExpression { public: static constexpr const ExpressionClass TYPE = ExpressionClass::FUNCTION; public: - DUCKDB_API FunctionExpression(string catalog_name, string schema_name, const string &function_name, + DUCKDB_API FunctionExpression(Identifier catalog_name, Identifier schema_name, const Identifier &function_name, vector> children, unique_ptr filter = nullptr, unique_ptr order_bys = nullptr, bool distinct = false, bool is_operator = false, bool export_state = false); - DUCKDB_API FunctionExpression(const string &function_name, vector> children, + DUCKDB_API FunctionExpression(const Identifier &function_name, vector> children, + unique_ptr filter = nullptr, + unique_ptr order_bys = nullptr, bool distinct = false, + bool is_operator = false, bool export_state = false); + DUCKDB_API FunctionExpression(Identifier catalog_name, Identifier schema_name, const Identifier &function_name, + vector children, unique_ptr filter = nullptr, + unique_ptr order_bys = nullptr, bool distinct = false, + bool is_operator = false, bool export_state = false); + DUCKDB_API FunctionExpression(const Identifier &function_name, vector children, unique_ptr filter = nullptr, unique_ptr order_bys = nullptr, bool distinct = false, bool is_operator = false, bool export_state = false); - - //! Catalog of the function - string catalog; - //! Schema of the function - string schema; - //! Function name - string function_name; - //! Whether or not the function is an operator, only used for rendering - bool is_operator; - //! List of arguments to the function - vector> children; - //! Whether or not the aggregate function is distinct, only used for aggregates - bool distinct; - //! Expression representing a filter, only used for aggregates - unique_ptr filter; - //! Modifier representing an ORDER BY, only used for aggregates - unique_ptr order_bys; - //! whether this function should export its state or not - bool export_state; public: + const Identifier &Catalog() const { + return catalog; + } + Identifier &CatalogMutable() { + return catalog; + } + const Identifier &Schema() const { + return schema; + } + Identifier &SchemaMutable() { + return schema; + } + const Identifier &FunctionName() const { + return function_name; + } + Identifier &FunctionNameMutable() { + return function_name; + } + void SetFunctionName(string function_name_p) { + function_name = Identifier(std::move(function_name_p)); + } + bool IsOperator() const { + return is_operator; + } + bool &IsOperatorMutable() { + return is_operator; + } + bool Distinct() const { + return distinct; + } + bool &DistinctMutable() { + return distinct; + } + const unique_ptr &Filter() const { + return filter; + } + unique_ptr &FilterMutable() { + return filter; + } + const unique_ptr &OrderBy() const { + return order_bys; + } + unique_ptr &OrderByMutable() { + return order_bys; + } + bool ExportState() const { + return export_state; + } + bool &ExportStateMutable() { + return export_state; + } + string ToString() const override; unique_ptr Copy() const override; - static bool Equal(const FunctionExpression &a, const FunctionExpression &b); + bool Equals(const ParsedExpression &other) const override; hash_t Hash() const override; void Serialize(Serializer &serializer) const override; @@ -63,72 +164,42 @@ class FunctionExpression : public ParsedExpression { void Verify() const override; //! Returns a pointer to the lambda expression, if the function has a lambda expression as a child, else nullptr. - optional_ptr IsLambdaFunction() const; + optional_ptr IsLambdaFunction(); -public: - template - static string ToString(const T &entry, const string &catalog, const string &schema, const string &function_name, - bool is_operator = false, bool distinct = false, BASE *filter = nullptr, - ORDER_MODIFIER *order_bys = nullptr, bool export_state = false, bool add_alias = false) { - if (is_operator) { - // built-in operator - D_ASSERT(!distinct); - if (entry.children.size() == 1) { - if (StringUtil::Contains(function_name, "__postfix")) { - return "((" + entry.children[0]->ToString() + ")" + - StringUtil::Replace(function_name, "__postfix", "") + ")"; - } else { - return function_name + "(" + entry.children[0]->ToString() + ")"; - } - } else if (entry.children.size() == 2) { - return StringUtil::Format("(%s %s %s)", entry.children[0]->ToString(), function_name, - entry.children[1]->ToString()); - } - } - // standard function call - string result; - if (!catalog.empty()) { - result += SQLIdentifier(catalog) + "."; - } - if (!schema.empty()) { - result += SQLIdentifier(schema) + "."; - } - result += SQLIdentifier(function_name); - result += "("; - if (distinct) { - result += "DISTINCT "; - } - result += StringUtil::Join(entry.children, entry.children.size(), ", ", [&](const unique_ptr &child) { - return child->GetAlias().empty() || !add_alias - ? child->ToString() - : StringUtil::Format("%s := %s", SQLIdentifier(child->GetAlias()), child->ToString()); - }); - // ordered aggregate - if (order_bys && !order_bys->orders.empty()) { - if (entry.children.empty()) { - result += ") WITHIN GROUP ("; - } - result += " ORDER BY "; - for (idx_t i = 0; i < order_bys->orders.size(); i++) { - if (i > 0) { - result += ", "; - } - result += order_bys->orders[i].ToString(); - } - } - result += ")"; + bool IsLegacyFunctionCall() const { + return is_legacy_function_call; + } - // filtered aggregate - if (filter) { - result += " FILTER (WHERE " + filter->ToString() + ")"; - } + const vector &GetArguments() const { + return arguments; + } + vector &GetArgumentsMutable() { + return arguments; + } - if (export_state) { - result += " EXPORT_STATE"; - } +private: + //! Catalog of the function + Identifier catalog; + //! Schema of the function + Identifier schema; + //! Function name + Identifier function_name; + //! Whether or not the function is an operator, only used for rendering + bool is_operator; + //! List of arguments to the function + vector arguments; + //! Whether or not the aggregate function is distinct, only used for aggregates + bool distinct; + //! Expression representing a filter, only used for aggregates + unique_ptr filter; + //! Modifier representing an ORDER BY, only used for aggregates + unique_ptr order_bys; + //! whether this function should export its state or not + bool export_state; - return result; - } + //! Whether this function is a legacy function call, which means it was parsed from a function call that does not + //! use the new function argument syntax. This is used to determine how to handle named arguments during binding. + bool is_legacy_function_call = false; private: FunctionExpression(); diff --git a/src/duckdb/src/include/duckdb/parser/expression/lambda_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/lambda_expression.hpp index 08c02f835..9ac278865 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/lambda_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/lambda_expression.hpp @@ -28,33 +28,55 @@ class LambdaExpression : public ParsedExpression { LambdaExpression(vector named_parameters_p, unique_ptr expr); LambdaExpression(unique_ptr lhs, unique_ptr expr); - //! The syntax type. - LambdaSyntaxType syntax_type; - //! The LHS of a lambda expression or the JSON "->"-operator. We need the context - //! to determine if the LHS is a list of column references (lambda parameters) or an expression (JSON) - unique_ptr lhs; - //! The lambda or JSON expression (RHS) - unique_ptr expr; - //! Band-aid for conflicts between lambda binding and JSON binding. - unique_ptr copied_expr; - public: + LambdaSyntaxType GetLambdaSyntaxType() const { + return syntax_type; + } + LambdaSyntaxType &GetLambdaSyntaxTypeMutable() { + return syntax_type; + } + const ParsedExpression &Left() const { + return *lhs; + } + unique_ptr &LeftMutable() { + return lhs; + } + const ParsedExpression &Right() const { + return *expr; + } + unique_ptr &RightMutable() { + return expr; + } + unique_ptr &CopiedExprMutable() { + return copied_expr; + } + //! Returns a vector to the column references in the LHS expression, and fills the error message, //! if the LHS is not a valid lambda parameter list vector> ExtractColumnRefExpressions(string &error_message) const; //! Returns the error message for an invalid lambda parameter list static string InvalidParametersErrorMessage(); //! Returns true, if the column_name is a lambda parameter name - static bool IsLambdaParameter(const vector> &lambda_params, const string &column_name); + static bool IsLambdaParameter(const vector &lambda_params, const Identifier &column_name); string ToString() const override; - static bool Equal(const LambdaExpression &a, const LambdaExpression &b); - hash_t Hash() const override; + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); +private: + //! The syntax type. + LambdaSyntaxType syntax_type; + //! The LHS of a lambda expression or the JSON "->"-operator. We need the context + //! to determine if the LHS is a list of column references (lambda parameters) or an expression (JSON) + unique_ptr lhs; + //! The lambda or JSON expression (RHS) + unique_ptr expr; + //! Band-aid for conflicts between lambda binding and JSON binding. + unique_ptr copied_expr; + private: LambdaExpression(); }; diff --git a/src/duckdb/src/include/duckdb/parser/expression/lambdaref_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/lambdaref_expression.hpp index 6fafaa5eb..63f767ce7 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/lambdaref_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/lambdaref_expression.hpp @@ -23,25 +23,40 @@ class LambdaRefExpression : public ParsedExpression { public: //! Constructs a LambdaRefExpression from a lambda_idx and a column_name. We do not specify a table name, //! because we use dummy tables to bind lambda parameters - LambdaRefExpression(idx_t lambda_idx, string column_name_p); - - //! The index of the lambda parameter in the lambda_bindings vector - idx_t lambda_idx; - //! The name of the lambda parameter (in a specific Binding in lambda_bindings) - string column_name; + LambdaRefExpression(idx_t lambda_idx, Identifier column_name_p); public: + idx_t LambdaIndex() const { + return lambda_idx; + } + idx_t &LambdaIndexMutable() { + return lambda_idx; + } + const Identifier &ColumnName() const { + return column_name; + } + Identifier &ColumnNameMutable() { + return column_name; + } + bool IsScalar() const override; - string GetName() const override; + Identifier GetName() const override; string ToString() const override; + bool Equals(const ParsedExpression &other) const override; hash_t Hash() const override; unique_ptr Copy() const override; //! Traverses the lambda_bindings to find a matching binding for the column_name static unique_ptr FindMatchingBinding(optional_ptr> &lambda_bindings, - const string ¶meter_name); + const Identifier ¶meter_name); void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! The index of the lambda parameter in the lambda_bindings vector + idx_t lambda_idx; + //! The name of the lambda parameter (in a specific Binding in lambda_bindings) + Identifier column_name; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp index 9ef98bc4b..07883f7d9 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/operator_expression.hpp @@ -23,53 +23,61 @@ class OperatorExpression : public ParsedExpression { unique_ptr right = nullptr); DUCKDB_API OperatorExpression(ExpressionType type, vector> children); - vector> children; - public: + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } string ToString() const override; - static bool Equal(const OperatorExpression &a, const OperatorExpression &b); + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); +private: + OperatorExpression(); + public: template static string ToString(const T &entry) { + auto &children = entry.GetChildren(); auto op = ExpressionTypeToOperator(entry.GetExpressionType()); if (!op.empty()) { // use the operator string to represent the operator - D_ASSERT(entry.children.size() == 2); - return entry.children[0]->ToString() + " " + op + " " + entry.children[1]->ToString(); + D_ASSERT(children.size() == 2); + return children[0]->ToString() + " " + op + " " + children[1]->ToString(); } switch (entry.GetExpressionType()) { case ExpressionType::COMPARE_IN: case ExpressionType::COMPARE_NOT_IN: { string op_type = entry.GetExpressionType() == ExpressionType::COMPARE_IN ? " IN " : " NOT IN "; - string in_child = entry.children[0]->ToString(); + string in_child = children[0]->ToString(); string child_list = "("; - for (idx_t i = 1; i < entry.children.size(); i++) { + for (idx_t i = 1; i < children.size(); i++) { if (i > 1) { child_list += ", "; } - child_list += entry.children[i]->ToString(); + child_list += children[i]->ToString(); } child_list += ")"; return "(" + in_child + op_type + child_list + ")"; } case ExpressionType::OPERATOR_UNPACK: { - return StringUtil::Format("UNPACK(%s)", entry.children[0]->ToString()); + return StringUtil::Format("UNPACK(%s)", children[0]->ToString()); } case ExpressionType::OPERATOR_TRY: { - return StringUtil::Format("TRY(%s)", entry.children[0]->ToString()); + return StringUtil::Format("TRY(%s)", children[0]->ToString()); } case ExpressionType::OPERATOR_NOT: { string result = "("; result += ExpressionTypeToString(entry.GetExpressionType()); result += " "; - result += StringUtil::Join(entry.children, entry.children.size(), ", ", + result += StringUtil::Join(children, children.size(), ", ", [](const unique_ptr &child) { return child->ToString(); }); result += ")"; return result; @@ -78,49 +86,48 @@ class OperatorExpression : public ParsedExpression { case ExpressionType::OPERATOR_COALESCE: { string result = ExpressionTypeToString(entry.GetExpressionType()); result += "("; - result += StringUtil::Join(entry.children, entry.children.size(), ", ", + result += StringUtil::Join(children, children.size(), ", ", [](const unique_ptr &child) { return child->ToString(); }); result += ")"; return result; } case ExpressionType::OPERATOR_IS_NULL: - return "(" + entry.children[0]->ToString() + " IS NULL)"; + return "(" + children[0]->ToString() + " IS NULL)"; case ExpressionType::OPERATOR_IS_NOT_NULL: - return "(" + entry.children[0]->ToString() + " IS NOT NULL)"; + return "(" + children[0]->ToString() + " IS NOT NULL)"; case ExpressionType::ARRAY_EXTRACT: - return entry.children[0]->ToString() + "[" + entry.children[1]->ToString() + "]"; + return children[0]->ToString() + "[" + children[1]->ToString() + "]"; case ExpressionType::ARRAY_SLICE: { - string begin = entry.children[1]->ToString(); + string begin = children[1]->ToString(); if (begin == "[]") { begin = ""; } - string end = entry.children[2]->ToString(); + string end = children[2]->ToString(); if (end == "[]") { - if (entry.children.size() == 4) { + if (children.size() == 4) { end = "-"; } else { end = ""; } } - if (entry.children.size() == 4) { - return entry.children[0]->ToString() + "[" + begin + ":" + end + ":" + entry.children[3]->ToString() + - "]"; + if (children.size() == 4) { + return children[0]->ToString() + "[" + begin + ":" + end + ":" + children[3]->ToString() + "]"; } - return entry.children[0]->ToString() + "[" + begin + ":" + end + "]"; + return children[0]->ToString() + "[" + begin + ":" + end + "]"; } case ExpressionType::STRUCT_EXTRACT: { - if (entry.children[1]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { + if (children[1]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { return string(); } - auto child_string = entry.children[1]->ToString(); + auto child_string = children[1]->ToString(); D_ASSERT(child_string.size() >= 3); D_ASSERT(child_string[0] == '\'' && child_string[child_string.size() - 1] == '\''); - return StringUtil::Format("(%s).%s", entry.children[0]->ToString(), + return StringUtil::Format("(%s).%s", children[0]->ToString(), SQLIdentifier(child_string.substr(1, child_string.size() - 2))); } case ExpressionType::ARRAY_CONSTRUCTOR: { string result = "(ARRAY["; - result += StringUtil::Join(entry.children, entry.children.size(), ", ", + result += StringUtil::Join(children, children.size(), ", ", [](const unique_ptr &child) { return child->ToString(); }); result += "])"; return result; @@ -129,6 +136,9 @@ class OperatorExpression : public ParsedExpression { throw InternalException("Unrecognized operator type"); } } + +private: + vector> children; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/parameter_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/parameter_expression.hpp index ac97ecbbf..8b242f79e 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/parameter_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/parameter_expression.hpp @@ -34,9 +34,13 @@ class ParameterExpression : public ParsedExpression { public: ParameterExpression(); - string identifier; - public: + const duckdb::Identifier &Identifier() const { + return identifier; + } + duckdb::Identifier &IdentifierMutable() { + return identifier; + } bool IsScalar() const override { return true; } @@ -46,12 +50,15 @@ class ParameterExpression : public ParsedExpression { string ToString() const override; - static bool Equal(const ParameterExpression &a, const ParameterExpression &b); + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; hash_t Hash() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + duckdb::Identifier identifier; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/positional_reference_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/positional_reference_expression.hpp index 9943551b0..cb2fe529c 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/positional_reference_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/positional_reference_expression.hpp @@ -18,22 +18,29 @@ class PositionalReferenceExpression : public ParsedExpression { public: DUCKDB_API explicit PositionalReferenceExpression(idx_t index); - idx_t index; - public: + idx_t Index() const { + return index; + } + idx_t &IndexMutable() { + return index; + } bool IsScalar() const override { return false; } string ToString() const override; - static bool Equal(const PositionalReferenceExpression &a, const PositionalReferenceExpression &b); + bool Equals(const ParsedExpression &other) const override; unique_ptr Copy() const override; hash_t Hash() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); +private: + idx_t index; + private: PositionalReferenceExpression(); }; diff --git a/src/duckdb/src/include/duckdb/parser/expression/star_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/star_expression.hpp index dbe7c9cfc..a8bfbfde5 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/star_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/star_expression.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/parser/qualified_name_set.hpp" @@ -22,25 +23,50 @@ class StarExpression : public ParsedExpression { static constexpr const ExpressionClass TYPE = ExpressionClass::STAR; public: - explicit StarExpression(string relation_name = string()); - - //! The relation name in case of tbl.*, or empty if this is a normal * - string relation_name; - //! List of columns to exclude from the STAR expression - qualified_column_set_t exclude_list; - //! List of columns to replace with another expression - case_insensitive_map_t> replace_list; - //! List of columns to rename - qualified_column_map_t rename_list; - //! The expression to select the columns (regular expression or list) - unique_ptr expr; - //! Whether or not this is a COLUMNS expression - bool columns = false; + explicit StarExpression(Identifier relation_name = Identifier()); public: + const Identifier &RelationName() const { + return relation_name; + } + Identifier &RelationNameMutable() { + return relation_name; + } + const qualified_column_set_t &ExcludeList() const { + return exclude_list; + } + qualified_column_set_t &ExcludeListMutable() { + return exclude_list; + } + const identifier_map_t> &ReplaceList() const { + return replace_list; + } + identifier_map_t> &ReplaceListMutable() { + return replace_list; + } + const qualified_column_map_t &RenameList() const { + return rename_list; + } + qualified_column_map_t &RenameListMutable() { + return rename_list; + } + const unique_ptr &Expression() const { + return expr; + } + unique_ptr &ExpressionMutable() { + return expr; + } + bool IsColumns() const { + return columns; + } + bool &IsColumnsMutable() { + return columns; + } + string ToString() const override; - static bool Equal(const StarExpression &a, const StarExpression &b); + bool Equals(const ParsedExpression &other) const override; + hash_t Hash() const override; static bool IsStar(const ParsedExpression &a); static bool IsColumns(const ParsedExpression &a); static bool IsColumnsUnpacked(const ParsedExpression &a); @@ -48,19 +74,33 @@ class StarExpression : public ParsedExpression { unique_ptr Copy() const override; static unique_ptr - DeserializeStarExpression(string &&relation_name, const case_insensitive_set_t &exclude_list, - case_insensitive_map_t> &&replace_list, bool columns, + DeserializeStarExpression(Identifier &&relation_name, const identifier_set_t &exclude_list, + identifier_map_t> &&replace_list, bool columns, unique_ptr expr, bool unpacked, const qualified_column_set_t &qualified_exclude_list, - qualified_column_map_t &&rename_list); + qualified_column_map_t &&rename_list); void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); +private: + //! The relation name in case of tbl.*, or empty if this is a normal * + Identifier relation_name; + //! List of columns to exclude from the STAR expression + qualified_column_set_t exclude_list; + //! List of columns to replace with another expression + identifier_map_t> replace_list; + //! List of columns to rename + qualified_column_map_t rename_list; + //! The expression to select the columns (regular expression or list) + unique_ptr expr; + //! Whether or not this is a COLUMNS expression + bool columns = false; + public: // these methods exist for backwards compatibility of (de)serialization - StarExpression(const case_insensitive_set_t &exclude_list, qualified_column_set_t qualified_set); + StarExpression(const identifier_set_t &exclude_list, qualified_column_set_t qualified_set); - case_insensitive_set_t SerializedExcludeList() const; + identifier_set_t SerializedExcludeList() const; qualified_column_set_t SerializedQualifiedExcludeList() const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/subquery_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/subquery_expression.hpp index b01bda765..45dc0a90e 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/subquery_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/subquery_expression.hpp @@ -22,17 +22,32 @@ class SubqueryExpression : public ParsedExpression { public: SubqueryExpression(); - //! The actual subquery - unique_ptr subquery; - //! The subquery type - SubqueryType subquery_type; - //! the child expression to compare with (in case of IN, ANY, ALL operators, empty for EXISTS queries and scalar - //! subquery) - unique_ptr child; - //! The comparison type of the child expression with the subquery (in case of ANY, ALL operators), empty otherwise - ExpressionType comparison_type; - public: + const unique_ptr &Subquery() const { + return subquery; + } + unique_ptr &SubqueryMutable() { + return subquery; + } + SubqueryType GetSubqueryType() const { + return subquery_type; + } + SubqueryType &GetSubqueryTypeMutable() { + return subquery_type; + } + const unique_ptr &GetChild() const { + return child; + } + unique_ptr &GetChildMutable() { + return child; + } + ExpressionType GetComparisonType() const { + return comparison_type; + } + ExpressionType &GetComparisonTypeMutable() { + return comparison_type; + } + bool HasSubquery() const override { return true; } @@ -42,11 +57,23 @@ class SubqueryExpression : public ParsedExpression { string ToString() const override; - static bool Equal(const SubqueryExpression &a, const SubqueryExpression &b); + bool Equals(const ParsedExpression &other) const override; + hash_t Hash() const override; unique_ptr Copy() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! The actual subquery + unique_ptr subquery; + //! The subquery type + SubqueryType subquery_type; + //! the child expression to compare with (in case of IN, ANY, ALL operators, empty for EXISTS queries and scalar + //! subquery) + unique_ptr child; + //! The comparison type of the child expression with the subquery (in case of ANY, ALL operators), empty otherwise + ExpressionType comparison_type; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/expression/type_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/type_expression.hpp index e26e2fc4e..1cf7b52e1 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/type_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/type_expression.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/keyword_helper.hpp" @@ -18,19 +19,27 @@ class TypeExpression : public ParsedExpression { public: static constexpr const ExpressionClass TYPE = ExpressionClass::TYPE; - TypeExpression(string catalog, string schema, string type_name, vector> children); - TypeExpression(string type_name, vector> children); + TypeExpression(Identifier catalog, Identifier schema, Identifier type_name, + vector> children); + TypeExpression(Identifier type_name, vector> children); + TypeExpression(const string &type_name, vector> children); public: - const string &GetTypeName() const { + const Identifier &GetTypeName() const { return type_name; } - const string &GetSchema() const { + const Identifier &GetSchema() const { return schema; } - const string &GetCatalog() const { + void SetSchema(Identifier new_schema) { + schema = std::move(new_schema); + } + const Identifier &GetCatalog() const { return catalog; } + void SetCatalog(Identifier new_catalog) { + catalog = std::move(new_catalog); + } const vector> &GetChildren() const { return children; } @@ -43,7 +52,7 @@ class TypeExpression : public ParsedExpression { unique_ptr Copy() const override; - static bool Equal(const TypeExpression &a, const TypeExpression &b); + bool Equals(const ParsedExpression &other) const override; hash_t Hash() const override; void Serialize(Serializer &serializer) const override; @@ -55,9 +64,9 @@ class TypeExpression : public ParsedExpression { TypeExpression(); //! Qualified name parts - string catalog; - string schema; - string type_name; + Identifier catalog; + Identifier schema; + Identifier type_name; //! Children of the type expression (e.g. type parameters) vector> children; diff --git a/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp b/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp index b8dfceaba..ae76e407c 100644 --- a/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/expression/window_expression.hpp @@ -8,8 +8,10 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/expression/function_expression.hpp" namespace duckdb { @@ -41,40 +43,6 @@ class WindowExpression : public ParsedExpression { public: WindowExpression(const string &catalog_name, const string &schema, const string &function_name); - //! Catalog of the aggregate function - string catalog; - //! Schema of the aggregate function - string schema; - //! Name of the aggregate function - string function_name; - //! The child expression of the main window function - vector> children; - //! The set of expressions to partition by - vector> partitions; - //! The set of ordering clauses - vector orders; - //! Expression representing a filter, only used for aggregates - unique_ptr filter_expr; - //! True if we parsed IGNORE/RESPECT NULLS - bool has_ignore_nulls = false; - //! True to ignore NULL values - bool ignore_nulls = false; - //! Whether or not the aggregate function is distinct, only used for aggregates - bool distinct = false; - //! The window boundaries - WindowBoundary start = WindowBoundary::INVALID; - WindowBoundary end = WindowBoundary::INVALID; - //! The EXCLUDE clause - WindowExcludeMode exclude_clause = WindowExcludeMode::NO_OTHER; - - unique_ptr start_expr; - unique_ptr end_expr; - - //! The set of argument ordering clauses - //! These are distinct from the frame ordering clauses e.g., the "x" in - //! FIRST_VALUE(a ORDER BY x) OVER (PARTITION BY p ORDER BY s) - vector arg_orders; - public: bool IsWindow() const override { return true; @@ -83,9 +51,10 @@ class WindowExpression : public ParsedExpression { //! Convert the Expression to a String string ToString() const override; - static bool Equal(const WindowExpression &a, const WindowExpression &b); + bool Equals(const ParsedExpression &other) const override; + hash_t Hash() const override; - bool HasBoundedParts(); + bool HasBoundedParts() const; unique_ptr Copy() const override; @@ -98,6 +67,105 @@ class WindowExpression : public ParsedExpression { static ExpressionType WindowToExpressionType(const string &fun_name); public: + const Identifier &Catalog() const { + return catalog; + } + Identifier &CatalogMutable() { + return catalog; + } + const Identifier &Schema() const { + return schema; + } + Identifier &SchemaMutable() { + return schema; + } + const Identifier &FunctionName() const { + return function_name; + } + Identifier &FunctionNameMutable() { + return function_name; + } + const vector> &Partitions() const { + return partitions; + } + vector> &PartitionsMutable() { + return partitions; + } + const vector &OrderBy() const { + return orders; + } + vector &OrderByMutable() { + return orders; + } + const unique_ptr &Filter() const { + return filter_expr; + } + unique_ptr &FilterMutable() { + return filter_expr; + } + bool HasIgnoreNulls() const { + return has_ignore_nulls; + } + bool &HasIgnoreNullsMutable() { + return has_ignore_nulls; + } + bool IgnoreNulls() const { + return ignore_nulls; + } + bool &IgnoreNullsMutable() { + return ignore_nulls; + } + bool Distinct() const { + return distinct; + } + bool &DistinctMutable() { + return distinct; + } + WindowBoundary WindowStart() const { + return start; + } + WindowBoundary &WindowStartMutable() { + return start; + } + WindowBoundary WindowEnd() const { + return end; + } + WindowBoundary &WindowEndMutable() { + return end; + } + WindowExcludeMode WindowExclude() const { + return exclude_clause; + } + WindowExcludeMode &WindowExcludeMutable() { + return exclude_clause; + } + const unique_ptr &StartExpr() const { + return start_expr; + } + unique_ptr &StartExprMutable() { + return start_expr; + } + const unique_ptr &EndExpr() const { + return end_expr; + } + unique_ptr &EndExprMutable() { + return end_expr; + } + const vector &ArgOrders() const { + return arg_orders; + } + vector &ArgOrdersMutable() { + return arg_orders; + } + + const vector &GetArguments() const { + return arguments; + } + + vector &GetArgumentsMutable() { + return arguments; + } + static inline string ToUnits(const WindowBoundary boundary, const WindowBoundary rows, const WindowBoundary range, const WindowBoundary groups) { if (boundary == rows) { @@ -114,27 +182,42 @@ class WindowExpression : public ParsedExpression { // Start with function call string result = schema.empty() ? function_name : schema + "." + function_name; result += "("; - if (entry.children.size()) { - // Only one DISTINCT is allowed (on the first argument) - int distincts = entry.distinct ? 0 : 1; - result += StringUtil::Join(entry.children, entry.children.size(), ", ", [&](const unique_ptr &child) { - return (distincts++ ? "" : "DISTINCT ") + child->ToString(); - }); + + if constexpr (std::is_same_v) { + auto &children = entry.GetArguments(); + if (children.size()) { + // Only one DISTINCT is allowed (on the first argument) + int distincts = entry.Distinct() ? 0 : 1; + result += StringUtil::Join(children, children.size(), ", ", [&](const FunctionArgument &child) { + return (distincts++ ? "" : "DISTINCT ") + child.ToString(); + }); + } + } else { + auto &children = entry.GetChildren(); + if (children.size()) { + // Only one DISTINCT is allowed (on the first argument) + int distincts = entry.Distinct() ? 0 : 1; + result += StringUtil::Join(children, children.size(), ", ", [&](const unique_ptr &child) { + return (distincts++ ? "" : "DISTINCT ") + child->ToString(); + }); + } } + // ORDER BY arguments - if (!entry.arg_orders.empty()) { + auto &arg_orders = entry.ArgOrders(); + if (!arg_orders.empty()) { result += " ORDER BY "; - result += StringUtil::Join(entry.arg_orders, entry.arg_orders.size(), ", ", + result += StringUtil::Join(arg_orders, arg_orders.size(), ", ", [](const ORDER_NODE &order) { return order.ToString(); }); } // IGNORE NULLS - if (entry.ignore_nulls) { + if (entry.IgnoreNulls()) { result += " IGNORE NULLS"; } // FILTER - if (entry.filter_expr) { - result += ") FILTER (WHERE " + entry.filter_expr->ToString(); + if (entry.Filter()) { + result += ") FILTER (WHERE " + entry.Filter()->ToString(); } // Over clause @@ -142,50 +225,56 @@ class WindowExpression : public ParsedExpression { string sep; // Partitions - if (!entry.partitions.empty()) { + auto &partitions = entry.Partitions(); + if (!partitions.empty()) { result += "PARTITION BY "; - result += StringUtil::Join(entry.partitions, entry.partitions.size(), ", ", + result += StringUtil::Join(partitions, partitions.size(), ", ", [](const unique_ptr &partition) { return partition->ToString(); }); sep = " "; } // Orders - if (!entry.orders.empty()) { + auto &orders = entry.OrderBy(); + if (!orders.empty()) { result += sep; result += "ORDER BY "; - result += StringUtil::Join(entry.orders, entry.orders.size(), ", ", - [](const ORDER_NODE &order) { return order.ToString(); }); + result += + StringUtil::Join(orders, orders.size(), ", ", [](const ORDER_NODE &order) { return order.ToString(); }); sep = " "; } // Rows/Range string units = "ROWS"; string from; - switch (entry.start) { + auto window_start = entry.WindowStart(); + auto window_end = entry.WindowEnd(); + auto &start_expr = entry.StartExpr(); + auto &end_expr = entry.EndExpr(); + switch (window_start) { case WindowBoundary::CURRENT_ROW_RANGE: case WindowBoundary::CURRENT_ROW_ROWS: case WindowBoundary::CURRENT_ROW_GROUPS: from = "CURRENT ROW"; - units = ToUnits(entry.start, WindowBoundary::CURRENT_ROW_ROWS, WindowBoundary::CURRENT_ROW_RANGE, + units = ToUnits(window_start, WindowBoundary::CURRENT_ROW_ROWS, WindowBoundary::CURRENT_ROW_RANGE, WindowBoundary::CURRENT_ROW_GROUPS); break; case WindowBoundary::UNBOUNDED_PRECEDING: - if (entry.end != WindowBoundary::CURRENT_ROW_RANGE) { + if (window_end != WindowBoundary::CURRENT_ROW_RANGE) { from = "UNBOUNDED PRECEDING"; } break; case WindowBoundary::EXPR_PRECEDING_ROWS: case WindowBoundary::EXPR_PRECEDING_RANGE: case WindowBoundary::EXPR_PRECEDING_GROUPS: - from = entry.start_expr->ToString() + " PRECEDING"; - units = ToUnits(entry.start, WindowBoundary::EXPR_PRECEDING_ROWS, WindowBoundary::EXPR_PRECEDING_RANGE, + from = start_expr->ToString() + " PRECEDING"; + units = ToUnits(window_start, WindowBoundary::EXPR_PRECEDING_ROWS, WindowBoundary::EXPR_PRECEDING_RANGE, WindowBoundary::EXPR_PRECEDING_GROUPS); break; case WindowBoundary::EXPR_FOLLOWING_ROWS: case WindowBoundary::EXPR_FOLLOWING_RANGE: case WindowBoundary::EXPR_FOLLOWING_GROUPS: - from = entry.start_expr->ToString() + " FOLLOWING"; - units = ToUnits(entry.start, WindowBoundary::EXPR_FOLLOWING_ROWS, WindowBoundary::EXPR_FOLLOWING_RANGE, + from = start_expr->ToString() + " FOLLOWING"; + units = ToUnits(window_start, WindowBoundary::EXPR_FOLLOWING_ROWS, WindowBoundary::EXPR_FOLLOWING_RANGE, WindowBoundary::EXPR_FOLLOWING_GROUPS); break; case WindowBoundary::UNBOUNDED_FOLLOWING: @@ -194,9 +283,9 @@ class WindowExpression : public ParsedExpression { } string to; - switch (entry.end) { + switch (window_end) { case WindowBoundary::CURRENT_ROW_RANGE: - if (entry.start != WindowBoundary::UNBOUNDED_PRECEDING) { + if (window_start != WindowBoundary::UNBOUNDED_PRECEDING) { to = "CURRENT ROW"; units = "RANGE"; } @@ -204,7 +293,7 @@ class WindowExpression : public ParsedExpression { case WindowBoundary::CURRENT_ROW_ROWS: case WindowBoundary::CURRENT_ROW_GROUPS: to = "CURRENT ROW"; - units = ToUnits(entry.end, WindowBoundary::CURRENT_ROW_ROWS, WindowBoundary::CURRENT_ROW_RANGE, + units = ToUnits(window_end, WindowBoundary::CURRENT_ROW_ROWS, WindowBoundary::CURRENT_ROW_RANGE, WindowBoundary::CURRENT_ROW_GROUPS); break; case WindowBoundary::UNBOUNDED_PRECEDING: @@ -216,21 +305,22 @@ class WindowExpression : public ParsedExpression { case WindowBoundary::EXPR_PRECEDING_ROWS: case WindowBoundary::EXPR_PRECEDING_RANGE: case WindowBoundary::EXPR_PRECEDING_GROUPS: - to = entry.end_expr->ToString() + " PRECEDING"; - units = ToUnits(entry.end, WindowBoundary::EXPR_PRECEDING_ROWS, WindowBoundary::EXPR_PRECEDING_RANGE, + to = end_expr->ToString() + " PRECEDING"; + units = ToUnits(window_end, WindowBoundary::EXPR_PRECEDING_ROWS, WindowBoundary::EXPR_PRECEDING_RANGE, WindowBoundary::EXPR_PRECEDING_GROUPS); break; case WindowBoundary::EXPR_FOLLOWING_ROWS: case WindowBoundary::EXPR_FOLLOWING_RANGE: case WindowBoundary::EXPR_FOLLOWING_GROUPS: - to = entry.end_expr->ToString() + " FOLLOWING"; - units = ToUnits(entry.end, WindowBoundary::EXPR_FOLLOWING_ROWS, WindowBoundary::EXPR_FOLLOWING_RANGE, + to = end_expr->ToString() + " FOLLOWING"; + units = ToUnits(window_end, WindowBoundary::EXPR_FOLLOWING_ROWS, WindowBoundary::EXPR_FOLLOWING_RANGE, WindowBoundary::EXPR_FOLLOWING_GROUPS); break; case WindowBoundary::INVALID: throw InternalException("Unrecognized TO in WindowExpression"); } - if (entry.exclude_clause != WindowExcludeMode::NO_OTHER) { + auto exclude_clause = entry.WindowExclude(); + if (exclude_clause != WindowExcludeMode::NO_OTHER) { // if we have an explicit EXCLUDE we always need to fill in from/to if (from.empty()) { from = "UNBOUNDED PRECEDING"; @@ -257,10 +347,10 @@ class WindowExpression : public ParsedExpression { result += to; } - if (entry.exclude_clause != WindowExcludeMode::NO_OTHER) { + if (exclude_clause != WindowExcludeMode::NO_OTHER) { result += " EXCLUDE "; } - switch (entry.exclude_clause) { + switch (exclude_clause) { case WindowExcludeMode::CURRENT_ROW: result += "CURRENT ROW"; break; @@ -279,11 +369,53 @@ class WindowExpression : public ParsedExpression { return result; } + bool IsLegacyFunctionCall() const { + return is_legacy_function_call; + } + private: - // Backwards-compatible serialization interface - WindowExpression(ExpressionType type, vector> children, - unique_ptr offset_expr, unique_ptr default_expr); + //! Catalog of the aggregate function + Identifier catalog; + //! Schema of the aggregate function + Identifier schema; + //! Name of the aggregate function + Identifier function_name; + //! The child expression of the main window function + vector arguments; + //! The set of expressions to partition by + vector> partitions; + //! The set of ordering clauses + vector orders; + //! Expression representing a filter, only used for aggregates + unique_ptr filter_expr; + //! True if we parsed IGNORE/RESPECT NULLS + bool has_ignore_nulls = false; + //! True to ignore NULL values + bool ignore_nulls = false; + //! Whether or not the aggregate function is distinct, only used for aggregates + bool distinct = false; + //! The window boundaries + WindowBoundary start = WindowBoundary::INVALID; + WindowBoundary end = WindowBoundary::INVALID; + //! The EXCLUDE clause + WindowExcludeMode exclude_clause = WindowExcludeMode::NO_OTHER; + + unique_ptr start_expr; + unique_ptr end_expr; + //! The set of argument ordering clauses + //! These are distinct from the frame ordering clauses e.g., the "x" in + //! FIRST_VALUE(a ORDER BY x) OVER (PARTITION BY p ORDER BY s) + vector arg_orders; + + //! Whether this function is a legacy function call, which means it was parsed from a function call that does not + //! use the new function argument syntax. This is used to determine how to handle named arguments during binding. + bool is_legacy_function_call = false; + +private: + WindowExpression(); + + // Backwards-compatible serialization interface // Remove LEAD/LAG offset/default vector> SerializedChildren(Serializer &serializer) const; unique_ptr SerializedOffset(Serializer &serializer) const; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_database_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_database_info.hpp index fc7be14d2..ceb928613 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_database_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_database_info.hpp @@ -10,6 +10,7 @@ #include "duckdb/parser/parsed_data/alter_info.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { enum class AlterDatabaseType : uint8_t { RENAME_DATABASE = 0 }; @@ -17,7 +18,7 @@ enum class AlterDatabaseType : uint8_t { RENAME_DATABASE = 0 }; struct AlterDatabaseInfo : public AlterInfo { public: explicit AlterDatabaseInfo(AlterDatabaseType alter_database_type); - AlterDatabaseInfo(AlterDatabaseType alter_database_type, string catalog_p, OnEntryNotFound if_not_found); + AlterDatabaseInfo(AlterDatabaseType alter_database_type, Identifier catalog_p, OnEntryNotFound if_not_found); ~AlterDatabaseInfo() override; AlterDatabaseType alter_database_type; @@ -35,9 +36,9 @@ struct AlterDatabaseInfo : public AlterInfo { struct RenameDatabaseInfo : public AlterDatabaseInfo { public: RenameDatabaseInfo(); - RenameDatabaseInfo(string catalog_p, string new_name_p, OnEntryNotFound if_not_found); + RenameDatabaseInfo(Identifier catalog_p, Identifier new_name_p, OnEntryNotFound if_not_found); - string new_name; + Identifier new_name; public: unique_ptr Copy() const override; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp index 5ba8680fc..3cea24761 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_info.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/enums/catalog_type.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/enums/on_entry_not_found.hpp" @@ -33,14 +34,14 @@ enum class AlterBindMode { BIND_ON_ALTER, SKIP_BINDING }; struct AlterEntryData { AlterEntryData() { } - AlterEntryData(string catalog_p, string schema_p, string name_p, OnEntryNotFound if_not_found) + AlterEntryData(Identifier catalog_p, Identifier schema_p, Identifier name_p, OnEntryNotFound if_not_found) : catalog(std::move(catalog_p)), schema(std::move(schema_p)), name(std::move(name_p)), if_not_found(if_not_found) { } - string catalog; - string schema; - string name; + Identifier catalog; + Identifier schema; + Identifier name; OnEntryNotFound if_not_found; }; @@ -49,18 +50,18 @@ struct AlterInfo : public ParseInfo { static constexpr const ParseInfoType TYPE = ParseInfoType::ALTER_INFO; public: - AlterInfo(AlterType type, string catalog, string schema, string name, OnEntryNotFound if_not_found); + AlterInfo(AlterType type, Identifier catalog, Identifier schema, Identifier name, OnEntryNotFound if_not_found); ~AlterInfo() override; AlterType type; //! if exists OnEntryNotFound if_not_found; //! Catalog name to alter - string catalog; + Identifier catalog; //! Schema name to alter - string schema; + Identifier schema; //! Entry name to alter - string name; + Identifier name; //! Allow altering internal entries bool allow_internal; //! Determine whether to skip Bind @@ -76,8 +77,8 @@ struct AlterInfo : public ParseInfo { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); - virtual string GetColumnName() const { - return ""; + virtual Identifier GetColumnName() const { + return Identifier(); }; AlterEntryData GetAlterEntryData() const; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp index 0730f7642..78a9bcd05 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp @@ -13,6 +13,7 @@ #include "duckdb/parser/constraint.hpp" #include "duckdb/parser/result_modifier.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { enum class AlterForeignKeyType : uint8_t { AFT_ADD = 0, AFT_DELETE = 1 }; @@ -21,16 +22,17 @@ enum class AlterForeignKeyType : uint8_t { AFT_ADD = 0, AFT_DELETE = 1 }; // Change Ownership //===--------------------------------------------------------------------===// struct ChangeOwnershipInfo : public AlterInfo { - ChangeOwnershipInfo(CatalogType entry_catalog_type, string entry_catalog, string entry_schema, string entry_name, - string owner_schema, string owner_name, OnEntryNotFound if_not_found); + ChangeOwnershipInfo(CatalogType entry_catalog_type, Identifier entry_catalog, Identifier entry_schema, + Identifier entry_name, Identifier owner_schema, Identifier owner_name, + OnEntryNotFound if_not_found); // Catalog type refers to the entry type, since this struct is usually built from an // ALTER . OWNED BY . statement // here it is only possible to know the type of who is to be owned CatalogType entry_catalog_type; - string owner_schema; - string owner_name; + Identifier owner_schema; + Identifier owner_name; public: CatalogType GetCatalogType() const override; @@ -47,8 +49,8 @@ struct ChangeOwnershipInfo : public AlterInfo { // Set Comment //===--------------------------------------------------------------------===// struct SetCommentInfo : public AlterInfo { - SetCommentInfo(CatalogType entry_catalog_type, string entry_catalog, string entry_schema, string entry_name, - Value new_comment_value_p, OnEntryNotFound if_not_found); + SetCommentInfo(CatalogType entry_catalog_type, Identifier entry_catalog, Identifier entry_schema, + Identifier entry_name, Value new_comment_value_p, OnEntryNotFound if_not_found); CatalogType entry_catalog_type; Value comment_value; @@ -109,13 +111,13 @@ struct AlterTableInfo : public AlterInfo { // RenameColumnInfo //===--------------------------------------------------------------------===// struct RenameColumnInfo : public AlterTableInfo { - RenameColumnInfo(AlterEntryData data, string old_name_p, string new_name_p); + RenameColumnInfo(AlterEntryData data, Identifier old_name_p, Identifier new_name_p); ~RenameColumnInfo() override; //! Column old name - string old_name; + Identifier old_name; //! Column new name - string new_name; + Identifier new_name; public: unique_ptr Copy() const override; @@ -132,18 +134,18 @@ struct RenameColumnInfo : public AlterTableInfo { // RenameFieldInfo //===--------------------------------------------------------------------===// struct RenameFieldInfo : public AlterTableInfo { - RenameFieldInfo(AlterEntryData data, vector column_path, string new_name_p); + RenameFieldInfo(AlterEntryData data, vector column_path, Identifier new_name_p); ~RenameFieldInfo() override; //! Path to source field. - vector column_path; + vector column_path; //! New name of the column (field). - string new_name; + Identifier new_name; public: unique_ptr Copy() const override; string ToString() const override; - string GetColumnName() const override { + Identifier GetColumnName() const override { return column_path[0]; } @@ -158,11 +160,11 @@ struct RenameFieldInfo : public AlterTableInfo { // RenameTableInfo //===--------------------------------------------------------------------===// struct RenameTableInfo : public AlterTableInfo { - RenameTableInfo(AlterEntryData data, string new_name); + RenameTableInfo(AlterEntryData data, Identifier new_name); ~RenameTableInfo() override; //! Relation new name - string new_table_name; + Identifier new_table_name; public: unique_ptr Copy() const override; @@ -202,11 +204,12 @@ struct AddColumnInfo : public AlterTableInfo { // AddFieldInfo //===--------------------------------------------------------------------===// struct AddFieldInfo : public AlterTableInfo { - AddFieldInfo(AlterEntryData data, vector column_path, ColumnDefinition new_field, bool if_field_not_exists); + AddFieldInfo(AlterEntryData data, vector column_path, ColumnDefinition new_field, + bool if_field_not_exists); ~AddFieldInfo() override; //! Path to source field. - vector column_path; + vector column_path; //! New field to add. ColumnDefinition new_field; //! Whether or not an error should be thrown if the field does not exist. @@ -215,7 +218,7 @@ struct AddFieldInfo : public AlterTableInfo { public: unique_ptr Copy() const override; string ToString() const override; - string GetColumnName() const override { + Identifier GetColumnName() const override { return column_path[0]; } @@ -234,7 +237,7 @@ struct RemoveColumnInfo : public AlterTableInfo { ~RemoveColumnInfo() override; //! The column to remove - string removed_column; + Identifier removed_column; //! Whether or not an error should be thrown if the column does not exist bool if_column_exists; //! Whether or not the column should be removed if a dependency conflict arises (used by GENERATED columns) @@ -245,7 +248,7 @@ struct RemoveColumnInfo : public AlterTableInfo { string ToString() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); - string GetColumnName() const override { + Identifier GetColumnName() const override { return removed_column; } @@ -257,11 +260,11 @@ struct RemoveColumnInfo : public AlterTableInfo { // RemoveFieldInfo //===--------------------------------------------------------------------===// struct RemoveFieldInfo : public AlterTableInfo { - RemoveFieldInfo(AlterEntryData data, vector column_path, bool if_column_exists, bool cascade); + RemoveFieldInfo(AlterEntryData data, vector column_path, bool if_column_exists, bool cascade); ~RemoveFieldInfo() override; //! Path to source field. - vector column_path; + vector column_path; //! Whether or not an error should be thrown if the column does not exist. bool if_column_exists; //! Whether or not the column should be removed if a dependency conflict arises (used by GENERATED columns). @@ -270,7 +273,7 @@ struct RemoveFieldInfo : public AlterTableInfo { public: unique_ptr Copy() const override; string ToString() const override; - string GetColumnName() const override { + Identifier GetColumnName() const override { return column_path[0]; } @@ -284,12 +287,12 @@ struct RemoveFieldInfo : public AlterTableInfo { // ChangeColumnTypeInfo //===--------------------------------------------------------------------===// struct ChangeColumnTypeInfo : public AlterTableInfo { - ChangeColumnTypeInfo(AlterEntryData data, string column_name, LogicalType target_type, + ChangeColumnTypeInfo(AlterEntryData data, Identifier column_name, LogicalType target_type, unique_ptr expression); ~ChangeColumnTypeInfo() override; //! The column name to alter - string column_name; + Identifier column_name; //! The target type of the column LogicalType target_type; //! The expression used for data conversion @@ -300,9 +303,9 @@ struct ChangeColumnTypeInfo : public AlterTableInfo { string ToString() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); - string GetColumnName() const override { + Identifier GetColumnName() const override { return column_name; - }; + } private: ChangeColumnTypeInfo(); @@ -312,11 +315,11 @@ struct ChangeColumnTypeInfo : public AlterTableInfo { // SetDefaultInfo //===--------------------------------------------------------------------===// struct SetDefaultInfo : public AlterTableInfo { - SetDefaultInfo(AlterEntryData data, string column_name, unique_ptr new_default); + SetDefaultInfo(AlterEntryData data, Identifier column_name, unique_ptr new_default); ~SetDefaultInfo() override; //! The column name to alter - string column_name; + Identifier column_name; //! The expression used for data conversion unique_ptr expression; @@ -334,13 +337,14 @@ struct SetDefaultInfo : public AlterTableInfo { // AlterForeignKeyInfo //===--------------------------------------------------------------------===// struct AlterForeignKeyInfo : public AlterTableInfo { - AlterForeignKeyInfo(AlterEntryData data, string fk_table, vector pk_columns, vector fk_columns, - vector pk_keys, vector fk_keys, AlterForeignKeyType type); + AlterForeignKeyInfo(AlterEntryData data, Identifier fk_table, vector pk_columns, + vector fk_columns, vector pk_keys, vector fk_keys, + AlterForeignKeyType type); ~AlterForeignKeyInfo() override; - string fk_table; - vector pk_columns; - vector fk_columns; + Identifier fk_table; + vector pk_columns; + vector fk_columns; vector pk_keys; vector fk_keys; AlterForeignKeyType type; @@ -359,11 +363,11 @@ struct AlterForeignKeyInfo : public AlterTableInfo { // SetNotNullInfo //===--------------------------------------------------------------------===// struct SetNotNullInfo : public AlterTableInfo { - SetNotNullInfo(AlterEntryData data, string column_name); + SetNotNullInfo(AlterEntryData data, Identifier column_name); ~SetNotNullInfo() override; //! The column name to alter - string column_name; + Identifier column_name; public: unique_ptr Copy() const override; @@ -379,11 +383,11 @@ struct SetNotNullInfo : public AlterTableInfo { // DropNotNullInfo //===--------------------------------------------------------------------===// struct DropNotNullInfo : public AlterTableInfo { - DropNotNullInfo(AlterEntryData data, string column_name); + DropNotNullInfo(AlterEntryData data, Identifier column_name); ~DropNotNullInfo() override; //! The column name to alter - string column_name; + Identifier column_name; public: unique_ptr Copy() const override; @@ -419,11 +423,11 @@ struct AlterViewInfo : public AlterInfo { // RenameViewInfo //===--------------------------------------------------------------------===// struct RenameViewInfo : public AlterViewInfo { - RenameViewInfo(AlterEntryData data, string new_name); + RenameViewInfo(AlterEntryData data, Identifier new_name); ~RenameViewInfo() override; //! Relation new name - string new_view_name; + Identifier new_view_name; public: unique_ptr Copy() const override; @@ -519,10 +523,10 @@ struct SetTableOptionsInfo : public AlterTableInfo { // ResetOptionsInfo //===--------------------------------------------------------------------===// struct ResetTableOptionsInfo : public AlterTableInfo { - ResetTableOptionsInfo(AlterEntryData data, case_insensitive_set_t table_options); + ResetTableOptionsInfo(AlterEntryData data, identifier_set_t table_options); ~ResetTableOptionsInfo() override; - case_insensitive_set_t table_options; + identifier_set_t table_options; public: unique_ptr Copy() const override; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp index 9afa75d0b..e0b8b55d6 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/types/value.hpp" @@ -25,7 +26,7 @@ struct AttachInfo : public ParseInfo { } //! The alias of the attached database - string name; + Identifier name; //! The path to the attached database string path; //! The path expression to the attached database diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/comment_on_column_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/comment_on_column_info.hpp index cbaf8e7ba..42e102a60 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/comment_on_column_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/comment_on_column_info.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/enums/catalog_type.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/parser/parsed_data/alter_info.hpp" @@ -23,14 +24,14 @@ struct SetColumnCommentInfo : public AlterInfo { public: SetColumnCommentInfo(); - SetColumnCommentInfo(string catalog, string schema, string name, string column_name, Value comment_value, - OnEntryNotFound if_not_found); + SetColumnCommentInfo(Identifier catalog, Identifier schema, Identifier name, Identifier column_name, + Value comment_value, OnEntryNotFound if_not_found); //! The resolved Catalog Type CatalogType catalog_entry_type; //! name of the column to comment on - string column_name; + Identifier column_name; //! The comment, can be NULL or a string Value comment_value; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/connect_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/connect_info.hpp new file mode 100644 index 000000000..e61b41b21 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/connect_info.hpp @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/connect_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/identifier.hpp" +#include "duckdb/parser/parsed_data/parse_info.hpp" + +namespace duckdb { + +struct ConnectInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::CONNECT_INFO; + +public: + ConnectInfo() : ParseInfo(TYPE) { + } + + //! Target — empty for bare `CONNECT;` and for `CONNECT LOCAL;` (in the latter case, + //! `target_is_local` is true). Otherwise either an identifier (attached-db name) or the + //! contents of a string literal (connection-string form). + Identifier name; + //! True iff the target was parsed as the LOCAL keyword (`CONNECT LOCAL;`). When true, `name` + //! is empty and `name_is_string_literal` is false. + bool target_is_local = false; + //! True iff the target was parsed as a StringLiteral. Differentiates `CONNECT 'foo'` + //! (connection-string form) from `CONNECT foo` (attached-db identifier) at the AST level, + //! so ToString roundtrips the source form and downstream impl can dispatch correctly. + bool name_is_string_literal = false; + +public: + unique_ptr Copy() const; + string ToString() const; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/copy_database_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/copy_database_info.hpp index 09b375bd5..7238dcf18 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/copy_database_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/copy_database_info.hpp @@ -13,6 +13,7 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct CopyDatabaseInfo : public ParseInfo { @@ -23,11 +24,11 @@ struct CopyDatabaseInfo : public ParseInfo { explicit CopyDatabaseInfo() : ParseInfo(TYPE), target_database(INVALID_CATALOG) { } - explicit CopyDatabaseInfo(const string &target_database) : ParseInfo(TYPE), target_database(target_database) { + explicit CopyDatabaseInfo(const Identifier &target_database) : ParseInfo(TYPE), target_database(target_database) { } // The destination database to which catalog entries are being copied - string target_database; + Identifier target_database; // The catalog entries that are going to be created in the destination DB vector> entries; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp index 5caa823b2..606d8654a 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/unordered_map.hpp" @@ -27,13 +28,13 @@ struct CopyInfo : public ParseInfo { CopyInfo(); //! The catalog name to copy to/from - string catalog; + Identifier catalog; //! The schema name to copy to/from - string schema; + Identifier schema; //! The table name to copy to/from - string table; + Identifier table; //! List of columns to copy to/from - vector select_list; + vector select_list; //! Whether or not this is a copy to file (false) or copy from a file (true) bool is_from; //! The file format of the external file diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_collation_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_collation_info.hpp index 09dea0570..bc3f85cba 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_collation_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_collation_info.hpp @@ -11,14 +11,15 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/function/scalar_function.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct CreateCollationInfo : public CreateInfo { - DUCKDB_API CreateCollationInfo(string name_p, ScalarFunction function_p, bool combinable_p, + DUCKDB_API CreateCollationInfo(Identifier name_p, ScalarFunction function_p, bool combinable_p, bool not_required_for_equality_p); //! The name of the collation - string name; + Identifier name; //! The collation function to push in case collation is required ScalarFunction function; //! Whether or not the collation can be combined with other collations. diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_coordinate_system_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_coordinate_system_info.hpp index 19d2786ac..179dac63a 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_coordinate_system_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_coordinate_system_info.hpp @@ -10,15 +10,16 @@ #include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct CreateCoordinateSystemInfo : public CreateInfo { - DUCKDB_API CreateCoordinateSystemInfo(string name_p, string authority, string code, string projjson, + DUCKDB_API CreateCoordinateSystemInfo(Identifier name_p, string authority, string code, string projjson, string wkt2_2019); //! The name of the coordinate system //! This is typically in the format "AUTH:CODE", e.g. "OGC:CRS84" - string name; + Identifier name; //! The authority identifier of the coordinate system (e.g. "EPSG") string authority; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_copy_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_copy_function_info.hpp index 40eafbc91..97d77fb8a 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_copy_function_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_copy_function_info.hpp @@ -11,13 +11,14 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/function/copy_function.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct CreateCopyFunctionInfo : public CreateInfo { DUCKDB_API explicit CreateCopyFunctionInfo(CopyFunction function); //! Function name - string name; + Identifier name; //! The table function CopyFunction function; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp index 10f7e271a..b42e9d3dd 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_function_info.hpp @@ -10,6 +10,7 @@ #include "duckdb/parser/parsed_data/create_info.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct FunctionDescription { @@ -26,12 +27,12 @@ struct FunctionDescription { }; struct CreateFunctionInfo : public CreateInfo { - explicit CreateFunctionInfo(CatalogType type, string schema = DEFAULT_SCHEMA); + explicit CreateFunctionInfo(CatalogType type, Identifier schema = Identifier::DefaultSchema()); //! Function name - string name; + Identifier name; //! The function name of which this function is an alias - string alias_of; + Identifier alias_of; //! Function description vector descriptions; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_index_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_index_info.hpp index 6cc6e84ce..995654a3b 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_index_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_index_info.hpp @@ -15,6 +15,7 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct CreateIndexInfo : public CreateInfo { @@ -22,9 +23,9 @@ struct CreateIndexInfo : public CreateInfo { CreateIndexInfo(const CreateIndexInfo &info); //! The table name of the underlying table - string table; + Identifier table; //! The name of the index - string index_name; + Identifier index_name; //! Options values (WITH ...) case_insensitive_map_t options; @@ -42,7 +43,7 @@ struct CreateIndexInfo : public CreateInfo { //! The types of the logical columns (necessary for scanning the table during CREATE INDEX) vector scan_types; //! The names of the logical columns (necessary for scanning the table during CREATE INDEX) - vector names; + vector names; public: DUCKDB_API unique_ptr Copy() const override; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_info.hpp index cd76dff26..77c39c6e7 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_info.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/enum_util.hpp" #include "duckdb/common/enums/on_create_conflict.hpp" @@ -23,7 +24,8 @@ struct CreateInfo : public ParseInfo { static constexpr const ParseInfoType TYPE = ParseInfoType::CREATE_INFO; public: - explicit CreateInfo(CatalogType type, string schema = DEFAULT_SCHEMA, string catalog_p = INVALID_CATALOG) + explicit CreateInfo(CatalogType type, Identifier schema = Identifier::DefaultSchema(), + Identifier catalog_p = Identifier::InvalidCatalog()) : ParseInfo(TYPE), type(type), catalog(std::move(catalog_p)), schema(std::move(schema)), on_conflict(OnCreateConflict::ERROR_ON_CONFLICT), temporary(false), internal(false) { } @@ -33,9 +35,9 @@ struct CreateInfo : public ParseInfo { //! The to-be-created catalog type CatalogType type; //! The catalog name of the entry - string catalog; + Identifier catalog; //! The schema name of the entry - string schema; + Identifier schema; //! What to do on create conflict OnCreateConflict on_conflict; //! Whether or not the entry is temporary @@ -43,7 +45,7 @@ struct CreateInfo : public ParseInfo { //! Whether or not the entry is an internal entry bool internal; //! The name of the extension that registered this entry (empty for core entries) - string extension_name; + Identifier extension_name; //! The SQL string of the CREATE statement string sql; //! The inherent dependencies of the created entry diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp index 272ac378f..256339d90 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_pragma_function_info.hpp @@ -11,11 +11,13 @@ #include "duckdb/parser/parsed_data/create_function_info.hpp" #include "duckdb/function/function_set.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct CreatePragmaFunctionInfo : public CreateFunctionInfo { DUCKDB_API explicit CreatePragmaFunctionInfo(PragmaFunction function); - DUCKDB_API CreatePragmaFunctionInfo(string name, PragmaFunctionSet functions); + DUCKDB_API explicit CreatePragmaFunctionInfo(PragmaFunctionSet functions); + DUCKDB_API CreatePragmaFunctionInfo(Identifier name, PragmaFunctionSet functions); PragmaFunctionSet functions; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_secret_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_secret_info.hpp index 229015f10..2dfee3158 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_secret_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_secret_info.hpp @@ -13,6 +13,7 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/common/named_parameter_map.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct CreateSecretInfo : public CreateInfo { // NOLINT: work-around bug in clang-tidy @@ -28,11 +29,11 @@ struct CreateSecretInfo : public CreateInfo { // NOLINT: work-around bug in clan //! The type of secret unique_ptr type; //! Which storage to use (empty for default) - string storage_type; + Identifier storage_type; //! (optionally) the provider of the secret credentials unique_ptr provider; //! (optionally) the name of the secret - string name; + Identifier name; //! (optionally) the scope of the secret unique_ptr scope; //! Named parameter list (if any) diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp index 17ca7f736..066ef7e06 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_sequence_info.hpp @@ -11,6 +11,7 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/common/optional.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { enum class SequenceInfo : uint8_t { @@ -32,7 +33,7 @@ struct CreateSequenceInfo : public CreateInfo { CreateSequenceInfo(); //! Sequence name to create - string name; + Identifier name; //! Usage count of the sequence uint64_t usage_count; //! The increment value diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp index ebbd2a6bf..fcc3e723c 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_table_info.hpp @@ -13,16 +13,17 @@ #include "duckdb/parser/statement/select_statement.hpp" #include "duckdb/parser/column_list.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { class SchemaCatalogEntry; struct CreateTableInfo : public CreateInfo { DUCKDB_API CreateTableInfo(); - DUCKDB_API CreateTableInfo(string catalog, string schema, string name); - DUCKDB_API CreateTableInfo(SchemaCatalogEntry &schema, string name); + DUCKDB_API CreateTableInfo(Identifier catalog, Identifier schema, Identifier name); + DUCKDB_API CreateTableInfo(SchemaCatalogEntry &schema, Identifier name); //! Table name to insert to - string table; + Identifier table; //! List of columns of the table ColumnList columns; //! List of constraints on the table diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_trigger_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_trigger_info.hpp index 27bdee735..74bdacaab 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_trigger_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_trigger_info.hpp @@ -13,13 +13,14 @@ #include "duckdb/parser/query_node.hpp" #include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct CreateTriggerInfo : public CreateInfo { CreateTriggerInfo(); //! Trigger name - string trigger_name; + Identifier trigger_name; //! The table the trigger is on unique_ptr base_table; //! When the trigger fires (BEFORE/AFTER/INSTEAD OF) @@ -27,9 +28,13 @@ struct CreateTriggerInfo : public CreateInfo { //! The event that fires the trigger (INSERT/DELETE/UPDATE) TriggerEventType event_type; //! Columns for UPDATE OF - vector columns; + vector columns; //! Whether this fires FOR EACH ROW or FOR EACH STATEMENT TriggerForEach for_each; + //! Alias for the NEW TABLE transition table (REFERENCING NEW TABLE AS ) + Identifier referencing_new_table; + //! Alias for the OLD TABLE transition table (REFERENCING OLD TABLE AS ) + Identifier referencing_old_table; //! The trigger action (INSERT/UPDATE/DELETE as QueryNode) unique_ptr trigger_action; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp index 8d25b626a..3ff18a58c 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_type_info.hpp @@ -11,6 +11,7 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { class TypeArgument { @@ -55,7 +56,7 @@ struct CreateTypeInfo : public CreateInfo { CreateTypeInfo(string name_p, LogicalType type_p, bind_logical_type_function_t bind_function_p = nullptr); //! Name of the Type - string name; + Identifier name; //! Logical Type LogicalType type; //! Used by create enum from query diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/create_view_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/create_view_info.hpp index 270a2e055..604bf3462 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/create_view_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/create_view_info.hpp @@ -11,6 +11,7 @@ #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { class SchemaCatalogEntry; @@ -19,20 +20,20 @@ enum class CreateViewBindingMode { BIND_ON_CREATE, SKIP_BINDING }; struct CreateViewInfo : public CreateInfo { public: CreateViewInfo(); - CreateViewInfo(SchemaCatalogEntry &schema, string view_name); - CreateViewInfo(string catalog_p, string schema_p, string view_name); + CreateViewInfo(SchemaCatalogEntry &schema, Identifier view_name); + CreateViewInfo(Identifier catalog_p, Identifier schema_p, Identifier view_name); public: //! View name - string view_name; + Identifier view_name; //! Aliases of the view - vector aliases; + vector aliases; //! Return types vector types; //! Names of the query - vector names; + vector names; //! Comments on columns of the query. Note: vector can be empty when no comments are set - unordered_map column_comments_map; + identifier_map_t column_comments_map; //! The SelectStatement of the view unique_ptr query; //! Whether or not to bind the view on create @@ -55,7 +56,7 @@ struct CreateViewInfo : public CreateInfo { string ToString() const override; private: - CreateViewInfo(vector names, vector comments, unordered_map column_comments); + CreateViewInfo(vector names, vector comments, identifier_map_t column_comments); vector GetColumnCommentsList() const; }; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/detach_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/detach_info.hpp index a97bafcb9..67f3ec9b6 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/detach_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/detach_info.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/enums/on_entry_not_found.hpp" @@ -21,7 +22,7 @@ struct DetachInfo : public ParseInfo { DetachInfo(); //! The alias of the attached database - string name; + Identifier name; //! Whether to throw an exception if alias is not found OnEntryNotFound if_not_found; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/disconnect_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/disconnect_info.hpp new file mode 100644 index 000000000..cd6931ef8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/disconnect_info.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/parsed_data/disconnect_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/parse_info.hpp" + +namespace duckdb { + +struct DisconnectInfo : public ParseInfo { +public: + static constexpr const ParseInfoType TYPE = ParseInfoType::DISCONNECT_INFO; + +public: + DisconnectInfo() : ParseInfo(TYPE) { + } + +public: + unique_ptr Copy() const; + string ToString() const; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/drop_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/drop_info.hpp index 40955edcb..58dbff964 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/drop_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/drop_info.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/parser/parsed_data/extra_drop_info.hpp" #include "duckdb/common/enums/on_entry_not_found.hpp" @@ -27,11 +28,11 @@ struct DropInfo : public ParseInfo { //! The catalog type to drop CatalogType type; //! Catalog name to drop from, if any - string catalog; + Identifier catalog; //! Schema name to drop from, if any - string schema; + Identifier schema; //! Element name to drop - string name; + Identifier name; //! Ignore if the entry does not exist instead of failing OnEntryNotFound if_not_found = OnEntryNotFound::THROW_EXCEPTION; //! Cascade drop (drop all dependents instead of throwing an error if there diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/exported_table_data.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/exported_table_data.hpp index 192def09f..02c233852 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/exported_table_data.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/exported_table_data.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/types/value.hpp" @@ -16,25 +17,25 @@ class TableCatalogEntry; struct ExportedTableData { //! Name of the exported table - string table_name; + Identifier table_name; //! Name of the schema - string schema_name; + Identifier schema_name; //! Name of the database - string database_name; + Identifier database_name; //! Path to be exported string file_path; //! Not Null columns, if any - vector not_null_columns; + vector not_null_columns; void Serialize(Serializer &serializer) const; static ExportedTableData Deserialize(Deserializer &deserializer); }; struct ExportedTableInfo { - ExportedTableInfo(TableCatalogEntry &entry, ExportedTableData table_data_p, vector ¬_null_columns_p); + ExportedTableInfo(TableCatalogEntry &entry, ExportedTableData table_data_p, vector ¬_null_columns_p); ExportedTableInfo(ClientContext &context, ExportedTableData table_data); TableCatalogEntry &entry; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/load_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/load_info.hpp index bc092a413..0011e8cbe 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/load_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/load_info.hpp @@ -10,9 +10,10 @@ #include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { -enum class LoadType : uint8_t { LOAD, INSTALL, FORCE_INSTALL }; +enum class LoadType : uint8_t { LOAD, INSTALL, FORCE_INSTALL, LOAD_AS }; struct LoadInfo : public ParseInfo { public: @@ -27,6 +28,7 @@ struct LoadInfo : public ParseInfo { bool repo_is_alias; string version; LoadType load_type; + Identifier alias; public: unique_ptr Copy() const; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/parse_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/parse_info.hpp index 8810c6c98..453245e2e 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/parse_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/parse_info.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { @@ -32,7 +33,9 @@ enum class ParseInfoType : uint8_t { COMMENT_ON_INFO, COMMENT_ON_COLUMN_INFO, COPY_DATABASE_INFO, - UPDATE_EXTENSIONS_INFO + UPDATE_EXTENSIONS_INFO, + CONNECT_INFO, + DISCONNECT_INFO }; struct ParseInfo { @@ -58,7 +61,7 @@ struct ParseInfo { virtual void Serialize(Serializer &serializer) const; static unique_ptr Deserialize(Deserializer &deserializer); - static string QualifierToString(const string &catalog, const string &schema, const string &name); + static string QualifierToString(const Identifier &catalog, const Identifier &schema, const Identifier &name); static string TypeToString(CatalogType type); }; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp index c76b2fcbf..335d0ea01 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/pragma_info.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/parser/parsed_expression.hpp" @@ -25,7 +26,7 @@ struct PragmaInfo : public ParseInfo { } //! Name of the PRAGMA statement - string name; + Identifier name; //! Parameter list (if any) vector> parameters; //! Named parameter list (if any) diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/transaction_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/transaction_info.hpp index f5ca9731f..0e1df39ca 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/transaction_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/transaction_info.hpp @@ -20,7 +20,11 @@ enum class TransactionModifierType : uint8_t { TRANSACTION_READ_WRITE }; -enum class TransactionInvalidationPolicy : uint8_t { STANDARD_POLICY, ALL_ERRORS_INVALIDATE_TRANSACTION }; +enum class TransactionInvalidationPolicy : uint8_t { + STANDARD_POLICY, + SYNTACTIC_ERRORS_DO_NOT_INVALIDATE, + ALL_ERRORS_INVALIDATE_TRANSACTION +}; struct TransactionInfo : public ParseInfo { public: diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/update_extensions_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/update_extensions_info.hpp index b43ab8953..b4c5e6eed 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/update_extensions_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/update_extensions_info.hpp @@ -10,6 +10,7 @@ #include "duckdb/parser/parsed_data/parse_info.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct UpdateExtensionsInfo : public ParseInfo { @@ -20,7 +21,7 @@ struct UpdateExtensionsInfo : public ParseInfo { UpdateExtensionsInfo() : ParseInfo(TYPE) { } - vector extensions_to_update; + vector extensions_to_update; public: unique_ptr Copy() const { diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp index 9b5ce8812..c6c07b41a 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp @@ -11,6 +11,7 @@ #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/parser/tableref.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { class Serializer; class Deserializer; @@ -34,7 +35,7 @@ struct VacuumInfo : public ParseInfo { explicit VacuumInfo(VacuumOptions options); const VacuumOptions options; - vector columns; + vector columns; bool has_table; unique_ptr ref; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp b/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp index 3421bca88..3398ab01b 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_expression.hpp @@ -14,8 +14,151 @@ namespace duckdb { class Deserializer; +class ParsedExpression; class Serializer; +//! Lightweight const view of a ParsedExpression's direct children. +//! Stores up to INLINE_CAPACITY children on the stack; spills to heap beyond that. +//! Iterating yields const ParsedExpression& — read-only access, works on const expressions. +class ConstChildrenView { +public: + static constexpr idx_t INLINE_CAPACITY = 8; + + ConstChildrenView() : count(0) { + } + ConstChildrenView(const ConstChildrenView &) = delete; + ConstChildrenView &operator=(const ConstChildrenView &) = delete; + ConstChildrenView(ConstChildrenView &&other) noexcept : count(other.count), overflow(std::move(other.overflow)) { + if (count <= INLINE_CAPACITY) { + for (idx_t i = 0; i < count; i++) { + inline_storage[i] = other.inline_storage[i]; + } + } + other.count = 0; + } + + void Append(const ParsedExpression &child) { + if (count < INLINE_CAPACITY) { + inline_storage[count] = &child; + } else { + if (overflow.empty()) { + overflow.reserve(count + 8); + for (idx_t i = 0; i < count; i++) { + overflow.push_back(inline_storage[i]); + } + } + overflow.push_back(&child); + } + count++; + } + + struct iterator { + const ParsedExpression **ptr; + const ParsedExpression &operator*() const { + return **ptr; + } + iterator &operator++() { + ++ptr; + return *this; + } + bool operator!=(const iterator &other) const { + return ptr != other.ptr; + } + }; + + iterator begin() { + auto *ptr = count <= INLINE_CAPACITY ? inline_storage : overflow.data(); + return {ptr}; + } + iterator end() { + auto *ptr = count <= INLINE_CAPACITY ? inline_storage : overflow.data(); + return {ptr + count}; + } + + idx_t size() const { + return count; + } + bool empty() const { + return count == 0; + } + +private: + idx_t count; + const ParsedExpression *inline_storage[INLINE_CAPACITY]; + vector overflow; +}; + +//! Lightweight mutable view of a ParsedExpression's direct children (as owning-pointer references). +//! Stores up to INLINE_CAPACITY children on the stack; spills to heap beyond that. +//! Iterating yields unique_ptr& — supports both reading and in-place replacement. +class ChildrenView { +public: + static constexpr idx_t INLINE_CAPACITY = 8; + + ChildrenView() : count(0) { + } + ChildrenView(const ChildrenView &) = delete; + ChildrenView &operator=(const ChildrenView &) = delete; + ChildrenView(ChildrenView &&other) noexcept : count(other.count), overflow(std::move(other.overflow)) { + if (count <= INLINE_CAPACITY) { + for (idx_t i = 0; i < count; i++) { + inline_storage[i] = other.inline_storage[i]; + } + } + other.count = 0; + } + + void Append(unique_ptr &child) { + if (count < INLINE_CAPACITY) { + inline_storage[count] = &child; + } else { + if (overflow.empty()) { + overflow.reserve(count + 8); + for (idx_t i = 0; i < count; i++) { + overflow.push_back(inline_storage[i]); + } + } + overflow.push_back(&child); + } + count++; + } + + struct iterator { + unique_ptr **ptr; + unique_ptr &operator*() const { + return **ptr; + } + iterator &operator++() { + ++ptr; + return *this; + } + bool operator!=(const iterator &other) const { + return ptr != other.ptr; + } + }; + + iterator begin() { + auto *ptr = count <= INLINE_CAPACITY ? inline_storage : overflow.data(); + return {ptr}; + } + iterator end() { + auto *ptr = count <= INLINE_CAPACITY ? inline_storage : overflow.data(); + return {ptr + count}; + } + + idx_t size() const { + return count; + } + bool empty() const { + return count == 0; + } + +private: + idx_t count; + unique_ptr *inline_storage[INLINE_CAPACITY]; + vector *> overflow; +}; + //! The ParsedExpression class is a base class that can represent any expression //! part of a SQL statement. /*! @@ -39,27 +182,26 @@ class ParsedExpression : public BaseExpression { bool HasParameter() const override; bool Equals(const BaseExpression &other) const override; + virtual bool Equals(const ParsedExpression &other) const; hash_t Hash() const override; //! Create a copy of this expression virtual unique_ptr Copy() const = 0; + //! Returns a read-only view over all direct ParsedExpression children of this expression + ConstChildrenView Children() const; + //! Returns a mutable view over all direct ParsedExpression children (as unique_ptr references) + ChildrenView ChildrenMutable(); + virtual void Serialize(Serializer &serializer) const; static unique_ptr Deserialize(Deserializer &deserializer); + //! Copies the base class properties (type, expression_class, alias, query_location) from other into this + void CopyBase(const ParsedExpression &other); + static bool Equals(const unique_ptr &left, const unique_ptr &right); static bool ListEquals(const vector> &left, const vector> &right); - -protected: - //! Copy base Expression properties from another expression to this one, - //! used in Copy method - void CopyProperties(const ParsedExpression &other) { - type = other.type; - expression_class = other.expression_class; - alias = other.alias; - query_location = other.query_location; - } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parser.hpp b/src/duckdb/src/include/duckdb/parser/parser.hpp index 5a1deff58..1864a580d 100644 --- a/src/duckdb/src/include/duckdb/parser/parser.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser.hpp @@ -20,6 +20,7 @@ namespace duckdb { struct ParserCache; +struct MatcherToken; class GroupByNode; struct UnicodeSpace { UnicodeSpace(idx_t pos, idx_t bytes) : pos(pos), bytes(bytes) { @@ -47,6 +48,15 @@ class Parser { //! variable. void ParseQuery(const string &query); + //! Parse a single TopLevelStatement from an already-tokenized stream starting at + //! `token_cursor`. On success advances `token_cursor` past the consumed tokens and returns + //! the SQLStatement. Returns nullptr at end-of-input or when the matched TLS was a + //! separator-only run (no statement). Throws ParserException on syntax error. + //! + //! Does NOT populate `stmt->query` — the caller owns the source string and can slice it + //! using `stmt->stmt_location` / `stmt->stmt_length` if needed. + DUCKDB_API unique_ptr ParseTopLevelStatement(vector &tokens, idx_t &token_cursor); + //! Tokenize a query, returning the raw tokens together with their locations static vector Tokenize(const string &query); @@ -67,7 +77,7 @@ class Parser { //! Parses a list as found in an ORDER BY expression (i.e. including optional ASCENDING/DESCENDING modifiers) static vector ParseOrderList(const string &select_list, ParserOptions options = ParserOptions()); //! Parses an update list (i.e. the list found in the SET clause of an UPDATE statement) - static void ParseUpdateList(const string &update_list, vector &update_columns, + static void ParseUpdateList(const string &update_list, vector &update_columns, vector> &expressions, ParserOptions options = ParserOptions()); //! Parses a VALUES list (i.e. the list of expressions after a VALUES clause) diff --git a/src/duckdb/src/include/duckdb/parser/parser_extension.hpp b/src/duckdb/src/include/duckdb/parser/parser_extension.hpp index 391ba5f00..4fce53f13 100644 --- a/src/duckdb/src/include/duckdb/parser/parser_extension.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser_extension.hpp @@ -13,10 +13,23 @@ #include "duckdb/common/enums/statement_type.hpp" #include "duckdb/function/table_function.hpp" #include "duckdb/parser/sql_statement.hpp" +#include "duckdb/parser/peg/token_type.hpp" namespace duckdb { struct DBConfig; +//! A minimal token view handed to parser extensions: the token text together with its classified +//! TokenType. +struct SimpleToken { + SimpleToken(string text_p, TokenType type_p) : text(std::move(text_p)), type(type_p) { + } + + //! The raw token text as it appears in the query + string text; + //! The classified type of the token + TokenType type; +}; + //! The ParserExtensionInfo holds static information relevant to the parser extension //! It is made available in the parse_function, and will be kept alive as long as the database system is kept alive struct ParserExtensionInfo { @@ -74,9 +87,24 @@ struct ParserExtensionParseResult { string error; //! The error location (if unsuccessful) optional_idx error_location; + //! How many leading tokens (from the start of the `tokens` view) the extension claimed: + //! > 0 : SUCCESS — the extension accepted that many tokens; the peeler resumes parsing + //! after them. + //! == 0 : the extension ran fine but did not claim any tokens (it's not taking this input) — + //! the peeler lets the next extension try. + //! < 0 : the extension wants to surface an error — the peeler throws (using `error` / + //! `error_location` if set). + //! A positive count must not exceed the number of tokens in the view, or the peeler throws. + int64_t consumed_tokens = 0; }; -typedef ParserExtensionParseResult (*parse_function_t)(ParserExtensionInfo *info, const string &query); +//! Called when the PEG parser fails to parse a statement. +//! +//! `tokens` is the tokenized view of the source tail from the PEG failure point onward — one +//! SimpleToken (text + TokenType) per token, in source order. The extension dispatches on this +//! token stream and reports, via `ParserExtensionParseResult::consumed_tokens`, how many leading +//! tokens it claimed (> 0 success, 0 = ran but claimed nothing, < 0 = throw). +typedef ParserExtensionParseResult (*parse_function_t)(ParserExtensionInfo *info, const vector &tokens); //===--------------------------------------------------------------------===// // Plan //===--------------------------------------------------------------------===// @@ -86,7 +114,7 @@ struct ParserExtensionPlanResult { // NOLINT: work-around bug in clang-tidy //! Parameters to the function vector parameters; //! The set of databases that will be modified by this statement (empty for a read-only statement) - unordered_map modified_databases; + identifier_map_t modified_databases; //! Whether or not the statement requires a valid transaction to be executed bool requires_valid_transaction = true; //! What type of result set the statement returns diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/add_column_entry.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/add_column_entry.hpp index f875bbab0..4df700ae4 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/add_column_entry.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/add_column_entry.hpp @@ -4,20 +4,13 @@ #include "duckdb/parser/column_definition.hpp" #include "duckdb/parser/constraint.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct AddColumnEntry { LogicalType type; - vector column_path; + vector column_path; unique_ptr default_value; - - AddColumnEntry Copy() { - AddColumnEntry result; - result.type = type; - result.column_path = column_path; - result.default_value = default_value->Copy(); - return result; - } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/analyze_target.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/analyze_target.hpp index 6920d4914..1d795f848 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/analyze_target.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/analyze_target.hpp @@ -1,9 +1,10 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct AnalyzeTarget { unique_ptr ref; - vector columns; + vector columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/column_constraint_entry.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/column_constraint_entry.hpp new file mode 100644 index 000000000..327502208 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/column_constraint_entry.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "duckdb/common/enums/compression_type.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/parser/constraint.hpp" +#include "duckdb/parser/parsed_expression.hpp" + +namespace duckdb { + +struct ColumnConstraintEntry { + string constraint_name; + pair constraint_type_info; + unique_ptr expression; + unique_ptr constraint; + CompressionType compression_type; + + ColumnConstraintEntry() + : constraint_type_info(false, ConstraintType::INVALID), compression_type(CompressionType::COMPRESSION_AUTO) { + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/column_elements.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/column_elements.hpp index 9a5a34a98..12049530b 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/column_elements.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/column_elements.hpp @@ -7,7 +7,7 @@ namespace duckdb { struct ColumnElements { - ColumnList columns; + ColumnList columns {false}; vector> constraints; vector> partition_keys; vector> sort_keys; diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/create_table_column_element.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/create_table_column_element.hpp new file mode 100644 index 000000000..3f523e725 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/create_table_column_element.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "duckdb/parser/peg/ast/generated_column_definition.hpp" + +namespace duckdb { +struct CreateTableColumnElement { + unique_ptr column_definition; + unique_ptr constraint; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/create_table_as.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/create_table_definition.hpp similarity index 77% rename from src/duckdb/src/include/duckdb/parser/peg/ast/create_table_as.hpp rename to src/duckdb/src/include/duckdb/parser/peg/ast/create_table_definition.hpp index ebcd4d8d2..747624325 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/create_table_as.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/create_table_definition.hpp @@ -1,17 +1,18 @@ #pragma once + #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/parser/column_list.hpp" +#include "duckdb/parser/constraint.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/statement/select_statement.hpp" namespace duckdb { -struct CreateTableAs { - bool with_data = false; +struct CreateTableDefinition { unique_ptr select_statement; - ColumnList column_names; + ColumnList columns; + vector> constraints; vector> partition_keys; vector> sort_keys; case_insensitive_map_t> options; }; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/describe_target.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/describe_target.hpp new file mode 100644 index 000000000..402f3564d --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/describe_target.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "duckdb/common/identifier.hpp" +#include "duckdb/common/common.hpp" +#include "duckdb/common/string.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" + +namespace duckdb { + +struct DescribeTarget { + bool is_table_name = false; + Identifier table_name; + unique_ptr table_ref; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/extension_repository_info.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/extension_repository_info.hpp index 05e399db5..72f62fdd6 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/extension_repository_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/extension_repository_info.hpp @@ -1,9 +1,10 @@ #pragma once #include "duckdb/common/string.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct ExtensionRepositoryInfo { - string name; + Identifier name; bool repository_is_alias = false; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/generic_copy_option.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/generic_copy_option.hpp index 40876fcf7..d614e0abe 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/generic_copy_option.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/generic_copy_option.hpp @@ -5,10 +5,11 @@ #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct GenericCopyOption { - string name; + Identifier name; vector children; // Default value unique_ptr expression = nullptr; @@ -39,7 +40,7 @@ struct GenericCopyOption { GenericCopyOption(GenericCopyOption &&other) noexcept = default; GenericCopyOption &operator=(GenericCopyOption &&other) noexcept = default; - unique_ptr GetFirstChildOrExpression() { + unique_ptr GetFirstChildOrExpression() const { if (!children.empty()) { return make_uniq(children[0]); } diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/generic_copy_option_value.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/generic_copy_option_value.hpp new file mode 100644 index 000000000..c2c376d58 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/generic_copy_option_value.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "duckdb/common/vector.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/result_modifier.hpp" + +namespace duckdb { + +struct GenericCopyOptionValue { + bool has_value = false; + bool is_order_list = false; + vector order_list; + unique_ptr expression; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/join_qualifier.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/join_qualifier.hpp index 600e60c8a..3a44662cf 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/join_qualifier.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/join_qualifier.hpp @@ -1,10 +1,11 @@ #pragma once #include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct JoinQualifier { unique_ptr on_clause; - vector using_columns; + vector using_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/macro_parameter.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/macro_parameter.hpp index 4bb8d8219..76eb3329e 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/macro_parameter.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/macro_parameter.hpp @@ -1,10 +1,11 @@ #pragma once #include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct MacroParameter { unique_ptr expression; - string name; + Identifier name; LogicalType type = LogicalType::UNKNOWN; bool is_default = false; }; diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/on_conflict_expression_target.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/on_conflict_expression_target.hpp index db0c862c3..51a72f377 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/on_conflict_expression_target.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/on_conflict_expression_target.hpp @@ -4,10 +4,11 @@ #include "duckdb/common/string.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct OnConflictExpressionTarget { - vector indexed_columns; + vector indexed_columns; unique_ptr where_clause; // Default value is defined here }; diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/setting_info.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/setting_info.hpp index 16e41d813..39831841c 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/setting_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/setting_info.hpp @@ -3,10 +3,11 @@ #include "duckdb/common/enums/set_scope.hpp" #include "duckdb/common/string.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct SettingInfo { - string name; + Identifier name; SetScope scope = SetScope::AUTOMATIC; // Default value is defined here }; diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/table_alias.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/table_alias.hpp index 9201f9266..1aa8c3757 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/table_alias.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/table_alias.hpp @@ -1,9 +1,10 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct TableAlias { - string name; - vector column_name_alias; + Identifier name; + vector column_name_alias; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/trigger_event_info.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/trigger_event_info.hpp index 7fd2da2fa..d8ad6963f 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/trigger_event_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/trigger_event_info.hpp @@ -2,9 +2,10 @@ #include "duckdb/common/enums/trigger_type.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct TriggerEventInfo { TriggerEventType event_type; - vector columns; + vector columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/trigger_table_referencing_info.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/trigger_table_referencing_info.hpp new file mode 100644 index 000000000..786c2c671 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/trigger_table_referencing_info.hpp @@ -0,0 +1,10 @@ +#pragma once +#include "duckdb/common/string.hpp" + +#include "duckdb/common/identifier.hpp" +namespace duckdb { +struct TriggerTableReferencingInfo { + Identifier new_table; + Identifier old_table; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/ast/unpivot_name_values.hpp b/src/duckdb/src/include/duckdb/parser/peg/ast/unpivot_name_values.hpp index bb2617525..ae3a75872 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/ast/unpivot_name_values.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/ast/unpivot_name_values.hpp @@ -1,9 +1,10 @@ #pragma once #include "duckdb/parser/tableref/pivotref.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { struct UnpivotNameValues { - vector unpivot_names; + vector unpivot_names; PivotColumn column; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp b/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp index 5789febca..f84886966 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp @@ -4,1481 +4,1642 @@ namespace duckdb { const char INLINED_PEG_GRAMMAR[] = { - "Program <- Statement? (';'+ Statement)* ';'*\n" - "Statement <-\n" - " CreateStatement /\n" - " SelectStatement /\n" - " SetStatement /\n" - " PragmaStatement /\n" - " CallStatement /\n" - " InsertStatement /\n" - " DropStatement /\n" - " CopyStatement /\n" - " ExplainStatement /\n" - " UpdateStatement /\n" - " PrepareStatement /\n" - " ExecuteStatement /\n" - " AlterStatement /\n" - " TransactionStatement /\n" - " DeleteStatement /\n" - " AttachStatement /\n" - " UseStatement /\n" - " DetachStatement /\n" - " CheckpointStatement /\n" - " VacuumStatement /\n" - " ResetStatement /\n" - " ExportStatement /\n" - " ImportStatement /\n" - " CommentStatement /\n" - " DeallocateStatement /\n" - " TruncateStatement /\n" - " LoadStatement /\n" - " InstallStatement /\n" - " UpdateExtensionsStatement /\n" - " AnalyzeStatement /\n" - " MergeIntoStatement /\n" - " ExpressionStatement\n" - "ExpressionStatement <- List(ExpressionAlias)\n" - "ExpressionAlias <- ColIdExpression / ExpressionAsCollabel / Expression\n" - "SemiList1(D) <- D (';' D)*\n" - "SemiList(D) <- SemiList1(D) ';'?\n" - "CatalogName <- Identifier\n" - "SchemaName <- Identifier\n" - "ReservedSchemaName <- Identifier\n" - "TableName <- Identifier\n" - "ReservedTableName <- Identifier\n" - "ReservedIdentifier <- Identifier\n" - "ColumnName <- Identifier\n" - "ReservedColumnName <- Identifier\n" - "IndexName <- Identifier\n" - "SettingName <- Identifier\n" - "PragmaName <- Identifier\n" - "FunctionName <- Identifier\n" - "ReservedFunctionName <- Identifier\n" - "TableFunctionName <- Identifier\n" - "ConstraintName <- ColIdOrString\n" - "SequenceName <- Identifier\n" - "TriggerName <- Identifier\n" - "CollationName <- Identifier\n" - "CopyOptionName <- ColLabel\n" - "SecretName <- ColId\n" - "NumberLiteral <- < [+-]?[0-9]*([.][0-9]*)? >\n" - "StringLiteral <- '\\'' [^\\']* '\\''\n" - "Type <- (TimeType / IntervalType / BitType / RowType / VariantType / MapType / GeometryType / UnionType / NumericType / SetofType / SimpleType) ArrayBounds*\n" - "SimpleType <- (CharacterType / QualifiedTypeName) TypeModifiers?\n" - "CharacterType <- ('CHARACTER' 'VARYING'?) /\n" - " ('CHAR' 'VARYING'?) /\n" - " ('NATIONAL' 'CHARACTER' 'VARYING'?) /\n" - " ('NATIONAL' 'CHAR' 'VARYING'?) /\n" - " ('NCHAR' 'VARYING'?) /\n" - " 'VARCHAR'\n" - "IntervalType <- IntervalInterval / IntervalNumber\n" - "IntervalInterval <- 'INTERVAL' Interval?\n" - "IntervalNumber <- 'INTERVAL' Parens(NumberLiteral)\n" - "YearKeyword <- 'YEAR' / 'YEARS'\n" - "MonthKeyword <- 'MONTH' / 'MONTHS'\n" - "DayKeyword <- 'DAY' / 'DAYS'\n" - "HourKeyword <- 'HOUR' / 'HOURS'\n" - "MinuteKeyword <- 'MINUTE' / 'MINUTES'\n" - "SecondKeyword <- 'SECOND' / 'SECONDS'\n" - "MillisecondKeyword <- 'MILLISECOND' / 'MILLISECONDS'\n" - "MicrosecondKeyword <- 'MICROSECOND' / 'MICROSECONDS'\n" - "WeekKeyword <- 'WEEK' / 'WEEKS'\n" - "QuarterKeyword <- 'QUARTER' / 'QUARTERS'\n" - "DecadeKeyword <- 'DECADE' / 'DECADES'\n" - "CenturyKeyword <- 'CENTURY' / 'CENTURIES'\n" - "MillenniumKeyword <- 'MILLENNIUM' / 'MILLENNIA'\n" - "Interval <- IntervalToInterval /\n" - " YearKeyword /\n" - " MonthKeyword /\n" - " DayKeyword /\n" - " HourKeyword /\n" - " MinuteKeyword /\n" - " SecondKeyword /\n" - " MillisecondKeyword /\n" - " MicrosecondKeyword /\n" - " WeekKeyword /\n" - " QuarterKeyword /\n" - " DecadeKeyword /\n" - " CenturyKeyword /\n" - " MillenniumKeyword\n" - "IntervalToInterval <- (YearKeyword 'TO' MonthKeyword) /\n" - " (DayKeyword 'TO' HourKeyword) /\n" - " (DayKeyword 'TO' MinuteKeyword) /\n" - " (DayKeyword 'TO' SecondKeyword) /\n" - " (HourKeyword 'TO' MinuteKeyword) /\n" - " (HourKeyword 'TO' SecondKeyword) /\n" - " (MinuteKeyword 'TO' SecondKeyword)\n" - "BitType <- 'BIT' 'VARYING'? Parens(List(Expression))?\n" - "GeometryType <- 'GEOMETRY' Parens(Expression)?\n" - "VariantType <- 'VARIANT'\n" - "NumericType <- SimpleNumericType / DecimalNumericType\n" - "SimpleNumericType <- IntType / IntegerType / SmallintType /\n" - " BigintType / RealType / BooleanType / DoubleType\n" - "DecimalNumericType <- FloatType / DecimalType / DecType / NumericModType\n" - "IntType <- 'INT'\n" - "IntegerType <- 'INTEGER'\n" - "SmallintType <- 'SMALLINT'\n" - "BigintType <- 'BIGINT'\n" - "RealType <- 'REAL'\n" - "BooleanType <- 'BOOLEAN'\n" - "DoubleType <- ('DOUBLE' 'PRECISION')\n" - "FloatType <- 'FLOAT' Parens(NumberLiteral)?\n" - "DecimalType <- 'DECIMAL' TypeModifiers?\n" - "DecType <- 'DEC' TypeModifiers?\n" - "NumericModType <- 'NUMERIC' TypeModifiers?\n" - "QualifiedTypeName <- (Identifier '.')* TypeName\n" - "TypeModifiers <- Parens(List(Expression)?)\n" - "RowType <- RowOrStruct ColIdTypeList\n" - "SetofType <- 'SETOF' Type\n" - "UnionType <- 'UNION' ColIdTypeList\n" - "ColIdTypeList <- Parens(List(ColIdType))\n" - "MapType <- 'MAP' Parens(List(Type))\n" - "ColIdType <- ColId Type\n" - "ArrayBounds <- SquareBracketsArray / 'ARRAY'\n" - "SquareBracketsArray <- '[' Expression? ']'\n" - "TimeType <- TimeOrTimestamp TypeModifiers? TimeZone?\n" - "TimeOrTimestamp <- TimeTypeId / TimestampTypeId\n" - "TimeTypeId <- 'TIME'\n" - "TimestampTypeId <- 'TIMESTAMP'\n" - "TimeZone <- WithOrWithout 'TIME' 'ZONE'\n" - "WithOrWithout <- WithRule / WithoutRule\n" - "WithRule <- 'WITH'\n" - "WithoutRule <- 'WITHOUT'\n" - "RowOrStruct <- 'ROW' / 'STRUCT'\n" - "# internal definitions\n" - "%whitespace <- [ \\t\\n\\r]*\n" - "List(D) <- D (',' D)* ','?\n" - "Parens(D) <- '(' D ')'\n" - "UnreservedKeyword <- 'ABORT' /\n" - "'ABSOLUTE' /\n" - "'ACCESS' /\n" - "'ACTION' /\n" - "'ADD' /\n" - "'ADMIN' /\n" - "'AFTER' /\n" - "'AGGREGATE' /\n" - "'ALSO' /\n" - "'ALTER' /\n" - "'ALWAYS' /\n" - "'ASSERTION' /\n" - "'ASSIGNMENT' /\n" - "'ATTACH' /\n" - "'ATTRIBUTE' /\n" - "'BACKWARD' /\n" - "'BEFORE' /\n" - "'BEGIN' /\n" - "'CACHE' /\n" - "'CALL' /\n" - "'CALLED' /\n" - "'CASCADE' /\n" - "'CASCADED' /\n" - "'CATALOG' /\n" - "'CENTURY' /\n" - "'CENTURIES' /\n" - "'CHAIN' /\n" - "'CHARACTERISTICS' /\n" - "'CHECKPOINT' /\n" - "'CLASS' /\n" - "'CLOSE' /\n" - "'CLUSTER' /\n" - "'COMMENT' /\n" - "'COMMENTS' /\n" - "'COMMIT' /\n" - "'COMMITTED' /\n" - "'COMPRESSION' /\n" - "'CONFIGURATION' /\n" - "'CONFLICT' /\n" - "'CONNECTION' /\n" - "'CONSTRAINTS' /\n" - "'CONTENT' /\n" - "'CONTINUE' /\n" - "'CONVERSION' /\n" - "'COPY' /\n" - "'COST' /\n" - "'CSV' /\n" - "'CUBE' /\n" - "'CURRENT' /\n" - "'CURSOR' /\n" - "'CYCLE' /\n" - "'DATA' /\n" - "'DATABASE' /\n" - "'DAY' /\n" - "'DAYS' /\n" - "'DEALLOCATE' /\n" - "'DECADE' /\n" - "'DECADES' /\n" - "'DECLARE' /\n" - "'DEFAULTS' /\n" - "'DEFERRED' /\n" - "'DEFINER' /\n" - "'DELETE' /\n" - "'DELIMITER' /\n" - "'DELIMITERS' /\n" - "'DEPENDS' /\n" - "'DETACH' /\n" - "'DICTIONARY' /\n" - "'DISABLE' /\n" - "'DISCARD' /\n" - "'DOCUMENT' /\n" - "'DOMAIN' /\n" - "'DOUBLE' /\n" - "'DROP' /\n" - "'EACH' /\n" - "'ENABLE' /\n" - "'ENCODING' /\n" - "'ENCRYPTED' /\n" - "'ENUM' /\n" - "'ERROR' /\n" - "'ESCAPE' /\n" - "'EVENT' /\n" - "'EXCLUDE' /\n" - "'EXCLUDING' /\n" - "'EXCLUSIVE' /\n" - "'EXECUTE' /\n" - "'EXPLAIN' /\n" - "'EXPORT' /\n" - "'EXPORT_STATE' /\n" - "'EXTENSION' /\n" - "'EXTENSIONS' /\n" - "'EXTERNAL' /\n" - "'FAMILY' /\n" - "'FILTER' /\n" - "'FIRST' /\n" - "'FOLLOWING' /\n" - "'FORCE' /\n" - "'FORWARD' /\n" - "'FUNCTION' /\n" - "'FUNCTIONS' /\n" - "'GLOBAL' /\n" - "'GRANT' /\n" - "'GRANTED' /\n" - "'GROUPS' /\n" - "'HANDLER' /\n" - "'HEADER' /\n" - "'HOLD' /\n" - "'HOUR' /\n" - "'HOURS' /\n" - "'IDENTITY' /\n" - "'IF' /\n" - "'IGNORE' /\n" - "'IMMEDIATE' /\n" - "'IMMUTABLE' /\n" - "'IMPLICIT' /\n" - "'IMPORT' /\n" - "'INCLUDE' /\n" - "'INCLUDING' /\n" - "'INCREMENT' /\n" - "'INDEX' /\n" - "'INDEXES' /\n" - "'INHERIT' /\n" - "'INHERITS' /\n" - "'INLINE' /\n" - "'INPUT' /\n" - "'INSENSITIVE' /\n" - "'INSERT' /\n" - "'INSTALL' /\n" - "'INSTEAD' /\n" - "'INVOKER' /\n" - "'JSON' /\n" - "'ISOLATION' /\n" - "'KEY' /\n" - "'LABEL' /\n" - "'LANGUAGE' /\n" - "'LARGE' /\n" - "'LAST' /\n" - "'LEAKPROOF' /\n" - "'LEVEL' /\n" - "'LISTEN' /\n" - "'LOAD' /\n" - "'LOCAL' /\n" - "'LOCATION' /\n" - "'LOCK' /\n" - "'LOCKED' /\n" - "'LOGGED' /\n" - "'MACRO' /\n" - "'MAPPING' /\n" - "'MATCH' /\n" - "'MATCHED' /\n" - "'MATERIALIZED' /\n" - "'MAXVALUE' /\n" - "'MERGE' /\n" - "'METHOD' /\n" - "'MICROSECOND' /\n" - "'MICROSECONDS' /\n" - "'MILLENNIUM' /\n" - "'MILLENNIA' /\n" - "'MILLISECOND' /\n" - "'MILLISECONDS' /\n" - "'MINUTE' /\n" - "'MINUTES' /\n" - "'MINVALUE' /\n" - "'MODE' /\n" - "'MONTH' /\n" - "'MONTHS' /\n" - "'MOVE' /\n" - "'NAME' /\n" - "'NAMES' /\n" - "'NEW' /\n" - "'NEXT' /\n" - "'NO' /\n" - "'NOTHING' /\n" - "'NOTIFY' /\n" - "'NOWAIT' /\n" - "'NULLS' /\n" - "'OBJECT' /\n" - "'OF' /\n" - "'OFF' /\n" - "'OIDS' /\n" - "'OLD' /\n" - "'OPERATOR' /\n" - "'OPTION' /\n" - "'OPTIONS' /\n" - "'ORDINALITY' /\n" - "'OTHERS' /\n" - "'OVER' /\n" - "'OVERRIDING' /\n" - "'OWNED' /\n" - "'OWNER' /\n" - "'PARALLEL' /\n" - "'PARSER' /\n" - "'PARTIAL' /\n" - "'PARTITION' /\n" - "'PARTITIONED' /\n" - "'PASSING' /\n" - "'PASSWORD' /\n" - "'PERCENT' /\n" - "'PERSISTENT' /\n" - "'PLANS' /\n" - "'POLICY' /\n" - "'PRAGMA' /\n" - "'PRECEDING' /\n" - "'PREPARE' /\n" - "'PREPARED' /\n" - "'PRESERVE' /\n" - "'PRIOR' /\n" - "'PRIVILEGES' /\n" - "'PROCEDURAL' /\n" - "'PROCEDURE' /\n" - "'PROGRAM' /\n" - "'PUBLICATION' /\n" - "'QUARTER' /\n" - "'QUARTERS' /\n" - "'QUOTE' /\n" - "'RANGE' /\n" - "'READ' /\n" - "'REASSIGN' /\n" - "'RECHECK' /\n" - "'RECURSIVE' /\n" - "'REF' /\n" - "'REFERENCING' /\n" - "'REFRESH' /\n" - "'REINDEX' /\n" - "'RELATIVE' /\n" - "'RELEASE' /\n" - "'RENAME' /\n" - "'REPEATABLE' /\n" - "'REPLACE' /\n" - "'REPLICA' /\n" - "'RESET' /\n" - "'RESPECT' /\n" - "'RESTART' /\n" - "'RESTRICT' /\n" - "'RETURNS' /\n" - "'REVOKE' /\n" - "'ROLE' /\n" - "'ROLLBACK' /\n" - "'ROLLUP' /\n" - "'ROWS' /\n" - "'RULE' /\n" - "'SAMPLE' /\n" - "'SAVEPOINT' /\n" - "'SCHEMA' /\n" - "'SCHEMAS' /\n" - "'SCOPE' /\n" - "'SCROLL' /\n" - "'SEARCH' /\n" - "'SECRET' /\n" - "'SECOND' /\n" - "'SECONDS' /\n" - "'SECURITY' /\n" - "'SEQUENCE' /\n" - "'SEQUENCES' /\n" - "'SERIALIZABLE' /\n" - "'SERVER' /\n" - "'SESSION' /\n" - "'SET' /\n" - "'SETS' /\n" - "'SHARE' /\n" - "'SIMPLE' /\n" - "'SKIP' /\n" - "'SNAPSHOT' /\n" - "'SORTED' /\n" - "'SOURCE' /\n" - "'SQL' /\n" - "'STABLE' /\n" - "'STANDALONE' /\n" - "'START' /\n" - "'STATEMENT' /\n" - "'STATISTICS' /\n" - "'STDIN' /\n" - "'STDOUT' /\n" - "'STORAGE' /\n" - "'STORED' /\n" - "'STRICT' /\n" - "'STRIP' /\n" - "'SUBSCRIPTION' /\n" - "'SYSID' /\n" - "'SYSTEM' /\n" - "'TABLES' /\n" - "'TABLESPACE' /\n" - "'TARGET' /\n" - "'TEMP' /\n" - "'TEMPLATE' /\n" - "'TEMPORARY' /\n" - "'TEXT' /\n" - "'TIES' /\n" - "'TRANSACTION' /\n" - "'TRANSFORM' /\n" - "'TRIGGER' /\n" - "'TRUNCATE' /\n" - "'TRUSTED' /\n" - "'TYPE' /\n" - "'TYPES' /\n" - "'UNBOUNDED' /\n" - "'UNCOMMITTED' /\n" - "'UNENCRYPTED' /\n" - "'UNKNOWN' /\n" - "'UNLISTEN' /\n" - "'UNLOGGED' /\n" - "'UNTIL' /\n" - "'UPDATE' /\n" - "'USE' /\n" - "'USER' /\n" - "'VACUUM' /\n" - "'VALID' /\n" - "'VALIDATE' /\n" - "'VALIDATOR' /\n" - "'VALUE' /\n" - "'VARIABLE' /\n" - "'VARYING' /\n" - "'VERSION' /\n" - "'VIEW' /\n" - "'VIEWS' /\n" - "'VIRTUAL' /\n" - "'VOLATILE' /\n" - "'WEEK' /\n" - "'WEEKS' /\n" - "'WHITESPACE' /\n" - "'WITHIN' /\n" - "'WITHOUT' /\n" - "'WORK' /\n" - "'WRAPPER' /\n" - "'WRITE' /\n" - "'XML' /\n" - "'YEAR' /\n" - "'YEARS' /\n" - "'YES' /\n" - "'ZONE'\n" - "ReservedKeyword <- 'ALL' /\n" - "'ANALYSE' /\n" - "'ANALYZE' /\n" - "'AND' /\n" - "'ANY' /\n" - "'ARRAY' /\n" - "'AS' /\n" - "'ASC' /\n" - "'ASYMMETRIC' /\n" - "'BOTH' /\n" - "'CASE' /\n" - "'CAST' /\n" - "'CHECK' /\n" - "'COLLATE' /\n" - "'COLUMN' /\n" - "'CONSTRAINT' /\n" - "'CREATE' /\n" - "'DEFAULT' /\n" - "'DEFERRABLE' /\n" - "'DESC' /\n" - "'DESCRIBE' /\n" - "'DISTINCT' /\n" - "'DO' /\n" - "'ELSE' /\n" - "'END' /\n" - "'EXCEPT' /\n" - "'FALSE' /\n" - "'FETCH' /\n" - "'FOR' /\n" - "'FOREIGN' /\n" - "'FROM' /\n" - "'GROUP' /\n" - "'HAVING' /\n" - "'QUALIFY' /\n" - "'IN' /\n" - "'INITIALLY' /\n" - "'INTERSECT' /\n" - "'INTO' /\n" - "'LAMBDA' /\n" - "'LATERAL' /\n" - "'LEADING' /\n" - "'LIMIT' /\n" - "'NOT' /\n" - "'NULL' /\n" - "'OFFSET' /\n" - "'ON' /\n" - "'ONLY' /\n" - "'OR' /\n" - "'ORDER' /\n" - "'PIVOT' /\n" - "'PIVOT_WIDER' /\n" - "'PIVOT_LONGER' /\n" - "'PLACING' /\n" - "'PRIMARY' /\n" - "'REFERENCES' /\n" - "'RETURNING' /\n" - "'SELECT' /\n" - "'SHOW' /\n" - "'SOME' /\n" - "'SUMMARIZE' /\n" - "'SYMMETRIC' /\n" - "'TABLE' /\n" - "'THEN' /\n" - "'TO' /\n" - "'TRAILING' /\n" - "'TRUE' /\n" - "'UNION' /\n" - "'UNIQUE' /\n" - "'UNPIVOT' /\n" - "'USING' /\n" - "'VARIADIC' /\n" - "'WHEN' /\n" - "'WHERE' /\n" - "'WINDOW' /\n" - "'WITH'\n" - "ColumnNameKeyword <- 'BETWEEN' /\n" - "'BIGINT' /\n" - "'BIT' /\n" - "'BOOLEAN' /\n" - "'CHAR' /\n" - "'CHARACTER' /\n" - "'COALESCE' /\n" - "'COLUMNS' /\n" - "'DEC' /\n" - "'DECIMAL' /\n" - "'EXISTS' /\n" - "'EXTRACT' /\n" - "'FLOAT' /\n" - "'GENERATED' /\n" - "'GROUPING' /\n" - "'GROUPING_ID' /\n" - "'INOUT' /\n" - "'INT' /\n" - "'INTEGER' /\n" - "'INTERVAL' /\n" - "'MAP' /\n" - "'NATIONAL' /\n" - "'NCHAR' /\n" - "'NONE' /\n" - "'NULLIF' /\n" - "'NUMERIC' /\n" - "'OUT' /\n" - "'OVERLAY' /\n" - "'POSITION' /\n" - "'PRECISION' /\n" - "'REAL' /\n" - "'ROW' /\n" - "'SETOF' /\n" - "'SMALLINT' /\n" - "'SUBSTRING' /\n" - "'STRUCT' /\n" - "'TIME' /\n" - "'TIMESTAMP' /\n" - "'TREAT' /\n" - "'TRIM' /\n" - "'TRY_CAST' /\n" - "'VALUES' /\n" - "'VARCHAR' /\n" - "'XMLATTRIBUTES' /\n" - "'XMLCONCAT' /\n" - "'XMLELEMENT' /\n" - "'XMLEXISTS' /\n" - "'XMLFOREST' /\n" - "'XMLNAMESPACES' /\n" - "'XMLPARSE' /\n" - "'XMLPI' /\n" - "'XMLROOT' /\n" - "'XMLSERIALIZE' /\n" - "'XMLTABLE'\n" - "FuncNameKeyword <- 'ASOF' /\n" - "'AT' /\n" - "'AUTHORIZATION' /\n" - "'BINARY' /\n" - "'COLLATION' /\n" - "'CONCURRENTLY' /\n" - "'CROSS' /\n" - "'FREEZE' /\n" - "'FULL' /\n" - "'GENERATED' /\n" - "'GLOB' /\n" - "'ILIKE' /\n" - "'INNER' /\n" - "'IS' /\n" - "'ISNULL' /\n" - "'JOIN' /\n" - "'LEFT' /\n" - "'LIKE' /\n" - "'MAP' /\n" - "'NATURAL' /\n" - "'NOTNULL' /\n" - "'OUTER' /\n" - "'OVERLAPS' /\n" - "'POSITIONAL' /\n" - "'RIGHT' /\n" - "'SIMILAR' /\n" - "'STRUCT' /\n" - "'TABLESAMPLE' /\n" - "'VERBOSE'\n" - "TypeNameKeyword <- 'ASOF' /\n" - "'AT' /\n" - "'AUTHORIZATION' /\n" - "'BINARY' /\n" - "'BY' /\n" - "'COLLATION' /\n" - "'COLUMNS' /\n" - "'CONCURRENTLY' /\n" - "'CROSS' /\n" - "'FREEZE' /\n" - "'FULL' /\n" - "'GLOB' /\n" - "'ILIKE' /\n" - "'INNER' /\n" - "'IS' /\n" - "'ISNULL' /\n" - "'JOIN' /\n" - "'LEFT' /\n" - "'LIKE' /\n" - "'NATURAL' /\n" - "'NOTNULL' /\n" - "'OUTER' /\n" - "'OVERLAPS' /\n" - "'POSITIONAL' /\n" - "'RIGHT' /\n" - "'UNPACK' /\n" - "'SIMILAR' /\n" - "'TABLESAMPLE' /\n" - "'TRY_CAST' /\n" - "'VERBOSE' /\n" - "'SEMI' /\n" - "'ANTI'\n" - "PivotStatement <- PivotKeyword TableRef PivotOn? PivotUsing? PivotGroupByList?\n" - "PivotOn <- 'ON' PivotColumnList\n" - "PivotUsing <- 'USING' TargetList\n" - "PivotColumnList <- List(PivotColumnEntry)\n" - "PivotColumnEntry <- PivotColumnSubquery / PivotValueList / PivotColumnEntryInternal\n" - "PivotColumnEntryInternal <- Expression\n" - "PivotColumnSubquery <- BaseExpression 'IN' Parens(SelectStatementInternal)\n" - "PivotKeyword <- 'PIVOT' / 'PIVOT_WIDER'\n" - "UnpivotKeyword <- 'UNPIVOT' / 'PIVOT_LONGER'\n" - "UnpivotStatement <- UnpivotKeyword TableRef 'ON' TargetList IntoNameValues?\n" - "IntoNameValues <- 'INTO' 'NAME' ColIdOrString ValueOrValues List(Identifier)\n" - "ValueOrValues <- 'VALUE' / 'VALUES'\n" - "IncludeOrExcludeNulls <- IncludeNulls / ExcludeNulls\n" - "IncludeNulls <- 'INCLUDE' 'NULLS'\n" - "ExcludeNulls <- 'EXCLUDE' 'NULLS'\n" - "UnpivotHeader <- UnpivotHeaderSingle / UnpivotHeaderList\n" - "UnpivotHeaderSingle <- ColIdOrString\n" - "UnpivotHeaderList <- Parens(List(ColIdOrString))\n" - "ColumnReference <- CatalogReservedSchemaTableColumnName / SchemaReservedTableColumnName / TableReservedColumnName / NestedColumnName\n" - "CatalogReservedSchemaTableColumnName <- CatalogQualification ReservedSchemaQualification ReservedTableQualification ReservedColumnName\n" - "SchemaReservedTableColumnName <- SchemaQualification ReservedTableQualification ReservedColumnName\n" - "TableReservedColumnName <- TableQualification ReservedColumnName\n" - "FunctionExpression <- FunctionIdentifier Parens(DistinctOrAll? List(FunctionArgument)? OrderByClause? IgnoreOrRespectNulls?) WithinGroupClause? FilterClause? ExportClause? OverClause?\n" - "FunctionIdentifier <- CatalogReservedSchemaFunctionName / SchemaReservedFunctionName / FunctionName\n" - "CatalogReservedSchemaFunctionName <- CatalogQualification ReservedSchemaQualification? ReservedFunctionName\n" - "SchemaReservedFunctionName <- SchemaQualification ReservedFunctionName\n" - "DistinctOrAll <- 'DISTINCT' / 'ALL'\n" - "ExportClause <- 'EXPORT_STATE'\n" - "WithinGroupClause <- 'WITHIN' 'GROUP' Parens(OrderByClause)\n" - "FilterClause <- 'FILTER' Parens('WHERE'? Expression)\n" - "IgnoreOrRespectNulls <- IgnoreNulls / RespectNulls\n" - "IgnoreNulls <- 'IGNORE' 'NULLS'\n" - "RespectNulls <- 'RESPECT' 'NULLS'\n" - "ParenthesisExpression <- Parens(List(Expression))\n" - "LiteralExpression <- StringLiteral / NumberLiteral / ConstantLiteral\n" - "ConstantLiteral <- NullLiteral / TrueLiteral / FalseLiteral\n" - "NullLiteral <- 'NULL'\n" - "TrueLiteral <- 'TRUE'\n" - "FalseLiteral <- 'FALSE'\n" - "CastExpression <- CastOrTryCast Parens(Expression 'AS' Type)\n" - "CastOrTryCast <- 'CAST' / 'TRY_CAST'\n" - "StarExpression <- (ColId '.')* '*' ExcludeList? ReplaceList? RenameList?\n" - "ExcludeList <- 'EXCLUDE' ExcludeNameList / ExcludeNameSingle\n" - "ExcludeNameList <- Parens(List(ExcludeName))\n" - "ExcludeNameSingle <- ExcludeName\n" - "ExcludeName <- DottedIdentifier / ColIdOrString\n" - "ReplaceList <- 'REPLACE' ReplaceEntries\n" - "ReplaceEntries <- ReplaceEntrySingle / ReplaceEntryList\n" - "ReplaceEntrySingle <- ReplaceEntry\n" - "ReplaceEntryList <- Parens(List(ReplaceEntry))\n" - "ReplaceEntry <- Expression 'AS' ColumnReference\n" - "RenameList <- 'RENAME' (RenameEntryList / SingleRenameEntry)\n" - "RenameEntryList <- Parens(List(RenameEntry))\n" - "SingleRenameEntry <- RenameEntry\n" - "RenameEntry <- ExcludeName 'AS' Identifier\n" - "SubqueryExpression <- 'NOT'? 'EXISTS'? SubqueryReference\n" - "CaseExpression <- 'CASE' Expression? CaseWhenThen+ CaseElse? 'END'\n" - "CaseWhenThen <- 'WHEN' Expression 'THEN' Expression\n" - "CaseElse <- 'ELSE' Expression\n" - "TypeLiteral <- ColId StringLiteral\n" - "IntervalLiteral <- 'INTERVAL' IntervalParameter Interval?\n" - "IntervalParameter <- StringLiteral / NumberLiteral / ParensExpression\n" - "FrameClause <- Framing FrameExtent WindowExcludeClause?\n" - "Framing <- 'ROWS' / 'RANGE' / 'GROUPS'\n" - "FrameExtent <- BetweenFrameExtent / SingleFrameExtent\n" - "SingleFrameExtent <- FrameBound\n" - "BetweenFrameExtent <- 'BETWEEN' FrameBound 'AND' FrameBound\n" - "FrameBound <- FrameUnbounded / FrameCurrentRow / FrameExpression\n" - "FrameUnbounded <- 'UNBOUNDED' PrecedingOrFollowing\n" - "FrameExpression <- Expression PrecedingOrFollowing\n" - "FrameCurrentRow <- 'CURRENT' 'ROW'\n" - "PrecedingOrFollowing <- 'PRECEDING' / 'FOLLOWING'\n" - "WindowExcludeClause <- 'EXCLUDE' WindowExcludeElement\n" - "WindowExcludeElement <- ExcludeCurrentRow / ExcludeGroup / ExcludeTies / ExcludeNoOthers\n" - "ExcludeCurrentRow <- 'CURRENT' 'ROW'\n" - "ExcludeGroup <- 'GROUP'\n" - "ExcludeTies <- 'TIES'\n" - "ExcludeNoOthers <- 'NO' 'OTHERS'\n" - "OverClause <- 'OVER' WindowFrame\n" - "WindowFrame <- ParensIdentifier / WindowFrameDefinition / Identifier\n" - "ParensIdentifier <- Parens(Identifier)\n" - "WindowFrameDefinition <- WindowFrameNameContentsParens / WindowFrameContentsParens\n" - "WindowFrameNameContentsParens <- Parens(BaseWindowName? WindowFrameContents)\n" - "WindowFrameContentsParens <- Parens(WindowFrameContents)\n" - "WindowFrameContents <- WindowPartition? OrderByClause? FrameClause?\n" - "BaseWindowName <- Identifier\n" - "WindowPartition <- 'PARTITION' 'BY' List(Expression)\n" - "ListExpression <- ArrayBoundedListExpression / ArrayParensSelect\n" - "ArrayBoundedListExpression <- 'ARRAY'? BoundedListExpression\n" - "ArrayParensSelect <- 'ARRAY' Parens(SelectStatementInternal)\n" - "BoundedListExpression <- '[' List(Expression)? ']'\n" - "StructExpression <- '{' List(StructField) '}'\n" - "StructField <- ColIdOrString ':' Expression\n" - "MapExpression <- 'MAP' MapStructExpression\n" - "MapStructExpression <- '{' List(MapStructField)? '}'\n" - "MapStructField <- Expression ':' Expression\n" - "GroupingExpression <- GroupingOrGroupingId Parens(List(Expression)?)\n" - "GroupingOrGroupingId <- 'GROUPING' / 'GROUPING_ID'\n" - "Parameter <- QuestionMarkNumberedParameter / AnonymousParameter / NumberedParameter / ColLabelParameter\n" - "QuestionMarkNumberedParameter <- '?' NumberLiteral\n" - "AnonymousParameter <- '?'\n" - "NumberedParameter <- '$' NumberLiteral\n" - "ColLabelParameter <- '$' ColLabel\n" - "PositionalExpression <- '#' NumberLiteral\n" - "DefaultExpression <- 'DEFAULT'\n" - "ListComprehensionExpression <- '[' Expression 'FOR' List(ColIdOrString) 'IN' Expression ListComprehensionFilter? ']'\n" - "ListComprehensionFilter <- 'IF' Expression\n" - "ParensExpression <- Parens(Expression)\n" - "SingleExpression <-\n" - " ParensExpression /\n" - " LiteralExpression /\n" - " Parameter /\n" - " SubqueryExpression /\n" - " SpecialFunctionExpression /\n" - " ParenthesisExpression /\n" - " IntervalLiteral /\n" - " TypeLiteral /\n" - " CaseExpression /\n" - " StarExpression /\n" - " CastExpression /\n" - " GroupingExpression /\n" - " MapExpression /\n" - " FunctionExpression /\n" - " ColumnReference /\n" - " ListComprehensionExpression /\n" - " ListExpression /\n" - " StructExpression /\n" - " PositionalExpression /\n" - " DefaultExpression\n" - "# LEVEL 1 (Lowest)\n" - "Expression <- LambdaArrowExpression\n" - "ColumnDefaultExpr <- ColDefOrExpr\n" - "# LEVEL 1.5\n" - "LambdaArrowExpression <- LogicalOrExpression SingleArrowPair*\n" - "SingleArrowPair <- '->' LogicalOrExpression\n" - "LogicalOrExpression <- LogicalAndExpression ('OR' LogicalAndExpression)*\n" - "ColDefOrExpr <- ColDefAndExpr ('OR' ColDefAndExpr)*\n" - "# LEVEL 2\n" - "LogicalAndExpression <- LogicalNotExpression ('AND' LogicalNotExpression)*\n" - "ColDefAndExpr <- IsDistinctFromExpression ('AND' IsDistinctFromExpression)*\n" - "# LEVEL 3\n" - "LogicalNotExpression <- 'NOT'* IsExpression\n" - "# LEVEL 4\n" - "IsExpression <- IsDistinctFromExpression IsTest*\n" - "IsTest <- IsLiteral / NotNull / IsNull\n" - "IsLiteral <- 'IS' 'NOT'? (TrueLiteral / FalseLiteral / NullLiteral / UnknownLiteral)\n" - "UnknownLiteral <- 'UNKNOWN'\n" - "NotNull <- ('NOT' 'NULL') / 'NOTNULL'\n" - "IsNull <- 'ISNULL'\n" - "# LEVEL 5 (Split because IsDistinctFromExpression allows post expression while IsOperator does not)\n" - "IsDistinctFromExpression <- ComparisonExpression (IsDistinctFromOp ComparisonExpression)*\n" - "IsDistinctFromOp <- 'IS' 'NOT'? 'DISTINCT' 'FROM'\n" - "# LEVEL 6\n" - "ComparisonExpression <- BetweenInLikeExpression (ComparisonOperator 'NOT'* BetweenInLikeExpression)*\n" - "ComparisonOperator <-\n" - " OperatorEqual /\n" - " OperatorNotEqual /\n" - " OperatorLessThan /\n" - " OperatorGreaterThan /\n" - " OperatorLessThanEquals /\n" - " OperatorGreaterThanEquals\n" - "OperatorEqual <- '=' / '=='\n" - "OperatorNotEqual <- '!=' / '<>'\n" - "OperatorLessThan <- '<'\n" - "OperatorGreaterThan <- '>'\n" - "OperatorLessThanEquals <- '<='\n" - "OperatorGreaterThanEquals <- '>='\n" - "# LEVEL 7\n" - "BetweenInLikeExpression <- OtherOperatorExpression BetweenInLikeOp?\n" - "BetweenInLikeOp <- 'NOT'? (BetweenClause / InClause / LikeClause)\n" - "LikeClause <- LikeVariations OtherOperatorExpression EscapeClause?\n" - "EscapeClause <- 'ESCAPE' ComparisonExpression\n" - "LikeVariations <- SimilarToToken / ILikeToken / LikeToken / GlobToken / NotILikeOp / NotLikeOp / NotSimilarToOp\n" - "LikeToken <- 'LIKE' / '~~'\n" - "ILikeToken <- 'ILIKE' / '~~*'\n" - "GlobToken <- 'GLOB' / '~~~'\n" - "SimilarToToken <- ('SIMILAR' 'TO') / '~'\n" - "NotILikeOp <- '!~~*'\n" - "NotLikeOp <- '!~~'\n" - "NotSimilarToOp <- '!~'\n" - "InClause <- 'IN' InExpression\n" - "InExpression <- InExpressionList / InSelectStatement / OtherOperatorExpression\n" - "InExpressionList <- Parens(List(Expression))\n" - "InSelectStatement <- Parens(SelectStatementInternal)\n" - "BetweenClause <- 'BETWEEN' OtherOperatorExpression 'AND' OtherOperatorExpression\n" - "# LEVEL 8\n" - "OtherOperatorExpression <- BitwiseExpression (OtherOperator BitwiseExpression)*\n" - "OtherOperator <-\n" - " QualifiedOperator / AnyAllOperator / InetOperator / JsonOperator / ListOperator / StringOperator / OperatorLiteral\n" - "OperatorLiteral <- Identifier\n" - "AnyAllOperator <- AnyOp AnyOrAll\n" - "AnyOrAll <- SubqueryAny / SubqueryAll\n" - "SubqueryAny <- 'ANY'\n" - "SubqueryAll <- 'ALL'\n" - "InetOperator <- '>>=' / '<<='\n" - "JsonOperator <- '->>'\n" - "ListOperator <- '&&' / '@>' / '<@'\n" - "StringOperator <- '^@' / '||'\n" - "QualifiedOperator <- 'OPERATOR' Parens(AnyOp)\n" - "AnyOp <- '!~~*' / '>>=' / '<<=' / '->>' / '!~~' / '~~*' / '~~~' / '!~' / '^@' / '||' / '&&' / '@>' / '<@' / '<=' / '>=' / '<>' / '!=' / '==' / '<<' / '>>' / '//' / '**' / '->' / '~~' / '+' / '-' / '*' / '/' / '%' / '^' / '<' / '>' / '=' / '&' / '|' / '~' / '!'\n" - "# LEVEL 9\n" - "BitwiseExpression <- AdditiveExpression (BitOperator AdditiveExpression)*\n" - "BitOperator <- '&' / '|' / '<<' / '>>'\n" - "# LEVEL 10\n" - "AdditiveExpression <- MultiplicativeExpression (Term MultiplicativeExpression)*\n" - "Term <- '+' / '-'\n" - "# LEVEL 11\n" - "MultiplicativeExpression <- ExponentiationExpression (Factor ExponentiationExpression)*\n" - "Factor <- '*' / '/' / '//' / '%'\n" - "# LEVEL 12\n" - "ExponentiationExpression <- CollateExpression (ExponentOperator CollateExpression)*\n" - "ExponentOperator <- '^' / '**'\n" - "# LEVEL 13\n" - "CollateExpression <- AtTimeZoneExpression (CollateOperator AtTimeZoneExpression)*\n" - "CollateOperator <- 'COLLATE'\n" - "# LEVEL 14\n" - "AtTimeZoneExpression <- PrefixExpression (AtTimeZoneOperator PrefixExpression)*\n" - "AtTimeZoneOperator <- 'AT' 'TIME' 'ZONE'\n" - "# LEVEL 15\n" - "PrefixExpression <- PrefixOperator* BaseExpression\n" - "PrefixOperator <- QualifiedOperator / MinusPrefixOperator / PlusPrefixOperator / TildePrefixOperator\n" - "MinusPrefixOperator <- '-'\n" - "PlusPrefixOperator <- '+'\n" - "TildePrefixOperator <- '~'\n" - "# LEVEL 16 (Highest)\n" - "BaseExpression <- SingleExpression Indirection*\n" - "Indirection <- CastOperator / DotOperator / SliceExpression / PostfixOperator\n" - "CastOperator <- '::' Type\n" - "DotOperator <- '.' (MethodExpression / ColLabel)\n" - "MethodExpression <- ColLabel Parens(DistinctOrAll? List(FunctionArgument)? OrderByClause? IgnoreOrRespectNulls?)\n" - "SliceExpression <- '[' SliceBound ']'\n" - "SliceBound <- Expression? EndSliceBound? StepSliceBound?\n" - "EndSliceBound <- ':' (Expression / '-')?\n" - "StepSliceBound <- ':' Expression?\n" - "PostfixOperator <- '!'\n" - "SpecialFunctionExpression <- CoalesceExpression / UnpackExpression / TryExpression / ColumnsExpression / ExtractExpression / LambdaExpression / NullIfExpression / PositionExpression / RowExpression / SubstringExpression / TrimExpression\n" - "CoalesceExpression <- 'COALESCE' Parens(List(Expression))\n" - "UnpackExpression <- 'UNPACK' Parens(Expression)\n" - "TryExpression <- 'TRY' Parens(Expression)\n" - "ColumnsExpression <- '*'? 'COLUMNS' Parens(Expression)\n" - "ExtractExpression <- 'EXTRACT' Parens(ExtractArgument 'FROM' Expression)\n" - "LambdaExpression <- 'LAMBDA' List(ColIdOrString) ':' Expression\n" - "NullIfExpression <- 'NULLIF' Parens(Expression ',' Expression)\n" - "PositionExpression <- 'POSITION' Parens(SingleExpression 'IN' SingleExpression)\n" - "RowExpression <- 'ROW' Parens(List(Expression)?)\n" - "SubstringExpression <- 'SUBSTRING' Parens(SubstringArguments)\n" - "SubstringArguments <- SubstringParameters / SubstringExpressionList\n" - "SubstringExpressionList <- List(Expression)\n" - "SubstringParameters <- Expression 'FROM' Expression 'FOR' Expression\n" - "TrimExpression <- 'TRIM' Parens(TrimDirection? TrimSource? List(Expression))\n" - "TrimDirection <- TrimBoth / TrimLeading / TrimTrailing\n" - "TrimBoth <- 'BOTH'\n" - "TrimLeading <- 'LEADING'\n" - "TrimTrailing <- 'TRAILING'\n" - "TrimSource <- Expression? 'FROM'\n" - "ExtractArgument <-\n" - " YearKeyword / MonthKeyword / DayKeyword / HourKeyword / MinuteKeyword / SecondKeyword /\n" - " MillisecondKeyword / MicrosecondKeyword / WeekKeyword / QuarterKeyword / DecadeKeyword /\n" - " CenturyKeyword / MillenniumKeyword / Identifier / StringLiteral\n" - "ExecuteStatement <- 'EXECUTE' Identifier TableFunctionArguments?\n" - "CreateSecretStmt <- 'SECRET' IfNotExists? SecretName? SecretStorageSpecifier? GenericCopyOptionList\n" - "SecretStorageSpecifier <- 'IN' Identifier\n" - "CreateViewStmt <- 'RECURSIVE'? 'VIEW' IfNotExists? QualifiedName InsertColumnList? WithList? 'AS' SelectStatementInternal\n" - "DescribeStatement <- ShowTables / ShowSelect / ShowAllTables / ShowQualifiedName\n" - "ShowSelect <- ShowOrDescribeOrSummarize SelectStatementInternal\n" - "ShowAllTables <- ShowOrDescribe 'ALL' 'TABLES'\n" - "ShowQualifiedName <- ShowOrDescribeOrSummarize (BaseTableName / StringLiteral)?\n" - "ShowTables <- ShowOrDescribe 'TABLES' 'FROM' QualifiedName\n" - "ShowOrDescribeOrSummarize <- ShowOrDescribe / Summarize\n" - "Summarize <- SummarizeRule\n" - "SummarizeRule <- 'SUMMARIZE'\n" - "ShowOrDescribe <- ShowRule / DescribeRule\n" - "ShowRule <- 'SHOW'\n" - "DescribeRule <- 'DESCRIBE' / 'DESC'\n" - "VacuumStatement <- 'VACUUM' VacuumOptions? AnalyzeTarget?\n" - "VacuumOptions <- VacuumParensOptions / VacuumLegacyOptions\n" - "VacuumParensOptions <- Parens(List(VacuumOption))\n" - "VacuumLegacyOptions <- OptFull? OptFreeze? OptVerbose? OptAnalyze?\n" - "VacuumOption <- OptAnalyze / OptFreeze / OptFull / OptVerbose / Identifier\n" - "OptAnalyze <- 'ANALYZE'\n" - "OptFull <- 'FULL'\n" - "OptFreeze <- 'FREEZE'\n" - "OptVerbose <- 'VERBOSE'\n" - "NameList <- Parens(List(ColId))\n" - "MergeIntoStatement <- WithClause? 'MERGE' 'INTO' TargetOptAlias MergeIntoUsingClause JoinQualifier MergeMatch+ ReturningClause?\n" - "MergeIntoUsingClause <- 'USING' TableRef\n" - "MergeMatch <- MatchedClause / NotMatchedClause\n" - "MatchedClause <- 'WHEN' 'MATCHED' AndExpression? 'THEN' MatchedClauseAction\n" - "MatchedClauseAction <- UpdateMatchClause / DeleteMatchClause / InsertMatchClause / DoNothingMatchClause / ErrorMatchClause\n" - "UpdateMatchClause <- 'UPDATE' UpdateMatchInfo?\n" - "UpdateMatchInfo <- UpdateMatchSetClause / ByNameOrPosition\n" - "DeleteMatchClause <- 'DELETE'\n" - "InsertMatchClause <- 'INSERT' InsertMatchInfo?\n" - "InsertMatchInfo <- InsertValuesList / InsertDefaultValues / InsertByNameOrPosition\n" - "InsertDefaultValues <- 'DEFAULT' 'VALUES'\n" - "InsertByNameOrPosition <- ByNameOrPosition? '*'?\n" - "InsertValuesList <- InsertColumnList? 'VALUES' Parens(List(Expression))\n" - "DoNothingMatchClause <- 'DO' 'NOTHING'\n" - "ErrorMatchClause <- 'ERROR' Expression?\n" - "UpdateMatchSetClause <- 'SET' (UpdateSetClause / '*')\n" - "AndExpression <- 'AND' Expression\n" - "NotMatchedClause <- 'WHEN' 'NOT' 'MATCHED' BySourceOrTarget? AndExpression? 'THEN' MatchedClauseAction\n" - "BySourceOrTarget <- BySource / ByTarget\n" - "BySource <- 'BY' 'SOURCE'\n" - "ByTarget <- 'BY' 'TARGET'\n" - "PragmaStatement <- 'PRAGMA' (PragmaAssign / PragmaFunction)\n" - "PragmaAssign <- SettingName '=' VariableList\n" - "PragmaFunction <- PragmaName PragmaParameters?\n" - "PragmaParameters <- Parens(List(Expression))\n" - "DeallocateStatement <- 'DEALLOCATE' 'PREPARE'? Identifier\n" - "PrepareStatement <- 'PREPARE' Identifier TypeList? 'AS' Statement\n" - "TypeList <- Parens(List(Type))\n" - "CreateStatement <- 'CREATE' OrReplace? Temporary? CreateStatementVariation\n" - "CreateStatementVariation <- CreateTableStmt / CreateMacroStmt / CreateSequenceStmt / CreateTypeStmt / CreateSchemaStmt / CreateViewStmt / CreateIndexStmt / CreateSecretStmt / CreateTriggerStmt\n" - "OrReplace <- 'OR' 'REPLACE'\n" - "Temporary <- Persistent / TempPersistent / TemporaryPersistent\n" - "Persistent <- 'PERSISTENT'\n" - "TempPersistent <- 'TEMP'\n" - "TemporaryPersistent <- 'TEMPORARY'\n" - "CreateTableStmt <- 'TABLE' IfNotExists? QualifiedName (CreateTableAs / CreateColumnList) CommitAction?\n" - "CreateTableAs <- IdentifierList? PartitionSortedOptions? WithList? 'AS' Statement WithData?\n" - "PartitionSortedOptions <- PartitionOptSortedOptions / SortedOptPartitionOptions\n" - "PartitionOptSortedOptions <- PartitionOptions SortedOptions?\n" - "SortedOptPartitionOptions <- SortedOptions PartitionOptions?\n" - "PartitionOptions <- 'PARTITIONED' 'BY' Parens(List(Expression))\n" - "SortedOptions <- 'SORTED' 'BY' Parens(List(Expression))\n" - "WithData <- 'WITH' 'NO'? 'DATA'\n" - "IdentifierList <- Parens(List(Identifier))\n" - "CreateColumnList <- Parens(CreateTableColumnList?) PartitionSortedOptions? WithList?\n" - "IfNotExists <- 'IF' 'NOT' 'EXISTS'\n" - "QualifiedName <- CatalogReservedSchemaIdentifier / SchemaReservedIdentifierOrStringLiteral / IdentifierOrStringLiteral\n" - "SchemaReservedIdentifierOrStringLiteral <- SchemaQualification ReservedIdentifierOrStringLiteral\n" - "CatalogReservedSchemaIdentifier <- CatalogQualification ReservedSchemaQualification ReservedIdentifierOrStringLiteral\n" - "IdentifierOrStringLiteral <- Identifier / StringLiteral\n" - "ReservedIdentifierOrStringLiteral <- ReservedIdentifier / StringLiteral\n" - "CatalogQualification <- CatalogName '.'\n" - "SchemaQualification <- SchemaName '.'\n" - "ReservedSchemaQualification <- ReservedSchemaName '.'\n" - "TableQualification <- TableName '.'\n" - "ReservedTableQualification <- ReservedTableName '.'\n" - "CreateTableColumnList <- List(CreateTableColumnElement)\n" - "CreateTableColumnElement <- ColumnDefinition / TopLevelConstraint\n" - "ColumnDefinition <- DottedIdentifier Type? GeneratedColumn? ConstraintNameClause? ColumnConstraint*\n" - "ColumnConstraint <- NotNullConstraint / UniqueConstraint / PrimaryKeyConstraint / DefaultValue / CheckConstraint / ForeignKeyConstraint / ColumnCollation / ColumnCompression\n" - "NotNullConstraint <- 'NOT'? 'NULL'\n" - "UniqueConstraint <- 'UNIQUE'\n" - "PrimaryKeyConstraint <- 'PRIMARY' 'KEY'\n" - "DefaultValue <- 'DEFAULT' ColumnDefaultExpr\n" - "CheckConstraint <- 'CHECK' Parens(Expression)\n" - "ForeignKeyConstraint <- 'REFERENCES' BaseTableName Parens(ColumnList)? KeyActions\n" - "ColumnCollation <- 'COLLATE' DottedIdentifier\n" - "ColumnCompression <- 'USING' 'COMPRESSION' ColIdOrString\n" - "KeyActions <- UpdateAction? DeleteAction?\n" - "UpdateAction <- 'ON' 'UPDATE' KeyAction\n" - "DeleteAction <- 'ON' 'DELETE' KeyAction\n" - "KeyAction <- NoKeyAction / RestrictKeyAction / CascadeKeyAction / SetNullKeyAction / SetDefaultKeyAction\n" - "NoKeyAction <- 'NO' 'ACTION'\n" - "RestrictKeyAction <- 'RESTRICT'\n" - "CascadeKeyAction <- 'CASCADE'\n" - "SetNullKeyAction <- 'SET' 'NULL'\n" - "SetDefaultKeyAction <- 'SET' 'DEFAULT'\n" - "TopLevelConstraint <- ConstraintNameClause? TopLevelConstraintList\n" - "TopLevelConstraintList <- TopPrimaryKeyConstraint / CheckConstraint / TopUniqueConstraint / TopForeignKeyConstraint\n" - "ConstraintNameClause <- 'CONSTRAINT' Identifier\n" - "TopPrimaryKeyConstraint <- 'PRIMARY' 'KEY' ColumnIdList\n" - "TopUniqueConstraint <- 'UNIQUE' ColumnIdList\n" - "TopForeignKeyConstraint <- 'FOREIGN' 'KEY' ColumnIdList ForeignKeyConstraint\n" - "ColumnIdList <- Parens(List(ColId))\n" - "PlainIdentifier <- !ReservedKeyword <[a-z_]i[a-z0-9_]i*>\n" - "QuotedIdentifier <- '\"' [^\"]* '\"'\n" - "DottedIdentifier <- Identifier ('.' ColLabel)*\n" - "Identifier <- QuotedIdentifier / PlainIdentifier\n" - "ColId <- UnreservedKeyword / ColumnNameKeyword / Identifier\n" - "ColIdOrString <- ColId / StringLiteral\n" - "TypeFuncName <- UnreservedKeyword / TypeNameKeyword / FuncNameKeyword / Identifier\n" - "TypeName <- UnreservedKeyword / TypeNameKeyword / Identifier\n" - "ColLabel <- ReservedKeyword / UnreservedKeyword / ColumnNameKeyword / FuncNameKeyword / TypeNameKeyword / Identifier\n" - "ColLabelOrString <- ColLabel / StringLiteral\n" - "GeneratedColumn <- Generated? 'AS' Parens(Expression) GeneratedColumnType?\n" - "Generated <- 'GENERATED' AlwaysOrByDefault?\n" - "AlwaysOrByDefault <- 'ALWAYS' / ('BY' 'DEFAULT')\n" - "GeneratedColumnType <- 'VIRTUAL' / 'STORED'\n" - "CommitAction <- 'ON' 'COMMIT' PreserveOrDelete 'ROWS'\n" - "PreserveOrDelete <- 'PRESERVE' / 'DELETE'\n" - "CreateIndexStmt <- Unique? 'INDEX' IfNotExists? IndexName? 'ON' BaseTableName InsertColumnList? IndexType? Parens(List(IndexElement))? WithList? WhereClause?\n" - "WithList <- 'WITH' RelOptionOrOids\n" - "RelOptionOrOids <- RelOptionList / Oids\n" - "RelOptionList <- Parens(List(RelOption))\n" - "Oids <- ('WITH' / 'WITHOUT') 'OIDS'\n" - "IndexElement <- Expression DescOrAsc? NullsFirstOrLast?\n" - "Unique <- 'UNIQUE'\n" - "IndexType <- 'USING' Identifier\n" - "RelOption <- RelOptionName RelOptionArgumentOpt?\n" - "RelOptionName <- DottedIdentifier / StringLiteral\n" - "RelOptionArgumentOpt <- '=' DefArg\n" - "DefArg <- NullLiteral / ReservedKeyword / StringLiteral / NumberLiteral / NoneLiteral / Expression\n" - "NoneLiteral <- 'NONE'\n" - "LoadStatement <- 'LOAD' ColIdOrString\n" - "InstallStatement <- 'FORCE'? 'INSTALL' IdentifierOrStringLiteral FromSource? VersionNumber?\n" - "UpdateExtensionsStatement <- 'UPDATE' 'EXTENSIONS' Parens(List(Identifier))?\n" - "FromSource <- 'FROM' (Identifier / StringLiteral)\n" - "VersionNumber <- 'VERSION' IdentifierOrStringLiteral\n" - "DropStatement <- 'DROP' DropEntries DropBehavior?\n" - "DropEntries <-\n" - " DropTable /\n" - " DropTableFunction /\n" - " DropFunction /\n" - " DropSchema /\n" - " DropIndex /\n" - " DropSequence /\n" - " DropCollation /\n" - " DropType /\n" - " DropSecret /\n" - " DropTrigger\n" - "DropTrigger <- 'TRIGGER' IfExists? TriggerName 'ON' BaseTableName\n" - "DropTable <- TableOrView IfExists? List(BaseTableName)\n" - "DropTableFunction <- CommentMacroTable IfExists? List(TableFunctionName)\n" - "DropFunction <- FunctionType IfExists? List(FunctionIdentifier)\n" - "DropSchema <- 'SCHEMA' IfExists? List(QualifiedSchemaName)\n" - "DropIndex <- 'INDEX' IfExists? List(QualifiedIndexName)\n" - "QualifiedIndexName <- CatalogQualification? SchemaQualification? IndexName\n" - "DropSequence <- 'SEQUENCE' IfExists? List(QualifiedSequenceName)\n" - "DropCollation <- 'COLLATION' IfExists? List(CollationName)\n" - "DropType <- 'TYPE' IfExists? List(QualifiedTypeName)\n" - "DropSecret <- Temporary? 'SECRET' IfExists? SecretName DropSecretStorage?\n" - "TableOrView <- CommentTable / CommentView / MaterializedViewEntry\n" - "MaterializedViewEntry <- 'MATERIALIZED' 'VIEW'\n" - "FunctionType <- 'MACRO' / 'FUNCTION'\n" - "DropBehavior <- 'CASCADE' / 'RESTRICT'\n" - "IfExists <- 'IF' 'EXISTS'\n" - "QualifiedSchemaName <- CatalogQualification? SchemaName\n" - "DropSecretStorage <- 'FROM' Identifier\n" - "UpdateStatement <- WithClause? 'UPDATE' UpdateTarget UpdateSetClause FromClause? WhereClause? ReturningClause?\n" - "UpdateTarget <- BaseTableSet / BaseTableAliasSet\n" - "BaseTableSet <- BaseTableName 'SET'\n" - "BaseTableAliasSet <- BaseTableName UpdateAlias? 'SET'\n" - "UpdateAlias <- 'AS'? ColId\n" - "UpdateSetClause <- UpdateSetElementList / UpdateSetTuple\n" - "UpdateSetTuple <- Parens(List(ColumnName)) '=' Expression\n" - "UpdateSetElementList <- List(UpdateSetElement)\n" - "UpdateSetElement <- UpdateSetColumnTarget '=' Expression\n" - "UpdateSetColumnTarget <- ColumnName ('.' Identifier)*\n" - "InsertStatement <- WithClause? 'INSERT' OrAction? 'INTO' InsertTarget ByNameOrPosition? InsertColumnList? InsertValues OnConflictClause? ReturningClause?\n" - "OrAction <- 'OR' 'REPLACE' / 'IGNORE'\n" - "ByNameOrPosition <- 'BY' InsertByName / InsertByPosition\n" - "InsertByName <- 'NAME'\n" - "InsertByPosition <- 'POSITION'\n" - "InsertTarget <- BaseTableName InsertAlias?\n" - "InsertAlias <- 'AS' Identifier\n" - "ColumnList <- List(ColId)\n" - "InsertColumnList <- Parens(ColumnList)\n" - "InsertValues <- SelectStatementInternal / DefaultValues\n" - "DefaultValues <- 'DEFAULT' 'VALUES'\n" - "OnConflictClause <- 'ON' 'CONFLICT' OnConflictTarget? OnConflictAction\n" - "OnConflictTarget <- OnConflictExpressionTarget / OnConflictIndexTarget\n" - "OnConflictExpressionTarget <- ColumnIdList WhereClause?\n" - "OnConflictIndexTarget <- 'ON' 'CONSTRAINT' ConstraintName\n" - "OnConflictAction <- OnConflictUpdate / OnConflictNothing\n" - "OnConflictUpdate <- 'DO' 'UPDATE' 'SET' UpdateSetClause WhereClause?\n" - "OnConflictNothing <- 'DO' 'NOTHING'\n" - "ReturningClause <- 'RETURNING' TargetList\n" - "CreateSchemaStmt <- 'SCHEMA' IfNotExists? QualifiedName\n" - "SelectStatement <- SelectStatementInternal\n" - "SelectStatementInternal <- WithClause? SelectSetOpChain ResultModifiers?\n" - "SelectSetOpChain <- IntersectChain (SetopClause IntersectChain)*\n" - "IntersectChain <- SelectAtom (SetIntersectClause SelectAtom)*\n" - "SetIntersectClause <- 'INTERSECT' DistinctOrAll?\n" - "SelectAtom <- SelectParens / SelectStatementType\n" - "SelectParens <- Parens(SelectStatementInternal)\n" - "SetopClause <- SetopType DistinctOrAll? ByName?\n" - "SetopType <- SetopUnion / SetopExcept\n" - "SetopUnion <- 'UNION'\n" - "SetopExcept <- 'EXCEPT'\n" - "ByName <- 'BY' 'NAME'\n" - "SelectStatementType <- OptionalParensSimpleSelect / ValuesClause / DescribeStatement / TableStatement / PivotStatement / UnpivotStatement\n" - "ResultModifiers <- OrderByClause? LimitOffset?\n" - "LimitOffset <- LimitOffsetClause / OffsetLimitClause\n" - "LimitOffsetClause <- LimitClause OffsetClause?\n" - "OffsetLimitClause <- OffsetClause LimitClause?\n" - "TableStatement <- 'TABLE' BaseTableName\n" - "OptionalParensSimpleSelect <- SimpleSelectParens / SimpleSelect\n" - "SimpleSelectParens <- Parens(SimpleSelect)\n" - "SimpleSelect <- SelectFrom WhereClause? GroupByClause? HavingClause? WindowClause? QualifyClause? SampleClause?\n" - "SelectFrom <- SelectFromClause / FromSelectClause\n" - "SelectFromClause <- SelectClause FromClause?\n" - "FromSelectClause <- FromClause SelectClause?\n" - "WithStatement <- ColIdOrString InsertColumnList? UsingKey? 'AS' Materialized? CTEBody\n" - "CTEBody <- Parens(CTEBodyContent)\n" - "CTEBodyContent <- SelectStatementInternal / Statement\n" - "UsingKey <- 'USING' 'KEY' Parens(TargetList)\n" - "Materialized <- 'NOT'? 'MATERIALIZED'\n" - "WithClause <- 'WITH' Recursive? List(WithStatement)\n" - "Recursive <- 'RECURSIVE'\n" - "SelectClause <- 'SELECT' DistinctClause? TargetList?\n" - "TargetList <- List(AliasedExpression)\n" - "ColumnAliases <- Parens(List(ColIdOrString))\n" - "DistinctClause <- DistinctOn / DistinctAll\n" - "DistinctAll <- 'ALL'\n" - "DistinctOn <- 'DISTINCT' DistinctOnTargets?\n" - "DistinctOnTargets <- 'ON' Parens(List(Expression))\n" - "InnerTableRef <- ValuesRef / TableFunction / TableSubquery / BaseTableRef / ParensTableRef\n" - "TableRef <- InnerTableRef JoinOrPivot*\n" - "TableSubquery <- Lateral? SubqueryReference TableAlias?\n" - "BaseTableRef <- TableAliasColon? BaseTableName TableAlias? AtClause? SampleClause?\n" - "TableAliasColon <- ColIdOrString ':'\n" - "ValuesRef <- ValuesClause TableAlias?\n" - "ParensTableRef <- TableAliasColon? Parens(TableRef) TableAlias? SampleClause?\n" - "JoinOrPivot <- JoinClause / TablePivotClause / TableUnpivotClause\n" - "TablePivotClause <- 'PIVOT' Parens(TargetList 'FOR' PivotValueList+ PivotGroupByList?) TableAlias?\n" - "PivotGroupByList <- 'GROUP' 'BY' List(ColIdOrString)\n" - "TableUnpivotClause <- 'UNPIVOT' IncludeOrExcludeNulls? Parens(UnpivotHeader 'FOR' UnpivotValueList+) TableAlias?\n" - "PivotHeader <- BaseExpression\n" - "PivotValueList <- PivotHeader 'IN' (Identifier / PivotTargetList)\n" - "UnpivotValueList <- UnpivotHeader 'IN' UnpivotTargetList\n" - "PivotTargetList <- Parens(TargetList)\n" - "UnpivotTargetList <- Parens(TargetList)\n" - "Lateral <- 'LATERAL'\n" - "BaseTableName <- CatalogReservedSchemaTable / SchemaReservedTable / TableName\n" - "SchemaReservedTable <- SchemaQualification ReservedTableName\n" - "CatalogReservedSchemaTable <- CatalogQualification ReservedSchemaQualification ReservedTableName\n" - "TableFunction <- TableFunctionLateralOpt / TableFunctionAliasColon\n" - "TableFunctionLateralOpt <- Lateral? QualifiedTableFunction TableFunctionArguments WithOrdinality? TableAlias?\n" - "TableFunctionAliasColon <- TableAliasColon QualifiedTableFunction TableFunctionArguments WithOrdinality? SampleClause?\n" - "WithOrdinality <- 'WITH' 'ORDINALITY'\n" - "QualifiedTableFunction <- CatalogQualification? SchemaQualification? TableFunctionName\n" - "TableFunctionArguments <- Parens(List(FunctionArgument)?)\n" - "FunctionArgument <- NamedParameter / Expression\n" - "NamedParameter <- TypeFuncName Type? NamedParameterAssignment Expression\n" - "NamedParameterAssignment <- ':=' / '=>'\n" - "TableAlias <- TableAliasAs / TableAliasWithoutAs\n" - "TableAliasAs <- 'AS' IdentifierOrStringLiteral ColumnAliases?\n" - "TableAliasWithoutAs <- Identifier ColumnAliases?\n" - "AtClause <- 'AT' Parens(AtSpecifier)\n" - "AtSpecifier <- AtUnit '=>' Expression\n" - "AtUnit <- 'VERSION' / 'TIMESTAMP'\n" - "JoinClause <- RegularJoinClause / JoinWithoutOnClause\n" - "RegularJoinClause <- 'ASOF'? JoinType? 'JOIN' TableRef JoinQualifier\n" - "JoinWithoutOnClause <- JoinPrefix 'JOIN' TableRef\n" - "JoinQualifier <- OnClause / UsingClause\n" - "OnClause <- 'ON' Expression\n" - "UsingClause <- 'USING' Parens(List(ColumnName))\n" - "JoinType <- FullJoin / LeftJoin / RightJoin / SemiJoin / AntiJoin / InnerJoin\n" - "JoinPrefix <- CrossJoinPrefix / NaturalJoinPrefix / PositionalJoinPrefix\n" - "CrossJoinPrefix <- 'CROSS'\n" - "NaturalJoinPrefix <- 'NATURAL' JoinType?\n" - "PositionalJoinPrefix <- 'POSITIONAL'\n" - "FullJoin <- 'FULL' 'OUTER'?\n" - "LeftJoin <- 'LEFT' 'OUTER'?\n" - "RightJoin <- 'RIGHT' 'OUTER'?\n" - "SemiJoin <- 'SEMI'\n" - "AntiJoin <- 'ANTI'\n" - "InnerJoin <- 'INNER'\n" - "FromClause <- 'FROM' List(TableRef)\n" - "WhereClause <- 'WHERE' Expression\n" - "GroupByClause <- 'GROUP' 'BY' GroupByExpressions\n" - "HavingClause <- 'HAVING' Expression\n" - "QualifyClause <- 'QUALIFY' Expression\n" - "SampleClause <- (TableSample / UsingSample) SampleEntry\n" - "UsingSample <- 'USING' 'SAMPLE'\n" - "TableSample <- 'TABLESAMPLE'\n" - "WindowClause <- 'WINDOW' List(WindowDefinition)\n" - "WindowDefinition <- Identifier 'AS' WindowFrameDefinition\n" - "SampleEntry <- SampleEntryFunction / SampleEntryCount\n" - "SampleEntryCount <- SampleCount Parens(SampleProperties)?\n" - "SampleEntryFunction <- SampleFunction? Parens(SampleCount) RepeatableSample?\n" - "SampleFunction <- ColId\n" - "SampleProperties <- ColId (',' SampleSeed)?\n" - "RepeatableSample <- 'REPEATABLE' Parens(SampleSeed)\n" - "SampleSeed <- NumberLiteral\n" - "SampleCount <- SampleValue SampleUnit?\n" - "SampleValue <- NumberLiteral / Parameter\n" - "SampleUnit <- SamplePercentage / SampleRows\n" - "SamplePercentage <- '%' / 'PERCENT'\n" - "SampleRows <- 'ROWS'\n" - "GroupByExpressions <- GroupByList / GroupByAll\n" - "GroupByAll <- 'ALL'\n" - "GroupByList <- List(GroupByExpression)\n" - "GroupByExpression <- EmptyGroupingItem / CubeOrRollupClause / GroupingSetsClause / Expression\n" - "EmptyGroupingItem <- '(' ')'\n" - "CubeOrRollupClause <- CubeOrRollup Parens(List(Expression)?)\n" - "CubeOrRollup <- 'CUBE' / 'ROLLUP'\n" - "GroupingSetsClause <- 'GROUPING' 'SETS' Parens(GroupByList)\n" - "SubqueryReference <- Parens(SelectStatementInternal)\n" - "OrderByExpression <- Expression DescOrAsc? NullsFirstOrLast?\n" - "DescOrAsc <- DescendingOrder / AscendingOrder\n" - "DescendingOrder <- 'DESC' / 'DESCENDING'\n" - "AscendingOrder <- 'ASC' / 'ASCENDING'\n" - "NullsFirstOrLast <- NullsFirst / NullsLast\n" - "NullsFirst <- 'NULLS' 'FIRST'\n" - "NullsLast <- 'NULLS' 'LAST'\n" - "OrderByClause <- 'ORDER' 'BY' OrderByExpressions\n" - "OrderByExpressions <- OrderByAll / OrderByExpressionList\n" - "OrderByExpressionList <- List(OrderByExpression)\n" - "OrderByAll <- 'ALL' DescOrAsc? NullsFirstOrLast?\n" - "LimitClause <- 'LIMIT' LimitValue\n" - "OffsetClause <- 'OFFSET' OffsetValue\n" - "OffsetValue <- Expression RowOrRows?\n" - "RowOrRows <- 'ROW' / 'ROWS'\n" - "LimitValue <- LimitAll / LimitLiteralPercent / LimitExpression\n" - "LimitAll <- 'ALL'\n" - "LimitLiteralPercent <- NumberLiteral 'PERCENT'\n" - "LimitExpression <- Expression '%'?\n" - "AliasedExpression <- ColIdExpression / ExpressionAsCollabel / ExpressionOptIdentifier\n" - "ColIdExpression <- ColId ':' Expression\n" - "ExpressionAsCollabel <- Expression 'AS' ColLabelOrString\n" - "ExpressionOptIdentifier <- Expression Identifier?\n" - "ValuesClause <- 'VALUES' List(ValuesExpressions)\n" - "ValuesExpressions <- Parens(List(Expression))\n" - "TransactionStatement <- BeginTransaction / RollbackTransaction / CommitTransaction\n" - "BeginTransaction <- StartOrBegin Transaction? ReadOrWrite?\n" - "RollbackTransaction <- AbortOrRollback Transaction?\n" - "CommitTransaction <- CommitOrEnd Transaction?\n" - "StartOrBegin <- 'START' / 'BEGIN'\n" - "Transaction <- 'WORK' / 'TRANSACTION'\n" - "ReadOrWrite <- 'READ' ReadOnlyOrReadWrite\n" - "ReadOnlyOrReadWrite <- ReadOnly / ReadWrite\n" - "ReadOnly <- 'ONLY'\n" - "ReadWrite <- 'WRITE'\n" - "AbortOrRollback <- 'ABORT' / 'ROLLBACK'\n" - "CommitOrEnd <- 'COMMIT' / 'END'\n" - "DeleteStatement <- WithClause? 'DELETE' 'FROM' TargetOptAlias DeleteUsingClause? WhereClause? ReturningClause?\n" - "TruncateStatement <- 'TRUNCATE' 'TABLE'? BaseTableName\n" - "TargetOptAlias <- BaseTableName 'AS'? ColId?\n" - "DeleteUsingClause <- 'USING' List(TableRef)\n" - "CreateTypeStmt <- 'TYPE' IfNotExists? QualifiedName 'AS' CreateType\n" - "CreateType <- EnumSelectType / EnumStringLiteralList / Type\n" - "EnumSelectType <- 'ENUM' Parens(SelectStatementInternal)\n" - "EnumStringLiteralList <- 'ENUM' Parens(List(StringLiteral)?)\n" - "SetStatement <- 'SET' (StandardAssignment / SetTimeZone)\n" - "StandardAssignment <- (SetVariable / SetSetting) SetAssignment\n" - "SetTimeZone <- 'TIME' 'ZONE' ZoneValue\n" - "ZoneValue <- ZoneIntervalWithPrecision / ZoneIntervalWithInterval / 'LOCAL' / 'DEFAULT' / StringLiteral / Identifier / NumberLiteral\n" - "ZoneIntervalWithInterval <- 'INTERVAL' StringLiteral Interval?\n" - "ZoneIntervalWithPrecision <- 'INTERVAL' Parens(NumberLiteral) StringLiteral\n" - "SetSetting <- SettingScope? SettingName\n" - "SetVariable <- VariableScope Identifier\n" - "VariableScope <- 'VARIABLE'\n" - "SettingScope <- LocalScope / SessionScope / GlobalScope\n" - "LocalScope <- 'LOCAL'\n" - "SessionScope <- 'SESSION'\n" - "GlobalScope <- 'GLOBAL'\n" - "SetAssignment <- VariableAssign VariableList\n" - "VariableAssign <- '=' / 'TO'\n" - "VariableList <- List(Expression)\n" - "ResetStatement <- 'RESET' (SetVariable / SetSetting)\n" - "ExportStatement <- 'EXPORT' 'DATABASE' ExportSource? StringLiteral GenericCopyOptionList?\n" - "ExportSource <- CatalogName 'TO'\n" - "ImportStatement <- 'IMPORT' 'DATABASE' StringLiteral\n" - "CheckpointStatement <- 'FORCE'? 'CHECKPOINT' CatalogName?\n" - "CopyStatement <- 'COPY' (CopyTable / CopySelect / CopyFromDatabase)\n" - "CopyTable <- BaseTableName InsertColumnList? FromOrTo CopyFileName CopyOptions?\n" - "FromOrTo <- 'FROM' / 'TO'\n" - "CopySelect <- Parens(SelectStatementInternal) 'TO' CopyFileName CopyOptions?\n" - "CopyFileName <- Parameter / ParensExpression / StringLiteral / Identifier / (Identifier '.' ColId)\n" - "CopyOptions <- 'WITH'? GenericCopyOptionList / SpecializedOptionList\n" - "SpecializedOptionList <- SpecializedOption*\n" - "SpecializedOption <- SingleOption / NullAsOption / DelimiterAsOption / EscapeAsOption /\n" - " EncodingOption / QuoteAsOption / ForceQuoteOption / PartitionByOption / ForceNullOption\n" - "SingleOption <- BinaryOption / FreezeOption / OidsOption / CsvOption / HeaderOption\n" - "BinaryOption <- 'BINARY'\n" - "FreezeOption <- 'FREEZE'\n" - "OidsOption <- 'OIDS'\n" - "CsvOption <- 'CSV'\n" - "HeaderOption <- 'HEADER'\n" - "NullAsOption <- 'NULL' 'AS'? StringLiteral\n" - "DelimiterAsOption <- 'DELIMITER' 'AS'? StringLiteral\n" - "QuoteAsOption <- 'QUOTE' 'AS'? StringLiteral\n" - "EscapeAsOption <- 'ESCAPE' 'AS'? StringLiteral\n" - "EncodingOption <- 'ENCODING' StringLiteral\n" - "ForceQuoteOption <- 'FORCE'? 'QUOTE' (StarSymbol / ColumnList)\n" - "PartitionByOption <- 'PARTITION' 'BY' (StarSymbol / ColumnList)\n" - "ForceNullOption <- 'FORCE' 'NOT'? 'NULL' ColumnList\n" - "StarSymbol <- '*'\n" - "GenericCopyOptionList <- Parens(List(GenericCopyOption))\n" - "GenericCopyOption <- CopyOptionName Expression?\n" - "CopyFromDatabase <- 'FROM' 'DATABASE' ColId 'TO' ColId CopyDatabaseFlag?\n" - "CopyDatabaseFlag <- Parens(SchemaOrData)\n" - "SchemaOrData <- CopySchema / CopyData\n" - "CopySchema <- 'SCHEMA'\n" - "CopyData <- 'DATA'\n" - "AlterStatement <- 'ALTER' AlterOptions\n" - "AlterOptions <- AlterTableStmt / AlterViewStmt / AlterSequenceStmt / AlterDatabaseStmt / AlterSchemaStmt\n" - "AlterTableStmt <- 'TABLE' IfExists? BaseTableName List(AlterTableOptions)\n" - "AlterSchemaStmt <- 'SCHEMA' IfExists? QualifiedName RenameAlter\n" - "AlterTableOptions <- AddColumn / DropColumn / AlterColumn / AddConstraint / ChangeNullability /\n" - " RenameColumn / RenameAlter / SetPartitionedBy / ResetPartitionedBy / SetSortedBy / ResetSortedBy / SetOptions / ResetOptions\n" - "AddConstraint <- 'ADD' TopLevelConstraint\n" - "AddColumn <- 'ADD' 'COLUMN'? IfNotExists? AddColumnEntry\n" - "AddColumnEntry <- DottedIdentifier Type? GeneratedColumn? ColumnConstraint*\n" - "DropColumn <- 'DROP' 'COLUMN'? IfExists? NestedColumnName DropBehavior?\n" - "AlterColumn <- 'ALTER' 'COLUMN'? NestedColumnName AlterColumnEntry\n" - "RenameColumn <- 'RENAME' 'COLUMN'? NestedColumnName 'TO' Identifier\n" - "NestedColumnName <- (Identifier '.')* ColumnName\n" - "RenameAlter <- 'RENAME' 'TO' Identifier\n" - "SetPartitionedBy <- 'SET' 'PARTITIONED' 'BY' Parens(List(Expression))\n" - "ResetPartitionedBy <- 'RESET' 'PARTITIONED' 'BY'\n" - "SetSortedBy <- 'SET' 'SORTED' 'BY' Parens(OrderByExpressions)\n" - "ResetSortedBy <- 'RESET' 'SORTED' 'BY'\n" - "SetOptions <- 'SET' RelOptionList\n" - "ResetOptions <- 'RESET' RelOptionList\n" - "AlterColumnEntry <- AddOrDropDefault / ChangeNullability / AlterType\n" - "AddOrDropDefault <- AddDefault / DropDefault\n" - "AddDefault <- 'SET' 'DEFAULT' Expression\n" - "DropDefault <- 'DROP' 'DEFAULT'\n" - "ChangeNullability <- DropOrSet 'NOT' 'NULL'\n" - "DropOrSet <- 'DROP' / 'SET'\n" - "AlterType <- SetData? 'TYPE' Type? UsingExpression?\n" - "SetData <- 'SET' 'DATA'?\n" - "UsingExpression <- 'USING' Expression\n" - "AlterViewStmt <- 'VIEW' IfExists? BaseTableName RenameAlter\n" - "AlterSequenceStmt <- 'SEQUENCE' IfExists? QualifiedSequenceName AlterSequenceOptions\n" - "QualifiedSequenceName <- CatalogQualification? SchemaQualification? SequenceName\n" - "AlterSequenceOptions <- RenameAlter / SetSequenceOption\n" - "SetSequenceOption <- SequenceOption+\n" - "AlterDatabaseStmt <- 'DATABASE' IfExists? Identifier 'SET' 'ALIAS' 'TO' Identifier\n" - "CreateSequenceStmt <- 'SEQUENCE' IfNotExists? QualifiedName SequenceOption*\n" - "SequenceOption <-\n" - " SeqSetCycle /\n" - " SeqSetIncrement /\n" - " SeqSetMinMax /\n" - " SeqNoMinMax /\n" - " SeqStartWith /\n" - " SeqOwnedBy\n" - "SeqSetCycle <- 'NO'? 'CYCLE'\n" - "SeqSetIncrement <- 'INCREMENT' 'BY'? Expression\n" - "SeqSetMinMax <- SeqMinOrMax Expression\n" - "SeqNoMinMax <- 'NO' SeqMinOrMax\n" - "SeqStartWith <- 'START' 'WITH'? Expression\n" - "SeqOwnedBy <- 'OWNED' 'BY' QualifiedName\n" - "SeqMinOrMax <- MinValue / MaxValue\n" - "MinValue <- 'MINVALUE'\n" - "MaxValue <- 'MAXVALUE'\n" - "ExplainStatement <- 'EXPLAIN' 'ANALYZE'? ExplainOptionList? ExplainableStatements\n" - "ExplainOptionList <- Parens(List(ExplainOption))\n" - "ExplainOption <- ExplainOptionName Expression?\n" - "ExplainOptionName <- 'ANALYZE' / 'ANALYSE' / ColId / FuncNameKeyword / TypeNameKeyword\n" - "ExplainableStatements <-\n" - " AlterStatement /\n" - " AnalyzeStatement /\n" - " CallStatement /\n" - " CheckpointStatement /\n" - " CopyStatement /\n" - " CreateStatement /\n" - " DeallocateStatement /\n" - " DeleteStatement /\n" - " DropStatement /\n" - " ExecuteStatement /\n" - " InsertStatement /\n" - " LoadStatement /\n" - " MergeIntoStatement /\n" - " PragmaStatement /\n" - " PrepareStatement /\n" - " SelectStatementInternal /\n" - " TransactionStatement /\n" - " UpdateStatement /\n" - " VacuumStatement /\n" - " SetStatement /\n" - " ResetStatement\n" - "AnalyzeStatement <- 'ANALYZE' 'VERBOSE'? AnalyzeTarget?\n" - "AnalyzeTarget <- BaseTableName NameList?\n" - "CreateMacroStmt <- MacroOrFunction IfNotExists? QualifiedName List(MacroDefinition)\n" - "MacroOrFunction <- 'MACRO' / 'FUNCTION'\n" - "MacroDefinition <- Parens(MacroParameters?) 'AS' (TableMacroDefinition / ScalarMacroDefinition)\n" - "MacroParameters <- List(MacroParameter)\n" - "MacroParameter <- NamedParameter / SimpleParameter\n" - "SimpleParameter <- TypeFuncName Type?\n" - "ScalarMacroDefinition <- Expression\n" - "TableMacroDefinition <- 'TABLE' SelectStatementInternal\n" - "CommentStatement <- 'COMMENT' 'ON' CommentOnType DottedIdentifier 'IS' CommentValue\n" - "CommentOnType <- CommentTable / CommentSequence / CommentFunction / CommentMacroTable / CommentMacro /\n" - " CommentView / CommentDatabase / CommentIndex / CommentSchema / CommentType / CommentColumn\n" - "CommentTable <- 'TABLE'\n" - "CommentSequence <- 'SEQUENCE'\n" - "CommentFunction <- 'FUNCTION'\n" - "CommentMacroTable <- 'MACRO' 'TABLE'\n" - "CommentMacro <- 'MACRO'\n" - "CommentView <- 'VIEW'\n" - "CommentDatabase <- 'DATABASE'\n" - "CommentIndex <- 'INDEX'\n" - "CommentSchema <- 'SCHEMA'\n" - "CommentType <- 'TYPE'\n" - "CommentColumn <- 'COLUMN'\n" - "CommentValue <- 'NULL' / StringLiteral\n" - "CreateTriggerStmt <- 'TRIGGER' IfNotExists? TriggerName TriggerTiming TriggerEvent 'ON' BaseTableName ForEachClause? TriggerBody\n" - "TriggerBody <- InsertStatement / UpdateStatement / DeleteStatement\n" - "TriggerTiming <- TriggerBefore / TriggerAfter / TriggerInsteadOf\n" - "TriggerBefore <- 'BEFORE'\n" - "TriggerAfter <- 'AFTER'\n" - "TriggerInsteadOf <- 'INSTEAD' 'OF'\n" - "TriggerEvent <- TriggerEventUpdateOf / TriggerEventInsert / TriggerEventDelete / TriggerEventUpdate\n" - "TriggerEventInsert <- 'INSERT'\n" - "TriggerEventDelete <- 'DELETE'\n" - "TriggerEventUpdate <- 'UPDATE'\n" - "TriggerEventUpdateOf <- 'UPDATE' 'OF' TriggerColumnList\n" - "TriggerColumnList <- List(ColId)\n" - "ForEachClause <- ForEachRow / ForEachStatement\n" - "ForEachRow <- 'FOR' 'EACH' 'ROW'\n" - "ForEachStatement <- 'FOR' 'EACH' 'STATEMENT'\n" - "AttachStatement <- 'ATTACH' OrReplace? IfNotExists? Database? DatabasePath AttachAlias? AttachOptions?\n" - "Database <- 'DATABASE'\n" - "DatabasePath <- Expression\n" - "AttachAlias <- 'AS' ColId\n" - "AttachOptions <- GenericCopyOptionList\n" - "DetachStatement <- 'DETACH' Database? IfExists? CatalogName\n" - "UseStatement <- 'USE' UseTarget\n" - "UseTarget <- UseTargetCatalogSchema / SchemaName / CatalogName\n" - "UseTargetCatalogSchema <- CatalogName '.' ReservedSchemaName DotIdentifier*\n" - "DotIdentifier <- '.' Identifier\n" - "CallStatement <- 'CALL' QualifiedTableFunction TableFunctionArguments\n" + "Program <- TopLevelStatement*\n" + "TopLevelStatement <- Statement? (';'+ / EndOfInput)\n" + "Statement <-\n" + " CreateStatement /\n" + " SelectStatement /\n" + " SetStatement /\n" + " PragmaStatement /\n" + " CallStatement /\n" + " InsertStatement /\n" + " DropStatement /\n" + " CopyStatement /\n" + " ExplainStatement /\n" + " UpdateStatement /\n" + " PrepareStatement /\n" + " ExecuteStatement /\n" + " AlterStatement /\n" + " TransactionStatement /\n" + " DeleteStatement /\n" + " AttachStatement /\n" + " UseStatement /\n" + " DetachStatement /\n" + " CheckpointStatement /\n" + " VacuumStatement /\n" + " ResetStatement /\n" + " ExportStatement /\n" + " ImportStatement /\n" + " CommentStatement /\n" + " DeallocateStatement /\n" + " TruncateStatement /\n" + " LoadStatement /\n" + " InstallStatement /\n" + " UpdateExtensionsStatement /\n" + " AnalyzeStatement /\n" + " MergeIntoStatement /\n" + " ConnectStatement /\n" + " DisconnectStatement /\n" + " ExpressionStatement\n" + "ExpressionStatement <- List(ExpressionAlias)\n" + "ExpressionAlias <- ColIdExpression / ExpressionAsCollabel / Expression\n" + "CatalogName <- Identifier\n" + "SchemaName <- Identifier\n" + "ReservedSchemaName <- Identifier\n" + "TableName <- Identifier\n" + "ReservedTableName <- Identifier\n" + "ReservedIdentifier <- Identifier\n" + "ColumnName <- Identifier\n" + "ReservedColumnName <- Identifier\n" + "IndexName <- Identifier\n" + "ReservedIndexName <- Identifier\n" + "SettingName <- Identifier\n" + "PragmaName <- Identifier\n" + "FunctionName <- Identifier\n" + "ReservedFunctionName <- Identifier\n" + "ReservedTypeName <- Identifier\n" + "TableFunctionName <- Identifier\n" + "ConstraintName <- ColIdOrString\n" + "SequenceName <- Identifier\n" + "CollationName <- Identifier\n" + "CopyOptionName <- ColLabel\n" + "NumberLiteral <- < [+-]?[0-9]*([.][0-9]*)? >\n" + "StringLiteral <- '\\'' [^\\']* '\\''\n" + "Type <- TypeVariations ArrayBounds*\n" + "TypeVariations <- TimeType / IntervalType / BitType / RowType / VariantType / MapType / GeometryType / UnionType " + "/ NumericType / SetofType / SimpleType\n" + "SimpleType <- CharacterSimpleType / QualifiedSimpleType\n" + "CharacterSimpleType <- CharacterType TypeModifiers?\n" + "QualifiedSimpleType <- QualifiedTypeName TypeModifiers?\n" + "CharacterType <- ('CHARACTER' 'VARYING'?) /\n" + " ('CHAR' 'VARYING'?) /\n" + " ('NATIONAL' 'CHARACTER' 'VARYING'?) /\n" + " ('NATIONAL' 'CHAR' 'VARYING'?) /\n" + " ('NCHAR' 'VARYING'?) /\n" + " 'VARCHAR'\n" + "IntervalType <- IntervalInterval / IntervalNumber\n" + "IntervalInterval <- IntervalWithSpecifier / IntervalWithoutSpecifier\n" + "IntervalWithSpecifier <- IntervalWithRangeSpecifier / IntervalWithSimpleSpecifier\n" + "IntervalWithRangeSpecifier <- 'INTERVAL' IntervalToIntervalAsType\n" + "IntervalWithSimpleSpecifier <- 'INTERVAL' Interval\n" + "IntervalWithoutSpecifier <- 'INTERVAL'\n" + "IntervalToIntervalAsType <- (YearKeyword 'TO' MonthKeyword) /\n" + " (DayKeyword 'TO' HourKeyword) /\n" + " (DayKeyword 'TO' MinuteKeyword) /\n" + " (DayKeyword 'TO' SecondKeyword) /\n" + " (HourKeyword 'TO' MinuteKeyword) /\n" + " (HourKeyword 'TO' SecondKeyword) /\n" + " (MinuteKeyword 'TO' SecondKeyword)\n" + "IntervalNumber <- 'INTERVAL' Parens(NumberLiteral)\n" + "YearKeyword <- 'YEAR' / 'YEARS'\n" + "MonthKeyword <- 'MONTH' / 'MONTHS'\n" + "DayKeyword <- 'DAY' / 'DAYS'\n" + "HourKeyword <- 'HOUR' / 'HOURS'\n" + "MinuteKeyword <- 'MINUTE' / 'MINUTES'\n" + "SecondKeyword <- 'SECOND' / 'SECONDS'\n" + "MillisecondKeyword <- 'MILLISECOND' / 'MILLISECONDS'\n" + "MicrosecondKeyword <- 'MICROSECOND' / 'MICROSECONDS'\n" + "WeekKeyword <- 'WEEK' / 'WEEKS'\n" + "QuarterKeyword <- 'QUARTER' / 'QUARTERS'\n" + "DecadeKeyword <- 'DECADE' / 'DECADES'\n" + "CenturyKeyword <- 'CENTURY' / 'CENTURIES'\n" + "MillenniumKeyword <- 'MILLENNIUM' / 'MILLENNIA'\n" + "Interval <- IntervalToInterval /\n" + " YearKeyword /\n" + " MonthKeyword /\n" + " DayKeyword /\n" + " HourKeyword /\n" + " MinuteKeyword /\n" + " SecondKeyword /\n" + " MillisecondKeyword /\n" + " MicrosecondKeyword /\n" + " WeekKeyword /\n" + " QuarterKeyword /\n" + " DecadeKeyword /\n" + " CenturyKeyword /\n" + " MillenniumKeyword\n" + "IntervalToInterval <- YearToMonth /\n" + " DayToHour /\n" + " DayToMinute /\n" + " DayToSecond /\n" + " HourToMinute /\n" + " HourToSecond /\n" + " MinuteToSecond\n" + "YearToMonth <- YearKeyword 'TO' MonthKeyword\n" + "DayToHour <- DayKeyword 'TO' HourKeyword\n" + "DayToMinute <- DayKeyword 'TO' MinuteKeyword\n" + "DayToSecond <- DayKeyword 'TO' SecondKeyword\n" + "HourToMinute <- HourKeyword 'TO' MinuteKeyword\n" + "HourToSecond <- HourKeyword 'TO' SecondKeyword\n" + "MinuteToSecond <- MinuteKeyword 'TO' SecondKeyword\n" + "BitType <- 'BIT' 'VARYING'? Parens(List(Expression))?\n" + "GeometryType <- 'GEOMETRY' Parens(Expression)?\n" + "VariantType <- 'VARIANT'\n" + "NumericType <- SimpleNumericType / DecimalNumericType\n" + "SimpleNumericType <- IntType / IntegerType / SmallintType /\n" + " BigintType / RealType / BooleanType / DoubleType\n" + "DecimalNumericType <- FloatType / DecimalType / DecType / NumericModType\n" + "IntType <- 'INT'\n" + "IntegerType <- 'INTEGER'\n" + "SmallintType <- 'SMALLINT'\n" + "BigintType <- 'BIGINT'\n" + "RealType <- 'REAL'\n" + "BooleanType <- 'BOOLEAN'\n" + "DoubleType <- ('DOUBLE' 'PRECISION')\n" + "FloatType <- 'FLOAT' Parens(NumberLiteral)?\n" + "DecimalType <- 'DECIMAL' TypeModifiers?\n" + "DecType <- 'DEC' TypeModifiers?\n" + "NumericModType <- 'NUMERIC' TypeModifiers?\n" + "QualifiedTypeName <- CatalogReservedSchemaTypeName / SchemaReservedTypeName / TypeNameAsQualifiedName\n" + "TypeNameAsQualifiedName <- TypeName\n" + "CatalogReservedSchemaTypeName <- CatalogQualification ReservedSchemaQualification ReservedTypeName\n" + "SchemaReservedTypeName <- SchemaQualification ReservedTypeName\n" + "TypeModifiers <- Parens(List(Expression)?)\n" + "RowType <- RowOrStruct ColIdTypeList\n" + "SetofType <- 'SETOF' Type\n" + "UnionType <- 'UNION' ColIdTypeList\n" + "ColIdTypeList <- Parens(List(ColIdType))\n" + "MapType <- 'MAP' Parens(List(Type))\n" + "ColIdType <- ColId Type\n" + "ArrayBounds <- SquareBracketsArray / ArrayKeyword\n" + "ArrayKeyword <- 'ARRAY'\n" + "SquareBracketsArray <- '[' Expression? ']'\n" + "TimeType <- TimeOrTimestamp TypeModifiers? TimeZone?\n" + "TimeOrTimestamp <- TimeTypeId / TimestampTypeId\n" + "TimeTypeId <- 'TIME'\n" + "TimestampTypeId <- 'TIMESTAMP'\n" + "TimeZone <- WithOrWithout 'TIME' 'ZONE'\n" + "WithOrWithout <- WithRule / WithoutRule\n" + "WithRule <- 'WITH'\n" + "WithoutRule <- 'WITHOUT'\n" + "RowOrStruct <- 'ROW' / 'STRUCT'\n" + "# internal definitions\n" + "%whitespace <- [ \\t\\n\\r]*\n" + "List(D) <- D (',' D)* ','?\n" + "Parens(D) <- '(' D ')'\n" + "UnreservedKeyword <- 'ABORT' /\n" + "'ABSOLUTE' /\n" + "'ACCESS' /\n" + "'ACTION' /\n" + "'ADD' /\n" + "'ADMIN' /\n" + "'AFTER' /\n" + "'AGGREGATE' /\n" + "'ALSO' /\n" + "'ALTER' /\n" + "'ALWAYS' /\n" + "'ASSERTION' /\n" + "'ASSIGNMENT' /\n" + "'ATTACH' /\n" + "'ATTRIBUTE' /\n" + "'BACKWARD' /\n" + "'BEFORE' /\n" + "'BEGIN' /\n" + "'CACHE' /\n" + "'CALL' /\n" + "'CALLED' /\n" + "'CASCADE' /\n" + "'CASCADED' /\n" + "'CATALOG' /\n" + "'CENTURY' /\n" + "'CENTURIES' /\n" + "'CHAIN' /\n" + "'CHARACTERISTICS' /\n" + "'CHECKPOINT' /\n" + "'CLASS' /\n" + "'CLOSE' /\n" + "'CLUSTER' /\n" + "'COMMENT' /\n" + "'COMMENTS' /\n" + "'COMMIT' /\n" + "'COMMITTED' /\n" + "'COMPRESSION' /\n" + "'CONFIGURATION' /\n" + "'CONFLICT' /\n" + "'CONNECT' /\n" + "'CONNECTION' /\n" + "'CONSTRAINTS' /\n" + "'CONTENT' /\n" + "'CONTINUE' /\n" + "'CONVERSION' /\n" + "'COPY' /\n" + "'COST' /\n" + "'CSV' /\n" + "'CUBE' /\n" + "'CURRENT' /\n" + "'CURSOR' /\n" + "'CYCLE' /\n" + "'DATA' /\n" + "'DATABASE' /\n" + "'DAY' /\n" + "'DAYS' /\n" + "'DEALLOCATE' /\n" + "'DECADE' /\n" + "'DECADES' /\n" + "'DECLARE' /\n" + "'DEFAULTS' /\n" + "'DEFERRED' /\n" + "'DEFINER' /\n" + "'DELETE' /\n" + "'DELIMITER' /\n" + "'DELIMITERS' /\n" + "'DEPENDS' /\n" + "'DETACH' /\n" + "'DICTIONARY' /\n" + "'DISABLE' /\n" + "'DISCARD' /\n" + "'DISCONNECT' /\n" + "'DOCUMENT' /\n" + "'DOMAIN' /\n" + "'DOUBLE' /\n" + "'DROP' /\n" + "'EACH' /\n" + "'ENABLE' /\n" + "'ENCODING' /\n" + "'ENCRYPTED' /\n" + "'ENUM' /\n" + "'ERROR' /\n" + "'ESCAPE' /\n" + "'EVENT' /\n" + "'EXCLUDE' /\n" + "'EXCLUDING' /\n" + "'EXCLUSIVE' /\n" + "'EXECUTE' /\n" + "'EXPLAIN' /\n" + "'EXPORT' /\n" + "'EXPORT_STATE' /\n" + "'EXTENSION' /\n" + "'EXTENSIONS' /\n" + "'EXTERNAL' /\n" + "'FAMILY' /\n" + "'FILTER' /\n" + "'FIRST' /\n" + "'FOLLOWING' /\n" + "'FORCE' /\n" + "'FORWARD' /\n" + "'FUNCTION' /\n" + "'FUNCTIONS' /\n" + "'GLOBAL' /\n" + "'GRANT' /\n" + "'GRANTED' /\n" + "'GROUPS' /\n" + "'HANDLER' /\n" + "'HEADER' /\n" + "'HOLD' /\n" + "'HOUR' /\n" + "'HOURS' /\n" + "'IDENTITY' /\n" + "'IF' /\n" + "'IGNORE' /\n" + "'IMMEDIATE' /\n" + "'IMMUTABLE' /\n" + "'IMPLICIT' /\n" + "'IMPORT' /\n" + "'INCLUDE' /\n" + "'INCLUDING' /\n" + "'INCREMENT' /\n" + "'INDEX' /\n" + "'INDEXES' /\n" + "'INHERIT' /\n" + "'INHERITS' /\n" + "'INLINE' /\n" + "'INPUT' /\n" + "'INSENSITIVE' /\n" + "'INSERT' /\n" + "'INSTALL' /\n" + "'INSTEAD' /\n" + "'INVOKER' /\n" + "'JSON' /\n" + "'ISOLATION' /\n" + "'KEY' /\n" + "'LABEL' /\n" + "'LANGUAGE' /\n" + "'LARGE' /\n" + "'LAST' /\n" + "'LEAKPROOF' /\n" + "'LEVEL' /\n" + "'LISTEN' /\n" + "'LOAD' /\n" + "'LOCAL' /\n" + "'LOCATION' /\n" + "'LOCK' /\n" + "'LOCKED' /\n" + "'LOGGED' /\n" + "'MACRO' /\n" + "'MAPPING' /\n" + "'MATCH' /\n" + "'MATCHED' /\n" + "'MATERIALIZED' /\n" + "'MAXVALUE' /\n" + "'MERGE' /\n" + "'METHOD' /\n" + "'MICROSECOND' /\n" + "'MICROSECONDS' /\n" + "'MILLENNIUM' /\n" + "'MILLENNIA' /\n" + "'MILLISECOND' /\n" + "'MILLISECONDS' /\n" + "'MINUTE' /\n" + "'MINUTES' /\n" + "'MINVALUE' /\n" + "'MODE' /\n" + "'MONTH' /\n" + "'MONTHS' /\n" + "'MOVE' /\n" + "'NAME' /\n" + "'NAMES' /\n" + "'NEW' /\n" + "'NEXT' /\n" + "'NO' /\n" + "'NOTHING' /\n" + "'NOTIFY' /\n" + "'NOWAIT' /\n" + "'NULLS' /\n" + "'OBJECT' /\n" + "'OF' /\n" + "'OFF' /\n" + "'OIDS' /\n" + "'OLD' /\n" + "'OPERATOR' /\n" + "'OPTION' /\n" + "'OPTIONS' /\n" + "'ORDINALITY' /\n" + "'OTHERS' /\n" + "'OVER' /\n" + "'OVERRIDING' /\n" + "'OWNED' /\n" + "'OWNER' /\n" + "'PARALLEL' /\n" + "'PARSER' /\n" + "'PARTIAL' /\n" + "'PARTITION' /\n" + "'PARTITIONED' /\n" + "'PASSING' /\n" + "'PASSWORD' /\n" + "'PERCENT' /\n" + "'PERSISTENT' /\n" + "'PLANS' /\n" + "'POLICY' /\n" + "'PRAGMA' /\n" + "'PRECEDING' /\n" + "'PREPARE' /\n" + "'PREPARED' /\n" + "'PRESERVE' /\n" + "'PRIOR' /\n" + "'PRIVILEGES' /\n" + "'PROCEDURAL' /\n" + "'PROCEDURE' /\n" + "'PROGRAM' /\n" + "'PUBLICATION' /\n" + "'QUARTER' /\n" + "'QUARTERS' /\n" + "'QUOTE' /\n" + "'RANGE' /\n" + "'READ' /\n" + "'REASSIGN' /\n" + "'RECHECK' /\n" + "'RECURSIVE' /\n" + "'REF' /\n" + "'REFERENCING' /\n" + "'REFRESH' /\n" + "'REINDEX' /\n" + "'RELATIVE' /\n" + "'RELEASE' /\n" + "'RENAME' /\n" + "'REPEATABLE' /\n" + "'REPLACE' /\n" + "'REPLICA' /\n" + "'RESET' /\n" + "'RESPECT' /\n" + "'RESTART' /\n" + "'RESTRICT' /\n" + "'RETURNS' /\n" + "'REVOKE' /\n" + "'ROLE' /\n" + "'ROLLBACK' /\n" + "'ROLLUP' /\n" + "'ROWS' /\n" + "'RULE' /\n" + "'SAMPLE' /\n" + "'SAVEPOINT' /\n" + "'SCHEMA' /\n" + "'SCHEMAS' /\n" + "'SCOPE' /\n" + "'SCROLL' /\n" + "'SEARCH' /\n" + "'SECRET' /\n" + "'SECOND' /\n" + "'SECONDS' /\n" + "'SECURITY' /\n" + "'SEQUENCE' /\n" + "'SEQUENCES' /\n" + "'SERIALIZABLE' /\n" + "'SERVER' /\n" + "'SESSION' /\n" + "'SET' /\n" + "'SETS' /\n" + "'SHARE' /\n" + "'SIMPLE' /\n" + "'SKIP' /\n" + "'SNAPSHOT' /\n" + "'SORTED' /\n" + "'SOURCE' /\n" + "'SQL' /\n" + "'STABLE' /\n" + "'STANDALONE' /\n" + "'START' /\n" + "'STATEMENT' /\n" + "'STATISTICS' /\n" + "'STDIN' /\n" + "'STDOUT' /\n" + "'STORAGE' /\n" + "'STORED' /\n" + "'STRICT' /\n" + "'STRIP' /\n" + "'SUBSCRIPTION' /\n" + "'SYSID' /\n" + "'SYSTEM' /\n" + "'TABLES' /\n" + "'TABLESPACE' /\n" + "'TARGET' /\n" + "'TEMP' /\n" + "'TEMPLATE' /\n" + "'TEMPORARY' /\n" + "'TEXT' /\n" + "'TIES' /\n" + "'TRANSACTION' /\n" + "'TRANSFORM' /\n" + "'TRIGGER' /\n" + "'TRUNCATE' /\n" + "'TRUSTED' /\n" + "'TYPE' /\n" + "'TYPES' /\n" + "'UNBOUNDED' /\n" + "'UNCOMMITTED' /\n" + "'UNENCRYPTED' /\n" + "'UNKNOWN' /\n" + "'UNLISTEN' /\n" + "'UNLOGGED' /\n" + "'UNTIL' /\n" + "'UPDATE' /\n" + "'USE' /\n" + "'USER' /\n" + "'VACUUM' /\n" + "'VALID' /\n" + "'VALIDATE' /\n" + "'VALIDATOR' /\n" + "'VALUE' /\n" + "'VARIABLE' /\n" + "'VARYING' /\n" + "'VERSION' /\n" + "'VIEW' /\n" + "'VIEWS' /\n" + "'VIRTUAL' /\n" + "'VOLATILE' /\n" + "'WEEK' /\n" + "'WEEKS' /\n" + "'WHITESPACE' /\n" + "'WITHIN' /\n" + "'WITHOUT' /\n" + "'WORK' /\n" + "'WRAPPER' /\n" + "'WRITE' /\n" + "'XML' /\n" + "'YEAR' /\n" + "'YEARS' /\n" + "'YES' /\n" + "'ZONE'\n" + "ReservedKeyword <- 'ALL' /\n" + "'ANALYSE' /\n" + "'ANALYZE' /\n" + "'AND' /\n" + "'ANY' /\n" + "'ARRAY' /\n" + "'AS' /\n" + "'ASC' /\n" + "'ASYMMETRIC' /\n" + "'BOTH' /\n" + "'CASE' /\n" + "'CAST' /\n" + "'CHECK' /\n" + "'COLLATE' /\n" + "'COLUMN' /\n" + "'CONSTRAINT' /\n" + "'CREATE' /\n" + "'DEFAULT' /\n" + "'DEFERRABLE' /\n" + "'DESC' /\n" + "'DESCRIBE' /\n" + "'DISTINCT' /\n" + "'DO' /\n" + "'ELSE' /\n" + "'END' /\n" + "'EXCEPT' /\n" + "'FALSE' /\n" + "'FETCH' /\n" + "'FOR' /\n" + "'FOREIGN' /\n" + "'FROM' /\n" + "'GROUP' /\n" + "'HAVING' /\n" + "'QUALIFY' /\n" + "'IN' /\n" + "'INITIALLY' /\n" + "'INTERSECT' /\n" + "'INTO' /\n" + "'LAMBDA' /\n" + "'LATERAL' /\n" + "'LEADING' /\n" + "'LIMIT' /\n" + "'NOT' /\n" + "'NULL' /\n" + "'OFFSET' /\n" + "'ON' /\n" + "'ONLY' /\n" + "'OR' /\n" + "'ORDER' /\n" + "'PIVOT' /\n" + "'PIVOT_WIDER' /\n" + "'PIVOT_LONGER' /\n" + "'PLACING' /\n" + "'PRIMARY' /\n" + "'REFERENCES' /\n" + "'RETURNING' /\n" + "'SELECT' /\n" + "'SHOW' /\n" + "'SOME' /\n" + "'SUMMARIZE' /\n" + "'SYMMETRIC' /\n" + "'TABLE' /\n" + "'THEN' /\n" + "'TO' /\n" + "'TRAILING' /\n" + "'TRUE' /\n" + "'UNION' /\n" + "'UNIQUE' /\n" + "'UNPIVOT' /\n" + "'USING' /\n" + "'VARIADIC' /\n" + "'WHEN' /\n" + "'WHERE' /\n" + "'WINDOW' /\n" + "'WITH'\n" + "ColumnNameKeyword <- 'BETWEEN' /\n" + "'BIGINT' /\n" + "'BIT' /\n" + "'BOOLEAN' /\n" + "'CHAR' /\n" + "'CHARACTER' /\n" + "'COALESCE' /\n" + "'COLUMNS' /\n" + "'DEC' /\n" + "'DECIMAL' /\n" + "'EXISTS' /\n" + "'EXTRACT' /\n" + "'FLOAT' /\n" + "'GENERATED' /\n" + "'GROUPING' /\n" + "'GROUPING_ID' /\n" + "'INOUT' /\n" + "'INT' /\n" + "'INTEGER' /\n" + "'INTERVAL' /\n" + "'MAP' /\n" + "'NATIONAL' /\n" + "'NCHAR' /\n" + "'NONE' /\n" + "'NULLIF' /\n" + "'NUMERIC' /\n" + "'OUT' /\n" + "'OVERLAY' /\n" + "'POSITION' /\n" + "'PRECISION' /\n" + "'REAL' /\n" + "'ROW' /\n" + "'SETOF' /\n" + "'SMALLINT' /\n" + "'SUBSTRING' /\n" + "'STRUCT' /\n" + "'TIME' /\n" + "'TIMESTAMP' /\n" + "'TREAT' /\n" + "'TRIM' /\n" + "'TRY_CAST' /\n" + "'VALUES' /\n" + "'VARCHAR' /\n" + "'XMLATTRIBUTES' /\n" + "'XMLCONCAT' /\n" + "'XMLELEMENT' /\n" + "'XMLEXISTS' /\n" + "'XMLFOREST' /\n" + "'XMLNAMESPACES' /\n" + "'XMLPARSE' /\n" + "'XMLPI' /\n" + "'XMLROOT' /\n" + "'XMLSERIALIZE' /\n" + "'XMLTABLE'\n" + "FuncNameKeyword <- 'ASOF' /\n" + "'AT' /\n" + "'AUTHORIZATION' /\n" + "'BINARY' /\n" + "'COLLATION' /\n" + "'CONCURRENTLY' /\n" + "'CROSS' /\n" + "'FREEZE' /\n" + "'FULL' /\n" + "'GENERATED' /\n" + "'GLOB' /\n" + "'ILIKE' /\n" + "'INNER' /\n" + "'IS' /\n" + "'ISNULL' /\n" + "'JOIN' /\n" + "'LEFT' /\n" + "'LIKE' /\n" + "'MAP' /\n" + "'NATURAL' /\n" + "'NOTNULL' /\n" + "'OUTER' /\n" + "'OVERLAPS' /\n" + "'POSITIONAL' /\n" + "'RIGHT' /\n" + "'SIMILAR' /\n" + "'STRUCT' /\n" + "'TABLESAMPLE' /\n" + "'VERBOSE'\n" + "TypeNameKeyword <- 'ASOF' /\n" + "'AT' /\n" + "'AUTHORIZATION' /\n" + "'BINARY' /\n" + "'BY' /\n" + "'COLLATION' /\n" + "'COLUMNS' /\n" + "'CONCURRENTLY' /\n" + "'CROSS' /\n" + "'FREEZE' /\n" + "'FULL' /\n" + "'GLOB' /\n" + "'ILIKE' /\n" + "'INNER' /\n" + "'IS' /\n" + "'ISNULL' /\n" + "'JOIN' /\n" + "'LEFT' /\n" + "'LIKE' /\n" + "'NATURAL' /\n" + "'NOTNULL' /\n" + "'OUTER' /\n" + "'OVERLAPS' /\n" + "'POSITIONAL' /\n" + "'RIGHT' /\n" + "'UNPACK' /\n" + "'SIMILAR' /\n" + "'TABLESAMPLE' /\n" + "'TRY_CAST' /\n" + "'VERBOSE' /\n" + "'SEMI' /\n" + "'ANTI'\n" + "PivotStatement <- PivotKeyword TableRef PivotOn? PivotUsing? PivotGroupByList?\n" + "PivotOn <- 'ON' PivotColumnList\n" + "PivotUsing <- 'USING' TargetList\n" + "PivotColumnList <- List(PivotColumnEntry)\n" + "PivotColumnEntry <- PivotColumnSubquery / PivotValueList / PivotColumnExpression\n" + "PivotColumnExpression <- Expression\n" + "PivotColumnSubquery <- BaseExpression 'IN' Parens(SelectStatementInternal)\n" + "PivotKeyword <- 'PIVOT' / 'PIVOT_WIDER'\n" + "UnpivotKeyword <- 'UNPIVOT' / 'PIVOT_LONGER'\n" + "UnpivotStatement <- UnpivotKeyword TableRef 'ON' TargetList IntoNameValues?\n" + "IntoNameValues <- 'INTO' 'NAME' ColIdOrString ValueOrValues List(Identifier)\n" + "ValueOrValues <- 'VALUE' / 'VALUES'\n" + "IncludeOrExcludeNulls <- IncludeNulls / ExcludeNulls\n" + "IncludeNulls <- 'INCLUDE' 'NULLS'\n" + "ExcludeNulls <- 'EXCLUDE' 'NULLS'\n" + "UnpivotHeader <- UnpivotHeaderSingle / UnpivotHeaderList\n" + "UnpivotHeaderSingle <- ColIdOrString\n" + "UnpivotHeaderList <- Parens(List(ColIdOrString))\n" + "ColumnReference <- CatalogReservedSchemaTableColumnName / SchemaReservedTableColumnName / TableReservedColumnName " + "/ NestedColumnName\n" + "CatalogReservedSchemaTableColumnName <- CatalogQualification ReservedSchemaQualification " + "ReservedTableQualification ReservedColumnName\n" + "SchemaReservedTableColumnName <- SchemaQualification ReservedTableQualification ReservedColumnName\n" + "TableReservedColumnName <- TableQualification ReservedColumnName\n" + "FunctionExpression <- FunctionIdentifier Parens(DistinctOrAll? List(FunctionArgument)? OrderByClause? " + "IgnoreOrRespectNulls?) WithinGroupClause? FilterClause? ExportClause? OverClause?\n" + "FunctionIdentifier <- CatalogReservedSchemaFunctionName / SchemaReservedFunctionName / FunctionName\n" + "CatalogReservedSchemaFunctionName <- CatalogQualification ReservedSchemaQualification? ReservedFunctionName\n" + "SchemaReservedFunctionName <- SchemaQualification ReservedFunctionName\n" + "DistinctOrAll <- 'DISTINCT' / 'ALL'\n" + "ExportClause <- 'EXPORT_STATE'\n" + "WithinGroupClause <- 'WITHIN' 'GROUP' Parens(OrderByClause)\n" + "FilterClause <- 'FILTER' Parens('WHERE'? Expression)\n" + "IgnoreOrRespectNulls <- IgnoreNulls / RespectNulls\n" + "IgnoreNulls <- 'IGNORE' 'NULLS'\n" + "RespectNulls <- 'RESPECT' 'NULLS'\n" + "ParenthesisExpression <- Parens(List(Expression))\n" + "LiteralExpression <- StringLiteral / NumberLiteral / ConstantLiteral\n" + "ConstantLiteral <- NullLiteral / TrueLiteral / FalseLiteral\n" + "NullLiteral <- 'NULL'\n" + "TrueLiteral <- 'TRUE'\n" + "FalseLiteral <- 'FALSE'\n" + "CastExpression <- CastOrTryCast Parens(Expression 'AS' Type)\n" + "CastOrTryCast <- 'CAST' / 'TRY_CAST'\n" + "ColIdDot <- ColId '.'\n" + "StarExpression <- ColIdDot* '*' ExcludeList? ReplaceList? RenameList?\n" + "ExcludeList <- ExcludeOrExcept ExcludeNameList / ExcludeNameSingle\n" + "ExcludeOrExcept <- 'EXCLUDE' / 'EXCEPT'\n" + "ExcludeNameList <- Parens(List(ExcludeName))\n" + "ExcludeNameSingle <- ExcludeName\n" + "ExcludeName <- DottedIdentifier / ColIdOrString\n" + "ReplaceList <- 'REPLACE' ReplaceEntries\n" + "ReplaceEntries <- ReplaceEntrySingle / ReplaceEntryList\n" + "ReplaceEntrySingle <- ReplaceEntry\n" + "ReplaceEntryList <- Parens(List(ReplaceEntry))\n" + "ReplaceEntry <- Expression 'AS' ColumnReference\n" + "RenameList <- 'RENAME' (RenameEntryList / SingleRenameEntry)\n" + "RenameEntryList <- Parens(List(RenameEntry))\n" + "SingleRenameEntry <- RenameEntry\n" + "RenameEntry <- ExcludeName 'AS' Identifier\n" + "SubqueryExpression <- 'NOT'? 'EXISTS'? SubqueryReference\n" + "CaseExpression <- 'CASE' Expression? CaseWhenThen+ CaseElse? 'END'\n" + "CaseWhenThen <- 'WHEN' Expression 'THEN' Expression\n" + "CaseElse <- 'ELSE' Expression\n" + "TypeLiteral <- ColId StringLiteral\n" + "IntervalLiteral <- 'INTERVAL' IntervalParameter Interval?\n" + "IntervalParameter <- StringLiteral / NumberLiteral / ParensExpression\n" + "FrameClause <- Framing FrameExtent WindowExcludeClause?\n" + "Framing <- 'ROWS' / 'RANGE' / 'GROUPS'\n" + "FrameExtent <- BetweenFrameExtent / SingleFrameExtent\n" + "SingleFrameExtent <- FrameBound\n" + "BetweenFrameExtent <- 'BETWEEN' FrameBound 'AND' FrameBound\n" + "FrameBound <- FrameUnbounded / FrameCurrentRow / FrameExpression\n" + "FrameUnbounded <- 'UNBOUNDED' PrecedingOrFollowing\n" + "FrameExpression <- Expression PrecedingOrFollowing\n" + "FrameCurrentRow <- 'CURRENT' 'ROW'\n" + "PrecedingOrFollowing <- 'PRECEDING' / 'FOLLOWING'\n" + "WindowExcludeClause <- 'EXCLUDE' WindowExcludeElement\n" + "WindowExcludeElement <- ExcludeCurrentRow / ExcludeGroup / ExcludeTies / ExcludeNoOthers\n" + "ExcludeCurrentRow <- 'CURRENT' 'ROW'\n" + "ExcludeGroup <- 'GROUP'\n" + "ExcludeTies <- 'TIES'\n" + "ExcludeNoOthers <- 'NO' 'OTHERS'\n" + "OverClause <- 'OVER' WindowFrame\n" + "WindowFrame <- ParensIdentifier / WindowFrameDefinition / Identifier\n" + "ParensIdentifier <- Parens(Identifier)\n" + "WindowFrameDefinition <- WindowFrameNameContentsParens / WindowFrameContentsParens\n" + "WindowFrameNameContentsParens <- Parens(BaseWindowName? WindowFrameContents)\n" + "WindowFrameContentsParens <- Parens(WindowFrameContents)\n" + "WindowFrameContents <- WindowPartition? OrderByClause? FrameClause?\n" + "BaseWindowName <- Identifier\n" + "WindowPartition <- 'PARTITION' 'BY' List(Expression)\n" + "ListExpression <- ArrayBoundedListExpression / ArrayParensSelect\n" + "ArrayBoundedListExpression <- 'ARRAY'? BoundedListExpression\n" + "ArrayParensSelect <- 'ARRAY' Parens(SelectStatementInternal)\n" + "BoundedListExpression <- '[' List(Expression)? ']'\n" + "StructExpression <- '{' List(StructField) '}'\n" + "StructField <- ColIdOrString ':' Expression\n" + "MapExpression <- 'MAP' MapStructExpression\n" + "MapStructExpression <- '{' List(MapStructField)? '}'\n" + "MapStructField <- Expression ':' Expression\n" + "GroupingExpression <- GroupingOrGroupingId Parens(List(Expression)?)\n" + "GroupingOrGroupingId <- 'GROUPING' / 'GROUPING_ID'\n" + "Parameter <- QuestionMarkNumberedParameter / AnonymousParameter / NumberedParameter / ColLabelParameter\n" + "QuestionMarkNumberedParameter <- '?' NumberLiteral\n" + "AnonymousParameter <- '?'\n" + "NumberedParameter <- '$' NumberLiteral\n" + "ColLabelParameter <- '$' ColLabel\n" + "PositionalExpression <- '#' NumberLiteral\n" + "DefaultExpression <- 'DEFAULT'\n" + "ListComprehensionExpression <- '[' Expression 'FOR' List(ColIdOrString) 'IN' Expression ListComprehensionFilter? " + "']'\n" + "ListComprehensionFilter <- 'IF' Expression\n" + "ParensExpression <- Parens(Expression)\n" + "SingleExpression <-\n" + " ParensExpression /\n" + " LiteralExpression /\n" + " Parameter /\n" + " SubqueryExpression /\n" + " SpecialFunctionExpression /\n" + " ParenthesisExpression /\n" + " IntervalLiteral /\n" + " TypeLiteral /\n" + " CaseExpression /\n" + " StarExpression /\n" + " CastExpression /\n" + " GroupingExpression /\n" + " MapExpression /\n" + " FunctionExpression /\n" + " ColumnReference /\n" + " ListComprehensionExpression /\n" + " ListExpression /\n" + " StructExpression /\n" + " PositionalExpression /\n" + " DefaultExpression\n" + "# LEVEL 1 (Lowest)\n" + "Expression <- LambdaArrowExpression\n" + "ColumnDefaultExpr <- ColDefOrExpr\n" + "# LEVEL 1.5\n" + "LambdaArrowExpression <- LogicalOrExpression SingleArrowPair*\n" + "SingleArrowPair <- '->' LogicalOrExpression\n" + "LogicalOrExpression <- LogicalAndExpression ('OR' LogicalAndExpression)*\n" + "ColDefOrExpr <- ColDefAndExpr ('OR' ColDefAndExpr)*\n" + "# LEVEL 2\n" + "LogicalAndExpression <- LogicalNotExpression ('AND' LogicalNotExpression)*\n" + "ColDefAndExpr <- IsDistinctFromExpression ('AND' IsDistinctFromExpression)*\n" + "# LEVEL 3\n" + "LogicalNotExpression <- 'NOT'* IsExpression\n" + "# LEVEL 4\n" + "IsExpression <- IsDistinctFromExpression IsTest*\n" + "IsTest <- IsLiteral / NotNull / IsNull\n" + "IsLiteral <- 'IS' 'NOT'? (TrueLiteral / FalseLiteral / NullLiteral / UnknownLiteral)\n" + "UnknownLiteral <- 'UNKNOWN'\n" + "NotNull <- ('NOT' 'NULL') / 'NOTNULL'\n" + "IsNull <- 'ISNULL'\n" + "# LEVEL 5 (Split because IsDistinctFromExpression allows post expression while IsOperator does not)\n" + "IsDistinctFromExpression <- ComparisonExpression (IsDistinctFromOp ComparisonExpression)*\n" + "IsDistinctFromOp <- 'IS' 'NOT'? 'DISTINCT' 'FROM'\n" + "# LEVEL 6\n" + "ComparisonExpression <- BetweenInLikeExpression (ComparisonOperator 'NOT'* BetweenInLikeExpression)*\n" + "ComparisonOperator <-\n" + " OperatorEqual /\n" + " OperatorNotEqual /\n" + " OperatorLessThan /\n" + " OperatorGreaterThan /\n" + " OperatorLessThanEquals /\n" + " OperatorGreaterThanEquals\n" + "OperatorEqual <- '=' / '=='\n" + "OperatorNotEqual <- '!=' / '<>'\n" + "OperatorLessThan <- '<'\n" + "OperatorGreaterThan <- '>'\n" + "OperatorLessThanEquals <- '<='\n" + "OperatorGreaterThanEquals <- '>='\n" + "# LEVEL 7\n" + "BetweenInLikeExpression <- OtherOperatorExpression BetweenInLikeOp?\n" + "BetweenInLikeOp <- 'NOT'? (BetweenClause / InClause / LikeClause)\n" + "LikeClause <- LikeVariations OtherOperatorExpression EscapeClause?\n" + "EscapeClause <- 'ESCAPE' ComparisonExpression\n" + "LikeVariations <- SimilarToToken / ILikeToken / LikeToken / GlobToken / NotILikeOp / NotLikeOp / NotSimilarToOp\n" + "LikeToken <- 'LIKE' / '~~'\n" + "ILikeToken <- 'ILIKE' / '~~*'\n" + "GlobToken <- 'GLOB' / '~~~'\n" + "SimilarToToken <- ('SIMILAR' 'TO') / '~'\n" + "NotILikeOp <- '!~~*'\n" + "NotLikeOp <- '!~~'\n" + "NotSimilarToOp <- '!~'\n" + "InClause <- 'IN' InExpression\n" + "InExpression <- InExpressionList / InSelectStatement / OtherOperatorExpression\n" + "InExpressionList <- Parens(List(Expression))\n" + "InSelectStatement <- Parens(SelectStatementInternal)\n" + "BetweenClause <- 'BETWEEN' OtherOperatorExpression 'AND' OtherOperatorExpression\n" + "# LEVEL 8\n" + "OtherOperatorExpression <- BitwiseExpression (OtherOperator BitwiseExpression)*\n" + "OtherOperator <-\n" + " QualifiedOperator / AnyAllOperator / InetOperator / JsonOperator / ListOperator / StringOperator / " + "OperatorLiteral\n" + "OperatorLiteral <- Identifier\n" + "AnyAllOperator <- AnyOp AnyOrAll\n" + "AnyOrAll <- SubqueryAny / SubqueryAll\n" + "SubqueryAny <- 'ANY'\n" + "SubqueryAll <- 'ALL'\n" + "InetOperator <- '>>=' / '<<='\n" + "JsonOperator <- '->>'\n" + "ListOperator <- '&&' / '@>' / '<@'\n" + "StringOperator <- '^@' / '||'\n" + "QualifiedOperator <- 'OPERATOR' Parens(ColIdDot* AnyOp)\n" + "AnyOp <- '!~~*' / '>>=' / '<<=' / '->>' / '!~~' / '~~*' / '~~~' / '!~' / '^@' / '||' / '&&' / '@>' / '<@' / '<=' " + "/ '>=' / '<>' / '!=' / '==' / '<<' / '>>' / '//' / '**' / '->' / '~~' / '+' / '-' / '*' / '/' / '%' / '^' / '<' / " + "'>' / '=' / '&' / '|' / '~' / '!'\n" + "# LEVEL 9\n" + "BitwiseExpression <- AdditiveExpression (BitOperator AdditiveExpression)*\n" + "BitOperator <- '&' / '|' / '<<' / '>>'\n" + "# LEVEL 10\n" + "AdditiveExpression <- MultiplicativeExpression (Term MultiplicativeExpression)*\n" + "Term <- '+' / '-'\n" + "# LEVEL 11\n" + "MultiplicativeExpression <- ExponentiationExpression (Factor ExponentiationExpression)*\n" + "Factor <- '*' / '/' / '//' / '%'\n" + "# LEVEL 12\n" + "ExponentiationExpression <- CollateExpression (ExponentOperator CollateExpression)*\n" + "ExponentOperator <- '^' / '**'\n" + "# LEVEL 13\n" + "CollateExpression <- AtTimeZoneExpression (CollateOperator AtTimeZoneExpression)*\n" + "CollateOperator <- 'COLLATE'\n" + "# LEVEL 14\n" + "AtTimeZoneExpression <- PrefixExpression (AtTimeZoneOperator PrefixExpression)*\n" + "AtTimeZoneOperator <- 'AT' 'TIME' 'ZONE'\n" + "# LEVEL 15\n" + "PrefixExpression <- PrefixOperator* BaseExpression\n" + "PrefixOperator <- QualifiedOperator / MinusPrefixOperator / PlusPrefixOperator / TildePrefixOperator\n" + "MinusPrefixOperator <- '-'\n" + "PlusPrefixOperator <- '+'\n" + "TildePrefixOperator <- '~'\n" + "# LEVEL 16 (Highest)\n" + "BaseExpression <- SingleExpression Indirection*\n" + "Indirection <- CastOperator / DotOperator / SliceExpression / PostfixOperator\n" + "CastOperator <- '::' Type\n" + "DotOperator <- '.' (MethodExpression / ColLabel)\n" + "MethodExpression <- ColLabel Parens(DistinctOrAll? List(FunctionArgument)? OrderByClause? IgnoreOrRespectNulls?)\n" + "SliceExpression <- '[' SliceBound ']'\n" + "SliceBound <- Expression? EndSliceBound? StepSliceBound?\n" + "EndSliceBound <- ':' (Expression / '-')?\n" + "StepSliceBound <- ':' Expression?\n" + "PostfixOperator <- '!'\n" + "SpecialFunctionExpression <- CoalesceExpression / UnpackExpression / TryExpression / ColumnsExpression / " + "ExtractExpression / LambdaExpression / NullIfExpression / PositionExpression / RowExpression / " + "SubstringExpression / TrimExpression / OverlayExpression\n" + "CoalesceExpression <- 'COALESCE' Parens(List(Expression))\n" + "UnpackExpression <- 'UNPACK' Parens(Expression)\n" + "TryExpression <- 'TRY' Parens(Expression)\n" + "ColumnsExpression <- '*'? 'COLUMNS' Parens(Expression)\n" + "ExtractExpression <- 'EXTRACT' Parens(ExtractArgument 'FROM' Expression)\n" + "LambdaExpression <- 'LAMBDA' List(ColIdOrString) ':' Expression\n" + "NullIfExpression <- 'NULLIF' Parens(Expression ',' Expression)\n" + "PositionExpression <- 'POSITION' Parens(SingleExpression 'IN' SingleExpression)\n" + "RowExpression <- 'ROW' Parens(List(Expression)?)\n" + "SubstringExpression <- 'SUBSTRING' Parens(SubstringArguments)\n" + "SubstringArguments <- SubstringParameters / SubstringExpressionList\n" + "SubstringExpressionList <- List(Expression)\n" + "SubstringParameters <- Expression SubstringFromFor\n" + "SubstringFromFor <- SubstringFromOptionalFor / SubstringFor\n" + "SubstringFromOptionalFor <- FromExpression ForExpression?\n" + "SubstringFor <- ForExpression\n" + "TrimExpression <- 'TRIM' Parens(TrimDirection? TrimSource? List(Expression))\n" + "TrimDirection <- TrimBoth / TrimLeading / TrimTrailing\n" + "TrimBoth <- 'BOTH'\n" + "TrimLeading <- 'LEADING'\n" + "TrimTrailing <- 'TRAILING'\n" + "TrimSource <- Expression? 'FROM'\n" + "OverlayExpression <- 'OVERLAY' Parens(OverlayArguments)\n" + "OverlayArguments <- OverlayParameters / OverlayExpressionList\n" + "OverlayParameters <- Expression 'PLACING' Expression FromExpression ForExpression?\n" + "FromExpression <- 'FROM' Expression\n" + "ForExpression <- 'FOR' Expression\n" + "OverlayExpressionList <- List(Expression)\n" + "ExtractArgument <-\n" + " YearKeyword / MonthKeyword / DayKeyword / HourKeyword / MinuteKeyword / SecondKeyword /\n" + " MillisecondKeyword / MicrosecondKeyword / WeekKeyword / QuarterKeyword / DecadeKeyword /\n" + " CenturyKeyword / MillenniumKeyword / Identifier / StringLiteral\n" + "ExecuteStatement <- 'EXECUTE' Identifier TableFunctionArguments?\n" + "CreateSecretStmt <- 'SECRET' IfNotExists? SecretName? SecretStorageSpecifier? GenericCopyOptionList\n" + "SecretStorageSpecifier <- 'IN' Identifier\n" + "SecretName <- ColId\n" + "CreateViewStmt <- CreateRecursive? 'VIEW' IfNotExists? QualifiedName InsertColumnList? WithList? 'AS' " + "SelectStatementInternal\n" + "CreateRecursive <- 'RECURSIVE'\n" + "DescribeStatement <- ShowTables / ShowSelect / ShowAllTables / ShowQualifiedName\n" + "ShowSelect <- ShowOrDescribeOrSummarize SelectStatementInternal\n" + "ShowAllTables <- ShowOrDescribe 'ALL' 'TABLES'\n" + "ShowQualifiedName <- ShowOrDescribeOrSummarize DescribeTarget?\n" + "ShowTables <- ShowOrDescribe 'TABLES' 'FROM' QualifiedName\n" + "DescribeTarget <- DescribeBaseTableName / DescribeStringLiteral\n" + "DescribeBaseTableName <- BaseTableName\n" + "DescribeStringLiteral <- StringLiteral\n" + "ShowOrDescribeOrSummarize <- ShowOrDescribe / Summarize\n" + "Summarize <- SummarizeRule\n" + "SummarizeRule <- 'SUMMARIZE'\n" + "ShowOrDescribe <- ShowRule / DescribeRule\n" + "ShowRule <- 'SHOW'\n" + "DescribeRule <- DescribeLongRule / DescRule\n" + "DescribeLongRule <- 'DESCRIBE'\n" + "DescRule <- 'DESC'\n" + "VacuumStatement <- 'VACUUM' VacuumOptions? AnalyzeTarget?\n" + "VacuumOptions <- VacuumParensOptions / VacuumLegacyOptions\n" + "VacuumParensOptions <- Parens(List(VacuumOption))\n" + "VacuumLegacyOptions <- OptFull? OptFreeze? OptVerbose? OptAnalyze?\n" + "VacuumOption <- OptAnalyze / OptFreeze / OptFull / OptVerbose / Identifier\n" + "OptAnalyze <- 'ANALYZE'\n" + "OptFull <- 'FULL'\n" + "OptFreeze <- 'FREEZE'\n" + "OptVerbose <- 'VERBOSE'\n" + "NameList <- Parens(List(ColId))\n" + "MergeIntoStatement <- WithClause? 'MERGE' 'INTO' TargetOptAlias MergeIntoUsingClause JoinQualifier MergeMatch+ " + "ReturningClause?\n" + "MergeIntoUsingClause <- 'USING' TableRef\n" + "MergeMatch <- MatchedClause / NotMatchedClause\n" + "MatchedClause <- 'WHEN' 'MATCHED' AndExpression? 'THEN' MatchedClauseAction\n" + "MatchedClauseAction <- UpdateMatchClause / DeleteMatchClause / InsertMatchClause / DoNothingMatchClause / " + "ErrorMatchClause\n" + "UpdateMatchClause <- 'UPDATE' UpdateMatchInfo?\n" + "UpdateMatchInfo <- UpdateMatchSetAction / UpdateByNameOrPosition\n" + "UpdateMatchSetAction <- UpdateMatchSetClause\n" + "UpdateByNameOrPosition <- ByNameOrPosition\n" + "DeleteMatchClause <- 'DELETE'\n" + "InsertMatchClause <- 'INSERT' InsertMatchInfo?\n" + "InsertMatchInfo <- InsertValuesList / InsertDefaultValues / InsertByNameOrPosition\n" + "InsertDefaultValues <- 'DEFAULT' 'VALUES'\n" + "InsertByNameOrPosition <- ByNameOrPosition? '*'?\n" + "InsertValuesList <- InsertColumnList? 'VALUES' Parens(List(Expression))\n" + "DoNothingMatchClause <- 'DO' 'NOTHING'\n" + "ErrorMatchClause <- 'ERROR' Expression?\n" + "UpdateMatchSetClause <- 'SET' UpdateMatchSetInfo\n" + "UpdateMatchSetInfo <- UpdateSetClause / StarSymbol\n" + "AndExpression <- 'AND' Expression\n" + "NotMatchedClause <- 'WHEN' 'NOT' 'MATCHED' BySourceOrTarget? AndExpression? 'THEN' MatchedClauseAction\n" + "BySourceOrTarget <- BySource / ByTarget\n" + "BySource <- 'BY' 'SOURCE'\n" + "ByTarget <- 'BY' 'TARGET'\n" + "PragmaStatement <- 'PRAGMA' PragmaAssignOrFunction\n" + "PragmaAssignOrFunction <- PragmaAssign / PragmaFunction\n" + "PragmaAssign <- SettingName '=' VariableList\n" + "PragmaFunction <- PragmaName PragmaParameters?\n" + "PragmaParameters <- Parens(List(Expression))\n" + "DeallocateStatement <- 'DEALLOCATE' DeallocatePrepare? Identifier\n" + "DeallocatePrepare <- 'PREPARE'\n" + "PrepareStatement <- 'PREPARE' Identifier TypeList? 'AS' Statement\n" + "TypeList <- Parens(List(Type))\n" + "CreateStatement <- 'CREATE' OrReplace? Temporary? CreateStatementVariation\n" + "CreateStatementVariation <- CreateTableStmt / CreateMacroStmt / CreateSequenceStmt / CreateTypeStmt / " + "CreateSchemaStmt / CreateViewStmt / CreateIndexStmt / CreateSecretStmt / CreateTriggerStmt\n" + "OrReplace <- 'OR' 'REPLACE'\n" + "Temporary <- Persistent / TempPersistent / TemporaryPersistent\n" + "Persistent <- 'PERSISTENT'\n" + "TempPersistent <- 'TEMP'\n" + "TemporaryPersistent <- 'TEMPORARY'\n" + "CreateTableStmt <- 'TABLE' IfNotExists? QualifiedName CreateTableDefinition CommitAction?\n" + "CreateTableDefinition <- CreateTableAs / CreateColumnList\n" + "CreateTableAs <- IdentifierList? PartitionSortedOptions? WithList? 'AS' Statement WithData?\n" + "PartitionSortedOptions <- PartitionOptSortedOptions / SortedOptPartitionOptions\n" + "PartitionOptSortedOptions <- PartitionOptions SortedOptions?\n" + "SortedOptPartitionOptions <- SortedOptions PartitionOptions?\n" + "PartitionOptions <- 'PARTITIONED' 'BY' Parens(List(Expression))\n" + "SortedOptions <- 'SORTED' 'BY' Parens(List(Expression))\n" + "WithData <- WithDataOnly / WithNoData\n" + "WithDataOnly <- 'WITH' 'DATA'\n" + "WithNoData <- 'WITH' 'NO' 'DATA'\n" + "IdentifierList <- Parens(List(Identifier))\n" + "CreateColumnList <- Parens(CreateTableColumnList?) PartitionSortedOptions? WithList?\n" + "IfNotExists <- 'IF' 'NOT' 'EXISTS'\n" + "QualifiedName <- CatalogReservedSchemaIdentifier / SchemaReservedIdentifierOrStringLiteral / " + "IdentifierOrStringLiteral\n" + "SchemaReservedIdentifierOrStringLiteral <- SchemaQualification ReservedIdentifierOrStringLiteral\n" + "CatalogReservedSchemaIdentifier <- CatalogQualification ReservedSchemaQualification " + "ReservedIdentifierOrStringLiteral\n" + "IdentifierOrStringLiteral <- Identifier / StringLiteral\n" + "ReservedIdentifierOrStringLiteral <- ReservedIdentifier / StringLiteral\n" + "CatalogQualification <- CatalogName '.'\n" + "SchemaQualification <- SchemaName '.'\n" + "ReservedSchemaQualification <- ReservedSchemaName '.'\n" + "TableQualification <- TableName '.'\n" + "ReservedTableQualification <- ReservedTableName '.'\n" + "CreateTableColumnList <- List(CreateTableColumnElement)\n" + "CreateTableColumnElement <- CreateTableColumnDefinition / CreateTableConstraint\n" + "CreateTableColumnDefinition <- ColumnDefinition\n" + "CreateTableConstraint <- TopLevelConstraint\n" + "ColumnDefinition <- DottedIdentifier Type? GeneratedColumn? ConstraintNameClause? ColumnConstraint*\n" + "ColumnConstraint <- NotNullConstraint / UniqueConstraint / PrimaryKeyConstraint / DefaultValue / CheckConstraint " + "/ ForeignKeyConstraint / ColumnCollation / ColumnCompression\n" + "NotNullConstraint <- NullConstraint / NotNullColumnConstraint\n" + "NullConstraint <- 'NULL'\n" + "NotNullColumnConstraint <- 'NOT' 'NULL'\n" + "UniqueConstraint <- 'UNIQUE'\n" + "PrimaryKeyConstraint <- 'PRIMARY' 'KEY'\n" + "DefaultValue <- 'DEFAULT' ColumnDefaultExpr\n" + "CheckConstraint <- 'CHECK' Parens(Expression)\n" + "ForeignKeyConstraint <- 'REFERENCES' BaseTableName Parens(ColumnList)? KeyActions\n" + "ColumnCollation <- 'COLLATE' DottedIdentifier\n" + "ColumnCompression <- 'USING' 'COMPRESSION' ColIdOrString\n" + "KeyActions <- UpdateAction? DeleteAction?\n" + "UpdateAction <- 'ON' 'UPDATE' KeyAction\n" + "DeleteAction <- 'ON' 'DELETE' KeyAction\n" + "KeyAction <- NoKeyAction / RestrictKeyAction / CascadeKeyAction / SetNullKeyAction / SetDefaultKeyAction\n" + "NoKeyAction <- 'NO' 'ACTION'\n" + "RestrictKeyAction <- 'RESTRICT'\n" + "CascadeKeyAction <- 'CASCADE'\n" + "SetNullKeyAction <- 'SET' 'NULL'\n" + "SetDefaultKeyAction <- 'SET' 'DEFAULT'\n" + "TopLevelConstraint <- ConstraintNameClause? TopLevelConstraintList\n" + "TopLevelConstraintList <- TopPrimaryKeyConstraint / CheckConstraint / TopUniqueConstraint / " + "TopForeignKeyConstraint\n" + "ConstraintNameClause <- 'CONSTRAINT' Identifier\n" + "TopPrimaryKeyConstraint <- 'PRIMARY' 'KEY' ColumnIdList\n" + "TopUniqueConstraint <- 'UNIQUE' ColumnIdList\n" + "TopForeignKeyConstraint <- 'FOREIGN' 'KEY' ColumnIdList ForeignKeyConstraint\n" + "ColumnIdList <- Parens(List(ColId))\n" + "PlainIdentifier <- !ReservedKeyword <[a-z_]i[a-z0-9_]i*>\n" + "QuotedIdentifier <- '\"' [^\"]* '\"'\n" + "DottedIdentifier <- Identifier DotColLabel*\n" + "DotColLabel <- '.' ColLabel\n" + "Identifier <- QuotedIdentifier / PlainIdentifier\n" + "ColId <- UnreservedKeyword / ColumnNameKeyword / Identifier\n" + "ColIdOrString <- ColId / StringLiteral\n" + "TypeFuncName <- UnreservedKeyword / TypeNameKeyword / FuncNameKeyword / Identifier\n" + "TypeName <- UnreservedKeyword / TypeNameKeyword / Identifier\n" + "ColLabel <- ReservedKeyword / UnreservedKeyword / ColumnNameKeyword / FuncNameKeyword / TypeNameKeyword / " + "Identifier\n" + "ColLabelOrString <- ColLabel / StringLiteral\n" + "GeneratedColumn <- Generated? 'AS' Parens(Expression) GeneratedColumnType?\n" + "Generated <- 'GENERATED' AlwaysOrByDefault?\n" + "AlwaysOrByDefault <- 'ALWAYS' / ('BY' 'DEFAULT')\n" + "GeneratedColumnType <- VirtualGeneratedColumn / StoredGeneratedColumn\n" + "CommitAction <- 'ON' 'COMMIT' PreserveOrDelete 'ROWS'\n" + "PreserveOrDelete <- PreserveRows / DeleteRows\n" + "PreserveRows <- 'PRESERVE'\n" + "DeleteRows <- 'DELETE'\n" + "VirtualGeneratedColumn <- 'VIRTUAL'\n" + "StoredGeneratedColumn <- 'STORED'\n" + "CreateIndexStmt <- UniqueIndex? 'INDEX' IfNotExists? IndexName? 'ON' BaseTableName InsertColumnList? IndexType? " + "Parens(List(IndexElement))? WithList? WhereClause?\n" + "WithList <- 'WITH' RelOptionOrOids\n" + "RelOptionOrOids <- RelOptionList / Oids\n" + "RelOptionList <- Parens(List(RelOption))\n" + "Oids <- WithOrWithoutOids 'OIDS'\n" + "WithOrWithoutOids <- WithOids / WithoutOids\n" + "WithOids <- 'WITH'\n" + "WithoutOids <- 'WITHOUT'\n" + "IndexElement <- Expression DescOrAsc? NullsFirstOrLast?\n" + "UniqueIndex <- 'UNIQUE'\n" + "IndexType <- 'USING' Identifier\n" + "RelOption <- RelOptionName RelOptionArgumentOpt?\n" + "RelOptionName <- DottedIdentifierString / StringLiteral\n" + "DottedIdentifierString <- DottedIdentifier\n" + "RelOptionArgumentOpt <- '=' DefArg\n" + "DefArg <- DefArgNull / DefArgKeyword / DefArgStringLiteral / NumberLiteral / NoneLiteral / Expression\n" + "DefArgNull <- NullLiteral\n" + "DefArgKeyword <- ReservedKeyword\n" + "DefArgStringLiteral <- StringLiteral\n" + "NoneLiteral <- 'NONE'\n" + "LoadStatement <- 'LOAD' ColIdOrString ExtensionAlias?\n" + "ExtensionAlias <- 'AS' Identifier\n" + "InstallStatement <- 'FORCE'? 'INSTALL' IdentifierOrStringLiteral FromSource? VersionNumber?\n" + "UpdateExtensionsStatement <- 'UPDATE' 'EXTENSIONS' Parens(List(Identifier))?\n" + "FromSource <- FromSourceIdentifier / FromSourceString\n" + "FromSourceIdentifier <- 'FROM' Identifier\n" + "FromSourceString <- 'FROM' StringLiteral\n" + "VersionNumber <- 'VERSION' IdentifierOrStringLiteral\n" + "DropStatement <- 'DROP' DropEntries DropBehavior?\n" + "DropEntries <-\n" + " DropTable /\n" + " DropTableFunction /\n" + " DropFunction /\n" + " DropSchema /\n" + " DropIndex /\n" + " DropSequence /\n" + " DropCollation /\n" + " DropType /\n" + " DropSecret /\n" + " DropTrigger\n" + "DropTrigger <- 'TRIGGER' IfExists? TriggerName 'ON' BaseTableName\n" + "DropTable <- TableOrView IfExists? List(BaseTableName)\n" + "DropTableFunction <- CommentMacroTable IfExists? List(TableFunctionName)\n" + "DropFunction <- FunctionTypeMacro IfExists? List(FunctionIdentifier)\n" + "DropSchema <- 'SCHEMA' IfExists? List(QualifiedSchemaName)\n" + "DropIndex <- 'INDEX' IfExists? List(QualifiedIndexName)\n" + "QualifiedIndexName <- CatalogReservedSchemaIndex / SchemaReservedIndex / QualifiedIndexNameString\n" + "QualifiedIndexNameString <- IndexName\n" + "SchemaReservedIndex <- SchemaQualification ReservedIndexName\n" + "CatalogReservedSchemaIndex <- CatalogQualification ReservedSchemaQualification ReservedIndexName\n" + "DropSequence <- 'SEQUENCE' IfExists? List(QualifiedSequenceName)\n" + "DropCollation <- 'COLLATION' IfExists? List(CollationName)\n" + "DropType <- 'TYPE' IfExists? List(QualifiedTypeName)\n" + "DropSecret <- Temporary? 'SECRET' IfExists? SecretName DropSecretStorage?\n" + "TableOrView <- CommentTable / CommentView / MaterializedViewEntry\n" + "MaterializedViewEntry <- 'MATERIALIZED' 'VIEW'\n" + "FunctionTypeMacro <- FunctionTypeMacroKeyword / FunctionTypeFunction\n" + "FunctionTypeMacroKeyword <- 'MACRO'\n" + "FunctionTypeFunction <- 'FUNCTION'\n" + "DropBehavior <- CascadeDropBehavior / RestrictDropBehavior\n" + "CascadeDropBehavior <- 'CASCADE'\n" + "RestrictDropBehavior <- 'RESTRICT'\n" + "IfExists <- 'IF' 'EXISTS'\n" + "QualifiedSchemaName <- CatalogReservedSchema / QualifiedSchemaNameString\n" + "QualifiedSchemaNameString <- SchemaName\n" + "CatalogReservedSchema <- CatalogQualification ReservedSchemaName\n" + "DropSecretStorage <- 'FROM' Identifier\n" + "UpdateStatement <- WithClause? 'UPDATE' UpdateTarget UpdateSetClause FromClause? WhereClause? ReturningClause?\n" + "UpdateTarget <- BaseTableSet / BaseTableAliasSet\n" + "BaseTableSet <- BaseTableName 'SET'\n" + "BaseTableAliasSet <- BaseTableName UpdateAlias? 'SET'\n" + "UpdateAlias <- 'AS'? ColId\n" + "UpdateSetClause <- UpdateSetElementList / UpdateSetTuple\n" + "UpdateSetTuple <- Parens(List(ColumnName)) '=' Expression\n" + "UpdateSetElementList <- List(UpdateSetElement)\n" + "UpdateSetElement <- UpdateSetColumnTarget '=' Expression\n" + "UpdateSetColumnTarget <- ColumnName DotIdentifier*\n" + "InsertStatement <- WithClause? 'INSERT' OrAction? 'INTO' InsertTarget ByNameOrPosition? InsertColumnList? " + "InsertValues OnConflictClause? ReturningClause?\n" + "OrAction <- InsertOrReplace / InsertOrIgnore\n" + "InsertOrReplace <- 'OR' 'REPLACE'\n" + "InsertOrIgnore <- 'OR' 'IGNORE'\n" + "ByNameOrPosition <- InsertByNameOrder / InsertByPositionOrder\n" + "InsertByNameOrder <- 'BY' InsertByName\n" + "InsertByPositionOrder <- 'BY' InsertByPosition\n" + "InsertByName <- 'NAME'\n" + "InsertByPosition <- 'POSITION'\n" + "InsertTarget <- BaseTableName InsertAlias?\n" + "InsertAlias <- 'AS' Identifier\n" + "ColumnList <- List(ColId)\n" + "InsertColumnList <- Parens(ColumnList)\n" + "InsertValues <- SelectInsertValues / DefaultValues\n" + "SelectInsertValues <- SelectStatementInternal\n" + "DefaultValues <- 'DEFAULT' 'VALUES'\n" + "OnConflictClause <- 'ON' 'CONFLICT' OnConflictTarget? OnConflictAction\n" + "OnConflictTarget <- OnConflictExpressionTarget / OnConflictIndexTarget\n" + "OnConflictExpressionTarget <- ColumnIdList WhereClause?\n" + "OnConflictIndexTarget <- 'ON' 'CONSTRAINT' ConstraintName\n" + "OnConflictAction <- OnConflictUpdate / OnConflictNothing\n" + "OnConflictUpdate <- 'DO' 'UPDATE' 'SET' UpdateSetClause WhereClause?\n" + "OnConflictNothing <- 'DO' 'NOTHING'\n" + "ReturningClause <- 'RETURNING' TargetList\n" + "CreateSchemaStmt <- 'SCHEMA' IfNotExists? QualifiedName\n" + "SelectStatement <- SelectStatementInternal\n" + "SelectStatementInternal <- WithClause? SelectSetOpChain ResultModifiers?\n" + "SelectSetOpChain <- IntersectChain (SetopClause IntersectChain)*\n" + "IntersectChain <- SelectAtom (SetIntersectClause SelectAtom)*\n" + "SetIntersectClause <- 'INTERSECT' DistinctOrAll?\n" + "SelectAtom <- SelectParens / SelectStatementType\n" + "SelectParens <- Parens(SelectStatementInternal)\n" + "SetopClause <- SetopType DistinctOrAll? ByName?\n" + "SetopType <- SetopUnion / SetopExcept\n" + "SetopUnion <- 'UNION'\n" + "SetopExcept <- 'EXCEPT'\n" + "ByName <- 'BY' 'NAME'\n" + "SelectStatementType <- OptionalParensSimpleSelect / ValuesClause / DescribeStatement / TableStatement / " + "PivotStatement / UnpivotStatement\n" + "ResultModifiers <- OrderByClause? LimitOffset?\n" + "LimitOffset <- LimitOffsetClause / OffsetLimitClause\n" + "LimitOffsetClause <- LimitClause OffsetClause?\n" + "OffsetLimitClause <- OffsetClause LimitClause?\n" + "TableStatement <- 'TABLE' BaseTableName\n" + "OptionalParensSimpleSelect <- SimpleSelectParens / SimpleSelect\n" + "SimpleSelectParens <- Parens(SimpleSelect)\n" + "SimpleSelect <- SelectFrom WhereClause? GroupByClause? HavingClause? WindowClause? QualifyClause? SampleClause?\n" + "SelectFrom <- SelectFromClause / FromSelectClause\n" + "SelectFromClause <- SelectClause FromClause?\n" + "FromSelectClause <- FromClause SelectClause?\n" + "WithStatement <- ColIdOrString InsertColumnList? UsingKey? 'AS' Materialized? CTEBody\n" + "CTEBody <- Parens(CTEBodyContent)\n" + "CTEBodyContent <- SelectStatementInternal / Statement\n" + "UsingKey <- 'USING' 'KEY' Parens(TargetList)\n" + "Materialized <- 'NOT'? 'MATERIALIZED'\n" + "WithClause <- 'WITH' Recursive? List(WithStatement)\n" + "Recursive <- 'RECURSIVE'\n" + "SelectClause <- 'SELECT' DistinctClause? TargetList?\n" + "TargetList <- List(AliasedExpression)\n" + "ColumnAliases <- Parens(List(ColIdOrString))\n" + "DistinctClause <- DistinctOn / DistinctAll\n" + "DistinctAll <- 'ALL'\n" + "DistinctOn <- 'DISTINCT' DistinctOnTargets?\n" + "DistinctOnTargets <- 'ON' Parens(List(Expression))\n" + "InnerTableRef <- ValuesRef / TableFunction / TableSubquery / BaseTableRef / ParensTableRef\n" + "TableRef <- InnerTableRef JoinOrPivot*\n" + "TableSubquery <- Lateral? SubqueryReference TableAlias?\n" + "BaseTableRef <- TableAliasColon? BaseTableName TableAlias? AtClause? SampleClause?\n" + "TableAliasColon <- ColIdOrString ':'\n" + "ValuesRef <- ValuesClause TableAlias?\n" + "ParensTableRef <- TableAliasColon? Parens(TableRef) TableAlias? SampleClause?\n" + "JoinOrPivot <- JoinClause / TablePivotClause / TableUnpivotClause\n" + "TablePivotClause <- 'PIVOT' Parens(TargetList 'FOR' PivotValueList+ PivotGroupByList?) TableAlias?\n" + "PivotGroupByList <- 'GROUP' 'BY' List(ColIdOrString)\n" + "TableUnpivotClause <- 'UNPIVOT' IncludeOrExcludeNulls? Parens(UnpivotHeader 'FOR' UnpivotValueList+) TableAlias?\n" + "PivotHeader <- BaseExpression\n" + "PivotValueList <- PivotHeader 'IN' (Identifier / PivotTargetList)\n" + "UnpivotValueList <- UnpivotHeader 'IN' UnpivotTargetList\n" + "PivotTargetList <- Parens(TargetList)\n" + "UnpivotTargetList <- Parens(TargetList)\n" + "Lateral <- 'LATERAL'\n" + "BaseTableName <- CatalogReservedSchemaTable / SchemaReservedTable / TableName\n" + "SchemaReservedTable <- SchemaQualification ReservedTableName\n" + "CatalogReservedSchemaTable <- CatalogQualification ReservedSchemaQualification ReservedTableName\n" + "TableFunction <- TableFunctionLateralOpt / TableFunctionAliasColon\n" + "TableFunctionLateralOpt <- Lateral? QualifiedTableFunction TableFunctionArguments WithOrdinality? TableAlias?\n" + "TableFunctionAliasColon <- TableAliasColon QualifiedTableFunction TableFunctionArguments WithOrdinality? " + "SampleClause?\n" + "WithOrdinality <- 'WITH' 'ORDINALITY'\n" + "QualifiedTableFunction <- CatalogQualification? SchemaQualification? TableFunctionName\n" + "TableFunctionArguments <- Parens(List(FunctionArgument)?)\n" + "FunctionArgument <- NamedParameter / Expression\n" + "NamedParameter <- TypeFuncName Type? NamedParameterAssignment Expression\n" + "NamedParameterAssignment <- ':=' / '=>'\n" + "TableAlias <- TableAliasAs / TableAliasWithoutAs\n" + "TableAliasAs <- 'AS' IdentifierOrStringLiteral ColumnAliases?\n" + "TableAliasWithoutAs <- Identifier ColumnAliases?\n" + "AtClause <- 'AT' Parens(AtSpecifier)\n" + "AtSpecifier <- AtUnit '=>' Expression\n" + "AtUnit <- 'VERSION' / 'TIMESTAMP'\n" + "JoinClause <- RegularJoinClause / JoinWithoutOnClause\n" + "RegularJoinClause <- 'ASOF'? JoinType? 'JOIN' TableRef JoinQualifier\n" + "JoinWithoutOnClause <- JoinPrefix 'JOIN' TableRef\n" + "JoinQualifier <- OnClause / UsingClause\n" + "OnClause <- 'ON' Expression\n" + "UsingClause <- 'USING' Parens(List(ColumnName))\n" + "JoinType <- FullJoin / LeftJoin / RightJoin / SemiJoin / AntiJoin / InnerJoin\n" + "JoinPrefix <- CrossJoinPrefix / NaturalJoinPrefix / PositionalJoinPrefix\n" + "CrossJoinPrefix <- 'CROSS'\n" + "NaturalJoinPrefix <- 'NATURAL' JoinType?\n" + "PositionalJoinPrefix <- 'POSITIONAL'\n" + "FullJoin <- 'FULL' 'OUTER'?\n" + "LeftJoin <- 'LEFT' 'OUTER'?\n" + "RightJoin <- 'RIGHT' 'OUTER'?\n" + "SemiJoin <- 'SEMI'\n" + "AntiJoin <- 'ANTI'\n" + "InnerJoin <- 'INNER'\n" + "FromClause <- 'FROM' List(TableRef)\n" + "WhereClause <- 'WHERE' Expression\n" + "GroupByClause <- 'GROUP' 'BY' GroupByExpressions\n" + "HavingClause <- 'HAVING' Expression\n" + "QualifyClause <- 'QUALIFY' Expression\n" + "SampleClause <- (TableSample / UsingSample) SampleEntry\n" + "UsingSample <- 'USING' 'SAMPLE'\n" + "TableSample <- 'TABLESAMPLE'\n" + "WindowClause <- 'WINDOW' List(WindowDefinition)\n" + "WindowDefinition <- Identifier 'AS' WindowFrameDefinition\n" + "SampleEntry <- SampleEntryFunction / SampleEntryCount\n" + "SampleEntryCount <- SampleCount Parens(SampleProperties)?\n" + "SampleEntryFunction <- SampleFunction? Parens(SampleCount) RepeatableSample?\n" + "SampleFunction <- ColId\n" + "SampleProperties <- ColId (',' SampleSeed)?\n" + "RepeatableSample <- 'REPEATABLE' Parens(SampleSeed)\n" + "SampleSeed <- NumberLiteral\n" + "SampleCount <- SampleValue SampleUnit?\n" + "SampleValue <- NumberLiteral / Parameter\n" + "SampleUnit <- SamplePercentage / SampleRows\n" + "SamplePercentage <- '%' / 'PERCENT'\n" + "SampleRows <- 'ROWS'\n" + "GroupByExpressions <- GroupByList / GroupByAll\n" + "GroupByAll <- 'ALL'\n" + "GroupByList <- List(GroupByExpression)\n" + "GroupByExpression <- EmptyGroupingItem / CubeOrRollupClause / GroupingSetsClause / Expression\n" + "EmptyGroupingItem <- '(' ')'\n" + "CubeOrRollupClause <- CubeOrRollup Parens(List(Expression)?)\n" + "CubeOrRollup <- 'CUBE' / 'ROLLUP'\n" + "GroupingSetsClause <- 'GROUPING' 'SETS' Parens(GroupByList)\n" + "SubqueryReference <- Parens(SelectStatementInternal)\n" + "OrderByExpression <- Expression DescOrAsc? NullsFirstOrLast?\n" + "DescOrAsc <- DescendingOrder / AscendingOrder\n" + "DescendingOrder <- 'DESC' / 'DESCENDING'\n" + "AscendingOrder <- 'ASC' / 'ASCENDING'\n" + "NullsFirstOrLast <- NullsFirst / NullsLast\n" + "NullsFirst <- 'NULLS' 'FIRST'\n" + "NullsLast <- 'NULLS' 'LAST'\n" + "OrderByClause <- 'ORDER' 'BY' OrderByExpressions\n" + "OrderByExpressions <- OrderByAll / OrderByExpressionList\n" + "OrderByExpressionList <- List(OrderByExpression)\n" + "OrderByAll <- 'ALL' DescOrAsc? NullsFirstOrLast?\n" + "LimitClause <- 'LIMIT' LimitValue\n" + "OffsetClause <- 'OFFSET' OffsetValue\n" + "OffsetValue <- Expression RowOrRows?\n" + "RowOrRows <- 'ROW' / 'ROWS'\n" + "LimitValue <- LimitAll / LimitLiteralPercent / LimitExpression\n" + "LimitAll <- 'ALL'\n" + "LimitLiteralPercent <- NumberLiteral 'PERCENT'\n" + "LimitExpression <- Expression '%'?\n" + "AliasedExpression <- ColIdExpression / ExpressionAsCollabel / ExpressionOptIdentifier\n" + "ColIdExpression <- ColId ':' Expression\n" + "ExpressionAsCollabel <- Expression 'AS' ColLabelOrString\n" + "ExpressionOptIdentifier <- Expression Identifier?\n" + "ValuesClause <- 'VALUES' List(ValuesExpressions)\n" + "ValuesExpressions <- Parens(List(Expression))\n" + "TransactionStatement <- BeginTransaction / RollbackTransaction / CommitTransaction\n" + "BeginTransaction <- StartOrBegin Transaction? ReadOrWrite?\n" + "RollbackTransaction <- AbortOrRollback Transaction?\n" + "CommitTransaction <- CommitOrEnd Transaction?\n" + "StartOrBegin <- 'START' / 'BEGIN'\n" + "Transaction <- 'WORK' / 'TRANSACTION'\n" + "ReadOrWrite <- 'READ' ReadOnlyOrReadWrite\n" + "ReadOnlyOrReadWrite <- ReadOnly / ReadWrite\n" + "ReadOnly <- 'ONLY'\n" + "ReadWrite <- 'WRITE'\n" + "AbortOrRollback <- 'ABORT' / 'ROLLBACK'\n" + "CommitOrEnd <- 'COMMIT' / 'END'\n" + "DeleteStatement <- WithClause? 'DELETE' 'FROM' TargetOptAlias DeleteUsingClause? WhereClause? ReturningClause?\n" + "TruncateStatement <- 'TRUNCATE' 'TABLE'? BaseTableName\n" + "TargetOptAlias <- BaseTableName 'AS'? ColId?\n" + "DeleteUsingClause <- 'USING' List(TableRef)\n" + "ConnectStatement <- 'CONNECT' SessionTarget?\n" + "DisconnectStatement <- 'DISCONNECT'\n" + "SessionTarget <- 'LOCAL' / StringLiteral / CatalogName\n" + "CreateTypeStmt <- 'TYPE' IfNotExists? QualifiedName 'AS' CreateType\n" + "CreateType <- EnumSelectType / EnumStringLiteralList / CreateTypeFromType\n" + "CreateTypeFromType <- Type\n" + "EnumSelectType <- 'ENUM' Parens(SelectStatementInternal)\n" + "EnumStringLiteralList <- 'ENUM' Parens(List(StringLiteral)?)\n" + "SetStatement <- 'SET' SetAssignmentOrTimeZone\n" + "SetAssignmentOrTimeZone <- StandardAssignment / SetTimeZone\n" + "ResetStatement <- 'RESET' SetVariableOrSetting\n" + "StandardAssignment <- SetVariableOrSetting SetAssignment\n" + "SetVariableOrSetting <- SetVariable / SetSetting\n" + "SetTimeZone <- 'TIME' 'ZONE' ZoneValue\n" + "ZoneValue <- ZoneIntervalWithPrecision / ZoneIntervalWithInterval / ZoneLocal / ZoneDefault / ZoneStringLiteral / " + "ZoneIdentifier / NumberLiteral\n" + "ZoneLocal <- 'LOCAL'\n" + "ZoneDefault <- 'DEFAULT'\n" + "ZoneStringLiteral <- StringLiteral\n" + "ZoneIdentifier <- Identifier\n" + "ZoneIntervalWithInterval <- 'INTERVAL' StringLiteral Interval?\n" + "ZoneIntervalWithPrecision <- 'INTERVAL' Parens(NumberLiteral) StringLiteral\n" + "SetSetting <- SettingScope? SettingName\n" + "SetVariable <- VariableScope Identifier\n" + "VariableScope <- 'VARIABLE'\n" + "SettingScope <- LocalScope / SessionScope / GlobalScope\n" + "LocalScope <- 'LOCAL'\n" + "SessionScope <- 'SESSION'\n" + "GlobalScope <- 'GLOBAL'\n" + "SetAssignment <- VariableAssign VariableList\n" + "VariableAssign <- '=' / 'TO'\n" + "VariableList <- List(Expression)\n" + "ExportStatement <- 'EXPORT' 'DATABASE' ExportSource? StringLiteral GenericCopyOptionList?\n" + "ExportSource <- CatalogName 'TO'\n" + "ImportStatement <- 'IMPORT' 'DATABASE' StringLiteral\n" + "CheckpointStatement <- CheckpointForce? 'CHECKPOINT' CatalogName?\n" + "CheckpointForce <- 'FORCE'\n" + "CopyStatement <- 'COPY' CopyVariations\n" + "CopyVariations <- CopyTable / CopySelect / CopyFromDatabase\n" + "CopyTable <- BaseTableName InsertColumnList? FromOrTo CopyFileName CopyOptions?\n" + "FromOrTo <- CopyFrom / CopyTo\n" + "CopyFrom <- 'FROM'\n" + "CopyTo <- 'TO'\n" + "CopySelect <- Parens(SelectStatementInternal) 'TO' CopyFileName CopyOptions?\n" + "CopyFileName <- CopyFileNameExpression / CopyFileNameStringLiteral / CopyFileNameIdentifier / " + "CopyFileNameIdentifierColId\n" + "CopyFileNameExpression <- Parameter / ParensExpression\n" + "CopyFileNameStringLiteral <- StringLiteral\n" + "CopyFileNameIdentifier <- Identifier\n" + "CopyFileNameIdentifierColId <- IdentifierColId\n" + "IdentifierColId <- Identifier '.' ColId\n" + "CopyOptions <- 'WITH'? CopyOptionList\n" + "CopyOptionList <- GenericCopyOptionList / SpecializedOptionList\n" + "SpecializedOptionList <- SpecializedOption*\n" + "SpecializedOption <- SingleOption / NullAsOption / DelimiterAsOption / EscapeAsOption /\n" + " EncodingOption / QuoteAsOption / ForceQuoteOption / PartitionByOption / ForceNullOption\n" + "SingleOption <- BinaryOption / FreezeOption / OidsOption / CsvOption / HeaderOption\n" + "BinaryOption <- 'BINARY'\n" + "FreezeOption <- 'FREEZE'\n" + "OidsOption <- 'OIDS'\n" + "CsvOption <- 'CSV'\n" + "HeaderOption <- 'HEADER'\n" + "NullAsOption <- 'NULL' 'AS'? StringLiteral\n" + "DelimiterAsOption <- 'DELIMITER' 'AS'? StringLiteral\n" + "QuoteAsOption <- 'QUOTE' 'AS'? StringLiteral\n" + "EscapeAsOption <- 'ESCAPE' 'AS'? StringLiteral\n" + "EncodingOption <- 'ENCODING' StringLiteral\n" + "ForceQuoteOption <- ForceQuote? 'QUOTE' StarSymbolColumnList\n" + "StarSymbolColumnList <- StarSymbol / ColumnList\n" + "ForceQuote <- 'FORCE'\n" + "PartitionByOption <- 'PARTITION' 'BY' StarSymbolColumnList\n" + "ForceNullOption <- 'FORCE' ForceNotNull? 'NULL' ColumnList\n" + "ForceNotNull <- 'NOT'\n" + "StarSymbol <- '*'\n" + "GenericCopyOptionList <- Parens(List(GenericCopyOption))\n" + "GenericCopyOption <- CopyOptionName GenericCopyOptionValue?\n" + "GenericCopyOptionValue <- GenericCopyOptionOrderList / GenericCopyOptionExpression\n" + "GenericCopyOptionOrderList <- GenericCopyOptionParenthesizedExpressionList\n" + "GenericCopyOptionExpression <- Expression\n" + "GenericCopyOptionParenthesizedExpressionList <- Parens(OrderByExpressionList)\n" + "CopyFromDatabase <- CopyFromDatabaseWithFlag / CopyFromDatabaseWithoutFlag\n" + "CopyFromDatabaseWithFlag <- 'FROM' 'DATABASE' ColId 'TO' ColId CopyDatabaseFlag\n" + "CopyFromDatabaseWithoutFlag <- 'FROM' 'DATABASE' ColId 'TO' ColId\n" + "CopyDatabaseFlag <- Parens(SchemaOrData)\n" + "SchemaOrData <- CopySchema / CopyData\n" + "CopySchema <- 'SCHEMA'\n" + "CopyData <- 'DATA'\n" + "AlterStatement <- 'ALTER' AlterOptions\n" + "AlterOptions <- AlterTableStmt / AlterViewStmt / AlterSequenceStmt / AlterDatabaseStmt / AlterSchemaStmt\n" + "AlterTableStmt <- 'TABLE' IfExists? BaseTableName List(AlterTableOptions)\n" + "AlterSchemaStmt <- 'SCHEMA' IfExists? QualifiedName RenameAlter\n" + "AlterTableOptions <- AddColumn / DropColumn / AlterColumn / AddConstraint / ChangeNullability /\n" + " RenameColumn / RenameAlter / SetPartitionedBy / ResetPartitionedBy / SetSortedBy / ResetSortedBy / " + "SetOptions / ResetOptions\n" + "AddConstraint <- 'ADD' TopLevelConstraint\n" + "AddColumn <- 'ADD' 'COLUMN'? IfNotExists? AddColumnEntry\n" + "AddColumnEntry <- DottedIdentifier Type? GeneratedColumn? ColumnConstraint*\n" + "DropColumn <- 'DROP' 'COLUMN'? IfExists? NestedColumnName DropBehavior?\n" + "AlterColumn <- 'ALTER' 'COLUMN'? NestedColumnName AlterColumnEntry\n" + "RenameColumn <- 'RENAME' 'COLUMN'? NestedColumnName 'TO' Identifier\n" + "NestedColumnName <- IdentifierDot* ColumnName\n" + "IdentifierDot <- Identifier '.'\n" + "RenameAlter <- 'RENAME' 'TO' Identifier\n" + "SetPartitionedBy <- 'SET' 'PARTITIONED' 'BY' Parens(List(Expression))\n" + "ResetPartitionedBy <- 'RESET' 'PARTITIONED' 'BY'\n" + "SetSortedBy <- 'SET' 'SORTED' 'BY' Parens(OrderByExpressions)\n" + "ResetSortedBy <- 'RESET' 'SORTED' 'BY'\n" + "SetOptions <- 'SET' RelOptionList\n" + "ResetOptions <- 'RESET' RelOptionList\n" + "AlterColumnEntry <- AddOrDropDefault / ChangeNullability / AlterType\n" + "AddOrDropDefault <- AddDefault / DropDefault\n" + "AddDefault <- 'SET' 'DEFAULT' Expression\n" + "DropDefault <- 'DROP' 'DEFAULT'\n" + "ChangeNullability <- DropOrSet 'NOT' 'NULL'\n" + "DropOrSet <- DropNullability / SetNullability\n" + "DropNullability <- 'DROP'\n" + "SetNullability <- 'SET'\n" + "AlterType <- SetData? 'TYPE' Type? UsingExpression?\n" + "SetData <- 'SET' 'DATA'?\n" + "UsingExpression <- 'USING' Expression\n" + "AlterViewStmt <- 'VIEW' IfExists? BaseTableName RenameAlter\n" + "AlterSequenceStmt <- 'SEQUENCE' IfExists? QualifiedSequenceName AlterSequenceOptions\n" + "QualifiedSequenceName <- CatalogQualification? SchemaQualification? SequenceName\n" + "AlterSequenceOptions <- RenameAlter / SetSequenceOption\n" + "SetSequenceOption <- SequenceOption+\n" + "AlterDatabaseStmt <- 'DATABASE' IfExists? Identifier 'SET' 'ALIAS' 'TO' Identifier\n" + "CreateSequenceStmt <- 'SEQUENCE' IfNotExists? QualifiedName SequenceOption*\n" + "SequenceOption <-\n" + " SeqSetCycle /\n" + " SeqSetIncrement /\n" + " SeqSetMinMax /\n" + " SeqNoMinMax /\n" + " SeqStartWith /\n" + " SeqOwnedBy\n" + "SeqSetCycle <- SeqCycle / SeqNoCycle\n" + "SeqCycle <- 'CYCLE'\n" + "SeqNoCycle <- 'NO' 'CYCLE'\n" + "SeqSetIncrement <- 'INCREMENT' 'BY'? Expression\n" + "SeqSetMinMax <- SeqMinOrMax Expression\n" + "SeqNoMinMax <- 'NO' SeqMinOrMax\n" + "SeqStartWith <- 'START' 'WITH'? Expression\n" + "SeqOwnedBy <- 'OWNED' 'BY' QualifiedName\n" + "SeqMinOrMax <- MinValue / MaxValue\n" + "MinValue <- 'MINVALUE'\n" + "MaxValue <- 'MAXVALUE'\n" + "ExplainStatement <- 'EXPLAIN' ExplainAnalyze? ExplainOptionList? ExplainableStatements\n" + "ExplainAnalyze <- 'ANALYZE'\n" + "ExplainOptionList <- Parens(List(ExplainOption))\n" + "ExplainOption <- ExplainOptionName Expression?\n" + "ExplainOptionName <- 'ANALYZE' / 'ANALYSE' / ColId / FuncNameKeyword / TypeNameKeyword\n" + "ExplainSelectStatement <- SelectStatementInternal\n" + "ExplainableStatements <-\n" + " AlterStatement /\n" + " AnalyzeStatement /\n" + " CallStatement /\n" + " CheckpointStatement /\n" + " CopyStatement /\n" + " CreateStatement /\n" + " DeallocateStatement /\n" + " DeleteStatement /\n" + " DropStatement /\n" + " ExecuteStatement /\n" + " InsertStatement /\n" + " LoadStatement /\n" + " MergeIntoStatement /\n" + " PragmaStatement /\n" + " PrepareStatement /\n" + " ExplainSelectStatement /\n" + " TransactionStatement /\n" + " UpdateStatement /\n" + " VacuumStatement /\n" + " SetStatement /\n" + " ResetStatement\n" + "AnalyzeStatement <- 'ANALYZE' AnalyzeVerbose? AnalyzeTarget?\n" + "AnalyzeTarget <- BaseTableName NameList?\n" + "AnalyzeVerbose <- 'VERBOSE'\n" + "CreateMacroStmt <- MacroOrFunction IfNotExists? QualifiedName List(MacroDefinition)\n" + "MacroOrFunction <- MacroKeyword / FunctionKeyword\n" + "MacroKeyword <- 'MACRO'\n" + "FunctionKeyword <- 'FUNCTION'\n" + "MacroDefinition <- Parens(MacroParameters?) 'AS' MacroDefinitionBody\n" + "MacroDefinitionBody <- TableMacroDefinition / ScalarMacroDefinition\n" + "MacroParameters <- List(MacroParameter)\n" + "MacroParameter <- NamedParameter / SimpleParameter\n" + "SimpleParameter <- TypeFuncName Type?\n" + "ScalarMacroDefinition <- Expression\n" + "TableMacroDefinition <- 'TABLE' SelectStatementInternal\n" + "CommentStatement <- 'COMMENT' 'ON' CommentOnType DottedIdentifier 'IS' CommentValue\n" + "CommentOnType <- CommentTable / CommentSequence / CommentFunction / CommentMacroTable / CommentMacro /\n" + " CommentView / CommentDatabase / CommentIndex / CommentSchema / CommentType / CommentColumn\n" + "CommentTable <- 'TABLE'\n" + "CommentSequence <- 'SEQUENCE'\n" + "CommentFunction <- 'FUNCTION'\n" + "CommentMacroTable <- 'MACRO' 'TABLE'\n" + "CommentMacro <- 'MACRO'\n" + "CommentView <- 'VIEW'\n" + "CommentDatabase <- 'DATABASE'\n" + "CommentIndex <- 'INDEX'\n" + "CommentSchema <- 'SCHEMA'\n" + "CommentType <- 'TYPE'\n" + "CommentColumn <- 'COLUMN'\n" + "CommentValue <- NullLiteral / StringLiteral\n" + "CreateTriggerStmt <- 'TRIGGER' IfNotExists? TriggerName TriggerTiming TriggerEvent 'ON' BaseTableName " + "ReferencingClause? ForEachClause? TriggerBody\n" + "TriggerBody <- InsertStatement / UpdateStatement / DeleteStatement\n" + "TriggerName <- Identifier\n" + "ReferencingClause <- 'REFERENCING' ReferencingItem ReferencingItem?\n" + "ReferencingItem <- ReferencingNewTableAs / ReferencingOldTableAs\n" + "ReferencingNewTableAs <- 'NEW' 'TABLE' 'AS' ColId\n" + "ReferencingOldTableAs <- 'OLD' 'TABLE' 'AS' ColId\n" + "TriggerTiming <- TriggerBefore / TriggerAfter / TriggerInsteadOf\n" + "TriggerBefore <- 'BEFORE'\n" + "TriggerAfter <- 'AFTER'\n" + "TriggerInsteadOf <- 'INSTEAD' 'OF'\n" + "TriggerEvent <- TriggerEventUpdateOf / TriggerEventInsert / TriggerEventDelete / TriggerEventUpdate\n" + "TriggerEventInsert <- 'INSERT'\n" + "TriggerEventDelete <- 'DELETE'\n" + "TriggerEventUpdate <- 'UPDATE'\n" + "TriggerEventUpdateOf <- 'UPDATE' 'OF' TriggerColumnList\n" + "TriggerColumnList <- List(ColId)\n" + "ForEachClause <- ForEachRow / ForEachStatement\n" + "ForEachRow <- 'FOR' 'EACH' 'ROW'\n" + "ForEachStatement <- 'FOR' 'EACH' 'STATEMENT'\n" + "AttachStatement <- 'ATTACH' OrReplace? IfNotExists? Database? DatabasePath AttachAlias? AttachOptions?\n" + "Database <- 'DATABASE'\n" + "DatabasePath <- Expression\n" + "AttachAlias <- 'AS' ColId\n" + "AttachOptions <- GenericCopyOptionList\n" + "DetachStatement <- 'DETACH' Database? IfExists? CatalogName\n" + "UseStatement <- 'USE' UseTarget\n" + "UseTarget <- UseTargetCatalogSchema / SchemaNameAsUseTarget / CatalogNameAsUseTarget\n" + "SchemaNameAsUseTarget <- SchemaName\n" + "CatalogNameAsUseTarget <- CatalogName\n" + "UseTargetCatalogSchema <- CatalogName '.' ReservedSchemaName DotIdentifier*\n" + "DotIdentifier <- '.' Identifier\n" + "CallStatement <- 'CALL' QualifiedTableFunction TableFunctionArguments\n" }; diff --git a/src/duckdb/src/include/duckdb/parser/peg/matcher.hpp b/src/duckdb/src/include/duckdb/parser/peg/matcher.hpp index 592749573..400dba3eb 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/matcher.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/matcher.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/string_util.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/reference_map.hpp" #include "duckdb/main/client_context.hpp" @@ -53,6 +54,11 @@ struct AutoCompleteCandidate { CandidateType candidate_type = CandidateType::IDENTIFIER) : AutoCompleteCandidate(string(candidate_p), suggestion_type, score_bonus, candidate_type) { } + // NOLINTNEXTLINE: allow implicit conversion from Identifier + AutoCompleteCandidate(const Identifier &candidate_p, SuggestionState suggestion_type, int32_t score_bonus = 0, + CandidateType candidate_type = CandidateType::IDENTIFIER) + : AutoCompleteCandidate(candidate_p.GetIdentifierName(), suggestion_type, score_bonus, candidate_type) { + } string candidate; //! Type being suggested @@ -116,8 +122,8 @@ struct MatcherSuggestion { struct MatchState { MatchState(vector &tokens, vector &suggestions, ParseResultAllocator &allocator, - idx_t &max_token_index, bool preserve_identifier_case_p = true) - : tokens(tokens), suggestions(suggestions), token_index(0), allocator(allocator), + idx_t &max_token_index, bool preserve_identifier_case_p = true, idx_t starting_token_index = 0) + : tokens(tokens), suggestions(suggestions), token_index(starting_token_index), allocator(allocator), max_token_index(max_token_index), preserve_identifier_case(preserve_identifier_case_p) { } MatchState(MatchState &state) @@ -147,7 +153,18 @@ struct MatchState { void AddSuggestion(MatcherSuggestion suggestion); }; -enum class MatcherType { KEYWORD, LIST, OPTIONAL, CHOICE, REPEAT, VARIABLE, STRING_LITERAL, NUMBER_LITERAL, OPERATOR }; +enum class MatcherType { + KEYWORD, + LIST, + OPTIONAL, + CHOICE, + REPEAT, + VARIABLE, + STRING_LITERAL, + NUMBER_LITERAL, + OPERATOR, + END_OF_INPUT +}; class Matcher { public: @@ -212,8 +229,11 @@ class ParseResultAllocator { struct PEGMatcher { MatcherAllocator allocator; - Matcher &Root() { - return *root; + Matcher &ProgramMatcher() { + return *program_matcher; + } + Matcher &TopLevelStatementMatcher() { + return *top_level_statement_matcher; } static shared_ptr Get(ClientContext &context); @@ -221,7 +241,8 @@ struct PEGMatcher { private: friend struct ParserCache; - optional_ptr root; + optional_ptr program_matcher; + optional_ptr top_level_statement_matcher; }; //! Per-database cache holder for the compiled PEG root matcher and transformer factory. diff --git a/src/duckdb/src/include/duckdb/parser/peg/token_type.hpp b/src/duckdb/src/include/duckdb/parser/peg/token_type.hpp new file mode 100644 index 000000000..e25e10884 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/peg/token_type.hpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/peg/token_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace duckdb { + +enum class TokenType { + INVALID, + KEYWORD, + STRING_LITERAL, + NUMBER_LITERAL, + OPERATOR, + IDENTIFIER, + COMMENT, + TERMINATOR, + CATALOG_NAME, + SCHEMA_NAME, + TABLE_NAME, + TYPE_NAME, + COLUMN_NAME, + SCALAR_FUNCTION, + TABLE_FUNCTION, + PRAGMA_FUNCTION, + SETTING_NAME, + //! Named TOKEN_ERROR (not ERROR) to avoid colliding with the Win32 `ERROR` macro. + TOKEN_ERROR, + //! Sentinel for real end of input — consumed by EndOfInputMatcher. + END_OF_INPUT, + //! Sentinel for cursor position in autocomplete mode — List/Repeat fire suggestion walk. + END_OF_INPUT_AUTOCOMPLETE +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/tokenizer/base_tokenizer.hpp b/src/duckdb/src/include/duckdb/parser/peg/tokenizer/base_tokenizer.hpp index 783d304b2..5b8789faf 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/tokenizer/base_tokenizer.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/tokenizer/base_tokenizer.hpp @@ -31,12 +31,34 @@ class BaseTokenizer { virtual ~BaseTokenizer() = default; public: - bool TokenizeInput(); + void TokenizeInput(); + //! True iff `TokenizeInput()` finished in a state where autocomplete could be offered (the + //! input wasn't truncated mid-comment / mid-dollar-quoted-string). Derived by inspecting the + //! trailing sentinel: the clean-exit paths append `GetTerminator()` (which becomes + //! `END_OF_INPUT_AUTOCOMPLETE` for autocomplete tokenizers); the dirty-exit paths always append + //! `END_OF_INPUT`. + bool CanAutocomplete() const; + +private: + //! Core tokenization loop. Returns true on a clean exit, false if the input ended inside an + //! unterminated comment / dollar-quoted string. Does NOT append the trailing sentinel — + //! `TokenizeInput()` is the one that appends `GetTerminator()` (clean) or `END_OF_INPUT` + //! (dirty) based on the return value. + bool TokenizeInputInternal(); + +public: virtual void PushToken(idx_t start, idx_t end, TokenType type, bool unterminated = false); virtual void OnStatementEnd(idx_t pos); virtual void OnLastToken(TokenizeState state, string last_word, idx_t last_pos); + //! Sentinel appended at the end of the token vector on a clean exit. Override to return + //! `END_OF_INPUT_AUTOCOMPLETE` in autocomplete tokenizers. Dirty exits (unterminated comment / + //! dollar-quote) always append `END_OF_INPUT` regardless of this hook. + virtual TokenType GetTerminator() const { + return TokenType::END_OF_INPUT; + } + bool IsSpecialOperator(idx_t pos, idx_t &op_len) const; static bool IsSingleByteOperator(char c); static bool CharacterIsInitialNumber(char c); diff --git a/src/duckdb/src/include/duckdb/parser/peg/tokenizer/parser_tokenizer.hpp b/src/duckdb/src/include/duckdb/parser/peg/tokenizer/parser_tokenizer.hpp index 7fabbda79..1a2f489cf 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/tokenizer/parser_tokenizer.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/tokenizer/parser_tokenizer.hpp @@ -10,6 +10,7 @@ class ParserTokenizer : public BaseTokenizer { ~ParserTokenizer() override = default; void OnStatementEnd(idx_t pos) override; + void OnLastToken(TokenizeState state, string last_word, idx_t last_pos) override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/transformer/parse_result.hpp b/src/duckdb/src/include/duckdb/parser/peg/transformer/parse_result.hpp index aaa3a08de..ea3647de4 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/transformer/parse_result.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/transformer/parse_result.hpp @@ -10,31 +10,11 @@ #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/peg/special_string_utils.hpp" +#include "duckdb/parser/peg/token_type.hpp" #include "duckdb/common/windows_undefs.hpp" namespace duckdb { -enum class TokenType { - INVALID, - KEYWORD, - STRING_LITERAL, - NUMBER_LITERAL, - OPERATOR, - IDENTIFIER, - COMMENT, - TERMINATOR, - CATALOG_NAME, - SCHEMA_NAME, - TABLE_NAME, - TYPE_NAME, - COLUMN_NAME, - SCALAR_FUNCTION, - TABLE_FUNCTION, - PRAGMA_FUNCTION, - SETTING_NAME, - ERROR -}; - inline string TokenTypeToString(TokenType type) { switch (type) { case TokenType::KEYWORD: @@ -51,7 +31,7 @@ inline string TokenTypeToString(TokenType type) { return "COMMENT"; case TokenType::TERMINATOR: return "TERMINATOR"; - case TokenType::ERROR: + case TokenType::TOKEN_ERROR: return "ERROR"; case TokenType::CATALOG_NAME: return "CATALOG_NAME"; @@ -71,6 +51,10 @@ inline string TokenTypeToString(TokenType type) { return "PRAGMA_FUNCTION"; case TokenType::SETTING_NAME: return "SETTING_NAME"; + case TokenType::END_OF_INPUT: + return "END_OF_INPUT"; + case TokenType::END_OF_INPUT_AUTOCOMPLETE: + return "END_OF_INPUT_AUTOCOMPLETE"; default: return "UNKNOWN"; } @@ -91,6 +75,7 @@ enum class ParseResultType : uint8_t { EXTENSION, NUMBER, STRING, + END_OF_INPUT, INVALID }; @@ -120,6 +105,8 @@ inline const char *ParseResultToString(ParseResultType type) { return "NUMBER"; case ParseResultType::STRING: return "STRING"; + case ParseResultType::END_OF_INPUT: + return "END_OF_INPUT"; case ParseResultType::INVALID: return "INVALID"; } @@ -168,7 +155,7 @@ class ParseResult { struct IdentifierParseResult : ParseResult { static constexpr ParseResultType TYPE = ParseResultType::IDENTIFIER; - string identifier; + Identifier identifier; explicit IdentifierParseResult(string identifier_p, optional_idx offset) : ParseResult(TYPE, offset), identifier(std::move(identifier_p)) { @@ -177,7 +164,20 @@ struct IdentifierParseResult : ParseResult { void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, const std::string &indent, bool is_last) const override { ParseResult::ToStringInternal(ss, visited, indent, is_last); - ss << ": " << identifier << "\n"; + ss << ": " << identifier.GetIdentifierName() << "\n"; + } +}; + +struct EndOfInputParseResult : ParseResult { + static constexpr ParseResultType TYPE = ParseResultType::END_OF_INPUT; + + EndOfInputParseResult() : ParseResult(TYPE, optional_idx()) { + } + + void ToStringInternal(std::stringstream &ss, std::unordered_set &visited, + const std::string &indent, bool is_last) const override { + ParseResult::ToStringInternal(ss, visited, indent, is_last); + ss << "\n"; } }; diff --git a/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp b/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp index 3d5f4fce1..57d8aa48f 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp @@ -5,14 +5,18 @@ #include "duckdb/parser/peg/transformer/transform_enum_result.hpp" #include "duckdb/parser/peg/transformer/transform_result.hpp" #include "duckdb/parser/peg/ast/add_column_entry.hpp" +#include "duckdb/parser/peg/ast/column_constraint_entry.hpp" #include "duckdb/parser/peg/ast/analyze_target.hpp" #include "duckdb/parser/peg/ast/column_elements.hpp" -#include "duckdb/parser/peg/ast/create_table_as.hpp" +#include "duckdb/parser/peg/ast/create_table_column_element.hpp" +#include "duckdb/parser/peg/ast/create_table_definition.hpp" #include "duckdb/parser/peg/ast/partition_sorted_options.hpp" #include "duckdb/parser/peg/ast/distinct_clause.hpp" +#include "duckdb/parser/peg/ast/describe_target.hpp" #include "duckdb/parser/peg/ast/extension_repository_info.hpp" #include "duckdb/parser/peg/ast/generated_column_definition.hpp" #include "duckdb/parser/peg/ast/generic_copy_option.hpp" +#include "duckdb/parser/peg/ast/generic_copy_option_value.hpp" #include "duckdb/parser/peg/ast/insert_values.hpp" #include "duckdb/parser/peg/ast/create_pivot_entry.hpp" #include "duckdb/parser/peg/ast/join_prefix.hpp" @@ -25,6 +29,7 @@ #include "duckdb/parser/peg/ast/setting_info.hpp" #include "duckdb/parser/peg/ast/table_alias.hpp" #include "duckdb/parser/peg/ast/trigger_event_info.hpp" +#include "duckdb/parser/peg/ast/trigger_table_referencing_info.hpp" #include "duckdb/parser/peg/ast/window_frame.hpp" #include "duckdb/function/macro_function.hpp" #include "duckdb/parser/parser_options.hpp" @@ -46,6 +51,7 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/statement/drop_statement.hpp" #include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" #include "duckdb/parser/tableref/pivotref.hpp" namespace duckdb { @@ -93,6 +99,13 @@ class PEGTransformer { auto *typed_result_ptr = dynamic_cast *>(base_result.get()); if (!typed_result_ptr) { + // allow transparent bridging between string-typed and Identifier-typed rules + auto bridged = TryBridgeTransformResult(*base_result); + if (bridged) { + auto bridged_result = std::move(bridged->value); + SetResultLocation(bridged_result, parse_result.offset); + return bridged_result; + } throw InternalException("Transformer for rule '" + parse_result.name + "' returned an unexpected type."); } @@ -101,6 +114,13 @@ class PEGTransformer { return result; } + //! Bridge between string-typed and Identifier-typed rule results (and their vector forms). + //! The generic form performs no bridging; the specializations below convert transparently. + template + static unique_ptr> TryBridgeTransformResult(TransformResultValue &base_result) { + return nullptr; + } + template T Transform(ListParseResult &parse_result, idx_t child_index) { auto &child_parse_result = parse_result.GetChild(child_index); @@ -141,8 +161,8 @@ class PEGTransformer { void Clear(); void ClearParameters(); static void ParamTypeCheck(PreparedParamType last_type, PreparedParamType new_type); - void SetParam(const string &name, idx_t index, PreparedParamType type); - bool GetParam(const string &name, idx_t &index, PreparedParamType type); + void SetParam(const Identifier &name, idx_t index, PreparedParamType type); + bool GetParam(const Identifier &name, idx_t &index, PreparedParamType type); void SetParamCount(idx_t new_count); idx_t ParamCount() const; unique_ptr CreatePivotStatement(unique_ptr statement); @@ -150,7 +170,7 @@ class PEGTransformer { void PivotEntryCheck(const string &type); void ExtractCTEsRecursive(CommonTableExpressionMap &cte_map); bool IsWindowFrameDefault(WindowBoundary start, WindowBoundary end); - unique_ptr GetWindowClause(const string &window_name); + unique_ptr GetWindowClause(const Identifier &window_name); void SetQueryLocation(ParsedExpression &expr, optional_idx query_location); void SetQueryLocation(TableRef &ref, optional_idx query_location); @@ -181,11 +201,11 @@ class PEGTransformer { const case_insensitive_map_t &grammar_rules; const case_insensitive_map_t &transform_functions; const case_insensitive_map_t> &enum_mappings; - case_insensitive_map_t named_parameter_map; + identifier_map_t named_parameter_map; idx_t prepared_statement_parameter_index = 0; PreparedParamType last_param_type = PreparedParamType::INVALID; - case_insensitive_map_t> window_clauses; + identifier_map_t> window_clauses; vector> pivot_entries; vector> stored_cte_map; @@ -208,6 +228,43 @@ class PEGTransformer { ParserOptions options; }; +//! Transparent bridging between string-typed and Identifier-typed transform results. +template <> +inline unique_ptr> +PEGTransformer::TryBridgeTransformResult(TransformResultValue &base_result) { + if (auto *ident = dynamic_cast *>(&base_result)) { + return make_uniq>(ident->value.GetIdentifierName()); + } + return nullptr; +} + +template <> +inline unique_ptr> +PEGTransformer::TryBridgeTransformResult(TransformResultValue &base_result) { + if (auto *str = dynamic_cast *>(&base_result)) { + return make_uniq>(Identifier(str->value)); + } + return nullptr; +} + +template <> +inline unique_ptr>> +PEGTransformer::TryBridgeTransformResult>(TransformResultValue &base_result) { + if (auto *idents = dynamic_cast> *>(&base_result)) { + return make_uniq>>(IdentifiersToStrings(idents->value)); + } + return nullptr; +} + +template <> +inline unique_ptr>> +PEGTransformer::TryBridgeTransformResult>(TransformResultValue &base_result) { + if (auto *strs = dynamic_cast> *>(&base_result)) { + return make_uniq>>(StringsToIdentifiers(strs->value)); + } + return nullptr; +} + typedef unique_ptr (*transform_function_t)(PEGTransformer &transformer, ParseResult &parse_result); @@ -220,12 +277,15 @@ class PEGTransformerFactory { public: explicit PEGTransformerFactory(); - //! Helper functions - vector> Transform(vector &tokens, ParserOptions &options, - Matcher &root_matcher); + //! Match a single TopLevelStatement from `tokens` starting at `token_cursor` and transform it + //! into a SQLStatement. Returns nullptr if the matched TLS was separator-only (no statement). + //! Throws on syntax error. `token_cursor` is in/out: it's the token index where matching + //! starts, and on return holds the token index immediately past the last consumed token. + unique_ptr TransformTopLevelStatement(vector &tokens, ParserOptions &options, + Matcher &root_matcher, idx_t &token_cursor); static ParseResult &ExtractResultFromParens(ParseResult &parse_result); static vector> ExtractParseResultsFromList(ParseResult &parse_result); - static bool ExpressionIsEmptyStar(ParsedExpression &expr); + static bool ExpressionIsEmptyStar(const ParsedExpression &expr); static QualifiedName StringToQualifiedName(vector input); static LogicalType GetIntervalTargetType(DatePartSpecifier date_part); static bool ConstructConstantFromExpression(const ParsedExpression &expr, Value &value); @@ -236,7 +296,8 @@ class PEGTransformerFactory { static vector GroupByExpressionUnfolding(PEGTransformer &transformer, ParseResult &group_by_expr, GroupingExpressionMap &map, GroupByNode &result); static unique_ptr VerifyLimitOffset(LimitPercentResult &limit, LimitPercentResult &offset); - static unique_ptr ToRecursiveCTE(unique_ptr node, const string &name, vector &aliases, + static unique_ptr ToRecursiveCTE(unique_ptr node, const Identifier &name, + vector &aliases, vector> &key_targets); static void WrapRecursiveView(unique_ptr &info, unique_ptr inner_node); static void ConvertToRecursiveView(unique_ptr &info, unique_ptr &node); @@ -258,41 +319,14 @@ class PEGTransformerFactory { unique_ptr expression); // Registration methods - void RegisterAlter(); - void RegisterAnalyze(); - void RegisterAttach(); - void RegisterCall(); - void RegisterCheckpoint(); void RegisterComment(); void RegisterCommon(); - void RegisterCopy(); - void RegisterCreateIndex(); void RegisterCreateMacro(); - void RegisterCreateSchema(); - void RegisterCreateSequence(); - void RegisterCreateSecret(); void RegisterCreateTable(); - void RegisterCreateType(); - void RegisterCreateView(); - void RegisterCreateTrigger(); - void RegisterDeallocate(); - void RegisterDelete(); - void RegisterDescribe(); - void RegisterDrop(); - void RegisterExecute(); - void RegisterExplain(); void RegisterExpression(); - void RegisterInsert(); - void RegisterLoad(); - void RegisterMergeInto(); + void RegisterConnect(); void RegisterPivot(); - void RegisterPragma(); - void RegisterPrepare(); void RegisterSelect(); - void RegisterSet(); - void RegisterTransaction(); - void RegisterUpdate(); - void RegisterVacuum(); void RegisterKeywordsAndIdentifiers(); void RegisterEnums(); void RegisterGenerated(); @@ -324,331 +358,27 @@ class PEGTransformerFactory { static unique_ptr TransformStatement(PEGTransformer &, ParseResult &list); - // alter.gram - static unique_ptr TransformAlterStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterOptions(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterTableStmt(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterDatabaseStmt(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterViewStmt(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterSchemaStmt(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterSequenceStmt(PEGTransformer &transformer, ParseResult &parse_result); - static QualifiedName TransformQualifiedSequenceName(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterSequenceOptions(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformSetSequenceOption(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterTableOptions(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformAddColumn(PEGTransformer &transformer, ParseResult &parse_result); - static AddColumnEntry TransformAddColumnEntry(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropColumn(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformSetPartitionedBy(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformResetPartitionedBy(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformAlterColumn(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterColumnEntry(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropDefault(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformChangeNullability(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformAlterType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformUsingExpression(PEGTransformer &transformer, - ParseResult &parse_result); - static string TransformDropOrSet(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAddOrDropDefault(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAddDefault(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformRenameColumn(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformRenameAlter(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAddConstraint(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformSequenceName(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformSetSortedBy(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformResetSortedBy(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformSetOptions(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformResetOptions(PEGTransformer &transformer, ParseResult &parse_result); - - // analyze.gram - static unique_ptr TransformAnalyzeStatement(PEGTransformer &transformer, ParseResult &parse_result); - static AnalyzeTarget TransformAnalyzeTarget(PEGTransformer &transformer, ParseResult &parse_result); - - // attach.gram - static unique_ptr TransformAttachStatement(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformAttachAlias(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformAttachOptions(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformGenericCopyOptionList(PEGTransformer &transformer, - ParseResult &parse_result); - static GenericCopyOption TransformGenericCopyOption(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDatabasePath(PEGTransformer &transformer, ParseResult &parse_result); - - // call.gram - static unique_ptr TransformCallStatement(PEGTransformer &transformer, ParseResult &parse_result); - - // checkpoint.gram - static unique_ptr TransformCheckpointStatement(PEGTransformer &transformer, - ParseResult &parse_result); - + // connect.gram — both rules have optional sub-clauses, so the generator skips them and we + // hand-write the (PEGTransformer&, ParseResult&) entry points. + static unique_ptr TransformConnectStatement(PEGTransformer &transformer, ParseResult &parse_result); // comment.gram - static unique_ptr TransformCommentStatement(PEGTransformer &transformer, ParseResult &parse_result); - static CatalogType TransformCommentOnType(PEGTransformer &transformer, ParseResult &parse_result); static Value TransformCommentValue(PEGTransformer &transformer, ParseResult &parse_result); // common.gram static unique_ptr TransformNumberLiteral(PEGTransformer &transformer, ParseResult &parse_result); static string TransformStringLiteral(PEGTransformer &transformer, ParseResult &parse_result); - static LogicalType TransformType(PEGTransformer &transformer, ParseResult &parse_result); - static int64_t TransformArrayBounds(PEGTransformer &transformer, ParseResult &parse_result); - static int64_t TransformSquareBracketsArray(PEGTransformer &transformer, ParseResult &parse_result); - - static unique_ptr TransformTimeType(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformTimeZone(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformWithOrWithout(PEGTransformer &transformer, ParseResult &parse_result); - static LogicalTypeId TransformTimeOrTimestamp(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformNumericType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformSimpleNumericType(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformDecimalNumericType(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformFloatType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDecimalType(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformTypeModifiers(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformSimpleType(PEGTransformer &transformer, ParseResult &parse_result); - static QualifiedName TransformQualifiedTypeName(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCharacterType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformMapType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformRowType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformGeometryType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformVariantType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformUnionType(PEGTransformer &transformer, ParseResult &parse_result); - static child_list_t TransformColIdTypeList(PEGTransformer &transformer, ParseResult &parse_result); - static pair TransformColIdType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformBitType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformIntervalType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformIntervalInterval(PEGTransformer &transformer, - ParseResult &parse_result); - static DatePartSpecifier TransformInterval(PEGTransformer &transformer, ParseResult &parse_result); - static DatePartSpecifier TransformIntervalToInterval(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformSetofType(PEGTransformer &transformer, ParseResult &parse_result); + static DatePartSpecifier TransformIntervalToIntervalAsType(PEGTransformer &transformer, ParseResult &parse_result); - // copy.gram - static unique_ptr TransformCopyStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCopySelect(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCopyFromDatabase(PEGTransformer &transformer, ParseResult &parse_result); - static CopyDatabaseType TransformCopyDatabaseFlag(PEGTransformer &transformer, ParseResult &parse_result); static string ExtractFormat(const string &file_path); - static unique_ptr TransformCopyTable(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformFromOrTo(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCopyFileName(PEGTransformer &transformer, ParseResult &parse_result); - string TransformIdentifierColId(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformCopyOptions(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformSpecializedOptionList(PEGTransformer &transformer, - ParseResult &parse_result); - static GenericCopyOption TransformSpecializedOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformSingleOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformEncodingOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformForceQuoteOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformQuoteAsOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformForceNullOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformPartitionByOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformNullAsOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformDelimiterAsOption(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformEscapeAsOption(PEGTransformer &transformer, ParseResult &parse_result); - static CopyDatabaseType TransformSchemaOrData(PEGTransformer &transformer, ParseResult &parse_result); - - // create_index.gram - static unique_ptr TransformCreateIndexStmt(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformIndexType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformIndexElement(PEGTransformer &transformer, ParseResult &parse_result); - static case_insensitive_map_t> TransformWithList(PEGTransformer &transformer, - ParseResult &parse_result); - static case_insensitive_map_t> TransformRelOptionOrOids(PEGTransformer &transformer, - ParseResult &parse_result); - static case_insensitive_map_t> TransformRelOptionList(PEGTransformer &transformer, - ParseResult &parse_result); - static case_insensitive_map_t> TransformOids(PEGTransformer &transformer, - ParseResult &parse_result); - static string TransformRelOptionName(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformRelOptionArgumentOpt(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformRelOption(PEGTransformer &transformer, - ParseResult &parse_result); - static string TransformIndexName(PEGTransformer &transformer, ParseResult &parse_result); - - // create_macro.gram - static unique_ptr TransformCreateMacroStmt(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformMacroDefinition(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformTableMacroDefinition(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformScalarMacroDefinition(PEGTransformer &transformer, - ParseResult &parse_result); - static vector TransformMacroParameters(PEGTransformer &transformer, ParseResult &parse_result); - static MacroParameter TransformMacroParameter(PEGTransformer &transformer, ParseResult &parse_result); - static MacroParameter TransformSimpleParameter(PEGTransformer &transformer, ParseResult &parse_result); - - // create_schema.gram - static unique_ptr TransformCreateSchemaStmt(PEGTransformer &transformer, - ParseResult &parse_result); - - // create_secret.gram - static unique_ptr TransformCreateSecretStmt(PEGTransformer &transformer, - ParseResult &parse_result); - static string TransformSecretStorageSpecifier(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformSecretName(PEGTransformer &transformer, ParseResult &parse_result); - - // create_sequence.gram - static unique_ptr TransformCreateSequenceStmt(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformSequenceOption(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformSeqSetCycle(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformSeqSetIncrement(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformSeqSetMinMax(PEGTransformer &transformer, - ParseResult &parse_result); - static string TransformSeqMinOrMax(PEGTransformer &transformer, ParseResult &parse_result); - static pair> TransformSeqNoMinMax(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformSeqStartWith(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformSeqOwnedBy(PEGTransformer &transformer, - ParseResult &parse_result); // create_table.gram - static unique_ptr TransformCreateStatement(PEGTransformer &transformer, ParseResult &parse_result); - static SecretPersistType TransformTemporary(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCreateStatementVariation(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformCreateTableStmt(PEGTransformer &transformer, ParseResult &parse_result); - static CreateTableAs TransformCreateTableAs(PEGTransformer &transformer, ParseResult &parse_result); - static ColumnList TransformIdentifierList(PEGTransformer &transformer, ParseResult &parse_result); - static ColumnElements TransformCreateColumnList(PEGTransformer &transformer, ParseResult &parse_result); - static ColumnElements TransformCreateTableColumnList(PEGTransformer &transformer, ParseResult &parse_result); - static PartitionSortedOptions TransformPartitionSortedOptions(PEGTransformer &transformer, - ParseResult &parse_result); - static PartitionSortedOptions TransformPartitionOptSortedOptions(PEGTransformer &transformer, - ParseResult &parse_result); - static PartitionSortedOptions TransformSortedOptPartitionOptions(PEGTransformer &transformer, - ParseResult &parse_result); - static vector> TransformPartitionOptions(PEGTransformer &transformer, - ParseResult &parse_result); - static vector> TransformSortedOptions(PEGTransformer &transformer, - ParseResult &parse_result); - static QualifiedName TransformIdentifierOrStringLiteral(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformColIdOrString(PEGTransformer &transformer, ParseResult &parse_result); static string TransformColLabelOrString(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformColId(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformColumnIdList(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformTypeFuncName(PEGTransformer &transformer, ParseResult &parse_result); static string TransformIdentifier(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformDottedIdentifier(PEGTransformer &transformer, ParseResult &parse_result); - static ConstraintColumnDefinition TransformColumnDefinition(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformTopLevelConstraint(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformTopLevelConstraintList(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformTopPrimaryKeyConstraint(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformTopUniqueConstraint(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCheckConstraint(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformTopForeignKeyConstraint(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformDefaultValue(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformForeignKeyConstraint(PEGTransformer &transformer, - ParseResult &parse_result); - static GeneratedColumnDefinition TransformGeneratedColumn(PEGTransformer &transformer, ParseResult &parse_result); - static CompressionType TransformColumnCompression(PEGTransformer &transformer, ParseResult &parse_result); - static pair TransformPrimaryKeyConstraint(PEGTransformer &transformer, - ParseResult &parse_result); - static pair TransformUniqueConstraint(PEGTransformer &transformer, ParseResult &parse_result); - static pair TransformNotNullConstraint(PEGTransformer &transformer, - ParseResult &parse_result); - static KeyActions TransformKeyActions(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformUpdateAction(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformDeleteAction(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformKeyAction(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformNoKeyAction(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformRestrictKeyAction(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformCascadeKeyAction(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformSetNullKeyAction(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformSetDefaultKeyAction(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformColumnCollation(PEGTransformer &transformer, - ParseResult &parse_result); - static bool TransformWithData(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformCommitAction(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformPreserveOrDelete(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformGeneratedColumnType(PEGTransformer &transformer, ParseResult &parse_result); - - // create_type.gram - static unique_ptr TransformCreateTypeStmt(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCreateType(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformEnumSelectType(PEGTransformer &transformer, ParseResult &parse_result); - static LogicalType TransformEnumStringLiteralList(PEGTransformer &transformer, ParseResult &parse_result); - - // create_view.gram - static unique_ptr TransformCreateViewStmt(PEGTransformer &transformer, ParseResult &parse_result); // create_trigger.gram - static unique_ptr TransformCreateTriggerStmt(PEGTransformer &transformer, - ParseResult &parse_result); static TriggerForEach TransformForEachClause(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformTriggerName(PEGTransformer &transformer, ParseResult &parse_result); static TriggerTiming TransformTriggerTiming(PEGTransformer &transformer, ParseResult &parse_result); static TriggerEventInfo TransformTriggerEvent(PEGTransformer &transformer, ParseResult &parse_result); - static TriggerEventInfo TransformTriggerEventInsert(PEGTransformer &transformer, ParseResult &parse_result); - static TriggerEventInfo TransformTriggerEventDelete(PEGTransformer &transformer, ParseResult &parse_result); - static TriggerEventInfo TransformTriggerEventUpdate(PEGTransformer &transformer, ParseResult &parse_result); - static TriggerEventInfo TransformTriggerEventUpdateOf(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformTriggerColumnList(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformTriggerBody(PEGTransformer &transformer, ParseResult &parse_result); - - // deallocate.gram - static unique_ptr TransformDeallocateStatement(PEGTransformer &transformer, - ParseResult &parse_result); - - // describe.gram - static unique_ptr TransformDescribeStatement(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformShowSelect(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformShowTables(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformShowAllTables(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformShowQualifiedName(PEGTransformer &transformer, ParseResult &parse_result); - static ShowType TransformShowOrDescribeOrSummarize(PEGTransformer &transformer, ParseResult &parse_result); - static ShowType TransformShowOrDescribe(PEGTransformer &transformer, ParseResult &parse_result); - static ShowType TransformSummarize(PEGTransformer &transformer, ParseResult &parse_result); - // - // delete.gram - static unique_ptr TransformDeleteStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformTargetOptAlias(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformDeleteUsingClause(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformTruncateStatement(PEGTransformer &transformer, ParseResult &parse_result); - - // drop.gram - static unique_ptr TransformDropStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropEntries(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropTable(PEGTransformer &transformer, ParseResult &parse_result); - static CatalogType TransformTableOrView(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropTableFunction(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropFunction(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropSchema(PEGTransformer &transformer, ParseResult &parse_result); - static QualifiedName TransformQualifiedSchemaName(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropIndex(PEGTransformer &transformer, ParseResult &parse_result); - static QualifiedName TransformQualifiedIndexName(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropSequence(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformCollationName(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropCollation(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropType(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformDropBehavior(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformIfExists(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropSecret(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformDropSecretStorage(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropTrigger(PEGTransformer &transformer, ParseResult &parse_result); - - // execute.gram - static unique_ptr TransformExecuteStatement(PEGTransformer &transformer, ParseResult &parse_result); - - // explain.gram - static unique_ptr TransformExplainStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformExplainableStatements(PEGTransformer &transformer, - ParseResult &parse_result); - static vector TransformExplainOptionList(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformExplainOption(PEGTransformer &transformer, ParseResult &parse_result); // expression.gram static unique_ptr TransformExpressionStatement(PEGTransformer &transformer, @@ -723,15 +453,16 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformConstantLiteral(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformNestedColumnName(PEGTransformer &transformer, - ParseResult &parse_result); + static Value TransformFalseLiteral(PEGTransformer &transformer, ParseResult &parse_result); + static Value TransformTrueLiteral(PEGTransformer &transformer, ParseResult &parse_result); + static Value TransformNullLiteral(PEGTransformer &transformer, ParseResult &parse_result); + static Value TransformUnknownLiteral(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformColumnReference(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCatalogReservedSchemaTableColumnName(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformSchemaReservedTableColumnName(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformReservedTableQualification(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformParameter(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformAnonymousParameter(PEGTransformer &transformer, ParseResult &parse_result); @@ -757,7 +488,7 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformStructExpression(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformStructField(PEGTransformer &transformer, ParseResult &parse_result); + static FunctionArgument TransformStructField(PEGTransformer &transformer, ParseResult &parse_result); static vector> TransformBoundedListExpression(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformFunctionExpression(PEGTransformer &transformer, @@ -783,7 +514,7 @@ class PEGTransformerFactory { static unique_ptr TransformStepSliceBound(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformTableReservedColumnName(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformTableQualification(PEGTransformer &transformer, ParseResult &parse_result); + static string TransformColIdDot(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformStarExpression(PEGTransformer &transformer, ParseResult &parse_result); static qualified_column_set_t TransformExcludeList(PEGTransformer &transformer, ParseResult &parse_result); static qualified_column_set_t TransformExcludeNameList(PEGTransformer &transformer, ParseResult &parse_result); @@ -852,6 +583,12 @@ class PEGTransformerFactory { ParseResult &parse_result); static vector> TransformSubstringParameters(PEGTransformer &transformer, ParseResult &parse_result); + static vector> TransformSubstringFromFor(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformSubstringFromOptionalFor(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformSubstringFor(PEGTransformer &transformer, + ParseResult &parse_result); static vector> TransformSubstringArguments(PEGTransformer &transformer, ParseResult &parse_result); static vector> TransformSubstringExpressionList(PEGTransformer &transformer, @@ -859,6 +596,16 @@ class PEGTransformerFactory { static unique_ptr TransformTrimExpression(PEGTransformer &transformer, ParseResult &parse_result); static string TransformTrimDirection(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformTrimSource(PEGTransformer &transformer, ParseResult &parse_result); + static unique_ptr TransformOverlayExpression(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformOverlayArguments(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformOverlayParameters(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformFromExpression(PEGTransformer &transformer, ParseResult &parse_result); + static unique_ptr TransformForExpression(PEGTransformer &transformer, ParseResult &parse_result); + static vector> TransformOverlayExpressionList(PEGTransformer &transformer, + ParseResult &parse_result); static unique_ptr TransformPositionExpression(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCastExpression(PEGTransformer &transformer, ParseResult &parse_result); @@ -897,97 +644,10 @@ class PEGTransformerFactory { ParseResult &parse_result); static bool TransformIgnoreOrRespectNulls(PEGTransformer &transformer, ParseResult &parse_result); - // insert.gram - static unique_ptr TransformInsertStatement(PEGTransformer &transformer, ParseResult &parse_result); - static OnConflictAction TransformOrAction(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformInsertTarget(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformInsertAlias(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformOnConflictClause(PEGTransformer &transformer, ParseResult &parse_result); - static OnConflictExpressionTarget TransformOnConflictTarget(PEGTransformer &transformer, ParseResult &parse_result); - static OnConflictExpressionTarget TransformOnConflictExpressionTarget(PEGTransformer &transformer, - ParseResult &parse_result); - static OnConflictExpressionTarget TransformOnConflictIndexTarget(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformOnConflictAction(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformOnConflictUpdate(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformOnConflictNothing(PEGTransformer &transformer, - ParseResult &parse_result); - static InsertValues TransformInsertValues(PEGTransformer &transformer, ParseResult &parse_result); - static InsertColumnOrder TransformByNameOrPosition(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformInsertColumnList(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformColumnList(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformReturningClause(PEGTransformer &transformer, - ParseResult &parse_result); - - // load.gram - static unique_ptr TransformLoadStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformInstallStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformUpdateExtensionsStatement(PEGTransformer &transformer, - ParseResult &parse_result); - static ExtensionRepositoryInfo TransformFromSource(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformVersionNumber(PEGTransformer &transformer, ParseResult &parse_result); - - // merge_into.gram - static unique_ptr TransformMergeIntoStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformMergeIntoUsingClause(PEGTransformer &transformer, ParseResult &parse_result); - static pair> TransformMergeMatch(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformMatchedClause(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformMatchedClauseAction(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformUpdateMatchClause(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformUpdateMatchInfo(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDeleteMatchClause(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformInsertMatchClause(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformDoNothingMatchClause(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformErrorMatchClause(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformUpdateMatchSetClause(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformAndExpression(PEGTransformer &transformer, ParseResult &parse_result); - static pair> - TransformNotMatchedClause(PEGTransformer &transformer, ParseResult &parse_result); - static MergeActionCondition TransformBySourceOrTarget(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformInsertMatchInfo(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformInsertDefaultValues(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformInsertByNameOrPosition(PEGTransformer &transformer, - ParseResult &parse_result); - static unique_ptr TransformInsertValuesList(PEGTransformer &transformer, - ParseResult &parse_result); - // pivot.gram static unique_ptr TransformPivotStatement(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformPivotUsing(PEGTransformer &transformer, - ParseResult &parse_result); - static vector TransformPivotOn(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformPivotColumnList(PEGTransformer &transformer, ParseResult &parse_result); - static PivotColumn TransformPivotColumnEntry(PEGTransformer &transformer, ParseResult &parse_result); - static PivotColumn TransformPivotColumnSubquery(PEGTransformer &transformer, ParseResult &parse_result); - static PivotColumn TransformPivotColumnEntryInternal(PEGTransformer &transformer, ParseResult &parse_result); - // - static vector TransformUnpivotHeader(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformUnpivotHeaderSingle(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformUnpivotHeaderList(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformUnpivotStatement(PEGTransformer &transformer, ParseResult &parse_result); - static UnpivotNameValues TransformIntoNameValues(PEGTransformer &transformer, ParseResult &parse_result); - static bool TransformIncludeOrExcludeNulls(PEGTransformer &transformer, ParseResult &parse_result); - - // pragma.gram - static unique_ptr TransformPragmaStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformPragmaAssign(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformPragmaFunction(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformPragmaParameters(PEGTransformer &transformer, - ParseResult &parse_result); - - // prepare.gram - static unique_ptr TransformPrepareStatement(PEGTransformer &transformer, ParseResult &parse_result); // select.gram static unique_ptr TransformSelectStatement(PEGTransformer &transformer, ParseResult &parse_result); @@ -1020,30 +680,20 @@ class PEGTransformerFactory { static vector> TransformDistinctOnTargets(PEGTransformer &transformer, ParseResult &parse_result); static DistinctClause TransformDistinctAll(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformFunctionArgument(PEGTransformer &transformer, - ParseResult &parse_result); + static FunctionArgument TransformFunctionArgument(PEGTransformer &transformer, ParseResult &parse_result); static MacroParameter TransformNamedParameter(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformTableFunctionArguments(PEGTransformer &transformer, - ParseResult &parse_result); + static vector TransformTableFunctionArguments(PEGTransformer &transformer, + ParseResult &parse_result); static unique_ptr TransformBaseTableName(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformSchemaReservedTable(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCatalogReservedSchemaTable(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformSchemaQualification(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformCatalogQualification(PEGTransformer &transformer, ParseResult &parse_result); - static QualifiedName TransformQualifiedName(PEGTransformer &transformer, ParseResult &parse_result); QualifiedName TransformCatalogReservedSchemaIdentifierOrStringLiteral(PEGTransformer &transformer, optional_ptr parse_result); - static QualifiedName TransformCatalogReservedSchemaIdentifier(PEGTransformer &transformer, - ParseResult &parse_result); - - static QualifiedName TransformSchemaReservedIdentifierOrStringLiteral(PEGTransformer &transformer, - ParseResult &parse_result); static QualifiedName TransformTableNameIdentifierOrStringLiteral(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformReservedIdentifierOrStringLiteral(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformWhereClause(PEGTransformer &transformer, ParseResult &parse_result); static vector> TransformTargetList(PEGTransformer &transformer, @@ -1168,53 +818,1866 @@ class PEGTransformerFactory { static SampleMethod TransformSampleFunction(PEGTransformer &transformer, ParseResult &parse_result); static optional_idx TransformRepeatableSample(PEGTransformer &transformer, ParseResult &parse_result); - // set.gram - static unique_ptr TransformResetStatement(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformSetAssignment(PEGTransformer &transformer, + static string TransformIdentifierOrKeyword(PEGTransformer &transformer, ParseResult &parse_result); + + //===--------------------------------------------------------------------===// + // START GENERATED RULES + //===--------------------------------------------------------------------===// + static unique_ptr TransformAlterStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterStatement(PEGTransformer &transformer, + unique_ptr alter_options); + static unique_ptr TransformAlterOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterTableStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterTableStmt(PEGTransformer &transformer, const bool &if_exists, + unique_ptr base_table_name, + vector> alter_table_options); + static unique_ptr TransformAlterSchemaStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterSchemaStmt(PEGTransformer &transformer, const bool &if_exists, + const QualifiedName &qualified_name, + unique_ptr rename_alter); + static unique_ptr TransformAlterTableOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAddConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAddConstraint(PEGTransformer &transformer, + unique_ptr top_level_constraint); + static unique_ptr TransformAddColumnInternal(PEGTransformer &transformer, ParseResult &parse_result); - static SettingInfo TransformSetSetting(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformSetStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformSetTimeZone(PEGTransformer &transformer, ParseResult &parse_result); - static SettingInfo TransformSetVariable(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformZoneValue(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformZoneIntervalWithInterval(PEGTransformer &transformer, + static unique_ptr TransformAddColumn(PEGTransformer &transformer, const bool &if_not_exists, + AddColumnEntry add_column_entry); + static unique_ptr TransformAddColumnEntryInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static AddColumnEntry TransformAddColumnEntry(PEGTransformer &transformer, const vector &dotted_identifier, + const LogicalType &type, GeneratedColumnDefinition generated_column, + vector column_constraint); + static unique_ptr TransformDropColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropColumn(PEGTransformer &transformer, const bool &if_exists, + unique_ptr nested_column_name, + const bool &drop_behavior); + static unique_ptr TransformAlterColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterColumn(PEGTransformer &transformer, + unique_ptr nested_column_name, + unique_ptr alter_column_entry); + static unique_ptr TransformRenameColumnInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformZoneIntervalWithPrecision(PEGTransformer &transformer, + static unique_ptr TransformRenameColumn(PEGTransformer &transformer, + unique_ptr nested_column_name, + const Identifier &identifier); + static unique_ptr TransformNestedColumnNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformNestedColumnName(PEGTransformer &transformer, + const vector &identifier_dot, + const Identifier &column_name); + static unique_ptr TransformIdentifierDotInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformStandardAssignment(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformVariableList(PEGTransformer &transformer, + static Identifier TransformIdentifierDot(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformRenameAlterInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformRenameAlter(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformSetPartitionedByInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSetPartitionedBy(PEGTransformer &transformer, + vector> expression); + static unique_ptr TransformResetPartitionedByInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformResetPartitionedBy(PEGTransformer &transformer); + static unique_ptr TransformSetSortedByInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSetSortedBy(PEGTransformer &transformer, + vector order_by_expressions); + static unique_ptr TransformResetSortedByInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformResetSortedBy(PEGTransformer &transformer); + static unique_ptr TransformSetOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformSetOptions(PEGTransformer &transformer, + case_insensitive_map_t> rel_option_list); + static unique_ptr TransformResetOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformResetOptions(PEGTransformer &transformer, + case_insensitive_map_t> rel_option_list); + static unique_ptr TransformAlterColumnEntryInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAddOrDropDefaultInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAddDefaultInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAddDefault(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformDropDefaultInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropDefault(PEGTransformer &transformer); + static unique_ptr TransformChangeNullabilityInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformChangeNullability(PEGTransformer &transformer, + const string &drop_or_set); + static unique_ptr TransformDropOrSetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropNullabilityInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformDropNullability(PEGTransformer &transformer); + static unique_ptr TransformSetNullabilityInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformSetNullability(PEGTransformer &transformer); + static unique_ptr TransformAlterTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterType(PEGTransformer &transformer, const LogicalType &type, + unique_ptr using_expression); + static unique_ptr TransformUsingExpressionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUsingExpression(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformAlterViewStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterViewStmt(PEGTransformer &transformer, const bool &if_exists, + unique_ptr base_table_name, + unique_ptr rename_alter); + static unique_ptr TransformAlterSequenceStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterSequenceStmt(PEGTransformer &transformer, const bool &if_exists, + const QualifiedName &qualified_sequence_name, + unique_ptr alter_sequence_options); + static unique_ptr TransformQualifiedSequenceNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformQualifiedSequenceName(PEGTransformer &transformer, + const Identifier &catalog_qualification, + const Identifier &schema_qualification, + const Identifier &sequence_name); + static unique_ptr TransformAlterSequenceOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterSequenceOptions(PEGTransformer &transformer, ParseResult &choice_result); + static unique_ptr TransformSetSequenceOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformSetSequenceOption(PEGTransformer &transformer, + vector>> sequence_option); + static unique_ptr TransformAlterDatabaseStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAlterDatabaseStmt(PEGTransformer &transformer, const bool &if_exists, + const Identifier &identifier, + const Identifier &identifier_1); + static unique_ptr TransformAnalyzeStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAnalyzeStatement(PEGTransformer &transformer, const bool &analyze_verbose, + AnalyzeTarget analyze_target); + static unique_ptr TransformAnalyzeTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static AnalyzeTarget TransformAnalyzeTarget(PEGTransformer &transformer, unique_ptr base_table_name, + const vector &name_list); + static unique_ptr TransformAnalyzeVerboseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformAnalyzeVerbose(PEGTransformer &transformer); + static unique_ptr TransformAttachStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAttachStatement(PEGTransformer &transformer, const bool &or_replace, + const bool &if_not_exists, + unique_ptr database_path, + const Identifier &attach_alias, + const vector &attach_options); + static unique_ptr TransformDatabasePathInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDatabasePath(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformAttachAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformAttachAlias(PEGTransformer &transformer, const Identifier &col_id); + static unique_ptr TransformAttachOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformAttachOptions(PEGTransformer &transformer, + const vector &generic_copy_option_list); + static unique_ptr TransformCallStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCallStatement(PEGTransformer &transformer, + const QualifiedName &qualified_table_function, + vector table_function_arguments); + static unique_ptr TransformCheckpointStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCheckpointStatement(PEGTransformer &transformer, + const bool &checkpoint_force, + const Identifier &catalog_name); + static unique_ptr TransformCheckpointForceInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformCheckpointForce(PEGTransformer &transformer); + static unique_ptr TransformCommentStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCommentStatement(PEGTransformer &transformer, + const CatalogType &comment_on_type, + const vector &dotted_identifier, + const Value &comment_value); + static unique_ptr TransformCommentOnTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCommentTableInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentTable(PEGTransformer &transformer); + static unique_ptr TransformCommentSequenceInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentSequence(PEGTransformer &transformer); + static unique_ptr TransformCommentFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentFunction(PEGTransformer &transformer); + static unique_ptr TransformCommentMacroTableInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentMacroTable(PEGTransformer &transformer); + static unique_ptr TransformCommentMacroInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentMacro(PEGTransformer &transformer); + static unique_ptr TransformCommentViewInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentView(PEGTransformer &transformer); + static unique_ptr TransformCommentDatabaseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentDatabase(PEGTransformer &transformer); + static unique_ptr TransformCommentIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentIndex(PEGTransformer &transformer); + static unique_ptr TransformCommentSchemaInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentSchema(PEGTransformer &transformer); + static unique_ptr TransformCommentTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentType(PEGTransformer &transformer); + static unique_ptr TransformCommentColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformCommentColumn(PEGTransformer &transformer); + static unique_ptr TransformExpressionStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformExpressionStatement(PEGTransformer &transformer, + vector> expression_alias); + static unique_ptr TransformExpressionAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformConstraintNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformConstraintName(PEGTransformer &transformer, const Identifier &col_id_or_string); + static unique_ptr TransformCollationNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformCollationName(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static LogicalType TransformType(PEGTransformer &transformer, unique_ptr type_variations, + const vector &array_bounds); + static unique_ptr TransformTypeVariationsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSimpleTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCharacterSimpleTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformCharacterSimpleType(PEGTransformer &transformer, const string &character_type, + vector> type_modifiers); + static unique_ptr TransformQualifiedSimpleTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformQualifiedSimpleType(PEGTransformer &transformer, const QualifiedName &qualified_type_name, + vector> type_modifiers); + static unique_ptr TransformCharacterTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformCharacterType(PEGTransformer &transformer); + static unique_ptr TransformIntervalTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformIntervalIntervalInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformIntervalWithSpecifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformIntervalWithRangeSpecifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformIntervalWithRangeSpecifier(PEGTransformer &transformer, + const DatePartSpecifier &interval_to_interval_as_type); + static unique_ptr TransformIntervalWithSimpleSpecifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformIntervalWithSimpleSpecifier(PEGTransformer &transformer, + const DatePartSpecifier &interval); + static unique_ptr TransformIntervalWithoutSpecifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformIntervalWithoutSpecifier(PEGTransformer &transformer); + static unique_ptr TransformYearKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformYearKeyword(PEGTransformer &transformer); + static unique_ptr TransformMonthKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformMonthKeyword(PEGTransformer &transformer); + static unique_ptr TransformDayKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformDayKeyword(PEGTransformer &transformer); + static unique_ptr TransformHourKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformHourKeyword(PEGTransformer &transformer); + static unique_ptr TransformMinuteKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformMinuteKeyword(PEGTransformer &transformer); + static unique_ptr TransformSecondKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformSecondKeyword(PEGTransformer &transformer); + static unique_ptr TransformMillisecondKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformMillisecondKeyword(PEGTransformer &transformer); + static unique_ptr TransformMicrosecondKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformMicrosecondKeyword(PEGTransformer &transformer); + static unique_ptr TransformWeekKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformWeekKeyword(PEGTransformer &transformer); + static unique_ptr TransformQuarterKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformQuarterKeyword(PEGTransformer &transformer); + static unique_ptr TransformDecadeKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformDecadeKeyword(PEGTransformer &transformer); + static unique_ptr TransformCenturyKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformCenturyKeyword(PEGTransformer &transformer); + static unique_ptr TransformMillenniumKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformMillenniumKeyword(PEGTransformer &transformer); + static unique_ptr TransformIntervalInternal(PEGTransformer &transformer, ParseResult &parse_result); - - static string TransformIdentifierOrKeyword(PEGTransformer &transformer, ParseResult &parse_result); - - // transaction.gram - static TransactionModifierType TransformReadOnlyOrReadWrite(PEGTransformer &transformer, ParseResult &parse_result); - - // update.gram - static unique_ptr TransformUpdateStatement(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformUpdateTarget(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformBaseTableSet(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformBaseTableAliasSet(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformUpdateAlias(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformUpdateSetClause(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformUpdateSetTuple(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformUpdateSetElementList(PEGTransformer &transformer, - ParseResult &parse_result); - static pair> TransformUpdateSetElement(PEGTransformer &transformer, + static unique_ptr TransformIntervalToIntervalInternal(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformUpdateSetColumnTarget(PEGTransformer &transformer, ParseResult &parse_result); - - // vacuum.gram - static unique_ptr TransformVacuumStatement(PEGTransformer &transformer, ParseResult &parse_result); - static VacuumOptions TransformVacuumOptions(PEGTransformer &transformer, ParseResult &parse_result); - static VacuumOptions TransformVacuumLegacyOptions(PEGTransformer &transformer, ParseResult &parse_result); - static VacuumOptions TransformVacuumParensOptions(PEGTransformer &transformer, ParseResult &parse_result); - static string TransformVacuumOption(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformNameList(PEGTransformer &transformer, ParseResult &parse_result); - -#define DUCKDB_INSIDE_PEG_TRANSFORMER_HPP -#include "duckdb/parser/peg/transformer/peg_transformer_generated.hpp" -#undef DUCKDB_INSIDE_PEG_TRANSFORMER_HPP + static unique_ptr TransformYearToMonthInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformYearToMonth(PEGTransformer &transformer, const DatePartSpecifier &year_keyword, + const DatePartSpecifier &month_keyword); + static unique_ptr TransformDayToHourInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformDayToHour(PEGTransformer &transformer, const DatePartSpecifier &day_keyword, + const DatePartSpecifier &hour_keyword); + static unique_ptr TransformDayToMinuteInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformDayToMinute(PEGTransformer &transformer, const DatePartSpecifier &day_keyword, + const DatePartSpecifier &minute_keyword); + static unique_ptr TransformDayToSecondInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformDayToSecond(PEGTransformer &transformer, const DatePartSpecifier &day_keyword, + const DatePartSpecifier &second_keyword); + static unique_ptr TransformHourToMinuteInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformHourToMinute(PEGTransformer &transformer, const DatePartSpecifier &hour_keyword, + const DatePartSpecifier &minute_keyword); + static unique_ptr TransformHourToSecondInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformHourToSecond(PEGTransformer &transformer, const DatePartSpecifier &hour_keyword, + const DatePartSpecifier &second_keyword); + static unique_ptr TransformMinuteToSecondInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DatePartSpecifier TransformMinuteToSecond(PEGTransformer &transformer, + const DatePartSpecifier &minute_keyword, + const DatePartSpecifier &second_keyword); + static unique_ptr TransformBitTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformBitType(PEGTransformer &transformer, + vector> expression); + static unique_ptr TransformGeometryTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformGeometryType(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformVariantTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformVariantType(PEGTransformer &transformer); + static unique_ptr TransformNumericTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSimpleNumericTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSimpleNumericType(PEGTransformer &transformer, const string &child); + static unique_ptr TransformDecimalNumericTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformIntTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformIntType(PEGTransformer &transformer); + static unique_ptr TransformIntegerTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformIntegerType(PEGTransformer &transformer); + static unique_ptr TransformSmallintTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformSmallintType(PEGTransformer &transformer); + static unique_ptr TransformBigintTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformBigintType(PEGTransformer &transformer); + static unique_ptr TransformRealTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformRealType(PEGTransformer &transformer); + static unique_ptr TransformBooleanTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformBooleanType(PEGTransformer &transformer); + static unique_ptr TransformDoubleTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformDoubleType(PEGTransformer &transformer); + static unique_ptr TransformFloatTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformFloatType(PEGTransformer &transformer, + unique_ptr number_literal); + static unique_ptr TransformDecimalTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDecimalType(PEGTransformer &transformer, + vector> type_modifiers); + static unique_ptr TransformDecTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDecType(PEGTransformer &transformer, + vector> type_modifiers); + static unique_ptr TransformNumericModTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformNumericModType(PEGTransformer &transformer, + vector> type_modifiers); + static unique_ptr TransformQualifiedTypeNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTypeNameAsQualifiedNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformTypeNameAsQualifiedName(PEGTransformer &transformer, const Identifier &type_name); + static unique_ptr TransformCatalogReservedSchemaTypeNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformCatalogReservedSchemaTypeName(PEGTransformer &transformer, + const Identifier &catalog_qualification, + const Identifier &reserved_schema_qualification, + const Identifier &reserved_type_name); + static unique_ptr TransformSchemaReservedTypeNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformSchemaReservedTypeName(PEGTransformer &transformer, + const Identifier &schema_qualification, + const Identifier &reserved_type_name); + static unique_ptr TransformTypeModifiersInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformTypeModifiers(PEGTransformer &transformer, + vector> expression); + static unique_ptr TransformRowTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformRowType(PEGTransformer &transformer, + const child_list_t &col_id_type_list); + static unique_ptr TransformSetofTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSetofType(PEGTransformer &transformer, const LogicalType &type); + static unique_ptr TransformUnionTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUnionType(PEGTransformer &transformer, + const child_list_t &col_id_type_list); + static unique_ptr TransformColIdTypeListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static child_list_t TransformColIdTypeList(PEGTransformer &transformer, + const vector> &col_id_type); + static unique_ptr TransformMapTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformMapType(PEGTransformer &transformer, const vector &type); + static unique_ptr TransformColIdTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair TransformColIdType(PEGTransformer &transformer, const Identifier &col_id, + const LogicalType &type); + static unique_ptr TransformArrayBoundsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformArrayKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static int64_t TransformArrayKeyword(PEGTransformer &transformer); + static unique_ptr TransformSquareBracketsArrayInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static int64_t TransformSquareBracketsArray(PEGTransformer &transformer, unique_ptr expression); + static unique_ptr TransformTimeTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTimeType(PEGTransformer &transformer, + const LogicalTypeId &time_or_timestamp, + vector> type_modifiers, + const bool &time_zone); + static unique_ptr TransformTimeOrTimestampInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTimeTypeIdInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static LogicalTypeId TransformTimeTypeId(PEGTransformer &transformer); + static unique_ptr TransformTimestampTypeIdInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static LogicalTypeId TransformTimestampTypeId(PEGTransformer &transformer); + static unique_ptr TransformTimeZoneInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformTimeZone(PEGTransformer &transformer, const bool &with_or_without); + static unique_ptr TransformWithOrWithoutInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformWithRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformWithRule(PEGTransformer &transformer); + static unique_ptr TransformWithoutRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformWithoutRule(PEGTransformer &transformer); + static unique_ptr TransformDisconnectStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDisconnectStatement(PEGTransformer &transformer); + static unique_ptr TransformCopyStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyStatement(PEGTransformer &transformer, + unique_ptr copy_variations); + static unique_ptr TransformCopyVariationsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyTableInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyTable(PEGTransformer &transformer, + unique_ptr base_table_name, + const vector &insert_column_list, const bool &from_or_to, + unique_ptr copy_file_name, + const vector ©_options); + static unique_ptr TransformFromOrToInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFromInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformCopyFrom(PEGTransformer &transformer); + static unique_ptr TransformCopyToInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformCopyTo(PEGTransformer &transformer); + static unique_ptr TransformCopySelectInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopySelect(PEGTransformer &transformer, + unique_ptr select_statement_internal, + unique_ptr copy_file_name, + const vector ©_options); + static unique_ptr TransformCopyFileNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFileNameExpressionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFileNameStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFileNameStringLiteral(PEGTransformer &transformer, + const string &string_literal); + static unique_ptr TransformCopyFileNameIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFileNameIdentifier(PEGTransformer &transformer, + const Identifier &identifier); + static unique_ptr TransformCopyFileNameIdentifierColIdInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFileNameIdentifierColId(PEGTransformer &transformer, + const Identifier &identifier_col_id); + static unique_ptr TransformIdentifierColIdInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformIdentifierColId(PEGTransformer &transformer, const Identifier &identifier, + const Identifier &col_id); + static unique_ptr TransformCopyOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformCopyOptions(PEGTransformer &transformer, + const vector ©_option_list); + static unique_ptr TransformCopyOptionListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSpecializedOptionListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector + TransformSpecializedOptionList(PEGTransformer &transformer, const vector &specialized_option); + static unique_ptr TransformSpecializedOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSingleOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformBinaryOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformBinaryOption(PEGTransformer &transformer); + static unique_ptr TransformFreezeOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformFreezeOption(PEGTransformer &transformer); + static unique_ptr TransformOidsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformOidsOption(PEGTransformer &transformer); + static unique_ptr TransformCsvOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformCsvOption(PEGTransformer &transformer); + static unique_ptr TransformHeaderOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformHeaderOption(PEGTransformer &transformer); + static unique_ptr TransformNullAsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformNullAsOption(PEGTransformer &transformer, const string &string_literal); + static unique_ptr TransformDelimiterAsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformDelimiterAsOption(PEGTransformer &transformer, const string &string_literal); + static unique_ptr TransformQuoteAsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformQuoteAsOption(PEGTransformer &transformer, const string &string_literal); + static unique_ptr TransformEscapeAsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformEscapeAsOption(PEGTransformer &transformer, const string &string_literal); + static unique_ptr TransformEncodingOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformEncodingOption(PEGTransformer &transformer, const string &string_literal); + static unique_ptr TransformForceQuoteOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformForceQuoteOption(PEGTransformer &transformer, const bool &force_quote, + const vector &star_symbol_column_list); + static unique_ptr TransformStarSymbolColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformForceQuoteInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformForceQuote(PEGTransformer &transformer); + static unique_ptr TransformPartitionByOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformPartitionByOption(PEGTransformer &transformer, + const vector &star_symbol_column_list); + static unique_ptr TransformForceNullOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformForceNullOption(PEGTransformer &transformer, const bool &force_not_null, + const vector &column_list); + static unique_ptr TransformForceNotNullInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformForceNotNull(PEGTransformer &transformer); + static unique_ptr TransformGenericCopyOptionListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector + TransformGenericCopyOptionList(PEGTransformer &transformer, const vector &generic_copy_option); + static unique_ptr TransformGenericCopyOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformGenericCopyOption(PEGTransformer &transformer, const Identifier ©_option_name, + GenericCopyOptionValue generic_copy_option_value); + static unique_ptr TransformGenericCopyOptionValueInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformGenericCopyOptionOrderListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOptionValue + TransformGenericCopyOptionOrderList(PEGTransformer &transformer, + vector generic_copy_option_parenthesized_expression_list); + static unique_ptr TransformGenericCopyOptionExpressionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOptionValue TransformGenericCopyOptionExpression(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr + TransformGenericCopyOptionParenthesizedExpressionListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector + TransformGenericCopyOptionParenthesizedExpressionList(PEGTransformer &transformer, + vector order_by_expression_list); + static unique_ptr TransformCopyFromDatabaseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFromDatabaseWithFlagInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFromDatabaseWithFlag(PEGTransformer &transformer, + const Identifier &col_id, + const Identifier &col_id_1, + const CopyDatabaseType ©_database_flag); + static unique_ptr TransformCopyFromDatabaseWithoutFlagInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopyFromDatabaseWithoutFlag(PEGTransformer &transformer, + const Identifier &col_id, + const Identifier &col_id_1); + static unique_ptr TransformCopyDatabaseFlagInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CopyDatabaseType TransformCopyDatabaseFlag(PEGTransformer &transformer, + const CopyDatabaseType &schema_or_data); + static unique_ptr TransformSchemaOrDataInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCopySchemaInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CopyDatabaseType TransformCopySchema(PEGTransformer &transformer); + static unique_ptr TransformCopyDataInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CopyDatabaseType TransformCopyData(PEGTransformer &transformer); + static unique_ptr TransformCreateIndexStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateIndexStmt( + PEGTransformer &transformer, const bool &unique_index, const bool &if_not_exists, const Identifier &index_name, + unique_ptr base_table_name, const vector &insert_column_list, + const Identifier &index_type, vector> index_element, + case_insensitive_map_t> with_list, unique_ptr where_clause); + static unique_ptr TransformWithListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static case_insensitive_map_t> + TransformWithList(PEGTransformer &transformer, + case_insensitive_map_t> rel_option_or_oids); + static unique_ptr TransformRelOptionOrOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformRelOptionListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static case_insensitive_map_t> + TransformRelOptionList(PEGTransformer &transformer, + vector>> rel_option); + static unique_ptr TransformOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static case_insensitive_map_t> TransformOids(PEGTransformer &transformer, + const bool &with_or_without_oids); + static unique_ptr TransformWithOrWithoutOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformWithOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformWithOids(PEGTransformer &transformer); + static unique_ptr TransformWithoutOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformWithoutOids(PEGTransformer &transformer); + static unique_ptr TransformIndexElementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformIndexElement(PEGTransformer &transformer, + unique_ptr expression, + const OrderType &desc_or_asc, + const OrderByNullType &nulls_first_or_last); + static unique_ptr TransformUniqueIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformUniqueIndex(PEGTransformer &transformer); + static unique_ptr TransformIndexTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformIndexType(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformRelOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> + TransformRelOption(PEGTransformer &transformer, const Identifier &rel_option_name, + unique_ptr rel_option_argument_opt); + static unique_ptr TransformRelOptionNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformRelOptionName(PEGTransformer &transformer, const string &child); + static unique_ptr TransformDottedIdentifierStringInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformDottedIdentifierString(PEGTransformer &transformer, const vector &dotted_identifier); + static unique_ptr TransformRelOptionArgumentOptInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformRelOptionArgumentOpt(PEGTransformer &transformer, + unique_ptr def_arg); + static unique_ptr TransformDefArgInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDefArgNullInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDefArgNull(PEGTransformer &transformer, const Value &null_literal); + static unique_ptr TransformDefArgKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDefArgKeyword(PEGTransformer &transformer, + const string &reserved_keyword); + static unique_ptr TransformDefArgStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDefArgStringLiteral(PEGTransformer &transformer, + const string &string_literal); + static unique_ptr TransformNoneLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformNoneLiteral(PEGTransformer &transformer); + static unique_ptr TransformCreateMacroStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformCreateMacroStmt(PEGTransformer &transformer, const bool ¯o_or_function, const bool &if_not_exists, + const QualifiedName &qualified_name, vector> macro_definition); + static unique_ptr TransformMacroOrFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformMacroKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformMacroKeyword(PEGTransformer &transformer); + static unique_ptr TransformFunctionKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformFunctionKeyword(PEGTransformer &transformer); + static unique_ptr TransformMacroDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformMacroDefinition(PEGTransformer &transformer, + vector macro_parameters, + unique_ptr macro_definition_body); + static unique_ptr TransformMacroDefinitionBodyInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformMacroParametersInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformMacroParameters(PEGTransformer &transformer, + vector macro_parameter); + static unique_ptr TransformMacroParameterInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSimpleParameterInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static MacroParameter TransformSimpleParameter(PEGTransformer &transformer, const Identifier &type_func_name, + const LogicalType &type); + static unique_ptr TransformScalarMacroDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformScalarMacroDefinition(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformTableMacroDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformTableMacroDefinition(PEGTransformer &transformer, unique_ptr select_statement_internal); + static unique_ptr TransformCreateSchemaStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateSchemaStmt(PEGTransformer &transformer, const bool &if_not_exists, + const QualifiedName &qualified_name); + static unique_ptr TransformCreateSecretStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformCreateSecretStmt(PEGTransformer &transformer, const bool &if_not_exists, const Identifier &secret_name, + const Identifier &secret_storage_specifier, + const vector &generic_copy_option_list); + static unique_ptr TransformSecretStorageSpecifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformSecretStorageSpecifier(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformSecretNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformSecretName(PEGTransformer &transformer, const Identifier &col_id); + static unique_ptr TransformCreateSequenceStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformCreateSequenceStmt(PEGTransformer &transformer, const bool &if_not_exists, + const QualifiedName &qualified_name, + vector>> sequence_option); + static unique_ptr TransformSequenceOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSeqSetCycleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSeqCycleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> TransformSeqCycle(PEGTransformer &transformer); + static unique_ptr TransformSeqNoCycleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> TransformSeqNoCycle(PEGTransformer &transformer); + static unique_ptr TransformSeqSetIncrementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> TransformSeqSetIncrement(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformSeqSetMinMaxInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> TransformSeqSetMinMax(PEGTransformer &transformer, + const string &seq_min_or_max, + unique_ptr expression); + static unique_ptr TransformSeqNoMinMaxInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> TransformSeqNoMinMax(PEGTransformer &transformer, + const string &seq_min_or_max); + static unique_ptr TransformSeqStartWithInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> TransformSeqStartWith(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformSeqOwnedByInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> TransformSeqOwnedBy(PEGTransformer &transformer, + const QualifiedName &qualified_name); + static unique_ptr TransformSeqMinOrMaxInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformMinValueInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformMinValue(PEGTransformer &transformer); + static unique_ptr TransformMaxValueInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformMaxValue(PEGTransformer &transformer); + static unique_ptr TransformCreateStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateStatement(PEGTransformer &transformer, const bool &or_replace, + const SecretPersistType &temporary, + unique_ptr create_statement_variation); + static unique_ptr TransformCreateStatementVariationInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformOrReplaceInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformOrReplace(PEGTransformer &transformer); + static unique_ptr TransformTemporaryInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPersistentInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SecretPersistType TransformPersistent(PEGTransformer &transformer); + static unique_ptr TransformTempPersistentInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SecretPersistType TransformTempPersistent(PEGTransformer &transformer); + static unique_ptr TransformTemporaryPersistentInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SecretPersistType TransformTemporaryPersistent(PEGTransformer &transformer); + static unique_ptr TransformCreateTableStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateTableStmt(PEGTransformer &transformer, const bool &if_not_exists, + const QualifiedName &qualified_name, + CreateTableDefinition create_table_definition, + const bool &commit_action); + static unique_ptr TransformCreateTableDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateTableAsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CreateTableDefinition TransformCreateTableAs(PEGTransformer &transformer, ColumnList identifier_list, + PartitionSortedOptions partition_sorted_options, + case_insensitive_map_t> with_list, + unique_ptr statement, const bool &with_data); + static unique_ptr TransformPartitionSortedOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPartitionOptSortedOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static PartitionSortedOptions + TransformPartitionOptSortedOptions(PEGTransformer &transformer, + vector> partition_options, + vector> sorted_options); + static unique_ptr TransformSortedOptPartitionOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static PartitionSortedOptions + TransformSortedOptPartitionOptions(PEGTransformer &transformer, vector> sorted_options, + vector> partition_options); + static unique_ptr TransformPartitionOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> + TransformPartitionOptions(PEGTransformer &transformer, vector> expression); + static unique_ptr TransformSortedOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformSortedOptions(PEGTransformer &transformer, + vector> expression); + static unique_ptr TransformWithDataInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformWithDataOnlyInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformWithDataOnly(PEGTransformer &transformer); + static unique_ptr TransformWithNoDataInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformWithNoData(PEGTransformer &transformer); + static unique_ptr TransformIdentifierListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnList TransformIdentifierList(PEGTransformer &transformer, const vector &identifier); + static unique_ptr TransformCreateColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CreateTableDefinition + TransformCreateColumnList(PEGTransformer &transformer, ColumnElements create_table_column_list, + PartitionSortedOptions partition_sorted_options, + case_insensitive_map_t> with_list); + static unique_ptr TransformIfNotExistsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformIfNotExists(PEGTransformer &transformer); + static unique_ptr TransformQualifiedNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformSchemaReservedIdentifierOrStringLiteralInternal(PEGTransformer &transformer, ParseResult &parse_result); + static QualifiedName + TransformSchemaReservedIdentifierOrStringLiteral(PEGTransformer &transformer, + const Identifier &schema_qualification, + const Identifier &reserved_identifier_or_string_literal); + static unique_ptr + TransformCatalogReservedSchemaIdentifierInternal(PEGTransformer &transformer, ParseResult &parse_result); + static QualifiedName + TransformCatalogReservedSchemaIdentifier(PEGTransformer &transformer, const Identifier &catalog_qualification, + const Identifier &reserved_schema_qualification, + const Identifier &reserved_identifier_or_string_literal); + static unique_ptr TransformIdentifierOrStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformIdentifierOrStringLiteral(PEGTransformer &transformer, const string &child); + static unique_ptr + TransformReservedIdentifierOrStringLiteralInternal(PEGTransformer &transformer, ParseResult &parse_result); + static unique_ptr TransformCatalogQualificationInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformCatalogQualification(PEGTransformer &transformer, const Identifier &catalog_name); + static unique_ptr TransformSchemaQualificationInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformSchemaQualification(PEGTransformer &transformer, const Identifier &schema_name); + static unique_ptr TransformReservedSchemaQualificationInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTableQualificationInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformTableQualification(PEGTransformer &transformer, const Identifier &table_name); + static unique_ptr TransformReservedTableQualificationInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformReservedTableQualification(PEGTransformer &transformer, + const Identifier &reserved_table_name); + static unique_ptr TransformCreateTableColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnElements TransformCreateTableColumnList(PEGTransformer &transformer, + vector create_table_column_element); + static unique_ptr TransformCreateTableColumnElementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateTableColumnDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CreateTableColumnElement TransformCreateTableColumnDefinition(PEGTransformer &transformer, + ConstraintColumnDefinition column_definition); + static unique_ptr TransformCreateTableConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CreateTableColumnElement TransformCreateTableConstraint(PEGTransformer &transformer, + unique_ptr top_level_constraint); + static unique_ptr TransformColumnDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ConstraintColumnDefinition TransformColumnDefinition(PEGTransformer &transformer, + const vector &dotted_identifier, + const LogicalType &type, + GeneratedColumnDefinition generated_column, + vector column_constraint); + static unique_ptr TransformColumnConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformNotNullConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnConstraintEntry TransformNotNullConstraint(PEGTransformer &transformer, const bool &child); + static unique_ptr TransformNullConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformNullConstraint(PEGTransformer &transformer); + static unique_ptr TransformNotNullColumnConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformNotNullColumnConstraint(PEGTransformer &transformer); + static unique_ptr TransformUniqueConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnConstraintEntry TransformUniqueConstraint(PEGTransformer &transformer); + static unique_ptr TransformPrimaryKeyConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnConstraintEntry TransformPrimaryKeyConstraint(PEGTransformer &transformer); + static unique_ptr TransformDefaultValueInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnConstraintEntry TransformDefaultValue(PEGTransformer &transformer, + unique_ptr column_default_expr); + static unique_ptr TransformCheckConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnConstraintEntry TransformCheckConstraint(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformForeignKeyConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnConstraintEntry TransformForeignKeyConstraint(PEGTransformer &transformer, + unique_ptr base_table_name, + const vector &column_list, + const KeyActions &key_actions); + static unique_ptr TransformColumnCollationInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnConstraintEntry TransformColumnCollation(PEGTransformer &transformer, + const vector &dotted_identifier); + static unique_ptr TransformColumnCompressionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ColumnConstraintEntry TransformColumnCompression(PEGTransformer &transformer, + const Identifier &col_id_or_string); + static unique_ptr TransformKeyActionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static KeyActions TransformKeyActions(PEGTransformer &transformer, const string &update_action, + const string &delete_action); + static unique_ptr TransformUpdateActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformUpdateAction(PEGTransformer &transformer, const string &key_action); + static unique_ptr TransformDeleteActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformDeleteAction(PEGTransformer &transformer, const string &key_action); + static unique_ptr TransformKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformNoKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformNoKeyAction(PEGTransformer &transformer); + static unique_ptr TransformRestrictKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformRestrictKeyAction(PEGTransformer &transformer); + static unique_ptr TransformCascadeKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformCascadeKeyAction(PEGTransformer &transformer); + static unique_ptr TransformSetNullKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformSetNullKeyAction(PEGTransformer &transformer); + static unique_ptr TransformSetDefaultKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformSetDefaultKeyAction(PEGTransformer &transformer); + static unique_ptr TransformTopLevelConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTopLevelConstraint(PEGTransformer &transformer, + unique_ptr top_level_constraint_list); + static unique_ptr TransformTopLevelConstraintListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTopLevelConstraintList(PEGTransformer &transformer, + ParseResult &choice_result); + static unique_ptr TransformTopPrimaryKeyConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTopPrimaryKeyConstraint(PEGTransformer &transformer, + const vector &column_id_list); + static unique_ptr TransformTopUniqueConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTopUniqueConstraint(PEGTransformer &transformer, + const vector &column_id_list); + static unique_ptr TransformTopForeignKeyConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTopForeignKeyConstraint(PEGTransformer &transformer, + const vector &column_id_list, + ColumnConstraintEntry foreign_key_constraint); + static unique_ptr TransformColumnIdListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformColumnIdList(PEGTransformer &transformer, const vector &col_id); + static unique_ptr TransformDottedIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformDottedIdentifier(PEGTransformer &transformer, const Identifier &identifier, + const vector &dot_col_label); + static unique_ptr TransformDotColLabelInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformColIdInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformColIdOrStringInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformColIdOrString(PEGTransformer &transformer, ParseResult &choice_result); + static unique_ptr TransformTypeFuncNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformGeneratedColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GeneratedColumnDefinition TransformGeneratedColumn(PEGTransformer &transformer, + unique_ptr expression, + const bool &generated_column_type); + static unique_ptr TransformGeneratedColumnTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCommitActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformCommitAction(PEGTransformer &transformer, const bool &preserve_or_delete); + static unique_ptr TransformPreserveOrDeleteInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPreserveRowsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformPreserveRows(PEGTransformer &transformer); + static unique_ptr TransformDeleteRowsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformDeleteRows(PEGTransformer &transformer); + static unique_ptr TransformVirtualGeneratedColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformVirtualGeneratedColumn(PEGTransformer &transformer); + static unique_ptr TransformStoredGeneratedColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformStoredGeneratedColumn(PEGTransformer &transformer); + static unique_ptr TransformCreateTriggerStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformCreateTriggerStmt(PEGTransformer &transformer, const bool &if_not_exists, const Identifier &trigger_name, + const TriggerTiming &trigger_timing, const TriggerEventInfo &trigger_event, + unique_ptr base_table_name, + const TriggerTableReferencingInfo &referencing_clause, + const TriggerForEach &for_each_clause, unique_ptr trigger_body); + static unique_ptr TransformTriggerBodyInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTriggerNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformTriggerName(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformReferencingClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerTableReferencingInfo + TransformReferencingClause(PEGTransformer &transformer, const TriggerTableReferencingInfo &referencing_item, + const TriggerTableReferencingInfo &referencing_item_1); + static unique_ptr TransformReferencingItemInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformReferencingNewTableAsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerTableReferencingInfo TransformReferencingNewTableAs(PEGTransformer &transformer, + const Identifier &col_id); + static unique_ptr TransformReferencingOldTableAsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerTableReferencingInfo TransformReferencingOldTableAs(PEGTransformer &transformer, + const Identifier &col_id); + static unique_ptr TransformTriggerTimingInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTriggerBeforeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerTiming TransformTriggerBefore(PEGTransformer &transformer); + static unique_ptr TransformTriggerAfterInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerTiming TransformTriggerAfter(PEGTransformer &transformer); + static unique_ptr TransformTriggerInsteadOfInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerTiming TransformTriggerInsteadOf(PEGTransformer &transformer); + static unique_ptr TransformTriggerEventInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTriggerEventInsertInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerEventInfo TransformTriggerEventInsert(PEGTransformer &transformer); + static unique_ptr TransformTriggerEventDeleteInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerEventInfo TransformTriggerEventDelete(PEGTransformer &transformer); + static unique_ptr TransformTriggerEventUpdateInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerEventInfo TransformTriggerEventUpdate(PEGTransformer &transformer); + static unique_ptr TransformTriggerEventUpdateOfInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerEventInfo TransformTriggerEventUpdateOf(PEGTransformer &transformer, + const vector &trigger_column_list); + static unique_ptr TransformTriggerColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformTriggerColumnList(PEGTransformer &transformer, const vector &col_id); + static unique_ptr TransformForEachClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformForEachRowInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerForEach TransformForEachRow(PEGTransformer &transformer); + static unique_ptr TransformForEachStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TriggerForEach TransformForEachStatement(PEGTransformer &transformer); + static unique_ptr TransformCreateTypeStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateTypeStmt(PEGTransformer &transformer, const bool &if_not_exists, + const QualifiedName &qualified_name, + unique_ptr create_type); + static unique_ptr TransformCreateTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateTypeFromTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCreateTypeFromType(PEGTransformer &transformer, const LogicalType &type); + static unique_ptr TransformEnumSelectTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformEnumSelectType(PEGTransformer &transformer, + unique_ptr select_statement_internal); + static unique_ptr TransformEnumStringLiteralListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformEnumStringLiteralList(PEGTransformer &transformer, + const vector &string_literal); + static unique_ptr TransformCreateViewStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformCreateViewStmt(PEGTransformer &transformer, const bool &create_recursive, const bool &if_not_exists, + const QualifiedName &qualified_name, const vector &insert_column_list, + case_insensitive_map_t> with_list, + unique_ptr select_statement_internal); + static unique_ptr TransformCreateRecursiveInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformCreateRecursive(PEGTransformer &transformer); + static unique_ptr TransformDeallocateStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDeallocateStatement(PEGTransformer &transformer, + const bool &deallocate_prepare, + const Identifier &identifier); + static unique_ptr TransformDeallocatePrepareInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformDeallocatePrepare(PEGTransformer &transformer); + static unique_ptr TransformDeleteStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDeleteStatement(PEGTransformer &transformer, + CommonTableExpressionMap with_clause, + unique_ptr target_opt_alias, + vector> delete_using_clause, + unique_ptr where_clause, + vector> returning_clause); + static unique_ptr TransformTruncateStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTruncateStatement(PEGTransformer &transformer, + unique_ptr base_table_name); + static unique_ptr TransformTargetOptAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformTargetOptAlias(PEGTransformer &transformer, + unique_ptr base_table_name, + const Identifier &col_id); + static unique_ptr TransformDeleteUsingClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformDeleteUsingClause(PEGTransformer &transformer, + vector> table_ref); + static unique_ptr TransformDescribeStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDescribeStatement(PEGTransformer &transformer, + unique_ptr child); + static unique_ptr TransformShowSelectInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformShowSelect(PEGTransformer &transformer, + const ShowType &show_or_describe_or_summarize, + unique_ptr select_statement_internal); + static unique_ptr TransformShowAllTablesInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformShowAllTables(PEGTransformer &transformer, const ShowType &show_or_describe); + static unique_ptr TransformShowQualifiedNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformShowQualifiedName(PEGTransformer &transformer, + const ShowType &show_or_describe_or_summarize, + DescribeTarget describe_target); + static unique_ptr TransformShowTablesInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformShowTables(PEGTransformer &transformer, const ShowType &show_or_describe, + const QualifiedName &qualified_name); + static unique_ptr TransformDescribeTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDescribeBaseTableNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DescribeTarget TransformDescribeBaseTableName(PEGTransformer &transformer, + unique_ptr base_table_name); + static unique_ptr TransformDescribeStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static DescribeTarget TransformDescribeStringLiteral(PEGTransformer &transformer, const string &string_literal); + static unique_ptr TransformShowOrDescribeOrSummarizeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSummarizeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSummarizeRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ShowType TransformSummarizeRule(PEGTransformer &transformer); + static unique_ptr TransformShowOrDescribeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformShowRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ShowType TransformShowRule(PEGTransformer &transformer); + static unique_ptr TransformDescribeRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDescribeLongRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ShowType TransformDescribeLongRule(PEGTransformer &transformer); + static unique_ptr TransformDescRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ShowType TransformDescRule(PEGTransformer &transformer); + static unique_ptr TransformDetachStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDetachStatement(PEGTransformer &transformer, const bool &if_exists, + const Identifier &catalog_name); + static unique_ptr TransformDropStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropStatement(PEGTransformer &transformer, + unique_ptr drop_entries, + const bool &drop_behavior); + static unique_ptr TransformDropEntriesInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropTriggerInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropTrigger(PEGTransformer &transformer, const bool &if_exists, + const Identifier &trigger_name, + unique_ptr base_table_name); + static unique_ptr TransformDropTableInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropTable(PEGTransformer &transformer, const CatalogType &table_or_view, + const bool &if_exists, + vector> base_table_name); + static unique_ptr TransformDropTableFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropTableFunction(PEGTransformer &transformer, + const CatalogType &comment_macro_table, + const bool &if_exists, + const vector &table_function_name); + static unique_ptr TransformDropFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropFunction(PEGTransformer &transformer, const bool &function_type_macro, + const bool &if_exists, + const vector &function_identifier); + static unique_ptr TransformDropSchemaInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropSchema(PEGTransformer &transformer, const bool &if_exists, + const vector &qualified_schema_name); + static unique_ptr TransformDropIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropIndex(PEGTransformer &transformer, const bool &if_exists, + const vector &qualified_index_name); + static unique_ptr TransformQualifiedIndexNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformQualifiedIndexNameStringInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformQualifiedIndexNameString(PEGTransformer &transformer, const Identifier &index_name); + static unique_ptr TransformSchemaReservedIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformSchemaReservedIndex(PEGTransformer &transformer, + const Identifier &schema_qualification, + const Identifier &reserved_index_name); + static unique_ptr TransformCatalogReservedSchemaIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformCatalogReservedSchemaIndex(PEGTransformer &transformer, + const Identifier &catalog_qualification, + const Identifier &reserved_schema_qualification, + const Identifier &reserved_index_name); + static unique_ptr TransformDropSequenceInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropSequence(PEGTransformer &transformer, const bool &if_exists, + const vector &qualified_sequence_name); + static unique_ptr TransformDropCollationInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropCollation(PEGTransformer &transformer, const bool &if_exists, + const vector &collation_name); + static unique_ptr TransformDropTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropType(PEGTransformer &transformer, const bool &if_exists, + const vector &qualified_type_name); + static unique_ptr TransformDropSecretInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDropSecret(PEGTransformer &transformer, + const SecretPersistType &temporary, const bool &if_exists, + const Identifier &secret_name, + const Identifier &drop_secret_storage); + static unique_ptr TransformTableOrViewInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformMaterializedViewEntryInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static CatalogType TransformMaterializedViewEntry(PEGTransformer &transformer); + static unique_ptr TransformFunctionTypeMacroInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformFunctionTypeMacroKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformFunctionTypeMacroKeyword(PEGTransformer &transformer); + static unique_ptr TransformFunctionTypeFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformFunctionTypeFunction(PEGTransformer &transformer); + static unique_ptr TransformDropBehaviorInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCascadeDropBehaviorInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformCascadeDropBehavior(PEGTransformer &transformer); + static unique_ptr TransformRestrictDropBehaviorInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformRestrictDropBehavior(PEGTransformer &transformer); + static unique_ptr TransformIfExistsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformIfExists(PEGTransformer &transformer); + static unique_ptr TransformQualifiedSchemaNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformQualifiedSchemaNameStringInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformQualifiedSchemaNameString(PEGTransformer &transformer, const Identifier &schema_name); + static unique_ptr TransformCatalogReservedSchemaInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformCatalogReservedSchema(PEGTransformer &transformer, + const Identifier &catalog_qualification, + const Identifier &reserved_schema_name); + static unique_ptr TransformDropSecretStorageInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformDropSecretStorage(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformExecuteStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformExecuteStatement(PEGTransformer &transformer, const Identifier &identifier, + vector table_function_arguments); + static unique_ptr TransformExplainStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformExplainStatement(PEGTransformer &transformer, const bool &explain_analyze, + const vector &explain_option_list, + unique_ptr explainable_statements); + static unique_ptr TransformExplainAnalyzeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformExplainAnalyze(PEGTransformer &transformer); + static unique_ptr TransformExplainOptionListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformExplainOptionList(PEGTransformer &transformer, + const vector &explain_option); + static unique_ptr TransformExplainOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static GenericCopyOption TransformExplainOption(PEGTransformer &transformer, const Identifier &explain_option_name, + unique_ptr expression); + static unique_ptr TransformExplainSelectStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformExplainSelectStatement(PEGTransformer &transformer, unique_ptr select_statement_internal); + static unique_ptr TransformExplainableStatementsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformExportStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformExportStatement(PEGTransformer &transformer, const string &export_source, + const string &string_literal, + const vector &generic_copy_option_list); + static unique_ptr TransformExportSourceInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformExportSource(PEGTransformer &transformer, const Identifier &catalog_name); + static unique_ptr TransformImportStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformImportStatement(PEGTransformer &transformer, const string &string_literal); + static unique_ptr TransformInsertStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformInsertStatement(PEGTransformer &transformer, CommonTableExpressionMap with_clause, + const OnConflictAction &or_action, unique_ptr insert_target, + const InsertColumnOrder &by_name_or_position, const vector &insert_column_list, + InsertValues insert_values, unique_ptr on_conflict_clause, + vector> returning_clause); + static unique_ptr TransformOrActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertOrReplaceInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static OnConflictAction TransformInsertOrReplace(PEGTransformer &transformer); + static unique_ptr TransformInsertOrIgnoreInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static OnConflictAction TransformInsertOrIgnore(PEGTransformer &transformer); + static unique_ptr TransformByNameOrPositionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertByNameOrderInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertByPositionOrderInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertByNameInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static InsertColumnOrder TransformInsertByName(PEGTransformer &transformer); + static unique_ptr TransformInsertByPositionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static InsertColumnOrder TransformInsertByPosition(PEGTransformer &transformer); + static unique_ptr TransformInsertTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertTarget(PEGTransformer &transformer, + unique_ptr base_table_name, + const Identifier &insert_alias); + static unique_ptr TransformInsertAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformInsertAlias(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformColumnList(PEGTransformer &transformer, const vector &col_id); + static unique_ptr TransformInsertColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformInsertColumnList(PEGTransformer &transformer, const vector &column_list); + static unique_ptr TransformInsertValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSelectInsertValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static InsertValues TransformSelectInsertValues(PEGTransformer &transformer, + unique_ptr select_statement_internal); + static unique_ptr TransformDefaultValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static InsertValues TransformDefaultValues(PEGTransformer &transformer); + static unique_ptr TransformOnConflictClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformOnConflictClause(PEGTransformer &transformer, + OnConflictExpressionTarget on_conflict_target, + unique_ptr on_conflict_action); + static unique_ptr TransformOnConflictTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformOnConflictExpressionTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static OnConflictExpressionTarget TransformOnConflictExpressionTarget(PEGTransformer &transformer, + const vector &column_id_list, + unique_ptr where_clause); + static unique_ptr TransformOnConflictIndexTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static OnConflictExpressionTarget TransformOnConflictIndexTarget(PEGTransformer &transformer, + const Identifier &constraint_name); + static unique_ptr TransformOnConflictActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformOnConflictUpdateInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformOnConflictUpdate(PEGTransformer &transformer, + unique_ptr update_set_clause, + unique_ptr where_clause); + static unique_ptr TransformOnConflictNothingInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformOnConflictNothing(PEGTransformer &transformer); + static unique_ptr TransformReturningClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> + TransformReturningClause(PEGTransformer &transformer, vector> target_list); + static unique_ptr TransformLoadStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformLoadStatement(PEGTransformer &transformer, + const Identifier &col_id_or_string, + const Identifier &extension_alias); + static unique_ptr TransformExtensionAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformExtensionAlias(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformInstallStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInstallStatement(PEGTransformer &transformer, + const QualifiedName &identifier_or_string_literal, + const ExtensionRepositoryInfo &from_source, + const string &version_number); + static unique_ptr TransformUpdateExtensionsStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateExtensionsStatement(PEGTransformer &transformer, + const vector &identifier); + static unique_ptr TransformFromSourceInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformFromSourceIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ExtensionRepositoryInfo TransformFromSourceIdentifier(PEGTransformer &transformer, + const Identifier &identifier); + static unique_ptr TransformFromSourceStringInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static ExtensionRepositoryInfo TransformFromSourceString(PEGTransformer &transformer, const string &string_literal); + static unique_ptr TransformVersionNumberInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformVersionNumber(PEGTransformer &transformer, + const QualifiedName &identifier_or_string_literal); + static unique_ptr TransformMergeIntoStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformMergeIntoStatement(PEGTransformer &transformer, CommonTableExpressionMap with_clause, + unique_ptr target_opt_alias, unique_ptr merge_into_using_clause, + JoinQualifier join_qualifier, + vector>> merge_match, + vector> returning_clause); + static unique_ptr TransformMergeIntoUsingClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformMergeIntoUsingClause(PEGTransformer &transformer, + unique_ptr table_ref); + static unique_ptr TransformMergeMatchInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformMatchedClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> + TransformMatchedClause(PEGTransformer &transformer, unique_ptr and_expression, + unique_ptr matched_clause_action); + static unique_ptr TransformMatchedClauseActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateMatchClause(PEGTransformer &transformer, + unique_ptr update_match_info); + static unique_ptr TransformUpdateMatchInfoInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateMatchSetActionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateMatchSetAction(PEGTransformer &transformer, + unique_ptr update_match_set_clause); + static unique_ptr TransformUpdateByNameOrPositionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateByNameOrPosition(PEGTransformer &transformer, + const InsertColumnOrder &by_name_or_position); + static unique_ptr TransformDeleteMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDeleteMatchClause(PEGTransformer &transformer); + static unique_ptr TransformInsertMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertMatchClause(PEGTransformer &transformer, + unique_ptr insert_match_info); + static unique_ptr TransformInsertMatchInfoInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertDefaultValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertDefaultValues(PEGTransformer &transformer); + static unique_ptr TransformInsertByNameOrPositionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertByNameOrPosition(PEGTransformer &transformer, + const InsertColumnOrder &by_name_or_position); + static unique_ptr TransformInsertValuesListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformInsertValuesList(PEGTransformer &transformer, + const vector &insert_column_list, + vector> expression); + static unique_ptr TransformDoNothingMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformDoNothingMatchClause(PEGTransformer &transformer); + static unique_ptr TransformErrorMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformErrorMatchClause(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformUpdateMatchSetClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateMatchSetInfoInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAndExpressionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformAndExpression(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformNotMatchedClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> + TransformNotMatchedClause(PEGTransformer &transformer, const MergeActionCondition &by_source_or_target, + unique_ptr and_expression, + unique_ptr matched_clause_action); + static unique_ptr TransformBySourceOrTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformBySourceInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static MergeActionCondition TransformBySource(PEGTransformer &transformer); + static unique_ptr TransformByTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static MergeActionCondition TransformByTarget(PEGTransformer &transformer); + static unique_ptr TransformPivotOnInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformPivotOn(PEGTransformer &transformer, vector pivot_column_list); + static unique_ptr TransformPivotUsingInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformPivotUsing(PEGTransformer &transformer, + vector> target_list); + static unique_ptr TransformPivotColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformPivotColumnList(PEGTransformer &transformer, + vector pivot_column_entry); + static unique_ptr TransformPivotColumnEntryInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPivotColumnExpressionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static PivotColumn TransformPivotColumnExpression(PEGTransformer &transformer, + unique_ptr expression); + static unique_ptr TransformPivotColumnSubqueryInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static PivotColumn TransformPivotColumnSubquery(PEGTransformer &transformer, + unique_ptr base_expression, + unique_ptr select_statement_internal); + static unique_ptr TransformIntoNameValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static UnpivotNameValues TransformIntoNameValues(PEGTransformer &transformer, const Identifier &col_id_or_string, + const vector &identifier); + static unique_ptr TransformIncludeOrExcludeNullsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformIncludeNullsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformIncludeNulls(PEGTransformer &transformer); + static unique_ptr TransformExcludeNullsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static bool TransformExcludeNulls(PEGTransformer &transformer); + static unique_ptr TransformUnpivotHeaderInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUnpivotHeaderSingleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformUnpivotHeaderSingle(PEGTransformer &transformer, const Identifier &col_id_or_string); + static unique_ptr TransformUnpivotHeaderListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformUnpivotHeaderList(PEGTransformer &transformer, + const vector &col_id_or_string); + static unique_ptr TransformPragmaStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPragmaStatement(PEGTransformer &transformer, + unique_ptr pragma_assign_or_function); + static unique_ptr TransformPragmaAssignOrFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPragmaAssignInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPragmaAssign(PEGTransformer &transformer, const Identifier &setting_name, + vector> variable_list); + static unique_ptr TransformPragmaFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPragmaFunction(PEGTransformer &transformer, const Identifier &pragma_name, + vector> pragma_parameters); + static unique_ptr TransformPragmaParametersInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> + TransformPragmaParameters(PEGTransformer &transformer, vector> expression); + static unique_ptr TransformPrepareStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformPrepareStatement(PEGTransformer &transformer, const Identifier &identifier, + const vector &type_list, + unique_ptr statement); + static unique_ptr TransformTypeListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformTypeList(PEGTransformer &transformer, const vector &type); + static unique_ptr TransformSetStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSetStatement(PEGTransformer &transformer, + unique_ptr set_assignment_or_time_zone); + static unique_ptr TransformSetAssignmentOrTimeZoneInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformResetStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformResetStatement(PEGTransformer &transformer, + const SettingInfo &set_variable_or_setting); + static unique_ptr TransformStandardAssignmentInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformStandardAssignment(PEGTransformer &transformer, + const SettingInfo &set_variable_or_setting, + vector> set_assignment); + static unique_ptr TransformSetVariableOrSettingInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSetTimeZoneInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSetTimeZone(PEGTransformer &transformer, + unique_ptr zone_value); + static unique_ptr TransformZoneValueInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformZoneLocalInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformZoneLocal(PEGTransformer &transformer); + static unique_ptr TransformZoneDefaultInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformZoneDefault(PEGTransformer &transformer); + static unique_ptr TransformZoneStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformZoneStringLiteral(PEGTransformer &transformer, + const string &string_literal); + static unique_ptr TransformZoneIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformZoneIdentifier(PEGTransformer &transformer, + const Identifier &identifier); + static unique_ptr TransformZoneIntervalWithIntervalInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformZoneIntervalWithInterval(PEGTransformer &transformer, + const string &string_literal, + const DatePartSpecifier &interval); + static unique_ptr TransformZoneIntervalWithPrecisionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformZoneIntervalWithPrecision(PEGTransformer &transformer, + unique_ptr number_literal, + const string &string_literal); + static unique_ptr TransformSetSettingInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SettingInfo TransformSetSetting(PEGTransformer &transformer, const SetScope &setting_scope, + const Identifier &setting_name); + static unique_ptr TransformSetVariableInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SettingInfo TransformSetVariable(PEGTransformer &transformer, const SetScope &variable_scope, + const Identifier &identifier); + static unique_ptr TransformVariableScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SetScope TransformVariableScope(PEGTransformer &transformer); + static unique_ptr TransformSettingScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformLocalScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SetScope TransformLocalScope(PEGTransformer &transformer); + static unique_ptr TransformSessionScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SetScope TransformSessionScope(PEGTransformer &transformer); + static unique_ptr TransformGlobalScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static SetScope TransformGlobalScope(PEGTransformer &transformer); + static unique_ptr TransformSetAssignmentInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> + TransformSetAssignment(PEGTransformer &transformer, vector> variable_list); + static unique_ptr TransformVariableListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector> TransformVariableList(PEGTransformer &transformer, + vector> expression); + static unique_ptr TransformTransactionStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformBeginTransactionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformBeginTransaction(PEGTransformer &transformer, + const TransactionModifierType &read_or_write); + static unique_ptr TransformRollbackTransactionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformRollbackTransaction(PEGTransformer &transformer); + static unique_ptr TransformCommitTransactionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCommitTransaction(PEGTransformer &transformer); + static unique_ptr TransformReadOrWriteInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TransactionModifierType TransformReadOrWrite(PEGTransformer &transformer, + const TransactionModifierType &read_only_or_read_write); + static unique_ptr TransformReadOnlyOrReadWriteInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformReadOnlyInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TransactionModifierType TransformReadOnly(PEGTransformer &transformer); + static unique_ptr TransformReadWriteInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static TransactionModifierType TransformReadWrite(PEGTransformer &transformer); + static unique_ptr TransformUpdateStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformUpdateStatement(PEGTransformer &transformer, CommonTableExpressionMap with_clause, + unique_ptr update_target, unique_ptr update_set_clause, + unique_ptr from_clause, unique_ptr where_clause, + vector> returning_clause); + static unique_ptr TransformUpdateTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformBaseTableSetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformBaseTableSet(PEGTransformer &transformer, + unique_ptr base_table_name); + static unique_ptr TransformBaseTableAliasSetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformBaseTableAliasSet(PEGTransformer &transformer, + unique_ptr base_table_name, + const Identifier &update_alias); + static unique_ptr TransformUpdateAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformUpdateAlias(PEGTransformer &transformer, const Identifier &col_id); + static unique_ptr TransformUpdateSetClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateSetTupleInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUpdateSetTuple(PEGTransformer &transformer, + const vector &column_name, + unique_ptr expression); + static unique_ptr TransformUpdateSetElementListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr + TransformUpdateSetElementList(PEGTransformer &transformer, + vector>> update_set_element); + static unique_ptr TransformUpdateSetElementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static pair> + TransformUpdateSetElement(PEGTransformer &transformer, const string &update_set_column_target, + unique_ptr expression); + static unique_ptr TransformUpdateSetColumnTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformUpdateSetColumnTarget(PEGTransformer &transformer, const Identifier &column_name, + const vector &dot_identifier); + static unique_ptr TransformUseStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformUseStatement(PEGTransformer &transformer, const QualifiedName &use_target); + static unique_ptr TransformUseTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformSchemaNameAsUseTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformSchemaNameAsUseTarget(PEGTransformer &transformer, const Identifier &schema_name); + static unique_ptr TransformCatalogNameAsUseTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformCatalogNameAsUseTarget(PEGTransformer &transformer, const Identifier &catalog_name); + static unique_ptr TransformUseTargetCatalogSchemaInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static QualifiedName TransformUseTargetCatalogSchema(PEGTransformer &transformer, const Identifier &catalog_name, + const Identifier &reserved_schema_name, + const vector &dot_identifier); + static unique_ptr TransformDotIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static Identifier TransformDotIdentifier(PEGTransformer &transformer, const Identifier &identifier); + static unique_ptr TransformVacuumStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformVacuumStatement(PEGTransformer &transformer, + const VacuumOptions &vacuum_options, + AnalyzeTarget analyze_target); + static unique_ptr TransformVacuumOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformVacuumParensOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static VacuumOptions TransformVacuumParensOptions(PEGTransformer &transformer, const vector &vacuum_option); + static unique_ptr TransformVacuumLegacyOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static VacuumOptions TransformVacuumLegacyOptions(PEGTransformer &transformer, const string &opt_full, + const string &opt_freeze, const string &opt_verbose, + const string &opt_analyze); + static unique_ptr TransformVacuumOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformOptAnalyzeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformOptAnalyze(PEGTransformer &transformer); + static unique_ptr TransformOptFullInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformOptFull(PEGTransformer &transformer); + static unique_ptr TransformOptFreezeInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformOptFreeze(PEGTransformer &transformer); + static unique_ptr TransformOptVerboseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static string TransformOptVerbose(PEGTransformer &transformer); + static unique_ptr TransformNameListInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static vector TransformNameList(PEGTransformer &transformer, const vector &col_id); + //===--------------------------------------------------------------------===// + // END GENERATED RULES + //===--------------------------------------------------------------------===// private: PEGParser parser; diff --git a/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer_generated.hpp b/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer_generated.hpp deleted file mode 100644 index 288eb86db..000000000 --- a/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer_generated.hpp +++ /dev/null @@ -1,43 +0,0 @@ -// AUTO-GENERATED by scripts/parser/gen_transformer_v2.py -- DO NOT EDIT -#ifdef DUCKDB_INSIDE_PEG_TRANSFORMER_HPP -static unique_ptr TransformUseStatementInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static unique_ptr TransformUseStatement(const QualifiedName &use_target); -static unique_ptr TransformUseTargetInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static QualifiedName TransformUseTarget(PEGTransformer &transformer, ParseResult &choice_result); -static unique_ptr TransformUseTargetCatalogSchemaInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static QualifiedName TransformUseTargetCatalogSchema(const string &catalog_name, const string &reserved_schema_name, - const vector &dot_identifier); -static unique_ptr TransformDotIdentifierInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static string TransformDotIdentifier(const string &identifier); -static unique_ptr TransformTransactionStatementInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static unique_ptr TransformBeginTransactionInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static unique_ptr TransformBeginTransaction(const TransactionModifierType &read_or_write); -static unique_ptr TransformRollbackTransactionInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static unique_ptr TransformRollbackTransaction(); -static unique_ptr TransformCommitTransactionInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static unique_ptr TransformCommitTransaction(); -static unique_ptr TransformReadOrWriteInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static TransactionModifierType TransformReadOrWrite(const TransactionModifierType &read_only_or_read_write); -static unique_ptr TransformDetachStatementInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static unique_ptr TransformDetachStatement(const bool &if_exists, const string &catalog_name); -static unique_ptr TransformExportStatementInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static unique_ptr TransformExportStatement(const string &export_source, const string &string_literal, - vector generic_copy_option_list); -static unique_ptr TransformExportSourceInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static string TransformExportSource(const string &catalog_name); -static unique_ptr TransformImportStatementInternal(PEGTransformer &transformer, - ParseResult &parse_result); -static unique_ptr TransformImportStatement(const string &string_literal); -#endif // DUCKDB_INSIDE_PEG_TRANSFORMER_HPP diff --git a/src/duckdb/src/include/duckdb/parser/qualified_name.hpp b/src/duckdb/src/include/duckdb/parser/qualified_name.hpp index 38f3a43cb..e19f62c88 100644 --- a/src/duckdb/src/include/duckdb/parser/qualified_name.hpp +++ b/src/duckdb/src/include/duckdb/parser/qualified_name.hpp @@ -15,27 +15,27 @@ namespace duckdb { struct QualifiedName { - string catalog; - string schema; - string name; + Identifier catalog; + Identifier schema; + Identifier name; //! Parse the (optional) schema and a name from a string in the format of e.g. "schema"."table"; if there is no dot //! the schema will be set to INVALID_SCHEMA static QualifiedName Parse(const string &input); - static vector ParseComponents(const string &input); + static vector ParseComponents(const string &input); string ToString() const; }; struct QualifiedColumnName { QualifiedColumnName(); - QualifiedColumnName(string column_p); // NOLINT: allow implicit conversion from string to column name - QualifiedColumnName(string table_p, string column_p); - QualifiedColumnName(const BindingAlias &alias, string column_p); - - string catalog; - string schema; - string table; - string column; + QualifiedColumnName(Identifier column_p); // NOLINT: allow implicit conversion from string to column name + QualifiedColumnName(Identifier table_p, Identifier column_p); + QualifiedColumnName(const BindingAlias &alias, Identifier column_p); + + Identifier catalog; + Identifier schema; + Identifier table; + Identifier column; static QualifiedColumnName Parse(string &input); diff --git a/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp b/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp index 5dd28c78c..199100140 100644 --- a/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp +++ b/src/duckdb/src/include/duckdb/parser/qualified_name_set.hpp @@ -16,7 +16,7 @@ namespace duckdb { struct QualifiedColumnHashFunction { uint64_t operator()(const QualifiedColumnName &a) const { // hash only on the column name - since we match based on the shortest possible match - return StringUtil::CIHash(a.column); + return a.column.Hash(); } }; @@ -25,16 +25,16 @@ struct QualifiedColumnEquality { // qualified column names follow a prefix comparison // so "tbl.i" and "i" are equivalent, as are "schema.tbl.i" and "i" // but "tbl.i" and "tbl2.i" are not equivalent - if (!a.catalog.empty() && !b.catalog.empty() && !StringUtil::CIEquals(a.catalog, b.catalog)) { + if (!a.catalog.empty() && !b.catalog.empty() && a.catalog != b.catalog) { return false; } - if (!a.schema.empty() && !b.schema.empty() && !StringUtil::CIEquals(a.schema, b.schema)) { + if (!a.schema.empty() && !b.schema.empty() && a.schema != b.schema) { return false; } - if (!a.table.empty() && !b.table.empty() && !StringUtil::CIEquals(a.table, b.table)) { + if (!a.table.empty() && !b.table.empty() && a.table != b.table) { return false; } - return StringUtil::CIEquals(a.column, b.column); + return a.column == b.column; } }; diff --git a/src/duckdb/src/include/duckdb/parser/query_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node.hpp index fa012561e..02344161c 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node.hpp @@ -28,7 +28,8 @@ enum class QueryNodeType : uint8_t { STATEMENT_NODE = 6, UPDATE_QUERY_NODE = 7, DELETE_QUERY_NODE = 8, - INSERT_QUERY_NODE = 9 + INSERT_QUERY_NODE = 9, + MERGE_QUERY_NODE = 10 }; struct CommonTableExpressionInfo; @@ -37,7 +38,7 @@ class CommonTableExpressionMap { public: CommonTableExpressionMap(); - InsertionOrderPreservingMap> map; + InsertionOrderPreservingMap, Identifier, identifier_map_t> map; public: string ToString() const; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp index 3bca2cd06..a62582fc2 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/query_node.hpp" @@ -22,10 +23,10 @@ class CTENode : public QueryNode { CTENode() : QueryNode(QueryNodeType::CTE_NODE) { } - string ctename; + Identifier ctename; unique_ptr query; unique_ptr child; - vector aliases; + vector aliases; CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/insert_query_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/insert_query_node.hpp index 0e0b76da6..137d984b4 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/insert_query_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/insert_query_node.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/query_node.hpp" #include "duckdb/parser/tableref.hpp" #include "duckdb/parser/parsed_expression.hpp" @@ -33,14 +34,14 @@ class InsertQueryNode : public QueryNode { //! The select statement to insert from unique_ptr select_statement; //! Column names to insert into - vector columns; + vector columns; //! Table name to insert to - string table; + Identifier table; //! Schema name to insert to - string schema; + Identifier schema; //! The catalog name to insert to - string catalog; + Identifier catalog; //! keep track of optional returningList if statement contains a RETURNING keyword vector> returning_list; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp index d136ccceb..152b57c7a 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp @@ -6,3 +6,4 @@ #include "duckdb/parser/query_node/update_query_node.hpp" #include "duckdb/parser/query_node/delete_query_node.hpp" #include "duckdb/parser/query_node/insert_query_node.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/query_node/merge_query_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/merge_query_node.hpp new file mode 100644 index 000000000..61737ece3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/merge_query_node.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/merge_query_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/identifier.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/tableref.hpp" +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/common/map.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/merge_action_type.hpp" + +namespace duckdb { +class Serializer; +class Deserializer; +class MergeIntoAction; + +//! MergeQueryNode represents a MERGE INTO statement +class MergeQueryNode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::MERGE_QUERY_NODE; + +public: + MergeQueryNode(); + + //! The table to merge into + unique_ptr target; + //! The source of the merge + unique_ptr source; + //! The join condition (ON clause) - or NULL if a USING column list is used + unique_ptr join_condition; + //! The columns used in the USING clause (if join_condition is NULL) + vector using_columns; + //! The merge actions, grouped by their condition + map>> actions; + //! keep track of optional returningList if statement contains a RETURNING keyword + vector> returning_list; + +public: + string ToString() const override; + bool Equals(const QueryNode *other) const override; + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + + static string ActionConditionToString(MergeActionCondition condition); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp index e4c8c619c..c04986afd 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/query_node.hpp" @@ -21,14 +22,14 @@ class RecursiveCTENode : public QueryNode { RecursiveCTENode() : QueryNode(QueryNodeType::RECURSIVE_CTE_NODE) { } - string ctename; + Identifier ctename; bool union_all; //! The left side of the set operation unique_ptr left; //! The right side of the set operation unique_ptr right; //! Aliases of the recursive CTE node - vector aliases; + vector aliases; //! targets for key variants vector> key_targets; diff --git a/src/duckdb/src/include/duckdb/parser/result_modifier.hpp b/src/duckdb/src/include/duckdb/parser/result_modifier.hpp index 8e777c082..45b7a8aad 100644 --- a/src/duckdb/src/include/duckdb/parser/result_modifier.hpp +++ b/src/duckdb/src/include/duckdb/parser/result_modifier.hpp @@ -20,9 +20,11 @@ enum class ResultModifierType : uint8_t { LIMIT_MODIFIER = 1, ORDER_MODIFIER = 2, DISTINCT_MODIFIER = 3, - LIMIT_PERCENT_MODIFIER = 4 + LEGACY_LIMIT_PERCENT_MODIFIER = 4 }; +enum class LimitValueType : uint8_t { ROW_COUNT = 0, PERCENTAGE = 1 }; + const char *ToString(ResultModifierType value); ResultModifierType ResultModifierFromString(const char *value); @@ -92,15 +94,20 @@ class LimitModifier : public ResultModifier { LimitModifier() : ResultModifier(ResultModifierType::LIMIT_MODIFIER) { } - //! LIMIT count + //! LIMIT unique_ptr limit; //! OFFSET unique_ptr offset; + //! Limit type (row count or percentage) + LimitValueType limit_type = LimitValueType::ROW_COUNT; public: bool Equals(const ResultModifier &other) const override; unique_ptr Copy() const override; + bool UseLegacySerialization() const; + void LegacySerialize(Serializer &serializer) const; + void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); }; @@ -145,12 +152,12 @@ class DistinctModifier : public ResultModifier { static unique_ptr Deserialize(Deserializer &deserializer); }; -class LimitPercentModifier : public ResultModifier { +class LegacyLimitPercentModifier : public ResultModifier { public: - static constexpr const ResultModifierType TYPE = ResultModifierType::LIMIT_PERCENT_MODIFIER; + static constexpr const ResultModifierType TYPE = ResultModifierType::LEGACY_LIMIT_PERCENT_MODIFIER; public: - LimitPercentModifier() : ResultModifier(ResultModifierType::LIMIT_PERCENT_MODIFIER) { + LegacyLimitPercentModifier() : ResultModifier(ResultModifierType::LEGACY_LIMIT_PERCENT_MODIFIER) { } //! LIMIT % @@ -161,9 +168,11 @@ class LimitPercentModifier : public ResultModifier { public: bool Equals(const ResultModifier &other) const override; unique_ptr Copy() const override; - void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + + static unique_ptr DeserializeLegacyLimitPercentModifier(unique_ptr limit, + unique_ptr offset); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/sql_statement.hpp b/src/duckdb/src/include/duckdb/parser/sql_statement.hpp index 2e8278d64..d2cbb25e1 100644 --- a/src/duckdb/src/include/duckdb/parser/sql_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/sql_statement.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/enums/statement_type.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/printer.hpp" @@ -33,7 +34,7 @@ class SQLStatement { //! The statement length within the query string idx_t stmt_length = 0; //! The map of named parameter to param index - case_insensitive_map_t named_param_map; + identifier_map_t named_param_map; //! The query text that corresponds to this SQL statement string query; diff --git a/src/duckdb/src/include/duckdb/parser/statement/connect_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/connect_statement.hpp new file mode 100644 index 000000000..8b87aaed6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/connect_statement.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/connect_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/connect_info.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class ConnectStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::CONNECT_STATEMENT; + +public: + ConnectStatement(); + + unique_ptr info; + +protected: + ConnectStatement(const ConnectStatement &other); + +public: + unique_ptr Copy() const override; + string ToString() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/copy_database_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/copy_database_statement.hpp index a70ca1193..cbf4fcf36 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/copy_database_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/copy_database_statement.hpp @@ -19,10 +19,10 @@ class CopyDatabaseStatement : public SQLStatement { static constexpr const StatementType TYPE = StatementType::COPY_DATABASE_STATEMENT; public: - CopyDatabaseStatement(string from_database, string to_database, CopyDatabaseType copy_type); + CopyDatabaseStatement(Identifier from_database, Identifier to_database, CopyDatabaseType copy_type); - string from_database; - string to_database; + Identifier from_database; + Identifier to_database; CopyDatabaseType copy_type; protected: diff --git a/src/duckdb/src/include/duckdb/parser/statement/disconnect_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/disconnect_statement.hpp new file mode 100644 index 000000000..c1c6b733c --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/statement/disconnect_statement.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/statement/disconnect_statement.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_data/disconnect_info.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class DisconnectStatement : public SQLStatement { +public: + static constexpr const StatementType TYPE = StatementType::DISCONNECT_STATEMENT; + +public: + DisconnectStatement(); + + unique_ptr info; + +protected: + DisconnectStatement(const DisconnectStatement &other); + +public: + unique_ptr Copy() const override; + string ToString() const override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp index 297ab031a..9971aa290 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/execute_statement.hpp @@ -20,8 +20,8 @@ class ExecuteStatement : public SQLStatement { public: ExecuteStatement(); - string name; - case_insensitive_map_t> named_values; + Identifier name; + identifier_map_t> named_values; protected: ExecuteStatement(const ExecuteStatement &other); diff --git a/src/duckdb/src/include/duckdb/parser/statement/insert_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/insert_statement.hpp index 80acbbd2e..5b1f12e9a 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/insert_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/insert_statement.hpp @@ -45,7 +45,7 @@ class OnConflictInfo { public: OnConflictAction action_type; - vector indexed_columns; + vector indexed_columns; //! The SET information (if action_type == UPDATE) unique_ptr set_info; //! The condition determining whether we apply the DO .. for conflicts that arise diff --git a/src/duckdb/src/include/duckdb/parser/statement/list.hpp b/src/duckdb/src/include/duckdb/parser/statement/list.hpp index 117a3e44c..77a8f7d20 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/list.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/list.hpp @@ -1,11 +1,13 @@ #include "duckdb/parser/statement/alter_statement.hpp" #include "duckdb/parser/statement/attach_statement.hpp" #include "duckdb/parser/statement/call_statement.hpp" +#include "duckdb/parser/statement/connect_statement.hpp" #include "duckdb/parser/statement/copy_statement.hpp" #include "duckdb/parser/statement/copy_database_statement.hpp" #include "duckdb/parser/statement/create_statement.hpp" #include "duckdb/parser/statement/delete_statement.hpp" #include "duckdb/parser/statement/detach_statement.hpp" +#include "duckdb/parser/statement/disconnect_statement.hpp" #include "duckdb/parser/statement/drop_statement.hpp" #include "duckdb/parser/statement/execute_statement.hpp" #include "duckdb/parser/statement/explain_statement.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/statement/merge_into_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/merge_into_statement.hpp index 36d9899d2..d4a9f5ea5 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/merge_into_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/merge_into_statement.hpp @@ -17,6 +17,9 @@ #include "duckdb/parser/statement/insert_statement.hpp" namespace duckdb { +class Serializer; +class Deserializer; +class MergeQueryNode; class MergeIntoAction { public: @@ -27,7 +30,7 @@ class MergeIntoAction { //! The SET information (if action_type == MERGE_UPDATE) unique_ptr update_info; //! Column names to insert into (if action_type == MERGE_INSERT) - vector insert_columns; + vector insert_columns; //! Set of expressions for INSERT vector> expressions; //! INSERT BY POSITION or INSERT BY NAME @@ -35,10 +38,14 @@ class MergeIntoAction { //! Whether or not this is a INSERT DEFAULT VALUES bool default_values = false; //! INSERT_BY_NAME exclude column list - columns to skip when generating the SET list - unordered_set exclude_columns; + identifier_set_t exclude_columns; string ToString() const; unique_ptr Copy() const; + static bool Equals(const MergeIntoAction &left, const MergeIntoAction &right); + + void Serialize(Serializer &serializer) const; + static unique_ptr Deserialize(Deserializer &deserializer); }; class MergeIntoStatement : public SQLStatement { @@ -48,18 +55,7 @@ class MergeIntoStatement : public SQLStatement { public: MergeIntoStatement(); - unique_ptr target; - unique_ptr source; - unique_ptr join_condition; - vector using_columns; - - map>> actions; - - //! keep track of optional returningList if statement contains a RETURNING keyword - vector> returning_list; - - //! CTEs - CommonTableExpressionMap cte_map; + unique_ptr node; protected: MergeIntoStatement(const MergeIntoStatement &other); @@ -67,8 +63,6 @@ class MergeIntoStatement : public SQLStatement { public: string ToString() const override; unique_ptr Copy() const override; - - static string ActionConditionToString(MergeActionCondition condition); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp index c47cb60f0..26942fd98 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/prepare_statement.hpp @@ -20,7 +20,7 @@ class PrepareStatement : public SQLStatement { PrepareStatement(); unique_ptr statement; - string name; + Identifier name; protected: PrepareStatement(const PrepareStatement &other); diff --git a/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp index f7051f893..7c87ddf79 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/set_statement.hpp @@ -20,18 +20,18 @@ class SetStatement : public SQLStatement { static constexpr const StatementType TYPE = StatementType::SET_STATEMENT; protected: - SetStatement(string name_p, SetScope scope_p, SetType type_p); + SetStatement(Identifier name_p, SetScope scope_p, SetType type_p); SetStatement(const SetStatement &other) = default; public: - string name; + Identifier name; SetScope scope; SetType set_type; }; class SetVariableStatement : public SetStatement { public: - SetVariableStatement(string name_p, unique_ptr value_p, SetScope scope_p); + SetVariableStatement(Identifier name_p, unique_ptr value_p, SetScope scope_p); protected: SetVariableStatement(const SetVariableStatement &other); @@ -46,7 +46,7 @@ class SetVariableStatement : public SetStatement { class ResetVariableStatement : public SetStatement { public: - ResetVariableStatement(std::string name_p, SetScope scope_p); + ResetVariableStatement(Identifier name_p, SetScope scope_p); protected: ResetVariableStatement(const ResetVariableStatement &other) = default; diff --git a/src/duckdb/src/include/duckdb/parser/statement/update_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/update_statement.hpp index 9f9a11e1f..62c1499cb 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/update_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/update_statement.hpp @@ -35,7 +35,7 @@ class UpdateSetInfo { // The condition that needs to be met to perform the update unique_ptr condition; // The columns to update - vector columns; + vector columns; // The set expressions to execute vector> expressions; diff --git a/src/duckdb/src/include/duckdb/parser/tableref.hpp b/src/duckdb/src/include/duckdb/parser/tableref.hpp index 1a0d81692..2bbff0927 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/enums/tableref_type.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" @@ -28,7 +29,7 @@ class TableRef { } TableReferenceType type; - string alias; + Identifier alias; //! Sample options (if any) unique_ptr sample; //! The location in the query (if any) @@ -36,7 +37,7 @@ class TableRef { //! External dependencies of this table function shared_ptr external_dependency; //! Aliases for the column names - vector column_name_alias; + vector column_name_alias; public: //! Convert the object to a string @@ -73,8 +74,8 @@ class TableRef { protected: string BaseToString(string result) const; - string BaseToString(string result, const vector &column_name_alias) const; - string AliasToString(const vector &column_name_alias) const; + string BaseToString(string result, const vector &column_name_alias) const; + string AliasToString(const vector &column_name_alias) const; string SampleToString() const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp index b2339d9bc..89ae24aea 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/basetableref.hpp @@ -29,11 +29,11 @@ class BaseTableRef : public TableRef { } //! The catalog name. - string catalog_name; + Identifier catalog_name; //! The schema name. - string schema_name; + Identifier schema_name; //! The table name. - string table_name; + Identifier table_name; //! The timestamp/version at which to read this table entry (if any) unique_ptr at_clause; diff --git a/src/duckdb/src/include/duckdb/parser/tableref/column_data_ref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/column_data_ref.hpp index 7d6f93a57..a01bd126d 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/column_data_ref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/column_data_ref.hpp @@ -21,11 +21,11 @@ class ColumnDataRef : public TableRef { public: explicit ColumnDataRef(optionally_owned_ptr collection_p, - vector expected_names = vector()); + vector expected_names = vector()); public: //! The set of expected names - vector expected_names; + vector expected_names; //! (Optionally) the owned collection optionally_owned_ptr collection; diff --git a/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp index 1cf31d1eb..9d936b215 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/delimgetref.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/tableref.hpp" namespace duckdb { @@ -21,7 +22,7 @@ class DelimGetRef : public TableRef { } } - vector internal_aliases; + vector internal_aliases; vector types; public: diff --git a/src/duckdb/src/include/duckdb/parser/tableref/expressionlistref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/expressionlistref.hpp index a38bd422b..7e98a109b 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/expressionlistref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/expressionlistref.hpp @@ -30,7 +30,7 @@ class ExpressionListRef : public TableRef { //! Expected table types. vector expected_types; //! Expected table names. - vector expected_names; + vector expected_names; public: string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp index f31992aef..c49f41441 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/enums/joinref_type.hpp" #include "duckdb/parser/parsed_expression.hpp" @@ -37,7 +38,7 @@ class JoinRef : public TableRef { //! Join condition type JoinRefType ref_type; //! The set of USING columns (if any) - vector using_columns; + vector using_columns; //! Duplicate eliminated columns (if any) vector> duplicate_eliminated_columns; //! If we have duplicate eliminated columns if the delim is flipped diff --git a/src/duckdb/src/include/duckdb/parser/tableref/pivotref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/pivotref.hpp index 6ad7bea7a..87c105016 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/pivotref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/pivotref.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/tableref.hpp" #include "duckdb/parser/query_node/select_node.hpp" @@ -19,7 +20,7 @@ struct PivotColumnEntry { //! The expression (UNPIVOT only) unique_ptr expr; //! The alias of the pivot column entry - string alias; + Identifier alias; bool Equals(const PivotColumnEntry &other) const; PivotColumnEntry Copy() const; @@ -37,11 +38,11 @@ struct PivotColumn { //! The set of expressions to pivot on vector> pivot_expressions; //! The set of unpivot names - vector unpivot_names; + vector unpivot_names; //! The set of values to pivot on vector entries; //! The enum to read pivot values from (if any) - string pivot_enum; + Identifier pivot_enum; //! Subquery (if any) - used during transform only unique_ptr subquery; @@ -69,19 +70,19 @@ class PivotRef : public TableRef { //! The aggregates to compute over the pivot (PIVOT only) vector> aggregates; //! The names of the unpivot expressions (UNPIVOT only) - vector unpivot_names; + vector unpivot_names; //! The set of pivots vector pivots; //! The groups to pivot over. If none are specified all columns not included in the pivots/aggregate are chosen. - vector groups; + vector groups; //! Whether or not to include nulls in the result (UNPIVOT only) bool include_nulls; //! The set of values to pivot on (bound pivot only) vector bound_pivot_values; //! The set of bound group names (bound pivot only) - vector bound_group_names; + vector bound_group_names; //! The set of bound aggregate names (bound pivot only) - vector bound_aggregate_names; + vector bound_aggregate_names; public: string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp index 78a1e1266..1d63b64a7 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/parser/tableref.hpp" #include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/query_node.hpp" @@ -25,11 +26,11 @@ class ShowRef : public TableRef { ShowRef(); //! The table name (if any) - string table_name; + Identifier table_name; //! The catalog name (if any) - string catalog_name; + Identifier catalog_name; //! The schema name (if any) - string schema_name; + Identifier schema_name; //! The QueryNode of select query (if any) unique_ptr query; //! Whether or not we are requesting a summary or a describe diff --git a/src/duckdb/src/include/duckdb/parser/tableref/subqueryref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/subqueryref.hpp index f93b45ab3..086661f7a 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/subqueryref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/subqueryref.hpp @@ -21,7 +21,7 @@ class SubqueryRef : public TableRef { SubqueryRef(); public: - DUCKDB_API explicit SubqueryRef(unique_ptr subquery, string alias = string()); + DUCKDB_API explicit SubqueryRef(unique_ptr subquery, Identifier alias = Identifier()); //! The subquery unique_ptr subquery; diff --git a/src/duckdb/src/include/duckdb/parser/tokens.hpp b/src/duckdb/src/include/duckdb/parser/tokens.hpp index d66bdcc08..8ad75b94f 100644 --- a/src/duckdb/src/include/duckdb/parser/tokens.hpp +++ b/src/duckdb/src/include/duckdb/parser/tokens.hpp @@ -18,9 +18,11 @@ class SQLStatement; class AlterStatement; class AttachStatement; class CallStatement; +class ConnectStatement; class CopyStatement; class CreateStatement; class DetachStatement; +class DisconnectStatement; class DeleteStatement; class DropStatement; class ExtensionStatement; @@ -57,6 +59,7 @@ class StatementNode; class UpdateQueryNode; class DeleteQueryNode; class InsertQueryNode; +class MergeQueryNode; //===--------------------------------------------------------------------===// // Expressions @@ -79,6 +82,7 @@ class ParameterExpression; class PositionalReferenceExpression; class StarExpression; class SubqueryExpression; +class TypeExpression; class WindowExpression; //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/planner/bind_context.hpp b/src/duckdb/src/include/duckdb/planner/bind_context.hpp index 6dc3e0a55..f27c22be6 100644 --- a/src/duckdb/src/include/duckdb/planner/bind_context.hpp +++ b/src/duckdb/src/include/duckdb/planner/bind_context.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/column_index.hpp" @@ -46,36 +47,36 @@ class BindContext { public: //! Given a column name, find the matching table it belongs to. Throws an //! exception if no table has a column of the given name. - optional_ptr GetMatchingBinding(const string &column_name, + optional_ptr GetMatchingBinding(const Identifier &column_name, QueryErrorContext context = QueryErrorContext()); //! Like GetMatchingBinding, but instead of throwing an error if multiple tables have the same binding it will //! return a list of all the matching ones - vector> GetMatchingBindings(const string &column_name); + vector> GetMatchingBindings(const Identifier &column_name); //! Like GetMatchingBindings, but returns the top 3 most similar bindings (in levenshtein distance) instead of the //! matching ones - vector GetSimilarBindings(const string &column_name); + vector GetSimilarBindings(const Identifier &column_name); optional_ptr GetCTEBinding(const BindingAlias &ctename); //! Binds a column expression to the base table. Returns the bound expression //! or throws an exception if the column could not be bound. BindResult BindColumn(ColumnRefExpression &colref, idx_t depth); - string BindColumn(PositionalReferenceExpression &ref, string &table_name, string &column_name); + string BindColumn(PositionalReferenceExpression &ref, Identifier &table_name, Identifier &column_name); unique_ptr PositionToColumn(PositionalReferenceExpression &ref); - unique_ptr ExpandGeneratedColumn(TableBinding &table_binding, const string &column_name); + unique_ptr ExpandGeneratedColumn(TableBinding &table_binding, const Identifier &column_name); unique_ptr - CreateColumnReference(const string &table_name, const string &column_name, + CreateColumnReference(const Identifier &table_name, const Identifier &column_name, ColumnBindType bind_type = ColumnBindType::EXPAND_GENERATED_COLUMNS); unique_ptr - CreateColumnReference(const string &schema_name, const string &table_name, const string &column_name, + CreateColumnReference(const Identifier &schema_name, const Identifier &table_name, const Identifier &column_name, ColumnBindType bind_type = ColumnBindType::EXPAND_GENERATED_COLUMNS); unique_ptr - CreateColumnReference(const string &catalog_name, const string &schema_name, const string &table_name, - const string &column_name, + CreateColumnReference(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &table_name, + const Identifier &column_name, ColumnBindType bind_type = ColumnBindType::EXPAND_GENERATED_COLUMNS); unique_ptr - CreateColumnReference(const BindingAlias &table_alias, const string &column_name, + CreateColumnReference(const BindingAlias &table_alias, const Identifier &column_name, ColumnBindType bind_type = ColumnBindType::EXPAND_GENERATED_COLUMNS); //! Generate column expressions for all columns that are present in the @@ -88,68 +89,68 @@ class BindContext { } vector GetBindingAliases(); - void GetTypesAndNames(vector &result_names, vector &result_types); + void GetTypesAndNames(vector &result_names, vector &result_types); //! Adds a base table with the given alias to the BindContext. - void AddBaseTable(TableIndex index, const string &alias, const vector &names, + void AddBaseTable(TableIndex index, const Identifier &alias, const vector &names, const vector &types, vector &bound_column_ids, TableCatalogEntry &entry, bool add_row_id = true); - void AddBaseTable(TableIndex index, const string &alias, const vector &names, + void AddBaseTable(TableIndex index, const Identifier &alias, const vector &names, const vector &types, vector &bound_column_ids, - const string &table_name); - void AddBaseTable(TableIndex index, const string &alias, const vector &names, + const Identifier &table_name); + void AddBaseTable(TableIndex index, const Identifier &alias, const vector &names, const vector &types, vector &bound_column_ids, TableCatalogEntry &entry, virtual_column_map_t virtual_columns); //! Adds a call to a table function with the given alias to the BindContext. - void AddTableFunction(TableIndex index, const string &alias, const vector &names, + void AddTableFunction(TableIndex index, const Identifier &alias, const vector &names, const vector &types, vector &bound_column_ids, optional_ptr entry, virtual_column_map_t virtual_columns); //! Adds a table view with a given alias to the BindContext. - void AddView(TableIndex index, const string &alias, SubqueryRef &ref, BoundStatement &subquery, + void AddView(TableIndex index, const Identifier &alias, SubqueryRef &ref, BoundStatement &subquery, ViewCatalogEntry &view); //! Adds a subquery with a given alias to the BindContext. - void AddSubquery(TableIndex index, const string &alias, SubqueryRef &ref, BoundStatement &subquery); + void AddSubquery(TableIndex index, const Identifier &alias, SubqueryRef &ref, BoundStatement &subquery); //! Adds a subquery with a given alias to the BindContext. - void AddSubquery(TableIndex index, const string &alias, TableFunctionRef &ref, BoundStatement &subquery); + void AddSubquery(TableIndex index, const Identifier &alias, TableFunctionRef &ref, BoundStatement &subquery); //! Adds a binding to a catalog entry with a given alias to the BindContext. - void AddEntryBinding(TableIndex index, const string &alias, const vector &names, + void AddEntryBinding(TableIndex index, const Identifier &alias, const vector &names, const vector &types, StandardEntry &entry); //! Adds a base table with the given alias to the BindContext. - void AddGenericBinding(TableIndex index, const string &alias, const vector &names, + void AddGenericBinding(TableIndex index, const Identifier &alias, const vector &names, const vector &types); //! Adds a base table with the given alias to the CTE BindContext. //! We need this to correctly bind recursive CTEs with multiple references. - void AddCTEBinding(TableIndex index, BindingAlias alias, const vector &names, + void AddCTEBinding(TableIndex index, BindingAlias alias, const vector &names, const vector &types, CTEType cte_type = CTEType::CAN_BE_REFERENCED); void AddCTEBinding(unique_ptr binding); //! Add an implicit join condition (e.g. USING (x)) - void AddUsingBinding(const string &column_name, UsingColumnSet &set); + void AddUsingBinding(const Identifier &column_name, UsingColumnSet &set); void AddUsingBindingSet(unique_ptr set); //! Returns any using column set for the given column name, or nullptr if there is none. On conflict (multiple using //! column sets with the same name) throw an exception. - optional_ptr GetUsingBinding(const string &column_name); + optional_ptr GetUsingBinding(const Identifier &column_name); //! Returns any using column set for the given column name, or nullptr if there is none - optional_ptr GetUsingBinding(const string &column_name, const BindingAlias &binding); + optional_ptr GetUsingBinding(const Identifier &column_name, const BindingAlias &binding); //! Erase a using binding from the set of using bindings - void RemoveUsingBinding(const string &column_name, UsingColumnSet &set); + void RemoveUsingBinding(const Identifier &column_name, UsingColumnSet &set); //! Transfer a using binding from one bind context to this bind context void TransferUsingBinding(BindContext ¤t_context, optional_ptr current_set, - UsingColumnSet &new_set, const string &using_column); + UsingColumnSet &new_set, const Identifier &using_column); //! Fetch the actual column name from the given binding, or throws if none exists //! This can be different from "column_name" because of case insensitivity //! (e.g. "column_name" might return "COLUMN_NAME") - string GetActualColumnName(const BindingAlias &binding_alias, const string &column_name); - string GetActualColumnName(Binding &binding, const string &column_name); + string GetActualColumnName(const BindingAlias &binding_alias, const Identifier &column_name); + string GetActualColumnName(Binding &binding, const Identifier &column_name); //! Alias a set of column names for the specified table, using the original names if there are not enough aliases //! specified. - static vector AliasColumnNames(const string &table_name, const vector &names, - const vector &column_aliases); + static vector AliasColumnNames(const Identifier &table_name, const vector &names, + const vector &column_aliases); //! Add all the bindings from a BindContext to this BindContext. The other BindContext is destroyed in the process. void AddContext(BindContext other); @@ -158,11 +159,11 @@ class BindContext { //! Gets a binding of the specified name. Returns a nullptr and sets the out_error if the binding could not be //! found. - optional_ptr GetBinding(const string &name, ErrorData &out_error); + optional_ptr GetBinding(const Identifier &name, ErrorData &out_error); optional_ptr GetBinding(const BindingAlias &alias, ErrorData &out_error); - optional_ptr GetBinding(const BindingAlias &alias, const string &column_name, ErrorData &out_error); + optional_ptr GetBinding(const BindingAlias &alias, const Identifier &column_name, ErrorData &out_error); //! Get all bindings that match a specific binding alias - returns an error if none match vector> GetBindings(const BindingAlias &alias, ErrorData &out_error); @@ -176,7 +177,7 @@ class BindContext { //! The list of bindings in insertion order vector> bindings_list; //! The set of columns used in USING join conditions - case_insensitive_map_t> using_columns; + identifier_map_t> using_columns; //! The set of CTE bindings vector> cte_bindings; }; diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 76b9dede1..725254973 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/enums/merge_action_type.hpp" @@ -90,15 +91,15 @@ enum class BinderType : uint8_t { REGULAR_BINDER, VIEW_BINDER }; struct CorrelatedColumnInfo { ColumnBinding binding; LogicalType type; - string name; + Identifier name; idx_t depth; // NOLINTNEXTLINE - work-around bug in clang-tidy - CorrelatedColumnInfo(ColumnBinding binding, LogicalType type_p, string name_p, idx_t depth) + CorrelatedColumnInfo(ColumnBinding binding, LogicalType type_p, Identifier name_p, idx_t depth) : binding(binding), type(std::move(type_p)), name(std::move(name_p)), depth(depth) { } explicit CorrelatedColumnInfo(BoundColumnRefExpression &expr) - : CorrelatedColumnInfo(expr.binding, expr.GetReturnType(), expr.GetName(), expr.depth) { + : CorrelatedColumnInfo(expr.Binding(), expr.GetReturnType(), expr.GetName(), expr.Depth()) { } bool operator==(const CorrelatedColumnInfo &rhs) const { @@ -180,7 +181,7 @@ struct GlobalBinderState { //! Table names extracted for BindingMode::EXTRACT_NAMES or BindingMode::EXTRACT_QUALIFIED_NAMES. unordered_set table_names; //! Replacement Scans extracted for BindingMode::EXTRACT_REPLACEMENT_SCANS - case_insensitive_map_t> replacement_scans; + identifier_map_t> replacement_scans; //! Using column sets vector> using_column_sets; //! The set of parameter expressions bound by this binder @@ -190,13 +191,7 @@ struct GlobalBinderState { //! Set during CREATE TRIGGER body validation to detect self-recursive writes optional_ptr trigger_creation_table; //! Name of the trigger being created (for error messages) - string trigger_creation_name; -}; - -// QueryBinderState is state shared WITHIN a query, a new query-binder state is created when binding inside e.g. a view -struct QueryBinderState { - //! The vector of active binders - vector> active_binders; + Identifier trigger_creation_name; }; //! Bind the parsed query tree to the actual columns present in the catalog. @@ -206,6 +201,7 @@ struct QueryBinderState { all expressions. */ class Binder : public enable_shared_from_this { + friend class ColumnQualifier; friend class ExpressionBinder; friend class RecursiveDependentJoinPlanner; @@ -221,7 +217,7 @@ class Binder : public enable_shared_from_this { //! vector) CorrelatedColumns correlated_columns; //! The alias for the currently processing subquery, if it exists - string alias; + Identifier alias; //! Macro parameter bindings (if any) optional_ptr macro_binding; //! The intermediate lambda bindings to bind nested lambdas (if any) @@ -244,25 +240,25 @@ class Binder : public enable_shared_from_this { static vector> BindConstraints(ClientContext &context, const vector> &constraints, - const string &table_name, const ColumnList &columns); + const Identifier &table_name, const ColumnList &columns); vector> BindConstraints(const vector> &constraints, - const string &table_name, const ColumnList &columns); + const Identifier &table_name, const ColumnList &columns); vector> BindConstraints(const TableCatalogEntry &table); vector> BindNewConstraints(vector> &constraints, - const string &table_name, const ColumnList &columns); - unique_ptr BindConstraint(const Constraint &constraint, const string &table, + const Identifier &table_name, const ColumnList &columns); + unique_ptr BindConstraint(const Constraint &constraint, const Identifier &table, const ColumnList &columns); - unique_ptr BindUniqueConstraint(const Constraint &constraint, const string &table, + unique_ptr BindUniqueConstraint(const Constraint &constraint, const Identifier &table, const ColumnList &columns); BoundStatement BindAlterAddIndex(BoundStatement &result, CatalogEntry &entry, unique_ptr alter_info); void SetCatalogLookupCallback(catalog_entry_callback_t callback); void BindCreateViewInfo(CreateViewInfo &base); - static void BindView(ClientContext &context, const SelectStatement &stmt, const string &catalog_name, - const string &schema_name, optional_ptr dependencies, - const vector &aliases, vector &result_types, - vector &result_names); + static void BindView(ClientContext &context, const SelectStatement &stmt, const Identifier &catalog_name, + const Identifier &schema_name, optional_ptr dependencies, + const vector &aliases, vector &result_types, + vector &result_names); void SearchSchema(CreateInfo &info); SchemaCatalogEntry &BindSchema(CreateInfo &info); @@ -271,7 +267,7 @@ class Binder : public enable_shared_from_this { //! Check usage, and cast named parameters to their types static void BindNamedParameters(named_parameter_type_map_t &types, named_parameter_map_t &values, - QueryErrorContext &error_context, string &func_name); + QueryErrorContext &error_context, const Identifier &func_name); unique_ptr BindPragma(PragmaInfo &info, QueryErrorContext error_context); BoundStatement Bind(TableRef &ref); @@ -279,7 +275,7 @@ class Binder : public enable_shared_from_this { //! Generates an unused index for a table TableIndex GenerateTableIndex(); - optional_ptr GetCatalogEntry(const string &catalog, const string &schema, + optional_ptr GetCatalogEntry(const Identifier &catalog, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found); //! Find all candidate common table expression by name; returns empty vector if none exists @@ -288,11 +284,10 @@ class Binder : public enable_shared_from_this { //! Add the view to the set of currently bound views - used for detecting recursive view definitions void AddBoundView(ViewCatalogEntry &view); - void PushExpressionBinder(ExpressionBinder &binder); - void PopExpressionBinder(); - void SetActiveBinder(ExpressionBinder &binder); + void BeginSubqueryBind(Binder &parent, ExpressionBinder &binder); ExpressionBinder &GetActiveBinder(); bool HasActiveBinder(); + void FinishSubqueryBind(); vector> &GetActiveBinders(); @@ -312,22 +307,24 @@ class Binder : public enable_shared_from_this { void BindVacuumTable(LogicalVacuum &vacuum, unique_ptr &root); - static void BindSchemaOrCatalog(ClientContext &context, string &catalog, string &schema); + static void BindSchemaOrCatalog(ClientContext &context, Identifier &catalog, Identifier &schema); void BindLogicalType(LogicalType &type); - optional_ptr GetMatchingBinding(const string &table_name, const string &column_name, ErrorData &error); - optional_ptr GetMatchingBinding(const string &schema_name, const string &table_name, - const string &column_name, ErrorData &error); - optional_ptr GetMatchingBinding(const string &catalog_name, const string &schema_name, - const string &table_name, const string &column_name, ErrorData &error); + optional_ptr GetMatchingBinding(const Identifier &table_name, const Identifier &column_name, + ErrorData &error); + optional_ptr GetMatchingBinding(const Identifier &schema_name, const Identifier &table_name, + const Identifier &column_name, ErrorData &error); + optional_ptr GetMatchingBinding(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &table_name, const Identifier &column_name, + ErrorData &error); void SetBindingMode(BindingMode mode); BindingMode GetBindingMode(); void AddTableName(string table_name); - void AddReplacementScan(const string &table_name, unique_ptr replacement); + void AddReplacementScan(const Identifier &table_name, unique_ptr replacement); const unordered_set &GetTableNames(); - case_insensitive_map_t> &GetReplacementScans(); + identifier_map_t> &GetReplacementScans(); CatalogEntryRetriever &EntryRetriever() { return entry_retriever; } @@ -337,7 +334,10 @@ class Binder : public enable_shared_from_this { static optional_ptr GetResolvedColumnExpression(ParsedExpression &root_expr); void SetCanContainNulls(bool can_contain_nulls); + bool CanContainNulls() const; void SetAlwaysRequireRebind(); + void SetInsideSubquery(); + bool IsInsideSubquery() const; StatementProperties &GetStatementProperties(); static void ReplaceStarExpression(unique_ptr &expr, unique_ptr &replacement); @@ -346,10 +346,12 @@ class Binder : public enable_shared_from_this { unique_ptr UnionOperators(vector> nodes); - void SetSearchPath(Catalog &catalog, const string &schema); + void SetSearchPath(Catalog &catalog, const Identifier &schema); void BindDefaultValue(const ColumnDefinition &column, vector> &bound_defaults, const string &catalog = "", const string &schema = ""); + unique_ptr GetSQLValueFunction(const Identifier &column_name); + Identifier GetExpressionName(const ParsedExpression &expr); private: //! The parent binder (if any) @@ -358,14 +360,16 @@ class Binder : public enable_shared_from_this { BinderType binder_type = BinderType::REGULAR_BINDER; //! Global binder state shared_ptr global_binder_state; - //! Query binder state - shared_ptr query_binder_state; + //! Active binders + vector> active_binders; //! Whether or not the binder has any unplanned dependent joins that still need to be planned/flattened bool has_unplanned_dependent_joins = false; //! Whether or not outside dependent joins have been planned and flattened bool is_outside_flattened = true; - //! Whether or not the binder can contain NULLs as the root of expressions - bool can_contain_nulls = false; + //! LEGACY: Whether or not the binder can contain NULLs as the root of expressions + bool legacy_can_contain_nulls = false; + //! Whether this binder is inside a subquery boundary + bool inside_subquery = false; //! The set of bound views reference_set_t bound_views; //! Used to retrieve CatalogEntry's @@ -386,8 +390,8 @@ class Binder : public enable_shared_from_this { void BindDefaultValues(const ColumnList &columns, vector> &bound_defaults, const string &catalog = "", const string &schema = ""); //! Bind a limit value (LIMIT or OFFSET) - BoundLimitNode BindLimitValue(OrderBinder &order_binder, unique_ptr limit_val, bool is_percentage, - bool is_offset); + BoundLimitNode BindLimitValue(OrderBinder &order_binder, unique_ptr limit_val, + LimitValueType value_type, bool is_offset); //! Move correlated expressions from the child binder to this binder void MoveCorrelatedExpressions(Binder &other); @@ -395,8 +399,6 @@ class Binder : public enable_shared_from_this { //! Tries to bind the table name with replacement scans BoundStatement BindWithReplacementScan(ClientContext &context, BaseTableRef &ref); - template - BoundStatement BindWithCTE(T &statement); BoundStatement Bind(SelectStatement &stmt); BoundStatement Bind(InsertStatement &stmt); BoundStatement Bind(CopyStatement &stmt, CopyToType copy_to_type); @@ -425,6 +427,8 @@ class Binder : public enable_shared_from_this { BoundStatement Bind(CopyDatabaseStatement &stmt); BoundStatement Bind(UpdateExtensionsStatement &stmt); BoundStatement Bind(MergeIntoStatement &stmt); + BoundStatement Bind(ConnectStatement &stmt); + BoundStatement Bind(DisconnectStatement &stmt); //! Resolves the base table for DROP TRIGGER, stamps catalog/schema onto stmt.info, //! and registers the catalog modification. IF EXISTS only guards the trigger, not the table. @@ -447,13 +451,13 @@ class Binder : public enable_shared_from_this { vector> &projection_expressions, LogicalOperator &target_binding); BoundStatement BindReturning(vector> returning_list, TableCatalogEntry &table, - const string &alias, TableIndex update_table_index, + const Identifier &alias, TableIndex update_table_index, unique_ptr child_operator, virtual_column_map_t virtual_columns = virtual_column_map_t()); unique_ptr BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, idx_t depth); - BoundStatement BindCTE(const string &ctename, CommonTableExpressionInfo &info); + BoundStatement BindCTE(const Identifier &ctename, CommonTableExpressionInfo &info); BoundStatement BindNode(SelectNode &node); BoundStatement BindNode(SetOperationNode &node); @@ -461,13 +465,14 @@ class Binder : public enable_shared_from_this { BoundStatement BindNode(QueryNode &node); BoundStatement BindNode(StatementNode &node); BoundStatement BindNode(InsertQueryNode &node); - unique_ptr TryExpandAfterTriggers(QueryNode &node, - vector> &returning_list, - TableCatalogEntry &table, TriggerEventType event_type); - BoundStatement ExpandAfterTriggers(QueryNode &node, vector> &returning_list, - const vector> &triggers); + unique_ptr TryExpandTriggers(QueryNode &node, vector> &returning_list, + TableCatalogEntry &table, TriggerEventType event_type); + BoundStatement ExpandTriggers(QueryNode &node, vector> &returning_list, + const vector> &before_triggers, + const vector> &after_triggers); BoundStatement BindNode(UpdateQueryNode &node); BoundStatement BindNode(DeleteQueryNode &node); + BoundStatement BindNode(MergeQueryNode &node); unique_ptr VisitQueryNode(BoundQueryNode &node, unique_ptr root); unique_ptr CreatePlan(BoundSelectNode &statement); @@ -495,7 +500,7 @@ class Binder : public enable_shared_from_this { unique_ptr &where_clause); BoundStatement BindBoundPivot(PivotRef &expr); void ExtractUnpivotEntries(Binder &child_binder, PivotColumnEntry &entry, vector &unpivot_entries); - void ExtractUnpivotColumnName(ParsedExpression &expr, vector &result); + void ExtractUnpivotColumnName(ParsedExpression &expr, vector &result); unique_ptr BindAtClause(optional_ptr at_clause); @@ -507,7 +512,8 @@ class Binder : public enable_shared_from_this { BoundStatement BindTableFunction(TableFunction &function, vector parameters); BoundStatement BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, vector parameters, named_parameter_map_t named_parameters, - vector input_table_types, vector input_table_names, + vector input_table_types, + vector input_table_names, optional_ptr> input_plan); unique_ptr CreatePlan(BoundJoinRef &ref); @@ -518,11 +524,10 @@ class Binder : public enable_shared_from_this { case_insensitive_map_t GetFullCopyOptionsList(const CopyFunction &function, CopyOptionMode mode); void PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, BoundQueryNode &result); - void BindModifiers(BoundQueryNode &result, TableIndex table_index, const vector &names, + void BindModifiers(BoundQueryNode &result, TableIndex table_index, const vector &names, const vector &sql_types, const SelectBindState &bind_state); unique_ptr BindLimit(OrderBinder &order_binder, LimitModifier &limit_mod); - unique_ptr BindLimitPercent(OrderBinder &order_binder, LimitPercentModifier &limit_mod); unique_ptr BindOrderExpression(OrderBinder &order_binder, unique_ptr expr); unique_ptr PlanFilter(unique_ptr condition, unique_ptr root); @@ -538,12 +543,12 @@ class Binder : public enable_shared_from_this { const vector &target_types, unique_ptr op); - BindingAlias FindBinding(const string &using_column, const string &join_side); - bool TryFindBinding(const string &using_column, const string &join_side, BindingAlias &result); + BindingAlias FindBinding(const Identifier &using_column, const string &join_side); + bool TryFindBinding(const Identifier &using_column, const string &join_side, BindingAlias &result); void AddUsingBindingSet(unique_ptr set); BindingAlias RetrieveUsingBinding(Binder ¤t_binder, optional_ptr current_set, - const string &column_name, const string &join_side); + const Identifier &column_name, const string &join_side); void ExpandStarExpressions(vector> &select_list, vector> &new_select_list); @@ -556,48 +561,47 @@ class Binder : public enable_shared_from_this { void BindWhereStarExpression(unique_ptr &expr); //! If only a schema name is provided (e.g. "a.b") then figure out if "a" is a schema or a catalog name - void BindSchemaOrCatalog(string &catalog_name, string &schema_name); - static void BindSchemaOrCatalog(CatalogEntryRetriever &retriever, string &catalog, string &schema); - const string BindCatalog(string &catalog_name); + void BindSchemaOrCatalog(Identifier &catalog_name, Identifier &schema_name); + static void BindSchemaOrCatalog(CatalogEntryRetriever &retriever, Identifier &catalog, Identifier &schema); + Identifier BindCatalog(const Identifier &catalog_name); SchemaCatalogEntry &BindCreateSchema(CreateInfo &info); - vector GetSearchPath(Catalog &catalog, const string &schema_name); + vector GetSearchPath(Catalog &catalog, const Identifier &schema_name); LogicalType BindLogicalTypeInternal(const unique_ptr &type_expr); BoundStatement BindSelectNode(SelectNode &statement, BoundStatement from_table); - unique_ptr BindCopyDatabaseSchema(Catalog &source_catalog, const string &target_database_name); - unique_ptr BindCopyDatabaseData(Catalog &source_catalog, const string &target_database_name); + unique_ptr BindCopyDatabaseSchema(Catalog &source_catalog, const Identifier &target_database_name); + unique_ptr BindCopyDatabaseData(Catalog &source_catalog, const Identifier &target_database_name); BoundStatement BindShowQuery(ShowRef &ref); BoundStatement BindShowTable(ShowRef &ref); BoundStatement BindSummarize(ShowRef &ref); - void BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, + void BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, vector &named_column_map, vector &expected_types, IndexVector &column_index_map); void TryReplaceDefaultExpression(unique_ptr &expr, const ColumnDefinition &column); void ExpandDefaultInValuesList(InsertQueryNode &node, TableCatalogEntry &table, optional_ptr values_list, const vector &named_column_map); - unique_ptr BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, - LogicalGet &get, TableIndex proj_index, - vector> &expressions, - unique_ptr &root, MergeIntoAction &action, - const vector &source_aliases, - const vector &source_names, MergeActionCondition condition, - const unordered_set &source_table_indices); + unique_ptr + BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, LogicalGet &get, TableIndex proj_index, + vector> &expressions, MergeIntoAction &action, + const vector &source_aliases, const vector &source_names); unique_ptr GenerateMergeInto(InsertQueryNode &node, TableCatalogEntry &table); static void CheckInsertColumnCountMismatch(idx_t expected_columns, idx_t result_columns, bool columns_provided, - const string &tname); + const Identifier &tname); - BoundCTEData PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement); + BoundCTEData PrepareCTE(const Identifier &ctename, CommonTableExpressionInfo &statement); BoundStatement FinishCTE(BoundCTEData &bound_cte, BoundStatement child_data); - shared_ptr CreateBinderWithSearchPath(const string &catalog_name, const string &schema_name); + shared_ptr CreateBinderWithSearchPath(const Identifier &catalog_name, const Identifier &schema_name); + + bool DebugAggregateStateExportVerify(BoundSelectNode &statement, unique_ptr &root); private: Binder(ClientContext &context, shared_ptr parent, BinderType binder_type); diff --git a/src/duckdb/src/include/duckdb/planner/binding_alias.hpp b/src/duckdb/src/include/duckdb/planner/binding_alias.hpp index 9d75738bc..803afd732 100644 --- a/src/duckdb/src/include/duckdb/planner/binding_alias.hpp +++ b/src/duckdb/src/include/duckdb/planner/binding_alias.hpp @@ -9,24 +9,25 @@ #pragma once #include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { class StandardEntry; struct BindingAlias { BindingAlias(); - explicit BindingAlias(string alias); - BindingAlias(string schema, string alias); - BindingAlias(string catalog, string schema, string alias); + explicit BindingAlias(Identifier alias); + BindingAlias(Identifier schema, Identifier alias); + BindingAlias(Identifier catalog, Identifier schema, Identifier alias); explicit BindingAlias(const StandardEntry &entry); bool IsSet() const; - const string &GetAlias() const; + const Identifier &GetAlias() const; - const string &GetCatalog() const { + const Identifier &GetCatalog() const { return catalog; } - const string &GetSchema() const { + const Identifier &GetSchema() const { return schema; } @@ -35,9 +36,9 @@ struct BindingAlias { string ToString() const; private: - string catalog; - string schema; - string alias; + Identifier catalog; + Identifier schema; + Identifier alias; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_parameter_map.hpp b/src/duckdb/src/include/duckdb/planner/bound_parameter_map.hpp index 33ce5cf00..a5e3cbdc2 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_parameter_map.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_parameter_map.hpp @@ -18,20 +18,20 @@ namespace duckdb { class ParameterExpression; class BoundParameterExpression; -using bound_parameter_map_t = case_insensitive_map_t>; +using bound_parameter_map_t = identifier_map_t>; struct BoundParameterMap { public: - explicit BoundParameterMap(case_insensitive_map_t ¶meter_data); + explicit BoundParameterMap(identifier_map_t ¶meter_data); public: - LogicalType GetReturnType(const string &identifier); + LogicalType GetReturnType(const Identifier &identifier); bound_parameter_map_t *GetParametersPtr(); const bound_parameter_map_t &GetParameters(); - const case_insensitive_map_t &GetParameterData(); + const identifier_map_t &GetParameterData(); unique_ptr BindParameterExpression(ParameterExpression &expr); @@ -39,13 +39,13 @@ struct BoundParameterMap { bool rebind = false; private: - shared_ptr CreateOrGetData(const string &identifier); + shared_ptr CreateOrGetData(const Identifier &identifier); void CreateNewParameter(const string &id, const shared_ptr ¶m_data); private: bound_parameter_map_t parameters; // Pre-provided parameter data if populated - case_insensitive_map_t ¶meter_data; + identifier_map_t ¶meter_data; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp index 898a0161f..22a3f4628 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp @@ -23,7 +23,7 @@ class BoundQueryNode { vector> modifiers; //! The names returned by this QueryNode. - vector names; + vector names; //! The types returned by this QueryNode. vector types; diff --git a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp index 23fae54d6..474399aa7 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp @@ -32,7 +32,7 @@ struct ExtraBoundInfo { struct BoundStatement { unique_ptr plan; vector types; - vector names; + vector names; ExtraBoundInfo extra_info; }; diff --git a/src/duckdb/src/include/duckdb/planner/column_qualifier.hpp b/src/duckdb/src/include/duckdb/planner/column_qualifier.hpp new file mode 100644 index 000000000..e06f053b8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/column_qualifier.hpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/column_qualifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" + +namespace duckdb { +class ColumnAliasBinder; +class HavingBinder; + +class ColumnQualifier { +public: + explicit ColumnQualifier(Binder &binder, optional_ptr> lambda_bindings = nullptr, + optional_ptr alias_binder = nullptr, + optional_ptr having_binder = nullptr); + virtual ~ColumnQualifier() = default; + + //! Returns the STRUCT_EXTRACT operator expression + unique_ptr CreateStructExtract(unique_ptr base, const Identifier &field_name); + //! Returns a STRUCT_PACK function expression + unique_ptr CreateStructPack(ColumnRefExpression &col_ref); + + //! Returns a qualified column reference from a column name + unique_ptr QualifyColumnName(const ParsedExpression &expr, const Identifier &column_name, + ErrorData &error); + //! Returns a qualified column reference from a column reference with column_names.size() > 2 + unique_ptr QualifyColumnNameWithManyDots(ColumnRefExpression &col_ref, ErrorData &error); + //! Returns a qualified column reference from a column reference + unique_ptr QualifyColumnName(ColumnRefExpression &col_ref, ErrorData &error); + //! Enables special-handling of lambda parameters by tracking them in the lambda_params vector + void QualifyColumnNamesInLambda(FunctionExpression &function, vector &lambda_params); + //! Recursively qualifies the column references in the (children) of the expression. Passes on the + //! within_function_expression state from outer expressions, or sets it + void QualifyColumnNames(unique_ptr &expr, vector &lambda_params, + const bool within_function_expression = false); + + optional_ptr QualifyFunction(FunctionExpression &function); + +private: + Binder &binder; + optional_ptr> lambda_bindings; + optional_ptr alias_binder; + optional_ptr having_binder; + +private: + unique_ptr QualifyColumnNameWithManyDotsInternal(ColumnRefExpression &col_ref, ErrorData &error, + idx_t &struct_extract_start); + + unique_ptr QualifyColumnNameInternal(ColumnRefExpression &col_ref, ErrorData &error); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp index d5a5b5a90..343a06b67 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_aggregate_expression.hpp @@ -23,25 +23,54 @@ class BoundAggregateExpression : public Expression { unique_ptr filter, unique_ptr bind_info, AggregateType aggr_type); - //! The bound function expression - BoundAggregateFunction function; - //! List of arguments to the function - vector> children; - //! The bound function data (if any) - unique_ptr bind_info; - //! The aggregate type (distinct or non-distinct) - AggregateType aggr_type; - - //! Filter for this aggregate - unique_ptr filter; - //! The order by expression for this aggregate - if any - unique_ptr order_bys; - public: bool IsDistinct() const { return aggr_type == AggregateType::DISTINCT; } + const BoundAggregateFunction &Function() const { + return function; + } + BoundAggregateFunction &FunctionMutable() { + return function; + } + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } + const unique_ptr &BindInfo() const { + return bind_info; + } + unique_ptr &BindInfoMutable() { + return bind_info; + } + AggregateType GetAggregateType() const { + return aggr_type; + } + AggregateType &GetAggregateTypeMutable() { + return aggr_type; + } + const unique_ptr &GetFilter() const { + return filter; + } + unique_ptr &GetFilterMutable() { + return filter; + } + const unique_ptr &GetOrderBys() const { + return order_bys; + } + unique_ptr &GetOrderBysMutable() { + return order_bys; + } + AggregateStateExportMode StateExportMode() const { + return state_export_mode; + } + AggregateStateExportMode &StateExportModeMutable() { + return state_export_mode; + } + bool IsAggregate() const override { return true; } @@ -58,5 +87,22 @@ class BoundAggregateExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! The bound function expression + BoundAggregateFunction function; + //! List of arguments to the function + vector> children; + //! The bound function data (if any) + unique_ptr bind_info; + //! The aggregate type (distinct or non-distinct) + AggregateType aggr_type; + //! Whether or not we are exporting state in this aggregate + AggregateStateExportMode state_export_mode; + + //! Filter for this aggregate + unique_ptr filter; + //! The order by expression for this aggregate - if any + unique_ptr order_bys; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_case_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_case_expression.hpp index fc554019b..cba22c14c 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_case_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_case_expression.hpp @@ -29,8 +29,19 @@ class BoundCaseExpression : public Expression { BoundCaseExpression(unique_ptr when_expr, unique_ptr then_expr, unique_ptr else_expr); - vector case_checks; - unique_ptr else_expr; +public: + const vector &CaseChecks() const { + return case_checks; + } + vector &CaseChecksMutable() { + return case_checks; + } + const Expression &Else() const { + return *else_expr; + } + unique_ptr &ElseMutable() { + return else_expr; + } public: string ToString() const override; @@ -41,5 +52,9 @@ class BoundCaseExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + vector case_checks; + unique_ptr else_expr; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_cast_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_cast_expression.hpp index c48ba6bda..04f530898 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_cast_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_cast_expression.hpp @@ -21,14 +21,25 @@ class BoundCastExpression : public Expression { BoundCastExpression(unique_ptr child, LogicalType target_type, BoundCastInfo bound_cast, bool try_cast = false); - //! The child type - unique_ptr child; - //! Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of throwing an error. - bool try_cast; - //! The bound cast info - BoundCastInfo bound_cast; - public: + const LogicalType &TargetType() const { + return return_type; + } + const Expression &Child() const { + return *child; + } + unique_ptr &ChildMutable() { + return child; + } + bool IsTryCast() const { + return try_cast; + } + const BoundCastInfo &GetBoundCast() const { + return bound_cast; + } + BoundCastInfo &GetBoundCastMutable() { + return bound_cast; + } LogicalType source_type() const { // NOLINT: allow casing for legacy reasons D_ASSERT(child->GetReturnType().IsValid()); return child->GetReturnType(); @@ -61,5 +72,13 @@ class BoundCastExpression : public Expression { private: BoundCastExpression(ClientContext &context, unique_ptr child, LogicalType target_type); + +private: + //! The child type + unique_ptr child; + //! Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of throwing an error. + bool try_cast; + //! The bound cast info + BoundCastInfo bound_cast; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_columnref_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_columnref_expression.hpp index 3f6e79a4e..c4614ec59 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_columnref_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_columnref_expression.hpp @@ -22,13 +22,7 @@ class BoundColumnRefExpression : public Expression { public: BoundColumnRefExpression(LogicalType type, ColumnBinding binding, idx_t depth = 0); - BoundColumnRefExpression(string alias, LogicalType type, ColumnBinding binding, idx_t depth = 0); - - //! Column index set by the binder, used to generate the final BoundExpression - ColumnBinding binding; - //! The subquery depth (i.e. depth 0 = current query, depth 1 = parent query, depth 2 = parent of parent, etc...). - //! This is only non-zero for correlated expressions inside subqueries. - idx_t depth; + BoundColumnRefExpression(Identifier alias, LogicalType type, ColumnBinding binding, idx_t depth = 0); public: bool IsScalar() const override { @@ -38,8 +32,21 @@ class BoundColumnRefExpression : public Expression { return false; } + const ColumnBinding &Binding() const { + return binding; + } + ColumnBinding &BindingMutable() { + return binding; + } + idx_t Depth() const { + return depth; + } + idx_t &DepthMutable() { + return depth; + } + string ToString() const override; - string GetName() const override; + Identifier GetName() const override; bool Equals(const BaseExpression &other) const override; hash_t Hash() const override; @@ -48,5 +55,12 @@ class BoundColumnRefExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! Column index set by the binder, used to generate the final BoundExpression + ColumnBinding binding; + //! The subquery depth (i.e. depth 0 = current query, depth 1 = parent query, depth 2 = parent of parent, etc...). + //! This is only non-zero for correlated expressions inside subqueries. + idx_t depth; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_conjunction_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_conjunction_expression.hpp index e83f78f57..c99a7d153 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_conjunction_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_conjunction_expression.hpp @@ -20,9 +20,13 @@ class BoundConjunctionExpression : public Expression { explicit BoundConjunctionExpression(ExpressionType type); BoundConjunctionExpression(ExpressionType type, unique_ptr left, unique_ptr right); - vector> children; - public: + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } string ToString() const override; bool Equals(const BaseExpression &other) const override; @@ -33,5 +37,8 @@ class BoundConjunctionExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + vector> children; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_constant_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_constant_expression.hpp index 40b8ac8af..270f98d68 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_constant_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_constant_expression.hpp @@ -20,9 +20,13 @@ class BoundConstantExpression : public Expression { public: explicit BoundConstantExpression(Value value); - Value value; - public: + const Value &GetValue() const { + return value; + } + Value &GetValueMutable() { + return value; + } string ToString() const override; bool Equals(const BaseExpression &other) const override; @@ -32,5 +36,8 @@ class BoundConstantExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + Value value; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_expanded_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_expanded_expression.hpp index 390e641c0..a777e2409 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_expanded_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_expanded_expression.hpp @@ -21,14 +21,21 @@ class BoundExpandedExpression : public Expression { public: explicit BoundExpandedExpression(vector> expanded_expressions); - vector> expanded_expressions; - public: + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } string ToString() const override; bool Equals(const BaseExpression &other) const override; unique_ptr Copy() const override; + +private: + vector> children; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_function_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_function_expression.hpp index 18509ef70..29e1b17a9 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_function_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_function_expression.hpp @@ -23,16 +23,32 @@ class BoundFunctionExpression : public Expression { BoundFunctionExpression(BoundScalarFunction bound_function, vector> arguments, unique_ptr bind_info, bool is_operator = false); - //! The bound function expression - BoundScalarFunction function; - //! List of child-expressions of the function - vector> children; - //! The bound function data (if any) - unique_ptr bind_info; - //! Whether or not the function is an operator, only used for rendering - bool is_operator; - public: + const BoundScalarFunction &Function() const { + return function; + } + BoundScalarFunction &FunctionMutable() { + return function; + } + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } + const unique_ptr &BindInfo() const { + return bind_info; + } + unique_ptr &BindInfoMutable() { + return bind_info; + } + bool IsOperator() const { + return is_operator; + } + bool &IsOperatorMutable() { + return is_operator; + } + bool IsVolatile() const override; bool IsConsistent() const override; bool IsFoldable() const override; @@ -52,6 +68,16 @@ class BoundFunctionExpression : public Expression { static ExpressionType GetFunctionExpressionType(const BoundScalarFunction &bound_function, const vector> &arguments, optional_ptr bind_info); + +private: + //! The bound function expression + BoundScalarFunction function; + //! List of child-expressions of the function + vector> children; + //! The bound function data (if any) + unique_ptr bind_info; + //! Whether or not the function is an operator, only used for rendering + bool is_operator; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_lambda_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_lambda_expression.hpp index 8206ee9c8..79384137d 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_lambda_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_lambda_expression.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/planner/expression.hpp" -#include "duckdb/parser/expression/lambda_expression.hpp" namespace duckdb { @@ -21,20 +20,40 @@ class BoundLambdaExpression : public Expression { BoundLambdaExpression(ExpressionType type_p, LogicalType return_type_p, unique_ptr lambda_expr_p, idx_t parameter_count_p); - //! The lambda expression that we'll use in the expression executor during execution - unique_ptr lambda_expr; - //! Non-lambda constants, column references, and outer lambda parameters that we need to pass - //! into the execution chunk - vector> captures; - //! The number of lhs parameters of the lambda function - idx_t parameter_count; - public: + const unique_ptr &LambdaExpr() const { + return lambda_expr; + } + unique_ptr &LambdaExprMutable() { + return lambda_expr; + } + const vector> &Captures() const { + return captures; + } + vector> &CapturesMutable() { + return captures; + } + idx_t ParameterCount() const { + return parameter_count; + } + idx_t &ParameterCountMutable() { + return parameter_count; + } + string ToString() const override; bool Equals(const BaseExpression &other) const override; unique_ptr Copy() const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! The lambda expression that we'll use in the expression executor during execution + unique_ptr lambda_expr; + //! Non-lambda constants, column references, and outer lambda parameters that we need to pass + //! into the execution chunk + vector> captures; + //! The number of lhs parameters of the lambda function + idx_t parameter_count; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_lambdaref_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_lambdaref_expression.hpp index f3fd9f20c..c13560e6d 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_lambdaref_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_lambdaref_expression.hpp @@ -23,16 +23,28 @@ class BoundLambdaRefExpression : public Expression { public: BoundLambdaRefExpression(LogicalType type, ColumnBinding binding, idx_t lambda_idx, idx_t depth = 0); - BoundLambdaRefExpression(string alias, LogicalType type, ColumnBinding binding, idx_t lambda_idx, idx_t depth = 0); - //! Column index set by the binder, used to generate the final BoundExpression - ColumnBinding binding; - //! The index of the lambda parameter in the lambda bindings vector - idx_t lambda_idx; - //! The subquery depth (i.e. depth 0 = current query, depth 1 = parent query, depth 2 = parent of parent, etc...). - //! This is only non-zero for correlated expressions inside subqueries. - idx_t depth; + BoundLambdaRefExpression(Identifier alias, LogicalType type, ColumnBinding binding, idx_t lambda_idx, + idx_t depth = 0); public: + const ColumnBinding &Binding() const { + return binding; + } + ColumnBinding &BindingMutable() { + return binding; + } + idx_t LambdaIndex() const { + return lambda_idx; + } + idx_t &LambdaIndexMutable() { + return lambda_idx; + } + idx_t Depth() const { + return depth; + } + idx_t &DepthMutable() { + return depth; + } bool IsScalar() const override { return false; } @@ -49,5 +61,14 @@ class BoundLambdaRefExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! Column index set by the binder, used to generate the final BoundExpression + ColumnBinding binding; + //! The index of the lambda parameter in the lambda bindings vector + idx_t lambda_idx; + //! The subquery depth (i.e. depth 0 = current query, depth 1 = parent query, depth 2 = parent of parent, etc...). + //! This is only non-zero for correlated expressions inside subqueries. + idx_t depth; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_operator_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_operator_expression.hpp index c0a2e2e0d..84cb9936c 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_operator_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_operator_expression.hpp @@ -19,9 +19,13 @@ class BoundOperatorExpression : public Expression { public: BoundOperatorExpression(ExpressionType type, LogicalType return_type); - vector> children; - public: + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } string ToString() const override; bool Equals(const BaseExpression &other) const override; @@ -30,5 +34,8 @@ class BoundOperatorExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + vector> children; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_expression.hpp index 92db0a041..ed472fdd3 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_parameter_expression.hpp @@ -18,12 +18,22 @@ class BoundParameterExpression : public Expression { static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_PARAMETER; public: - explicit BoundParameterExpression(const string &identifier); - - string identifier; - shared_ptr parameter_data; + explicit BoundParameterExpression(const duckdb::Identifier &identifier); public: + const duckdb::Identifier &Identifier() const { + return identifier; + } + duckdb::Identifier &IdentifierMutable() { + return identifier; + } + const shared_ptr &ParameterData() const { + return parameter_data; + } + shared_ptr &ParameterDataMutable() { + return parameter_data; + } + //! Invalidate a bound parameter expression - forcing a rebind on any subsequent filters DUCKDB_API static void Invalidate(Expression &expr); //! Invalidate all parameters within an expression @@ -44,8 +54,12 @@ class BoundParameterExpression : public Expression { static unique_ptr Deserialize(Deserializer &deserializer); private: - BoundParameterExpression(bound_parameter_map_t &global_parameter_set, string identifier, LogicalType return_type, - shared_ptr parameter_data); + BoundParameterExpression(bound_parameter_map_t &global_parameter_set, duckdb::Identifier identifier, + LogicalType return_type, shared_ptr parameter_data); + +private: + duckdb::Identifier identifier; + shared_ptr parameter_data; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_reference_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_reference_expression.hpp index 23c894a50..f5c5a3a05 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_reference_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_reference_expression.hpp @@ -18,13 +18,16 @@ class BoundReferenceExpression : public Expression { static constexpr const ExpressionClass TYPE = ExpressionClass::BOUND_REF; public: - BoundReferenceExpression(string alias, LogicalType type, idx_t index); + BoundReferenceExpression(Identifier alias, LogicalType type, idx_t index); BoundReferenceExpression(LogicalType type, storage_t index); - //! Index used to access data in the chunks - storage_t index; - public: + idx_t Index() const { + return index; + } + idx_t &IndexMutable() { + return index; + } bool IsScalar() const override { return false; } @@ -41,5 +44,9 @@ class BoundReferenceExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! Index used to access data in the chunks + storage_t index; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp index 35792c8d4..6553086d1 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp @@ -26,23 +26,49 @@ class BoundSubqueryExpression : public Expression { return !binder->correlated_columns.empty(); } - //! The binder used to bind the subquery node - shared_ptr binder; - //! The bound subquery node - BoundStatement subquery; - //! The subquery type - SubqueryType subquery_type; - //! the child expressions to compare with (in case of IN, ANY, ALL operators) - vector> children; - //! The comparison type of the child expression with the subquery (in case of ANY, ALL operators) - ExpressionType comparison_type; - //! The LogicalTypes of the subquery result. Only used for ANY expressions. - vector child_types; - //! The target LogicalType of the subquery result (i.e. to which type it should be casted, if child_type <> - //! child_target). Only used for ANY expressions. - vector child_targets; - public: + const shared_ptr &GetBinder() const { + return binder; + } + shared_ptr &GetBinderMutable() { + return binder; + } + const BoundStatement &Subquery() const { + return subquery; + } + BoundStatement &SubqueryMutable() { + return subquery; + } + SubqueryType GetSubqueryType() const { + return subquery_type; + } + SubqueryType &SubqueryTypeMutable() { + return subquery_type; + } + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } + const vector &GetChildTypes() const { + return child_types; + } + vector &ChildTypesMutable() { + return child_types; + } + const vector &GetChildTargets() const { + return child_targets; + } + vector &ChildTargetsMutable() { + return child_targets; + } + ExpressionType ComparisonType() const { + return comparison_type; + } + ExpressionType &ComparisonTypeMutable() { + return comparison_type; + } bool HasSubquery() const override { return true; } @@ -60,5 +86,22 @@ class BoundSubqueryExpression : public Expression { unique_ptr Copy() const override; bool PropagatesNullValues() const override; + +private: + //! The binder used to bind the subquery node + shared_ptr binder; + //! The bound subquery node + BoundStatement subquery; + //! The subquery type + SubqueryType subquery_type; + //! the child expressions to compare with (in case of IN, ANY, ALL operators) + vector> children; + //! The comparison type of the child expression with the subquery (in case of ANY, ALL operators) + ExpressionType comparison_type; + //! The LogicalTypes of the subquery result. Only used for ANY expressions. + vector child_types; + //! The target LogicalType of the subquery result (i.e. to which type it should be casted, if child_type <> + //! child_target). Only used for ANY expressions. + vector child_targets; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_unnest_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_unnest_expression.hpp index 51906548c..7e82cfc9a 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_unnest_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_unnest_expression.hpp @@ -20,9 +20,13 @@ class BoundUnnestExpression : public Expression { public: explicit BoundUnnestExpression(LogicalType return_type); - unique_ptr child; - public: + const unique_ptr &Child() const { + return child; + } + unique_ptr &ChildMutable() { + return child; + } bool IsFoldable() const override; string ToString() const override; @@ -33,5 +37,8 @@ class BoundUnnestExpression : public Expression { void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); + +private: + unique_ptr child; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_window_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_window_expression.hpp index 522c70972..51422739e 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_window_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_window_expression.hpp @@ -26,6 +26,139 @@ class BoundWindowExpression : public Expression { BoundWindowExpression(LogicalType return_type, unique_ptr aggregate, unique_ptr window, unique_ptr bind_info); +public: + const unique_ptr &AggregateFunction() const { + return aggregate; + } + unique_ptr &AggregateFunctionMutable() { + return aggregate; + } + const unique_ptr &WindowFunction() const { + return window; + } + unique_ptr &WindowFunctionMutable() { + return window; + } + const unique_ptr &BindInfo() const { + return bind_info; + } + unique_ptr &BindInfoMutable() { + return bind_info; + } + const vector> &GetChildren() const { + return children; + } + vector> &GetChildrenMutable() { + return children; + } + const vector> &Partitions() const { + return partitions; + } + vector> &PartitionsMutable() { + return partitions; + } + const vector &OrderBy() const { + return orders; + } + vector &OrderByMutable() { + return orders; + } + const unique_ptr &Filter() const { + return filter_expr; + } + unique_ptr &FilterMutable() { + return filter_expr; + } + bool IgnoreNulls() const { + return ignore_nulls; + } + bool &IgnoreNullsMutable() { + return ignore_nulls; + } + bool Distinct() const { + return distinct; + } + bool &DistinctMutable() { + return distinct; + } + WindowBoundary WindowStart() const { + return start; + } + WindowBoundary &WindowStartMutable() { + return start; + } + WindowBoundary WindowEnd() const { + return end; + } + WindowBoundary &WindowEndMutable() { + return end; + } + WindowExcludeMode WindowExclude() const { + return exclude_clause; + } + WindowExcludeMode &WindowExcludeMutable() { + return exclude_clause; + } + const unique_ptr &StartExpr() const { + return start_expr; + } + unique_ptr &StartExprMutable() { + return start_expr; + } + const unique_ptr &EndExpr() const { + return end_expr; + } + unique_ptr &EndExprMutable() { + return end_expr; + } + const vector &ArgOrders() const { + return arg_orders; + } + vector &ArgOrdersMutable() { + return arg_orders; + } + const vector> &PartitionsStats() const { + return partitions_stats; + } + vector> &PartitionsStatsMutable() { + return partitions_stats; + } + const vector> &ExprStats() const { + return expr_stats; + } + vector> &ExprStatsMutable() { + return expr_stats; + } + + bool IsWindow() const override { + return true; + } + bool IsFoldable() const override { + return false; + } + + string ToString() const override; + + //! The number of ordering clauses the functions share + static idx_t GetSharedOrders(const vector &lhs, const vector &rhs); + idx_t GetSharedOrders(const BoundWindowExpression &other) const; + + bool PartitionsAreEquivalent(const BoundWindowExpression &other) const; + bool KeysAreCompatible(const BoundWindowExpression &other) const; + bool Equals(const BaseExpression &other) const override; + + unique_ptr Copy() const override; + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &deserializer); + +private: + //! Remove LEAD/LAG offset/default + vector> SerializedChildren(Serializer &serializer) const; + unique_ptr SerializedOffset(Serializer &serializer) const; + unique_ptr SerializedDefault(Serializer &serializer) const; + +private: //! The bound aggregate function unique_ptr aggregate; //! The bound window function @@ -62,35 +195,6 @@ class BoundWindowExpression : public Expression { //! Statistics belonging to the other expressions (start, end, offset, default) vector> expr_stats; - -public: - bool IsWindow() const override { - return true; - } - bool IsFoldable() const override { - return false; - } - - string ToString() const override; - - //! The number of ordering clauses the functions share - static idx_t GetSharedOrders(const vector &lhs, const vector &rhs); - idx_t GetSharedOrders(const BoundWindowExpression &other) const; - - bool PartitionsAreEquivalent(const BoundWindowExpression &other) const; - bool KeysAreCompatible(const BoundWindowExpression &other) const; - bool Equals(const BaseExpression &other) const override; - - unique_ptr Copy() const override; - - void Serialize(Serializer &serializer) const override; - static unique_ptr Deserialize(Deserializer &deserializer); - -private: - //! Remove LEAD/LAG offset/default - vector> SerializedChildren(Serializer &serializer) const; - unique_ptr SerializedOffset(Serializer &serializer) const; - unique_ptr SerializedDefault(Serializer &serializer) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder.hpp index 7d24fadde..15a6914d7 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder.hpp @@ -34,14 +34,17 @@ class ScalarFunctionCatalogEntry; class AggregateFunctionCatalogEntry; class WindowFunctionCatalogEntry; class ScalarMacroCatalogEntry; +class ScalarMacroFunction; class CatalogEntry; class SimpleFunction; +class HavingBinder; +class ColumnAliasBinder; struct DummyBinding; struct SelectBindState; struct BoundColumnReferenceInfo { - string name; + Identifier name; optional_idx query_location; }; @@ -72,7 +75,7 @@ class ExpressionBinder { friend class StackChecker; public: - ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder = false); + ExpressionBinder(Binder &binder, ClientContext &context); virtual ~ExpressionBinder(); virtual bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, @@ -80,10 +83,14 @@ class ExpressionBinder { return false; } - virtual bool DoesColumnAliasExist(const ColumnRefExpression &colref) { + virtual bool IsLateralBinder() const { return false; } + Binder &GetBinder() const { + return binder; + } + // Returns true if the ColumnRef could be an alias reference (unqualified or qualified with table name "alias") static bool IsPotentialAlias(const ColumnRefExpression &colref); @@ -105,33 +112,27 @@ class ExpressionBinder { const vector &GetBoundColumns() { return bound_columns; } + void TruncateBoundColumns(idx_t count) { + if (bound_columns.size() > count) { + bound_columns.resize(count); + } + } void SetCatalogLookupCallback(catalog_entry_callback_t callback); ErrorData Bind(unique_ptr &expr, idx_t depth, bool root_expression = false); //! Returns the STRUCT_EXTRACT operator expression - unique_ptr CreateStructExtract(unique_ptr base, const string &field_name); + unique_ptr CreateStructExtract(unique_ptr base, const Identifier &field_name); //! Returns a STRUCT_PACK function expression unique_ptr CreateStructPack(ColumnRefExpression &col_ref); - BindResult BindQualifiedColumnName(ColumnRefExpression &colref, const string &table_name); + BindResult BindQualifiedColumnName(ColumnRefExpression &colref, const Identifier &table_name); - //! Returns a qualified column reference from a column name - unique_ptr QualifyColumnName(const ParsedExpression &expr, const string &column_name, - ErrorData &error); - //! Returns a qualified column reference from a column reference with column_names.size() > 2 - unique_ptr QualifyColumnNameWithManyDots(ColumnRefExpression &col_ref, ErrorData &error); - //! Returns a qualified column reference from a column reference - virtual unique_ptr QualifyColumnName(ColumnRefExpression &col_ref, ErrorData &error); - //! Enables special-handling of lambda parameters by tracking them in the lambda_params vector - void QualifyColumnNamesInLambda(FunctionExpression &function, vector> &lambda_params); - //! Recursively qualifies the column references in the (children) of the expression. Passes on the - //! within_function_expression state from outer expressions, or sets it - void QualifyColumnNames(unique_ptr &expr, vector> &lambda_params, - const bool within_function_expression = false); //! Entry point for qualifying the column references of the expression - static void QualifyColumnNames(Binder &binder, unique_ptr &expr); + static void QualifyColumnNames(Binder &binder, unique_ptr &expr, + optional_ptr alias_binder = nullptr); static void QualifyColumnNames(ExpressionBinder &binder, unique_ptr &expr); + static void QualifyColumnNames(HavingBinder &having_binder, unique_ptr &expr); static bool PushCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, CollationType type = CollationType::ALL_COLLATIONS); @@ -154,15 +155,17 @@ class ExpressionBinder { //! FIXME: Generalise this for extensibility. //! Recursively replaces macro parameters with the provided input parameters. - void ReplaceMacroParameters(unique_ptr &expr, vector> &lambda_params); + void ReplaceMacroParameters(unique_ptr &expr, vector &lambda_params); //! Enables special-handling of lambda parameters during macro replacement by tracking them in the lambda_params //! vector. - void ReplaceMacroParametersInLambda(FunctionExpression &function, vector> &lambda_params); + void ReplaceMacroParametersInLambda(FunctionExpression &function, vector &lambda_params); static LogicalType GetExpressionReturnType(const Expression &expr); + virtual unique_ptr QualifyColumnName(ColumnRefExpression &col_ref, ErrorData &error); + //! Returns true if the function name is an alias for the UNNEST function - static bool IsUnnestFunction(const string &function_name); + static bool IsUnnestFunction(const Identifier &function_name); private: //! Current stack depth @@ -204,7 +207,7 @@ class ExpressionBinder { const optional_ptr bind_lambda_context, const vector &function_child_types); - virtual unique_ptr GetSQLValueFunction(const string &column_name); + unique_ptr GetSQLValueFunction(const Identifier &column_name); LogicalType ResolveOperatorType(OperatorExpression &op, vector> &children); LogicalType ResolveCoalesceType(OperatorExpression &op, vector> &children); @@ -212,7 +215,7 @@ class ExpressionBinder { BindResult BindUnsupportedExpression(ParsedExpression &expr, idx_t depth, const string &message); - optional_ptr BindAndQualifyFunction(FunctionExpression &function, bool allow_throw); + CatalogEntry &BindFunction(FunctionExpression &function); protected: virtual BindResult BindGroupingFunction(OperatorExpression &op, idx_t depth); @@ -223,25 +226,25 @@ class ExpressionBinder { virtual BindResult BindUnnest(FunctionExpression &expr, idx_t depth, bool root_expression); virtual BindResult BindMacro(FunctionExpression &expr, ScalarMacroCatalogEntry ¯o, idx_t depth, unique_ptr &expr_ptr); + void FindAggregateExprs(unique_ptr &expr, vector>> &exprs); + void UnfoldWindowMacroExpression(unique_ptr &expr, ScalarMacroFunction ¯o_def); void UnfoldMacroExpression(FunctionExpression &function, ScalarMacroCatalogEntry ¯o_func, unique_ptr &expr, idx_t depth); virtual string UnsupportedAggregateMessage(); virtual string UnsupportedWindowMessage(); virtual string UnsupportedUnnestMessage(); - optional_ptr GetCatalogEntry(const string &catalog, const string &schema, + optional_ptr GetCatalogEntry(const Identifier &catalog, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found); Binder &binder; ClientContext &context; - optional_ptr stored_binder; vector bound_columns; + bool inside_try = false; BindResult TryBindLambdaOrJson(FunctionExpression &function, idx_t depth, CatalogEntry &func, const LambdaSyntaxType syntax_type); - unique_ptr QualifyColumnNameWithManyDotsInternal(ColumnRefExpression &col_ref, ErrorData &error, - idx_t &struct_extract_start); virtual void ThrowIfUnnestInLambda(const ColumnBinding &column_binding); }; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/aggregate_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/aggregate_binder.hpp deleted file mode 100644 index 84075d9f4..000000000 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/aggregate_binder.hpp +++ /dev/null @@ -1,30 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/expression_binder/aggregate_binder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/expression_binder.hpp" - -namespace duckdb { - -//! The AggregateBinder is responsible for binding aggregate statements extracted from a SELECT clause (by the -//! SelectBinder) -class AggregateBinder : public ExpressionBinder { - friend class SelectBinder; - -public: - AggregateBinder(Binder &binder, ClientContext &context); - -protected: - BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, - bool root_expression = false) override; - - string UnsupportedAggregateMessage() override; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/base_select_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/base_select_binder.hpp index 4c551e47e..06dd83125 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/base_select_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/base_select_binder.hpp @@ -18,17 +18,13 @@ class WindowExpression; class BoundSelectNode; -struct BoundGroupInformation { - parsed_expression_map_t map; - case_insensitive_map_t alias_map; - unordered_map collated_groups; -}; - //! The BaseSelectBinder is the base binder of the SELECT, HAVING and QUALIFY binders. It can bind aggregates and window //! functions. class BaseSelectBinder : public ExpressionBinder { + friend class ColumnQualifier; + public: - BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info); + BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node); bool BoundAggregates() { return bound_aggregate; @@ -44,13 +40,6 @@ class BaseSelectBinder : public ExpressionBinder { BindResult BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, idx_t depth) override; - bool inside_window; - bool bound_aggregate = false; - - BoundSelectNode &node; - BoundGroupInformation &info; - -protected: BindResult BindGroupingFunction(OperatorExpression &op, idx_t depth) override; //! Binds a WINDOW expression and returns the result. @@ -59,6 +48,14 @@ class BaseSelectBinder : public ExpressionBinder { ProjectionIndex TryBindGroup(ParsedExpression &expr); BindResult BindGroup(ParsedExpression &expr, idx_t depth, ProjectionIndex group_index); + +protected: + bool inside_window = false; + bool inside_aggregate = false; + bool bound_aggregate = false; + bool inside_aggregate_filter = false; + + BoundSelectNode &node; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/check_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/check_binder.hpp index 51bd9c29a..6dbb86506 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/check_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/check_binder.hpp @@ -17,10 +17,10 @@ namespace duckdb { //! The CHECK binder is responsible for binding an expression within a CHECK constraint class CheckBinder : public ExpressionBinder { public: - CheckBinder(Binder &binder, ClientContext &context, string table, const ColumnList &columns, + CheckBinder(Binder &binder, ClientContext &context, Identifier table, const ColumnList &columns, physical_index_set_t &bound_columns); - string table; + Identifier table; const ColumnList &columns; physical_index_set_t &bound_columns; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp index 5e068ce76..550ac0cb9 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/column_alias_binder.hpp @@ -21,6 +21,7 @@ class ColumnAliasBinder { public: explicit ColumnAliasBinder(SelectBindState &bind_state); + unique_ptr ResolveAlias(ColumnRefExpression &colref); bool BindAlias(ExpressionBinder &enclosing_binder, unique_ptr &expr_ptr, idx_t depth, bool root_expression, BindResult &result); // Check if the column reference is an SELECT item alias. diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp index 9c8efc0b6..faccf16da 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/group_binder.hpp @@ -19,30 +19,22 @@ struct SelectBindState; //! The GROUP binder is responsible for binding expressions in the GROUP BY clause class GroupBinder : public ExpressionBinder { public: - GroupBinder(Binder &binder, ClientContext &context, SelectNode &node, TableIndex group_index, - SelectBindState &bind_state, case_insensitive_map_t &group_alias_map); + GroupBinder(Binder &binder, ClientContext &context, TableIndex group_index, SelectBindState &bind_state); - //! The unbound root expression - unique_ptr unbound_expression; - //! The group index currently being bound - ProjectionIndex bind_index; +public: + static void ReplaceSelectRef(SelectNode &node, SelectBindState &bind_state, ProjectionIndex index, + unique_ptr &expr_ptr); protected: BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) override; string UnsupportedAggregateMessage() override; - BindResult BindSelectRef(idx_t entry); - BindResult BindColumnRef(ColumnRefExpression &expr, unique_ptr &expr_ptr); - BindResult BindConstant(ConstantExpression &expr); - bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, BindResult &result, unique_ptr &expr_ptr) override; - SelectNode &node; +private: SelectBindState &bind_state; - case_insensitive_map_t &group_alias_map; - unordered_set used_aliases; TableIndex group_index; }; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/having_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/having_binder.hpp index 70ef60c9a..471ff1c72 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/having_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/having_binder.hpp @@ -17,20 +17,17 @@ namespace duckdb { //! The HAVING binder is responsible for binding an expression within the HAVING clause of a SQL statement. class HavingBinder : public BaseSelectBinder { public: - HavingBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, - AggregateHandling aggregate_handling); + HavingBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, AggregateHandling aggregate_handling); protected: BindResult BindLambdaReference(LambdaRefExpression &expr, idx_t depth); BindResult BindWindowExpression(WindowExpression &expr, idx_t depth) override; BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) override; - unique_ptr QualifyColumnName(ColumnRefExpression &col_ref, ErrorData &error) override; - - bool DoesColumnAliasExist(const ColumnRefExpression &colref) override; +public: + ColumnAliasBinder column_alias_binder; private: - ColumnAliasBinder column_alias_binder; AggregateHandling aggregate_handling; }; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/index_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/index_binder.hpp index f8972d13e..0abb4cd20 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/index_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/index_binder.hpp @@ -30,7 +30,7 @@ class IndexBinder : public ExpressionBinder { TableCatalogEntry &table_entry, unique_ptr plan, unique_ptr alter_table_info); - static void InitCreateIndexInfo(LogicalGet &get, CreateIndexInfo &info, const string &schema); + static void InitCreateIndexInfo(LogicalGet &get, CreateIndexInfo &info, const Identifier &schema); protected: BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp index 55f046cd7..ae11f8207 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp @@ -24,6 +24,10 @@ class LateralBinder : public ExpressionBinder { return !correlated_columns.empty(); } + bool IsLateralBinder() const override { + return true; + } + static void ReduceExpressionDepth(LogicalOperator &op, const CorrelatedColumns &info); protected: diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/projection_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/projection_binder.hpp index 3416cb719..417605e6f 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/projection_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/projection_binder.hpp @@ -32,6 +32,7 @@ class ProjectionBinder : public ExpressionBinder { TableIndex proj_index; vector> &proj_expressions; string clause; + bool in_child_projection = false; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/qualify_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/qualify_binder.hpp index 23e51ea2b..c9aa54e44 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/qualify_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/qualify_binder.hpp @@ -17,9 +17,7 @@ struct SelectBindState; //! The QUALIFY binder is responsible for binding an expression within the QUALIFY clause of a SQL statement class QualifyBinder : public BaseSelectBinder { public: - QualifyBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info); - - bool DoesColumnAliasExist(const ColumnRefExpression &colref) override; + QualifyBinder(Binder &binder, ClientContext &context, BoundSelectNode &node); protected: BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) override; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp index 6860ce1aa..90505f959 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/planner/bound_query_node.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/parser/expression_map.hpp" @@ -19,11 +20,16 @@ namespace duckdb { //! Bind state during a SelectNode struct SelectBindState { // Mapping of (alias -> index) and a mapping of (Expression -> index) for the SELECT list - case_insensitive_map_t alias_map; + identifier_map_t alias_map; parsed_expression_map_t projection_map; //! The original unparsed expressions. This is exported after binding, because the binding might change the //! expressions (e.g. when a * clause is present) vector> original_expressions; + vector> unbound_groups; + parsed_expression_map_t group_map; + identifier_map_t group_alias_map; + unordered_map collated_groups; + unordered_set used_group_aliases; public: unique_ptr BindAlias(idx_t index); diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp index c070a3325..a14e5f4c2 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/select_binder.hpp @@ -15,18 +15,15 @@ namespace duckdb { //! The SELECT binder is responsible for binding an expression within the SELECT clause of a SQL statement class SelectBinder : public BaseSelectBinder { public: - SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info); + SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node); bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, BindResult &result, unique_ptr &expr_ptr) override; - bool DoesColumnAliasExist(const ColumnRefExpression &colref) override; protected: void ThrowIfUnnestInLambda(const ColumnBinding &column_binding) override; BindResult BindUnnest(FunctionExpression &function, idx_t depth, bool root_expression) override; - unique_ptr GetSQLValueFunction(const string &column_name) override; - protected: idx_t unnest_level = 0; }; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/try_operator_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/try_operator_binder.hpp deleted file mode 100644 index 1a04c0349..000000000 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/try_operator_binder.hpp +++ /dev/null @@ -1,31 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/expression_binder/try_operator_binder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/expression_binder.hpp" - -namespace duckdb { - -//! This binder is used for the TRY expression -class TryOperatorBinder : public ExpressionBinder { - friend class SelectBinder; - -public: - TryOperatorBinder(Binder &binder, ClientContext &context); - - bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, BindResult &result, - unique_ptr &expr_ptr) override; - - bool DoesColumnAliasExist(const ColumnRefExpression &colref) override; - -protected: - BindResult BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, idx_t depth) override; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp index f6e8db403..a949e134b 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/where_binder.hpp @@ -28,8 +28,6 @@ class WhereBinder : public ExpressionBinder { bool TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, BindResult &result, unique_ptr &expr_ptr) override; - bool DoesColumnAliasExist(const ColumnRefExpression &colref) override; - private: BindResult BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression); diff --git a/src/duckdb/src/include/duckdb/planner/filter/bloom_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/bloom_filter.hpp index b1a65d886..bdd01b0fd 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/bloom_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/bloom_filter.hpp @@ -8,81 +8,35 @@ #pragma once +#include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/planner/table_filter.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/storage/buffer_manager.hpp" namespace duckdb { -struct JoinFilterTableFilterState; - -class BloomFilter { -public: - BloomFilter() = default; - void Initialize(ClientContext &context_p, idx_t number_of_rows); - - void InsertHashes(const Vector &hashes_v, idx_t count) const; - idx_t LookupHashes(const Vector &hashes_v, SelectionVector &result_sel, idx_t count) const; - - void InsertOne(hash_t hash) const; - bool LookupOne(hash_t hash) const; - - bool IsInitialized() const { - return initialized; - } +class BloomFilter; +//! DEPRECATED - only preserved for backwards-compatible expression conversion +class LegacyBFTableFilter final : public TableFilter { private: - idx_t num_sectors; - uint64_t bitmask; // num_sectors - 1 -> used to get the sector offset - - bool initialized = false; - AllocatedData buf_; - uint64_t *bf; -}; - -class BFTableFilter final : public TableFilter { -private: - BloomFilter &filter; + optional_ptr filter; bool filters_null_values; string key_column_name; LogicalType key_type; public: - static constexpr auto TYPE = TableFilterType::BLOOM_FILTER; + static constexpr auto TYPE = TableFilterType::LEGACY_BLOOM_FILTER; public: - explicit BFTableFilter(BloomFilter &filter_p, const bool filters_null_values_p, const string &key_column_name_p, - const LogicalType &key_type_p) + explicit LegacyBFTableFilter(optional_ptr filter_p, const bool filters_null_values_p, + const string &key_column_name_p, const LogicalType &key_type_p) : TableFilter(TYPE), filter(filter_p), filters_null_values(filters_null_values_p), key_column_name(key_column_name_p), key_type(key_type_p) { } - //! If the join condition is e.g. "A = B", the bf will filter null values. - //! If the condition is "A is B" the filter will let nulls pass - bool FiltersNullValues() const { - return filters_null_values; - } - - const LogicalType &GetKeyType() const { - return key_type; - } - - string ToString(const string &column_name) const override; - - // Filters by first hashing and then probing the bloom filter. The &sel will hold - // the remaining tuples, &approved_tuple_count will hold the approved count. - idx_t Filter(Vector &keys_v, SelectionVector &sel, idx_t &approved_tuple_count, - JoinFilterTableFilterState &state) const; - bool FilterValue(const Value &value) const; - - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - private: - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; diff --git a/src/duckdb/src/include/duckdb/planner/filter/conjunction_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/conjunction_filter.hpp index 61a9aff51..1d22d2c93 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/conjunction_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/conjunction_filter.hpp @@ -13,50 +13,40 @@ namespace duckdb { -class ConjunctionFilter : public TableFilter { +//! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion +class LegacyConjunctionFilter : public TableFilter { public: - explicit ConjunctionFilter(TableFilterType filter_type_p) : TableFilter(filter_type_p) { + explicit LegacyConjunctionFilter(TableFilterType filter_type_p) : TableFilter(filter_type_p) { } - ~ConjunctionFilter() override { + ~LegacyConjunctionFilter() override { } //! The filters of this conjunction vector> child_filters; - -public: - bool Equals(const TableFilter &other) const override { - return TableFilter::Equals(other); - } }; -class ConjunctionOrFilter : public ConjunctionFilter { +//! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion +class LegacyConjunctionOrFilter : public LegacyConjunctionFilter { public: - static constexpr const TableFilterType TYPE = TableFilterType::CONJUNCTION_OR; + static constexpr const TableFilterType TYPE = TableFilterType::LEGACY_CONJUNCTION_OR; public: - ConjunctionOrFilter(); - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; + LegacyConjunctionOrFilter(); unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); }; -class ConjunctionAndFilter : public ConjunctionFilter { +//! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion +class LegacyConjunctionAndFilter : public LegacyConjunctionFilter { public: - static constexpr const TableFilterType TYPE = TableFilterType::CONJUNCTION_AND; + static constexpr const TableFilterType TYPE = TableFilterType::LEGACY_CONJUNCTION_AND; public: - ConjunctionAndFilter(); + LegacyConjunctionAndFilter(); public: - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); diff --git a/src/duckdb/src/include/duckdb/planner/filter/constant_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/constant_filter.hpp index 7d56ac7ea..31101b340 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/constant_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/constant_filter.hpp @@ -15,12 +15,12 @@ namespace duckdb { //! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion -class ConstantFilter : public TableFilter { +class LegacyConstantFilter : public TableFilter { public: - static constexpr const TableFilterType TYPE = TableFilterType::CONSTANT_COMPARISON; + static constexpr const TableFilterType TYPE = TableFilterType::LEGACY_CONSTANT_COMPARISON; public: - ConstantFilter(ExpressionType comparison_type, Value constant); + LegacyConstantFilter(ExpressionType comparison_type, Value constant); //! The comparison type (e.g. COMPARE_EQUAL, COMPARE_GREATERTHAN, COMPARE_LESSTHAN, ...) ExpressionType comparison_type; @@ -28,10 +28,6 @@ class ConstantFilter : public TableFilter { Value constant; public: - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); diff --git a/src/duckdb/src/include/duckdb/planner/filter/dynamic_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/dynamic_filter.hpp index ca4199a6b..274ed332d 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/dynamic_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/dynamic_filter.hpp @@ -17,38 +17,21 @@ namespace duckdb { class Expression; +struct DynamicFilterData; -struct DynamicFilterData { +//! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion +class LegacyDynamicFilter : public TableFilter { public: - DynamicFilterData(ExpressionType comparison_type, Value constant); - - mutex lock; - ExpressionType comparison_type; - Value constant; - atomic initialized = {false}; - - unique_ptr ToExpression(const Expression &column) const; - - void SetValue(Value val); - void Reset(); -}; - -class DynamicFilter : public TableFilter { -public: - static constexpr const TableFilterType TYPE = TableFilterType::DYNAMIC_FILTER; + static constexpr const TableFilterType TYPE = TableFilterType::LEGACY_DYNAMIC_FILTER; public: - DynamicFilter(); - explicit DynamicFilter(shared_ptr filter_data); + LegacyDynamicFilter(); + explicit LegacyDynamicFilter(shared_ptr filter_data); //! The shared, dynamic filter data shared_ptr filter_data; public: - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); diff --git a/src/duckdb/src/include/duckdb/planner/filter/expression_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/expression_filter.hpp index dd415f770..29d96feea 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/expression_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/expression_filter.hpp @@ -16,8 +16,10 @@ namespace duckdb { class ExpressionExecutor; +struct DynamicFilterData; class BoundFunctionExpression; +class ClientContext; class ExpressionFilter : public TableFilter { public: @@ -46,12 +48,29 @@ class ExpressionFilter : public TableFilter { ExpressionType expression_type); //! Enhanced CheckStatistics that recognizes standard expression patterns - static FilterPropagateResult CheckExpressionStatistics(const Expression &expr, BaseStatistics &stats); + static FilterPropagateResult CheckExpressionStatistics(const Expression &expr, const BaseStatistics &stats); + static FilterPropagateResult CheckExpressionStatistics(optional_ptr context_p, + const Expression &expr, const BaseStatistics &stats); + //! Check if an expression tree contains an internal function with the given name + static bool ContainsInternalFunction(const Expression &expr, const string &func_name); + //! Check if an expression tree is entirely optional filter semantics + static bool IsOptionalExpression(const Expression &expr); + //! Check if the root of an expression tree is an optional filter wrapper + static bool IsRootOptionalExpression(const Expression &expr); + //! Check if a table filter tree is entirely optional filter semantics + static bool IsOptionalFilter(const TableFilter &filter); + //! Check if the root of a table filter tree is an optional filter wrapper + static bool IsRootOptionalFilter(const TableFilter &filter); + //! If this is an optional/selectivity-optional wrapper around a root dynamic filter, + //! return the shared dynamic filter state. + static shared_ptr GetRootOptionalDynamicFilterData(const TableFilter &filter); - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; + FilterPropagateResult CheckStatistics(const BaseStatistics &stats) const; + FilterPropagateResult CheckStatistics(ClientContext &context, const BaseStatistics &stats) const; + string ToString(const string &column_name) const; + string DebugToString() const; + bool Equals(const ExpressionFilter &other) const; + unique_ptr Copy() const; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); @@ -59,6 +78,8 @@ class ExpressionFilter : public TableFilter { ExpressionType replace_type = ExpressionType::BOUND_REF); private: + //! Produce human-readable ToString for internal tablefilter functions + static string InternalFunctionToString(const BoundFunctionExpression &func_expr, const string &column_name); //! Recursively convert expression to friendly string, handling internal functions static string ExpressionToFriendlyString(const Expression &expression, const string &column_name); }; diff --git a/src/duckdb/src/include/duckdb/planner/filter/in_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/in_filter.hpp index e637bbea4..c4e4bc609 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/in_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/in_filter.hpp @@ -14,20 +14,16 @@ namespace duckdb { //! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion -class InFilter : public TableFilter { +class LegacyInFilter : public TableFilter { public: - static constexpr const TableFilterType TYPE = TableFilterType::IN_FILTER; + static constexpr const TableFilterType TYPE = TableFilterType::LEGACY_IN_FILTER; public: - explicit InFilter(vector values); + explicit LegacyInFilter(vector values); vector values; public: - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); diff --git a/src/duckdb/src/include/duckdb/planner/filter/null_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/null_filter.hpp index 43b17e6d7..1bdad4d8c 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/null_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/null_filter.hpp @@ -13,36 +13,28 @@ namespace duckdb { //! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion -class IsNullFilter : public TableFilter { +class LegacyIsNullFilter : public TableFilter { public: - static constexpr const TableFilterType TYPE = TableFilterType::IS_NULL; + static constexpr const TableFilterType TYPE = TableFilterType::LEGACY_IS_NULL; public: - IsNullFilter(); + LegacyIsNullFilter(); public: - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); }; //! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion -class IsNotNullFilter : public TableFilter { +class LegacyIsNotNullFilter : public TableFilter { public: - static constexpr const TableFilterType TYPE = TableFilterType::IS_NOT_NULL; + static constexpr const TableFilterType TYPE = TableFilterType::LEGACY_IS_NOT_NULL; public: - IsNotNullFilter(); + LegacyIsNotNullFilter(); public: - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); diff --git a/src/duckdb/src/include/duckdb/planner/filter/optional_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/optional_filter.hpp index 17c6db2e8..9ddaf3ec7 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/optional_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/optional_filter.hpp @@ -14,38 +14,21 @@ namespace duckdb { -class OptionalFilter : public TableFilter { +//! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion +class LegacyOptionalFilter : public TableFilter { public: - static constexpr auto TYPE = TableFilterType::OPTIONAL_FILTER; + static constexpr auto TYPE = TableFilterType::LEGACY_OPTIONAL_FILTER; public: - explicit OptionalFilter(unique_ptr filter = nullptr); + explicit LegacyOptionalFilter(unique_ptr filter = nullptr); //! Optional child filters. unique_ptr child_filter; public: - string ToString(const string &column_name) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); - - bool IsOnlyForZoneMapFiltering() const override { - return true; - } - - virtual void FiltersNullValues(const LogicalType &type, bool &filters_nulls, bool &filters_valid_values, - TableFilterState &filter_state) const { - } - - virtual unique_ptr InitializeState(ClientContext &context) const { - return make_uniq(); - } - - virtual idx_t FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, - TableFilterState &filter_state, idx_t scan_count, idx_t &approved_tuple_count) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/perfect_hash_join_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/perfect_hash_join_filter.hpp index b31cd6716..bc9b60f2a 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/perfect_hash_join_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/perfect_hash_join_filter.hpp @@ -11,35 +11,21 @@ #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/planner/table_filter.hpp" -#include "duckdb/planner/table_filter_state.hpp" namespace duckdb { class PerfectHashJoinExecutor; -class PerfectHashJoinFilter final : public TableFilter { +//! DEPRECATED - only preserved for backwards-compatible expression conversion +class LegacyPerfectHashJoinFilter final : public TableFilter { public: - static constexpr auto TYPE = TableFilterType::PERFECT_HASH_JOIN_FILTER; + static constexpr auto TYPE = TableFilterType::LEGACY_PERFECT_HASH_JOIN_FILTER; public: - PerfectHashJoinFilter(optional_ptr perfect_join_executor, - const string &key_column_name, const LogicalType &key_type_p); - -public: - const LogicalType &GetKeyType() const { - return key_type; - } - - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - - idx_t Filter(Vector &keys, SelectionVector &sel, idx_t &approved_tuple_count, - JoinFilterTableFilterState &state) const; - bool FilterValue(const Value &value) const; + LegacyPerfectHashJoinFilter(optional_ptr perfect_join_executor, + const string &key_column_name, const LogicalType &key_type_p); private: - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; diff --git a/src/duckdb/src/include/duckdb/planner/filter/prefix_range_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/prefix_range_filter.hpp index bf32e217a..5bd169287 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/prefix_range_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/prefix_range_filter.hpp @@ -8,49 +8,17 @@ #pragma once -#include "duckdb/common/enums/filter_propagate_result.hpp" #include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/unique_ptr.hpp" #include "duckdb/planner/table_filter.hpp" namespace duckdb { -struct JoinFilterTableFilterState; +class PrefixRangeFilter; -class PrefixRangeFilter { -public: - struct BuildState { - virtual ~BuildState() = default; - template - - TARGET &Cast() { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } - }; - - virtual ~PrefixRangeFilter() = default; - virtual void Initialize(ClientContext &context, idx_t number_of_rows, Value min, Value max) = 0; - virtual unique_ptr InitializeBuildState(ClientContext &context) const = 0; - virtual void InsertKeys(Vector &keys, idx_t count, BuildState &state) const = 0; - virtual void MergeBuildState(BuildState &state) = 0; - virtual idx_t LookupKeys(Vector &keys, SelectionVector &result_sel, idx_t count) const = 0; - virtual FilterPropagateResult LookupRange(const Value &lower_bound, const Value &upper_bound) const = 0; - virtual bool IsInitialized() const = 0; - static unique_ptr CreatePrefixRangeFilter(const LogicalType &key_type); - static bool TryComputeSpan(const Value &lower_bound, const Value &upper_bound, uhugeint_t &result); -}; - -class PrefixRangeTableFilter final : public TableFilter { +//! DEPRECATED - only preserved for backwards-compatible expression conversion +class LegacyPrefixRangeTableFilter final : public TableFilter { private: optional_ptr filter; @@ -58,29 +26,13 @@ class PrefixRangeTableFilter final : public TableFilter { LogicalType key_type; public: - static constexpr auto TYPE = TableFilterType::PREFIX_RANGE_FILTER; - static bool SupportedType(const LogicalType &type); - static unique_ptr CreateBitmap(const LogicalType &type); + static constexpr auto TYPE = TableFilterType::LEGACY_PREFIX_RANGE_FILTER; public: - explicit PrefixRangeTableFilter(optional_ptr filter_p, const string &key_column_name_p, - const LogicalType &key_type_p); - - const LogicalType &GetKeyType() const { - return key_type; - } - - string ToString(const string &column_name) const override; - - idx_t Filter(Vector &keys, SelectionVector &sel, idx_t &approved_tuple_count, - JoinFilterTableFilterState &state) const; - bool FilterValue(const Value &value) const; - - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; + explicit LegacyPrefixRangeTableFilter(optional_ptr filter_p, const string &key_column_name_p, + const LogicalType &key_type_p); private: - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; diff --git a/src/duckdb/src/include/duckdb/planner/filter/selectivity_optional_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/selectivity_optional_filter.hpp index 666de870e..f23ccb030 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/selectivity_optional_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/selectivity_optional_filter.hpp @@ -12,64 +12,19 @@ namespace duckdb { -struct SelectivityOptionalFilterState final : public TableFilterState { - enum class FilterStatus { ACTIVE, PAUSED_DUE_TO_HIGH_SELECTIVITY }; - - struct SelectivityStats { - SelectivityStats(idx_t n_vectors_to_check, float selectivity_threshold); - - void Update(idx_t accepted, idx_t processed); - bool IsActive() const; - double GetSelectivity() const; - - //! Configuration - const idx_t n_vectors_to_check; - const float selectivity_threshold; - - //! For computing selectivity stats - idx_t tuples_accepted; - idx_t tuples_processed; - idx_t vectors_processed; - - //! Whether currently paused - FilterStatus status; - - //! For increasing pause if filter is not selective enough - idx_t pause_multiplier; - }; - - unique_ptr child_state; - SelectivityStats stats; - - explicit SelectivityOptionalFilterState(unique_ptr child_state, const idx_t n_vectors_to_check, - const float selectivity_threshold) - : child_state(std::move(child_state)), stats(n_vectors_to_check, selectivity_threshold) { - } -}; - -enum class SelectivityOptionalFilterType : uint8_t { MIN_MAX, BF, PHJ, PRF }; - -class SelectivityOptionalFilter final : public OptionalFilter { +//! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion +class LegacySelectivityOptionalFilter final : public LegacyOptionalFilter { public: float selectivity_threshold; idx_t n_vectors_to_check; - SelectivityOptionalFilter(unique_ptr filter, SelectivityOptionalFilterType type); - SelectivityOptionalFilter(unique_ptr filter, float selectivity_threshold, idx_t n_vectors_to_check); + LegacySelectivityOptionalFilter(unique_ptr filter, SelectivityOptionalFilterType type); + LegacySelectivityOptionalFilter(unique_ptr filter, float selectivity_threshold, + idx_t n_vectors_to_check); public: - unique_ptr Copy() const override; - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; + unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); - void FiltersNullValues(const LogicalType &type, bool &filters_nulls, bool &filters_valid_values, - TableFilterState &filter_state) const override; - unique_ptr InitializeState(ClientContext &context) const override; - idx_t FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, - TableFilterState &filter_state, idx_t scan_count, idx_t &approved_tuple_count) const override; - - bool IsOnlyForZoneMapFiltering() const override { - return false; - } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/struct_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/struct_filter.hpp index 34c51adc4..df9d05834 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/struct_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/struct_filter.hpp @@ -13,27 +13,23 @@ namespace duckdb { //! DEPRECATED - only preserved for backwards-compatible deserialization and expression conversion -class StructFilter : public TableFilter { +class LegacyStructFilter : public TableFilter { public: - static constexpr const TableFilterType TYPE = TableFilterType::STRUCT_EXTRACT; + static constexpr const TableFilterType TYPE = TableFilterType::LEGACY_STRUCT_EXTRACT; public: - StructFilter(idx_t child_idx, string child_name, unique_ptr child_filter); + LegacyStructFilter(idx_t child_idx, Identifier child_name, unique_ptr child_filter); //! The field index to filter on idx_t child_idx; //! The field name to filter on - string child_name; + Identifier child_name; //! The child filter unique_ptr child_filter; public: - FilterPropagateResult CheckStatistics(BaseStatistics &stats) const override; - string ToString(const string &column_name) const override; - bool Equals(const TableFilter &other) const override; - unique_ptr Copy() const override; unique_ptr ToExpression(const Expression &column) const override; void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); diff --git a/src/duckdb/src/include/duckdb/planner/filter/table_filter_function_helpers.hpp b/src/duckdb/src/include/duckdb/planner/filter/table_filter_function_helpers.hpp new file mode 100644 index 000000000..ed1cc23e5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/filter/table_filter_function_helpers.hpp @@ -0,0 +1,165 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_function_helpers.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/filter/table_filter_functions.hpp" + +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/planner/table_filter_state.hpp" + +namespace duckdb { + +struct SelectivityTrackingLocalState : public FunctionLocalState { + SelectivityTrackingLocalState(idx_t n_vectors_to_check, float selectivity_threshold) + : stats(n_vectors_to_check, selectivity_threshold) { + } + + void Update(idx_t accepted, idx_t processed) { + stats.Update(accepted, processed); + } + + bool IsActive() const { + return stats.IsActive(); + } + + SelectivityOptionalFilterState::SelectivityStats stats; +}; + +inline unique_ptr InitSelectivityTrackingLocalState(idx_t n_vectors_to_check, + float selectivity_threshold) { + if (n_vectors_to_check == 0) { + return nullptr; + } + return make_uniq(n_vectors_to_check, selectivity_threshold); +} + +inline idx_t SetAllTrueSelection(idx_t count, optional_ptr true_sel, + optional_ptr false_sel) { + if (true_sel) { + for (idx_t i = 0; i < count; i++) { + true_sel->set_index(i, i); + } + } + return count; +} + +inline idx_t SetAllTrueSelection(idx_t count, optional_ptr sel, + optional_ptr true_sel, optional_ptr false_sel) { + if (!sel) { + return SetAllTrueSelection(count, true_sel, false_sel); + } + if (true_sel) { + for (idx_t i = 0; i < count; i++) { + true_sel->set_index(i, sel->get_index(i)); + } + } + return count; +} + +inline idx_t FillSelectionInversion(idx_t count, const SelectionVector &true_sel, idx_t true_count, + optional_ptr false_sel) { + if (!false_sel) { + return count - true_count; + } + return SelectionVector::Inverted(true_sel, *false_sel, true_count, count); +} + +inline idx_t TranslateSelection(idx_t count, optional_ptr input_sel, + const SelectionVector &local_true_sel, idx_t local_true_count, + optional_ptr true_sel, optional_ptr false_sel) { + if (local_true_count == count) { + return SetAllTrueSelection(count, input_sel, true_sel, false_sel); + } + if (!input_sel) { + if (true_sel && true_sel.get() != &local_true_sel) { + for (idx_t i = 0; i < local_true_count; i++) { + true_sel->set_index(i, local_true_sel.get_index(i)); + } + } + if (false_sel) { + FillSelectionInversion(count, local_true_sel, local_true_count, false_sel); + } + return local_true_count; + } + if (true_sel) { + for (idx_t i = 0; i < local_true_count; i++) { + true_sel->set_index(i, input_sel->get_index(local_true_sel.get_index(i))); + } + } + if (false_sel) { + idx_t false_count = 0; + idx_t true_offset = 0; + for (idx_t i = 0; i < count; i++) { + if (true_offset < local_true_count && local_true_sel.get_index(true_offset) == i) { + true_offset++; + continue; + } + false_sel->set_index(false_count++, input_sel->get_index(i)); + } + } + return local_true_count; +} + +inline void SetConstantBooleanResult(Vector &result, bool value) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::GetData(result)[0] = value; +} + +inline void SelectionToBooleanResult(idx_t count, const SelectionVector &sel, idx_t sel_count, Vector &result) { + if (count == 0 || sel_count == 0) { + SetConstantBooleanResult(result, false); + return; + } + if (sel_count == count) { + SetConstantBooleanResult(result, true); + return; + } + result.SetVectorType(VectorType::FLAT_VECTOR); + FlatVector::ValidityMutable(result).SetAllValid(count); + auto result_data = FlatVector::GetDataMutable(result); + for (idx_t i = 0; i < count; i++) { + result_data[i] = false; + } + for (idx_t i = 0; i < sel_count; i++) { + result_data[sel.get_index(i)] = true; + } +} + +inline void SetAllTrue(DataChunk &args, Vector &result) { + SetConstantBooleanResult(result, true); +} + +template +inline void ExecuteWithSelectivityTracking(DataChunk &args, Vector &result, TRACKING_STATE *tracking_state, + EXECUTOR &&execute) { + if (tracking_state && !tracking_state->IsActive()) { + SetAllTrue(args, result); + tracking_state->Update(0, 0); + return; + } + auto approved_count = execute(); + if (tracking_state) { + tracking_state->Update(approved_count, args.size()); + } +} + +void TableFilterFunctionSerialize(Serializer &serializer, const optional_ptr bind_data, + const BoundScalarFunction &function); +unique_ptr TableFilterFunctionDeserialize(Deserializer &deserializer, BoundScalarFunction &function); + +inline string FormatOptionalFilterString(const string &child_filter_string) { + if (child_filter_string.empty()) { + return "optional"; + } + return "optional: " + child_filter_string; +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/table_filter_functions.hpp b/src/duckdb/src/include/duckdb/planner/filter/table_filter_functions.hpp new file mode 100644 index 000000000..299ebac71 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/filter/table_filter_functions.hpp @@ -0,0 +1,297 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_functions.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/function/scalar/tablefilter_functions.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/planner/table_filter_state.hpp" + +namespace duckdb { + +class BaseStatistics; +class Expression; +class PerfectHashJoinExecutor; +class PrefixRangeFilter; +struct DynamicFilterData; + +struct SelectivityOptionalFilterState final : public TableFilterState { + enum class FilterStatus { ACTIVE, PAUSED_DUE_TO_HIGH_SELECTIVITY }; + + struct SelectivityStats { + SelectivityStats(idx_t n_vectors_to_check, float selectivity_threshold); + + void Update(idx_t accepted, idx_t processed); + bool IsActive() const; + double GetSelectivity() const; + + //! Configuration + const idx_t n_vectors_to_check; + const float selectivity_threshold; + + //! For computing selectivity stats + idx_t tuples_accepted; + idx_t tuples_processed; + idx_t vectors_processed; + + //! Whether currently paused + FilterStatus status; + + //! For increasing pause if filter is not selective enough + idx_t pause_multiplier; + }; + + unique_ptr child_state; + SelectivityStats stats; + + explicit SelectivityOptionalFilterState(unique_ptr child_state, const idx_t n_vectors_to_check, + const float selectivity_threshold) + : child_state(std::move(child_state)), stats(n_vectors_to_check, selectivity_threshold) { + } +}; + +enum class SelectivityOptionalFilterType : uint8_t { MIN_MAX, BF, PHJ, PRF }; + +void GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType type, float &selectivity_threshold, + idx_t &n_vectors_to_check); + +unique_ptr CreateOptionalFilterExpression(unique_ptr child_expr, + const LogicalType &target_type); +unique_ptr CreateSelectivityOptionalFilterExpression(unique_ptr child_expr, + const LogicalType &target_type, + float selectivity_threshold, idx_t n_vectors_to_check); +unique_ptr CreateDynamicFilterExpression(shared_ptr filter_data, + const LogicalType &target_type); + +//! Bind function that prevents user access to internal tablefilter functions +struct TableFilterFunctions { + static unique_ptr Bind(BindScalarFunctionInput &input); + static bool IsTableFilterFunction(const Identifier &name); + static bool IsTableFilterFunction(const BoundScalarFunction &function) { + return IsTableFilterFunction(function.GetName()); + } +}; + +//! Runtime bloom-filter state used by join pushdown and internal tablefilter functions. +class BloomFilter { +public: + BloomFilter() = default; + void Initialize(ClientContext &context_p, idx_t number_of_rows); + + void InsertHashes(const Vector &hashes_v) const; + idx_t LookupHashes(const Vector &hashes_v, SelectionVector &result_sel, idx_t count) const; + + void InsertOne(hash_t hash) const; + bool LookupOne(hash_t hash) const; + void Merge(const BloomFilter &other); + void Reset(); + + bool IsInitialized() const { + return initialized; + } + +private: + idx_t num_sectors; + uint64_t bitmask; // num_sectors - 1 -> used to get the sector offset + + bool initialized = false; + AllocatedData buf_; + uint64_t *bf; +}; + +//! FunctionData for bloom filter internal function +struct BloomFilterFunctionData : public FunctionData { + BloomFilterFunctionData(optional_ptr filter_p, bool filters_null_values_p, + const string &key_column_name_p, const LogicalType &key_type_p, + float selectivity_threshold_p, idx_t n_vectors_to_check_p); + + optional_ptr filter; + bool filters_null_values; + string key_column_name; + LogicalType key_type; + float selectivity_threshold; + idx_t n_vectors_to_check; + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other) const override; +}; + +//! FunctionData for perfect hash join internal function +struct PerfectHashJoinFunctionData : public FunctionData { + PerfectHashJoinFunctionData(optional_ptr executor_p, const string &key_column_name_p, + float selectivity_threshold_p, idx_t n_vectors_to_check_p); + + optional_ptr executor; + string key_column_name; + float selectivity_threshold; + idx_t n_vectors_to_check; + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other) const override; +}; + +//! Runtime prefix-range filter state used by join pushdown and internal tablefilter functions. +class PrefixRangeFilter { +public: + struct BuildState { + virtual ~BuildState() = default; + template + + TARGET &Cast() { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + }; + + virtual ~PrefixRangeFilter() = default; + virtual void Initialize(ClientContext &context, idx_t number_of_rows, Value min, Value max) = 0; + virtual unique_ptr InitializeBuildState(ClientContext &context) const = 0; + virtual void InsertKeys(Vector &keys, BuildState &state) const = 0; + virtual void MergeBuildState(BuildState &state) = 0; + virtual idx_t LookupKeys(Vector &keys, SelectionVector &result_sel, idx_t count) const = 0; + virtual FilterPropagateResult LookupRange(const Value &lower_bound, const Value &upper_bound) const = 0; + virtual bool IsInitialized() const = 0; + static bool SupportedType(const LogicalType &type); + static unique_ptr CreatePrefixRangeFilter(const LogicalType &key_type); + static bool TryComputeSpan(const Value &lower_bound, const Value &upper_bound, uhugeint_t &result); +}; + +//! FunctionData for prefix range internal function +struct PrefixRangeFunctionData : public FunctionData { + PrefixRangeFunctionData(optional_ptr filter_p, const string &key_column_name_p, + const LogicalType &key_type_p, float selectivity_threshold_p, idx_t n_vectors_to_check_p); + + optional_ptr filter; + string key_column_name; + LogicalType key_type; + float selectivity_threshold; + idx_t n_vectors_to_check; + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other) const override; +}; + +//! Runtime dynamic-filter state shared by internal tablefilter functions. +struct DynamicFilterData { +public: + DynamicFilterData(ExpressionType comparison_type, Value constant); + + mutex lock; + ExpressionType comparison_type; + Value constant; + atomic initialized = {false}; + + unique_ptr ToExpression(const Expression &column) const; + + void SetValue(Value val); + void Reset(); + static bool CompareValue(ExpressionType comparison_type, const Value &constant, const Value &value); + static FilterPropagateResult CheckStatistics(const BaseStatistics &stats, ExpressionType comparison_type, + const Value &constant); +}; + +//! FunctionData for dynamic filter internal function +struct DynamicFilterFunctionData : public FunctionData { + explicit DynamicFilterFunctionData(shared_ptr filter_data_p); + + shared_ptr filter_data; + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other) const override; +}; + +//! FunctionData for optional filter internal function (always returns TRUE, used for statistics only) +struct OptionalFilterFunctionData : public FunctionData { + //! The child filter expression, used for CheckStatistics via FilterPruneCallback + unique_ptr child_filter_expr; + + explicit OptionalFilterFunctionData(unique_ptr child_filter_expr_p); + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other) const override; +}; + +//! FunctionData for selectivity-optional filter internal function +struct SelectivityOptionalFilterFunctionData : public FunctionData { + //! The child filter expression, executed only while it remains selective enough + unique_ptr child_filter_expr; + float selectivity_threshold; + idx_t n_vectors_to_check; + + SelectivityOptionalFilterFunctionData(unique_ptr child_filter_expr_p, float selectivity_threshold_p, + idx_t n_vectors_to_check_p); + + unique_ptr Copy() const override; + bool Equals(const FunctionData &other) const override; +}; + +//! Factory for bloom filter internal function +struct BloomFilterScalarFun : public TableFilterBloomFilterFun { + using TableFilterBloomFilterFun::GetFunction; + static constexpr const char *NAME = TableFilterBloomFilterFun::Name; + static ScalarFunction GetFunction(const LogicalType &input_type); + static FilterPropagateResult FilterPrune(const FunctionStatisticsPruneInput &input); + static string ToString(const string &column_name, const string &key_column_name); +}; + +//! Factory for perfect hash join internal function +struct PerfectHashJoinScalarFun : public TableFilterPerfectHashJoinFun { + using TableFilterPerfectHashJoinFun::GetFunction; + static constexpr const char *NAME = TableFilterPerfectHashJoinFun::Name; + static ScalarFunction GetFunction(const LogicalType &input_type); + static FilterPropagateResult FilterPrune(const FunctionStatisticsPruneInput &input); + static string ToString(const string &column_name, const string &key_column_name); +}; + +//! Factory for prefix range internal function +struct PrefixRangeScalarFun : public TableFilterPrefixRangeFun { + using TableFilterPrefixRangeFun::GetFunction; + static constexpr const char *NAME = TableFilterPrefixRangeFun::Name; + static ScalarFunction GetFunction(const LogicalType &input_type); + static FilterPropagateResult FilterPrune(const FunctionStatisticsPruneInput &input); + static string ToString(const string &column_name, const string &key_column_name); +}; + +//! Factory for dynamic filter internal function +struct DynamicFilterScalarFun : public TableFilterDynamicFun { + using TableFilterDynamicFun::GetFunction; + static constexpr const char *NAME = TableFilterDynamicFun::Name; + static ScalarFunction GetFunction(const LogicalType &input_type); + static FilterPropagateResult FilterPrune(const FunctionStatisticsPruneInput &input); + static string ToString(const string &column_name, bool has_filter_data); +}; + +//! Factory for optional filter internal function (always returns TRUE) +struct OptionalFilterScalarFun : public TableFilterOptionalFun { + using TableFilterOptionalFun::GetFunction; + static constexpr const char *NAME = TableFilterOptionalFun::Name; + static ScalarFunction GetFunction(const LogicalType &input_type); + static FilterPropagateResult FilterPrune(const FunctionStatisticsPruneInput &input); + static string ToString(const string &child_filter_string); +}; + +//! Factory for selectivity-optional filter internal function +struct SelectivityOptionalFilterScalarFun : public TableFilterSelectivityOptionalFun { + using TableFilterSelectivityOptionalFun::GetFunction; + static constexpr const char *NAME = TableFilterSelectivityOptionalFun::Name; + static ScalarFunction GetFunction(const LogicalType &input_type); + static FilterPropagateResult FilterPrune(const FunctionStatisticsPruneInput &input); + static string ToString(const string &child_filter_string); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_copy_to_file.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_copy_to_file.hpp index d36091978..2b935252a 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_copy_to_file.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_copy_to_file.hpp @@ -15,6 +15,7 @@ #include "duckdb/function/copy_function.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/common/enums/preserve_order.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" namespace duckdb { @@ -27,6 +28,8 @@ class LogicalCopyToFile : public LogicalOperator { : LogicalOperator(LogicalOperatorType::LOGICAL_COPY_TO_FILE), function(std::move(function)), bind_data(std::move(bind_data)), copy_info(std::move(copy_info)) { } + +public: CopyFunction function; unique_ptr bind_data; unique_ptr copy_info; @@ -50,7 +53,9 @@ class LogicalCopyToFile : public LogicalOperator { bool hive_file_pattern = true; PreserveOrderType preserve_order = PreserveOrderType::AUTOMATIC; vector partition_columns; - vector names; + vector order_columns; + + vector names; vector expected_types; public: @@ -60,8 +65,8 @@ class LogicalCopyToFile : public LogicalOperator { static unique_ptr Deserialize(Deserializer &deserializer); static vector GetTypesWithoutPartitions(const vector &col_types, const vector &part_cols, bool write_part_cols); - static vector GetNamesWithoutPartitions(const vector &col_names, const vector &part_cols, - bool write_part_cols); + static vector GetNamesWithoutPartitions(const vector &col_names, + const vector &part_cols, bool write_part_cols); protected: void ResolveTypes() override { diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp index 1b7832585..b469d6214 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp @@ -23,8 +23,8 @@ class LogicalCTE : public LogicalOperator { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INVALID; public: - explicit LogicalCTE(string ctename_p, TableIndex table_index, idx_t column_count, unique_ptr top, - unique_ptr bottom, + explicit LogicalCTE(Identifier ctename_p, TableIndex table_index, idx_t column_count, + unique_ptr top, unique_ptr bottom, LogicalOperatorType logical_type = LogicalOperatorType::LOGICAL_INVALID) : LogicalOperator(logical_type), ctename(std::move(ctename_p)), table_index(table_index), column_count(column_count) { @@ -32,7 +32,7 @@ class LogicalCTE : public LogicalOperator { children.push_back(std::move(bottom)); } - string ctename; + Identifier ctename; TableIndex table_index; idx_t column_count; CorrelatedColumns correlated_columns; diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_cteref.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_cteref.hpp index 94c889924..d77a064b1 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_cteref.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_cteref.hpp @@ -19,7 +19,7 @@ class LogicalCTERef : public LogicalOperator { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_CTE_REF; public: - LogicalCTERef(TableIndex table_index, TableIndex cte_index, vector types, vector colnames, + LogicalCTERef(TableIndex table_index, TableIndex cte_index, vector types, vector colnames, bool is_recurring = false) : LogicalOperator(LogicalOperatorType::LOGICAL_CTE_REF), table_index(table_index), cte_index(cte_index), correlated_columns(0), is_recurring(is_recurring) { @@ -28,7 +28,7 @@ class LogicalCTERef : public LogicalOperator { bound_columns = std::move(colnames); } - vector bound_columns; + vector bound_columns; //! The table index in the current bind context TableIndex table_index; //! CTE index diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp index afa2ef4b4..5d866094f 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp @@ -23,7 +23,7 @@ class LogicalGet : public LogicalOperator { public: LogicalGet(TableIndex table_index, TableFunction function, unique_ptr bind_data, - vector returned_types, vector returned_names, + vector returned_types, vector returned_names, virtual_column_map_t virtual_columns = virtual_column_map_t()); //! The table index in the current bind context @@ -35,7 +35,7 @@ class LogicalGet : public LogicalOperator { //! The types of ALL columns that can be returned by the table function vector returned_types; //! The names of ALL columns that can be returned by the table function - vector names; + vector names; //! A mapping of column index -> type/name for all virtual columns virtual_column_map_t virtual_columns; //! Columns that are used outside the scan @@ -49,7 +49,7 @@ class LogicalGet : public LogicalOperator { //! The set of named input table types for the table-in table-out function vector input_table_types; //! The set of named input table names for the table-in table-out function - vector input_table_names; + vector input_table_names; //! For a table-in-out function, the set of projected input columns vector projected_input; //! Currently stores File Filters (as strings) applied by hive partitioning/complex filter pushdown and sample rate @@ -62,6 +62,8 @@ class LogicalGet : public LogicalOperator { optional_idx ordinality_idx; //! Row group order options (if set) unique_ptr row_group_order_options; + //! Partition indices to scan, empty means scan all + vector scan_partition_indices; string GetName() const override; InsertionOrderPreservingMap ParamsToString() const override; @@ -72,7 +74,7 @@ class LogicalGet : public LogicalOperator { column_t GetAnyColumn() const; const LogicalType &GetColumnType(const ColumnIndex &column_index) const; - const string &GetColumnName(const ColumnIndex &column_index) const; + const Identifier &GetColumnName(const ColumnIndex &column_index) const; public: void SetColumnIds(vector &&column_ids); @@ -85,6 +87,7 @@ class LogicalGet : public LogicalOperator { idx_t EstimateCardinality(ClientContext &context) override; bool TryGetStorageIndex(const ColumnIndex &column_index, StorageIndex &out_index) const; void SetScanOrder(unique_ptr options); + void SetPartitionsToScan(vector partition_indices); vector GetTableIndex() const override; //! Skips the serialization check in VerifyPlan diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_materialized_cte.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_materialized_cte.hpp index b6fed1a89..1a5eb1773 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_materialized_cte.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_materialized_cte.hpp @@ -21,7 +21,7 @@ class LogicalMaterializedCTE : public LogicalCTE { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_MATERIALIZED_CTE; public: - LogicalMaterializedCTE(string ctename_p, TableIndex table_index, idx_t column_count, + LogicalMaterializedCTE(Identifier ctename_p, TableIndex table_index, idx_t column_count, unique_ptr cte, unique_ptr child, CTEMaterialize materialize) : LogicalCTE(std::move(ctename_p), table_index, column_count, std::move(cte), std::move(child), diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_prepare.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_prepare.hpp index 4a1abb40e..5d54d0c8b 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_prepare.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_prepare.hpp @@ -22,7 +22,8 @@ class LogicalPrepare : public LogicalOperator { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_PREPARE; public: - LogicalPrepare(string name_p, shared_ptr prepared, unique_ptr logical_plan) + LogicalPrepare(Identifier name_p, shared_ptr prepared, + unique_ptr logical_plan) : LogicalOperator(LogicalOperatorType::LOGICAL_PREPARE), name(std::move(name_p)), prepared(std::move(prepared)) { if (logical_plan) { @@ -30,7 +31,7 @@ class LogicalPrepare : public LogicalOperator { } } - string name; + Identifier name; shared_ptr prepared; public: diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_recursive_cte.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_recursive_cte.hpp index c2c7f94c4..c6e3b631e 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_recursive_cte.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_recursive_cte.hpp @@ -21,7 +21,7 @@ class LogicalRecursiveCTE : public LogicalCTE { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_RECURSIVE_CTE; public: - LogicalRecursiveCTE(string ctename_p, TableIndex table_index, idx_t column_count, bool union_all, + LogicalRecursiveCTE(Identifier ctename_p, TableIndex table_index, idx_t column_count, bool union_all, vector> key_targets, unique_ptr top, unique_ptr bottom); diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_reset.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_reset.hpp index 6131b6f88..1944000c4 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_reset.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_reset.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/enums/set_scope.hpp" #include "duckdb/parser/parsed_data/copy_info.hpp" #include "duckdb/planner/logical_operator.hpp" @@ -20,11 +21,11 @@ class LogicalReset : public LogicalOperator { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_RESET; public: - LogicalReset(std::string name_p, SetScope scope_p) + LogicalReset(Identifier name_p, SetScope scope_p) : LogicalOperator(LogicalOperatorType::LOGICAL_RESET), name(std::move(name_p)), scope(scope_p) { } - std::string name; + Identifier name; SetScope scope; public: diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_set.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_set.hpp index ff3aa5098..3500f6331 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_set.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_set.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/identifier.hpp" #include "duckdb/common/enums/set_scope.hpp" #include "duckdb/parser/parsed_data/copy_info.hpp" #include "duckdb/planner/logical_operator.hpp" @@ -20,12 +21,12 @@ class LogicalSet : public LogicalOperator { static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_SET; public: - LogicalSet(std::string name_p, Value value_p, SetScope scope_p) + LogicalSet(Identifier name_p, Value value_p, SetScope scope_p) : LogicalOperator(LogicalOperatorType::LOGICAL_SET), name(std::move(name_p)), value(std::move(value_p)), scope(scope_p) { } - std::string name; + Identifier name; Value value; SetScope scope; diff --git a/src/duckdb/src/include/duckdb/planner/planner.hpp b/src/duckdb/src/include/duckdb/planner/planner.hpp index 985567839..380f53961 100644 --- a/src/duckdb/src/include/duckdb/planner/planner.hpp +++ b/src/duckdb/src/include/duckdb/planner/planner.hpp @@ -27,9 +27,9 @@ class Planner { public: unique_ptr plan; - vector names; + vector names; vector types; - case_insensitive_map_t parameter_data; + identifier_map_t parameter_data; shared_ptr binder; ClientContext &context; diff --git a/src/duckdb/src/include/duckdb/planner/planner_extension.hpp b/src/duckdb/src/include/duckdb/planner/planner_extension.hpp index 669ddbde6..4b8662147 100644 --- a/src/duckdb/src/include/duckdb/planner/planner_extension.hpp +++ b/src/duckdb/src/include/duckdb/planner/planner_extension.hpp @@ -23,6 +23,8 @@ struct PlannerExtensionInfo { } }; +enum class GetSQLValueFunctionReturnType { CONTINUE_BINDING, FINISH_BINDING }; + struct PlannerExtensionInput { ClientContext &context; Binder &binder; @@ -32,6 +34,10 @@ struct PlannerExtensionInput { //! The post_bind function runs after binding succeeds, allowing modification of the bound statement typedef void (*post_bind_function_t)(PlannerExtensionInput &input, BoundStatement &statement); +typedef GetSQLValueFunctionReturnType (*get_sql_value_function_t)(PlannerExtensionInput &input, + const string &column_name, + unique_ptr &result); + class PlannerExtension { public: //! The post-bind function of the planner extension. @@ -40,6 +46,9 @@ class PlannerExtension { //! allowing modification of the plan and result types. post_bind_function_t post_bind_function = nullptr; + //! Override that allows registering a different callback for SQL value functions + get_sql_value_function_t get_sql_value_function = nullptr; + //! Additional planner info passed to the functions shared_ptr planner_info; diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp index c0b4e79cd..a6a3be4fa 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp @@ -42,16 +42,12 @@ class BoundSelectNode : public BoundQueryNode { vector> select_list; //! The FROM clause BoundStatement from_table; - //! The WHERE clause - unique_ptr where_clause; //! list of groups BoundGroupByNode groups; //! HAVING clause unique_ptr having; //! QUALIFY clause unique_ptr qualify; - //! SAMPLE clause - unique_ptr sample_options; //! The amount of columns in the final result idx_t column_count; diff --git a/src/duckdb/src/include/duckdb/planner/subquery/delim_join_cte_rewriter.hpp b/src/duckdb/src/include/duckdb/planner/subquery/delim_join_cte_rewriter.hpp new file mode 100644 index 000000000..a97b7fea0 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/subquery/delim_join_cte_rewriter.hpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/subquery/delim_join_cte_rewriter.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +//! Rewrites fully decorrelated DelimJoins into materialized CTEs. +class DelimJoinCTERewriter { +public: + static void Rewrite(Binder &binder, unique_ptr &plan); + +private: + explicit DelimJoinCTERewriter(Binder &binder); + + void Rewrite(unique_ptr &plan); + void RewriteDelimJoinsToCTEs(unique_ptr &plan, LogicalOperator &rewrite_root, + bool null_rejecting_filter_above = false); + void MaterializeDelimJoinAsCTE(unique_ptr &plan, LogicalOperator &rewrite_root, + bool null_rejecting_filter_above); + +private: + Binder &binder; + bool cte_deliminator_enabled; + vector generated_dedup_cte_indexes; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/table_binding.hpp b/src/duckdb/src/include/duckdb/planner/table_binding.hpp index 5c0bea2b5..f1614b580 100644 --- a/src/duckdb/src/include/duckdb/planner/table_binding.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_binding.hpp @@ -34,29 +34,29 @@ enum class BindingType { BASE, TABLE, DUMMY, CATALOG_ENTRY, CTE }; //! A Binding represents a binding to a table, table-producing function or subquery with a specified table index. struct Binding { - Binding(BindingType binding_type, BindingAlias alias, vector types, vector names, + Binding(BindingType binding_type, BindingAlias alias, vector types, vector names, TableIndex index); virtual ~Binding() = default; public: - bool TryGetBindingIndex(const string &column_name, column_t &column_index); - column_t GetBindingIndex(const string &column_name); - bool HasMatchingBinding(const string &column_name); - virtual ErrorData ColumnNotFoundError(const string &column_name) const; + bool TryGetBindingIndex(const Identifier &column_name, column_t &column_index); + column_t GetBindingIndex(const Identifier &column_name); + bool HasMatchingBinding(const Identifier &column_name); + virtual ErrorData ColumnNotFoundError(const Identifier &column_name) const; virtual BindResult Bind(ColumnRefExpression &colref, idx_t depth); virtual optional_ptr GetStandardEntry(); - string GetAlias() const; + const Identifier &GetAlias() const; BindingType GetBindingType(); const BindingAlias &GetBindingAlias(); TableIndex GetIndex(); const vector &GetColumnTypes(); - const vector &GetColumnNames(); + const vector &GetColumnNames(); idx_t GetColumnCount(); void SetColumnType(idx_t col_idx, LogicalType type); - static BindingAlias GetAlias(const string &explicit_alias, const StandardEntry &entry); - static BindingAlias GetAlias(const string &explicit_alias, optional_ptr entry); + static BindingAlias GetAlias(const Identifier &explicit_alias, const StandardEntry &entry); + static BindingAlias GetAlias(const Identifier &explicit_alias, optional_ptr entry); public: template @@ -88,9 +88,9 @@ struct Binding { //! The types of the bound columns vector types; //! Column names of the subquery - vector names; + vector names; //! Name -> index for the names - case_insensitive_map_t name_map; + identifier_map_t name_map; }; struct EntryBinding : public Binding { @@ -98,7 +98,7 @@ struct EntryBinding : public Binding { static constexpr const BindingType TYPE = BindingType::CATALOG_ENTRY; public: - EntryBinding(const string &alias, vector types, vector names, TableIndex index, + EntryBinding(const Identifier &alias, vector types, vector names, TableIndex index, StandardEntry &entry); StandardEntry &entry; @@ -113,7 +113,7 @@ struct TableBinding : public Binding { static constexpr const BindingType TYPE = BindingType::TABLE; public: - TableBinding(const string &alias, vector types, vector names, + TableBinding(const Identifier &alias, vector types, vector names, vector &bound_column_ids, optional_ptr entry, TableIndex index, virtual_column_map_t virtual_columns); @@ -125,10 +125,10 @@ struct TableBinding : public Binding { virtual_column_map_t virtual_columns; public: - unique_ptr ExpandGeneratedColumn(const string &column_name); + unique_ptr ExpandGeneratedColumn(const Identifier &column_name); BindResult Bind(ColumnRefExpression &colref, idx_t depth) override; optional_ptr GetStandardEntry() override; - ErrorData ColumnNotFoundError(const string &column_name) const override; + ErrorData ColumnNotFoundError(const Identifier &column_name) const override; // These are columns that are present in the name_map, appearing in the order that they're bound const vector &GetBoundColumnIds() const; @@ -145,12 +145,12 @@ struct DummyBinding : public Binding { static constexpr const char *DUMMY_NAME = "0_macro_parameters"; public: - DummyBinding(vector types, vector names, string dummy_name); + DummyBinding(vector types, vector names, string dummy_name); //! Arguments (for macros) vector> *arguments; //! The name of the dummy binding - string dummy_name; + Identifier dummy_name; public: //! Binding macros @@ -166,16 +166,16 @@ enum class CTEType { CAN_BE_REFERENCED, CANNOT_BE_REFERENCED }; struct CTEBinding; struct CTEBindState { - CTEBindState(Binder &parent_binder, QueryNode &cte_def, const vector &aliases); + CTEBindState(Binder &parent_binder, QueryNode &cte_def, const vector &aliases); ~CTEBindState(); Binder &parent_binder; QueryNode &cte_def; - const vector &aliases; + const vector &aliases; idx_t active_binder_count; shared_ptr query_binder; BoundStatement query; - vector names; + vector names; vector types; public: @@ -188,7 +188,7 @@ struct CTEBinding : public Binding { static constexpr const BindingType TYPE = BindingType::CTE; public: - CTEBinding(BindingAlias alias, vector types, vector names, TableIndex index, CTEType type); + CTEBinding(BindingAlias alias, vector types, vector names, TableIndex index, CTEType type); CTEBinding(BindingAlias alias, shared_ptr bind_state, TableIndex index); public: diff --git a/src/duckdb/src/include/duckdb/planner/table_filter.hpp b/src/duckdb/src/include/duckdb/planner/table_filter.hpp index 22910b2d0..dffb5e577 100644 --- a/src/duckdb/src/include/duckdb/planner/table_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_filter.hpp @@ -20,19 +20,19 @@ class PhysicalOperator; class PhysicalTableScan; enum class TableFilterType : uint8_t { - CONSTANT_COMPARISON = 0, // constant comparison (e.g. =C, >C, >=C, C, >=C, Copy() const = 0; - virtual bool Equals(const TableFilter &other) const { - return filter_type == other.filter_type; - } virtual unique_ptr ToExpression(const Expression &column) const = 0; - - virtual bool IsOnlyForZoneMapFiltering() const { - return false; - } - virtual void Serialize(Serializer &serializer) const; static unique_ptr Deserialize(Deserializer &deserializer); diff --git a/src/duckdb/src/include/duckdb/planner/table_filter_state.hpp b/src/duckdb/src/include/duckdb/planner/table_filter_state.hpp index 007351ce0..95f7385e6 100644 --- a/src/duckdb/src/include/duckdb/planner/table_filter_state.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_filter_state.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/planner/table_filter.hpp" -#include "duckdb/common/types/selection_vector.hpp" #include "duckdb/execution/expression_executor.hpp" namespace duckdb { @@ -35,16 +34,6 @@ struct TableFilterState { } }; -struct ConjunctionAndFilterState : public TableFilterState { -public: - vector> child_states; -}; - -struct ConjunctionOrFilterState : public TableFilterState { -public: - vector> child_states; -}; - struct ExpressionFilterState : public TableFilterState { public: ExpressionFilterState(ClientContext &context, const Expression &expression); @@ -57,18 +46,4 @@ struct ExpressionFilterState : public TableFilterState { unique_ptr executor; }; -struct JoinFilterTableFilterState final : public TableFilterState { -public: - explicit JoinFilterTableFilterState(const LogicalType &key_logical_type); - -public: - void PrepareSlicedKeys(Vector &keys_v, SelectionVector &sel, idx_t approved_tuple_count); - -public: - idx_t current_capacity; - Vector hashes_v; - Vector keys_sliced_v; - SelectionVector probe_sel; -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/block_allocator.hpp b/src/duckdb/src/include/duckdb/storage/block_allocator.hpp index 608e61694..bb95b18b7 100644 --- a/src/duckdb/src/include/duckdb/storage/block_allocator.hpp +++ b/src/duckdb/src/include/duckdb/storage/block_allocator.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/hugeint.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/optional_idx.hpp" +#include "duckdb/common/shared_ptr.hpp" #include "duckdb/common/typedefs.hpp" #include "duckdb/common/unique_ptr.hpp" @@ -87,6 +88,9 @@ class BlockAllocator { unsafe_unique_ptr untouched; //! Touched by block IDs unsafe_unique_ptr touched; + + //! Token used to indicate whether current BlockAllocator is alive. + shared_ptr> alive_token; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/block_manager.hpp b/src/duckdb/src/include/duckdb/storage/block_manager.hpp index 371535a44..9b15d9597 100644 --- a/src/duckdb/src/include/duckdb/storage/block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/block_manager.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/mutex.hpp" #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/unordered_map.hpp" +#include "duckdb/main/query_context.hpp" #include "duckdb/storage/block.hpp" #include "duckdb/storage/storage_info.hpp" @@ -20,7 +21,6 @@ namespace duckdb { class BlockHandle; class BufferHandle; class BufferManager; -class ClientContext; class DatabaseInstance; class MetadataManager; diff --git a/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp b/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp index d611ece60..93ad8dfcc 100644 --- a/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp +++ b/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp @@ -34,6 +34,7 @@ struct BufferEvictionNode { bool CanUnload(BlockMemory &memory); shared_ptr TryGetBlockMemory(); + bool IsDeadNode(optional_idx debug_sleep_micros = optional_idx()); }; //! The BufferPool is in charge of handling memory management for one or more databases. It defines memory limits diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp index 21b491091..b7a629c8d 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp @@ -31,7 +31,7 @@ class TableDataWriter { void WriteTableData(Serializer &metadata_serializer); virtual void WriteUnchangedTable(MetaBlockPointer pointer, const vector &metadata_pointers, - idx_t total_rows) = 0; + idx_t total_rows, idx_t next_row_id) = 0; virtual void FinalizeTable(const TableStatistics &global_stats, DataTableInfo &info, RowGroupCollection &collection, Serializer &serializer) = 0; virtual unique_ptr GetRowGroupWriter(RowGroup &row_group) = 0; @@ -55,6 +55,9 @@ class TableDataWriter { bool RequireLegacyStartRow() const { return require_legacy_start_row; } + bool CanPersistRowIdGaps() const { + return can_persist_rowid_gaps; + } void SetRowIdsChanged() { row_ids_changed = true; } @@ -65,6 +68,7 @@ class TableDataWriter { AttachedDatabase &GetAttached(); DatabaseInstance &GetDatabase(); unique_ptr CreateTaskExecutor(); + optional_ptr TryGetClientContext() const; protected: DuckTableEntry &table; @@ -75,6 +79,7 @@ class TableDataWriter { optional_idx row_group_count; bool rebuild_indexes = false; bool require_legacy_start_row = false; + bool can_persist_rowid_gaps = false; atomic row_ids_changed {false}; }; @@ -85,7 +90,7 @@ class SingleFileTableDataWriter : public TableDataWriter { public: void WriteUnchangedTable(MetaBlockPointer pointer, const vector &metadata_pointers, - idx_t total_rows) override; + idx_t total_rows, idx_t next_row_id) override; void FinalizeTable(const TableStatistics &global_stats, DataTableInfo &info, RowGroupCollection &collection, Serializer &serializer) override; unique_ptr GetRowGroupWriter(RowGroup &row_group) override; @@ -100,6 +105,7 @@ class SingleFileTableDataWriter : public TableDataWriter { //! The root pointer, if we are re-using metadata of the table MetaBlockPointer existing_pointer; optional_idx existing_rows; + optional_idx existing_next_row_id; vector existing_pointers; }; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp index c36b8d6bc..693c7848b 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp @@ -25,7 +25,7 @@ struct AlpAnalyzeState : public AnalyzeState { public: using EXACT_TYPE = typename FloatingToExact::TYPE; - explicit AlpAnalyzeState(const CompressionInfo &info) : AnalyzeState(info), compression_data() { + explicit AlpAnalyzeState(BlockManager &block_manager) : AnalyzeState(block_manager), compression_data() { } idx_t total_bytes_used = 0; @@ -36,7 +36,7 @@ struct AlpAnalyzeState : public AnalyzeState { vector> rowgroup_sample; vector> complete_vectors_sampled; alp::AlpCompressionData compression_data; - idx_t storage_version = 0; + StorageVersion storage_version = StorageVersion::INVALID; public: // Returns the required space to hyphotetically store the compressed segment @@ -66,8 +66,7 @@ struct AlpAnalyzeState : public AnalyzeState { template unique_ptr AlpInitAnalyze(ColumnData &col_data, PhysicalType type) { - CompressionInfo info(col_data.GetBlockManager()); - auto state = make_uniq>(info); + auto state = make_uniq>(col_data.GetBlockManager()); state->storage_version = col_data.GetStorageManager().GetStorageVersion(); return unique_ptr(std::move(state)); } @@ -76,13 +75,14 @@ unique_ptr AlpInitAnalyze(ColumnData &col_data, PhysicalType type) * ALP Analyze step only pushes the needed samples to estimate the compression size in the finalize step */ template -bool AlpAnalyze(AnalyzeState &state, Vector &input, idx_t count) { +bool AlpAnalyze(AnalyzeState &state, const Vector &input) { if (state.info.GetBlockSize() + state.info.GetBlockHeaderSize() < DEFAULT_BLOCK_ALLOC_SIZE) { return false; } auto &analyze_state = state.Cast>(); + const auto count = input.size(); bool must_skip_current_vector = alp::AlpUtils::MustSkipSamplingFromCurrentVector( analyze_state.vectors_count, analyze_state.vectors_sampled_count, count); analyze_state.vectors_count += 1; @@ -158,7 +158,9 @@ idx_t AlpFinalAnalyze(AnalyzeState &state) { analyze_state.compression_data); const idx_t uncompressed_size = AlpConstants::EXPONENT_SIZE + sizeof(T) * vector_to_compress.size(); const idx_t compressed_size = analyze_state.compression_data.RequiredSpace(); - const bool should_compress = compressed_size < uncompressed_size || analyze_state.storage_version < 7; + const bool should_compress = + compressed_size < uncompressed_size || + StorageManager::IsPriorToVersion(StorageVersion::V1_5_0, analyze_state.storage_version); const idx_t vector_size = should_compress ? compressed_size : uncompressed_size; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp index b0762f272..93532ac24 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_compress.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/operator/subtract.hpp" #include "duckdb/common/types/null_value.hpp" #include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/main/config.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/compression/alp/algorithm/alp.hpp" @@ -19,28 +20,24 @@ #include "duckdb/storage/compression/patas/patas.hpp" #include "duckdb/storage/table/column_data_checkpointer.hpp" #include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/storage/statistics/stats_writer.hpp" namespace duckdb { template -struct AlpCompressionState : public CompressionState { +struct AlpCompressionState : public StandardCompressionState { public: using EXACT_TYPE = typename FloatingToExact::TYPE; AlpCompressionState(ColumnDataCheckpointData &checkpoint_data, AlpAnalyzeState *analyze_state) - : CompressionState(analyze_state->info), checkpoint_data(checkpoint_data), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ALP)) { + : StandardCompressionState(checkpoint_data, CompressionType::COMPRESSION_ALP) { CreateEmptySegment(); //! Combinations found on the analyze step are needed for compression compression_data.best_k_combinations = analyze_state->compression_data.best_k_combinations; } - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; - unique_ptr current_segment; - BufferHandle handle; - + StatsWriter stats_writer; idx_t vector_idx = 0; idx_t nulls_idx = 0; idx_t vectors_flushed = 0; @@ -76,15 +73,7 @@ struct AlpCompressionState : public CompressionState { } void CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); - current_segment = std::move(compressed_segment); - - auto &buffer_manager = BufferManager::GetBufferManager(current_segment->db); - handle = buffer_manager.Pin(current_segment->block); + CreateAndPinNewSegment(); // The pointer to the start of the compressed data. data_ptr = handle.GetDataMutable() + current_segment->GetBlockOffset() + AlpConstants::HEADER_SIZE; @@ -103,7 +92,8 @@ struct AlpCompressionState : public CompressionState { const idx_t compressed_size = compression_data.RequiredSpace(); const auto storage_version = checkpoint_data.GetStorageManager().GetStorageVersion(); - const bool should_compress = compressed_size < uncompressed_size || storage_version < 7; + const bool should_compress = compressed_size < uncompressed_size || + StorageManager::IsPriorToVersion(StorageVersion::V1_5_0, storage_version); const idx_t vector_size = should_compress ? compressed_size : uncompressed_size; @@ -114,12 +104,12 @@ struct AlpCompressionState : public CompressionState { } if (nulls_idx) { - current_segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); } if (vector_idx != nulls_idx) { //! At least there is one valid value in the vector - current_segment->stats.statistics.SetHasNoNullFast(); + stats_writer.SetHasValid(); for (idx_t i = 0; i < vector_idx; i++) { - current_segment->stats.statistics.UpdateNumericStats(input_vector[i]); + stats_writer.UpdateMinMax(input_vector[i]); } } current_segment->count += vector_idx; @@ -205,7 +195,6 @@ struct AlpCompressionState : public CompressionState { } void FlushSegment() { - auto &checkpoint_state = checkpoint_data.GetCheckpointState(); auto dataptr = handle.GetDataMutable(); idx_t metadata_offset = AlignValue(UsedSpace()); @@ -238,7 +227,7 @@ struct AlpCompressionState : public CompressionState { // Store the offset to the end of metadata (to be used as a backwards pointer in decoding) Store(NumericCast(total_segment_size), dataptr); - checkpoint_state.FlushSegment(std::move(current_segment), std::move(handle), total_segment_size); + FlushCurrentSegment(stats_writer, total_segment_size); data_bytes_used = 0; vectors_flushed = 0; } @@ -298,11 +287,11 @@ unique_ptr AlpInitCompression(ColumnDataCheckpointData &checkp } template -void AlpCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void AlpCompress(CompressionState &state_p, const Vector &scan_vector) { auto &state = (AlpCompressionState &)state_p; UnifiedVectorFormat vdata; scan_vector.ToUnifiedFormat(vdata); - state.Append(vdata, count); + state.Append(vdata, scan_vector.size()); } template diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp index f8f2916f6..b082d5810 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp @@ -63,8 +63,8 @@ struct AlpScanState : public SegmentScanState { using EXACT_TYPE = typename FloatingToExact::TYPE; explicit AlpScanState(ColumnSegment &segment) : segment(segment), count(segment.count) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + handle = buffer_manager.Pin(segment.GetBlockHandle()); // ScanStates never exceed the boundaries of a Segment, // but are not guaranteed to start at the beginning of the Block segment_data = handle.GetDataMutable() + segment.GetBlockOffset(); diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp index 5f8d91d54..730cbf203 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp @@ -26,7 +26,7 @@ struct AlpRDAnalyzeState : public AnalyzeState { public: using EXACT_TYPE = typename FloatingToExact::TYPE; - explicit AlpRDAnalyzeState(const CompressionInfo &info) : AnalyzeState(info), compression_data() { + explicit AlpRDAnalyzeState(BlockManager &block_manager) : AnalyzeState(block_manager), compression_data() { } idx_t vectors_count = 0; @@ -38,15 +38,14 @@ struct AlpRDAnalyzeState : public AnalyzeState { template unique_ptr AlpRDInitAnalyze(ColumnData &col_data, PhysicalType type) { - CompressionInfo info(col_data.GetBlockManager()); - return make_uniq>(info); + return make_uniq>(col_data.GetBlockManager()); } /* * ALPRD Analyze step only pushes the needed samples to estimate the compression size in the finalize step */ template -bool AlpRDAnalyze(AnalyzeState &state, Vector &input, idx_t count) { +bool AlpRDAnalyze(AnalyzeState &state, const Vector &input) { if (state.info.GetBlockSize() + state.info.GetBlockHeaderSize() < DEFAULT_BLOCK_ALLOC_SIZE) { return false; } @@ -54,6 +53,7 @@ bool AlpRDAnalyze(AnalyzeState &state, Vector &input, idx_t count) { using EXACT_TYPE = typename FloatingToExact::TYPE; auto &analyze_state = state.Cast>(); + const auto count = input.size(); bool must_skip_current_vector = alp::AlpUtils::MustSkipSamplingFromCurrentVector( analyze_state.vectors_count, analyze_state.vectors_sampled_count, count); analyze_state.vectors_count += 1; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp index 50e920bbe..6ed07ae63 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_compress.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/helper.hpp" #include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/main/config.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/compression/alp/alp_constants.hpp" @@ -23,13 +24,12 @@ namespace duckdb { template -struct AlpRDCompressionState : public CompressionState { +struct AlpRDCompressionState : public StandardCompressionState { public: using EXACT_TYPE = typename FloatingToExact::TYPE; AlpRDCompressionState(ColumnDataCheckpointData &checkpoint_data, AlpRDAnalyzeState *analyze_state) - : CompressionState(analyze_state->info), checkpoint_data(checkpoint_data), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ALPRD)) { + : StandardCompressionState(checkpoint_data, CompressionType::COMPRESSION_ALPRD) { //! State variables from the analyze step that are needed for compression compression_data.left_parts_dict_map = std::move(analyze_state->compression_data.left_parts_dict_map); compression_data.left_bit_width = analyze_state->compression_data.left_bit_width; @@ -43,11 +43,7 @@ struct AlpRDCompressionState : public CompressionState { CreateEmptySegment(); } - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; - unique_ptr current_segment; - BufferHandle handle; - + StatsWriter stats_writer; idx_t vector_idx = 0; idx_t nulls_idx = 0; idx_t vectors_flushed = 0; @@ -85,15 +81,7 @@ struct AlpRDCompressionState : public CompressionState { } void CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); - current_segment = std::move(compressed_segment); - - auto &buffer_manager = BufferManager::GetBufferManager(db); - handle = buffer_manager.Pin(current_segment->block); + CreateAndPinNewSegment(); // The pointer to the start of the compressed data. data_ptr = handle.GetDataMutable() + current_segment->GetBlockOffset() + AlpRDConstants::HEADER_SIZE + @@ -114,7 +102,8 @@ struct AlpRDCompressionState : public CompressionState { const idx_t compressed_size = compression_data.RequiredSpace(); const auto storage_version = checkpoint_data.GetStorageManager().GetStorageVersion(); - const bool should_compress = compressed_size < uncompressed_size || storage_version < 7; + const bool should_compress = compressed_size < uncompressed_size || + StorageManager::IsPriorToVersion(StorageVersion::V1_5_0, storage_version); const idx_t vector_size = should_compress ? compressed_size : uncompressed_size; @@ -124,13 +113,13 @@ struct AlpRDCompressionState : public CompressionState { CreateEmptySegment(); } if (nulls_idx) { - current_segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); } if (vector_idx != nulls_idx) { //! At least there is one valid value in the vector - current_segment->stats.statistics.SetHasNoNullFast(); + stats_writer.SetHasValid(); for (idx_t i = 0; i < vector_idx; i++) { T floating_point_value = Load(const_data_ptr_cast(&input_vector[i])); - current_segment->stats.statistics.UpdateNumericStats(floating_point_value); + stats_writer.UpdateMinMax(floating_point_value); } } current_segment->count += vector_idx; @@ -203,7 +192,6 @@ struct AlpRDCompressionState : public CompressionState { } void FlushSegment() { - auto &checkpoint_state = checkpoint_data.GetCheckpointState(); auto dataptr = handle.GetDataMutable(); idx_t metadata_offset = AlignValue(UsedSpace()); @@ -252,7 +240,7 @@ struct AlpRDCompressionState : public CompressionState { // Store the Dictionary memcpy((void *)dataptr, (void *)compression_data.left_parts_dict, actual_dictionary_size_bytes); - checkpoint_state.FlushSegment(std::move(current_segment), std::move(handle), total_segment_size); + FlushCurrentSegment(stats_writer, total_segment_size); data_bytes_used = 0; vectors_flushed = 0; } @@ -310,11 +298,11 @@ unique_ptr AlpRDInitCompression(ColumnDataCheckpointData &chec } template -void AlpRDCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void AlpRDCompress(CompressionState &state_p, const Vector &scan_vector) { auto &state = (AlpRDCompressionState &)state_p; UnifiedVectorFormat vdata; scan_vector.ToUnifiedFormat(vdata); - state.Append(vdata, count); + state.Append(vdata, scan_vector.size()); } template diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp index 39b90c822..a911e1068 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp @@ -68,9 +68,9 @@ struct AlpRDScanState : public SegmentScanState { using EXACT_TYPE = typename FloatingToExact::TYPE; explicit AlpRDScanState(ColumnSegment &segment) : segment(segment), count(segment.count) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); - handle = buffer_manager.Pin(segment.block); + handle = buffer_manager.Pin(segment.GetBlockHandle()); // ScanStates never exceed the boundaries of a Segment, // but are not guaranteed to start at the beginning of the Block segment_data = handle.GetDataMutable() + segment.GetBlockOffset(); diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp index 7a38c065e..a31eac7ad 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/flag_buffer.hpp @@ -8,7 +8,6 @@ #pragma once -#include "duckdb.h" #include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" #ifdef DEBUG #include "duckdb/common/vector.hpp" diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_analyze.hpp index 4976e2ca8..f67d23058 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_analyze.hpp @@ -25,7 +25,7 @@ unique_ptr ChimpInitAnalyze(ColumnData &col_data, PhysicalType typ } template -bool ChimpAnalyze(AnalyzeState &state, Vector &input, idx_t count) { +bool ChimpAnalyze(AnalyzeState &state, const Vector &input) { throw InternalException("Chimp has been deprecated, can no longer be used to compress data"); return false; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_compress.hpp index 833cc5567..0bdd80056 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_compress.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_compress.hpp @@ -30,7 +30,7 @@ unique_ptr ChimpInitCompression(ColumnDataCheckpointData &chec } template -void ChimpCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void ChimpCompress(CompressionState &state_p, const Vector &scan_vector) { throw InternalException("Chimp has been deprecated, can no longer be used to compress data"); } diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp index 7535482ee..97fbaf9b8 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp @@ -133,9 +133,9 @@ struct ChimpScanState : public SegmentScanState { using CHIMP_TYPE = typename ChimpType::TYPE; explicit ChimpScanState(ColumnSegment &segment) : segment(segment), segment_count(segment.count) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); - handle = buffer_manager.Pin(segment.block); + handle = buffer_manager.Pin(segment.GetBlockHandle()); auto dataptr = handle.GetDataMutable(); // ScanStates never exceed the boundaries of a Segment, // but are not guaranteed to start at the beginning of the Block diff --git a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/analyze.hpp index 607d1a78c..69d2c1ea6 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/analyze.hpp @@ -12,10 +12,10 @@ namespace dict_fsst { //===--------------------------------------------------------------------===// struct DictFSSTAnalyzeState : public AnalyzeState { public: - explicit DictFSSTAnalyzeState(const CompressionInfo &info); + explicit DictFSSTAnalyzeState(BlockManager &block_manager); public: - bool Analyze(Vector &input, idx_t count); + bool Analyze(const Vector &input); idx_t FinalAnalyze(); public: diff --git a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp index 7a0d61c03..10ea4a4f5 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp @@ -6,6 +6,7 @@ #include "duckdb/storage/compression/dict_fsst/analyze.hpp" #include "duckdb/function/compression_function.hpp" #include "duckdb/common/bitpacking.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/storage/table/column_data_checkpointer.hpp" namespace duckdb { @@ -33,7 +34,7 @@ struct EncodedInput { //===--------------------------------------------------------------------===// // Compress //===--------------------------------------------------------------------===// -struct DictFSSTCompressionState : public CompressionState { +struct DictFSSTCompressionState : public StandardCompressionState { public: DictFSSTCompressionState(ColumnDataCheckpointData &checkpoint_data_p, unique_ptr &&state); ~DictFSSTCompressionState() override; @@ -49,16 +50,12 @@ struct DictFSSTCompressionState : public CompressionState { bool CompressInternal(UnifiedVectorFormat &vector_format, const string_t &str, bool is_null, EncodedInput &encoded_input, const idx_t i, idx_t count, bool fail_on_no_space); - void Compress(Vector &scan_vector, idx_t count); + void Compress(const Vector &scan_vector); void FinalizeCompress(); void Flush(bool final); public: - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; - // State regarding current segment - unique_ptr current_segment; - BufferHandle current_handle; + StatsWriter stats_writer; //! Offset at which to write the next dictionary string idx_t dictionary_offset = 0; diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp index 14565e476..62821e65b 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/analyze.hpp @@ -9,41 +9,30 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Analyze //===--------------------------------------------------------------------===// -struct DictionaryAnalyzeState : public DictionaryCompressionState { +struct DictionaryAnalyzeState : public AnalyzeState { public: - explicit DictionaryAnalyzeState(const CompressionInfo &info); + explicit DictionaryAnalyzeState(BlockManager &block_manager); public: - bool LookupString(string_t str) override; - void AddNewString(string_t str) override; - void AddLastLookup() override; - void AddNull() override; - bool CalculateSpaceRequirements(bool new_string, idx_t string_size) override; - void Flush(bool final = false) override; - void Verify() override; + bool LookupString(string_t str); + void AddNewString(string_t str); + void AddLastLookup(); + void AddNull(); + bool CalculateSpaceRequirements(bool new_string, idx_t string_size); + void Flush(bool final = false); + void Verify(); void UpdateMaxUniqueCount(); public: idx_t segment_count; idx_t current_tuple_count; idx_t current_unique_count; - idx_t max_unique_count_across_segments = - 0; // Is used to allocate the dictionary optimally later on at the InitCompression step + //! Used to allocate the dictionary optimally later on at the InitCompression step + idx_t max_unique_count_across_segments = 0; idx_t current_dict_size; StringHeap heap; string_set_t current_set; bitpacking_width_t current_width; bitpacking_width_t next_width; }; - -struct DictionaryCompressionAnalyzeState : public AnalyzeState { -public: - explicit DictionaryCompressionAnalyzeState(const CompressionInfo &info) - : AnalyzeState(info), analyze_state(make_uniq(info)) { - } - -public: - unique_ptr analyze_state; -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/common.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/common.hpp index 79bd094bf..1845e504e 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/common.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/common.hpp @@ -29,32 +29,50 @@ struct DictionaryCompression { static StringDictionaryContainer GetDictionary(ColumnSegment &segment, BufferHandle &handle); static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer container); -}; -//! Abstract class managing the compression state for size analysis or compression. -class DictionaryCompressionState : public CompressionState { -public: - explicit DictionaryCompressionState(const CompressionInfo &info); - ~DictionaryCompressionState() override; + template + static bool UpdateState(T &state, const Vector &scan_vector) { + state.Verify(); -public: - bool UpdateState(Vector &scan_vector, idx_t count); - -protected: - // Should verify the State - virtual void Verify() = 0; - // Performs a lookup of str, storing the result internally - virtual bool LookupString(string_t str) = 0; - // Add the most recently looked up str to compression state - virtual void AddLastLookup() = 0; - // Add string to the state that is known to not be seen yet - virtual void AddNewString(string_t str) = 0; - // Add a null value to the compression state - virtual void AddNull() = 0; - // Needs to be called before adding a value. Will return false if a flush is required first. - virtual bool CalculateSpaceRequirements(bool new_string, idx_t string_size) = 0; - // Flush the segment to disk if compressing or reset the counters if analyzing - virtual void Flush(bool final = false) = 0; + for (auto entry : scan_vector.Values()) { + idx_t string_size = 0; + bool new_string = false; + auto row_is_valid = entry.IsValid(); + + if (row_is_valid) { + auto &str = entry.GetValue(); + string_size = str.GetSize(); + if (string_size >= StringUncompressed::GetStringBlockLimit(state.info.GetBlockSize())) { + // Big strings not implemented for dictionary compression + return false; + } + new_string = !state.LookupString(str); + } + + bool fits = state.CalculateSpaceRequirements(new_string, string_size); + if (!fits) { + state.Flush(); + new_string = true; + + fits = state.CalculateSpaceRequirements(new_string, string_size); + if (!fits) { + throw InternalException("Dictionary compression could not write to new segment"); + } + } + + if (!row_is_valid) { + state.AddNull(); + } else if (new_string) { + state.AddNewString(entry.GetValue()); + } else { + state.AddLastLookup(); + } + + state.Verify(); + } + + return true; + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp index 11852a66d..61592a8cf 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dictionary/compression.hpp @@ -4,6 +4,7 @@ #include "duckdb/common/typedefs.hpp" #include "duckdb/storage/compression/dictionary/common.hpp" #include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/storage/table/column_data_checkpointer.hpp" namespace duckdb { @@ -21,31 +22,27 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Compress //===--------------------------------------------------------------------===// -struct DictionaryCompressionCompressState : public DictionaryCompressionState { +struct DictionaryCompressionCompressState : public StandardCompressionState { public: - DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, const CompressionInfo &info, + DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, idx_t max_unique_count_across_all_segments); public: void CreateEmptySegment(); - void Verify() override; - bool LookupString(string_t str) override; - void AddNewString(string_t str) override; - void AddNull() override; - void AddLastLookup() override; - bool CalculateSpaceRequirements(bool new_string, idx_t string_size) override; - void Flush(bool final = false) override; + void Verify(); + bool LookupString(string_t str); + void AddNewString(string_t str); + void AddNull(); + void AddLastLookup(); + bool CalculateSpaceRequirements(bool new_string, idx_t string_size); + void Flush(bool final = false); idx_t Finalize(); public: - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; - // State regarding current segment - unique_ptr current_segment; - BufferHandle current_handle; StringDictionaryContainer current_dictionary; data_ptr_t current_end_ptr; + StatsWriter stats_writer; // Buffers and map for current segment PrimitiveDictionary current_string_map; diff --git a/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp b/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp index 83c83c5ba..3d13ccad0 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp @@ -1,6 +1,7 @@ #pragma once #include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/statistics/stats_writer.hpp" #include "duckdb/storage/table/column_data.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table/column_data_checkpointer.hpp" @@ -11,17 +12,13 @@ class EmptyValidityCompression { public: struct EmptyValidityCompressionState : public CompressionState { public: - explicit EmptyValidityCompressionState(ColumnDataCheckpointData &checkpoint_data, const CompressionInfo &info) - : CompressionState(info), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_EMPTY)), - checkpoint_data(checkpoint_data) { + explicit EmptyValidityCompressionState(ColumnDataCheckpointData &checkpoint_data) + : CompressionState(checkpoint_data, CompressionType::COMPRESSION_EMPTY) { } ~EmptyValidityCompressionState() override { } public: - optional_ptr function; - ColumnDataCheckpointData &checkpoint_data; idx_t count = 0; idx_t non_nulls = 0; }; @@ -43,12 +40,13 @@ class EmptyValidityCompression { public: static unique_ptr InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state_p) { - return make_uniq(checkpoint_data, state_p->info); + return make_uniq(checkpoint_data); } - static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { + static void Compress(CompressionState &state_p, const Vector &scan_vector) { auto &state = state_p.Cast(); UnifiedVectorFormat format; scan_vector.ToUnifiedFormat(format); + const auto count = scan_vector.size(); state.non_nulls += format.validity.CountValid(count); state.count += count; } @@ -56,22 +54,20 @@ class EmptyValidityCompression { auto &state = state_p.Cast(); auto &checkpoint_data = state.checkpoint_data; - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto &info = state.info; - auto compressed_segment = ColumnSegment::CreateTransientSegment(db, *state.function, type, info.GetBlockSize(), - info.GetBlockManager()); + auto compressed_segment = state.CreateNewSegment(); compressed_segment->count = state.count; + + StatsWriter stats_writer; if (state.non_nulls != state.count) { - compressed_segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); } if (state.non_nulls != 0) { - compressed_segment->stats.statistics.SetHasNoNullFast(); + stats_writer.SetHasValid(); } + stats_writer.Merge(compressed_segment->GetStatsMutable()); auto &buffer_manager = BufferManager::GetBufferManager(checkpoint_data.GetDatabase()); - auto handle = buffer_manager.Pin(compressed_segment->block); + auto handle = buffer_manager.Pin(compressed_segment->GetBlockHandle()); auto &checkpoint_state = checkpoint_data.GetCheckpointState(); checkpoint_state.FlushSegment(std::move(compressed_segment), std::move(handle), 0); diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_analyze.hpp index c86d2538d..1801854fc 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_analyze.hpp @@ -25,7 +25,7 @@ unique_ptr PatasInitAnalyze(ColumnData &col_data, PhysicalType typ } template -bool PatasAnalyze(AnalyzeState &state, Vector &input, idx_t count) { +bool PatasAnalyze(AnalyzeState &state, const Vector &input) { throw InternalException("Patas has been deprecated, can no longer be used to compress data"); return false; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_compress.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_compress.hpp index 9770a3d47..67abf7fc1 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_compress.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_compress.hpp @@ -29,7 +29,7 @@ unique_ptr PatasInitCompression(ColumnDataCheckpointData &chec } template -void PatasCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void PatasCompress(CompressionState &state_p, const Vector &scan_vector) { throw InternalException("Patas has been deprecated, can no longer be used to compress data"); } diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp index 10a1b7491..e5dee8903 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp @@ -88,9 +88,9 @@ struct PatasScanState : public SegmentScanState { using EXACT_TYPE = typename FloatingToExact::TYPE; explicit PatasScanState(ColumnSegment &segment) : segment(segment), count(segment.count) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); - handle = buffer_manager.Pin(segment.block); + handle = buffer_manager.Pin(segment.GetBlockHandle()); // ScanStates never exceed the boundaries of a Segment, // but are not guaranteed to start at the beginning of the Block segment_data = handle.GetDataMutable() + segment.GetBlockOffset(); diff --git a/src/duckdb/src/include/duckdb/storage/compression/roaring/appender.hpp b/src/duckdb/src/include/duckdb/storage/compression/roaring/appender.hpp index 11af88125..368e7b79b 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/roaring/appender.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/roaring/appender.hpp @@ -45,7 +45,8 @@ struct RoaringStateAppender { } } - static void AppendVector(STATE_TYPE &state, Vector &input, idx_t input_size) { + static void AppendVector(STATE_TYPE &state, const Vector &input) { + const idx_t input_size = input.size(); UnifiedVectorFormat unified; input.ToUnifiedFormat(unified); auto &validity = unified.validity; diff --git a/src/duckdb/src/include/duckdb/storage/compression/roaring/roaring.hpp b/src/duckdb/src/include/duckdb/storage/compression/roaring/roaring.hpp index ab249bedf..ff69c4b41 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/roaring/roaring.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/roaring/roaring.hpp @@ -13,6 +13,7 @@ #include "duckdb/common/types/validity_mask.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/function/compression_function.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/storage/table/scan_state.hpp" namespace duckdb { @@ -206,7 +207,7 @@ struct BitmaskTableEntry { //===--------------------------------------------------------------------===// struct RoaringAnalyzeState : public AnalyzeState { public: - explicit RoaringAnalyzeState(const CompressionInfo &info); + explicit RoaringAnalyzeState(BlockManager &block_manager); public: // RoaringStateAppender interface: @@ -224,7 +225,7 @@ struct RoaringAnalyzeState : public AnalyzeState { void FlushSegment(); void FlushContainer(); template - void Analyze(Vector &input, idx_t count) { + void Analyze(const Vector &input) { static_assert(AlwaysFalse>::VALUE, "No specialization exists for this type"); } @@ -265,9 +266,9 @@ struct RoaringAnalyzeState : public AnalyzeState { vector container_metadata; }; template <> -void RoaringAnalyzeState::Analyze(Vector &input, idx_t count); +void RoaringAnalyzeState::Analyze(const Vector &input); template <> -void RoaringAnalyzeState::Analyze(Vector &input, idx_t count); +void RoaringAnalyzeState::Analyze(const Vector &input); //===--------------------------------------------------------------------===// // Compress @@ -327,7 +328,7 @@ struct ContainerCompressionState { append_func_t append_function; }; -struct RoaringCompressState : public CompressionState { +struct RoaringCompressState : public StandardCompressionState { public: explicit RoaringCompressState(ColumnDataCheckpointData &checkpoint_data, unique_ptr analyze_state_p); @@ -351,9 +352,9 @@ struct RoaringCompressState : public CompressionState { void Finalize(); void FlushContainer(); void NextContainer(); - void Compress(Vector &input, idx_t count); + void Compress(const Vector &input); template - void Compress(Vector &input, idx_t count) { + void Compress(const Vector &input) { static_assert(AlwaysFalse>::VALUE, "No specialization exists for this type"); } @@ -366,23 +367,22 @@ struct RoaringCompressState : public CompressionState { ContainerMetadataCollection metadata_collection; vector &container_metadata; - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; - unique_ptr current_segment; - BufferHandle handle; - // Ptr to next free spot in segment; data_ptr_t data_ptr; // Ptr to next free spot for storing data_ptr_t metadata_ptr; //! The amount of values already compressed idx_t total_count = 0; + //! Stats writer that only tracks NULL/NOT NULL + StatsWriter stats_writer; + //! Stats writer (only used for boolean) + StatsWriter bool_stats_writer; }; template <> -void RoaringCompressState::Compress(Vector &input, idx_t count); +void RoaringCompressState::Compress(const Vector &input); template <> -void RoaringCompressState::Compress(Vector &input, idx_t count); +void RoaringCompressState::Compress(const Vector &input); //===--------------------------------------------------------------------===// // Scan @@ -634,20 +634,20 @@ struct RoaringScanState : public SegmentScanState { template static void BitPackBooleans(data_ptr_t dst, const bool *src, const idx_t count, - const ValidityMask *validity_mask = nullptr, BaseStatistics *statistics = nullptr) { + const ValidityMask *validity_mask = nullptr, StatsWriter *statistics = nullptr) { uint8_t byte = 0; int bit_pos = 0; uint8_t src_bit = false; if (ALL_VALID) { if (UPDATE_STATS) { - statistics->SetHasNoNullFast(); + statistics->SetHasValid(); } for (idx_t i = 0; i < count; i++) { src_bit = src[i]; if (UPDATE_STATS) { - statistics->UpdateNumericStats(src_bit); + statistics->Update(src_bit); } byte |= src_bit << bit_pos; bit_pos++; @@ -673,10 +673,10 @@ static void BitPackBooleans(data_ptr_t dst, const bool *src, const idx_t count, if (UPDATE_STATS) { if (valid) { - statistics->UpdateNumericStats(src_bit); - statistics->SetHasNoNullFast(); + statistics->Update(src_bit); + statistics->SetHasValid(); } else { - statistics->SetHasNullFast(); + statistics->SetHasNull(); } } diff --git a/src/duckdb/src/include/duckdb/storage/compression/standard_compression_state.hpp b/src/duckdb/src/include/duckdb/storage/compression/standard_compression_state.hpp new file mode 100644 index 000000000..9b9b777de --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/compression/standard_compression_state.hpp @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/compression/standard_compression_state.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/storage/statistics/stats_writer.hpp" + +namespace duckdb { + +struct StandardCompressionState : public CompressionState { + explicit StandardCompressionState(ColumnDataCheckpointData &checkpoint_data, CompressionType compression_type) + : CompressionState(checkpoint_data, compression_type) { + } + ~StandardCompressionState() override; + + void CreateAndPinNewSegment(); + void FlushCurrentSegment(idx_t segment_size); + template + void FlushCurrentSegment(T &stats_writer, idx_t segment_size) { + // merge the stats into the segment stats + stats_writer.Merge(current_segment->GetStatsMutable()); + // flush the segment + FlushCurrentSegment(segment_size); + + // clear the stats writer for the next segment + stats_writer.Clear(); + } + +public: + unique_ptr current_segment; + BufferHandle handle; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp index 21560ba00..6cd69afe2 100644 --- a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp @@ -17,6 +17,8 @@ #include "duckdb/storage/block.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/storage_info.hpp" +#include "duckdb/storage/table/per_column_metadata_blocks.hpp" +#include "duckdb/storage/table/per_column_metadata_blocks.hpp" namespace duckdb { @@ -78,7 +80,13 @@ struct RowGroupPointer { bool has_metadata_blocks = false; //! Metadata blocks of the columns that are not mentioned in "data_pointers" //! This is often empty - but can be set for wide columns with a lot of metadata + //! When targeting 2.0 storage format, per_column_metadata_blocks is used instead vector extra_metadata_blocks; + //! Whether or not we have per-column metadata blocks + bool has_per_column_metadata_blocks = false; + //! Per-column metadata blocks beyond the start block + //! Each column entry contains the additional block IDs that the column's metadata spans (excluding the start block) + PerColumnMetadataBlocks per_column_metadata_blocks; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/data_table.hpp b/src/duckdb/src/include/duckdb/storage/data_table.hpp index 66d5deb2a..7d22e8705 100644 --- a/src/duckdb/src/include/duckdb/storage/data_table.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_table.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/enums/column_segment_info_scan_type.hpp" #include "duckdb/common/unique_ptr.hpp" #include "duckdb/storage/table/data_table_info.hpp" #include "duckdb/storage/table/persistent_table_data.hpp" @@ -45,6 +46,8 @@ struct ConstraintState; struct TableUpdateState; struct OptimisticWriteCollection; struct ColumnFetchState; +struct ColumnSegmentInfo; +struct ColumnSegmentInfoScanState; struct DataTableInfo; struct LocalAppendState; struct ParallelTableScanState; @@ -246,8 +249,16 @@ class DataTable : public enable_shared_from_this { idx_t ColumnCount() const; idx_t GetTotalRows() const; + idx_t GetNextRowId() const; + idx_t GetRowGroupCount() const; + idx_t GetRowGroupCountWithLocalStorage(ClientContext &context); - vector GetColumnSegmentInfo(const QueryContext &context); + vector + GetColumnSegmentInfo(const QueryContext &context, + const ColumnSegmentInfoScanOptions &options = ColumnSegmentInfoScanOptions {}); + void InitializeColumnSegmentInfoScan(ColumnSegmentInfoScanState &state); + bool ScanColumnSegmentInfo(const QueryContext &context, ColumnSegmentInfoScanState &state, + vector &result); //! Scans the next chunk for the CREATE INDEX operator bool CreateIndexScan(TableScanState &state, DataChunk &result); @@ -264,6 +275,12 @@ class DataTable : public enable_shared_from_this { shared_ptr &GetDataTableInfo(); + //! Direct access to the row group collection. Intended for extensions that need to walk storage internals; + //! prefer the higher-level DataTable API for normal use. + const shared_ptr &GetRowGroupCollection() const { + return row_groups; + } + void BindIndexes(ClientContext &context); bool HasIndexes() const; bool HasUniqueIndexes() const; @@ -274,8 +291,8 @@ class DataTable : public enable_shared_from_this { void CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_t count); void Destroy(); - string GetTableName() const; - void SetTableName(string new_name); + Identifier GetTableName() const; + void SetTableName(Identifier new_name); TableStorageInfo GetStorageInfo(); diff --git a/src/duckdb/src/include/duckdb/storage/database_handle.hpp b/src/duckdb/src/include/duckdb/storage/database_handle.hpp new file mode 100644 index 000000000..aa28da48f --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/database_handle.hpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/database_handle.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/block_manager.hpp" +#include "duckdb/storage/storage_options.hpp" +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/memory_mapped_file.hpp" + +namespace duckdb { +struct StorageManagerOptions; + +enum class DatabaseOpenMode { OPEN_EXISTING_FILE, CREATE_NEW_FILE }; + +struct DatabaseHandle { + explicit DatabaseHandle(unique_ptr handle); + explicit DatabaseHandle(unique_ptr mmap_handle); + + static unique_ptr Open(AttachedDatabase &db, const string &path, + const StorageManagerOptions &options, DatabaseOpenMode open_mode); + bool OnDiskFile() const; + + void CheckMagicBytes(QueryContext context); + void Read(QueryContext context, FileBuffer &block, uint64_t location) const; + void Write(QueryContext context, FileBuffer &block, uint64_t location); + void Sync(); + + void Truncate(idx_t new_size); + void Trim(idx_t offset, idx_t length); + + FileHandle &GetFileHandle(); + +private: + //! Throws if a write at [required_size] would exceed the mmap reserve. No-op otherwise. + void EnsureMappedSize(idx_t required_size) const; + + static unique_ptr OpenMemoryMap(AttachedDatabase &db, const string &path, + const StorageManagerOptions &options, DatabaseOpenMode open_mode); + static unique_ptr OpenFile(AttachedDatabase &db, const string &path, + const StorageManagerOptions &options, DatabaseOpenMode open_mode); + +private: + //! The file handle + unique_ptr handle; + //! Memory-mapped view of the file in MAP mode. Mutually exclusive with `handle`. + unique_ptr mmap_handle; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/external_file_cache/caching_file_system.hpp b/src/duckdb/src/include/duckdb/storage/external_file_cache/caching_file_system.hpp index bea7ab5ea..9deafa37c 100644 --- a/src/duckdb/src/include/duckdb/storage/external_file_cache/caching_file_system.hpp +++ b/src/duckdb/src/include/duckdb/storage/external_file_cache/caching_file_system.hpp @@ -21,12 +21,14 @@ namespace duckdb { +class Allocator; class BufferHandle; class ClientContext; class DatabaseInstance; class FileOpenFlags; class FileSystem; struct FileHandle; +struct NetworkThroughputEstimate; class QueryContext; class CachingFileSystem; @@ -36,12 +38,14 @@ struct CachingFileHandle { public: DUCKDB_API CachingFileHandle(QueryContext context, CachingFileSystem &caching_file_system, const OpenFileInfo &path, - FileOpenFlags flags, optional_ptr opener, CachedFile &cached_file); + FileOpenFlags flags, optional_ptr opener); DUCKDB_API ~CachingFileHandle(); public: //! Get the underlying FileHandle DUCKDB_API FileHandle &GetFileHandle(); + //! Get the buffer-manager-backed Allocator. + DUCKDB_API Allocator &GetBufferAllocator() const; //! Read [nr_bytes] bytes at the requested [location]. //! Returns a buffer handle group that keeps the data pinned in memory. DUCKDB_API FileBufferHandleGroup Read(idx_t nr_bytes, idx_t location); @@ -56,9 +60,16 @@ struct CachingFileHandle { DUCKDB_API bool CanSeek(); DUCKDB_API bool IsRemoteFile() const; DUCKDB_API bool OnDiskFile(); + DUCKDB_API bool TryGetNetworkThroughput(NetworkThroughputEstimate &result); DUCKDB_API idx_t SeekPosition(); DUCKDB_API void Seek(idx_t location); +private: + //! Remove the 'force_full_download' option from the file handle if present, and return whether it was present + bool StripForceFullDownloadIfPresent(); + //! Refresh the cached file if the global cache state has changed. + shared_ptr EnsureCachedFileCurrent(); + private: QueryContext context; @@ -74,8 +85,8 @@ struct CachingFileHandle { optional_ptr opener; //! Cache validation mode for this file CacheValidationMode validate; - //! The associated CachedFile with cached blocks - CachedFile &cached_file; + //! Associated cached file. + shared_ptr cached_file; //! Used to ensure file handle and cached file metadata is only initialized once. annotated_mutex file_handle_mutex; diff --git a/src/duckdb/src/include/duckdb/storage/external_file_cache/external_file_cache.hpp b/src/duckdb/src/include/duckdb/storage/external_file_cache/external_file_cache.hpp index 59354525f..ac562780d 100644 --- a/src/duckdb/src/include/duckdb/storage/external_file_cache/external_file_cache.hpp +++ b/src/duckdb/src/include/duckdb/storage/external_file_cache/external_file_cache.hpp @@ -18,6 +18,7 @@ #include "duckdb/common/typedefs.hpp" #include "duckdb/common/unique_ptr.hpp" #include "duckdb/common/unordered_map.hpp" +#include "duckdb/common/unordered_set.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/winapi.hpp" #include "duckdb/storage/buffer/temporary_file_information.hpp" @@ -38,13 +39,14 @@ class ExternalFileCache { //! Cached files struct CachedFile { public: - explicit CachedFile(string path_p); + CachedFile(string path_p, idx_t generation_p); public: //! Whether the CachedFile is still valid given the current modified/version tag bool IsValid(bool validate, const string ¤t_version_tag, timestamp_t current_last_modified); const string path; + const idx_t generation; mutable annotated_mutex map_lock; //! The block size used to index the current block map. Invalid if no blocks have been cached yet. @@ -69,7 +71,10 @@ class ExternalFileCache { bool IsEnabled() const; void SetEnabled(bool enable); + idx_t GetGeneration() const; vector GetCachedFileInformation() const; + //! Number of files tracked in the ObjectCache, exposed for testing. + idx_t GetCachedFileCount() const; //! Re-index to `current_block_size` if it differs from the cache block size. //! Return the blocks cached for the given range. @@ -77,25 +82,38 @@ class ExternalFileCache { idx_t first_block, idx_t num_blocks); BufferManager &GetBufferManager() const; - //! Gets the cached file, or creates it if is not yet present - CachedFile &GetOrCreateCachedFile(const string &path); + //! Gets the shared cached file for the given path, creating it if not yet present. + //! When caching is disabled, returns a transient CachedFile that is not tracked in the cached file map. + shared_ptr GetOrCreateCachedFile(const string &path); DUCKDB_API static bool IsValid(bool validate, const string &cached_version_tag, timestamp_t cached_last_modified, const string ¤t_version_tag, timestamp_t current_last_modified); private: + class ExternalFileCacheObjectCacheEntry; + //! Re-index blocks of a single cached file. void ReindexCachedFileCore(CachedFile &cached_file, idx_t file_size, idx_t old_block_size, idx_t new_block_size) DUCKDB_REQUIRES(cached_file.map_lock); + //! Registers a cached file path in the tracked set. + void InsertCachedFileKey(const string &path); + //! Removes a cached file path from the tracked set. + void EraseCachedFileKey(const string &path); + //! Delete the ObjectCache entries for the given cached file paths. + void DeleteObjectCacheEntries(const vector &paths); + //! The BufferManager used to cache files BufferManager &buffer_manager; //! Whether or not file caching is enabled atomic enable; - //! Mapping from file path to cached file with cached blocks - unordered_map> cached_files; - //! Lock for accessing the cached files - mutable mutex lock; + //! Generation counter, incremented whenever cache enablement changes. + atomic generation; + //! Paths of the cached files tracked in the ObjectCache. + //! Entries should only be inserted at `GetOrCreateCachedFile` and deleted at object cache entry deletion. + unordered_set cached_file_keys DUCKDB_GUARDED_BY(lock); + //! Lock for accessing cached_file_keys. + mutable annotated_mutex lock; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/index.hpp b/src/duckdb/src/include/duckdb/storage/index.hpp index d14f22692..53004460c 100644 --- a/src/duckdb/src/include/duckdb/storage/index.hpp +++ b/src/duckdb/src/include/duckdb/storage/index.hpp @@ -52,7 +52,7 @@ class Index { virtual const string &GetIndexType() const = 0; //! The name of the index - virtual const string &GetIndexName() const = 0; + virtual const Identifier &GetIndexName() const = 0; //! The index constraint type virtual IndexConstraintType GetConstraintType() const = 0; diff --git a/src/duckdb/src/include/duckdb/storage/index_storage_info.hpp b/src/duckdb/src/include/duckdb/storage/index_storage_info.hpp index 838662787..704b6c809 100644 --- a/src/duckdb/src/include/duckdb/storage/index_storage_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/index_storage_info.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/common/unordered_set.hpp" #include "duckdb/storage/block.hpp" @@ -41,7 +42,7 @@ struct IndexBufferInfo { //! Index (de)serialization information. struct IndexStorageInfo { IndexStorageInfo() {}; - explicit IndexStorageInfo(const string &name) : name(name) {}; + explicit IndexStorageInfo(const Identifier &name) : name(name) {}; //! Disable copy constructor and copy assignment, this type's lifetime is explicitly managed. IndexStorageInfo(const IndexStorageInfo &) = delete; @@ -51,7 +52,7 @@ struct IndexStorageInfo { IndexStorageInfo &operator=(IndexStorageInfo &&) = default; //! The name. - string name; + Identifier name; //! The storage root. idx_t root; //! Any index specialization can provide additional key-Value settings via this map. diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp index 29d160260..65f5a2414 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp @@ -80,6 +80,7 @@ class MetadataManager { //! Flush all blocks to disk void Flush(); + bool BlockIsModified(const MetaBlockPointer &ptr); bool BlockHasBeenCleared(const MetaBlockPointer &ptr); void MarkBlocksAsModified(); diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp index ce8d01b41..cb94abfb8 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp @@ -34,9 +34,6 @@ class MetadataReader : public ReadStream { MetadataManager &GetMetadataManager() { return manager; } - //! Gets a list of all remaining blocks to be read by this metadata reader - consumes all blocks - //! If "last_block" is specified, we stop when reaching that block - vector GetRemainingBlocks(MetaBlockPointer last_block = MetaBlockPointer()); private: data_ptr_t BasePtr(); diff --git a/src/duckdb/src/include/duckdb/storage/object_cache.hpp b/src/duckdb/src/include/duckdb/storage/object_cache.hpp index 5d4cb6ef9..08712fde6 100644 --- a/src/duckdb/src/include/duckdb/storage/object_cache.hpp +++ b/src/duckdb/src/include/duckdb/storage/object_cache.hpp @@ -13,6 +13,7 @@ #include "duckdb/common/lru_cache.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/string.hpp" +#include "duckdb/common/string_util.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/main/database.hpp" #include "duckdb/storage/buffer/buffer_pool_reservation.hpp" @@ -149,6 +150,30 @@ class ObjectCache { lru_cache.Delete(key); } + //! Type-prefixed variants of the methods above. These namespace the caller-provided key with the entry's + //! ObjectType so that callers can pass a natural key (e.g. a file path) without having to build a unique + //! cache key themselves. + template + shared_ptr GetWithTypePrefix(const string &key) { + return Get(MakeCacheKey(key)); + } + + template + shared_ptr GetOrCreateWithTypePrefix(const string &key, ARGS &&... args) { + return GetOrCreate(MakeCacheKey(key), std::forward(args)...); + } + + template + void PutWithTypePrefix(const string &key, + shared_ptr value) { // NOLINT(performance-unnecessary-value-param) + Put(MakeCacheKey(key), std::move(value)); + } + + template + void DeleteWithTypePrefix(const string &key) { + Delete(MakeCacheKey(key)); + } + DUCKDB_API static ObjectCache &GetObjectCache(ClientContext &context); idx_t GetMaxMemory() const { @@ -173,6 +198,14 @@ class ObjectCache { return lru_cache.EvictToReduceAtLeast(target_bytes); } +private: + //! Build the internal cache key for a typed entry by namespacing the caller-provided key with the entry's + //! ObjectType. + template + static string MakeCacheKey(const string &key) { + return StringUtil::Format("%s-%s", T::ObjectType(), key); + } + private: mutable mutex lock_mutex; //! LRU cache for evictable entries diff --git a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp index 9d8e8499b..42566ec72 100644 --- a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp @@ -19,10 +19,13 @@ struct OptimisticWriteCollection { shared_ptr collection; set unflushed_row_groups; - idx_t complete_row_groups = 0; + set flushed_row_groups; + idx_t unflushed_data_size = 0; + idx_t prev_allocated_size = 0; vector> partial_block_managers; void MergeStorage(OptimisticWriteCollection &collection); + void FinalizeFlush(); }; enum class OptimisticWritePartialManagers { PER_COLUMN, GLOBAL }; @@ -38,7 +41,7 @@ class OptimisticDataWriter { CreateCollection(DataTable &storage, const vector &insert_types, OptimisticWritePartialManagers type = OptimisticWritePartialManagers::PER_COLUMN); //! Write a new row group to disk (if possible) - void WriteNewRowGroup(OptimisticWriteCollection &row_groups); + void WriteNewRowGroup(OptimisticWriteCollection &row_groups, idx_t flushed_row_group_idx); //! Write any unflushed row groups of a collection to disk void WriteUnflushedRowGroups(OptimisticWriteCollection &row_groups); //! Final flush of the optimistic writer - fully flushes the partial block manager diff --git a/src/duckdb/src/include/duckdb/storage/segment/uncompressed.hpp b/src/duckdb/src/include/duckdb/storage/segment/uncompressed.hpp index 643101c01..ee4c9a8c1 100644 --- a/src/duckdb/src/include/duckdb/storage/segment/uncompressed.hpp +++ b/src/duckdb/src/include/duckdb/storage/segment/uncompressed.hpp @@ -16,7 +16,7 @@ class DatabaseInstance; struct UncompressedFunctions { static unique_ptr InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state); - static void Compress(CompressionState &state_p, Vector &data, idx_t count); + static void Compress(CompressionState &state_p, const Vector &data); static void FinalizeCompress(CompressionState &state_p); static void EmptySkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_count) { } diff --git a/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp index 59a7c3672..706be9b85 100644 --- a/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp @@ -10,12 +10,15 @@ #include "duckdb/storage/block_manager.hpp" #include "duckdb/storage/block.hpp" +#include "duckdb/storage/storage_options.hpp" #include "duckdb/common/file_system.hpp" +#include "duckdb/common/memory_mapped_file.hpp" #include "duckdb/common/unordered_set.hpp" #include "duckdb/common/set.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/main/config.hpp" #include "duckdb/common/encryption_functions.hpp" +#include "duckdb/storage/database_handle.hpp" namespace duckdb { @@ -44,11 +47,13 @@ struct EncryptionOptions { struct StorageManagerOptions { bool read_only = false; - bool use_direct_io = false; + FileIOMode io_mode = FileIOMode::BUFFERED_IO; + //! Reserve size for MMAP mode; empty uses the built-in default. + optional_idx mmap_reserve_size; DebugInitialize debug_initialize = DebugInitialize::NO_INITIALIZE; optional_idx block_alloc_size; - optional_idx storage_version; - optional_idx version_number; + StorageVersion storage_version = StorageVersion::INVALID; + StorageVersion version_number = StorageVersion::INVALID; optional_idx block_header_size; //! Unique database identifier and optional encryption salt. data_t db_identifier[MainHeader::DB_IDENTIFIER_LEN]; @@ -65,7 +70,6 @@ class SingleFileBlockManager : public BlockManager { SingleFileBlockManager(AttachedDatabase &db_p, const string &path_p, const StorageManagerOptions &options_p); ~SingleFileBlockManager() override; - FileOpenFlags GetFileFlags(bool create_new) const; //! Creates a new database. void CreateNewDatabase(QueryContext context); //! Loads an existing database. We pass the provided block allocation size as a parameter @@ -131,13 +135,18 @@ class SingleFileBlockManager : public BlockManager { return iteration_count; } //! Return the version number of the file. - uint64_t GetVersionNumber() const; + StorageVersion GetVersionNumber() const; //! Return the database identifier. data_ptr_t GetDBIdentifier() { return options.db_identifier; } private: + //! Rewrites the main header + //! Usage should be avoided + //! Except for phasing-out the storage version + void RewriteMainHeader(QueryContext &context, + StorageVersion version_number = MainHeader::DEPRECATED_VERSION_NUMBER); //! Loads the free list of the file. void LoadFreeList(QueryContext context); @@ -188,8 +197,8 @@ class SingleFileBlockManager : public BlockManager { uint8_t active_header; //! The path where the file is stored string path; - //! The file handle - unique_ptr handle; + //! The database handle + unique_ptr handle; //! The buffer used to read/write to the headers FileBuffer header_buffer; //! The list of free blocks that can be written to currently diff --git a/src/duckdb/src/include/duckdb/storage/statistics/array_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/array_stats.hpp index 4466e2598..20660bd64 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/array_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/array_stats.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/storage/statistics/stats_merge_type.hpp" namespace duckdb { class BaseStatistics; @@ -34,9 +35,11 @@ struct ArrayStats { DUCKDB_API static child_list_t ToStruct(const BaseStatistics &stats); - DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other, + StatsMergeType merge_type = StatsMergeType::MERGE_STATS); DUCKDB_API static void Copy(BaseStatistics &stats, const BaseStatistics &other); - DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + DUCKDB_API static void Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, + idx_t count); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp index 892d06ea8..a68c85d28 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp @@ -17,6 +17,7 @@ #include "duckdb/storage/statistics/string_stats.hpp" #include "duckdb/storage/statistics/geometry_stats.hpp" #include "duckdb/storage/statistics/variant_stats.hpp" +#include "duckdb/storage/statistics/stats_merge_type.hpp" namespace duckdb { struct SelectionVector; @@ -25,7 +26,6 @@ class Serializer; class Deserializer; class Vector; -struct UnifiedVectorFormat; enum class StatsInfo : uint8_t { CAN_HAVE_NULL_VALUES = 0, @@ -46,6 +46,21 @@ enum class StatisticsType : uint8_t { VARIANT_STATS }; +struct ExtraStatsData { + virtual ~ExtraStatsData() = default; + + template + TARGET &Cast() { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } + template + const TARGET &Cast() const { + DynamicCastCheck(this); + return reinterpret_cast(*this); + } +}; + class BaseStatistics { friend struct NumericStats; friend struct StringStats; @@ -105,7 +120,7 @@ class BaseStatistics { void SetHasNull(); void SetHasNoNull(); - void Merge(const BaseStatistics &other); + void Merge(const BaseStatistics &other, StatsMergeType merge_type = StatsMergeType::MERGE_STATS); void Copy(const BaseStatistics &other); @@ -118,8 +133,8 @@ class BaseStatistics { static BaseStatistics Deserialize(Deserializer &deserializer); //! Verify that a vector does not violate the statistics - void Verify(Vector &vector, const SelectionVector &sel, idx_t count, bool ignore_has_null = false) const; - void Verify(Vector &vector, idx_t count) const; + void Verify(const Vector &vector, const SelectionVector &sel, idx_t count, bool ignore_has_null = false) const; + void Verify(const Vector &vector, idx_t count) const; Value ToStruct() const; string ToString() const; @@ -166,6 +181,8 @@ class BaseStatistics { //! Variant stats data, for variant stats VariantStatsData variant_data; } stats_union; + //! Extra stats data, used for e.g. string data if required + unique_ptr extra_data; //! Child stats (for LIST and STRUCT) unsafe_unique_array child_stats; }; diff --git a/src/duckdb/src/include/duckdb/storage/statistics/column_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/column_statistics.hpp index 12f8d5672..412d538a6 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/column_statistics.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/column_statistics.hpp @@ -24,7 +24,7 @@ class ColumnStatistics { void Merge(ColumnStatistics &other); - void UpdateDistinctStatistics(Vector &v, idx_t count, Vector &hashes); + void UpdateDistinctStatistics(const Vector &v, idx_t count, Vector &hashes); BaseStatistics &Statistics(); diff --git a/src/duckdb/src/include/duckdb/storage/statistics/distinct_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/distinct_statistics.hpp index 57f7fb2bb..4efbb6ea6 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/distinct_statistics.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/distinct_statistics.hpp @@ -33,8 +33,8 @@ class DistinctStatistics { unique_ptr Copy() const; - void UpdateSample(Vector &new_data, idx_t count, Vector &hashes); - void Update(Vector &new_data, idx_t count, Vector &hashes); + void UpdateSample(const Vector &new_data, idx_t count, Vector &hashes); + void Update(const Vector &new_data, idx_t count, Vector &hashes); string ToString() const; idx_t GetCount() const; @@ -45,7 +45,7 @@ class DistinctStatistics { static unique_ptr Deserialize(Deserializer &deserializer); private: - void UpdateInternal(Vector &update, idx_t count, Vector &hashes); + void UpdateInternal(const Vector &update, idx_t count, Vector &hashes); private: //! For distinct statistics we sample the input to speed up insertions diff --git a/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp index 25fdd4640..34d1be6f7 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp @@ -281,7 +281,8 @@ struct GeometryStats { DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); - DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + DUCKDB_API static void Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, + idx_t count); //! Check if a spatial predicate check with a constant could possibly be satisfied by rows given the statistics DUCKDB_API static FilterPropagateResult CheckZonemap(const BaseStatistics &stats, diff --git a/src/duckdb/src/include/duckdb/storage/statistics/list_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/list_stats.hpp index 31ec6abdc..ce23a2703 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/list_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/list_stats.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/storage/statistics/stats_merge_type.hpp" namespace duckdb { class BaseStatistics; @@ -32,9 +33,11 @@ struct ListStats { DUCKDB_API static child_list_t ToStruct(const BaseStatistics &stats); - DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other, + StatsMergeType merge_type = StatsMergeType::MERGE_STATS); DUCKDB_API static void Copy(BaseStatistics &stats, const BaseStatistics &other); - DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + DUCKDB_API static void Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, + idx_t count); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats.hpp index c275dd004..c391f9405 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/numeric_stats.hpp @@ -90,7 +90,7 @@ struct NumericStats { UpdateValue(new_value, nstats.min.GetReferenceUnsafe(), nstats.max.GetReferenceUnsafe()); } - static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + static void Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count); template static T GetMin(const BaseStatistics &stats) { @@ -111,7 +111,8 @@ struct NumericStats { static Value MinOrNull(const BaseStatistics &stats); static Value MaxOrNull(const BaseStatistics &stats); template - static void TemplatedVerify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + static void TemplatedVerify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, + idx_t count); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/stats_merge_type.hpp b/src/duckdb/src/include/duckdb/storage/statistics/stats_merge_type.hpp new file mode 100644 index 000000000..ca71a1f18 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/stats_merge_type.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/stats_merge_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/typedefs.hpp" + +namespace duckdb { + +//! How two statistics objects are being merged. +//! MERGE_STATS: both sides represent disjoint data sets — additive stats (e.g. total_string_length) are summed. +//! EXPAND_BOUNDS: one or both sides are approximate bounds — additive stats become unknown (has_X = false). +enum class StatsMergeType { + MERGE_STATS = 0, + EXPAND_BOUNDS = 1, +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/stats_writer.hpp b/src/duckdb/src/include/duckdb/storage/statistics/stats_writer.hpp new file mode 100644 index 000000000..7325d3bdd --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/stats_writer.hpp @@ -0,0 +1,203 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/stats_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/storage/statistics/numeric_stats.hpp" +#include "duckdb/storage/statistics/string_stats.hpp" +#include "duckdb/storage/statistics/geometry_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "utf8proc_wrapper.hpp" +#include "duckdb/main/error_manager.hpp" + +namespace duckdb { + +struct BaseStatsWriter { + void SetHasNull() { + has_null = true; + } + void SetHasValid() { + has_valid = true; + } + + bool HasStats() const { + return has_null || has_valid; + } + + bool AnyValid() const { + return has_valid; + } + + void ClearBase() { + has_null = false; + has_valid = false; + } + + void MergeBase(BaseStatistics &other) const { + if (has_null) { + other.SetHasNullFast(); + } + if (has_valid) { + other.SetHasNoNullFast(); + } + } + +private: + bool has_null = false; + bool has_valid = false; +}; + +template +struct StatsWriter : public BaseStatsWriter { + explicit StatsWriter() { + Clear(); + } + + inline void Clear() { + ClearBase(); + min = NumericLimits::Maximum(); + max = NumericLimits::Minimum(); + } + + void Update(T new_value) { + SetHasValid(); + UpdateMinMax(new_value); + } + + void UpdateMinMax(T new_value) { + min = LessThan::Operation(new_value, min) ? new_value : min; + max = GreaterThan::Operation(new_value, max) ? new_value : max; + } + + void Merge(BaseStatistics &target) const { + MergeBase(target); + if (AnyValid()) { + target.UpdateNumericStats(min); + target.UpdateNumericStats(max); + } + } + +private: + T min; + T max; +}; + +template <> +struct StatsWriter : public BaseStatsWriter { + explicit StatsWriter() { + Clear(); + } + + inline void Clear() { + ClearBase(); + } + + void Merge(BaseStatistics &target) const { + MergeBase(target); + } +}; + +template <> +struct StatsWriter : public BaseStatsWriter { + friend struct StringStats; + + explicit StatsWriter(const LogicalType &type) + : is_varchar(type.id() == LogicalTypeId::VARCHAR), is_geometry(type.id() == LogicalTypeId::GEOMETRY) { + Clear(); + } + + inline void Clear() { + ClearBase(); + if (is_geometry) { + geometry_stats.SetEmpty(); + } else { + is_set = false; + min_size = 0; + max_size = 0; + has_unicode = false; + max_string_length = 0; + min_string_length = StringStatsData::MAXIMUM_MIN_STRING_LENGTH; + total_string_length = 0; + } + } + + inline void Update(const string_t &value) { + SetHasValid(); + if (is_geometry) { + geometry_stats.Update(value); + return; + } + auto data = const_data_ptr_cast(value.GetData()); + auto size = value.GetSize(); + + auto copy_count = MinValue(size, StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE); + if (is_set) { + // compare to current min/max + auto min_cmp_count = MinValue(copy_count, min_size); + auto min_cmp = memcmp(data, min, min_cmp_count); + if (min_cmp < 0 || (min_cmp == 0 && size < min_size)) { + memcpy(min, data, copy_count); + min_size = size; + } + auto max_cmp_count = MinValue(copy_count, max_size); + int max_cmp = memcmp(data, max, max_cmp_count); + if (max_cmp > 0 || (max_cmp == 0 && size > max_size)) { + memcpy(max, data, copy_count); + max_size = size; + } + } else { + memcpy(min, data, copy_count); + memcpy(max, data, copy_count); + min_size = size; + max_size = size; + is_set = true; + } + if (size > max_string_length) { + max_string_length = UnsafeNumericCast(size); + } + if (size < min_string_length) { + min_string_length = UnsafeNumericCast(size); + } + total_string_length += size; + if (is_varchar && !has_unicode) { + auto unicode = Utf8Proc::Analyze(const_char_ptr_cast(data), size); + if (unicode == UnicodeType::UTF8) { + has_unicode = true; + } else if (unicode == UnicodeType::INVALID) { + throw ErrorManager::InvalidUnicodeError(string(const_char_ptr_cast(data), size), + "segment statistics update"); + } + } + } + + void Merge(BaseStatistics &other) const { + MergeBase(other); + if (is_geometry) { + GeometryStats::GetDataUnsafe(other).Merge(geometry_stats); + } else { + StringStats::Merge(other, *this); + } + } + +private: + data_t min[StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE]; + data_t max[StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE]; + idx_t min_size; + idx_t max_size; + bool is_set; + bool has_unicode; + uint32_t max_string_length; + uint32_t min_string_length; + idx_t total_string_length; + bool is_varchar; + bool is_geometry; + GeometryStatsData geometry_stats; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp index 5e8985f9c..3aa3d8c03 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp @@ -13,26 +13,53 @@ #include "duckdb/common/enums/filter_propagate_result.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/array_ptr.hpp" +#include "duckdb/storage/statistics/stats_merge_type.hpp" namespace duckdb { class BaseStatistics; struct SelectionVector; class Vector; class Value; +template +struct StatsWriter; + +enum class StringStatsType { + NO_STATS, // no min/max stats available at all + EMPTY_STATS, // min/max is empty + EXACT_STATS, // min/max is exact + TRUNCATED_STATS // min/max is truncated, so the min/max values here might be wider than the values in the dataset +}; struct StringStatsData { - constexpr static uint32_t MAX_STRING_MINMAX_SIZE = 8; + constexpr static uint32_t CURRENT_MAX_STRING_MINMAX_SIZE = 12; + constexpr static uint32_t LEGACY_MAX_STRING_MINMAX_SIZE = 8; + constexpr static uint32_t MAXIMUM_MIN_STRING_LENGTH = (1 << 21U) - 1; - //! The minimum value of the segment, potentially truncated - data_t min[MAX_STRING_MINMAX_SIZE]; - //! The maximum value of the segment, potentially truncated - data_t max[MAX_STRING_MINMAX_SIZE]; + //! The minimum value of the segment, potentially truncated depending on min_type + string_t min; + //! The maximum value of the segment, potentially truncated depending on max_type + string_t max; //! Whether or not the column can contain unicode characters bool has_unicode; //! Whether or not the maximum string length is known bool has_max_string_length; //! The maximum string length in bytes uint32_t max_string_length; + //! Whether or not the maximum string length is known + bool has_min_string_length; + //! The minimum string length in bytes + uint32_t min_string_length; + //! WHether or not the total string length is known + bool has_total_string_length; + //! The total (summed) length of all strings + uint64_t total_string_length; + //! Min stats type + StringStatsType min_type; + //! Max stats type + StringStatsType max_type; + +public: + bool HasMinStringLength() const; }; struct StringStats { @@ -42,16 +69,38 @@ struct StringStats { DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); //! Returns true if the stats has both a min and max value defined DUCKDB_API static bool HasMinMax(const BaseStatistics &stats); + DUCKDB_API static bool HasMin(const BaseStatistics &stats); + DUCKDB_API static bool HasMax(const BaseStatistics &stats); //! Whether or not the statistics have a maximum string length defined DUCKDB_API static bool HasMaxStringLength(const BaseStatistics &stats); //! Returns the maximum string length, or throws an exception if !HasMaxStringLength() DUCKDB_API static uint32_t MaxStringLength(const BaseStatistics &stats); + //! Returns the minimum string length if it is known + DUCKDB_API static optional_idx MinStringLength(const BaseStatistics &stats); + //! Returns the total string length if known + DUCKDB_API static optional_idx TotalStringLength(const BaseStatistics &stats); //! Whether or not the strings can contain unicode DUCKDB_API static bool CanContainUnicode(const BaseStatistics &stats); - //! Returns the min value (up to a length of StringStatsData::MAX_STRING_MINMAX_SIZE) + //! Returns the min value DUCKDB_API static string Min(const BaseStatistics &stats); - //! Returns the max value (up to a length of StringStatsData::MAX_STRING_MINMAX_SIZE) + //! Returns the max value DUCKDB_API static string Max(const BaseStatistics &stats); + //! Returns a valid UTF-8 lower bound for the min value, or NULL if stats have no min + DUCKDB_API static Value TryGetValidMin(const BaseStatistics &stats); + //! Returns a valid UTF-8 upper bound for the max value, or NULL if no finite bound exists + DUCKDB_API static Value TryGetValidMax(const BaseStatistics &stats); + + //! Construct string stats from a constant + DUCKDB_API static void FromConstant(BaseStatistics &stats, string_t input); + + DUCKDB_API static void MergeInConstant(BaseStatistics &stats, string_t input); + + [[deprecated("StringStats::Update is deprecated. Use StringStats::MergeInConstant for non-performance-sensitive " + "code, or StatsWriter instead")]] DUCKDB_API static void + Update(BaseStatistics &stats, const string_t &value); + + DUCKDB_API static StringStatsType GetMinType(const BaseStatistics &stats); + DUCKDB_API static StringStatsType GetMaxType(const BaseStatistics &stats); //! Resets the max string length so HasMaxStringLength() is false DUCKDB_API static void ResetMaxStringLength(BaseStatistics &stats); @@ -67,19 +116,33 @@ struct StringStats { DUCKDB_API static FilterPropagateResult CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, array_ptr constants); - DUCKDB_API static FilterPropagateResult CheckZonemap(const_data_ptr_t min_data, idx_t min_len, - const_data_ptr_t max_data, idx_t max_len, - ExpressionType comparison_type, const string &value); + DUCKDB_API static FilterPropagateResult CheckZonemap(string_t min, StringStatsType min_type, string_t max, + StringStatsType max_type, ExpressionType comparison_type, + string_t constant); - DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); - DUCKDB_API static void SetMin(BaseStatistics &stats, const string_t &value); - DUCKDB_API static void SetMax(BaseStatistics &stats, const string_t &value); - DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); - DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + [[deprecated("StringStats::SetMin without specifying StringStatsType is deprecated - specify either " + "StringStatsType::TRUNCATED_STATS or StringStatsType::EXACT_STATS")]] DUCKDB_API static void + SetMin(BaseStatistics &stats, const string_t &value); + [[deprecated("StringStats::SetMax without specifying StringStatsType is deprecated - specify either " + "StringStatsType::TRUNCATED_STATS or StringStatsType::EXACT_STATS")]] DUCKDB_API static void + SetMax(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void SetMin(BaseStatistics &stats, const string_t &value, StringStatsType type); + DUCKDB_API static void SetMax(BaseStatistics &stats, const string_t &value, StringStatsType type); + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other, + StatsMergeType merge_type = StatsMergeType::MERGE_STATS); + DUCKDB_API static void Merge(BaseStatistics &stats, const StatsWriter &other); + DUCKDB_API static void Merge(BaseStatistics &stats, const StringStatsData &other_data, + StatsMergeType merge_type = StatsMergeType::MERGE_STATS); + DUCKDB_API static void Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, + idx_t count); + DUCKDB_API static void Copy(BaseStatistics &stats, const BaseStatistics &other); private: static StringStatsData &GetDataUnsafe(BaseStatistics &stats); static const StringStatsData &GetDataUnsafe(const BaseStatistics &stats); + static string_t AssignString(BaseStatistics &stats, const string_t &input, bool is_min); + static void MergeStats(BaseStatistics &stats, string_t &target, StringStatsType &target_type, + const string_t &source, StringStatsType source_type, bool is_min); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/statistics/struct_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/struct_stats.hpp index 512eb0431..5e00d14d0 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/struct_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/struct_stats.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/storage/statistics/stats_merge_type.hpp" namespace duckdb { class BaseStatistics; @@ -35,9 +36,11 @@ struct StructStats { DUCKDB_API static child_list_t ToStruct(const BaseStatistics &stats); - DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other, + StatsMergeType merge_type = StatsMergeType::MERGE_STATS); DUCKDB_API static void Copy(BaseStatistics &stats, const BaseStatistics &other); - DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + DUCKDB_API static void Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, + idx_t count); DUCKDB_API static unique_ptr PushdownExtract(const BaseStatistics &stats, const StorageIndex &index); diff --git a/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp index 9c5502f8f..1b777fbba 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp @@ -91,7 +91,8 @@ struct VariantStats { DUCKDB_API static child_list_t ToStruct(const BaseStatistics &stats); DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); - DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + DUCKDB_API static void Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, + idx_t count); DUCKDB_API static void Copy(BaseStatistics &stats, const BaseStatistics &other); DUCKDB_API static unique_ptr PushdownExtract(const BaseStatistics &stats, const StorageIndex &index); diff --git a/src/duckdb/src/include/duckdb/storage/storage_info.hpp b/src/duckdb/src/include/duckdb/storage/storage_info.hpp index 6342eed63..2ea8bd8d5 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_info.hpp @@ -12,10 +12,14 @@ #include "duckdb/common/limits.hpp" #include "duckdb/common/string.hpp" #include "duckdb/common/vector_size.hpp" +#include "duckdb/common/optional_idx.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/main/query_context.hpp" namespace duckdb { struct FileHandle; +class MemoryMappedFile; class QueryContext; //! The standard row group size @@ -68,15 +72,163 @@ struct Storage { static void VerifyBlockHeaderSize(const idx_t block_header_size); }; +// These sections are automatically generated by scripts/generate_storage_info.py +// Do not edit them manually, your changes will be overwritten +// clang-format off +// START OF ENUM VERSION INFO +enum class StorageVersion : uint64_t { + V0_0_4 = 1, + V0_1_0 = 1, + V0_1_1 = 1, + V0_1_2 = 1, + V0_1_3 = 1, + V0_1_4 = 1, + V0_1_5 = 1, + V0_1_6 = 1, + V0_1_7 = 1, + V0_1_8 = 1, + V0_1_9 = 1, + V0_2_0 = 1, + V0_2_1 = 1, + V0_2_2 = 4, + V0_2_3 = 6, + V0_2_4 = 11, + V0_2_5 = 13, + V0_2_6 = 15, + V0_2_7 = 17, + V0_2_8 = 18, + V0_2_9 = 21, + V0_3_0 = 25, + V0_3_1 = 27, + V0_3_2 = 31, + V0_3_3 = 33, + V0_3_4 = 33, + V0_3_5 = 33, + V0_4_0 = 33, + V0_5_0 = 38, + V0_5_1 = 38, + V0_6_0 = 39, + V0_6_1 = 39, + V0_7_0 = 43, + V0_7_1 = 43, + V0_8_0 = 51, + V0_8_1 = 51, + V0_9_0 = 64, + V0_9_1 = 64, + V0_9_2 = 64, + V0_10_0 = 64, + V0_10_1 = 64, + V0_10_2 = 64, + V0_10_3 = 64, + V1_0_0 = 64, + V1_1_0 = 64, + V1_1_1 = 64, + V1_1_2 = 64, + V1_1_3 = 64, + V1_2_0 = 65, + V1_2_1 = 65, + V1_2_2 = 65, + V1_3_0 = 66, + V1_3_1 = 66, + V1_3_2 = 66, + V1_4_0 = 67, + V1_4_1 = 67, + V1_4_2 = 67, + V1_4_3 = 67, + V1_4_4 = 67, + V1_5_0 = 68, + V1_5_1 = 68, + V1_5_2 = 68, + V1_5_3 = 68, + V2_0_0 = 69, + LATEST = 69, + DEPRECATED = 999, + INVALID = 0 +}; +// END OF ENUM VERSION INFO + +// START OF SER_ENUM VERSION INFO +enum class SerializationVersionDeprecated : uint64_t { + V0_10_0 = 1, + V0_10_1 = 1, + V0_10_2 = 1, + V0_10_3 = 2, + V1_0_0 = 2, + V1_1_0 = 3, + V1_1_1 = 3, + V1_1_2 = 3, + V1_1_3 = 3, + V1_2_0 = 4, + V1_2_1 = 4, + V1_2_2 = 4, + V1_3_0 = 5, + V1_3_1 = 5, + V1_3_2 = 5, + V1_4_0 = 6, + V1_4_1 = 6, + V1_4_2 = 6, + V1_4_3 = 6, + V1_4_4 = 6, + V1_5_0 = 7, + V1_5_1 = 7, + V1_5_2 = 7, + V1_5_3 = 7, + V2_0_0 = 8, + LATEST = 8, + INVALID = UINT64_MAX +}; +// END OF SER_ENUM VERSION INFO +// clang-format on + +struct StorageVersionInfo { + // When the default storage version has to be updated, do it here + static constexpr StorageVersion DEFAULT_STORAGE_VERSION_INFO = StorageVersion::V0_10_2; + + const char *version_name; + StorageVersion storage_version; + + static string GetStorageVersionString(const StorageVersion &version); + + static constexpr StorageVersion GetStorageVersionDefault() { + return DEFAULT_STORAGE_VERSION_INFO; + }; + + static constexpr idx_t GetStorageVersionDefaultInt() { + return static_cast(DEFAULT_STORAGE_VERSION_INFO); + }; + + static constexpr StorageVersion Invalid() { + return StorageVersion::INVALID; + }; + + static constexpr idx_t InvalidStorageVersionInt() { + return static_cast(StorageVersion::INVALID); + }; +}; + +struct SerializationVersionInfo { + const char *version_name; + SerializationVersionDeprecated storage_version; + + static constexpr uint64_t GetSerializationVersionValue(SerializationVersionDeprecated version) { + return static_cast(version); + } + + static constexpr uint64_t Invalid() { + return GetSerializationVersionValue(SerializationVersionDeprecated::INVALID); + }; +}; + //! The version number default, lower and upper bounds of the database storage format extern const uint64_t VERSION_NUMBER; extern const uint64_t VERSION_NUMBER_LOWER; extern const uint64_t VERSION_NUMBER_UPPER; -string GetDuckDBVersions(const idx_t version_number); -optional_idx GetStorageVersion(const char *version_string); -string GetStorageVersionName(const idx_t serialization_version, const bool add_suffix); -optional_idx GetSerializationVersion(const char *version_string); -vector GetSerializationCandidates(); +string GetDuckDBVersions(const StorageVersion version_number); +string GetStorageVersionNameInternal(const idx_t storage_version); +string GetStorageVersionName(const StorageVersion storage_version, const bool add_suffix); +idx_t GetSerializationVersionDeprecated(const char *version_string); +StorageVersion GetStorageVersion(const char *version_string); +vector GetStorageCandidates(); //! The MainHeader is the first header in the storage file. //! It is written only once for a database file. @@ -96,8 +248,10 @@ class MainHeader { //! The canary should be "DUCKKEY". static const char CANARY[]; - //! The (storage) version of the database. - uint64_t version_number; + //! The main header storage version is now deprecated + static constexpr StorageVersion DEPRECATED_VERSION_NUMBER = StorageVersion::DEPRECATED; + + idx_t version_number; //! The set of flags used by the database. uint64_t flags[FLAG_COUNT]; //! Encryption version @@ -117,6 +271,7 @@ class MainHeader { static constexpr uint64_t AES_TAG_LEN = 16; static void CheckMagicBytes(QueryContext context, FileHandle &handle); + static void CheckMagicBytes(MemoryMappedFile &handle); string LibraryGitDesc() { return string(char_ptr_cast(library_git_desc), MAX_VERSION_SIZE); @@ -225,11 +380,13 @@ struct DatabaseHeader { idx_t block_alloc_size = 0; //! The vector size of the database file idx_t vector_size = 0; - //! The serialization compatibility version - idx_t serialization_compatibility = 0; + //! The storage compatibility version + StorageVersion storage_compatibility = StorageVersion::INVALID; void Write(WriteStream &ser); static DatabaseHeader Read(const MainHeader &header, ReadStream &source); + static void SetStorageVersionInDatabaseHeader(DatabaseHeader &header, StorageVersion main_version, + StorageVersion read_version); }; //! Detect mismatching constant values when compiling diff --git a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp index 61a232976..249388309 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp @@ -62,6 +62,10 @@ class StorageManager { //! or the file's block allocation size, if it is an existing database. void Initialize(QueryContext context); + static bool TargetAtLeastVersion(StorageVersion version, idx_t target_version); + static bool TargetAtLeastVersion(StorageVersion version, StorageVersion target_version); + static bool IsPriorToVersion(StorageVersion version, StorageVersion target_version); + DatabaseInstance &GetDatabase(); AttachedDatabase &GetAttached() const { return db; @@ -111,15 +115,22 @@ class StorageManager { virtual BlockManager &GetBlockManager() = 0; virtual void Destroy(); - void SetStorageVersion(idx_t version) { + void SetStorageVersion(const StorageVersion &version) { storage_version = version; } + StorageVersion GetLatestStorageVersion() const { + return StorageVersion::LATEST; + } + StorageVersion GetStorageVersion() const { + D_ASSERT(HasStorageVersion()); + return storage_version; + } bool HasStorageVersion() const { - return storage_version.IsValid(); + return storage_version != StorageVersion::INVALID; } - idx_t GetStorageVersion() const { + string_t GetDuckDBVersionString() const { D_ASSERT(HasStorageVersion()); - return storage_version.GetIndex(); + return StorageVersionInfo::GetStorageVersionString(storage_version); } bool CompressionIsEnabled() const { return storage_options.compress_in_memory == CompressInMemory::COMPRESS; @@ -166,8 +177,8 @@ class StorageManager { //! When loading a database, we do not yet set the wal-field. Therefore, GetWriteAheadLog must //! return nullptr when loading a database bool load_complete = false; - //! The serialization compatibility version when reading and writing from this database - optional_idx storage_version; + //! The storage compatibility version when reading and writing from this database + StorageVersion storage_version; //! Estimated size of changes for determining automatic checkpointing on in-memory databases and databases without a //! WAL. atomic wal_size; diff --git a/src/duckdb/src/include/duckdb/storage/storage_options.hpp b/src/duckdb/src/include/duckdb/storage/storage_options.hpp index dc68c4ad9..934d1f721 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_options.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_options.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/common/optional.hpp" #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/common/encryption_state.hpp" @@ -17,18 +18,32 @@ namespace duckdb { enum class CompressInMemory { AUTOMATIC, COMPRESS, DO_NOT_COMPRESS }; +enum class FileIOMode : uint8_t { + //! Buffered pread/pwrite via FileHandle (default). + BUFFERED_IO, + //! Memory-map the file; reads/writes go through the mapped region. + MMAP, + //! Unbuffered I/O (O_DIRECT / FILE_FLAG_NO_BUFFERING / F_NOCACHE). + DIRECT_IO, +}; + struct StorageOptions { //! The allocation size of blocks for this attached database file (if any) optional_idx block_alloc_size; //! The row group size for this attached database (if any) optional_idx row_group_size; //! Target storage version (if any) - optional_idx storage_version; + StorageVersion storage_version = StorageVersion::INVALID; //! Block header size (only used for encryption) optional_idx block_header_size; CompressInMemory compress_in_memory = CompressInMemory::AUTOMATIC; + //! IO_MODE attach option; empty falls back to the default_io_mode setting. + optional io_mode; + //! MMAP_RESERVE_SIZE attach option; empty uses the StorageManager default. + optional_idx mmap_reserve_size; + //! Whether the database is encrypted bool encryption = false; //! Encryption algorithm diff --git a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp index d8ff5f1ef..161c90889 100644 --- a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp +++ b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp @@ -15,6 +15,7 @@ #include "duckdb/storage/segment/uncompressed.hpp" #include "duckdb/storage/table/column_segment.hpp" #include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/statistics/stats_writer.hpp" namespace duckdb { struct StringDictionaryContainer { @@ -59,7 +60,7 @@ struct UncompressedStringStorage { public: static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); - static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); + static bool StringAnalyze(AnalyzeState &state_p, const Vector &input); static idx_t StringFinalAnalyze(AnalyzeState &state_p); static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, @@ -74,25 +75,25 @@ struct UncompressedStringStorage { optional_ptr segment_state); static unique_ptr StringInitAppend(ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); // This block was initialized in StringInitSegment - auto handle = buffer_manager.Pin(segment.block); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); return make_uniq(std::move(handle)); } - static idx_t StringAppend(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, + static idx_t StringAppend(CompressionAppendState &append_state, ColumnSegment &segment, BaseStatistics &stats, UnifiedVectorFormat &data, idx_t offset, idx_t count) { return StringAppendBase(append_state.handle, segment, stats, data, offset, count); } - static idx_t StringAppendBase(ColumnSegment &segment, SegmentStatistics &stats, UnifiedVectorFormat &data, + static idx_t StringAppendBase(ColumnSegment &segment, BaseStatistics &stats, UnifiedVectorFormat &data, idx_t offset, idx_t count) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); return StringAppendBase(handle, segment, stats, data, offset, count); } - static idx_t StringAppendBase(BufferHandle &handle, ColumnSegment &segment, SegmentStatistics &stats, + static idx_t StringAppendBase(BufferHandle &handle, ColumnSegment &segment, BaseStatistics &stats, UnifiedVectorFormat &data, idx_t offset, idx_t count) { D_ASSERT(segment.GetBlockOffset() == 0); auto handle_ptr = handle.GetDataMutable(); @@ -103,18 +104,19 @@ struct UncompressedStringStorage { idx_t remaining_space = RemainingSpace(segment, handle); auto base_count = segment.count.load(); + StatsWriter stats_writer(stats.GetType()); for (idx_t i = 0; i < count; i++) { auto source_idx = data.sel->get_index(offset + i); auto target_idx = base_count + i; if (remaining_space < sizeof(int32_t)) { // string index does not fit in the block at all - segment.count += i; - return i; + count = i; + break; } remaining_space -= sizeof(int32_t); const bool is_null = !data.validity.RowIsValid(source_idx); if (is_null) { - stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); // null value is stored as a copy of the last value, this is done to be able to efficiently do the // string_length calculation if (target_idx > 0) { @@ -143,12 +145,12 @@ struct UncompressedStringStorage { } if (DUCKDB_UNLIKELY(required_space > remaining_space)) { // no space remaining: return how many tuples we ended up writing - segment.count += i; - return i; + count = i; + break; } // we have space: write the string - UpdateStringStats(stats, source_data[source_idx]); + stats_writer.Update(source_data[source_idx]); if (DUCKDB_UNLIKELY(use_overflow_block)) { // write to overflow blocks @@ -188,6 +190,7 @@ struct UncompressedStringStorage { GetDictionary(segment, handle).Verify(segment.GetBlockSize()); #endif } + stats_writer.Merge(stats); segment.count += count; return count; } @@ -197,8 +200,8 @@ struct UncompressedStringStorage { return; } // we need to decrement the dictionary size by all of the strings we are erasing - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); auto handle_ptr = handle.GetDataMutable(); auto result_data = reinterpret_cast(handle_ptr + DICTIONARY_HEADER_SIZE); auto dictionary_size = reinterpret_cast(handle_ptr); @@ -217,18 +220,9 @@ struct UncompressedStringStorage { *dictionary_size = new_dictionary_size; } - static idx_t FinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats); + static idx_t FinalizeAppend(ColumnSegment &segment, BaseStatistics &stats); public: - static inline void UpdateStringStats(SegmentStatistics &stats, const string_t &new_value) { - stats.statistics.SetHasNoNullFast(); - if (stats.statistics.GetStatsType() == StatisticsType::GEOMETRY_STATS) { - GeometryStats::Update(stats.statistics, new_value); - } else { - StringStats::Update(stats.statistics, new_value); - } - } - static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer dict); static StringDictionaryContainer GetDictionary(ColumnSegment &segment, BufferHandle &handle); static uint32_t GetDictionaryEnd(ColumnSegment &segment, BufferHandle &handle); diff --git a/src/duckdb/src/include/duckdb/storage/table/append_state.hpp b/src/duckdb/src/include/duckdb/storage/table/append_state.hpp index 1e42a6bdf..f64d4ad04 100644 --- a/src/duckdb/src/include/duckdb/storage/table/append_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/append_state.hpp @@ -38,16 +38,39 @@ struct ColumnAppendState { unique_ptr lock; //! The compression append state unique_ptr append_state; + //! Stats for the append to the current segment + unique_ptr append_stats; + //! Stats for the full append to this column + unique_ptr full_append_stats; + +public: + void InitializeStats(const LogicalType &type); + void FlushSegmentStats(); + void FinalFlush(vector> &global_stats); }; -struct RowGroupAppendState { - explicit RowGroupAppendState(TableAppendState &parent_p) : parent(parent_p) { +struct ColumnDataFinalizeAppendState { + explicit ColumnDataFinalizeAppendState(BaseStatistics &table_stats) { + global_stats.emplace_back(table_stats); } + ColumnDataFinalizeAppendState(BaseStatistics &table_stats, BaseStatistics &column_data_stats) { + global_stats.emplace_back(table_stats); + global_stats.emplace_back(column_data_stats); + } + ColumnDataFinalizeAppendState(ColumnDataFinalizeAppendState &parent, LogicalTypeId type_transform, + optional_idx child_id = optional_idx()); + + vector> global_stats; +}; + +struct RowGroupAppendState { + explicit RowGroupAppendState(TableAppendState &parent_p); + ~RowGroupAppendState(); //! The parent append state TableAppendState &parent; //! The current row_group we are appending to - RowGroup *row_group; + optional_ptr> row_group; //! The column append states unsafe_unique_array states; //! Offset within the row_group @@ -76,7 +99,7 @@ struct TableAppendState { optional_ptr> start_row_group; //! The transaction data TransactionData transaction; - //! Table statistics + //! Table statistics gathered during the Append phase - flushed to the table in FinalizeAppend TableStatistics stats; //! Cached hash vector Vector hashes; diff --git a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp index 9a293c94c..f65421b93 100644 --- a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp @@ -36,11 +36,13 @@ class ArrayColumnData : public ColumnData { void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; void InitializeAppend(ColumnAppendState &state) override; - void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; + void Append(ColumnAppendState &state, const Vector &vector, idx_t count) override; + void FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) override; void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; - void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, row_t row_id, - Vector &result, idx_t result_idx) override; + void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset) override; void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) override; void UpdateColumn(TransactionData transaction, DuckTableEntry &table_entry, const vector &column_path, @@ -61,7 +63,8 @@ class ArrayColumnData : public ColumnData { void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, - vector col_path, vector &result) override; + vector col_path, vector &result, + const ColumnSegmentInfoScanOptions &options) override; void Verify(RowGroup &parent) override; diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp index aa391c11d..e95c2171a 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp @@ -35,6 +35,7 @@ struct TableScanOptions; struct TransactionData; struct PersistentColumnData; class ValidityColumnData; +struct ColumnDataFinalizeAppendState; using column_segment_vector_t = vector>; @@ -154,18 +155,25 @@ class ColumnData : public enable_shared_from_this { //! Initialize an appending phase for this column virtual void InitializeAppend(ColumnAppendState &state); //! Append a vector of type [type] to the end of the column - virtual void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count); - //! Append a vector of type [type] to the end of the column - void Append(ColumnAppendState &state, Vector &vector, idx_t count); - virtual void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count); + virtual void Append(ColumnAppendState &state, const Vector &vector, idx_t count); + virtual void AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count); + //! Finalize appending + virtual void FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state); + void FinalizeAppend(optional_ptr table_stats, ColumnAppendState &state); + //! Finalize appending while holding stats_lock (for use by child column calls) + void FinalizeAppendLocked(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state); //! Revert a set of appends to the ColumnData virtual void RevertAppend(row_t new_count); //! Fetch the vector from the column data that belongs to this specific row virtual idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result); - //! Fetch a specific row id and append it to the vector - virtual void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, - row_t row_id, Vector &result, idx_t result_idx); + //! Fetch a batch of row offsets and append them to the vector + virtual void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset); + //! Fetches a batch of row offsets for leaf columns. + void FetchRowsAtSegmentLevel(TransactionData transaction, ColumnFetchState &state, const idx_t *offsets, + const SelectionVector &sel, idx_t count, Vector &result, idx_t result_offset); virtual void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start); @@ -196,10 +204,11 @@ class ColumnData : public enable_shared_from_this { ReadStream &source, const LogicalType &type); virtual void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, - vector &result); + vector &result, const ColumnSegmentInfoScanOptions &options); virtual void Verify(RowGroup &parent); - FilterPropagateResult CheckZonemap(const StorageIndex &index, TableFilter &filter); + FilterPropagateResult CheckZonemap(optional_ptr context, const StorageIndex &index, + TableFilter &filter); static shared_ptr CreateColumn(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, const LogicalType &type, @@ -213,7 +222,7 @@ class ColumnData : public enable_shared_from_this { protected: //! Append a transient segment - void AppendTransientSegment(SegmentLock &l, idx_t start_row); + void AppendTransientSegment(SegmentLock &l, idx_t start_row, optional_ptr prev_segment); void AppendSegment(SegmentLock &l, unique_ptr segment); void BeginScanVectorInternal(ColumnScanState &state); diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp index 95a892a00..655492839 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp @@ -68,7 +68,7 @@ class ColumnDataCheckpointer { void FinalizeCheckpoint(); private: - void ScanSegments(const std::function &callback); + void ScanSegments(const std::function &callback); vector DetectBestCompressionMethod(); void WriteToDisk(); void WritePersistentSegments(ColumnCheckpointState &state); diff --git a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp index dd836c6ec..603964509 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp @@ -39,10 +39,10 @@ enum class ColumnSegmentType : uint8_t { TRANSIENT, PERSISTENT }; class ColumnSegment : public SegmentBase { public: //! Construct a column segment. - ColumnSegment(DatabaseInstance &db, shared_ptr block, const LogicalType &type, - const ColumnSegmentType segment_type, const idx_t count, const CompressionFunction &function_p, - BaseStatistics statistics, const block_id_t block_id_p, const idx_t offset, - const idx_t segment_size_p, unique_ptr segment_state_p = nullptr); + ColumnSegment(DatabaseInstance &db, shared_ptr block, const ColumnSegmentType segment_type, + const idx_t count, const CompressionFunction &function_p, BaseStatistics statistics, + const block_id_t block_id_p, const idx_t offset, const idx_t segment_size_p, + unique_ptr segment_state_p = nullptr); //! Construct a column segment from another column segment. //! The other column segment becomes invalid (std::move). ColumnSegment(ColumnSegment &other); @@ -50,8 +50,8 @@ class ColumnSegment : public SegmentBase { public: static unique_ptr CreatePersistentSegment(DatabaseInstance &db, BlockManager &block_manager, - block_id_t id, idx_t offset, const LogicalType &type_p, - idx_t count, CompressionType compression_type, + block_id_t id, idx_t offset, idx_t count, + CompressionType compression_type, BaseStatistics statistics, unique_ptr segment_state); static unique_ptr CreateTransientSegment(DatabaseInstance &db, const CompressionFunction &function, @@ -105,17 +105,21 @@ class ColumnSegment : public SegmentBase { //! Gets a data pointer from a persistent column segment DataPointer GetDataPointer(idx_t row_start); - block_id_t GetBlockId() { + block_id_t GetBlockId() const { D_ASSERT(segment_type == ColumnSegmentType::PERSISTENT); return block_id; } + const LogicalType &GetType() const { + return stats.statistics.GetType(); + } + //! Returns the size of the underlying block of the segment. It is size is the size available for usage on a block. idx_t GetBlockSize() const { return block->GetBlockSize(); } - idx_t GetBlockOffset() { + idx_t GetBlockOffset() const { D_ASSERT(segment_type == ColumnSegmentType::PERSISTENT || offset == 0); return offset; } @@ -124,27 +128,43 @@ class ColumnSegment : public SegmentBase { return segment_state.get(); } + DatabaseInstance &GetDatabase() const { + return db; + } + + ColumnSegmentType GetSegmentType() const { + return segment_type; + } + + void SetSegmentType(ColumnSegmentType type) { + segment_type = type; + } + + shared_ptr &GetBlockHandle() { + return block; + } + void VisitBlockIds(BlockIdVisitor &visitor) const; + const BaseStatistics &GetStats() const { + return stats.statistics; + } + + BaseStatistics &GetStatsMutable() { + return stats.statistics; + } + private: void Scan(ColumnScanState &state, idx_t scan_count, Vector &result); void ScanPartial(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); -public: +private: //! The database instance DatabaseInstance &db; - //! The type stored in the column - LogicalType type; - //! The size of the type - idx_t type_size; //! The column segment type (transient or persistent) ColumnSegmentType segment_type; - //! The statistics for the segment - SegmentStatistics stats; //! The block that this segment relates to shared_ptr block; - -private: //! The compression function reference function; //! The block id that this segment relates to (persistent segment only) @@ -155,6 +175,8 @@ class ColumnSegment : public SegmentBase { idx_t segment_size; //! Storage associated with the compressed segment unique_ptr segment_state; + //! The statistics for the segment + SegmentStatistics stats; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp b/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp index 52ed5f130..a09837468 100644 --- a/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp @@ -21,7 +21,8 @@ struct DataTableInfo { friend class DataTable; public: - DataTableInfo(AttachedDatabase &db, shared_ptr table_io_manager_p, string schema, string table); + DataTableInfo(AttachedDatabase &db, shared_ptr table_io_manager_p, Identifier schema, + Identifier table); //! Bind unknown indexes throwing an exception if binding fails. //! Only binds the specified index type, or all, if nullptr. @@ -42,7 +43,7 @@ struct DataTableInfo { return indexes; } //! Find and move out an IndexStorageInfo by name from the stored collection. - IndexStorageInfo ExtractIndexStorageInfo(const string &name); + IndexStorageInfo ExtractIndexStorageInfo(const Identifier &name); unique_ptr GetSharedLock() { return checkpoint_lock.GetSharedLock(); } @@ -50,9 +51,9 @@ struct DataTableInfo { optional_idx CheckpointRowGroupCount(const CheckpointOptions &options) const; void VerifyIndexBuffers(); - string GetSchemaName(); - string GetTableName(); - void SetTableName(string name); + Identifier GetSchemaName(); + Identifier GetTableName(); + void SetTableName(Identifier name); private: //! The database instance of the table @@ -62,9 +63,9 @@ struct DataTableInfo { //! Lock for modifying the name mutex name_lock; //! The schema of the table - string schema; + Identifier schema; //! The name of the table - string table; + Identifier table; //! The physical list of indexes of this table TableIndexList indexes; //! Index storage information of the indexes created by this table diff --git a/src/duckdb/src/include/duckdb/storage/table/geo_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/geo_column_data.hpp index 0112b3fba..8d7ef912d 100644 --- a/src/duckdb/src/include/duckdb/storage/table/geo_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/geo_column_data.hpp @@ -42,11 +42,13 @@ class GeoColumnData final : public ColumnData { void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; void InitializeAppend(ColumnAppendState &state) override; - void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; + void Append(ColumnAppendState &state, const Vector &vector, idx_t count) override; + void FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) override; void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; - void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, row_t row_id, - Vector &result, idx_t result_idx) override; + void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset) override; void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) override; void UpdateColumn(TransactionData transaction, DuckTableEntry &table_entry, const vector &column_path, @@ -65,16 +67,17 @@ class GeoColumnData final : public ColumnData { void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, - vector &result) override; + vector &result, const ColumnSegmentInfoScanOptions &options) override; void Verify(RowGroup &parent) override; void VisitBlockIds(BlockIdVisitor &visitor) const override; private: - static void Specialize(Vector &source, Vector &target, idx_t count, GeometryStorageType storage_type); - static void Reassemble(Vector &source, Vector &target, idx_t count, GeometryStorageType storage_type, idx_t offset); - static void InterpretStats(BaseStatistics &source, BaseStatistics &target, GeometryType geom_type, + static void Specialize(const Vector &source, Vector &target, idx_t count, GeometryStorageType storage_type); + static void Reassemble(const Vector &source, Vector &target, idx_t count, GeometryStorageType storage_type, + idx_t offset); + static void InterpretStats(const BaseStatistics &source, BaseStatistics &target, GeometryType geom_type, VertexType vert_type); }; diff --git a/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp b/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp index c9bd5bae7..0c801c612 100644 --- a/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp @@ -50,7 +50,7 @@ class InMemoryTableDataWriter : public TableDataWriter { public: void WriteUnchangedTable(MetaBlockPointer pointer, const vector &metadata_pointers, - idx_t total_rows) override; + idx_t total_rows, idx_t next_row_id) override; void FinalizeTable(const TableStatistics &global_stats, DataTableInfo &info, RowGroupCollection &collection, Serializer &serializer) override; unique_ptr GetRowGroupWriter(RowGroup &row_group) override; diff --git a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp index 6742b6651..b56904a7f 100644 --- a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp @@ -34,11 +34,13 @@ class ListColumnData : public ColumnData { void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; void InitializeAppend(ColumnAppendState &state) override; - void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; + void Append(ColumnAppendState &state, const Vector &vector, idx_t count) override; + void FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) override; void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; - void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, row_t row_id, - Vector &result, idx_t result_idx) override; + void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset) override; void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) override; void UpdateColumn(TransactionData transaction, DuckTableEntry &table_entry, const vector &column_path, @@ -59,7 +61,8 @@ class ListColumnData : public ColumnData { void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, - vector col_path, vector &result) override; + vector col_path, vector &result, + const ColumnSegmentInfoScanOptions &options) override; void SetValidityData(shared_ptr validity_p); void SetChildData(shared_ptr child_column_p); diff --git a/src/duckdb/src/include/duckdb/storage/table/per_column_metadata_blocks.hpp b/src/duckdb/src/include/duckdb/storage/table/per_column_metadata_blocks.hpp new file mode 100644 index 000000000..b7fb99f16 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/per_column_metadata_blocks.hpp @@ -0,0 +1,56 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/per_column_metadata_blocks.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/vector.hpp" + +namespace duckdb { + +class Serializer; +class Deserializer; + +struct PerColumnMetadataBlock { + bool is_column_index : 1; + idx_t index : 63; + + idx_t GetPacked(); + + static PerColumnMetadataBlock Unpack(idx_t packed); +}; + +class PerColumnMetadataBlocks { +public: + //! Get block IDs for specific columns (linear scan), returns one vector per requested column + vector> GetBlocksForColumns(const vector &columns) const; + + //! Add a column entry with its block IDs + void AddColumn(idx_t col_idx, const vector &blocks); + //! Remove a column entry and all its block IDs (linear scan) + void RemoveColumn(idx_t col_idx); + //! Merge two PerColumnMetadataBlocks sorted by column index with disjoint column sets + static PerColumnMetadataBlocks Merge(const PerColumnMetadataBlocks &a, const PerColumnMetadataBlocks &b); + + //! Iterate over all block IDs, passing (column_index, block_id) to the callback + template + void ForEachBlock(Func func) const { + idx_t current_col = 0; + for (auto &entry : data) { + if (entry.is_column_index) { + current_col = entry.index; + } else { + func(current_col, entry.index); + } + } + } + + vector data; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp b/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp index a8b91ccfc..74acb26df 100644 --- a/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp @@ -23,6 +23,7 @@ class PersistentTableData { vector read_metadata_pointers; TableStatistics table_stats; idx_t total_rows; + idx_t next_row_id; idx_t row_group_count; MetaBlockPointer block_pointer; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp index 7f23ddeaa..384d5abf9 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp @@ -11,7 +11,9 @@ #include "duckdb/storage/table/chunk_info.hpp" #include "duckdb/storage/statistics/segment_statistics.hpp" #include "duckdb/common/types/data_chunk.hpp" +#include "duckdb/common/enums/column_segment_info_scan_type.hpp" #include "duckdb/common/enums/scan_options.hpp" +#include "duckdb/storage/table/per_column_metadata_blocks.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/parser/column_list.hpp" #include "duckdb/storage/table/segment_base.hpp" @@ -54,6 +56,7 @@ class StorageCommitState; template struct SegmentNode; enum class ColumnDataType; +class ClientContext; struct RowGroupWriteInfo { RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, @@ -75,13 +78,17 @@ struct RowGroupWriteInfo { optional_ptr>> column_partial_block_managers; }; +enum class RowGroupWriteAction { + REUSE_EXISTING_ROW_GROUP_METADATA, + PARTIALLY_REUSE_COLUMN_METADATA, + FULLY_CHECKPOINT_ROW_GROUP +}; + struct RowGroupWriteData { shared_ptr result_row_group; vector> states; vector statistics; - vector keep_column_loaded; - bool reuse_existing_metadata_blocks = false; - vector existing_extra_metadata_blocks; + RowGroupWriteAction write_action = RowGroupWriteAction::FULLY_CHECKPOINT_ROW_GROUP; optional_idx write_count; }; @@ -110,11 +117,13 @@ class RowGroup : public SegmentBase { RowGroupCollection &GetCollection() const { return collection.get(); } - //! Returns the list of meta block pointers used by the columns - vector GetOrComputeExtraMetadataBlocks(bool force_compute = false); + //! Compute per-column metadata blocks by reading column metadata from disk + PerColumnMetadataBlocks ComputePerColumnMetadataBlocks() const; const vector &GetColumnStartPointers() const; + vector GetExtraMetadataBlockPointers() const; + BlockManager &GetBlockManager() const; DataTableInfo &GetTableInfo() const; @@ -122,7 +131,7 @@ class RowGroup : public SegmentBase { ExpressionExecutor &executor, CollectionScanState &scan_state, SegmentNode &node, DataChunk &scan_chunk); unique_ptr AddColumn(RowGroupCollection &collection, ColumnDefinition &new_column, - ExpressionExecutor &executor, Vector &intermediate); + ExpressionExecutor &executor); unique_ptr RemoveColumn(RowGroupCollection &collection, idx_t removed_column); //! Accumulates this row group's on-disk blocks into the drop state. @@ -140,7 +149,7 @@ class RowGroup : public SegmentBase { bool InitializeScanWithOffset(CollectionScanState &state, SegmentNode &node, idx_t vector_offset); //! Checks the given set of table filters against the row-group statistics. Returns false if the entire row group //! can be skipped. - bool CheckZonemap(ScanFilterInfo &filters); + bool CheckZonemap(optional_ptr context, ScanFilterInfo &filters); //! Checks the given set of table filters against the per-segment statistics. Returns false if any segments were //! skipped. bool CheckZonemapSegments(CollectionScanState &state); @@ -149,11 +158,14 @@ class RowGroup : public SegmentBase { idx_t GetSelVector(ScanOptions options, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); - //! For a specific row, returns true if it should be used for the transaction and false otherwise. - bool Fetch(TransactionData transaction, idx_t row); - //! Fetch a specific row from the row_group and insert it into the result at the specified index - void FetchRow(TransactionData transaction, ColumnFetchState &state, const vector &column_ids, - row_t row_id, DataChunk &result, idx_t result_idx); + //! Bulk visibility check. For each offset in [0, count), writes the input index into `visible_sel` if that row is + //! visible to the transaction. Returns the number of visible rows. + idx_t Fetch(TransactionData transaction, const idx_t *offsets, idx_t count, SelectionVector &visible_sel); + //! Bulk row fetch. For each `i` in [0, visible_count), fetches the row at `offsets[visible_sel.get_index(i)]` + //! and writes every requested column into `result.data[col_idx][result_offset + i]`. + void FetchRows(TransactionData transaction, ColumnFetchState &state, const vector &column_ids, + const idx_t *offsets, const SelectionVector &visible_sel, idx_t visible_count, DataChunk &result, + idx_t result_offset); //! Append count rows to the version info void AppendVersionInfo(TransactionData transaction, idx_t count); @@ -184,8 +196,9 @@ class RowGroup : public SegmentBase { bool IsPersistent() const; PersistentRowGroupData SerializeRowGroupInfo(idx_t row_group_start) const; - void InitializeAppend(RowGroupAppendState &append_state); + static void InitializeAppend(SegmentNode &row_group, RowGroupAppendState &append_state); void Append(RowGroupAppendState &append_state, DataChunk &chunk, idx_t append_count); + void FinalizeAppend(RowGroupAppendState &append_state); void Update(TransactionData transaction, DuckTableEntry &table_entry, DataChunk &updates, row_t *ids, idx_t offset, idx_t count, const vector &column_ids, idx_t row_group_start); @@ -200,7 +213,8 @@ class RowGroup : public SegmentBase { unique_ptr GetStatistics(idx_t column_idx) const; unique_ptr GetStatistics(const StorageIndex &column_idx) const; - void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector &result); + void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector &result, + const ColumnSegmentInfoScanOptions &options = ColumnSegmentInfoScanOptions {}); static PartitionStatistics GetPartitionStats(SegmentNode &row_group); idx_t GetAllocationSize() const { @@ -215,7 +229,7 @@ class RowGroup : public SegmentBase { RowVersionManager &GetOrCreateVersionInfo(); // Serialization - static void Serialize(RowGroupPointer &pointer, Serializer &serializer); + static void Serialize(RowGroupPointer &pointer, Serializer &serializer, bool supports_per_column_writes); static RowGroupPointer Deserialize(Deserializer &deserializer); idx_t GetRowGroupSize() const; @@ -225,7 +239,12 @@ class RowGroup : public SegmentBase { vector CheckpointDeletes(RowGroupWriter &writer); + //! Direct accessors, fall outside of general use but can be useful to some extensions + ColumnData &GetRawColumnData(const StorageIndex &c) const; + ColumnData &GetRawColumnData(storage_t c) const; + private: + void InitializeAppendInternal(RowGroupAppendState &append_state); optional_ptr GetVersionInfo(); optional_ptr GetVersionInfoIfLoaded() const; shared_ptr GetOrCreateVersionInfoPtr(); @@ -241,8 +260,12 @@ class RowGroup : public SegmentBase { void SetCount(idx_t count); bool ColumnIsLoaded(storage_t c) const; void UnloadColumn(storage_t c); + bool HasUnchangedColumns() const; + static shared_ptr CheckpointColumn(const RowGroup &row_group, idx_t column_idx, RowGroupWriteInfo &info, + RowGroupWriteData &write_data); bool HasUnloadedDeletes() const; + unique_ptr CreateNewRowGroupCopy(RowGroupCollection &new_collection, idx_t new_column_count); private: mutable mutex row_group_lock; @@ -252,6 +275,8 @@ class RowGroup : public SegmentBase { vector deletes_pointers; bool has_metadata_blocks = false; vector extra_metadata_blocks; + bool has_per_column_metadata_blocks = false; + PerColumnMetadataBlocks per_column_metadata_blocks; atomic deletes_is_loaded; atomic allocation_size; //! The row id column data (mutable because `const` can lazy load) diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp index aa20dbecd..a32141533 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp @@ -13,6 +13,7 @@ #include "duckdb/storage/statistics/column_statistics.hpp" #include "duckdb/storage/table/table_statistics.hpp" #include "duckdb/storage/storage_index.hpp" +#include "duckdb/common/enums/column_segment_info_scan_type.hpp" #include "duckdb/common/enums/index_removal_type.hpp" #include "duckdb/common/enums/row_group_append_mode.hpp" @@ -44,6 +45,15 @@ class DuckTableEntry; class RowGroupIterationHelper; class TableScanState; +//! Snapshot state used to iterate row groups without holding the row-group segment-tree +//! lock for the duration of the scan. Holding row_groups pins the snapshot alive; consistency +//! follows the same model as table scans. +struct ColumnSegmentInfoScanState { + shared_ptr row_groups; + optional_ptr> current_row_group; + ColumnSegmentInfoScanOptions options; +}; + class RowGroupCollection { public: RowGroupCollection(shared_ptr info, TableIOManager &io_manager, vector types, @@ -53,6 +63,7 @@ class RowGroupCollection { public: idx_t GetTotalRows() const; + idx_t GetNextRowId() const; idx_t GetRowGroupCount() const; Allocator &GetAllocator() const; @@ -96,8 +107,9 @@ class RowGroupCollection { //! Initialize an append with a variable number of rows. FinalizeAppend should not be called after appending is //! done. void InitializeAppend(TransactionData transaction, TableAppendState &state); - //! Appends to the row group collection. Returns true if a new row group has been created to append to - bool Append(DataChunk &chunk, TableAppendState &state); + //! Appends to the row group collection. Returns the finished row group index if a new row group has been appended + //! to + optional_idx Append(DataChunk &chunk, TableAppendState &state); //! FinalizeAppend flushes an append with a variable number of rows. void FinalizeAppend(TransactionData transaction, TableAppendState &state); void CommitAppend(transaction_t commit_id, idx_t row_start, idx_t count); @@ -135,7 +147,16 @@ class RowGroupCollection { void CommitDropTable(); vector GetPartitionStats() const; - vector GetColumnSegmentInfo(const QueryContext &context); + vector + GetColumnSegmentInfo(const QueryContext &context, + const ColumnSegmentInfoScanOptions &options = ColumnSegmentInfoScanOptions {}) const; + //! Initialize an incremental scan over column segment info, pinning the current row groups for consistency. + void InitializeColumnSegmentInfoScan(ColumnSegmentInfoScanState &state) const; + //! Append the next row group's column segment info to result. Returns false when no row groups remain. + bool ScanColumnSegmentInfo(const QueryContext &context, ColumnSegmentInfoScanState &state, + vector &result) const; + bool SupportsPerColumnWrites() const; + bool SupportsPerColumnWrites(); const vector &GetTypes() const; shared_ptr AddColumn(ClientContext &context, ColumnDefinition &new_column, @@ -175,11 +196,13 @@ class RowGroupCollection { //! Returns the total amount of segments - use sparingly, as this forces all segments to be loaded idx_t GetSegmentCount(); + //! Get a ptr to the raw segment tree. This can be useful for some extensions to have directly exposed. + shared_ptr GetRowGroups() const; + private: optional_ptr> NextUpdateRowGroup(RowGroupSegmentTree &row_groups, row_t *ids, idx_t &pos, idx_t count) const; - shared_ptr GetRowGroups() const; void SetRowGroups(shared_ptr row_groups); private: @@ -189,6 +212,9 @@ class RowGroupCollection { const idx_t row_group_size; //! The number of rows in the table atomic total_rows; + //! Next rowid offset relative to the row group tree base rowid. + //! For main table storage the base is 0, so this is also the absolute next rowid. + atomic next_row_id; //! The data table info shared_ptr info; //! The column types of the row group collection diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_reorderer.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_reorderer.hpp index 665b9392e..a1229d437 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_reorderer.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_reorderer.hpp @@ -50,7 +50,7 @@ struct OffsetPruningResult { class RowGroupReorderer { public: - explicit RowGroupReorderer(const RowGroupOrderOptions &options_p); + RowGroupReorderer(const RowGroupOrderOptions &options_p, TransactionData transaction_p); optional_ptr> GetRootSegment(RowGroupSegmentTree &row_groups); optional_ptr> GetNextRowGroup(SegmentNode &row_group); @@ -62,6 +62,7 @@ class RowGroupReorderer { private: const RowGroupOrderOptions options; + const TransactionData transaction; idx_t offset; bool initialized; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp index b5a2287f2..26562fc29 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp @@ -28,7 +28,7 @@ class RowGroupSegmentTree : public SegmentTree { } protected: - shared_ptr LoadSegment() const override; + optional> LoadSegment() const override; RowGroupCollection &collection; mutable idx_t current_row_group; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp index 2e0c1a1b0..e73bdd33f 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp @@ -32,16 +32,17 @@ class RowIdColumnData : public ColumnData { SelectionVector &sel, idx_t count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; - void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, row_t row_id, - Vector &result, idx_t result_idx) override; + void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset) override; void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; FilterPropagateResult CheckZonemap(ColumnScanState &state, TableFilter &filter) override; void InitializeAppend(ColumnAppendState &state) override; - void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; - void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; + void Append(ColumnAppendState &state, const Vector &vector, idx_t count) override; + void AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; void RevertAppend(row_t new_count) override; void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, diff --git a/src/duckdb/src/include/duckdb/storage/table/row_number_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/row_number_column_data.hpp index 60811b3fd..6599a3fab 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_number_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_number_column_data.hpp @@ -30,16 +30,17 @@ class RowNumberColumnData : public ColumnData { SelectionVector &sel, idx_t count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; - void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, row_t row_id, - Vector &result, idx_t result_idx) override; + void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset) override; void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; FilterPropagateResult CheckZonemap(ColumnScanState &state, TableFilter &filter) override; void InitializeAppend(ColumnAppendState &state) override; - void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; - void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; + void Append(ColumnAppendState &state, const Vector &vector, idx_t count) override; + void AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; void RevertAppend(row_t new_count) override; void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, diff --git a/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp b/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp index c995c0479..357973965 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_version_manager.hpp @@ -27,7 +27,8 @@ class RowVersionManager { idx_t GetRowCount(ScanOptions options, idx_t count); idx_t GetSelVector(ScanOptions options, idx_t vector_idx, SelectionVector &sel_vector, idx_t max_count); - bool Fetch(TransactionData transaction, idx_t row); + //! Bulk visibility check. Returns the number of visible rows. + idx_t GetVisibleRows(TransactionData transaction, const idx_t *offsets, idx_t count, SelectionVector &visible_sel); void AppendVersionInfo(TransactionData transaction, idx_t count, idx_t row_group_start, idx_t row_group_end); void CommitAppend(transaction_t commit_id, idx_t row_group_start, idx_t count); diff --git a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp index fb030e725..9e6a34790 100644 --- a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/common.hpp" +#include "duckdb/common/unordered_set.hpp" #include "duckdb/storage/buffer/buffer_handle.hpp" #include "duckdb/storage/storage_lock.hpp" #include "duckdb/storage/table/row_group_reorderer.hpp" @@ -227,7 +228,8 @@ class ScanFilterInfo { class CollectionScanState { public: explicit CollectionScanState(TableScanState &parent_p); - + //! The query context for this scan + QueryContext context; //! The current row_group we are scanning optional_ptr> row_group; //! The vector index within the row_group @@ -250,11 +252,14 @@ class CollectionScanState { RandomEngine random; + //! The amount of tuples considered by a scan, before applying filters + idx_t rows_scanned = 0; + //! Optional state for custom row group ordering unique_ptr reorderer; public: - void Initialize(const QueryContext &context, const vector &types); + void Initialize(const QueryContext &context_p, const vector &types); const vector &GetColumnIds(); ScanFilterInfo &GetFilterInfo(); ScanSamplingInfo &GetSamplingInfo(); @@ -326,6 +331,7 @@ class TableScanState { struct ParallelCollectionScanState { ParallelCollectionScanState(); + void AssignRowGroup(optional_ptr> row_group); optional_ptr> GetRootSegment(RowGroupSegmentTree &row_groups) const; optional_ptr> GetNextRowGroup(RowGroupSegmentTree &row_groups, SegmentNode &row_group) const; @@ -343,6 +349,13 @@ struct ParallelCollectionScanState { //! Optional state for custom row group ordering unique_ptr reorderer; + //! Subset of partition indices to scan, if null, scan all + optional_ptr> partitions_to_scan; + + //! Whether this row group should be scanned + bool ShouldScanPartition(SegmentNode &row_group) const { + return !partitions_to_scan || partitions_to_scan->count(row_group.GetIndex()) > 0; + } }; struct ParallelTableScanState { diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp index 4c01fbea1..c4f888d9a 100644 --- a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/common/constants.hpp" +#include "duckdb/common/optional.hpp" #include "duckdb/storage/storage_lock.hpp" #include "duckdb/storage/table/segment_lock.hpp" #include "duckdb/common/vector.hpp" @@ -83,6 +84,17 @@ struct SegmentNode { idx_t index; }; +template +struct LoadedSegment { + LoadedSegment(shared_ptr segment_p, idx_t row_start_p) : segment(std::move(segment_p)), row_start(row_start_p) { + } + + shared_ptr segment; + idx_t row_start; +}; + +enum class SegmentTreeVerifyMode : uint8_t { CONTIGUOUS, NON_OVERLAPPING }; + //! The SegmentTree maintains a list of all segments of a specific column in a table, and allows searching for a segment //! by row number // The const-ness of the SegmentTree is implemented in an odd manner due to the lazy loading @@ -207,6 +219,10 @@ class SegmentTree { auto l = Lock(); AppendSegment(l, std::move(segment)); } + void AppendSegment(shared_ptr segment, idx_t row_start) { + auto l = Lock(); + AppendSegment(l, std::move(segment), row_start); + } void AppendSegment(SegmentLock &l, shared_ptr segment) { LoadAllSegments(l); AppendSegmentInternal(l, std::move(segment)); @@ -260,44 +276,46 @@ class SegmentTree { return false; } idx_t lower = 0; - idx_t upper = nodes.size() - 1; - // binary search to find the node - while (lower <= upper) { - idx_t index = (lower + upper) / 2; - if (index >= nodes.size()) { - string segments; - for (auto &entry : nodes) { - segments += StringUtil::Format("Start %d Count %d", entry->GetRowStart(), entry->GetCount()); - } - throw InternalException("Segment tree index not found for row number %d\nSegments:%s", row_number, - segments); - } + idx_t upper = nodes.size(); + // binary search to find the node. Searches using half-open interval [lower, upper). + while (lower < upper) { + idx_t index = lower + (upper - lower) / 2; auto &entry = *nodes[index]; if (row_number < entry.GetRowStart()) { - upper = index - 1; + // This node and all nodes to the right are excluded. + upper = index; } else if (row_number >= entry.GetRowEnd()) { + // This node and all nodes to the left are excluded. lower = index + 1; } else { + // entry.GetRowStart() <= row_number < entry.GetRowEnd() result = index; return true; } } + D_ASSERT(lower == upper); + D_ASSERT(lower == 0 || row_number >= nodes[lower - 1]->GetRowEnd()); + D_ASSERT(upper == nodes.size() || row_number < nodes[upper]->GetRowStart()); return false; } - void Verify(SegmentLock &) const { + void Verify(SegmentLock &, SegmentTreeVerifyMode mode = SegmentTreeVerifyMode::CONTIGUOUS) const { #ifdef DEBUG - idx_t base_start = nodes.empty() ? 0 : nodes[0]->GetRowStart(); + idx_t current_rowid_end = nodes.empty() ? 0 : nodes[0]->GetRowStart(); for (idx_t i = 0; i < nodes.size(); i++) { - D_ASSERT(nodes[i]->GetRowStart() == base_start); - base_start += nodes[i]->GetCount(); + if (mode == SegmentTreeVerifyMode::NON_OVERLAPPING) { + D_ASSERT(nodes[i]->GetRowStart() >= current_rowid_end); + } else { + D_ASSERT(nodes[i]->GetRowStart() == current_rowid_end); + } + current_rowid_end = nodes[i]->GetRowStart() + nodes[i]->GetCount(); } #endif } - void Verify() { + void Verify(SegmentTreeVerifyMode mode = SegmentTreeVerifyMode::CONTIGUOUS) { #ifdef DEBUG auto l = Lock(); - Verify(l); + Verify(l, mode); #endif } @@ -325,8 +343,8 @@ class SegmentTree { mutable atomic finished_loading; //! Load the next segment - only used when lazily loading - virtual shared_ptr LoadSegment() const { - return nullptr; + virtual optional> LoadSegment() const { + return nullopt; } optional_ptr> GetRootSegmentInternal() const { @@ -443,7 +461,7 @@ class SegmentTree { } auto result = LoadSegment(); if (result) { - AppendSegmentInternal(l, std::move(result)); + AppendSegmentInternal(l, std::move(result->segment), result->row_start); return true; } return false; diff --git a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp index 67b9df5dd..337bf8b5c 100644 --- a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp @@ -37,11 +37,13 @@ class StandardColumnData : public ColumnData { SelectionVector &sel, idx_t sel_count) override; void InitializeAppend(ColumnAppendState &state) override; - void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; + void AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; + void FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) override; void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; - void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, row_t row_id, - Vector &result, idx_t result_idx) override; + void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset) override; void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) override; void UpdateColumn(TransactionData transaction, DuckTableEntry &table_entry, const vector &column_path, @@ -59,7 +61,8 @@ class StandardColumnData : public ColumnData { Vector &scan_vector) const override; void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, - vector col_path, vector &result) override; + vector col_path, vector &result, + const ColumnSegmentInfoScanOptions &options) override; bool IsPersistent() override; bool HasAnyChanges() const override; @@ -69,6 +72,8 @@ class StandardColumnData : public ColumnData { void Verify(RowGroup &parent) override; void SetValidityData(shared_ptr validity); + //! Direct access to the validity column data. Intended for extensions that need to walk storage internals. + ValidityColumnData &GetValidityData(); protected: //! The validity column data diff --git a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp index 6171db1cd..f87ec4596 100644 --- a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp @@ -50,11 +50,13 @@ class StructColumnData : public ColumnData { void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; void InitializeAppend(ColumnAppendState &state) override; - void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; + void Append(ColumnAppendState &state, const Vector &vector, idx_t count) override; + void FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) override; void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; - void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, row_t row_id, - Vector &result, idx_t result_idx) override; + void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset) override; void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) override; void UpdateColumn(TransactionData transaction, DuckTableEntry &table_entry, const vector &column_path, @@ -75,7 +77,8 @@ class StructColumnData : public ColumnData { void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, - vector col_path, vector &result) override; + vector col_path, vector &result, + const ColumnSegmentInfoScanOptions &options) override; void Verify(RowGroup &parent) override; diff --git a/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp b/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp index 314aaf79c..dd7257d26 100644 --- a/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp @@ -68,11 +68,11 @@ class TableIndexList { //! Adds an index entry to the list of index entries. void AddIndex(unique_ptr index); //! Removes an index entry from the list of index entries and release any storage the index owns. - void RemoveIndex(const string &name); + void RemoveIndex(const Identifier &name); //! Returns true, if the index name does not exist. bool NameIsUnique(const string &name); //! Returns an optional pointer to the index matching the name. - optional_ptr Find(const string &name); + optional_ptr Find(const Identifier &name); //! Binds unbound indexes possibly present after loading an extension. void Bind(ClientContext &context, DataTableInfo &table_info, const char *index_type = nullptr); //! Returns true, if there are no index entries. diff --git a/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp b/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp index 579f5a70f..7f9ac2b86 100644 --- a/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/table_statistics.hpp @@ -40,7 +40,8 @@ class TableStatistics { void MergeStats(TableStatistics &other); void MergeStats(idx_t i, BaseStatistics &stats); - void MergeStats(TableStatisticsLock &lock, idx_t i, BaseStatistics &stats); + void MergeStats(TableStatisticsLock &lock, idx_t i, BaseStatistics &stats, + StatsMergeType merge_type = StatsMergeType::MERGE_STATS); void SetStats(TableStatistics &other); void CopyStats(TableStatistics &other); @@ -58,7 +59,7 @@ class TableStatistics { void DestroyTableSample(TableStatisticsLock &lock) const; void AppendToTableSample(TableStatisticsLock &lock, unique_ptr sample); - bool Empty(); + bool Empty() const; unique_ptr GetLock(); diff --git a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp index b6629fdc8..2ad77eb09 100644 --- a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp @@ -18,6 +18,7 @@ namespace duckdb { class ColumnData; class DataTable; class DuckTableEntry; +struct SelectionVector; class Vector; struct UpdateInfo; struct UpdateNode; @@ -41,7 +42,8 @@ class UpdateSegment { void FetchCommittedRange(idx_t start_row, idx_t count, Vector &result); void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update, row_t *ids, idx_t count, Vector &base_data, idx_t row_group_start); - void FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx); + void FetchRows(TransactionData transaction, const idx_t *offsets, const SelectionVector &sel, idx_t count, + Vector &result, idx_t result_offset); void RollbackUpdate(UpdateInfo &info); void CleanupUpdateInternal(const StorageLockKey &lock, UpdateInfo &info); @@ -77,8 +79,9 @@ class UpdateSegment { typedef void (*fetch_committed_function_t)(UpdateInfo &info, Vector &result); typedef void (*fetch_committed_range_function_t)(UpdateInfo &info, idx_t start, idx_t end, idx_t result_offset, Vector &result); - typedef void (*fetch_row_function_t)(transaction_t start_time, transaction_t transaction_id, UpdateInfo &info, - idx_t row_idx, Vector &result, idx_t result_idx); + typedef void (*fetch_rows_function_t)(transaction_t start_time, transaction_t transaction_id, UpdateInfo &info, + const idx_t *offsets, const SelectionVector &sel, idx_t fetch_offset, + idx_t count, idx_t vector_offset, Vector &result, idx_t result_offset); typedef void (*rollback_update_function_t)(UpdateInfo &base_info, UpdateInfo &rollback_info); typedef idx_t (*statistics_update_function_t)(UpdateSegment *segment, SegmentStatistics &stats, UnifiedVectorFormat &update, idx_t count, SelectionVector &sel); @@ -91,7 +94,7 @@ class UpdateSegment { fetch_update_function_t fetch_update_function; fetch_committed_function_t fetch_committed_function; fetch_committed_range_function_t fetch_committed_range; - fetch_row_function_t fetch_row_function; + fetch_rows_function_t fetch_rows_function; rollback_update_function_t rollback_update_function; statistics_update_function_t statistics_update_function; get_effective_updates_t get_effective_updates; @@ -101,6 +104,7 @@ class UpdateSegment { void InitializeUpdateInfo(idx_t vector_idx); void InitializeUpdateInfo(UpdateInfo &info, row_t *ids, const SelectionVector &sel, idx_t count, idx_t vector_index, idx_t vector_offset); + void ReallocateRootInfoIfNeeded(UpdateInfo ¤t_info, idx_t update_count, idx_t vector_index); }; struct UpdateNode { diff --git a/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp index d12b1c64e..027c665c3 100644 --- a/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/validity_column_data.hpp @@ -23,7 +23,7 @@ class ValidityColumnData : public ColumnData { public: FilterPropagateResult CheckZonemap(ColumnScanState &state, TableFilter &filter) override; - void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; + void AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; unique_ptr CreateCheckpointState(const RowGroup &row_group, PartialBlockManager &partial_block_manager) override; diff --git a/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp index 83243b0b6..078b6a59c 100644 --- a/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp @@ -45,11 +45,13 @@ class VariantColumnData : public ColumnData { void Skip(ColumnScanState &state, idx_t count = STANDARD_VECTOR_SIZE) override; void InitializeAppend(ColumnAppendState &state) override; - void Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) override; + void Append(ColumnAppendState &state, const Vector &vector, idx_t count) override; + void FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) override; void RevertAppend(row_t new_count) override; idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; - void FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, row_t row_id, - Vector &result, idx_t result_idx) override; + void FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t count, Vector &result, + idx_t result_offset) override; void Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t row_group_start) override; void UpdateColumn(TransactionData transaction, DuckTableEntry &table_entry, const vector &column_path, @@ -70,11 +72,12 @@ class VariantColumnData : public ColumnData { void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, - vector col_path, vector &result) override; + vector col_path, vector &result, + const ColumnSegmentInfoScanOptions &options) override; void Verify(RowGroup &parent) override; - static void ShredVariantData(Vector &input, Vector &output, idx_t count); + static void ShredVariantData(const Vector &input, Vector &output, idx_t count); void SetValidityData(shared_ptr validity_p); void SetChildData(vector> child_data); diff --git a/src/duckdb/src/include/duckdb/storage/temporary_file_manager.hpp b/src/duckdb/src/include/duckdb/storage/temporary_file_manager.hpp index e292ca28a..2651cfb32 100644 --- a/src/duckdb/src/include/duckdb/storage/temporary_file_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/temporary_file_manager.hpp @@ -97,6 +97,8 @@ struct BlockIndexManager { idx_t GetMaxIndex() const; //! Whether there are free blocks available within the file bool HasFreeBlocks() const; + //! Get the count of blocks currently holding data + idx_t GetUsedBlockCount() const; private: //! Get/set max block index diff --git a/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp b/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp index de92f3fcc..9fc5c6e10 100644 --- a/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp +++ b/src/duckdb/src/include/duckdb/storage/write_ahead_log.hpp @@ -99,7 +99,7 @@ class WriteAheadLog { void WriteCreateTrigger(const TriggerCatalogEntry &entry); void WriteDropTrigger(const TriggerCatalogEntry &entry); //! Sets the table used for subsequent insert/delete/update commands - void WriteSetTable(const string &schema, const string &table); + void WriteSetTable(const Identifier &schema, const Identifier &table); void WriteAlter(CatalogEntry &entry, const AlterInfo &info); diff --git a/src/duckdb/src/include/duckdb/transaction/commit_state.hpp b/src/duckdb/src/include/duckdb/transaction/commit_state.hpp index 23c4506e2..34c1d6a2c 100644 --- a/src/duckdb/src/include/duckdb/transaction/commit_state.hpp +++ b/src/duckdb/src/include/duckdb/transaction/commit_state.hpp @@ -35,7 +35,7 @@ enum class CommitMode { COMMIT, REVERT_COMMIT }; //! An index that has been marked for removal from a table's index list once the commit chain succeeds. struct PendingIndexRemoval { reference indexes; - string name; + Identifier name; }; //! Accumulates block marks and index removals during commit so they can be applied together once the @@ -51,7 +51,7 @@ class CommitDropState { //! Register an index to be removed from a table's index list during FinalizeCommit. Index removal will drop in //! memory index data and also marks all blocks on disk as free blocks allowing for reclamation. Block marking for //! indexes is handled implicitly along destruction paths for index memory. - void RemoveIndex(TableIndexList &indexes, string name); + void RemoveIndex(TableIndexList &indexes, Identifier name); //! Finalize accumulated block marks and index removals. void FinalizeCommit(); //! True if no work has been queued. diff --git a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp index 9f4846947..1ae24e6f0 100644 --- a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp +++ b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp @@ -83,7 +83,7 @@ class LocalTableStorage : public enable_shared_from_this { public: void InitializeScan(CollectionScanState &state, optional_ptr table_filters = nullptr); //! Write a new row group to disk (if possible) - void WriteNewRowGroup(); + void WriteNewRowGroup(idx_t flushed_row_group_idx); void FlushBlocks(); void Rollback(); idx_t EstimatedSize(); diff --git a/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp b/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp index 99416795e..2e6f56465 100644 --- a/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp +++ b/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp @@ -77,8 +77,8 @@ class MetaTransaction { const vector> &OpenedTransactions() const { return all_transactions; } - optional_ptr GetReferencedDatabase(const string &name); - shared_ptr GetReferencedDatabaseOwning(const string &name); + optional_ptr GetReferencedDatabase(const Identifier &name); + shared_ptr GetReferencedDatabaseOwning(const Identifier &name); AttachedDatabase &UseDatabase(shared_ptr &database); void DetachDatabase(AttachedDatabase &database); @@ -98,7 +98,7 @@ class MetaTransaction { //! The set of used (referenced) databases. reference_map_t> referenced_databases; //! Map of name -> database for databases that are in-use by this transaction. - case_insensitive_map_t> used_databases; + identifier_map_t> used_databases; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp b/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp index bb9016145..8a7ecfc77 100644 --- a/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp +++ b/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp @@ -56,9 +56,7 @@ class TransactionContext { void ResetActiveQuery(); void SetActiveQuery(transaction_t query_number); - void SetInvalidationPolicy(TransactionInvalidationPolicy new_invalidation_policy) { - invalidation_policy = new_invalidation_policy; - }; + void SetInvalidationPolicy(TransactionInvalidationPolicy new_invalidation_policy); TransactionInvalidationPolicy GetInvalidationPolicy() { return invalidation_policy; }; diff --git a/src/duckdb/src/include/duckdb/transaction/update_info.hpp b/src/duckdb/src/include/duckdb/transaction/update_info.hpp index 725dc329c..c1d5e2fe0 100644 --- a/src/duckdb/src/include/duckdb/transaction/update_info.hpp +++ b/src/duckdb/src/include/duckdb/transaction/update_info.hpp @@ -93,11 +93,18 @@ struct UpdateInfo { bool HasPrev() const; bool HasNext() const; static UpdateInfo &Get(UndoBufferReference &entry); - //! Returns the total allocation size for an UpdateInfo entry, together with space for the tuple data + //! Returns the total allocation size for an UpdateInfo entry with max capacity (STANDARD_VECTOR_SIZE) static idx_t GetAllocSize(idx_t type_size); + //! Returns the total allocation size for an UpdateInfo entry with a specific capacity + static idx_t GetAllocSize(idx_t type_size, idx_t capacity); + //! Computes a compact capacity for a given count (rounds up with growth headroom) + static idx_t GetCompactCapacity(idx_t count); //! Initialize an UpdateInfo struct that has been allocated using GetAllocSize (i.e. has extra space after it) static void Initialize(UpdateInfo &info, DuckTableEntry &table_entry, transaction_t transaction_id, idx_t row_group_start); + //! Initialize with a specific capacity (for compact allocations) + static void Initialize(UpdateInfo &info, DuckTableEntry &table_entry, transaction_t transaction_id, + idx_t row_group_start, idx_t capacity); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb_extension.h b/src/duckdb/src/include/duckdb_extension.h index c1e3d3603..c99af185a 100644 --- a/src/duckdb/src/include/duckdb_extension.h +++ b/src/duckdb/src/include/duckdb_extension.h @@ -78,9 +78,9 @@ typedef struct { void (*duckdb_interrupt)(duckdb_connection connection); duckdb_query_progress_type (*duckdb_query_progress)(duckdb_connection connection); void (*duckdb_disconnect)(duckdb_connection *connection); - const char *(*duckdb_library_version)(); + const char *(*duckdb_library_version)(void); duckdb_state (*duckdb_create_config)(duckdb_config *out_config); - size_t (*duckdb_config_count)(); + size_t (*duckdb_config_count)(void); duckdb_state (*duckdb_get_config_flag)(size_t index, const char **out_name, const char **out_description); duckdb_state (*duckdb_set_config)(duckdb_config config, const char *name, const char *option); void (*duckdb_destroy_config)(duckdb_config *config); @@ -97,7 +97,7 @@ typedef struct { duckdb_result_type (*duckdb_result_return_type)(duckdb_result result); void *(*duckdb_malloc)(size_t size); void (*duckdb_free)(void *ptr); - idx_t (*duckdb_vector_size)(); + idx_t (*duckdb_vector_size)(void); bool (*duckdb_string_is_inlined)(duckdb_string_t string); uint32_t (*duckdb_string_t_length)(duckdb_string_t string); const char *(*duckdb_string_t_data)(duckdb_string_t *string); @@ -235,7 +235,7 @@ typedef struct { duckdb_value (*duckdb_get_map_key)(duckdb_value value, idx_t index); duckdb_value (*duckdb_get_map_value)(duckdb_value value, idx_t index); bool (*duckdb_is_null_value)(duckdb_value value); - duckdb_value (*duckdb_create_null_value)(); + duckdb_value (*duckdb_create_null_value)(void); idx_t (*duckdb_get_list_size)(duckdb_value value); duckdb_value (*duckdb_get_list_child)(duckdb_value value, idx_t index); duckdb_value (*duckdb_create_enum_value)(duckdb_logical_type type, uint64_t value); @@ -297,7 +297,7 @@ typedef struct { void (*duckdb_validity_set_row_validity)(uint64_t *validity, idx_t row, bool valid); void (*duckdb_validity_set_row_invalid)(uint64_t *validity, idx_t row); void (*duckdb_validity_set_row_valid)(uint64_t *validity, idx_t row); - duckdb_scalar_function (*duckdb_create_scalar_function)(); + duckdb_scalar_function (*duckdb_create_scalar_function)(void); void (*duckdb_destroy_scalar_function)(duckdb_scalar_function *scalar_function); void (*duckdb_scalar_function_set_name)(duckdb_scalar_function scalar_function, const char *name); void (*duckdb_scalar_function_set_varargs)(duckdb_scalar_function scalar_function, duckdb_logical_type type); @@ -316,7 +316,7 @@ typedef struct { void (*duckdb_destroy_scalar_function_set)(duckdb_scalar_function_set *scalar_function_set); duckdb_state (*duckdb_add_scalar_function_to_set)(duckdb_scalar_function_set set, duckdb_scalar_function function); duckdb_state (*duckdb_register_scalar_function_set)(duckdb_connection con, duckdb_scalar_function_set set); - duckdb_aggregate_function (*duckdb_create_aggregate_function)(); + duckdb_aggregate_function (*duckdb_create_aggregate_function)(void); void (*duckdb_destroy_aggregate_function)(duckdb_aggregate_function *aggregate_function); void (*duckdb_aggregate_function_set_name)(duckdb_aggregate_function aggregate_function, const char *name); void (*duckdb_aggregate_function_add_parameter)(duckdb_aggregate_function aggregate_function, @@ -343,7 +343,7 @@ typedef struct { duckdb_state (*duckdb_add_aggregate_function_to_set)(duckdb_aggregate_function_set set, duckdb_aggregate_function function); duckdb_state (*duckdb_register_aggregate_function_set)(duckdb_connection con, duckdb_aggregate_function_set set); - duckdb_table_function (*duckdb_create_table_function)(); + duckdb_table_function (*duckdb_create_table_function)(void); void (*duckdb_destroy_table_function)(duckdb_table_function *table_function); void (*duckdb_table_function_set_name)(duckdb_table_function table_function, const char *name); void (*duckdb_table_function_add_parameter)(duckdb_table_function table_function, duckdb_logical_type type); @@ -417,7 +417,7 @@ typedef struct { void (*duckdb_destroy_task_state)(duckdb_task_state state); bool (*duckdb_execution_is_finished)(duckdb_connection con); duckdb_data_chunk (*duckdb_fetch_chunk)(duckdb_result result); - duckdb_cast_function (*duckdb_create_cast_function)(); + duckdb_cast_function (*duckdb_create_cast_function)(void); void (*duckdb_cast_function_set_source_type)(duckdb_cast_function cast_function, duckdb_logical_type source_type); void (*duckdb_cast_function_set_target_type)(duckdb_cast_function cast_function, duckdb_logical_type target_type); void (*duckdb_cast_function_set_implicit_cast_cost)(duckdb_cast_function cast_function, int64_t cost); @@ -529,7 +529,7 @@ typedef struct { // Exposing the instance cache #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE - duckdb_instance_cache (*duckdb_create_instance_cache)(); + duckdb_instance_cache (*duckdb_create_instance_cache)(void); duckdb_state (*duckdb_get_or_create_from_cache)(duckdb_instance_cache instance_cache, const char *path, duckdb_database *out_database, duckdb_config config, char **out_error); @@ -576,7 +576,7 @@ typedef struct { // New configuration options functions #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE - duckdb_config_option (*duckdb_create_config_option)(); + duckdb_config_option (*duckdb_create_config_option)(void); void (*duckdb_destroy_config_option)(duckdb_config_option *option); void (*duckdb_config_option_set_name)(duckdb_config_option option, const char *name); void (*duckdb_config_option_set_type)(duckdb_config_option option, duckdb_logical_type type); @@ -591,7 +591,7 @@ typedef struct { // API to define custom copy functions #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE - duckdb_copy_function (*duckdb_create_copy_function)(); + duckdb_copy_function (*duckdb_create_copy_function)(void); void (*duckdb_copy_function_set_name)(duckdb_copy_function copy_function, const char *name); void (*duckdb_copy_function_set_extra_info)(duckdb_copy_function copy_function, void *extra_info, duckdb_delete_callback_t destructor); @@ -662,7 +662,7 @@ typedef struct { duckdb_state (*duckdb_file_system_open)(duckdb_file_system file_system, const char *path, duckdb_file_open_options options, duckdb_file_handle *out_file); duckdb_error_data (*duckdb_file_system_error_data)(duckdb_file_system file_system); - duckdb_file_open_options (*duckdb_create_file_open_options)(); + duckdb_file_open_options (*duckdb_create_file_open_options)(void); duckdb_state (*duckdb_file_open_options_set_flag)(duckdb_file_open_options options, duckdb_file_flag flag, bool value); void (*duckdb_destroy_file_open_options)(duckdb_file_open_options *options); @@ -684,7 +684,7 @@ typedef struct { // API to register a custom log storage. #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE - duckdb_log_storage (*duckdb_create_log_storage)(); + duckdb_log_storage (*duckdb_create_log_storage)(void); void (*duckdb_destroy_log_storage)(duckdb_log_storage *log_storage); void (*duckdb_log_storage_set_write_log_entry)(duckdb_log_storage log_storage, duckdb_logger_write_log_entry_t function); diff --git a/src/duckdb/src/logging/log_manager.cpp b/src/duckdb/src/logging/log_manager.cpp index cb4d06e6e..b9a41e980 100644 --- a/src/duckdb/src/logging/log_manager.cpp +++ b/src/duckdb/src/logging/log_manager.cpp @@ -273,6 +273,8 @@ void LogManager::RegisterDefaultLogTypes() { RegisterLogType(make_uniq()); RegisterLogType(make_uniq()); RegisterLogType(make_uniq()); + RegisterLogType(make_uniq()); + RegisterLogType(make_uniq()); } } // namespace duckdb diff --git a/src/duckdb/src/logging/log_storage.cpp b/src/duckdb/src/logging/log_storage.cpp index 7f712cd02..ff780aa27 100644 --- a/src/duckdb/src/logging/log_storage.cpp +++ b/src/duckdb/src/logging/log_storage.cpp @@ -83,8 +83,8 @@ void LogStorage::Truncate() { } void LogStorage::UpdateConfig(DatabaseInstance &db, case_insensitive_map_t &config) { - if (config.size() > 1) { - throw InvalidInputException("LogStorage does not support passing configuration"); + if (!config.empty()) { + throw InvalidInputException("Log storage '%s' does not support passing configuration", GetStorageName()); } } @@ -141,7 +141,6 @@ void CSVLogStorage::ExecuteCast(LoggingTargetTable table, DataChunk &chunk) { for (idx_t i = 0; i < chunk.data.size(); i++) { VectorOperations::DefaultCast(chunk.data[i], cast_buffer.data[i], count, false); } - cast_buffer.SetCardinality(count); } void CSVLogStorage::ResetAllBuffers() { @@ -356,8 +355,11 @@ void FileLogStorage::Truncate() { } // Truncate the file writer file_writer->Truncate(0); - // Re-initialize the corresponding CSVWriter - GetWriter(it.first).Initialize(true); + auto &writer = GetWriter(it.first); + // Reset writer and header option, then re-initialize + writer.Reset(nullptr); + writer.options.dialect_options.header = CSVOption(true); + writer.Initialize(true); } } @@ -626,7 +628,7 @@ static void WriteLoggingContextsToChunk(DataChunk &chunk, const RegisteredLoggin FlatVector::ValidityMutable(chunk.data[col++]).SetInvalid(size); } - chunk.SetCardinality(size + 1); + chunk.SetChildCardinality(size + 1); } void BufferingLogStorage::WriteLogEntry(timestamp_t timestamp, LogLevel level, const string &log_type, @@ -669,7 +671,7 @@ void BufferingLogStorage::WriteLogEntry(timestamp_t timestamp, LogLevel level, c auto message_data = FlatVector::GetDataMutable(log_entries_buffer->data[col]); message_data[size] = StringVector::AddString(log_entries_buffer->data[col++], log_message); - log_entries_buffer->SetCardinality(size + 1); + log_entries_buffer->SetChildCardinality(size + 1); if (size + 1 >= buffer_limit) { if (normalize_contexts) { diff --git a/src/duckdb/src/logging/log_types.cpp b/src/duckdb/src/logging/log_types.cpp index 774701d9e..ccc32a63d 100644 --- a/src/duckdb/src/logging/log_types.cpp +++ b/src/duckdb/src/logging/log_types.cpp @@ -20,6 +20,8 @@ constexpr LogLevel PhysicalOperatorLogType::LEVEL; constexpr LogLevel MetricsLogType::LEVEL; constexpr LogLevel CheckpointLogType::LEVEL; constexpr LogLevel AdaptiveFilterLogType::LEVEL; +constexpr LogLevel ParquetPrefetchLogType::LEVEL; +constexpr LogLevel AsyncTaskScheduleLogType::LEVEL; //===--------------------------------------------------------------------===// // QueryLogType @@ -172,9 +174,9 @@ LogicalType MetricsLogType::GetLogType() { return LogicalType::STRUCT(child_list); } -string MetricsLogType::ConstructLogMessage(const MetricType &metric, const Value &value) { +string MetricsLogType::ConstructLogMessage(const string &metric, const Value &value) { child_list_t child_list = { - {"metric", EnumUtil::ToString(metric)}, + {"metric", metric}, {"value", value.ToString()}, }; return Value::STRUCT(std::move(child_list)).ToString(); @@ -200,9 +202,9 @@ LogicalType CheckpointLogType::GetLogType() { string CheckpointLogType::CreateLog(const AttachedDatabase &db, DataTableInfo &table, const char *op_name, vector map_keys, vector map_values) { child_list_t child_list = { - {"database", db.name}, - {"schema", table.GetSchemaName()}, - {"table", table.GetTableName()}, + {"database", db.name.GetIdentifierName()}, + {"schema", table.GetSchemaName().GetIdentifierName()}, + {"table", table.GetTableName().GetIdentifierName()}, {"type", op_name}, {"info", Value::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR, std::move(map_keys), std::move(map_values))}, }; @@ -246,7 +248,7 @@ LogicalType TransactionLogType::GetLogType() { string TransactionLogType::ConstructLogMessage(const AttachedDatabase &db, const char *log_type, transaction_t transaction_id) { child_list_t child_list = { - {"database", db.name}, + {"database", db.name.GetIdentifierName()}, {"type", log_type}, {"transaction_id", transaction_id == MAX_TRANSACTION_ID ? Value() : Value::UBIGINT(transaction_id)}, }; @@ -284,4 +286,76 @@ string AdaptiveFilterLogType::ConstructLogMessage(const char *event, const strin return Value::STRUCT(std::move(child_list)).ToString(); } +//===--------------------------------------------------------------------===// +// ParquetPrefetchLogType +//===--------------------------------------------------------------------===// +ParquetPrefetchLogType::ParquetPrefetchLogType() : LogType(NAME, LEVEL, GetLogType()) { +} + +LogicalType ParquetPrefetchLogType::GetLogType() { + child_list_t child_list = { + {"file_path", LogicalType::VARCHAR}, + {"row_group_id", LogicalType::BIGINT}, + {"fully_filtered", LogicalType::BOOLEAN}, + {"strategy", LogicalType::VARCHAR}, + {"prefetch_groups", LogicalType::LIST(LogicalType::LIST(LogicalType::VARCHAR))}, + {"minimal_filters", LogicalType::LIST(LogicalType::VARCHAR)}, + {"accepted_column_gap", LogicalType::UBIGINT}, + }; + return LogicalType::STRUCT(child_list); +} + +string ParquetPrefetchLogType::ConstructLogMessage(const string &file_path, idx_t row_group_id, bool fully_filtered, + const char *strategy, const vector> &prefetch_groups, + const vector &minimal_filters, + uint64_t accepted_column_gap) { + vector outer; + outer.reserve(prefetch_groups.size()); + for (auto &group : prefetch_groups) { + vector inner; + inner.reserve(group.size()); + for (auto &name : group) { + inner.emplace_back(name); + } + outer.push_back(Value::LIST(LogicalType::VARCHAR, std::move(inner))); + } + vector minimal; + minimal.reserve(minimal_filters.size()); + for (auto &name : minimal_filters) { + minimal.emplace_back(name); + } + child_list_t child_list = { + {"file_path", Value(file_path)}, + {"row_group_id", Value::BIGINT(static_cast(row_group_id))}, + {"fully_filtered", Value::BOOLEAN(fully_filtered)}, + {"strategy", strategy ? Value(strategy) : Value(LogicalType::VARCHAR)}, + {"prefetch_groups", Value::LIST(LogicalType::LIST(LogicalType::VARCHAR), std::move(outer))}, + {"minimal_filters", Value::LIST(LogicalType::VARCHAR, std::move(minimal))}, + {"accepted_column_gap", Value::UBIGINT(accepted_column_gap)}, + }; + return Value::STRUCT(std::move(child_list)).ToString(); +} + +//===--------------------------------------------------------------------===// +// AsyncTaskScheduleLogType +//===--------------------------------------------------------------------===// +AsyncTaskScheduleLogType::AsyncTaskScheduleLogType() : LogType(NAME, LEVEL, GetLogType()) { +} + +LogicalType AsyncTaskScheduleLogType::GetLogType() { + child_list_t child_list = { + {"pool", LogicalType::VARCHAR}, + {"task_count", LogicalType::BIGINT}, + }; + return LogicalType::STRUCT(child_list); +} + +string AsyncTaskScheduleLogType::ConstructLogMessage(const string &pool, idx_t task_count) { + child_list_t child_list = { + {"pool", Value(pool)}, + {"task_count", Value::BIGINT(static_cast(task_count))}, + }; + return Value::STRUCT(std::move(child_list)).ToString(); +} + } // namespace duckdb diff --git a/src/duckdb/src/main/appender.cpp b/src/duckdb/src/main/appender.cpp index 3e57622e8..7019d4e95 100644 --- a/src/duckdb/src/main/appender.cpp +++ b/src/duckdb/src/main/appender.cpp @@ -24,6 +24,7 @@ #include "duckdb/parser/statement/update_statement.hpp" #include "duckdb/parser/query_node/update_query_node.hpp" #include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/expression/parameter_expression.hpp" #include "duckdb/parser/tableref/expressionlistref.hpp" @@ -79,7 +80,7 @@ void BaseAppender::EndRow() { throw InvalidInputException("Call to EndRow before all columns have been appended to!"); } column = 0; - chunk.SetCardinality(chunk.size() + 1); + chunk.SetChildCardinality(chunk.size() + 1); if (ShouldFlushChunk()) { FlushChunk(); } @@ -368,7 +369,7 @@ void BaseAppender::AppendDataChunk(DataChunk &chunk_p) { auto size = chunk_p.size(); DataChunk cast_chunk; cast_chunk.Initialize(allocator, appender_types); - cast_chunk.SetCardinality(size); + cast_chunk.SetChildCardinality(size); for (idx_t i = 0; i < count; i++) { if (chunk_p.data[i].GetType() == appender_types[i]) { @@ -412,7 +413,14 @@ void BaseAppender::Flush() { return; } - FlushInternal(*collection); + try { + FlushInternal(*collection); + } catch (...) { + // Reset so the destructor does not re-attempt flushing the same data, + // which would hit assertions in the storage layer if it is inconsistent. + collection->Reset(); + throw; + } collection->Reset(); column = 0; } @@ -425,7 +433,7 @@ void BaseAppender::AppendDefault(DataChunk &chunk, idx_t col, idx_t row) { throw NotImplementedException("AppendDefault is only supported when directly appending to a table"); } -void BaseAppender::AddColumn(const string &name) { +void BaseAppender::AddColumn(const Identifier &name) { throw NotImplementedException("AddColumn is only supported when directly appending to a table"); } @@ -433,8 +441,8 @@ void BaseAppender::ClearColumns() { throw NotImplementedException("ClearColumns is only supported when directly appending to a table"); } -unique_ptr BaseAppender::GetColumnDataTableRef(ColumnDataCollection &collection, const string &table_name, - const vector &expected_names) { +unique_ptr BaseAppender::GetColumnDataTableRef(ColumnDataCollection &collection, const Identifier &table_name, + const vector &expected_names) { auto column_data_ref = make_uniq(collection); column_data_ref->alias = table_name.empty() ? "appended_data" : table_name; ; @@ -451,7 +459,7 @@ CommonTableExpressionMap &GetCTEMap(SQLStatement &statement) { case StatementType::UPDATE_STATEMENT: return statement.Cast().node->cte_map; case StatementType::MERGE_INTO_STATEMENT: - return statement.Cast().cte_map; + return statement.Cast().node->cte_map; default: throw InvalidInputException( "Unsupported statement type for appender: expected INSERT, DELETE, UPDATE or MERGE INTO"); @@ -486,7 +494,7 @@ unique_ptr BaseAppender::ParseStatement(unique_ptr table // Add the appender data as a CTE to the CTE map of the statement. string alias = table_name.empty() ? "appended_data" : table_name; auto &cte_map = GetCTEMap(*parser.statements[0]); - cte_map.map.insert(alias, std::move(cte_info)); + cte_map.map.insert(Identifier(alias), std::move(cte_info)); return std::move(parser.statements[0]); } @@ -494,8 +502,8 @@ unique_ptr BaseAppender::ParseStatement(unique_ptr table //===--------------------------------------------------------------------===// // Table Appender //===--------------------------------------------------------------------===// -Appender::Appender(Connection &con, const string &database_name, const string &schema_name, const string &table_name, - const idx_t flush_memory_threshold_p) +Appender::Appender(Connection &con, const Identifier &database_name, const Identifier &schema_name, + const Identifier &table_name, const idx_t flush_memory_threshold_p) : BaseAppender(Allocator::DefaultAllocator(), AppenderType::LOGICAL), context(con.context) { flush_memory_threshold = (flush_memory_threshold_p == DConstants::INVALID_INDEX) ? optional_idx::Invalid() @@ -556,30 +564,30 @@ Appender::Appender(Connection &con, const string &database_name, const string &s collection = make_uniq(allocator, GetActiveTypes()); } -Appender::Appender(Connection &con, const string &schema_name, const string &table_name, +Appender::Appender(Connection &con, const Identifier &schema_name, const Identifier &table_name, const idx_t flush_memory_threshold_p) - : Appender(con, INVALID_CATALOG, schema_name, table_name, flush_memory_threshold_p) { + : Appender(con, Identifier::InvalidCatalog(), schema_name, table_name, flush_memory_threshold_p) { } -Appender::Appender(Connection &con, const string &table_name, const idx_t flush_memory_threshold_p) - : Appender(con, INVALID_CATALOG, DEFAULT_SCHEMA, table_name, flush_memory_threshold_p) { +Appender::Appender(Connection &con, const Identifier &table_name, const idx_t flush_memory_threshold_p) + : Appender(con, Identifier::InvalidCatalog(), Identifier::DefaultSchema(), table_name, flush_memory_threshold_p) { } Appender::~Appender() { Destructor(); } -vector Appender::GetExpectedNames() { - vector expected_names; +vector Appender::GetExpectedNames() { + vector expected_names; for (idx_t i = 0; i < column_ids.size(); i++) { auto &col_name = description->columns[column_ids[i].index].Name(); - expected_names.push_back(col_name); + expected_names.emplace_back(col_name); } return expected_names; } -string Appender::ConstructQuery(TableDescription &description_p, const string &table_name, - const vector &expected_names) { +string Appender::ConstructQuery(TableDescription &description_p, const Identifier &table_name, + const vector &expected_names) { string query = "INSERT INTO "; if (!description_p.database.empty()) { query += StringUtil::Format("%s.", SQLIdentifier(description_p.database)); @@ -609,12 +617,12 @@ void Appender::FlushInternal(ColumnDataCollection &collection) { throw InvalidInputException("Appender: Attempting to flush data to a closed connection"); } - string table_name = "__duckdb_internal_appended_data"; + Identifier table_name("__duckdb_internal_appended_data"); auto expected_names = GetExpectedNames(); auto query = ConstructQuery(*description, table_name, expected_names); auto table_ref = GetColumnDataTableRef(collection, table_name, expected_names); - auto stmt = ParseStatement(std::move(table_ref), query, table_name); + auto stmt = ParseStatement(std::move(table_ref), query, table_name.GetIdentifierName()); context_ref->Append(std::move(stmt)); } @@ -647,7 +655,7 @@ Value Appender::GetDefaultValue(idx_t column) { return it->second; } -void Appender::AddColumn(const string &name) { +void Appender::AddColumn(const Identifier &name) { Flush(); auto exists = false; @@ -694,8 +702,8 @@ void Appender::ClearColumns() { //===--------------------------------------------------------------------===// // Query Appender //===--------------------------------------------------------------------===// -QueryAppender::QueryAppender(Connection &con, string query_p, vector types_p, vector names_p, - string table_name_p, const idx_t flush_memory_threshold_p) +QueryAppender::QueryAppender(Connection &con, string query_p, vector types_p, vector names_p, + Identifier table_name_p, const idx_t flush_memory_threshold_p) : BaseAppender(Allocator::DefaultAllocator(), AppenderType::LOGICAL), context(con.context), query(std::move(query_p)), names(std::move(names_p)), table_name(std::move(table_name_p)) { types = std::move(types_p); @@ -715,7 +723,7 @@ void QueryAppender::FlushInternal(ColumnDataCollection &collection) { throw InvalidInputException("Attempting to flush query appender data on a closed connection"); } auto table_ref = GetColumnDataTableRef(collection, table_name, names); - auto parsed_statement = ParseStatement(std::move(table_ref), query, table_name); + auto parsed_statement = ParseStatement(std::move(table_ref), query, table_name.GetIdentifierName()); context_ref->Append(std::move(parsed_statement)); } diff --git a/src/duckdb/src/main/attached_database.cpp b/src/duckdb/src/main/attached_database.cpp index d5ade6c69..6c6455bdc 100644 --- a/src/duckdb/src/main/attached_database.cpp +++ b/src/duckdb/src/main/attached_database.cpp @@ -16,11 +16,12 @@ #include "duckdb/storage/block_allocator.hpp" #include "duckdb/storage/block_manager.hpp" #include "duckdb/storage/metadata/metadata_manager.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { StoredDatabasePath::StoredDatabasePath(DatabaseManager &db_manager, DatabaseFilePathManager &manager, string path_p, - const string &name) + const Identifier &name) : db_manager(db_manager), manager(manager), path(std::move(path_p)) { } @@ -72,8 +73,10 @@ AttachOptions::AttachOptions(const unordered_map &attach_options, } if (entry.first == "type") { - // Extract the database type. - db_type = StringValue::Get(entry.second.DefaultCastAs(LogicalType::VARCHAR)); + // Extract the database type. Normalize case so that + // `TYPE sqlite` and `TYPE 'SQLite'` are equivalent. + // `TYPE sqlite` and `TYPE 'sqlite3'` are NOT equivalent, aliasing to be applied on comparison + db_type = StringUtil::Lower(StringValue::Get(entry.second.DefaultCastAs(LogicalType::VARCHAR))); continue; } @@ -89,6 +92,18 @@ AttachOptions::AttachOptions(const unordered_map &attach_options, } continue; } + + if (entry.first == "vacuum_rebuild_indexes") { + const auto threshold = UBigIntValue::Get(entry.second.DefaultCastAs(LogicalType::UBIGINT)); + try { + vacuum_rebuild_indexes_threshold = threshold; + } catch (InternalException &e) { + throw InvalidInputException("Invalid setting for vacuum_rebuild_indexes: %d (valid range is 0 - %d)", + threshold, + UBigIntValue::Get(Value::MaximumValue(LogicalType::UBIGINT)) - 1); + } + continue; + } options.emplace(entry.first, entry.second); } } @@ -98,7 +113,7 @@ AttachOptions::AttachOptions(const unordered_map &attach_options, //===--------------------------------------------------------------------===// AttachedDatabase::AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType type) : CatalogEntry(CatalogType::DATABASE_ENTRY, - type == AttachedDatabaseType::SYSTEM_DATABASE ? SYSTEM_CATALOG : TEMP_CATALOG, 0), + Identifier(type == AttachedDatabaseType::SYSTEM_DATABASE ? SYSTEM_CATALOG : TEMP_CATALOG), 0), db(db), type(type), close_lock(make_shared_ptr()) { // This database does not have storage, or uses temporary_objects for in-memory storage. D_ASSERT(type == AttachedDatabaseType::TEMP_DATABASE || type == AttachedDatabaseType::SYSTEM_DATABASE); @@ -114,7 +129,7 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType ty internal = true; } -AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, string name_p, string file_path_p, +AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, Identifier name_p, string file_path_p, AttachOptions &options) : CatalogEntry(CatalogType::DATABASE_ENTRY, catalog_p, std::move(name_p)), db(db), parent_catalog(&catalog_p), close_lock(make_shared_ptr()) { @@ -125,6 +140,7 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, str } recovery_mode = options.recovery_mode; visibility = options.visibility; + vacuum_rebuild_threshold = options.vacuum_rebuild_indexes_threshold; // We create the storage after the catalog to guarantee we allow extensions to instantiate the DuckCatalog. catalog = make_uniq(*this); @@ -136,7 +152,7 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, str } AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, StorageExtension &storage_extension_p, - ClientContext &context, string name_p, AttachInfo &info, AttachOptions &options) + ClientContext &context, Identifier name_p, AttachInfo &info, AttachOptions &options) : CatalogEntry(CatalogType::DATABASE_ENTRY, catalog_p, std::move(name_p)), db(db), parent_catalog(&catalog_p), storage_extension(&storage_extension_p), close_lock(make_shared_ptr()) { if (options.access_mode == AccessMode::READ_ONLY) { @@ -146,9 +162,10 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, Sto } recovery_mode = options.recovery_mode; visibility = options.visibility; + vacuum_rebuild_threshold = options.vacuum_rebuild_indexes_threshold; optional_ptr storage_info = storage_extension->storage_info.get(); - catalog = storage_extension->attach(storage_info, context, *this, name, info, options); + catalog = storage_extension->attach(storage_info, context, *this, name.GetIdentifierName(), info, options); stored_database_path = std::move(options.stored_database_path); if (!catalog) { throw InternalException("AttachedDatabase - attach function did not return a catalog"); @@ -185,10 +202,17 @@ bool AttachedDatabase::IsReadOnly() const { return type == AttachedDatabaseType::READ_ONLY_DATABASE; } -bool AttachedDatabase::NameIsReserved(const string &name) { +bool AttachedDatabase::NameIsReserved(const Identifier &name) { return name == DEFAULT_SCHEMA || name == TEMP_CATALOG || name == SYSTEM_CATALOG; } +idx_t AttachedDatabase::GetVacuumRebuildIndexThreshold() const { + if (vacuum_rebuild_threshold.IsValid()) { + return vacuum_rebuild_threshold.GetIndex(); + } + return Settings::Get(db); +} + string AttachedDatabase::StoredPath() const { if (stored_database_path) { return stored_database_path->path; @@ -202,15 +226,15 @@ static string RemoveQueryParams(const string &name) { return vec[0]; } -string AttachedDatabase::ExtractDatabaseName(const string &dbpath, FileSystem &fs) { +Identifier AttachedDatabase::ExtractDatabaseName(const string &dbpath, FileSystem &fs) { if (dbpath.empty() || dbpath == IN_MEMORY_PATH) { - return "memory"; + return Identifier("memory"); } auto name = RemoveQueryParams(fs.ExtractBaseName(dbpath)); - if (NameIsReserved(name)) { + if (NameIsReserved(Identifier(name))) { name += "_db"; } - return name; + return Identifier(name); } void AttachedDatabase::InvokeCloseIfLastReference(shared_ptr &attached_db, ClientContext &context) { diff --git a/src/duckdb/src/main/capi/aggregate_function-c.cpp b/src/duckdb/src/main/capi/aggregate_function-c.cpp index bd354e88c..d3a3f7145 100644 --- a/src/duckdb/src/main/capi/aggregate_function-c.cpp +++ b/src/duckdb/src/main/capi/aggregate_function-c.cpp @@ -96,7 +96,6 @@ void CAPIAggregateUpdate(Vector inputs[], AggregateInputData &aggr_input_data, i inputs[c].Flatten(); chunk.data.emplace_back(Vector::Ref(inputs[c])); } - chunk.SetCardinality(count); auto &bind_data = aggr_input_data.bind_data->Cast(); auto state_data = FlatVector::GetDataMutableUnsafe(state); @@ -149,10 +148,10 @@ void CAPIAggregateDestructor(Vector &state, AggregateInputData &aggr_input_data, using duckdb::GetCAggregateFunction; duckdb_aggregate_function duckdb_create_aggregate_function() { - auto function = new duckdb::AggregateFunction("", {}, duckdb::LogicalType::INVALID, duckdb::CAPIAggregateStateSize, - duckdb::CAPIAggregateStateInit, duckdb::CAPIAggregateUpdate, - duckdb::CAPIAggregateCombine, duckdb::CAPIAggregateFinalize, nullptr, - duckdb::CAPIAggregateBind); + auto function = new duckdb::AggregateFunction( + "", {}, duckdb::LogicalType::INVALID, duckdb::CAPIAggregateStateSize, duckdb::CAPIAggregateStateInit, + duckdb::CAPIAggregateUpdate, duckdb::CAPIAggregateCombine, duckdb::CAPIAggregateFinalize, + duckdb::FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, duckdb::CAPIAggregateBind); try { function->SetExtraFunctionInfo(); return reinterpret_cast(function); @@ -175,7 +174,7 @@ void duckdb_aggregate_function_set_name(duckdb_aggregate_function function, cons return; } auto &aggregate_function = GetCAggregateFunction(function); - aggregate_function.name = name; + aggregate_function.name = duckdb::Identifier(name); } void duckdb_aggregate_function_add_parameter(duckdb_aggregate_function function, duckdb_logical_type type) { @@ -227,7 +226,7 @@ duckdb_state duckdb_register_aggregate_function(duckdb_connection connection, du } auto &aggregate_function = GetCAggregateFunction(function); - duckdb::AggregateFunctionSet set(aggregate_function.name); + duckdb::AggregateFunctionSet set {aggregate_function.name}; set.AddFunction(aggregate_function); return duckdb_register_aggregate_function_set(connection, reinterpret_cast(&set)); } diff --git a/src/duckdb/src/main/capi/appender-c.cpp b/src/duckdb/src/main/capi/appender-c.cpp index 959db5098..1068fa124 100644 --- a/src/duckdb/src/main/capi/appender-c.cpp +++ b/src/duckdb/src/main/capi/appender-c.cpp @@ -57,8 +57,8 @@ duckdb_state duckdb_appender_create_query(duckdb_connection connection, const ch return DuckDBError; } duckdb::vector types; - duckdb::vector column_names; - duckdb::string table_name; + duckdb::vector column_names; + duckdb::Identifier table_name; for (idx_t c = 0; c < column_count; ++c) { if (!types_p[c]) { return DuckDBError; @@ -66,7 +66,7 @@ duckdb_state duckdb_appender_create_query(duckdb_connection connection, const ch types.push_back(*reinterpret_cast(types_p[c])); } if (table_name_p) { - table_name = table_name_p; + table_name = duckdb::Identifier(table_name_p); } if (column_names_p) { for (idx_t c = 0; c < column_count; ++c) { @@ -99,6 +99,11 @@ duckdb_state duckdb_appender_destroy(duckdb_appender *appender) { auto state = duckdb_appender_close(*appender); auto wrapper = reinterpret_cast(*appender); if (wrapper) { + // If flush/close previously failed, preserve that error even if close just succeeded with empty collection. + // (Non-fatal column-level errors do not set flush_failed, so those are not propagated here.) + if (wrapper->flush_failed) { + state = DuckDBError; + } delete wrapper; } *appender = nullptr; @@ -315,7 +320,14 @@ duckdb_state duckdb_append_blob(duckdb_appender appender, const void *data, idx_ } duckdb_state duckdb_appender_flush(duckdb_appender appender_p) { - return duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Flush(); }); + auto state = duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Flush(); }); + if (state == DuckDBError) { + auto wrapper = reinterpret_cast(appender_p); + if (wrapper) { + wrapper->flush_failed = true; + } + } + return state; } duckdb_state duckdb_appender_clear(duckdb_appender appender_p) { @@ -323,7 +335,14 @@ duckdb_state duckdb_appender_clear(duckdb_appender appender_p) { } duckdb_state duckdb_appender_close(duckdb_appender appender_p) { - return duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Close(); }); + auto state = duckdb_appender_run_function(appender_p, [&](BaseAppender &appender) { appender.Close(); }); + if (state == DuckDBError) { + auto wrapper = reinterpret_cast(appender_p); + if (wrapper) { + wrapper->flush_failed = true; + } + } + return state; } idx_t duckdb_appender_column_count(duckdb_appender appender) { diff --git a/src/duckdb/src/main/capi/arrow-c.cpp b/src/duckdb/src/main/capi/arrow-c.cpp index cf8037e5c..93565561d 100644 --- a/src/duckdb/src/main/capi/arrow-c.cpp +++ b/src/duckdb/src/main/capi/arrow-c.cpp @@ -106,7 +106,7 @@ duckdb_error_data duckdb_data_chunk_from_arrow(duckdb_connection connection, str dchunk->Initialize(duckdb::Allocator::DefaultAllocator(), types, duckdb::NumericCast(arrow_array->length)); auto &arrow_types = arrow_table->GetColumns(); - dchunk->SetCardinality(duckdb::NumericCast(arrow_array->length)); + dchunk->SetChildCardinality(duckdb::NumericCast(arrow_array->length)); for (idx_t i = 0; i < dchunk->ColumnCount(); i++) { auto &parent_array = *arrow_array; auto &array = parent_array.children[i]; @@ -422,7 +422,7 @@ duckdb_state Ingest(duckdb_connection connection, const char *table_name, struct ->TableFunction("arrow_scan", {duckdb::Value::POINTER((uintptr_t)input), duckdb::Value::POINTER((uintptr_t)FactoryGetNext), duckdb::Value::POINTER((uintptr_t)FactoryGetSchema)}) - ->CreateView(table_name, true, false); + ->CreateView(duckdb::Identifier(table_name), true, false); } catch (...) { // LCOV_EXCL_START // Tried covering this in tests, but it proved harder than expected. At the time of writing: // - Passing any name to `CreateView` worked without throwing an exception diff --git a/src/duckdb/src/main/capi/catalog-c.cpp b/src/duckdb/src/main/capi/catalog-c.cpp index 59421e909..646ff6764 100644 --- a/src/duckdb/src/main/capi/catalog-c.cpp +++ b/src/duckdb/src/main/capi/catalog-c.cpp @@ -87,7 +87,7 @@ duckdb_catalog duckdb_client_context_get_catalog(duckdb_client_context context, return nullptr; } - auto catalog_ptr = duckdb::Catalog::GetCatalogEntry(context_ref.context, name); + auto catalog_ptr = duckdb::Catalog::GetCatalogEntry(context_ref.context, duckdb::Identifier(name)); if (!catalog_ptr) { return nullptr; @@ -125,8 +125,9 @@ duckdb_catalog_entry duckdb_catalog_get_entry(duckdb_catalog catalog, duckdb_cli auto &catalog_ref = *reinterpret_cast(catalog); auto &context_ref = *reinterpret_cast(context); - auto entry = catalog_ref.catalog.GetEntry(context_ref.context, duckdb::CatalogTypeFromC(entry_type), schema_name, - entry_name, duckdb::OnEntryNotFound::RETURN_NULL); + auto entry = catalog_ref.catalog.GetEntry(context_ref.context, duckdb::CatalogTypeFromC(entry_type), + duckdb::Identifier(schema_name), duckdb::Identifier(entry_name), + duckdb::OnEntryNotFound::RETURN_NULL); if (!entry) { return nullptr; diff --git a/src/duckdb/src/main/capi/copy_function-c.cpp b/src/duckdb/src/main/capi/copy_function-c.cpp index c4bcc7b07..4be983952 100644 --- a/src/duckdb/src/main/capi/copy_function-c.cpp +++ b/src/duckdb/src/main/capi/copy_function-c.cpp @@ -99,7 +99,7 @@ void duckdb_copy_function_set_name(duckdb_copy_function copy_function, const cha return; } auto ©_function_ref = *reinterpret_cast(copy_function); - copy_function_ref.name = name; + copy_function_ref.name = duckdb::Identifier(name); } void duckdb_destroy_copy_function(duckdb_copy_function *copy_function) { @@ -151,7 +151,7 @@ struct CCopyToBindInfo : FunctionData { struct CCopyFunctionToInternalBindInfo { CCopyFunctionToInternalBindInfo(ClientContext &context, CopyFunctionBindInput &input, - const vector &sql_types, const vector &names, + const vector &sql_types, const vector &names, const CCopyFunctionInfo &function_info) : context(context), input(input), sql_types(sql_types), names(names), function_info(function_info), success(true) { @@ -160,7 +160,7 @@ struct CCopyFunctionToInternalBindInfo { ClientContext &context; CopyFunctionBindInput &input; const vector &sql_types; - const vector &names; + const vector &names; const CCopyFunctionInfo &function_info; bool success; string error; @@ -170,8 +170,8 @@ struct CCopyFunctionToInternalBindInfo { duckdb_delete_callback_t delete_callback = nullptr; }; -unique_ptr CCopyToBind(ClientContext &context, CopyFunctionBindInput &input, const vector &names, - const vector &sql_types) { +unique_ptr CCopyToBind(ClientContext &context, CopyFunctionBindInput &input, + const vector &names, const vector &sql_types) { auto &info = input.function_info->Cast(); auto result = make_uniq(); @@ -654,7 +654,7 @@ unique_ptr CCopyFromBind(ClientContext &context, CopyFromFunctionB // Turn all options into named parameters for (auto opt : info.info.options) { - auto param_it = info.tf.named_parameters.find(opt.first); + auto param_it = info.tf.named_parameters.find(Identifier(opt.first)); if (param_it == info.tf.named_parameters.end()) { // Option not found in the table function's named parameters throw BinderException("'%s' is not a supported option for copy function '%s'", opt.first.c_str(), @@ -686,7 +686,7 @@ unique_ptr CCopyFromBind(ClientContext &context, CopyFromFunctionB } // Assign the option as a named parameter - named_parameters[opt.first] = param_value; + named_parameters[Identifier(opt.first)] = param_value; } // Also pass file path as a regular parameter diff --git a/src/duckdb/src/main/capi/duckdb_value-c.cpp b/src/duckdb/src/main/capi/duckdb_value-c.cpp index f128c529a..b14eab4d2 100644 --- a/src/duckdb/src/main/capi/duckdb_value-c.cpp +++ b/src/duckdb/src/main/capi/duckdb_value-c.cpp @@ -6,6 +6,7 @@ #include "duckdb/common/types/uuid.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/common/types/bignum.hpp" +#include "duckdb/common/types/decimal.hpp" #include "duckdb/main/capi/capi_internal.hpp" using duckdb::LogicalTypeId; @@ -139,6 +140,9 @@ duckdb_bignum duckdb_get_bignum(duckdb_value val) { return {data, size, is_negative}; } duckdb_value duckdb_create_decimal(duckdb_decimal input) { + if (!duckdb::Decimal::IsValidWidthScale(input.width, input.scale)) { + return nullptr; + } duckdb::hugeint_t hugeint(input.value.upper, input.value.lower); int64_t int64; if (duckdb::Hugeint::TryCast(hugeint, int64)) { diff --git a/src/duckdb/src/main/capi/helper-c.cpp b/src/duckdb/src/main/capi/helper-c.cpp index b74e84d86..e9a8b4f93 100644 --- a/src/duckdb/src/main/capi/helper-c.cpp +++ b/src/duckdb/src/main/capi/helper-c.cpp @@ -202,6 +202,7 @@ idx_t GetCTypeSize(const duckdb_type type) { case DUCKDB_TYPE_SMALLINT: return sizeof(int16_t); case DUCKDB_TYPE_INTEGER: + case DUCKDB_TYPE_SQLNULL: return sizeof(int32_t); case DUCKDB_TYPE_BIGINT: return sizeof(int64_t); diff --git a/src/duckdb/src/main/capi/logical_types-c.cpp b/src/duckdb/src/main/capi/logical_types-c.cpp index f8419b8fc..d2d08ab5c 100644 --- a/src/duckdb/src/main/capi/logical_types-c.cpp +++ b/src/duckdb/src/main/capi/logical_types-c.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/type_visitor.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/common/types/geometry_crs.hpp" +#include "duckdb/common/types/decimal.hpp" namespace duckdb { @@ -85,7 +86,7 @@ duckdb_logical_type duckdb_create_union_type(duckdb_logical_type *member_types_p duckdb::child_list_t members; for (idx_t i = 0; i < member_count; i++) { - members.push_back(make_pair(member_names[i], *member_types[i])); + members.emplace_back(make_pair(member_names[i], *member_types[i])); } duckdb::LogicalType *mtype = new duckdb::LogicalType(duckdb::LogicalType::UNION(members)); return reinterpret_cast(mtype); @@ -110,7 +111,7 @@ duckdb_logical_type duckdb_create_struct_type(duckdb_logical_type *member_types_ duckdb::child_list_t members; for (idx_t i = 0; i < member_count; i++) { - members.push_back(make_pair(member_names[i], *member_types[i])); + members.emplace_back(make_pair(member_names[i], *member_types[i])); } duckdb::LogicalType *mtype = new duckdb::LogicalType(duckdb::LogicalType::STRUCT(members)); return reinterpret_cast(mtype); @@ -155,7 +156,15 @@ duckdb_logical_type duckdb_create_map_type(duckdb_logical_type key_type, duckdb_ } duckdb_logical_type duckdb_create_decimal_type(uint8_t width, uint8_t scale) { - return reinterpret_cast(new duckdb::LogicalType(duckdb::LogicalType::DECIMAL(width, scale))); + if (!duckdb::Decimal::IsValidWidthScale(width, scale)) { + return nullptr; + } + try { + return reinterpret_cast( + new duckdb::LogicalType(duckdb::LogicalType::DECIMAL(width, scale))); + } catch (...) { + return nullptr; + } } duckdb_type duckdb_get_type_id(duckdb_logical_type type) { diff --git a/src/duckdb/src/main/capi/prepared-c.cpp b/src/duckdb/src/main/capi/prepared-c.cpp index 2fff168a0..2f194401e 100644 --- a/src/duckdb/src/main/capi/prepared-c.cpp +++ b/src/duckdb/src/main/capi/prepared-c.cpp @@ -45,7 +45,7 @@ static void duckdb_prepare_param_index_to_name_map_internal(PreparedStatementWra auto &named_param_map = wrapper->statement->named_param_map; auto &cache = wrapper->param_index_to_name; for (auto &kv : named_param_map) { - cache[kv.second] = kv.first; + cache[kv.second] = kv.first.GetIdentifierName(); } } @@ -170,12 +170,12 @@ duckdb_logical_type duckdb_param_logical_type(duckdb_prepared_statement prepared LogicalType param_type; - if (wrapper->statement->data->TryGetType(identifier, param_type)) { + if (wrapper->statement->data->TryGetType(duckdb::Identifier(identifier), param_type)) { return reinterpret_cast(new LogicalType(param_type)); } // The value_map is gone after executing the prepared statement // See if this is the case and we still have a value registered for it - auto it = wrapper->values.find(identifier); + auto it = wrapper->values.find(duckdb::Identifier(identifier)); if (it != wrapper->values.end()) { return reinterpret_cast(new LogicalType(it->second.return_type)); } @@ -251,7 +251,7 @@ duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx return DuckDBError; } auto identifier = duckdb_parameter_name_internal(prepared_statement, param_idx); - wrapper->values[identifier] = duckdb::BoundParameterData(*value); + wrapper->values[duckdb::Identifier(identifier)] = duckdb::BoundParameterData(*value); return DuckDBSuccess; } @@ -266,7 +266,7 @@ duckdb_state duckdb_bind_parameter_index(duckdb_prepared_statement prepared_stat } auto name = std::string(name_p); for (auto &pair : wrapper->statement->named_param_map) { - if (duckdb::StringUtil::CIEquals(pair.first, name)) { + if (pair.first == name) { *param_idx_out = pair.second; return DuckDBSuccess; } diff --git a/src/duckdb/src/main/capi/profiling_info-c.cpp b/src/duckdb/src/main/capi/profiling_info-c.cpp index c8aa47f2f..1eafc51ed 100644 --- a/src/duckdb/src/main/capi/profiling_info-c.cpp +++ b/src/duckdb/src/main/capi/profiling_info-c.cpp @@ -1,68 +1,96 @@ #include "duckdb/main/capi/capi_internal.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/query_profiler.hpp" using duckdb::Connection; -using duckdb::DuckDB; -using duckdb::EnumUtil; -using duckdb::MetricType; -using duckdb::optional_ptr; -using duckdb::ProfilingNode; +using duckdb::QueryProfileResult; +using duckdb::QueryProfileResultKind; + +// Returns the total number of navigable children of a node: +// items across all LIST children (flattened) followed by direct OBJECT children. +static idx_t CountChildren(const QueryProfileResult &node) { + idx_t count = 0; + for (auto &child : node.children) { + if (child->kind == QueryProfileResultKind::LIST) { + count += child->children.size(); + } else if (child->kind == QueryProfileResultKind::OBJECT) { + count++; + } + } + return count; +} + +// Returns the i-th navigable child of a node. +// LIST items come first (flattened across all LIST children), then direct OBJECT children. +static QueryProfileResult *GetChild(const QueryProfileResult &node, idx_t index) { + // LIST items first (preserves backward-compatible indices for operator nodes) + for (auto &child : node.children) { + if (child->kind != QueryProfileResultKind::LIST) { + continue; + } + if (index < child->children.size()) { + return child->children[index].get(); + } + index -= child->children.size(); + } + // Then direct OBJECT children (e.g. metric-group sub-objects like "query", "system") + for (auto &child : node.children) { + if (child->kind != QueryProfileResultKind::OBJECT) { + continue; + } + if (index == 0) { + return child.get(); + } + index--; + } + return nullptr; +} duckdb_profiling_info duckdb_get_profiling_info(duckdb_connection connection) { if (!connection) { return nullptr; } - Connection *conn = reinterpret_cast(connection); - optional_ptr profiling_node; + auto *conn = reinterpret_cast(connection); try { - profiling_node = conn->GetProfilingTree(); - } catch (std::exception &ex) { - return nullptr; - } - - if (!profiling_node) { + auto &profiler = duckdb::QueryProfiler::Get(*conn->context); + if (!profiler.IsEnabled() || !profiler.HasRoot()) { + return nullptr; + } + return reinterpret_cast(&profiler.GetResult()); + } catch (...) { return nullptr; } - return reinterpret_cast(profiling_node.get()); } duckdb_value duckdb_profiling_info_get_value(duckdb_profiling_info info, const char *key) { - if (!info) { + if (!info || !key) { return nullptr; } - auto &node = *reinterpret_cast(info); - auto &profiling_info = node.GetProfilingInfo(); - auto key_enum = EnumUtil::FromString(duckdb::StringUtil::Upper(key)); - if (!profiling_info.Enabled(profiling_info.settings, key_enum)) { - return nullptr; + auto &node = *reinterpret_cast(info); + duckdb::string key_str(key); + for (auto &child : node.children) { + if (child->kind != QueryProfileResultKind::VALUE) { + continue; + } + if (duckdb::StringUtil::CIEquals(child->key, key_str)) { + return reinterpret_cast(new duckdb::Value(child->value)); + } } - - auto str = profiling_info.GetMetricAsString(key_enum); - return duckdb_create_varchar_length(str.c_str(), strlen(str.c_str())); + return nullptr; } duckdb_value duckdb_profiling_info_get_metrics(duckdb_profiling_info info) { if (!info) { return nullptr; } - - auto &node = *reinterpret_cast(info); - auto &profiling_info = node.GetProfilingInfo(); - + auto &node = *reinterpret_cast(info); duckdb::InsertionOrderPreservingMap metrics_map; - for (const auto &metric : profiling_info.metrics) { - auto key = EnumUtil::ToString(metric.first); - if (!profiling_info.Enabled(profiling_info.settings, metric.first)) { - continue; - } - - if (key == EnumUtil::ToString(MetricType::OPERATOR_TYPE)) { - auto type = duckdb::PhysicalOperatorType(metric.second.GetValue()); - metrics_map[key] = EnumUtil::ToString(type); - } else { - metrics_map[key] = metric.second.ToString(); + for (auto &child : node.children) { + if (child->kind == QueryProfileResultKind::VALUE) { + metrics_map.insert(child->key, child->value.ToString()); } } - auto map = duckdb::Value::MAP(metrics_map); return reinterpret_cast(new duckdb::Value(map)); } @@ -71,19 +99,14 @@ idx_t duckdb_profiling_info_get_child_count(duckdb_profiling_info info) { if (!info) { return 0; } - auto &node = *reinterpret_cast(info); - return node.GetChildCount(); + auto &node = *reinterpret_cast(info); + return CountChildren(node); } duckdb_profiling_info duckdb_profiling_info_get_child(duckdb_profiling_info info, idx_t index) { if (!info) { return nullptr; } - auto &node = *reinterpret_cast(info); - if (index >= node.GetChildCount()) { - return nullptr; - } - - ProfilingNode *profiling_info_ptr = node.GetChild(index).get(); - return reinterpret_cast(profiling_info_ptr); + auto &node = *reinterpret_cast(info); + return reinterpret_cast(GetChild(node, index)); } diff --git a/src/duckdb/src/main/capi/replacement_scan-c.cpp b/src/duckdb/src/main/capi/replacement_scan-c.cpp index 5f23a3169..ae529862a 100644 --- a/src/duckdb/src/main/capi/replacement_scan-c.cpp +++ b/src/duckdb/src/main/capi/replacement_scan-c.cpp @@ -47,7 +47,8 @@ unique_ptr duckdb_capi_replacement_callback(ClientContext &context, Re for (auto ¶m : info.parameters) { children.push_back(make_uniq(std::move(param))); } - table_function->function = make_uniq(info.function_name, std::move(children)); + table_function->function = + make_uniq(duckdb::Identifier(info.function_name), std::move(children)); return std::move(table_function); } diff --git a/src/duckdb/src/main/capi/scalar_function-c.cpp b/src/duckdb/src/main/capi/scalar_function-c.cpp index 064a5db61..07011aaee 100644 --- a/src/duckdb/src/main/capi/scalar_function-c.cpp +++ b/src/duckdb/src/main/capi/scalar_function-c.cpp @@ -177,7 +177,7 @@ unique_ptr CScalarFunctionBind(BindScalarFunctionInput &input) { unique_ptr CScalarFunctionInit(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { - auto &function = expr.function; + auto &function = expr.Function(); auto &info = function.GetExtraFunctionInfo().Cast(); D_ASSERT(info.function); @@ -196,7 +196,7 @@ unique_ptr CScalarFunctionInit(ExpressionState &state, const void CAPIScalarFunction(DataChunk &input, ExpressionState &state, Vector &result) { auto &function = state.expr.Cast(); - auto &bind_info = function.bind_info; + auto &bind_info = function.BindInfo(); auto &c_bind_info = bind_info->Cast(); auto &c_local_state = ExecuteFunctionState::GetFunctionState(state)->Cast(); @@ -242,7 +242,7 @@ void duckdb_scalar_function_set_name(duckdb_scalar_function function, const char return; } auto &scalar_function = GetCScalarFunction(function); - scalar_function.name = name; + scalar_function.name = duckdb::Identifier(name); } void duckdb_scalar_function_set_varargs(duckdb_scalar_function function, duckdb_logical_type type) { @@ -469,7 +469,7 @@ duckdb_state duckdb_register_scalar_function(duckdb_connection connection, duckd return DuckDBError; } auto &scalar_function = GetCScalarFunction(function); - duckdb::ScalarFunctionSet set(scalar_function.name); + duckdb::ScalarFunctionSet set {scalar_function.name}; set.AddFunction(scalar_function); return duckdb_register_scalar_function_set(connection, reinterpret_cast(&set)); } diff --git a/src/duckdb/src/main/capi/table_function-c.cpp b/src/duckdb/src/main/capi/table_function-c.cpp index a878d99a6..06ff0970b 100644 --- a/src/duckdb/src/main/capi/table_function-c.cpp +++ b/src/duckdb/src/main/capi/table_function-c.cpp @@ -196,7 +196,7 @@ void duckdb_table_function_set_name(duckdb_table_function function, const char * return; } auto &tf = GetCTableFunction(function); - tf.name = name; + tf.name = duckdb::Identifier(name); } void duckdb_table_function_add_parameter(duckdb_table_function function, duckdb_logical_type type) { diff --git a/src/duckdb/src/main/chunk_scan_state/batched_data_collection.cpp b/src/duckdb/src/main/chunk_scan_state/batched_data_collection.cpp index 7d2bc36cf..d6aad8c9a 100644 --- a/src/duckdb/src/main/chunk_scan_state/batched_data_collection.cpp +++ b/src/duckdb/src/main/chunk_scan_state/batched_data_collection.cpp @@ -18,7 +18,7 @@ BatchCollectionChunkScanState::~BatchCollectionChunkScanState() { void BatchCollectionChunkScanState::InternalLoad(ErrorData &error) { if (state.range.begin == state.range.end) { // Signal empty chunk to break out of the loop - current_chunk->SetCardinality(0); + current_chunk->SetChildCardinality(0); return; } offset = 0; diff --git a/src/duckdb/src/main/client_context.cpp b/src/duckdb/src/main/client_context.cpp index c872139be..5029ef641 100644 --- a/src/duckdb/src/main/client_context.cpp +++ b/src/duckdb/src/main/client_context.cpp @@ -27,6 +27,9 @@ #include "duckdb/main/stream_query_result.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/parser/expression/parameter_expression.hpp" #include "duckdb/parser/parsed_data/create_function_info.hpp" #include "duckdb/parser/parser.hpp" @@ -218,6 +221,47 @@ unique_ptr ClientContext::LockContext() { return make_uniq(context_lock); } +void ClientContext::ConnectToCatalog(const shared_ptr &target) { + D_ASSERT(target); + // Pre-flight: Supports(RemoteCapability::CONNECT) is the capability declaration; catalogs that + // return true MUST implement RemoteExecute(string). Validation runs before mutation so a throw + // leaves the client unbound. + if (!target->GetCatalog().Supports(RemoteCapability::CONNECT)) { + throw InvalidInputException("Database \"%s\" does not support CONNECT", target->GetName()); + } + connected_to_database = target; + is_connected = true; +} + +void ClientContext::DisconnectFromCatalog() { + connected_to_database.reset(); + is_connected = false; +} + +shared_ptr ClientContext::TryGetConnectedCatalog() const { + if (!is_connected) { + return nullptr; + } + return connected_to_database.lock(); +} + +//! True if `type` is a CONNECT control statement that must execute against LOCAL even while a +//! CONNECT binding is active (i.e. the chokepoint must let it fall through, not rewrite it). +static bool IsConnectControlStatement(StatementType type) { + return type == StatementType::CONNECT_STATEMENT || type == StatementType::DISCONNECT_STATEMENT; +} + +//! Wrap a TableRef returned from Catalog::RemoteExecute into `SELECT * FROM ` for the chokepoint. +static unique_ptr WrapAsSelect(unique_ptr from_ref) { + auto select_node = make_uniq(); + select_node->select_list.push_back(make_uniq()); + select_node->from_table = std::move(from_ref); + + auto select_stmt = make_uniq(); + select_stmt->node = std::move(select_node); + return std::move(select_stmt); +} + void ClientContext::Destroy() { auto lock = LockContext(); if (transaction.HasActiveTransaction()) { @@ -426,7 +470,6 @@ shared_ptr ClientContext::CreatePreparedStatementInternal auto &profiler = QueryProfiler::Get(*this); profiler.StartQuery(query, IsExplainAnalyze(statement.get())); - profiler.StartPhase(MetricType::PLANNER); Planner logical_planner(*this); if (parameters.parameters) { auto ¶meter_values = *parameters.parameters; @@ -435,9 +478,11 @@ shared_ptr ClientContext::CreatePreparedStatementInternal } } - logical_planner.CreatePlan(std::move(statement)); - D_ASSERT(logical_planner.plan || !logical_planner.properties.bound_all_parameters); - profiler.EndPhase(); + { + auto planner_timer = profiler.StartTimer(); + logical_planner.CreatePlan(std::move(statement)); + D_ASSERT(logical_planner.plan || !logical_planner.properties.bound_all_parameters); + } auto logical_plan = std::move(logical_planner.plan); // extract the result column names from the plan @@ -463,7 +508,7 @@ shared_ptr ClientContext::CreatePreparedStatementInternal } } - bool optimize = config.enable_optimizer; + bool optimize = Settings::Get(*this); if (Settings::Get(*this)) { // verify disable optimizer - disable EXCEPT for explain, otherwise every single EXPLAIN query breaks if (logical_plan->type != LogicalOperatorType::LOGICAL_EXPLAIN) { @@ -471,22 +516,23 @@ shared_ptr ClientContext::CreatePreparedStatementInternal } } if (optimize && logical_plan->RequireOptimizer()) { - profiler.StartPhase(MetricType::ALL_OPTIMIZERS); - Optimizer optimizer(*logical_planner.binder, *this); - logical_plan = optimizer.Optimize(std::move(logical_plan)); - D_ASSERT(logical_plan); - profiler.EndPhase(); - + { + auto optimizer_timer = profiler.StartTimer(); + Optimizer optimizer(*logical_planner.binder, *this); + logical_plan = optimizer.Optimize(std::move(logical_plan)); + D_ASSERT(logical_plan); + } #ifdef DEBUG logical_plan->Verify(*this); #endif } // Convert the logical query plan into a physical query plan. - profiler.StartPhase(MetricType::PHYSICAL_PLANNER); - PhysicalPlanGenerator physical_planner(*this); - result->physical_plan = physical_planner.Plan(std::move(logical_plan)); - profiler.EndPhase(); + { + auto physical_timer = profiler.StartTimer(); + PhysicalPlanGenerator physical_planner(*this); + result->physical_plan = physical_planner.Plan(std::move(logical_plan)); + } D_ASSERT(result->physical_plan); return result; } @@ -544,7 +590,7 @@ QueryProgress ClientContext::GetQueryProgress() { } void BindPreparedStatementParameters(PreparedStatementData &statement, const PendingQueryParameters ¶meters) { - case_insensitive_map_t owned_values; + identifier_map_t owned_values; if (parameters.parameters) { auto ¶ms = *parameters.parameters; for (auto &val : params) { @@ -676,6 +722,7 @@ void ClientContext::WaitForTask(ClientContextLock &lock, BaseQueryResult &result bool ClientContext::ErrorInvalidatesTransaction(ExceptionType type) { switch (transaction.GetInvalidationPolicy()) { + case TransactionInvalidationPolicy::STANDARD_POLICY: case TransactionInvalidationPolicy::ALL_ERRORS_INVALIDATE_TRANSACTION: return true; default: @@ -741,7 +788,7 @@ vector> ClientContext::ParseStatementsInternal(ClientCo Parser parser(GetParserOptions()); auto &profiler = QueryProfiler::Get(*this); profiler.StartQuery(query); - profiler.StartPhase(MetricType::PARSER); + auto parser_timer = profiler.StartTimer(); parser.ParseQuery(query); StatementPreprocessor preprocessor(*this); @@ -783,10 +830,8 @@ unique_ptr ClientContext::ExtractPlan(const string &query) { plan = std::move(planner.plan); - if (config.enable_optimizer) { - Optimizer optimizer(*planner.binder, *this); - plan = optimizer.Optimize(std::move(plan)); - } + Optimizer optimizer(*planner.binder, *this); + plan = optimizer.Optimize(std::move(plan)); ColumnBindingResolver resolver; resolver.Verify(*this, *plan); @@ -879,7 +924,7 @@ unique_ptr ClientContext::Execute(const string &query, shared_ptr

ClientContext::Execute(const string &query, shared_ptr &prepared, - case_insensitive_map_t &values, + identifier_map_t &values, QueryParameters query_parameters) { PendingQueryParameters parameters; parameters.parameters = &values; @@ -941,6 +986,46 @@ unique_ptr ClientContext::PendingStatementOrPreparedStatemen unique_ptr ClientContext::PendingStatementOrPreparedStatement( ClientContextLock &lock, const string &query, unique_ptr statement, shared_ptr &prepared, const PendingQueryParameters ¶meters) { + // CONNECT chokepoint: when connected, non-control SQL is rewritten in place and falls through to + // the normal pipeline. No recursion — the rewrite goes through PendingStatementInternal, not back here. + if (is_connected) { + bool is_control = false; + if (statement && IsConnectControlStatement(statement->type)) { + is_control = true; + } + if (!is_control) { + if (!statement) { + // Prepared-statement execution path (statement is null, prepared is set). Parameterized + // prepared statements would need parameter substitution we don't do in v0 — reject. + // No-param prepared statements have a fully-resolved `query` already, route them. + if (!prepared || prepared->properties.parameter_count > 0) { + return ErrorResult( + ErrorData(InvalidInputException( + "Parameterized prepared statements cannot be executed while CONNECT-ed; " + "DISCONNECT first, or run the SQL as a fresh statement to route through " + "the CONNECT binding")), + query); + } + // Clear prepared so the rest of the function takes the fresh-statement path on the + // rewritten SELECT below. + prepared.reset(); + } + auto live = connected_to_database.lock(); + if (!live) { + // Target was detached elsewhere; user must explicitly DISCONNECT to clear is_connected. + return ErrorResult( + ErrorData(InvalidInputException( + "The connected database has been detached out from under this connection. Issue " + "DISCONNECT to clear the connection before running further SQL.")), + query); + } + // Dispatch via the catalog — Supports(CONNECT) was validated at CONNECT time, so RemoteExecute + // is contracted to be implemented. Wrap the returned TableRef into a SelectStatement. + auto remote_ref = live->GetCatalog().RemoteExecute(*this, query); + statement = WrapAsSelect(std::move(remote_ref)); + // statement is now SELECT * FROM ; fall through. + } + } unique_ptr pending; try { BeginQueryInternal(lock, query); @@ -1087,18 +1172,18 @@ vector> ClientContext::ParseStatements(ClientContextLoc } unique_ptr ClientContext::PendingQuery(const string &query, QueryParameters parameters) { - case_insensitive_map_t empty_param_list; + identifier_map_t empty_param_list; return PendingQuery(query, empty_param_list, parameters); } unique_ptr ClientContext::PendingQuery(unique_ptr statement, QueryParameters parameters) { - case_insensitive_map_t empty_param_list; + identifier_map_t empty_param_list; return PendingQuery(std::move(statement), empty_param_list, parameters); } unique_ptr ClientContext::PendingQuery(const string &query, - case_insensitive_map_t &values, + identifier_map_t &values, QueryParameters parameters) { PendingQueryParameters params; params.parameters = values; @@ -1128,7 +1213,7 @@ unique_ptr ClientContext::PendingQuery(const string &query, } unique_ptr ClientContext::PendingQuery(unique_ptr statement, - case_insensitive_map_t &values, + identifier_map_t &values, QueryParameters parameters) { auto lock = LockContext(); auto query = statement->query; @@ -1191,7 +1276,7 @@ void ClientContext::InterruptCheck() const { if (query_deadline.IsValid() && ++timeout_check_counter % TIMEOUT_CHECK_INTERVAL == 0) { auto now = NumericCast(duration_cast(steady_clock::now().time_since_epoch()).count()); if (now >= query_deadline.GetIndex()) { - throw InterruptException(); + throw InterruptException("Query exceeded maximum execution time"); } } } @@ -1216,8 +1301,8 @@ void ClientContext::DisableProfiling() { void ClientContext::RegisterFunction(CreateFunctionInfo &info) { RunFunctionInTransaction([&]() { - auto existing_function = Catalog::GetEntry(*this, INVALID_CATALOG, info.schema, - info.name, OnEntryNotFound::RETURN_NULL); + auto existing_function = Catalog::GetEntry( + *this, Identifier::InvalidCatalog(), info.schema, info.name, OnEntryNotFound::RETURN_NULL); if (existing_function) { auto &new_info = info.Cast(); if (new_info.functions.MergeFunctionSet(existing_function->functions)) { @@ -1274,8 +1359,8 @@ void ClientContext::RunFunctionInTransaction(const std::function &fu RunFunctionInTransactionInternal(*lock, fun, requires_valid_transaction); } -unique_ptr ClientContext::TableInfo(const string &database_name, const string &schema_name, - const string &table_name) { +unique_ptr ClientContext::TableInfo(const Identifier &database_name, const Identifier &schema_name, + const Identifier &table_name) { unique_ptr result; RunFunctionInTransaction([&]() { // Obtain the table from the catalog. @@ -1295,8 +1380,8 @@ unique_ptr ClientContext::TableInfo(const string &database_nam return result; } -unique_ptr ClientContext::TableInfo(const string &schema_name, const string &table_name) { - return TableInfo(INVALID_CATALOG, schema_name, table_name); +unique_ptr ClientContext::TableInfo(const Identifier &schema_name, const Identifier &table_name) { + return TableInfo(Identifier::InvalidCatalog(), schema_name, table_name); } void ClientContext::Append(unique_ptr stmt) { @@ -1307,11 +1392,11 @@ void ClientContext::Append(unique_ptr stmt) { } void ClientContext::Append(TableDescription &description, ColumnDataCollection &collection) { - string table_name = "__duckdb_internal_appended_data"; - vector expected_names; + Identifier table_name("__duckdb_internal_appended_data"); + vector expected_names; auto query = Appender::ConstructQuery(description, table_name, expected_names); auto table_ref = BaseAppender::GetColumnDataTableRef(collection, table_name, expected_names); - auto stmt = BaseAppender::ParseStatement(std::move(table_ref), query, table_name); + auto stmt = BaseAppender::ParseStatement(std::move(table_ref), query, table_name.GetIdentifierName()); Append(std::move(stmt)); } diff --git a/src/duckdb/src/main/client_data.cpp b/src/duckdb/src/main/client_data.cpp index 2ea4ead18..c79656fd3 100644 --- a/src/duckdb/src/main/client_data.cpp +++ b/src/duckdb/src/main/client_data.cpp @@ -207,10 +207,7 @@ class ClientBufferManager : public BufferManager { private: void TrackMemoryAllocation(idx_t size) const { if (size > 0) { - auto &profiler = QueryProfiler::Get(context); - // Track allocations even if profiler isn't running yet - they'll be included when the query starts - // AddToCounter already checks IsEnabled(), so we don't need to check here - profiler.AddToCounter(MetricType::TOTAL_MEMORY_ALLOCATED, size); + QueryProfiler::Get(context).TrackTotalMemoryAllocated(size); } } diff --git a/src/duckdb/src/main/client_verify.cpp b/src/duckdb/src/main/client_verify.cpp index 2022cf335..5944c5ab2 100644 --- a/src/duckdb/src/main/client_verify.cpp +++ b/src/duckdb/src/main/client_verify.cpp @@ -28,7 +28,7 @@ struct PreparedStatementVerification { void ConvertConstants(QueryNode &node); void ConvertConstants(unique_ptr &child); - case_insensitive_map_t> values; + identifier_map_t> values; }; void PreparedStatementVerification::ConvertConstants(QueryNode &node) { @@ -43,8 +43,8 @@ void PreparedStatementVerification::ConvertConstants(unique_ptrClearAlias(); // check if the value already exists idx_t index = values.size(); - auto identifier = std::to_string(index + 1); - const auto predicate = [&](const std::pair> &pair) { + auto identifier = Identifier(std::to_string(index + 1)); + const auto predicate = [&](const std::pair> &pair) { return pair.second->Equals(*expr.get()); }; auto result = std::find_if(values.begin(), values.end(), predicate); @@ -57,7 +57,7 @@ void PreparedStatementVerification::ConvertConstants(unique_ptr(); - parameter->identifier = identifier; + parameter->IdentifierMutable() = identifier; parameter->SetAlias(alias); expr = std::move(parameter); return; @@ -112,7 +112,7 @@ void ClientContext::StatementVerification(ClientContextLock &lock, const string Allocator allocator; MemoryStream stream(allocator); SerializationOptions options; - options.serialization_compatibility = SerializationCompatibility::FromString("latest"); + options.storage_compatibility = StorageCompatibility::FromString("latest"); optional_ptr to_serialize_stmt; optional_ptr to_serialize_node; switch (statement->type) { @@ -201,7 +201,7 @@ void ClientContext::StatementVerification(ClientContextLock &lock, const string // create the PREPARE and EXECUTE statements string name = "__duckdb_verification_prepared_statement_" + UUID::ToString(UUID::GenerateRandomUUID()); auto prepare = make_uniq(); - prepare->name = name; + prepare->name = Identifier(name); prepare->statement = std::move(prepare_base); // execute the PREPARE @@ -225,7 +225,7 @@ void ClientContext::StatementVerification(ClientContextLock &lock, const string // create and return the EXECUTE statement // i.e. EXECUTE p('hello', 42) auto execute = make_uniq(); - execute->name = name; + execute->name = Identifier(name); execute->named_values = std::move(prep_verifier.values); statement = std::move(execute); @@ -244,6 +244,14 @@ void ClientContext::StatementVerification(ClientContextLock &lock, const string } auto explain_q = "EXPLAIN " + query; auto explain_stmt = make_uniq(statement->Copy()); + // Disable the profiler during the verification EXPLAIN to prevent it from consuming the profiler context + // (which would lose parser timing captured before StatementVerification was called) and from + // overwriting the profiling output file with the EXPLAIN's profiling data. + auto &client_config = ClientConfig::GetConfig(*this); + bool saved_profiler = client_config.enable_profiler; + ScopedConfigSetting suppress_profiling( + client_config, [](ClientConfig &config) { config.enable_profiler = false; }, + [saved_profiler](ClientConfig &config) { config.enable_profiler = saved_profiler; }); auto explain_result = RunStatementInternal(lock, explain_q, std::move(explain_stmt), query_parameters); if (explain_result->HasError()) { explain_result->ThrowError(); diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index 4c57957a7..e5761b0f8 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -22,7 +22,9 @@ namespace duckdb { +#ifdef DEBUG bool DBConfigOptions::debug_print_bindings = false; +#endif DebugVerificationMode DBConfigOptions::global_verification_mode = DebugVerificationMode::NONE; #define DUCKDB_SETTING(_PARAM) \ @@ -68,7 +70,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(AccessModeSetting), DUCKDB_SETTING_CALLBACK(AllocatorBackgroundThreadsSetting), DUCKDB_GLOBAL(AllocatorBulkDeallocationFlushThresholdSetting), - DUCKDB_GLOBAL(AllocatorFlushThresholdSetting), + DUCKDB_SETTING_CALLBACK(AllocatorFlushThresholdSetting), DUCKDB_SETTING_CALLBACK(AllowCommunityExtensionsSetting), DUCKDB_SETTING(AllowExtensionsMetadataMismatchSetting), DUCKDB_SETTING_CALLBACK(AllowParserOverrideExtensionSetting), @@ -83,6 +85,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING(ArrowOutputListViewSetting), DUCKDB_SETTING_CALLBACK(ArrowOutputVersionSetting), DUCKDB_SETTING(AsofLoopJoinThresholdSetting), + DUCKDB_GLOBAL(AsyncThreadsSetting), DUCKDB_SETTING(AutoCheckpointSkipWalThresholdSetting), DUCKDB_SETTING(AutoinstallExtensionRepositorySetting), DUCKDB_SETTING(AutoinstallKnownExtensionsSetting), @@ -107,6 +110,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING(DebugSkipCheckpointOnCommitSetting), DUCKDB_GLOBAL(DebugVerificationModeSetting), DUCKDB_SETTING(DebugVerificationProjectionSetting), + DUCKDB_SETTING(DebugVerifyAggregateStateExportSetting), DUCKDB_SETTING(DebugVerifyBlocksSetting), DUCKDB_SETTING(DebugVerifyColumnBindingsSetting), DUCKDB_SETTING(DebugVerifySerializerSetting), @@ -116,9 +120,12 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING_CALLBACK(DebugWindowModeSetting), DUCKDB_SETTING_CALLBACK(DefaultBlockSizeSetting), DUCKDB_SETTING_CALLBACK(DefaultCollationSetting), + DUCKDB_SETTING_CALLBACK(DefaultIoModeSetting), DUCKDB_SETTING_CALLBACK(DefaultNullOrderSetting), DUCKDB_SETTING_CALLBACK(DefaultOrderSetting), DUCKDB_GLOBAL(DefaultSecretStorageSetting), + DUCKDB_SETTING_CALLBACK(DefaultTransactionInvalidationPolicySetting), + DUCKDB_SETTING(DelimJoinAsCteSetting), DUCKDB_SETTING_CALLBACK(DeprecatedUsingKeySyntaxSetting), DUCKDB_SETTING_CALLBACK(DisableDatabaseInvalidationSetting), DUCKDB_SETTING(DisableTimestamptzCastsSetting), @@ -132,11 +139,11 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING_CALLBACK(EnableExternalAccessSetting), DUCKDB_SETTING_CALLBACK(EnableExternalFileCacheSetting), DUCKDB_SETTING(EnableFSSTVectorsSetting), - DUCKDB_LOCAL(EnableHTTPLoggingSetting), DUCKDB_SETTING(EnableHTTPMetadataCacheSetting), DUCKDB_GLOBAL(EnableLogging), DUCKDB_SETTING(EnableMacroDependenciesSetting), DUCKDB_SETTING(EnableObjectCacheSetting), + DUCKDB_SETTING(EnableOptimizerSetting), DUCKDB_LOCAL(EnableProfilingSetting), DUCKDB_LOCAL(EnableProgressBarSetting), DUCKDB_LOCAL(EnableProgressBarPrintSetting), @@ -152,13 +159,13 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING_CALLBACK(ExternalThreadsSetting), DUCKDB_SETTING(FileSearchPathSetting), DUCKDB_SETTING_CALLBACK(ForceBitpackingModeSetting), + DUCKDB_SETTING(ForceColumnMetadataReuseSetting), DUCKDB_SETTING_CALLBACK(ForceCompressionSetting), DUCKDB_GLOBAL(ForceMbedtlsUnsafeSetting), DUCKDB_GLOBAL(ForceVariantShredding), DUCKDB_SETTING(GeometryMinimumShreddingSize), DUCKDB_SETTING_CALLBACK(HomeDirectorySetting), - DUCKDB_LOCAL(HTTPLoggingOutputSetting), - DUCKDB_SETTING(HTTPProxySetting), + DUCKDB_GLOBAL(HTTPProxySetting), DUCKDB_SETTING(HTTPProxyPasswordSetting), DUCKDB_SETTING(HTTPProxyUsernameSetting), DUCKDB_SETTING(IeeeFloatingPointOpsSetting), @@ -166,9 +173,12 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING(ImmediateTransactionModeSetting), DUCKDB_SETTING(IndexScanMaxCountSetting), DUCKDB_SETTING_CALLBACK(IndexScanPercentageSetting), + DUCKDB_SETTING_CALLBACK(InitialColumnSegmentSizeSetting), DUCKDB_SETTING(IntegerDivisionSetting), DUCKDB_SETTING_CALLBACK(LambdaSyntaxSetting), DUCKDB_SETTING(LateMaterializationMaxRowsSetting), + DUCKDB_SETTING(LegacyDisableNullTypeSetting), + DUCKDB_SETTING(LegacyMetricsFormatSetting), DUCKDB_SETTING(LockConfigurationSetting), DUCKDB_SETTING_CALLBACK(LogQueryPathSetting), DUCKDB_GLOBAL(LoggingLevel), @@ -206,12 +216,14 @@ static const ConfigurationOption internal_options[] = { DUCKDB_LOCAL(SchemaSetting), DUCKDB_LOCAL(SearchPathSetting), DUCKDB_GLOBAL(SecretDirectorySetting), + DUCKDB_GLOBAL(StandardVectorSizeSetting), DUCKDB_SETTING_CALLBACK(StorageBlockPrefetchSetting), DUCKDB_GLOBAL(StorageCompatibilityVersionSetting), DUCKDB_LOCAL(StreamingBufferSizeSetting), DUCKDB_GLOBAL(TempDirectorySetting), DUCKDB_SETTING_CALLBACK(TempFileEncryptionSetting), DUCKDB_GLOBAL(ThreadsSetting), + DUCKDB_LOCAL(TrackedMetricsSetting), DUCKDB_SETTING(UsernameSetting), DUCKDB_SETTING_CALLBACK(VacuumRebuildIndexesSetting), DUCKDB_SETTING_CALLBACK(ValidateExternalFileCacheSetting), @@ -219,17 +231,18 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING(WalAutocheckpointEntriesSetting), DUCKDB_SETTING_CALLBACK(WarningsAsErrorsSetting), DUCKDB_SETTING(WriteBufferRowGroupCountSetting), + DUCKDB_GLOBAL(WriteBufferRowGroupMemoryLimitSetting), DUCKDB_SETTING(ZstdMinStringLengthSetting), FINAL_SETTING}; -static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("configure_metrics", 27), - DUCKDB_SETTING_ALIAS("custom_profiling_settings", 27), - DUCKDB_SETTING_ALIAS("memory_limit", 112), - DUCKDB_SETTING_ALIAS("null_order", 52), - DUCKDB_SETTING_ALIAS("profiling_output", 133), - DUCKDB_SETTING_ALIAS("user", 148), - DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 26), - DUCKDB_SETTING_ALIAS("worker_threads", 147), +static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("configure_metrics", 28), + DUCKDB_SETTING_ALIAS("custom_profiling_settings", 28), + DUCKDB_SETTING_ALIAS("memory_limit", 120), + DUCKDB_SETTING_ALIAS("null_order", 55), + DUCKDB_SETTING_ALIAS("profiling_output", 141), + DUCKDB_SETTING_ALIAS("user", 158), + DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 27), + DUCKDB_SETTING_ALIAS("worker_threads", 156), FINAL_ALIAS}; vector DBConfig::GetOptions() { @@ -576,15 +589,15 @@ void DBConfig::CheckLock(const String &name) { // not locked return; } - case_insensitive_set_t allowed_settings {"schema", "search_path"}; - if (allowed_settings.find(name.ToStdString()) != allowed_settings.end()) { + identifier_set_t allowed_settings {"schema", "search_path"}; + if (allowed_settings.find(Identifier(name.ToStdString())) != allowed_settings.end()) { // we are always allowed to change these settings return; } if (!options.allowed_configs.empty()) { auto option = GetOptionByName(name); auto canonical_name = option ? string(option->name) : name.ToStdString(); - if (options.allowed_configs.find(canonical_name) != options.allowed_configs.end()) { + if (options.allowed_configs.find(Identifier(canonical_name)) != options.allowed_configs.end()) { // settings that are allowed through allowed_configs return; } @@ -612,6 +625,15 @@ idx_t DBConfig::GetSystemMaxThreads(FileSystem &fs) { #endif } +idx_t DBConfig::GetSystemMaxAsyncThreads(FileSystem &fs) { +#ifdef DUCKDB_NO_THREADS + return 0; +#else + // via benchmark, it seems that ~2x system threads is the I/O-concurrency sweet spot for remote (S3) scans + return 2 * GetSystemMaxThreads(fs); +#endif +} + idx_t DBConfig::GetSystemAvailableMemory(FileSystem &fs) { // System memory detection auto memory = FileSystem::GetAvailableMemory(); @@ -657,7 +679,7 @@ idx_t DBConfig::ParseMemoryLimit(const string &arg) { if (!error.empty()) { if (error == "Memory cannot be negative") { - return NumericLimits::Maximum(); + result = DConstants::INVALID_INDEX; } else { throw ParserException(error); } @@ -825,8 +847,8 @@ void DBConfig::AddAllowedConfig(const string &config_name) { if (config_name.empty()) { throw InvalidInputException("Cannot provide an empty string for allowed_configs"); } - duckdb::case_insensitive_set_t always_disallowed_config {"allowed_configs", "lock_configuration"}; - if (always_disallowed_config.find(config_name) != always_disallowed_config.end()) { + duckdb::identifier_set_t always_disallowed_config {"allowed_configs", "lock_configuration"}; + if (always_disallowed_config.find(Identifier(config_name)) != always_disallowed_config.end()) { throw InvalidInputException("Cannot include '%s' in allowed_configs", config_name); } // Validate that the config name refers to a known setting (built-in or extension) @@ -839,14 +861,14 @@ void DBConfig::AddAllowedConfig(const string &config_name) { } ExtensionOption extension_option; if (TryGetExtensionOption(config_name, extension_option)) { - options.allowed_configs.insert(config_name); + options.allowed_configs.insert(Identifier(config_name)); return; } // Check if the setting belongs to a known extension (even if not yet loaded) auto extension_name = ExtensionHelper::FindExtensionInEntries(config_name, EXTENSION_SETTINGS); if (!extension_name.empty()) { // Accept the setting - the extension may be autoloaded later when the setting is used - options.allowed_configs.insert(config_name); + options.allowed_configs.insert(Identifier(config_name)); return; } throw InvalidInputException("Unknown configuration option '%s' in allowed_configs", config_name); @@ -909,7 +931,7 @@ bool DBConfig::CanAccessFile(const string &input_path, FileType type) { } SerializationOptions::SerializationOptions(AttachedDatabase &db) { - serialization_compatibility = SerializationCompatibility::FromDatabase(db); + storage_compatibility = StorageCompatibility::FromDatabase(db); } void DBConfig::SetHTTPUtil(const shared_ptr &new_http_util) { diff --git a/src/duckdb/src/main/connection.cpp b/src/duckdb/src/main/connection.cpp index e2d7675a9..29b856211 100644 --- a/src/duckdb/src/main/connection.cpp +++ b/src/duckdb/src/main/connection.cpp @@ -51,17 +51,6 @@ string Connection::GetProfilingInformation(ProfilerPrintFormat format) { return profiler.ToString(format); } -optional_ptr Connection::GetProfilingTree() { - auto &client_config = ClientConfig::GetConfig(*context); - auto enable_profiler = client_config.enable_profiler; - - if (!enable_profiler) { - throw Exception(ExceptionType::SETTINGS, "Profiling is not enabled for this connection"); - } - auto &profiler = QueryProfiler::Get(*context); - return profiler.GetRoot(); -} - void Connection::Interrupt() { context->Interrupt(); } @@ -124,22 +113,22 @@ unique_ptr Connection::PendingQuery(unique_ptr } unique_ptr Connection::PendingQuery(const string &query, - case_insensitive_map_t &named_values, + identifier_map_t &named_values, QueryParameters query_parameters) { return context->PendingQuery(query, named_values, query_parameters); } unique_ptr Connection::PendingQuery(unique_ptr statement, - case_insensitive_map_t &named_values, + identifier_map_t &named_values, QueryParameters query_parameters) { return context->PendingQuery(std::move(statement), named_values, query_parameters); } -static case_insensitive_map_t ConvertParamListToMap(vector ¶m_list) { - case_insensitive_map_t named_values; +static identifier_map_t ConvertParamListToMap(vector ¶m_list) { + identifier_map_t named_values; for (idx_t i = 0; i < param_list.size(); i++) { auto &val = param_list[i]; - named_values[std::to_string(i + 1)] = BoundParameterData(val); + named_values[Identifier(std::to_string(i + 1))] = BoundParameterData(val); } return named_values; } @@ -181,17 +170,17 @@ unique_ptr Connection::QueryParamsRecursive(const string &query, ve return pending->Execute(); } -unique_ptr Connection::TableInfo(const string &database_name, const string &schema_name, - const string &table_name) { +unique_ptr Connection::TableInfo(const Identifier &database_name, const Identifier &schema_name, + const Identifier &table_name) { return context->TableInfo(database_name, schema_name, table_name); } -unique_ptr Connection::TableInfo(const string &schema_name, const string &table_name) { - return TableInfo(INVALID_CATALOG, schema_name, table_name); +unique_ptr Connection::TableInfo(const Identifier &schema_name, const Identifier &table_name) { + return TableInfo(Identifier::InvalidCatalog(), schema_name, table_name); } -unique_ptr Connection::TableInfo(const string &table_name) { - return TableInfo(INVALID_CATALOG, DEFAULT_SCHEMA, table_name); +unique_ptr Connection::TableInfo(const Identifier &table_name) { + return TableInfo(Identifier::InvalidCatalog(), Identifier::DefaultSchema(), table_name); } vector> Connection::ExtractStatements(const string &query) { @@ -206,20 +195,21 @@ void Connection::Append(TableDescription &description, ColumnDataCollection &col context->Append(description, collection); } -shared_ptr Connection::Table(const string &table_name) { - return Table(DEFAULT_SCHEMA, table_name); +shared_ptr Connection::Table(const Identifier &table_name) { + return Table(Identifier::DefaultSchema(), table_name); } -shared_ptr Connection::Table(const string &schema_name, const string &table_name) { - auto table_info = TableInfo(INVALID_CATALOG, schema_name, table_name); +shared_ptr Connection::Table(const Identifier &schema_name, const Identifier &table_name) { + auto table_info = TableInfo(Identifier::InvalidCatalog(), schema_name, table_name); if (!table_info) { - throw CatalogException("Table %s does not exist!", ParseInfo::QualifierToString("", schema_name, table_name)); + throw CatalogException("Table %s does not exist!", + ParseInfo::QualifierToString(Identifier(), schema_name, table_name)); } return make_shared_ptr(context, std::move(table_info)); } -shared_ptr Connection::Table(const string &catalog_name, const string &schema_name, - const string &table_name) { +shared_ptr Connection::Table(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &table_name) { unique_ptr table_info; do { table_info = TableInfo(catalog_name, schema_name, table_name); @@ -228,7 +218,7 @@ shared_ptr Connection::Table(const string &catalog_name, const string } if (catalog_name.empty() && !schema_name.empty()) { - table_info = TableInfo(schema_name, DEFAULT_SCHEMA, table_name); + table_info = TableInfo(schema_name, Identifier::DefaultSchema(), table_name); } } while (false); @@ -239,11 +229,11 @@ shared_ptr Connection::Table(const string &catalog_name, const string return make_shared_ptr(context, std::move(table_info)); } -shared_ptr Connection::View(const string &tname) { - return View(DEFAULT_SCHEMA, tname); +shared_ptr Connection::View(const Identifier &tname) { + return View(Identifier::DefaultSchema(), tname); } -shared_ptr Connection::View(const string &schema_name, const string &table_name) { +shared_ptr Connection::View(const Identifier &schema_name, const Identifier &table_name) { return make_shared_ptr(context, schema_name, table_name); } @@ -310,7 +300,7 @@ shared_ptr Connection::ReadCSV(const string &csv_file, const vector files {csv_file}; return make_shared_ptr(context, files, std::move(options)); diff --git a/src/duckdb/src/main/connection_manager.cpp b/src/duckdb/src/main/connection_manager.cpp index f97d373aa..4035d018d 100644 --- a/src/duckdb/src/main/connection_manager.cpp +++ b/src/duckdb/src/main/connection_manager.cpp @@ -38,17 +38,16 @@ void ConnectionManager::AssignConnectionId(Connection &connection) { vector> ConnectionManager::GetConnectionList() { lock_guard lock(connections_lock); vector> result; - for (auto &it : connections) { - auto connection = it.second.lock(); + for (auto it = connections.begin(); it != connections.end();) { + auto connection = it->second.lock(); if (!connection) { - connections.erase(it.first); - connection_count = connections.size(); - continue; + it = connections.erase(it); } else { result.push_back(std::move(connection)); + ++it; } } - + connection_count = connections.size(); return result; } diff --git a/src/duckdb/src/main/database.cpp b/src/duckdb/src/main/database.cpp index 9a4d6e609..b5daa611f 100644 --- a/src/duckdb/src/main/database.cpp +++ b/src/duckdb/src/main/database.cpp @@ -1,9 +1,11 @@ #include "duckdb/main/database.hpp" +#include "duckdb/main/metrics_manager.hpp" #include "duckdb/parser/peg/matcher.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/common/http_util.hpp" #include "duckdb/common/virtual_file_system.hpp" +#include "duckdb/common/local_file_system.hpp" #include "duckdb/execution/index/index_type_set.hpp" #include "duckdb/execution/operator/helper/physical_set.hpp" #include "duckdb/function/cast/cast_function_set.hpp" @@ -135,6 +137,10 @@ FileSystem &FileSystem::GetFileSystem(DatabaseInstance &db) { return db.GetFileSystem(); } +FileSystem &FileSystem::GetLocal(DatabaseInstance &db) { + return db.GetLocalFileSystem(); +} + FileSystem &FileSystem::Get(AttachedDatabase &db) { return FileSystem::GetFileSystem(db.GetDatabase()); } @@ -202,7 +208,7 @@ shared_ptr DatabaseInstance::CreateAttachedDatabase(ClientCont void DatabaseInstance::CreateMainDatabase() { AttachInfo info; - info.name = AttachedDatabase::ExtractDatabaseName(config.options.database_path, GetFileSystem()); + info.name = Identifier(AttachedDatabase::ExtractDatabaseName(config.options.database_path, GetFileSystem())); info.path = config.options.database_path; Connection con(*this); @@ -285,6 +291,7 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf create_api_v1 = CreateAPIv1Wrapper; db_file_system = make_uniq(*this); + local_db_file_system = make_uniq(*this); db_manager = make_uniq(*this); if (config.buffer_manager) { buffer_manager = config.buffer_manager; @@ -295,6 +302,8 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf log_manager = make_uniq(*this, LogConfig()); log_manager->Initialize(); + metrics_manager = make_uniq(); + bool enable_external_file_cache = Settings::Get(config); external_file_cache = make_uniq(*this, enable_external_file_cache); result_set_manager = make_uniq(*this); @@ -308,13 +317,13 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf // initialize the secret manager config.secret_manager->Initialize(*this); + // initialize the system catalog + db_manager->InitializeSystemCatalog(); + // resolve the type of the database we are opening auto &fs = FileSystem::GetFileSystem(*this); DBPathAndType::ResolveDatabaseType(fs, config.options.database_path, config.options.database_type); - // initialize the system catalog - db_manager->InitializeSystemCatalog(); - if (!config.options.database_type.empty() && !StringUtil::CIEquals(config.options.database_type, "duckdb")) { // if we are opening an extension database - load the extension if (!config.file_system) { @@ -322,7 +331,7 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf } auto storage_extension = StorageExtension::Find(config, config.options.database_type); if (!storage_extension) { - ExtensionHelper::LoadExternalExtension(*this, *config.file_system, config.options.database_type); + ExtensionHelper::LoadExternalExtension(*this, *config.file_system, {config.options.database_type}); } } @@ -334,6 +343,7 @@ void DatabaseInstance::Initialize(const char *database_path, DBConfig *user_conf // only increase thread count after storage init because we get races on catalog otherwise scheduler->SetThreads(config.options.maximum_threads, Settings::Get(config)); + scheduler->SetAsyncThreads(config.options.async_threads); scheduler->RelaunchThreads(); } @@ -390,6 +400,32 @@ FileSystem &DatabaseInstance::GetFileSystem() { return *db_file_system; } +FileSystem &DatabaseInstance::GetLocalFileSystem() { + return *local_db_file_system; +} + +static FileSystem &ResolveLocalFileSystem(DatabaseInstance &db, unique_ptr &owned) { + auto &vfs = static_cast(*db.config.file_system); + auto &default_fs = vfs.GetDefaultFileSystem(); + if (default_fs.IsLocalFileSystem()) { + return default_fs; + } + owned = make_uniq(); + return *owned; +} + +LocalDatabaseFileSystem::LocalDatabaseFileSystem(DatabaseInstance &db_p) + : db(db_p), local_fs(ResolveLocalFileSystem(db_p, owned_file_system)), database_opener(db_p) { +} + +FileSystem &LocalDatabaseFileSystem::GetFileSystem() const { + auto &vfs = static_cast(*db.config.file_system); + if (vfs.SubSystemIsDisabled(local_fs.GetName())) { + throw PermissionException("File system %s has been disabled by configuration", local_fs.GetName()); + } + return local_fs; +} + ExternalFileCache &DatabaseInstance::GetExternalFileCache() { return *external_file_cache; } @@ -440,6 +476,10 @@ void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path config.SetDefaultTempDirectory(); } + if (new_config.options.http_proxy.empty()) { + HTTPProxySetting::ResetGlobal(this, config); + } + if (config.options.access_mode == AccessMode::UNDEFINED) { config.options.access_mode = AccessMode::READ_WRITE; } @@ -466,6 +506,9 @@ void DatabaseInstance::Configure(DBConfig &new_config, const char *database_path if (new_config.options.maximum_threads == DConstants::INVALID_INDEX) { config.options.maximum_threads = config.GetSystemMaxThreads(*config.file_system); } + if (new_config.options.async_threads == DConstants::INVALID_INDEX) { + config.options.async_threads = config.GetSystemMaxAsyncThreads(*config.file_system); + } config.allocator = std::move(new_config.allocator); if (!config.allocator) { config.allocator = make_uniq(); @@ -594,6 +637,10 @@ LogManager &DatabaseInstance::GetLogManager() const { return *log_manager; } +MetricsManager &DatabaseInstance::GetMetricsManager() { + return *metrics_manager; +} + ValidChecker &ValidChecker::Get(DatabaseInstance &db) { return db.GetValidChecker(); } diff --git a/src/duckdb/src/main/database_file_path_manager.cpp b/src/duckdb/src/main/database_file_path_manager.cpp index 3381b4b21..ffe667bc6 100644 --- a/src/duckdb/src/main/database_file_path_manager.cpp +++ b/src/duckdb/src/main/database_file_path_manager.cpp @@ -6,8 +6,8 @@ namespace duckdb { -DatabasePathInfo::DatabasePathInfo(DatabaseManager &manager, string name_p, AccessMode access_mode) - : name(std::move(name_p)), access_mode(access_mode) { +DatabasePathInfo::DatabasePathInfo(DatabaseManager &manager, const Identifier &name_p, AccessMode access_mode) + : name(name_p.GetIdentifierName()), access_mode(access_mode) { attached_databases.insert(manager); } @@ -17,7 +17,8 @@ idx_t DatabaseFilePathManager::ApproxDatabaseCount() const { } InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(DatabaseManager &manager, const string &path, - const string &name, OnCreateConflict on_conflict, + const Identifier &name, + OnCreateConflict on_conflict, AttachOptions &options) { if (path.empty() || path == IN_MEMORY_PATH) { throw InternalException("DatabaseFilePathManager::InsertDatabasePath - cannot insert in-memory database"); diff --git a/src/duckdb/src/main/database_manager.cpp b/src/duckdb/src/main/database_manager.cpp index b3f9dfe96..45dd1dcc2 100644 --- a/src/duckdb/src/main/database_manager.cpp +++ b/src/duckdb/src/main/database_manager.cpp @@ -4,6 +4,7 @@ #include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_data.hpp" +#include "duckdb/main/query_profiler.hpp" #include "duckdb/main/database.hpp" #include "duckdb/main/database_path_and_type.hpp" #include "duckdb/main/extension_helper.hpp" @@ -18,7 +19,7 @@ namespace duckdb { // Oids are started at 20000 to avoid colliding with Postgres builtin types, which end at 16383: // https://github.com/postgres/postgres/blob/db93988ab0e78396f2ed9e96c826ff988d12b9f2/src/include/access/transam.h#L156-L197 DatabaseManager::DatabaseManager(DatabaseInstance &db) - : db(db), next_oid(20000), current_query_number(1), current_transaction_id(0) { + : db(db), next_oid(20000), current_query_number(1), current_transaction_id(0), remote_catalog_count(0) { system = make_shared_ptr(db); auto &config = DBConfig::GetConfig(db); path_manager = config.path_manager; @@ -47,7 +48,7 @@ void DatabaseManager::FinalizeStartup() { } } -optional_ptr DatabaseManager::GetDatabase(ClientContext &context, const string &name) { +optional_ptr DatabaseManager::GetDatabase(ClientContext &context, const Identifier &name) { auto &meta_transaction = MetaTransaction::Get(context); // first check if we have a local reference to this database already auto database = meta_transaction.GetReferencedDatabase(name); @@ -57,7 +58,7 @@ optional_ptr DatabaseManager::GetDatabase(ClientContext &conte } lock_guard guard(databases_lock); shared_ptr db; - if (StringUtil::Lower(name) == TEMP_CATALOG) { + if (name == TEMP_CATALOG) { db = context.client_data->temporary_objects; } else { db = GetDatabaseInternal(guard, name); @@ -68,13 +69,13 @@ optional_ptr DatabaseManager::GetDatabase(ClientContext &conte return meta_transaction.UseDatabase(db); } -shared_ptr DatabaseManager::GetDatabase(const string &name) { +shared_ptr DatabaseManager::GetDatabase(const Identifier &name) { lock_guard guard(databases_lock); return GetDatabaseInternal(guard, name); } -shared_ptr DatabaseManager::GetDatabaseInternal(const lock_guard &, const string &name) { - if (StringUtil::Lower(name) == SYSTEM_CATALOG) { +shared_ptr DatabaseManager::GetDatabaseInternal(const lock_guard &, const Identifier &name) { + if (name == SYSTEM_CATALOG) { return system; } auto entry = databases.find(name); @@ -135,6 +136,16 @@ shared_ptr DatabaseManager::AttachDatabase(ClientContext &cont existing_db->GetCatalog().SetDefaultTable(options.default_table.schema, options.default_table.name); } if (info.on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT) { + // we require the vacuuming threshold for indexed tables to be the same as the already attached db + if (options.vacuum_rebuild_indexes_threshold.IsValid()) { + auto previous_setting = existing_db->GetVacuumRebuildIndexThreshold(); + auto new_setting = options.vacuum_rebuild_indexes_threshold.GetIndex(); + if (previous_setting != new_setting) { + throw BinderException("Cannot re-attach with a different vacuum_rebuild_indexes setting " + "(previous: %d, new: %d)", + previous_setting, new_setting); + } + } // allow custom catalogs to override this behavior if (!existing_db->GetCatalog().HasConflictingAttachOptions(info.path, options)) { return existing_db; @@ -147,14 +158,13 @@ shared_ptr DatabaseManager::AttachDatabase(ClientContext &cont if (requires_tracking_attaches) { // Start timing the ATTACH-delay step. - auto profiler = context.client_data->profiler->StartTimer(MetricType::WAITING_TO_ATTACH_LATENCY); - + auto timer = context.client_data->profiler->StartTimer(); + // Start trying to attach. while (InsertDatabasePath(info, options) == InsertDatabasePathResult::ALREADY_EXISTS) { // database with this name and path already exists // first check if it exists within this transaction auto &meta_transaction = MetaTransaction::Get(context); - auto existing_db = meta_transaction.GetReferencedDatabaseOwning(info.name); - if (existing_db) { + if (auto existing_db = meta_transaction.GetReferencedDatabaseOwning(info.name)) { // it does! return it return existing_db; } @@ -169,6 +179,8 @@ shared_ptr DatabaseManager::AttachDatabase(ClientContext &cont } context.InterruptCheck(); } + // Returning in the loop above will also end the timer, otherwise, do it explicitly here. + timer.EndTimer(); } auto &config = DBConfig::GetConfig(context); GetDatabaseType(context, info, config, options); @@ -178,7 +190,8 @@ shared_ptr DatabaseManager::AttachDatabase(ClientContext &cont options.stored_database_path.reset(); } if (AttachedDatabase::NameIsReserved(info.name)) { - throw BinderException("Attached database name \"%s\" cannot be used because it is a reserved name", info.name); + throw BinderException("Attached database name \"%s\" cannot be used because it is a reserved name", + info.name.GetIdentifierName()); } if (!extension.empty()) { if (!ExtensionHelper::TryAutoLoadExtension(context, extension)) { @@ -231,10 +244,16 @@ optional_ptr DatabaseManager::FinalizeAttach(ClientContext &co } auto &meta_transaction = MetaTransaction::Get(context); if (detached_db) { + if (detached_db->GetCatalog().Supports(RemoteCapability::IS_REMOTE)) { + --remote_catalog_count; + } meta_transaction.DetachDatabase(*detached_db); detached_db->OnDetach(context); detached_db.reset(); } + if (attached_db->GetCatalog().Supports(RemoteCapability::IS_REMOTE)) { + ++remote_catalog_count; + } auto &db_ref = meta_transaction.UseDatabase(attached_db); auto &transaction = DuckTransaction::Get(context, *system); auto &transaction_manager = DuckTransactionManager::Get(*system); @@ -242,17 +261,18 @@ optional_ptr DatabaseManager::FinalizeAttach(ClientContext &co return db_ref; } -void DatabaseManager::DetachDatabase(ClientContext &context, const string &name, OnEntryNotFound if_not_found) { - if (StringUtil::CIEquals(GetDefaultDatabase(context), name)) { +void DatabaseManager::DetachDatabase(ClientContext &context, const Identifier &name, OnEntryNotFound if_not_found) { + if (GetDefaultDatabase(context) == name) { throw BinderException("Cannot detach database \"%s\" because it is the default database. Select a different " "database using `USE` to allow detaching this database", - name); + name.GetIdentifierName()); } auto attached_db = DetachInternal(name); if (!attached_db) { if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw BinderException("Failed to detach database with name \"%s\": database not found", name); + throw BinderException("Failed to detach database with name \"%s\": database not found", + name.GetIdentifierName()); } return; } @@ -277,10 +297,11 @@ void DatabaseManager::Alter(ClientContext &context, AlterInfo &info) { } } -void DatabaseManager::RenameDatabase(ClientContext &context, const string &old_name, const string &new_name, +void DatabaseManager::RenameDatabase(ClientContext &context, const Identifier &old_name, const Identifier &new_name, OnEntryNotFound if_not_found) { if (AttachedDatabase::NameIsReserved(new_name)) { - throw BinderException("Database name \"%s\" cannot be used because it is a reserved name", new_name); + throw BinderException("Database name \"%s\" cannot be used because it is a reserved name", + new_name.GetIdentifierName()); } shared_ptr attached_db; @@ -289,7 +310,8 @@ void DatabaseManager::RenameDatabase(ClientContext &context, const string &old_n auto old_entry = databases.find(old_name); if (old_entry == databases.end()) { if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw BinderException("Failed to rename database \"%s\": database not found", old_name); + throw BinderException("Failed to rename database \"%s\": database not found", + old_name.GetIdentifierName()); } return; } @@ -297,7 +319,7 @@ void DatabaseManager::RenameDatabase(ClientContext &context, const string &old_n auto new_entry = databases.find(new_name); if (new_entry != databases.end()) { throw BinderException("Failed to rename database \"%s\" to \"%s\": database with new name already exists", - old_name, new_name); + old_name.GetIdentifierName(), new_name.GetIdentifierName()); } attached_db = old_entry->second; @@ -306,12 +328,12 @@ void DatabaseManager::RenameDatabase(ClientContext &context, const string &old_n databases[new_name] = attached_db; } - if (StringUtil::CIEquals(default_database, old_name)) { + if (old_name == default_database) { default_database = new_name; } } -shared_ptr DatabaseManager::DetachInternal(const string &name) { +shared_ptr DatabaseManager::DetachInternal(const Identifier &name) { shared_ptr attached_db; { lock_guard guard(databases_lock); @@ -322,6 +344,9 @@ shared_ptr DatabaseManager::DetachInternal(const string &name) attached_db = std::move(entry->second); databases.erase(entry); } + if (attached_db && attached_db->GetCatalog().Supports(RemoteCapability::IS_REMOTE)) { + --remote_catalog_count; + } return attached_db; } @@ -380,11 +405,11 @@ void DatabaseManager::GetDatabaseType(ClientContext &context, AttachInfo &info, // FIXME: Here it might be preferable to use an AutoLoadOrThrow kind of function // so that either there will be success or a message to throw, and load will be // attempted only once respecting the auto-loading options - ExtensionHelper::LoadExternalExtension(context, options.db_type); + ExtensionHelper::LoadExternalExtension(context, {options.db_type}); } } -const string &DatabaseManager::GetDefaultDatabase(ClientContext &context) { +Identifier DatabaseManager::GetDefaultDatabase(ClientContext &context) { auto &config = ClientData::Get(context); auto &default_entry = config.catalog_search_path->GetDefault(); if (IsInvalidCatalog(default_entry.catalog)) { @@ -399,7 +424,7 @@ const string &DatabaseManager::GetDefaultDatabase(ClientContext &context) { // LCOV_EXCL_START void DatabaseManager::SetDefaultDatabase(ClientContext &context, const string &new_value) { - auto db_entry = GetDatabase(context, new_value); + auto db_entry = GetDatabase(context, Identifier(new_value)); if (!db_entry) { throw InternalException("Database \"%s\" not found", new_value); @@ -409,7 +434,7 @@ void DatabaseManager::SetDefaultDatabase(ClientContext &context, const string &n throw InternalException("Cannot set the default database to a system database"); } - default_database = new_value; + default_database = Identifier(new_value); } // LCOV_EXCL_STOP diff --git a/src/duckdb/src/main/database_path_and_type.cpp b/src/duckdb/src/main/database_path_and_type.cpp index d3f8c0282..24d06e3c0 100644 --- a/src/duckdb/src/main/database_path_and_type.cpp +++ b/src/duckdb/src/main/database_path_and_type.cpp @@ -12,7 +12,10 @@ void DBPathAndType::ExtractExtensionPrefix(string &path, string &db_type) { if (!extension.empty()) { // path is prefixed with an extension - remove the first occurrence of it path = path.substr(extension.length() + 1); - db_type = ExtensionHelper::ApplyExtensionAlias(extension); + // Store the raw user prefix normalized to lowercase. The alias is + // applied only at lookup/comparison sites — symmetric with how the + // `TYPE 'xxx'` option preserves the user-supplied value. + db_type = StringUtil::Lower(extension); } } diff --git a/src/duckdb/src/main/extension/extension_alias.cpp b/src/duckdb/src/main/extension/extension_alias.cpp index 9089eb269..e1b84546f 100644 --- a/src/duckdb/src/main/extension/extension_alias.cpp +++ b/src/duckdb/src/main/extension/extension_alias.cpp @@ -21,11 +21,12 @@ idx_t ExtensionHelper::ExtensionAliasCount() { return index; } -ExtensionAlias ExtensionHelper::GetExtensionAlias(idx_t index) { +ExtensionAlias ExtensionHelper::GetInternalExtensionAlias(idx_t index) { D_ASSERT(index < ExtensionAliasCount()); return internal_aliases[index]; } +// todo; we might want to update this string ExtensionHelper::ApplyExtensionAlias(const string &extension_name) { auto lname = StringUtil::Lower(extension_name); for (idx_t index = 0; internal_aliases[index].alias; index++) { diff --git a/src/duckdb/src/main/extension/extension_helper.cpp b/src/duckdb/src/main/extension/extension_helper.cpp index 3b2cde48b..1ec4a3de0 100644 --- a/src/duckdb/src/main/extension/extension_helper.cpp +++ b/src/duckdb/src/main/extension/extension_helper.cpp @@ -1,6 +1,8 @@ #include "duckdb/main/extension_helper.hpp" #include "duckdb/common/file_system.hpp" +#include "duckdb/common/local_file_system.hpp" +#include "duckdb/main/database_file_opener.hpp" #include "duckdb/common/serializer/binary_deserializer.hpp" #include "duckdb/common/serializer/buffered_file_reader.hpp" #include "duckdb/common/string_util.hpp" @@ -46,10 +48,6 @@ #define DUCKDB_EXTENSION_JSON_LINKED false #endif -#ifndef DUCKDB_EXTENSION_JEMALLOC_LINKED -#define DUCKDB_EXTENSION_JEMALLOC_LINKED false -#endif - #ifndef DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED #define DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED false #endif @@ -84,10 +82,6 @@ #include "json_extension.hpp" #endif -#if DUCKDB_EXTENSION_JEMALLOC_LINKED -#include "jemalloc_extension.hpp" -#endif - #if DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED #include "autocomplete_extension.hpp" #endif @@ -108,7 +102,6 @@ static const DefaultExtension internal_extensions[] = { {"tpcds", "Adds TPC-DS data generation and query support", DUCKDB_EXTENSION_TPCDS_LINKED}, {"httpfs", "Adds support for reading and writing files over a HTTP(S) connection", DUCKDB_EXTENSION_HTTPFS_LINKED}, {"json", "Adds support for JSON operations", DUCKDB_EXTENSION_JSON_LINKED}, - {"jemalloc", "Overwrites system allocator with JEMalloc", DUCKDB_EXTENSION_JEMALLOC_LINKED}, {"autocomplete", "Adds support for autocomplete in the shell", DUCKDB_EXTENSION_AUTOCOMPLETE_LINKED}, {"motherduck", "Enables motherduck integration with the system", false}, {"mysql_scanner", "Adds support for connecting to a MySQL database", false}, @@ -126,8 +119,11 @@ static const DefaultExtension internal_extensions[] = { {"fts", "Adds support for Full-Text Search Indexes", false}, {"ui", "Adds local UI for DuckDB", false}, {"ducklake", "Adds support for DuckLake, SQL as a Lakehouse Format", false}, + {"quack", "The DuckDB 'Quack' Client/Server Protocol", false}, {"vortex", "Adds support for reading and writing files using the Vortex file format", false}, {"lance", "Adds support for querying Lance datasets", false}, + {"avro", "Adds support for reading Avro files", false}, + {"unity_catalog", "Adds support for connecting to Unity Catalog", false}, {nullptr, nullptr, false}}; idx_t ExtensionHelper::DefaultExtensionCount() { @@ -219,7 +215,7 @@ bool ExtensionHelper::TryAutoLoadExtension(ClientContext &context, const string options.repository = autoinstall_repo; ExtensionHelper::InstallExtension(context, extension_name, options); } - ExtensionHelper::LoadExternalExtension(context, extension_name); + ExtensionHelper::LoadExternalExtension(context, {extension_name}); return true; } catch (...) { return false; @@ -249,7 +245,7 @@ bool ExtensionHelper::TryAutoLoadExtension(DatabaseInstance &instance, const str ExtensionHelper::InstallExtension(instance, fs, extension_name, options); } if (Settings::Get(instance)) { - ExtensionHelper::LoadExternalExtension(instance, fs, extension_name); + ExtensionHelper::LoadExternalExtension(instance, fs, {extension_name}); return true; } return false; @@ -264,7 +260,7 @@ bool ExtensionHelper::TryAutoLoadAvailableExtension(DatabaseInstance &instance, } try { auto &fs = FileSystem::GetFileSystem(instance); - ExtensionHelper::LoadExternalExtension(instance, fs, extension_name); + ExtensionHelper::LoadExternalExtension(instance, fs, {extension_name}); return true; } catch (...) { return false; @@ -353,7 +349,7 @@ vector ExtensionHelper::UpdateExtensions(ClientContext &c DatabaseInstance &db = DatabaseInstance::GetDatabase(context); #ifndef WASM_LOADABLE_EXTENSIONS - case_insensitive_set_t seen_extensions; + identifier_set_t seen_extensions; // scan the install directory for installed extensions auto ext_directory = ExtensionHelper::ExtensionDirectory(db, fs); @@ -365,7 +361,7 @@ vector ExtensionHelper::UpdateExtensions(ClientContext &c auto extension_file_name = StringUtil::GetFileName(path); auto extension_name = StringUtil::Split(extension_file_name, ".")[0]; - seen_extensions.insert(extension_name); + seen_extensions.insert(Identifier(extension_name)); result.push_back(UpdateExtensionInternal(context, db, fs, fs.JoinPath(ext_directory, path), extension_name)); }); @@ -403,17 +399,17 @@ void ExtensionHelper::AutoLoadExtension(DatabaseInstance &db, const string &exte } auto &dbconfig = DBConfig::GetConfig(db); try { - auto fs = FileSystem::CreateLocal(); + auto &fs = FileSystem::GetLocal(db); #ifndef DUCKDB_WASM if (Settings::Get(db)) { auto repository_url = GetAutoInstallExtensionsRepository(dbconfig); auto autoinstall_repo = ExtensionRepository::GetRepositoryByUrl(repository_url); ExtensionInstallOptions options; options.repository = autoinstall_repo; - ExtensionHelper::InstallExtension(db, *fs, extension_name, options); + ExtensionHelper::InstallExtension(db, fs, extension_name, options); } #endif - ExtensionHelper::LoadExternalExtension(db, *fs, extension_name); + ExtensionHelper::LoadExternalExtension(db, fs, {extension_name}); DUCKDB_LOG_INFO(db, "Loaded extension '%s'", extension_name); } catch (std::exception &e) { ErrorData error(e); diff --git a/src/duckdb/src/main/extension/extension_install.cpp b/src/duckdb/src/main/extension/extension_install.cpp index fc5816da3..7dbd4cec0 100644 --- a/src/duckdb/src/main/extension/extension_install.cpp +++ b/src/duckdb/src/main/extension/extension_install.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/gzip_file_system.hpp" #include "duckdb/common/http_util.hpp" #include "duckdb/common/local_file_system.hpp" +#include "duckdb/main/database_file_opener.hpp" #include "duckdb/common/serializer/binary_serializer.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/types/uuid.hpp" @@ -174,7 +175,7 @@ bool ExtensionHelper::CreateSuggestions(const string &extension_name, string &me candidates.emplace_back(ExtensionHelper::GetDefaultExtension(i).name); } for (idx_t ext_count = ExtensionHelper::ExtensionAliasCount(), i = 0; i < ext_count; i++) { - candidates.emplace_back(ExtensionHelper::GetExtensionAlias(i).alias); + candidates.emplace_back(ExtensionHelper::GetInternalExtensionAlias(i).alias); } auto closest_extensions = StringUtil::TopNJaroWinkler(candidates, lowercase_extension_name); message = StringUtil::CandidatesMessage(closest_extensions, "Candidate extensions"); @@ -411,11 +412,11 @@ static unique_ptr InstallFromHttpUrl(DatabaseInstance &db, optional_ptr context) { unique_ptr install_info; { - auto fs = FileSystem::CreateLocal(); - if (fs->FileExists(local_extension_path + ".info")) { + auto &fs = FileSystem::GetLocal(db); + if (fs.FileExists(local_extension_path + ".info")) { try { install_info = - ExtensionInstallInfo::TryReadInfoFile(*fs, local_extension_path + ".info", extension_name); + ExtensionInstallInfo::TryReadInfoFile(fs, local_extension_path + ".info", extension_name); } catch (...) { if (!options.force_install) { // We are going to rewrite the file anyhow, so this is fine @@ -527,10 +528,6 @@ static unique_ptr InstallFromRepository(DatabaseInstance & context); } -static bool IsHTTP(const string &path) { - return StringUtil::StartsWith(path, "http://") || !StringUtil::StartsWith(path, "https://"); -} - static void ThrowErrorOnMismatchingExtensionOrigin(FileSystem &fs, const string &local_extension_path, const string &extension_name, const string &extension, optional_ptr repository) { @@ -598,16 +595,16 @@ unique_ptr ExtensionHelper::InstallExtensionInternal(Datab } // Install extension from local, direct url - if (ExtensionHelper::IsFullPath(extension) && !IsHTTP(extension)) { - LocalFileSystem local_fs; + if (ExtensionHelper::IsFullPath(extension) && !FileSystem::IsRemoteFile(extension)) { + auto &local_fs = FileSystem::GetLocal(db); return DirectInstallExtension(db, local_fs, extension, temp_path, extension, local_extension_path, options, context); } // Install extension from local url based on a repository (Note that this will install it as a local file) - if (options.repository && !IsHTTP(options.repository->path)) { - LocalFileSystem local_fs; - return InstallFromRepository(db, fs, extension, extension_name, temp_path, local_extension_path, options, + if (options.repository && !FileSystem::IsRemoteFile(options.repository->path)) { + auto &local_fs = FileSystem::GetLocal(db); + return InstallFromRepository(db, local_fs, extension, extension_name, temp_path, local_extension_path, options, context); } diff --git a/src/duckdb/src/main/extension/extension_load.cpp b/src/duckdb/src/main/extension/extension_load.cpp index f2f02429f..d1c8d8ee9 100644 --- a/src/duckdb/src/main/extension/extension_load.cpp +++ b/src/duckdb/src/main/extension/extension_load.cpp @@ -595,6 +595,7 @@ string ExtensionHelper::GetExtensionName(const string &original_name) { if (!IsFullPath(extension)) { return ExtensionHelper::ApplyExtensionAlias(extension); } + // split the name if it's a full path auto splits = StringUtil::Split(StringUtil::Replace(extension, "\\", "/"), '/'); if (splits.empty()) { return ExtensionHelper::ApplyExtensionAlias(extension); @@ -606,14 +607,34 @@ string ExtensionHelper::GetExtensionName(const string &original_name) { return ExtensionHelper::ApplyExtensionAlias(splits.front()); } -void ExtensionHelper::LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension) { +void ExtensionHelper::LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const ExtensionLoadOptions &options) { + // If this is a logical extension name (not an explicit path), prefer the + // statically linked implementation for built-in linked extensions only. + // This avoids loading a second copy from disk (ASan ODR violation) while + // keeping externally installed/autoloaded extensions on the normal path. + auto logical_name = ExtensionHelper::GetExtensionName(options.extension_name); + if (!ExtensionHelper::IsFullPath(options.extension_name)) { + for (idx_t i = 0; i < ExtensionHelper::DefaultExtensionCount(); i++) { + auto default_extension = ExtensionHelper::GetDefaultExtension(i); + if (!default_extension.statically_loaded || logical_name != default_extension.name) { + continue; + } + DuckDB db_wrapper(db); + auto load_result = ExtensionHelper::LoadExtension(db_wrapper, logical_name); + if (load_result == ExtensionLoadResult::LOADED_EXTENSION) { + return; + } + break; + } + } + auto &manager = ExtensionManager::Get(db); - auto info = manager.BeginLoad(extension); + auto info = manager.BeginLoad(options); if (!info) { return; } try { - LoadExternalExtensionInternal(db, fs, extension, *info); + LoadExternalExtensionInternal(db, fs, options.extension_name, *info); } catch (std::exception &ex) { ErrorData error(ex); info->LoadFail(error); @@ -699,8 +720,8 @@ void ExtensionHelper::LoadExternalExtensionInternal(DatabaseInstance &db, FileSy #endif } -void ExtensionHelper::LoadExternalExtension(ClientContext &context, const string &extension) { - LoadExternalExtension(DatabaseInstance::GetDatabase(context), FileSystem::GetFileSystem(context), extension); +void ExtensionHelper::LoadExternalExtension(ClientContext &context, const ExtensionLoadOptions &options) { + LoadExternalExtension(DatabaseInstance::GetDatabase(context), FileSystem::GetFileSystem(context), options); } string ExtensionHelper::ExtractExtensionPrefixFromPath(const string &path) { diff --git a/src/duckdb/src/main/extension/extension_loader.cpp b/src/duckdb/src/main/extension/extension_loader.cpp index 5cf2a9585..d8f6ffc9f 100644 --- a/src/duckdb/src/main/extension/extension_loader.cpp +++ b/src/duckdb/src/main/extension/extension_loader.cpp @@ -2,6 +2,7 @@ #include "duckdb/function/scalar_function.hpp" #include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" +#include "duckdb/parser/parsed_data/create_schema_info.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "duckdb/parser/parsed_data/create_copy_function_info.hpp" #include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" @@ -14,91 +15,177 @@ #include "duckdb/parser/parsed_data/create_collation_info.hpp" #include "duckdb/main/extension_install_info.hpp" #include "duckdb/catalog/catalog.hpp" +#include "duckdb/main/client_data.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/secret/secret_manager.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/main/metrics_manager.hpp" + +#include "duckdb/main/extension_callback_manager.hpp" +#include "re2/re2.h" namespace duckdb { -ExtensionLoader::ExtensionLoader(ExtensionActiveLoad &load_info) - : db(load_info.db), extension_name(load_info.extension_name), extension_info(load_info.info) { +ExtensionLoader::ExtensionLoader(const ExtensionActiveLoad &load_info) + : db(load_info.db), extension_info(load_info.info) { + loader_info.extension_name = load_info.extension_name; + loader_info.extension_alias = load_info.alias; } -ExtensionLoader::ExtensionLoader(DatabaseInstance &db, const string &name) : db(db), extension_name(name) { +ExtensionLoader::ExtensionLoader(DatabaseInstance &db, const string &name) : db(db) { + loader_info.extension_name = Identifier(name); } -DatabaseInstance &ExtensionLoader::GetDatabaseInstance() { +DatabaseInstance &ExtensionLoader::GetDatabaseInstance() const { return db; } void ExtensionLoader::SetDescription(const string &description) { - extension_description = description; + loader_info.extension_description = description; +} + +void ExtensionLoader::UseDedicatedSchemaForExtension(const Identifier &extension_schema_name) { + CreateSchema(extension_schema_name); + UseDefaultSchema(extension_schema_name); + AddSchemaToSearchPath(extension_schema_name); +} + +void ExtensionLoader::UseDedicatedSchemaForExtension() { + auto registered_ext_name = GetRegisteredExtensionName(); + UseDedicatedSchemaForExtension(registered_ext_name); +} + +void ExtensionLoader::CreateSchema(const Identifier &name) const { + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + + CreateSchemaInfo info; + info.schema = name; + info.internal = true; + // TODO; we can give the user more control here + info.on_conflict = OnCreateConflict::ERROR_ON_CONFLICT; + system_catalog.CreateSchema(data, info); +} + +void ExtensionLoader::UseDefaultSchema(const Identifier &name) { + if (loader_info.extension_schema != DEFAULT_SCHEMA && name != DEFAULT_SCHEMA && + loader_info.extension_schema != name) { + throw InvalidInputException("Cannot set extension schema to '%s', schema is already set to '%s'", name, + loader_info.extension_schema); + } + if (name == "pg_catalog") { + throw InvalidInputException("Cannot set default extension schema to '%s'", name); + } + if (name == DEFAULT_SCHEMA) { + loader_info.extension_schema = Identifier::DefaultSchema(); + return; + } + loader_info.extension_schema = name; +} + +void ExtensionLoader::AddSchemaToSearchPath(const Identifier &schema_name) const { + // adds an explicitly set extension schema to the search path + if (loader_info.extension_schema != schema_name || schema_name == DEFAULT_SCHEMA || + loader_info.extension_schema == DEFAULT_SCHEMA) { + throw InvalidInputException("Cannot add schema '%s' to search path, first set the extension schema explicitly " + "with UseDefaultSchema()", + schema_name); + } + + // check if schema already exists in the system catalog + auto &system_catalog = Catalog::GetSystemCatalog(db); + auto data = CatalogTransaction::GetSystemTransaction(db); + auto schema = system_catalog.GetSchema(data, schema_name, OnEntryNotFound::RETURN_NULL); + if (!schema) { + throw InvalidInputException("Cannot add schema '%s' to search path: schema does not exist. " + "Call CreateExtensionSchema() first.", + schema_name); + } + + // TODO: remove extension schema from search path if loading extension failed + ExtensionCallbackManager::Get(db).AddExtensionSchema(loader_info.extension_schema); +} + +void ExtensionLoader::RefreshSearchPath(ClientContext &context) { + ClientData::Get(context).catalog_search_path->RefreshSetPaths(); } void ExtensionLoader::FinalizeLoad() { // Set extension description, if provided - if (!extension_description.empty() && extension_info) { + if (!loader_info.extension_description.empty() && extension_info) { auto info = make_uniq(); - info->description = extension_description; + info->description = loader_info.extension_description; extension_info->load_info = std::move(info); } } void ExtensionLoader::RegisterFunction(ScalarFunction function) { - ScalarFunctionSet set(function.name); + ScalarFunctionSet set {function.name}; set.AddFunction(std::move(function)); RegisterFunction(std::move(set)); } void ExtensionLoader::RegisterFunction(ScalarFunctionSet function) { CreateScalarFunctionInfo info(std::move(function)); + info.schema = loader_info.extension_schema; info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; RegisterFunction(std::move(info)); } void ExtensionLoader::RegisterFunction(CreateScalarFunctionInfo function) { D_ASSERT(!function.functions.name.empty()); - function.extension_name = extension_name; + function.extension_name = GetRegisteredExtensionName(); + if (function.schema == DEFAULT_SCHEMA) { + function.schema = loader_info.extension_schema; + } auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); system_catalog.CreateFunction(data, function); } void ExtensionLoader::RegisterFunction(AggregateFunction function) { - AggregateFunctionSet set(function.name); + AggregateFunctionSet set {function.name}; set.AddFunction(std::move(function)); RegisterFunction(std::move(set)); } void ExtensionLoader::RegisterFunction(AggregateFunctionSet function) { CreateAggregateFunctionInfo info(std::move(function)); + info.schema = loader_info.extension_schema; info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; RegisterFunction(std::move(info)); } void ExtensionLoader::RegisterFunction(CreateAggregateFunctionInfo function) { D_ASSERT(!function.functions.name.empty()); - function.extension_name = extension_name; + if (function.schema == DEFAULT_SCHEMA) { + function.schema = loader_info.extension_schema; + } + function.extension_name = GetRegisteredExtensionName(); auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); system_catalog.CreateFunction(data, function); } void ExtensionLoader::RegisterFunction(WindowFunction function) { - WindowFunctionSet set(function.name); + WindowFunctionSet set {function.name}; set.AddFunction(std::move(function)); RegisterFunction(std::move(set)); } void ExtensionLoader::RegisterFunction(WindowFunctionSet function) { CreateWindowFunctionInfo info(std::move(function)); + info.schema = loader_info.extension_schema; info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; RegisterFunction(std::move(info)); } void ExtensionLoader::RegisterFunction(CreateWindowFunctionInfo function) { D_ASSERT(!function.functions.name.empty()); - function.extension_name = extension_name; + if (function.schema == DEFAULT_SCHEMA) { + function.schema = loader_info.extension_schema; + } + function.extension_name = GetRegisteredExtensionName(); auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); system_catalog.CreateFunction(data, function); @@ -111,7 +198,7 @@ void ExtensionLoader::RegisterFunction(CreateSecretFunction function) { } void ExtensionLoader::RegisterFunction(TableFunction function) { - TableFunctionSet set(function.name); + TableFunctionSet set {function.name}; set.AddFunction(std::move(function)); RegisterFunction(std::move(set)); } @@ -119,13 +206,17 @@ void ExtensionLoader::RegisterFunction(TableFunction function) { void ExtensionLoader::RegisterFunction(TableFunctionSet function) { D_ASSERT(!function.name.empty()); CreateTableFunctionInfo info(std::move(function)); + info.schema = loader_info.extension_schema; info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; RegisterFunction(std::move(info)); } void ExtensionLoader::RegisterFunction(CreateTableFunctionInfo info) { D_ASSERT(!info.functions.name.empty()); - info.extension_name = extension_name; + info.extension_name = GetRegisteredExtensionName(); + if (info.schema == DEFAULT_SCHEMA) { + info.schema = loader_info.extension_schema; + } auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); system_catalog.CreateFunction(data, info); @@ -133,16 +224,16 @@ void ExtensionLoader::RegisterFunction(CreateTableFunctionInfo info) { void ExtensionLoader::RegisterFunction(PragmaFunction function) { D_ASSERT(!function.name.empty()); - PragmaFunctionSet set(function.name); + PragmaFunctionSet set {function.name}; set.AddFunction(std::move(function)); RegisterFunction(std::move(set)); } void ExtensionLoader::RegisterFunction(PragmaFunctionSet function) { D_ASSERT(!function.name.empty()); - auto function_name = function.name; - CreatePragmaFunctionInfo info(std::move(function_name), std::move(function)); - info.extension_name = extension_name; + CreatePragmaFunctionInfo info(std::move(function)); + info.extension_name = GetRegisteredExtensionName(); + info.schema = loader_info.extension_schema; auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); system_catalog.CreatePragmaFunction(data, info); @@ -150,29 +241,37 @@ void ExtensionLoader::RegisterFunction(PragmaFunctionSet function) { void ExtensionLoader::RegisterFunction(CopyFunction function) { CreateCopyFunctionInfo info(std::move(function)); - info.extension_name = extension_name; + info.extension_name = GetRegisteredExtensionName(); + info.schema = loader_info.extension_schema; auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); system_catalog.CreateCopyFunction(data, info); } void ExtensionLoader::RegisterFunction(CreateMacroInfo &info) { - info.extension_name = extension_name; + info.extension_name = GetRegisteredExtensionName(); + if (info.schema == DEFAULT_SCHEMA) { + info.schema = loader_info.extension_schema; + } auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); system_catalog.CreateFunction(data, info); } void ExtensionLoader::RegisterCollation(CreateCollationInfo &info) { - info.extension_name = extension_name; + info.extension_name = GetRegisteredExtensionName(); auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); + if (info.schema == DEFAULT_SCHEMA) { + info.schema = loader_info.extension_schema; + } info.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; system_catalog.CreateCollation(data, info); // Also register as a function for serialisation CreateScalarFunctionInfo finfo(info.function); - finfo.extension_name = extension_name; + finfo.extension_name = GetRegisteredExtensionName(); + finfo.schema = loader_info.extension_schema; finfo.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; system_catalog.CreateFunction(data, finfo); } @@ -205,19 +304,19 @@ void ExtensionLoader::AddFunctionOverload(TableFunctionSet functions) { // NOLIN } } -static optional_ptr TryGetEntry(DatabaseInstance &db, const string &name, CatalogType type) { +static optional_ptr TryGetEntry(DatabaseInstance &db, const Identifier &name, CatalogType type) { D_ASSERT(!name.empty()); auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); - auto &schema = system_catalog.GetSchema(data, DEFAULT_SCHEMA); + auto &schema = system_catalog.GetSchema(data, Identifier::DefaultSchema()); return schema.GetEntry(data, type, name); } -optional_ptr ExtensionLoader::TryGetFunction(const string &name) { +optional_ptr ExtensionLoader::TryGetFunction(const Identifier &name) { return TryGetEntry(db, name, CatalogType::SCALAR_FUNCTION_ENTRY); } -ScalarFunctionCatalogEntry &ExtensionLoader::GetFunction(const string &name) { +ScalarFunctionCatalogEntry &ExtensionLoader::GetFunction(const Identifier &name) { auto catalog_entry = TryGetFunction(name); if (!catalog_entry) { throw InvalidInputException("Function with name \"%s\" not found in ExtensionLoader::GetFunction", name); @@ -225,11 +324,11 @@ ScalarFunctionCatalogEntry &ExtensionLoader::GetFunction(const string &name) { return catalog_entry->Cast(); } -optional_ptr ExtensionLoader::TryGetTableFunction(const string &name) { +optional_ptr ExtensionLoader::TryGetTableFunction(const Identifier &name) { return TryGetEntry(db, name, CatalogType::TABLE_FUNCTION_ENTRY); } -TableFunctionCatalogEntry &ExtensionLoader::GetTableFunction(const string &name) { +TableFunctionCatalogEntry &ExtensionLoader::GetTableFunction(const Identifier &name) { auto catalog_entry = TryGetTableFunction(name); if (!catalog_entry) { throw InvalidInputException("Function with name \"%s\" not found in ExtensionLoader::GetTableFunction", name); @@ -242,7 +341,8 @@ void ExtensionLoader::RegisterType(string type_name, LogicalType type, bind_logi CreateTypeInfo info(std::move(type_name), std::move(type), bind_modifiers); info.temporary = true; info.internal = true; - info.extension_name = extension_name; + info.extension_name = GetRegisteredExtensionName(); + info.schema = loader_info.extension_schema; auto &system_catalog = Catalog::GetSystemCatalog(db); auto data = CatalogTransaction::GetSystemTransaction(db); system_catalog.CreateType(data, info); @@ -267,4 +367,8 @@ void ExtensionLoader::RegisterCastFunction(const LogicalType &source, const Logi casts.RegisterCastFunction(source, target, std::move(function), implicit_cast_cost); } +void ExtensionLoader::RegisterMetric(MetricInfo info) { + MetricsManager::Get(db).RegisterMetric(std::move(info)); +} + } // namespace duckdb diff --git a/src/duckdb/src/main/extension_callback_manager.cpp b/src/duckdb/src/main/extension_callback_manager.cpp index fc0423d02..ff606ad70 100644 --- a/src/duckdb/src/main/extension_callback_manager.cpp +++ b/src/duckdb/src/main/extension_callback_manager.cpp @@ -40,6 +40,14 @@ ExtensionCallbackManager::ExtensionCallbackManager() : callback_registry(make_sh ExtensionCallbackManager::~ExtensionCallbackManager() { } +void ExtensionCallbackManager::AddExtensionSchema(const Identifier &schema) { + extension_schemas.push_back(schema.GetIdentifierName()); +} + +vector ExtensionCallbackManager::GetExtensionSchemas() const { + return extension_schemas; +} + void ExtensionCallbackManager::Register(ParserExtension extension) { lock_guard guard(registry_lock); auto new_registry = make_shared_ptr(*callback_registry); diff --git a/src/duckdb/src/main/extension_manager.cpp b/src/duckdb/src/main/extension_manager.cpp index 24459fabf..b31487b51 100644 --- a/src/duckdb/src/main/extension_manager.cpp +++ b/src/duckdb/src/main/extension_manager.cpp @@ -9,25 +9,35 @@ namespace duckdb { ExtensionInfo::ExtensionInfo() : is_loaded(false) { } -ExtensionActiveLoad::ExtensionActiveLoad(DatabaseInstance &db, ExtensionInfo &info, string extension_name_p) - : db(db), load_lock(info.lock), info(info), extension_name(std::move(extension_name_p)) { -} - void ExtensionActiveLoad::FinishLoad(ExtensionInstallInfo &install_info) { info.is_loaded = true; info.install_info = make_uniq(install_info); + string final_extension_name = extension_name.GetIdentifierName(); + + if (!alias.empty()) { + final_extension_name = alias.GetIdentifierName(); + } for (auto &callback : ExtensionCallback::Iterate(db)) { - callback->OnExtensionLoaded(db, extension_name); + callback->OnExtensionLoaded(db, final_extension_name); } - DUCKDB_LOG_INFO(db, extension_name); + + if (!alias.empty()) { + DUCKDB_LOG_INFO(db, alias.GetIdentifierName()); + } + + DUCKDB_LOG_INFO(db, extension_name.GetIdentifierName()); } void ExtensionActiveLoad::LoadFail(const ErrorData &error) { for (auto &callback : ExtensionCallback::Iterate(db)) { - callback->OnExtensionLoadFail(db, extension_name, error); + callback->OnExtensionLoadFail(db, extension_name.GetIdentifierName(), error); } - DUCKDB_LOG_INFO(db, "Failed to load extension '%s': %s", extension_name, error.Message()); + if (!alias.empty()) { + load_lock.unlock(); + ExtensionManager::Get(db).RemoveExtensionInfo(alias.GetIdentifierName()); + } + DUCKDB_LOG_INFO(db, "Failed to load extension '%s': %s", extension_name.GetIdentifierName(), error.Message()); } ExtensionManager::ExtensionManager(DatabaseInstance &db) : db(db) { @@ -38,14 +48,19 @@ ExtensionManager &ExtensionManager::Get(DatabaseInstance &db) { } ExtensionManager &ExtensionManager::Get(ClientContext &context) { - return ExtensionManager::Get(DatabaseInstance::GetDatabase(context)); + return Get(DatabaseInstance::GetDatabase(context)); +} + +void ExtensionManager::RemoveExtensionInfo(const string &name) { + lock_guard guard(lock); + loaded_extensions_info.erase(Identifier(name)); } optional_ptr ExtensionManager::GetExtensionInfo(const string &name) { auto extension_name = ExtensionHelper::GetExtensionName(name); lock_guard guard(lock); - auto entry = loaded_extensions_info.find(extension_name); + auto entry = loaded_extensions_info.find(Identifier(extension_name)); if (entry == loaded_extensions_info.end()) { return nullptr; } @@ -57,7 +72,7 @@ vector ExtensionManager::GetExtensions() { vector result; for (auto &entry : loaded_extensions_info) { - result.push_back(entry.first); + result.push_back(entry.first.GetIdentifierName()); } return result; } @@ -70,18 +85,56 @@ bool ExtensionManager::ExtensionIsLoaded(const string &name) { return info->is_loaded; } -unique_ptr ExtensionManager::BeginLoad(const string &name) { - auto extension_name = ExtensionHelper::GetExtensionName(name); - +unique_ptr ExtensionManager::BeginLoad(const ExtensionLoadOptions &options) { unique_lock extension_list_lock(lock); + if (!options.alias.empty()) { + if (loaded_extensions_info.find(Identifier(options.alias.GetIdentifierName())) != + loaded_extensions_info.end()) { + auto &info = loaded_extensions_info[Identifier(options.alias.GetIdentifierName())]; + throw InvalidInputException("Alias '%s' already exists for extension '%s'", + options.alias.GetIdentifierName(), info->orig_ext_name); + } + } + + string extension_name; + extension_name = ExtensionHelper::GetExtensionName(options.extension_name); + string original_extension_name = extension_name; + optional_ptr info; - auto entry = loaded_extensions_info.find(extension_name); + auto entry = loaded_extensions_info.end(); + + if (options.alias.empty()) { + // no alias given, we first search for the original name + entry = loaded_extensions_info.find(Identifier(extension_name)); + } + if (entry == loaded_extensions_info.end()) { + // check if the extension is not already registered under an alias + // we don't have an entry yet - create one auto extension_info = make_uniq(); + extension_info->orig_ext_name = original_extension_name; info = extension_info.get(); + + if (!options.alias.empty()) { + // if aliasing is used, we need to register the alias instead of extension name + extension_name = options.alias.GetIdentifierName(); + extension_info->alias = extension_name; + info->alias = extension_name; + } else { + for (auto &loaded_extension : loaded_extensions_info) { + auto loaded_ext_info = loaded_extension.second.get(); + if (loaded_ext_info->orig_ext_name == original_extension_name) { + throw InvalidInputException("Extension '%s' already exists under alias '%s'. If you want to load " + "the extension twice, you can do this by specifying an alias.", + original_extension_name, loaded_ext_info->alias); + } + } + } + loaded_extensions_info.emplace(extension_name, std::move(extension_info)); + } else { // we already have an entry if (entry->second->is_loaded) { @@ -95,7 +148,7 @@ unique_ptr ExtensionManager::BeginLoad(const string &name) // we have an extension and we want to try to load it - instantiate the load // we instantiate the ExtensionActiveLoad which also grabs the lock for loading the specific extension - auto result = make_uniq(db, *info, extension_name); + auto result = make_uniq(db, *info, Identifier(original_extension_name), options.alias); // we now have a lock for loading the extension // HOWEVER - another thread might have finished loading in the meantime - double check to avoid a double load diff --git a/src/duckdb/src/main/gathered_metrics.cpp b/src/duckdb/src/main/gathered_metrics.cpp new file mode 100644 index 000000000..e0adc0831 --- /dev/null +++ b/src/duckdb/src/main/gathered_metrics.cpp @@ -0,0 +1,145 @@ +#include "duckdb/main/gathered_metrics.hpp" + +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/scalar/string_common.hpp" +#include "duckdb/main/profiling_utils.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/logging/log_manager.hpp" + +#include "yyjson.hpp" + +using namespace duckdb_yyjson; // NOLINT + +namespace duckdb { + +static bool IsPrefixPattern(const string &pattern) { + if (pattern.empty() || pattern.back() != '*') { + return false; + } + for (idx_t i = 0; i + 1 < pattern.size(); i++) { + if (pattern[i] == '*' || pattern[i] == '?' || pattern[i] == '[') { + return false; + } + } + return true; +} + +GatheredMetrics::GatheredMetrics(const vector &tracked_metrics) { + InitTrackedMetrics(tracked_metrics); + ResetMetrics(); +} + +void GatheredMetrics::InitTrackedMetrics(const vector &patterns) { + for (const auto &pattern : patterns) { + if (pattern == "*") { + track_all = true; + return; + } else if (!FileSystem::HasGlob(pattern)) { + tracked_exact.insert(pattern); + } else if (IsPrefixPattern(pattern)) { + tracked_prefixes.insert(pattern.substr(0, pattern.size() - 1)); + } else { + tracked_globs.push_back(pattern); + } + } +} + +void GatheredMetrics::ResetMetrics() { + metrics.clear(); +} + +bool GatheredMetrics::MetricIsTracked(const string &key) const { + if (track_all) { + return true; + } + // Exact match from tracked_metrics + if (tracked_exact.find(key) != tracked_exact.end()) { + return true; + } + // Prefix match: walk backwards from the largest prefix <= key + { + auto it = tracked_prefixes.upper_bound(key); + while (it != tracked_prefixes.begin()) { + --it; + const auto &prefix = *it; + if (prefix.size() <= key.size() && key.compare(0, prefix.size(), prefix) == 0) { + return true; + } + } + } + // Arbitrary glob patterns + for (const auto &glob : tracked_globs) { + if (Glob(key.c_str(), key.size(), glob.c_str(), glob.size())) { + return true; + } + } + return false; +} + +void GatheredMetrics::SetMetric(const string &key, Value new_value) { + if (!MetricIsTracked(key)) { + return; + } + metrics[key] = std::move(new_value); +} + +void GatheredMetrics::SetMetric(const string &key, idx_t value) { + if (!MetricIsTracked(key)) { + return; + } + metrics[key] = Value::UBIGINT(value); +} + +void GatheredMetrics::SetMetric(const string &key, double value) { + if (!MetricIsTracked(key)) { + return; + } + metrics[key] = Value::DOUBLE(value); +} + +void GatheredMetrics::SetMetric(const string &key, const string &value) { + if (!MetricIsTracked(key)) { + return; + } + metrics[key] = Value(value); +} + +void GatheredMetrics::WriteMetricsToLog(ClientContext &context) const { + auto &logger = Logger::Get(context); + if (logger.ShouldLog(MetricsLogType::NAME, MetricsLogType::LEVEL)) { + for (const auto &entry : metrics) { + logger.WriteLog(MetricsLogType::NAME, MetricsLogType::LEVEL, + MetricsLogType::ConstructLogMessage(entry.first, entry.second)); + } + } +} + +static QueryProfileResult &GetOrCreateObject(QueryProfileResult &parent, const string &key) { + for (auto &child : parent.children) { + if (child->kind == QueryProfileResultKind::OBJECT && child->key == key) { + return *child; + } + } + return parent.AddObject(key); +} + +static void AddMetricToResult(QueryProfileResult &obj, const string &key, const Value &value) { + auto dot_pos = key.find('.'); + if (dot_pos == string::npos) { + obj.AddValue(key, value); + } else { + auto prefix = key.substr(0, dot_pos); + auto suffix = key.substr(dot_pos + 1); + auto &sub_obj = GetOrCreateObject(obj, prefix); + AddMetricToResult(sub_obj, suffix, value); + } +} + +void GatheredMetrics::MetricsToProfileResult(QueryProfileResult &result) const { + for (auto &entry : metrics) { + AddMetricToResult(result, entry.first, entry.second); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/http/http_util.cpp b/src/duckdb/src/main/http/http_util.cpp index e4ef83db8..62bfe3ee8 100644 --- a/src/duckdb/src/main/http/http_util.cpp +++ b/src/duckdb/src/main/http/http_util.cpp @@ -117,6 +117,7 @@ bool HTTPResponse::ShouldRetry() const { case HTTPStatusCode::ImATeapot_418: case HTTPStatusCode::TooManyRequests_429: case HTTPStatusCode::InternalServerError_500: + case HTTPStatusCode::BadGateway_502: case HTTPStatusCode::ServiceUnavailable_503: case HTTPStatusCode::GatewayTimeout_504: return true; @@ -265,20 +266,14 @@ unique_ptr HTTPUtil::SendRequest(BaseRequest &request, unique_ptr< } try { - if (request.have_request_timing) { - request.request_start = Timestamp::GetCurrentTimestamp(); - } + request.request_start = Timestamp::GetCurrentTimestamp(); response = client->Request(request); } catch (...) { - if (request.have_request_timing) { - request.request_end = Timestamp::GetCurrentTimestamp(); - } + request.request_end = Timestamp::GetCurrentTimestamp(); LogRequest(request, nullptr); throw; } - if (request.have_request_timing) { - request.request_end = Timestamp::GetCurrentTimestamp(); - } + request.request_end = Timestamp::GetCurrentTimestamp(); LogRequest(request, response ? response.get() : nullptr); return response; }); @@ -486,7 +481,7 @@ HTTPUtil::RunRequestWithRetry(const std::function(void) void HTTPParams::Initialize(optional_ptr opener) { auto db = FileOpener::TryGetDatabase(opener); if (db) { - auto http_proxy_setting = Settings::Get(*db); + auto &http_proxy_setting = db->config.options.http_proxy; if (!http_proxy_setting.empty()) { idx_t port; string host; @@ -500,10 +495,7 @@ void HTTPParams::Initialize(optional_ptr opener) { auto client_context = FileOpener::TryGetClientContext(opener); if (client_context) { - auto &client_config = ClientConfig::GetConfig(*client_context); - if (client_config.enable_http_logging) { - logger = client_context->logger; - } + logger = client_context->logger; } } diff --git a/src/duckdb/src/main/metrics_manager.cpp b/src/duckdb/src/main/metrics_manager.cpp new file mode 100644 index 000000000..670a4c5bf --- /dev/null +++ b/src/duckdb/src/main/metrics_manager.cpp @@ -0,0 +1,132 @@ +#include "duckdb/main/metrics_manager.hpp" +#include "duckdb/main/metrics.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/common/enums/optimizer_type.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/enum_util.hpp" + +namespace duckdb { + +namespace { + +struct MetricDescriptor { + const char *name; + const char *metric_type; + const char *description; + const char *unit; +}; + +} // namespace + +// clang-format off +#define DUCKDB_METRIC(METRIC) \ + {METRIC::Name, METRIC::TypeStr, METRIC::Description, METRIC::Unit} +#define FINAL_METRIC \ + {nullptr, nullptr, nullptr, nullptr} + +static const MetricDescriptor internal_metrics[] = { + DUCKDB_METRIC(MetricQueryCPUTime), + DUCKDB_METRIC(MetricQuerySQL), + DUCKDB_METRIC(MetricQueryTotalIntermediateRows), + DUCKDB_METRIC(MetricQueryTotalIntermediateSizeBytes), + DUCKDB_METRIC(MetricQueryTotalRowGroupsScanned), + DUCKDB_METRIC(MetricQueryTotalRowGroupsToScan), + DUCKDB_METRIC(MetricQueryTotalRowsScanned), + DUCKDB_METRIC(MetricQueryTotalTime), + DUCKDB_METRIC(MetricSystemBlockedThreadTime), + DUCKDB_METRIC(MetricSystemPeakBufferMemory), + DUCKDB_METRIC(MetricSystemPeakTempDirSize), + DUCKDB_METRIC(MetricSystemTotalMemoryAllocated), + DUCKDB_METRIC(MetricIOTotalBytesRead), + DUCKDB_METRIC(MetricIOTotalBytesWritten), + DUCKDB_METRIC(MetricStorageAttachLoadStorageLatency), + DUCKDB_METRIC(MetricStorageAttachReplayWALLatency), + DUCKDB_METRIC(MetricStorageCheckpointLatency), + DUCKDB_METRIC(MetricStorageCommitLocalStorageLatency), + DUCKDB_METRIC(MetricStorageTotalVacuumTime), + DUCKDB_METRIC(MetricStorageWaitingToAttachLatency), + DUCKDB_METRIC(MetricStorageWALReplayEntryCount), + DUCKDB_METRIC(MetricStorageWriteToWALLatency), + DUCKDB_METRIC(MetricOperatorCPUTime), + DUCKDB_METRIC(MetricOperatorExtraInfo), + DUCKDB_METRIC(MetricOperatorIntermediateRows), + DUCKDB_METRIC(MetricOperatorIntermediateSizeBytes), + DUCKDB_METRIC(MetricOperatorRowGroupsScanned), + DUCKDB_METRIC(MetricOperatorRowsScanned), + DUCKDB_METRIC(MetricOperatorTiming), + DUCKDB_METRIC(MetricOperatorTotalRowGroupsToScan), + DUCKDB_METRIC(MetricOperatorType), + DUCKDB_METRIC(MetricOptimizerTotalTime), + DUCKDB_METRIC(MetricParserTotalTime), + DUCKDB_METRIC(MetricPlannerBindingTime), + DUCKDB_METRIC(MetricPlannerTotalTime), + DUCKDB_METRIC(MetricPhysicalPlannerColumnBinding), + DUCKDB_METRIC(MetricPhysicalPlannerCreatePlan), + DUCKDB_METRIC(MetricPhysicalPlannerResolveTypes), + DUCKDB_METRIC(MetricPhysicalPlannerTotalTime), + FINAL_METRIC, +}; +// clang-format on + +#undef DUCKDB_METRIC +#undef FINAL_METRIC + +MetricsManager &MetricsManager::Get(ClientContext &context) { + return DatabaseInstance::GetDatabase(context).GetMetricsManager(); +} + +MetricsManager &MetricsManager::Get(DatabaseInstance &db) { + return db.GetMetricsManager(); +} + +idx_t MetricsManager::GetMetricCount() const { + lock_guard guard(registered_metrics_lock); + idx_t count = sizeof(internal_metrics) / sizeof(MetricDescriptor) - 1; + count += static_cast(OptimizerType::PARTIAL_AGGREGATE_PUSHDOWN) - + static_cast(OptimizerType::EXPRESSION_REWRITER) + 1; + count += registered_metrics.size(); + return count; +} + +vector MetricsManager::GetAllMetrics() const { + lock_guard guard(registered_metrics_lock); + vector result; + for (idx_t i = 0; internal_metrics[i].name; i++) { + MetricInfo info; + info.name = internal_metrics[i].name; + info.metric_type = internal_metrics[i].metric_type; + info.description = internal_metrics[i].description; + info.unit = internal_metrics[i].unit; + result.push_back(std::move(info)); + } + // Optimizer metrics are generated at runtime from the OptimizerType enum. + for (auto opt = static_cast(OptimizerType::EXPRESSION_REWRITER); + opt <= static_cast(OptimizerType::PARTIAL_AGGREGATE_PUSHDOWN); opt++) { + auto opt_type = static_cast(opt); + string opt_name = StringUtil::Lower(EnumUtil::ToString(opt_type)); + MetricInfo info; + info.name = "optimizer." + opt_name; + info.metric_type = "double"; + info.description = "Time spent in the " + opt_name + " optimizer"; + info.unit = "seconds"; + result.push_back(std::move(info)); + } + // Extension-registered metrics. + for (const auto &m : registered_metrics) { + result.push_back(m); + } + return result; +} + +void MetricsManager::RegisterMetric(MetricInfo info) { + lock_guard guard(registered_metrics_lock); + for (const auto &existing : registered_metrics) { + if (existing.name == info.name) { + return; // silently ignore duplicate + } + } + registered_metrics.push_back(std::move(info)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/main/pending_query_result.cpp b/src/duckdb/src/main/pending_query_result.cpp index 8c4d615c2..612413244 100644 --- a/src/duckdb/src/main/pending_query_result.cpp +++ b/src/duckdb/src/main/pending_query_result.cpp @@ -7,7 +7,7 @@ namespace duckdb { PendingQueryResult::PendingQueryResult(shared_ptr context_p, PreparedStatementData &statement, vector types_p, bool allow_stream_result) : BaseQueryResult(QueryResultType::PENDING_RESULT, statement.statement_type, statement.properties, - std::move(types_p), statement.names), + std::move(types_p), IdentifiersToStrings(statement.names)), context(std::move(context_p)), allow_stream_result(allow_stream_result) { } diff --git a/src/duckdb/src/main/prepared_statement.cpp b/src/duckdb/src/main/prepared_statement.cpp index b0de40a1c..fd33a3c08 100644 --- a/src/duckdb/src/main/prepared_statement.cpp +++ b/src/duckdb/src/main/prepared_statement.cpp @@ -6,7 +6,7 @@ namespace duckdb { PreparedStatement::PreparedStatement(shared_ptr context, shared_ptr data_p, - string query, case_insensitive_map_t named_param_map_p) + string query, identifier_map_t named_param_map_p) : context(std::move(context)), data(std::move(data_p)), query(std::move(query)), success(true), named_param_map(std::move(named_param_map_p)) { D_ASSERT(data || !success); @@ -51,7 +51,7 @@ const vector &PreparedStatement::GetTypes() { return data->types; } -const vector &PreparedStatement::GetNames() { +const vector &PreparedStatement::GetNames() { D_ASSERT(data); return data->names; } @@ -63,12 +63,12 @@ case_insensitive_map_t PreparedStatement::GetExpectedParameterTypes auto &identifier = it.first; D_ASSERT(data->value_map.count(identifier)); D_ASSERT(it.second); - expected_types[identifier] = it.second->GetValue().type(); + expected_types[identifier.GetIdentifierName()] = it.second->GetValue().type(); } return expected_types; } -unique_ptr PreparedStatement::Execute(case_insensitive_map_t &named_values, +unique_ptr PreparedStatement::Execute(identifier_map_t &named_values, bool allow_stream_result) { if (!success) { return make_uniq( @@ -93,23 +93,23 @@ unique_ptr PreparedStatement::Execute(case_insensitive_map_t PreparedStatement::Execute(vector &values, bool allow_stream_result) { - case_insensitive_map_t named_values; + identifier_map_t named_values; for (idx_t i = 0; i < values.size(); i++) { - named_values[std::to_string(i + 1)] = BoundParameterData(values[i]); + named_values[Identifier(std::to_string(i + 1))] = BoundParameterData(values[i]); } return Execute(named_values, allow_stream_result); } unique_ptr PreparedStatement::PendingQuery(vector &values, bool allow_stream_result) { - case_insensitive_map_t named_values; + identifier_map_t named_values; for (idx_t i = 0; i < values.size(); i++) { auto &val = values[i]; - named_values[std::to_string(i + 1)] = BoundParameterData(val); + named_values[Identifier(std::to_string(i + 1))] = BoundParameterData(val); } return PendingQuery(named_values, allow_stream_result); } -unique_ptr PreparedStatement::PendingQuery(case_insensitive_map_t &named_values, +unique_ptr PreparedStatement::PendingQuery(identifier_map_t &named_values, bool allow_stream_result) { if (!success) { auto exception = InvalidInputException("Attempting to execute an unsuccessfully prepared statement!"); diff --git a/src/duckdb/src/main/prepared_statement_data.cpp b/src/duckdb/src/main/prepared_statement_data.cpp index 7ef2edb41..da01d0f0a 100644 --- a/src/duckdb/src/main/prepared_statement_data.cpp +++ b/src/duckdb/src/main/prepared_statement_data.cpp @@ -22,7 +22,7 @@ void PreparedStatementData::CheckParameterCount(idx_t parameter_count) { } } -bool CheckCatalogIdentity(ClientContext &context, const string &catalog_name, +bool CheckCatalogIdentity(ClientContext &context, const Identifier &catalog_name, const StatementProperties::CatalogIdentity catalog_identity) { // some catalogs don't support catalog version, we can't check identity in that case if (!catalog_identity.catalog_version.IsValid()) { @@ -40,7 +40,7 @@ bool CheckCatalogIdentity(ClientContext &context, const string &catalog_name, } bool PreparedStatementData::RequireRebind(ClientContext &context, - optional_ptr> values) { + optional_ptr> values) { idx_t count = values ? values->size() : 0; CheckParameterCount(count); if (!unbound_statement) { @@ -78,15 +78,15 @@ bool PreparedStatementData::RequireRebind(ClientContext &context, return false; } -void PreparedStatementData::Bind(case_insensitive_map_t values) { +void PreparedStatementData::Bind(identifier_map_t values) { // set parameters D_ASSERT(!unbound_statement || unbound_statement->named_param_map.size() == properties.parameter_count); CheckParameterCount(values.size()); // bind the required values for (auto &it : value_map) { - const string &identifier = it.first; - auto lookup = values.find(identifier); + const string &identifier = it.first.GetIdentifierName(); + auto lookup = values.find(it.first); if (lookup == values.end()) { throw BinderException("Could not find parameter with identifier %s", identifier); } @@ -101,7 +101,7 @@ void PreparedStatementData::Bind(case_insensitive_map_t valu } } -bool PreparedStatementData::TryGetType(const string &identifier, LogicalType &result) { +bool PreparedStatementData::TryGetType(const Identifier &identifier, LogicalType &result) { auto it = value_map.find(identifier); if (it == value_map.end()) { return false; @@ -114,7 +114,7 @@ bool PreparedStatementData::TryGetType(const string &identifier, LogicalType &re return true; } -LogicalType PreparedStatementData::GetType(const string &identifier) { +LogicalType PreparedStatementData::GetType(const Identifier &identifier) { LogicalType result; if (!TryGetType(identifier, result)) { throw BinderException("Could not find parameter identified with: %s", identifier); diff --git a/src/duckdb/src/main/profiling_info.cpp b/src/duckdb/src/main/profiling_info.cpp deleted file mode 100644 index ff709b7de..000000000 --- a/src/duckdb/src/main/profiling_info.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include "duckdb/main/profiling_info.hpp" - -#include "duckdb/common/enum_util.hpp" -#include "duckdb/main/profiling_utils.hpp" -#include "duckdb/logging/log_manager.hpp" - -#include "yyjson.hpp" - -using namespace duckdb_yyjson; // NOLINT - -namespace duckdb { - -ProfilingInfo::ProfilingInfo(const profiler_settings_t &n_settings, const idx_t depth) : settings(n_settings) { - // Expand. - if (depth > 0) { - settings.insert(MetricType::OPERATOR_NAME); - settings.insert(MetricType::OPERATOR_TYPE); - } - for (const auto &metric : settings) { - Expand(expanded_settings, metric); - } - - // Reduce. - if (depth == 0) { - auto op_metrics = MetricsUtils::GetOperatorMetrics(); - for (const auto metric : op_metrics) { - settings.erase(metric); - } - } else { - auto root_metrics = MetricsUtils::GetRootScopeMetrics(); - for (const auto metric : root_metrics) { - settings.erase(metric); - } - } - ResetMetrics(); -} - -void ProfilingInfo::ResetMetrics() { - metrics.clear(); - for (auto &metric : expanded_settings) { - if (MetricsUtils::IsOptimizerMetric(metric) || MetricsUtils::IsPhaseTimingMetric(metric)) { - metrics[metric] = Value::CreateValue(0.0); - continue; - } - - ProfilingUtils::SetMetricToDefault(metrics, metric); - } -} - -bool ProfilingInfo::Enabled(const profiler_settings_t &settings, const MetricType metric) { - if (settings.find(metric) != settings.end()) { - return true; - } - return false; -} - -void ProfilingInfo::Expand(profiler_settings_t &settings, const MetricType metric) { - settings.insert(metric); - - switch (metric) { - case MetricType::CPU_TIME: - settings.insert(MetricType::OPERATOR_TIMING); - return; - case MetricType::CUMULATIVE_CARDINALITY: - settings.insert(MetricType::OPERATOR_CARDINALITY); - return; - case MetricType::CUMULATIVE_ROWS_SCANNED: - settings.insert(MetricType::OPERATOR_ROWS_SCANNED); - return; - case MetricType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricType::ALL_OPTIMIZERS: { - auto optimizer_metrics = MetricsUtils::GetOptimizerMetrics(); - for (const auto optimizer_metric : optimizer_metrics) { - settings.insert(optimizer_metric); - } - return; - } - default: - return; - } -} - -string ProfilingInfo::GetMetricAsString(const MetricType metric) const { - if (!Enabled(settings, metric)) { - throw InternalException("Metric %s not enabled", EnumUtil::ToString(metric)); - } - - // The metric cannot be NULL and must be initialized. - D_ASSERT(!metrics.at(metric).IsNull()); - if (metric == MetricType::OPERATOR_TYPE) { - const auto type = PhysicalOperatorType(metrics.at(metric).GetValue()); - return EnumUtil::ToString(type); - } - return metrics.at(metric).ToString(); -} - -void ProfilingInfo::WriteMetricsToLog(ClientContext &context) { - auto &logger = Logger::Get(context); - if (logger.ShouldLog(MetricsLogType::NAME, MetricsLogType::LEVEL)) { - for (auto &metric : settings) { - logger.WriteLog(MetricsLogType::NAME, MetricsLogType::LEVEL, - MetricsLogType::ConstructLogMessage(metric, metrics[metric])); - } - } -} - -void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest) { - for (auto &metric : settings) { - auto metric_str = StringUtil::Lower(EnumUtil::ToString(metric)); - auto key_val = yyjson_mut_strcpy(doc, metric_str.c_str()); - auto key_ptr = yyjson_mut_get_str(key_val); - - if (metric == MetricType::EXTRA_INFO) { - auto extra_info_obj = yyjson_mut_obj(doc); - - auto extra_info = metrics.at(metric); - auto children = MapValue::GetChildren(extra_info); - for (auto &child : children) { - auto struct_children = StructValue::GetChildren(child); - auto key = struct_children[0].GetValue(); - auto value = struct_children[1].GetValue(); - - auto key_mut = unsafe_yyjson_mut_strncpy(doc, key.c_str(), key.size()); - auto value_mut = unsafe_yyjson_mut_strncpy(doc, value.c_str(), value.size()); - - auto splits = StringUtil::Split(value_mut, "\n"); - if (splits.size() > 1) { - auto list_items = yyjson_mut_arr(doc); - for (auto &split : splits) { - yyjson_mut_arr_add_strcpy(doc, list_items, split.c_str()); - } - yyjson_mut_obj_add_val(doc, extra_info_obj, key_mut, list_items); - } else { - yyjson_mut_obj_add_strcpy(doc, extra_info_obj, key_mut, value_mut); - } - } - yyjson_mut_obj_add_val(doc, dest, key_ptr, extra_info_obj); - continue; - } - - // The metric cannot be NULL, and should have been 0 initialized. - D_ASSERT(!metrics[metric].IsNull()); - - if (MetricsUtils::IsOptimizerMetric(metric) || MetricsUtils::IsPhaseTimingMetric(metric)) { - yyjson_mut_obj_add_real(doc, dest, key_ptr, metrics[metric].GetValue()); - continue; - } - - ProfilingUtils::MetricToJson(doc, dest, key_ptr, metrics, metric); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/main/profiling_utils.cpp b/src/duckdb/src/main/profiling_utils.cpp index 341ce3c2b..39aeafa18 100644 --- a/src/duckdb/src/main/profiling_utils.cpp +++ b/src/duckdb/src/main/profiling_utils.cpp @@ -1,291 +1,28 @@ -// This file is automatically generated by scripts/generate_metric_enums.py -// Do not edit this file manually, your changes will be overwritten - #include "duckdb/main/profiling_utils.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/main/profiling_node.hpp" -#include "duckdb/main/query_profiler.hpp" - -#include "yyjson.hpp" - -using namespace duckdb_yyjson; // NOLINT +#include "duckdb/common/string_util.hpp" namespace duckdb { -static string OperatorToString(const Value &val) { - const auto type = static_cast(val.GetValue()); - return EnumUtil::ToString(type); -} - -template -static void AggregateMetric(ProfilingNode &node, MetricType aggregated_metric, MetricType child_metric, const std::function &update_fun) { - auto &info = node.GetProfilingInfo(); - info.metrics[aggregated_metric] = info.metrics[child_metric]; - - for (idx_t i = 0; i < node.GetChildCount(); i++) { - auto child = node.GetChild(i); - AggregateMetric(*child, aggregated_metric, child_metric, update_fun); - - auto &child_info = child->GetProfilingInfo(); - auto value = child_info.GetMetricValue(aggregated_metric); - info.MetricUpdate(aggregated_metric, value, update_fun); - } -} - -template -static void GetCumulativeMetric(ProfilingNode &node, MetricType cumulative_metric, MetricType child_metric) { - AggregateMetric( - node, cumulative_metric, child_metric, - [](const METRIC_TYPE &old_value, const METRIC_TYPE &new_value) { return old_value + new_value; }); -} - -static Value GetCumulativeOptimizers(ProfilingNode &node) { - auto &metrics = node.GetProfilingInfo().metrics; - double count = 0; - for (auto &metric : metrics) { - if (MetricsUtils::IsOptimizerMetric(metric.first)) { - count += metric.second.GetValue(); - } +void QueryMetrics::FinalizeMetrics(GatheredMetrics &info) { + info.SetMetric(query_sql); + for (const auto &[key, ns] : string_timings) { + info.SetMetric(key, static_cast(ns) / 1e9); } - return Value::CreateValue(count); -} - -void ProfilingUtils::SetMetricToDefault(profiler_metrics_t &metrics, const MetricType &type) { - switch(type) { - case MetricType::ALL_OPTIMIZERS: - case MetricType::ATTACH_LOAD_STORAGE_LATENCY: - case MetricType::ATTACH_REPLAY_WAL_LATENCY: - case MetricType::BLOCKED_THREAD_TIME: - case MetricType::CHECKPOINT_LATENCY: - case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: - case MetricType::CPU_TIME: - case MetricType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricType::LATENCY: - case MetricType::OPERATOR_TIMING: - case MetricType::OPTIMIZER_AGGREGATE_FUNCTION_REWRITER: - case MetricType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: - case MetricType::OPTIMIZER_COLUMN_LIFETIME: - case MetricType::OPTIMIZER_COMMON_AGGREGATE: - case MetricType::OPTIMIZER_COMMON_SUBEXPRESSIONS: - case MetricType::OPTIMIZER_COMMON_SUBPLAN: - case MetricType::OPTIMIZER_COMPRESSED_MATERIALIZATION: - case MetricType::OPTIMIZER_CTE_FILTER_PUSHER: - case MetricType::OPTIMIZER_CTE_INLINING: - case MetricType::OPTIMIZER_DELIMINATOR: - case MetricType::OPTIMIZER_DUPLICATE_GROUPS: - case MetricType::OPTIMIZER_EMPTY_RESULT_PULLUP: - case MetricType::OPTIMIZER_EXPRESSION_REWRITER: - case MetricType::OPTIMIZER_EXTENSION: - case MetricType::OPTIMIZER_FILTER_PULLUP: - case MetricType::OPTIMIZER_FILTER_PUSHDOWN: - case MetricType::OPTIMIZER_IN_CLAUSE: - case MetricType::OPTIMIZER_JOIN_ELIMINATION: - case MetricType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: - case MetricType::OPTIMIZER_JOIN_ORDER: - case MetricType::OPTIMIZER_LATE_MATERIALIZATION: - case MetricType::OPTIMIZER_LIMIT_PUSHDOWN: - case MetricType::OPTIMIZER_MATERIALIZED_CTE: - case MetricType::OPTIMIZER_OUTER_JOIN_SIMPLIFICATION: - case MetricType::OPTIMIZER_PARTITIONED_EXECUTION: - case MetricType::OPTIMIZER_PROJECTION_PULLUP: - case MetricType::OPTIMIZER_REGEX_RANGE: - case MetricType::OPTIMIZER_REORDER_FILTER: - case MetricType::OPTIMIZER_ROW_GROUP_PRUNER: - case MetricType::OPTIMIZER_ROW_NUMBER_REWRITER: - case MetricType::OPTIMIZER_SAMPLING_PUSHDOWN: - case MetricType::OPTIMIZER_STATISTICS_PROPAGATION: - case MetricType::OPTIMIZER_TOP_N: - case MetricType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION: - case MetricType::OPTIMIZER_UNNEST_REWRITER: - case MetricType::OPTIMIZER_UNUSED_COLUMNS: - case MetricType::OPTIMIZER_WINDOW_SELF_JOIN: - case MetricType::PARSER: - case MetricType::PHYSICAL_PLANNER: - case MetricType::PHYSICAL_PLANNER_COLUMN_BINDING: - case MetricType::PHYSICAL_PLANNER_CREATE_PLAN: - case MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES: - case MetricType::PLANNER: - case MetricType::PLANNER_BINDING: - case MetricType::WAITING_TO_ATTACH_LATENCY: - case MetricType::WRITE_TO_WAL_LATENCY: - metrics[type] = Value::CreateValue(0.0); - break; - case MetricType::CUMULATIVE_CARDINALITY: - case MetricType::CUMULATIVE_ROWS_SCANNED: - case MetricType::OPERATOR_CARDINALITY: - case MetricType::OPERATOR_ROWS_SCANNED: - case MetricType::RESULT_SET_SIZE: - case MetricType::ROWS_RETURNED: - case MetricType::SYSTEM_PEAK_BUFFER_MEMORY: - case MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE: - case MetricType::TOTAL_BYTES_READ: - case MetricType::TOTAL_BYTES_WRITTEN: - case MetricType::TOTAL_MEMORY_ALLOCATED: - case MetricType::WAL_REPLAY_ENTRY_COUNT: - metrics[type] = Value::CreateValue(0); - break; - case MetricType::EXTRA_INFO: - metrics[type] = Value::MAP(InsertionOrderPreservingMap()); - break; - case MetricType::OPERATOR_NAME: - case MetricType::QUERY_NAME: - metrics[type] = Value::CreateValue(""); - break; - case MetricType::OPERATOR_TYPE: - metrics[type] = Value::CreateValue(0); - break; - default: - throw InternalException("Unknown metric type %s", EnumUtil::ToString(type)); + for (const auto &[key, count] : string_counters) { + info.SetMetric(key, count); } + info.SetMetric(GetBytesRead()); + info.SetMetric(GetBytesWritten()); + info.SetMetric(blocked_thread_time); + info.SetMetric(system_peak_buffer_memory); + info.SetMetric(system_peak_temp_dir_size); + info.SetMetric(GetTotalMemoryAllocated()); } -void ProfilingUtils::MetricToJson(duckdb_yyjson::yyjson_mut_doc *doc, duckdb_yyjson::yyjson_mut_val *dest, const char *key_ptr, profiler_metrics_t &metrics, const MetricType &type) { - switch(type) { - case MetricType::ALL_OPTIMIZERS: - case MetricType::ATTACH_LOAD_STORAGE_LATENCY: - case MetricType::ATTACH_REPLAY_WAL_LATENCY: - case MetricType::BLOCKED_THREAD_TIME: - case MetricType::CHECKPOINT_LATENCY: - case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: - case MetricType::CPU_TIME: - case MetricType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricType::LATENCY: - case MetricType::OPERATOR_TIMING: - case MetricType::OPTIMIZER_AGGREGATE_FUNCTION_REWRITER: - case MetricType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: - case MetricType::OPTIMIZER_COLUMN_LIFETIME: - case MetricType::OPTIMIZER_COMMON_AGGREGATE: - case MetricType::OPTIMIZER_COMMON_SUBEXPRESSIONS: - case MetricType::OPTIMIZER_COMMON_SUBPLAN: - case MetricType::OPTIMIZER_COMPRESSED_MATERIALIZATION: - case MetricType::OPTIMIZER_CTE_FILTER_PUSHER: - case MetricType::OPTIMIZER_CTE_INLINING: - case MetricType::OPTIMIZER_DELIMINATOR: - case MetricType::OPTIMIZER_DUPLICATE_GROUPS: - case MetricType::OPTIMIZER_EMPTY_RESULT_PULLUP: - case MetricType::OPTIMIZER_EXPRESSION_REWRITER: - case MetricType::OPTIMIZER_EXTENSION: - case MetricType::OPTIMIZER_FILTER_PULLUP: - case MetricType::OPTIMIZER_FILTER_PUSHDOWN: - case MetricType::OPTIMIZER_IN_CLAUSE: - case MetricType::OPTIMIZER_JOIN_ELIMINATION: - case MetricType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: - case MetricType::OPTIMIZER_JOIN_ORDER: - case MetricType::OPTIMIZER_LATE_MATERIALIZATION: - case MetricType::OPTIMIZER_LIMIT_PUSHDOWN: - case MetricType::OPTIMIZER_MATERIALIZED_CTE: - case MetricType::OPTIMIZER_OUTER_JOIN_SIMPLIFICATION: - case MetricType::OPTIMIZER_PARTITIONED_EXECUTION: - case MetricType::OPTIMIZER_PROJECTION_PULLUP: - case MetricType::OPTIMIZER_REGEX_RANGE: - case MetricType::OPTIMIZER_REORDER_FILTER: - case MetricType::OPTIMIZER_ROW_GROUP_PRUNER: - case MetricType::OPTIMIZER_ROW_NUMBER_REWRITER: - case MetricType::OPTIMIZER_SAMPLING_PUSHDOWN: - case MetricType::OPTIMIZER_STATISTICS_PROPAGATION: - case MetricType::OPTIMIZER_TOP_N: - case MetricType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION: - case MetricType::OPTIMIZER_UNNEST_REWRITER: - case MetricType::OPTIMIZER_UNUSED_COLUMNS: - case MetricType::OPTIMIZER_WINDOW_SELF_JOIN: - case MetricType::PARSER: - case MetricType::PHYSICAL_PLANNER: - case MetricType::PHYSICAL_PLANNER_COLUMN_BINDING: - case MetricType::PHYSICAL_PLANNER_CREATE_PLAN: - case MetricType::PHYSICAL_PLANNER_RESOLVE_TYPES: - case MetricType::PLANNER: - case MetricType::PLANNER_BINDING: - case MetricType::WAITING_TO_ATTACH_LATENCY: - case MetricType::WRITE_TO_WAL_LATENCY: - yyjson_mut_obj_add_real(doc, dest, key_ptr, metrics[type].GetValue()); - break; - case MetricType::CUMULATIVE_CARDINALITY: - case MetricType::CUMULATIVE_ROWS_SCANNED: - case MetricType::OPERATOR_CARDINALITY: - case MetricType::OPERATOR_ROWS_SCANNED: - case MetricType::RESULT_SET_SIZE: - case MetricType::ROWS_RETURNED: - case MetricType::SYSTEM_PEAK_BUFFER_MEMORY: - case MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE: - case MetricType::TOTAL_BYTES_READ: - case MetricType::TOTAL_BYTES_WRITTEN: - case MetricType::TOTAL_MEMORY_ALLOCATED: - case MetricType::WAL_REPLAY_ENTRY_COUNT: - yyjson_mut_obj_add_uint(doc, dest, key_ptr, metrics[type].GetValue()); - break; - case MetricType::EXTRA_INFO: - break; - case MetricType::OPERATOR_NAME: - case MetricType::QUERY_NAME: - yyjson_mut_obj_add_strcpy(doc, dest, key_ptr, metrics[type].GetValue().c_str()); - break; - case MetricType::OPERATOR_TYPE: - yyjson_mut_obj_add_strcpy(doc, dest, key_ptr, OperatorToString(metrics[type]).c_str()); - break; - default: - throw InternalException("Unknown metric type %s", EnumUtil::ToString(type)); - } +QueryMetrics::QueryMetrics() : bytes_read(0), bytes_written(0), total_memory_allocated(0) { + Reset(); } -void ProfilingUtils::CollectMetrics(const MetricType &type, QueryMetrics &query_metrics, Value &metric, ProfilingNode &node, ProfilingInfo &child_info) { - switch(type) { - case MetricType::CPU_TIME: - GetCumulativeMetric(node, MetricType::CPU_TIME, MetricType::OPERATOR_TIMING); - break; - case MetricType::CUMULATIVE_CARDINALITY: - GetCumulativeMetric(node, MetricType::CUMULATIVE_CARDINALITY, MetricType::OPERATOR_CARDINALITY); - break; - case MetricType::CUMULATIVE_ROWS_SCANNED: - GetCumulativeMetric(node, MetricType::CUMULATIVE_ROWS_SCANNED, MetricType::OPERATOR_ROWS_SCANNED); - break; - case MetricType::ATTACH_LOAD_STORAGE_LATENCY: - metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::ATTACH_LOAD_STORAGE_LATENCY)); - break; - case MetricType::ATTACH_REPLAY_WAL_LATENCY: - metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::ATTACH_REPLAY_WAL_LATENCY)); - break; - case MetricType::CHECKPOINT_LATENCY: - metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::CHECKPOINT_LATENCY)); - break; - case MetricType::COMMIT_LOCAL_STORAGE_LATENCY: - metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::COMMIT_LOCAL_STORAGE_LATENCY)); - break; - case MetricType::LATENCY: - metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::LATENCY)); - break; - case MetricType::WAITING_TO_ATTACH_LATENCY: - metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::WAITING_TO_ATTACH_LATENCY)); - break; - case MetricType::WRITE_TO_WAL_LATENCY: - metric = Value::DOUBLE(query_metrics.GetMetricInSeconds(MetricType::WRITE_TO_WAL_LATENCY)); - break; - case MetricType::QUERY_NAME: - metric = query_metrics.query_name; - break; - case MetricType::TOTAL_BYTES_READ: - metric = Value::UBIGINT(query_metrics.GetMetricValue(MetricType::TOTAL_BYTES_READ)); - break; - case MetricType::TOTAL_BYTES_WRITTEN: - metric = Value::UBIGINT(query_metrics.GetMetricValue(MetricType::TOTAL_BYTES_WRITTEN)); - break; - case MetricType::TOTAL_MEMORY_ALLOCATED: - metric = Value::UBIGINT(query_metrics.GetMetricValue(MetricType::TOTAL_MEMORY_ALLOCATED)); - break; - case MetricType::WAL_REPLAY_ENTRY_COUNT: - metric = Value::UBIGINT(query_metrics.GetMetricValue(MetricType::WAL_REPLAY_ENTRY_COUNT)); - break; - case MetricType::RESULT_SET_SIZE: - metric = child_info.metrics[MetricType::RESULT_SET_SIZE]; - break; - case MetricType::ROWS_RETURNED: - metric = child_info.metrics[MetricType::OPERATOR_CARDINALITY]; - break; - case MetricType::CUMULATIVE_OPTIMIZER_TIMING: - metric = GetCumulativeOptimizers(node); - break; - default: - return; - } -} +QueryMetrics::~QueryMetrics() = default; -} +} // namespace duckdb diff --git a/src/duckdb/src/main/query_profiler.cpp b/src/duckdb/src/main/query_profiler.cpp index 78513b218..9a9cd5329 100644 --- a/src/duckdb/src/main/query_profiler.cpp +++ b/src/duckdb/src/main/query_profiler.cpp @@ -13,7 +13,8 @@ #include "duckdb/main/client_context.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/main/profiling_utils.hpp" -#include "duckdb/main/profiling_info.hpp" +#include "duckdb/main/gathered_metrics.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/storage/buffer/buffer_pool.hpp" #include "yyjson.hpp" #include "yyjson_utils.hpp" @@ -24,6 +25,52 @@ using namespace duckdb_yyjson; // NOLINT namespace duckdb { +void QueryProfileResult::AddValue(const string &k, Value val) { + D_ASSERT(kind == QueryProfileResultKind::OBJECT); + auto child = make_uniq(); + child->kind = QueryProfileResultKind::VALUE; + child->key = k; + child->value = std::move(val); + children.push_back(std::move(child)); +} + +QueryProfileResult &QueryProfileResult::AddObject(const string &k) { + D_ASSERT(kind == QueryProfileResultKind::OBJECT); + auto child = make_uniq(); + child->kind = QueryProfileResultKind::OBJECT; + child->key = k; + auto &ref = *child; + children.push_back(std::move(child)); + return ref; +} + +QueryProfileResult &QueryProfileResult::AddList(const string &k) { + D_ASSERT(kind == QueryProfileResultKind::OBJECT); + auto child = make_uniq(); + child->kind = QueryProfileResultKind::LIST; + child->key = k; + auto &ref = *child; + children.push_back(std::move(child)); + return ref; +} + +QueryProfileResult &QueryProfileResult::AppendObject() { + D_ASSERT(kind == QueryProfileResultKind::LIST); + auto child = make_uniq(); + child->kind = QueryProfileResultKind::OBJECT; + auto &ref = *child; + children.push_back(std::move(child)); + return ref; +} + +QueryProfileResult &QueryProfileResult::AppendList() { + auto child = make_uniq(); + child->kind = QueryProfileResultKind::LIST; + auto &ref = *child; + children.push_back(std::move(child)); + return ref; +} + QueryProfiler::QueryProfiler(ClientContext &context_p) : context(context_p), running(false), query_requires_profiling(false), is_explain_analyze(false), metrics_finalized(false) { @@ -83,7 +130,20 @@ ExplainFormat QueryProfiler::GetExplainFormat(ProfilerPrintFormat format) const } bool QueryProfiler::PrintOptimizerOutput() const { - return GetPrintFormat() == ProfilerPrintFormat::QUERY_TREE_OPTIMIZER || IsDetailedEnabled(); + if (GetPrintFormat() == ProfilerPrintFormat::QUERY_TREE_OPTIMIZER || IsDetailedEnabled()) { + return true; + } + if (metrics) { + return metrics->MetricIsTracked("optimizer.join_order"); + } + // Fall back to checking tracked_metrics patterns directly + auto &config = ClientConfig::GetConfig(context); + for (const auto &pattern : config.tracked_metrics) { + if (pattern == "*" || StringUtil::StartsWith(pattern, "optimizer")) { + return true; + } + } + return false; } string QueryProfiler::GetSaveLocation() const { @@ -97,25 +157,25 @@ QueryProfiler &QueryProfiler::Get(ClientContext &context) { void QueryProfiler::Start(const string &query) { Reset(); running = true; - query_metrics.query_name = query; - query_metrics.latency_timer = make_uniq(StartTimer(MetricType::LATENCY)); + query_metrics.query_sql = query; + query_metrics.latency_timer = make_uniq(StartTimer()); } void QueryProfiler::Reset() { tree_map.clear(); root = nullptr; - phase_timings.clear(); - phase_stack.clear(); + metrics.reset(); running = false; query_metrics.Reset(); + result_tree.reset(); metrics_finalized = false; } void QueryProfiler::StartQuery(const string &query, bool is_explain_analyze_p, bool start_at_optimizer) { lock_guard guard(lock); // Always reset byte counters at the start of each query so the progress bar shows per-query values - query_metrics.ResetMetric(MetricType::TOTAL_BYTES_READ); - query_metrics.ResetMetric(MetricType::TOTAL_BYTES_WRITTEN); + query_metrics.bytes_read = 0; + query_metrics.bytes_written = 0; if (is_explain_analyze_p) { StartExplainAnalyze(); } @@ -178,22 +238,6 @@ bool QueryProfiler::OperatorRequiresProfiling(const PhysicalOperatorType op_type } } -void QueryProfiler::Finalize(ProfilingNode &node) { - for (idx_t i = 0; i < node.GetChildCount(); i++) { - auto child = node.GetChild(i); - Finalize(*child); - - auto &info = node.GetProfilingInfo(); - auto type = PhysicalOperatorType(info.GetMetricValue(MetricType::OPERATOR_TYPE)); - if (type == PhysicalOperatorType::UNION && - info.Enabled(info.expanded_settings, MetricType::OPERATOR_CARDINALITY)) { - auto &child_info = child->GetProfilingInfo(); - auto value = child_info.metrics[MetricType::OPERATOR_CARDINALITY].GetValue(); - info.MetricSum(MetricType::OPERATOR_CARDINALITY, value); - } - } -} - void QueryProfiler::StartExplainAnalyze() { is_explain_analyze = true; } @@ -216,10 +260,10 @@ void QueryProfiler::EndQuery() { is_explain_analyze = false; - guard.unlock(); - // To log is inexpensive, whether to log or not depends on whether logging is active - ToLog(); + ToLogInternal(); + + guard.unlock(); if (emit_output) { string tree = ToString(); @@ -239,27 +283,48 @@ void QueryProfiler::FinalizeMetrics() { FinalizeMetricsInternal(); } -void QueryProfiler::AddToCounter(const MetricType type, const idx_t amount) { - // Always track bytes read/written so the progress bar can display them - if (type == MetricType::TOTAL_BYTES_READ || type == MetricType::TOTAL_BYTES_WRITTEN) { - query_metrics.UpdateMetric(type, amount); +void QueryProfiler::TrackBytesRead(const idx_t amount) { + query_metrics.UpdateBytesRead(amount); +} + +void QueryProfiler::TrackBytesWritten(const idx_t amount) { + query_metrics.UpdateBytesWritten(amount); +} + +void QueryProfiler::TrackTotalMemoryAllocated(const idx_t amount) { + query_metrics.UpdateTotalMemoryAllocated(amount); +} + +void QueryProfiler::AddToMetricCounter(const string &key, const idx_t amount) { + if (IsEnabled()) { + query_metrics.UpdateMetricCounter(key, amount); + } +} + +void QueryProfiler::SetMetric(const string &key, Value new_value) { + if (!IsEnabled()) { return; } - if (IsEnabled()) { - query_metrics.UpdateMetric(type, amount); + metrics->SetMetric(key, std::move(new_value)); +} + +bool QueryProfiler::MetricIsTracked(const string &key) const { + if (!IsEnabled()) { + return false; } + return metrics->MetricIsTracked(key); } idx_t QueryProfiler::GetBytesRead() const { - return query_metrics.GetMetricValue(MetricType::TOTAL_BYTES_READ); + return query_metrics.GetBytesRead(); } idx_t QueryProfiler::GetBytesWritten() const { - return query_metrics.GetMetricValue(MetricType::TOTAL_BYTES_WRITTEN); + return query_metrics.GetBytesWritten(); } -ActiveTimer QueryProfiler::StartTimer(const MetricType type) { - return ActiveTimer(query_metrics, type, IsEnabled()); +MetricsTimer QueryProfiler::StartTimerInternal(const string &key) { + return MetricsTimer(query_metrics, key, IsEnabled()); } string QueryProfiler::ToString(ExplainFormat explain_format) const { @@ -284,7 +349,7 @@ string QueryProfiler::ToString(ProfilerPrintFormat format) const { lock_guard guard(lock); // checking the tree to ensure the query is really empty // the query string is empty when a logical plan is deserialized - if (query_metrics.query_name.empty() || !root) { + if (query_metrics.query_sql.empty() || !root) { return ""; } auto renderer = TreeRenderer::CreateRenderer(GetExplainFormat(format)); @@ -297,80 +362,68 @@ string QueryProfiler::ToString(ProfilerPrintFormat format) const { } } -void QueryProfiler::StartPhase(MetricType phase_metric) { - lock_guard guard(lock); - if (!IsEnabled() || !running) { - return; - } - - // start a new phase - phase_stack.push_back(phase_metric); - // restart the timer - phase_profiler.Start(); +OperatorProfiler::OperatorProfiler(ClientContext &context) : context(context) { + enabled = QueryProfiler::Get(context).IsEnabled(); } -void QueryProfiler::EndPhase() { - lock_guard guard(lock); - if (!IsEnabled() || !running) { +void OperatorProfiler::StartOperator(optional_ptr phys_op) { + if (!enabled) { return; } - D_ASSERT(!phase_stack.empty()); - - // end the timer - phase_profiler.End(); - // add the timing to all currently active phases - for (auto &phase : phase_stack) { - phase_timings[phase] += phase_profiler.Elapsed(); + if (active_operator) { + throw InternalException("OperatorProfiler: Attempting to call StartOperator while another operator is active"); } - // now remove the last added phase - phase_stack.pop_back(); + active_operator = phys_op; - if (!phase_stack.empty()) { - phase_profiler.Start(); + if (!OperatorMetricsIsInitialized(*active_operator)) { + // first time calling into this operator - fetch the info + auto &info = GetOperatorMetrics(*active_operator); + info.SetExtraInfo(active_operator->ParamsToString()); } -} -OperatorProfiler::OperatorProfiler(ClientContext &context) : context(context) { - enabled = QueryProfiler::Get(context).IsEnabled(); - auto &context_metrics = ClientConfig::GetConfig(context).profiler_settings; + // Start the timing of the current operator. + op.Start(); +} - // Expand. - for (const auto metric : context_metrics) { - settings.insert(metric); - ProfilingInfo::Expand(settings, metric); +void OperatorMetrics::GatherMetrics(ClientContext &context, double elapsed_time, optional_ptr chunk) { + time += elapsed_time; + if (chunk) { + elements_returned += chunk->size(); + intermediate_size_bytes += LossyNumericCast(chunk->GetDataSize()); } - - // Reduce. - auto root_metrics = MetricsUtils::GetRootScopeMetrics(); - for (const auto metric : root_metrics) { - settings.erase(metric); + auto &buffer_manager = BufferManager::GetBufferManager(context); + auto used_memory = buffer_manager.GetBufferPool().GetUsedMemory(false); + if (used_memory > system_peak_buffer_manager_memory) { + system_peak_buffer_manager_memory = used_memory; + } + auto used_swap = buffer_manager.GetUsedSwap(); + if (used_swap > system_peak_temp_directory_size) { + system_peak_temp_directory_size = used_swap; } } -void OperatorProfiler::StartOperator(optional_ptr phys_op) { - if (!enabled) { - return; +void OperatorMetrics::MergeInternal(const OperatorMetrics &other) { + time += other.time; + elements_returned += other.elements_returned; + intermediate_size_bytes += other.intermediate_size_bytes; + rows_scanned += other.rows_scanned; + row_groups_scanned += other.row_groups_scanned; + if (other.system_peak_buffer_manager_memory > system_peak_buffer_manager_memory) { + system_peak_buffer_manager_memory = other.system_peak_buffer_manager_memory; } - if (active_operator) { - throw InternalException("OperatorProfiler: Attempting to call StartOperator while another operator is active"); + if (other.system_peak_temp_directory_size > system_peak_temp_directory_size) { + system_peak_temp_directory_size = other.system_peak_temp_directory_size; } - active_operator = phys_op; +} - if (!settings.empty()) { - if (ProfilingInfo::Enabled(settings, MetricType::EXTRA_INFO)) { - if (!OperatorInfoIsInitialized(*active_operator)) { - // first time calling into this operator - fetch the info - auto &info = GetOperatorInfo(*active_operator); - auto params = active_operator->ParamsToString(); - info.extra_info = params; - } - } +void OperatorMetrics::Accumulate(const OperatorMetrics &other) { + MergeInternal(other); + total_row_groups_to_scan += other.total_row_groups_to_scan; +} - // Start the timing of the current operator. - if (ProfilingInfo::Enabled(settings, MetricType::OPERATOR_TIMING)) { - op.Start(); - } - } +void OperatorMetrics::Merge(const OperatorMetrics &other) { + MergeInternal(other); + total_row_groups_to_scan = MaxValue(total_row_groups_to_scan, other.total_row_groups_to_scan); } void OperatorProfiler::EndOperator(optional_ptr chunk) { @@ -381,28 +434,9 @@ void OperatorProfiler::EndOperator(optional_ptr chunk) { throw InternalException("OperatorProfiler: Attempting to call EndOperator while no operator is active"); } - if (!settings.empty()) { - auto &info = GetOperatorInfo(*active_operator); - if (ProfilingInfo::Enabled(settings, MetricType::OPERATOR_TIMING)) { - op.End(); - info.AddMetric(MetricType::OPERATOR_TIMING, op.Elapsed()); - } - if (ProfilingInfo::Enabled(settings, MetricType::OPERATOR_CARDINALITY) && chunk) { - info.AddMetric(MetricType::OPERATOR_CARDINALITY, chunk->size()); - } - if (ProfilingInfo::Enabled(settings, MetricType::RESULT_SET_SIZE) && chunk) { - auto result_set_size = chunk->GetDataSize(); - info.AddMetric(MetricType::RESULT_SET_SIZE, result_set_size); - } - if (ProfilingInfo::Enabled(settings, MetricType::SYSTEM_PEAK_BUFFER_MEMORY)) { - auto used_memory = BufferManager::GetBufferManager(context).GetBufferPool().GetUsedMemory(false); - info.AddMetric(MetricType::SYSTEM_PEAK_BUFFER_MEMORY, used_memory); - } - if (ProfilingInfo::Enabled(settings, MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE)) { - auto used_swap = BufferManager::GetBufferManager(context).GetUsedSwap(); - info.AddMetric(MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE, used_swap); - } - } + auto &info = GetOperatorMetrics(*active_operator); + op.End(); + info.GatherMetrics(context, op.Elapsed(), chunk); active_operator = nullptr; } @@ -413,68 +447,44 @@ void OperatorProfiler::FinishSource(GlobalSourceState &gstate, LocalSourceState if (!active_operator) { throw InternalException("OperatorProfiler: Attempting to call FinishSource while no operator is active"); } - if (!settings.empty()) { - if (ProfilingInfo::Enabled(settings, MetricType::EXTRA_INFO)) { - // we're emitting extra info - get the extra source info - auto &info = GetOperatorInfo(*active_operator); - auto extra_info = active_operator->ExtraSourceParams(gstate, lstate); - for (auto &new_info : extra_info) { - auto entry = info.extra_info.find(new_info.first); - if (entry != info.extra_info.end()) { - // entry exists - override - entry->second = std::move(new_info.second); - } else { - // entry does not exist yet - insert - info.extra_info.insert(std::move(new_info)); - } - } - } - if (ProfilingInfo::Enabled(settings, MetricType::OPERATOR_ROWS_SCANNED) && - active_operator.get()->type == PhysicalOperatorType::TABLE_SCAN) { - const auto &table_scan = active_operator->Cast(); - const auto rows_scanned = table_scan.GetRowsScanned(gstate, lstate); - auto &info = GetOperatorInfo(*active_operator); - if (rows_scanned.IsValid()) { - // Use exact value if available. - info.AddMetric(MetricType::OPERATOR_ROWS_SCANNED, rows_scanned.GetIndex()); - } else { - // Otherwise estimate as the cardinality of the table scan, if there is no exact value available. - auto &bind_data = table_scan.bind_data; - if (bind_data && table_scan.function.cardinality) { - auto cardinality = table_scan.function.cardinality(context, &(*bind_data)); - if (cardinality && cardinality->has_estimated_cardinality) { - info.AddMetric(MetricType::OPERATOR_ROWS_SCANNED, cardinality->estimated_cardinality); - } - } - } - } + FinishSource(*active_operator, gstate, lstate); +} + +void OperatorProfiler::FinishSource(const PhysicalOperator &phys_op, GlobalSourceState &gstate, + LocalSourceState &lstate) { + if (phys_op.type == PhysicalOperatorType::TABLE_SCAN) { + const auto &table_scan = phys_op.Cast(); + auto &scan_metrics = GetOperatorMetrics(phys_op); + table_scan.GetMetrics(context, gstate, lstate, scan_metrics); } } -bool OperatorProfiler::OperatorInfoIsInitialized(const PhysicalOperator &phys_op) { - auto entry = operator_infos.find(phys_op); - return entry != operator_infos.end(); +bool OperatorProfiler::OperatorMetricsIsInitialized(const PhysicalOperator &phys_op) { + auto entry = operator_metrics.find(phys_op); + return entry != operator_metrics.end(); } -OperatorInformation &OperatorProfiler::GetOperatorInfo(const PhysicalOperator &phys_op) { - auto entry = operator_infos.find(phys_op); - if (entry != operator_infos.end()) { +OperatorMetrics &OperatorProfiler::GetOperatorMetrics(const PhysicalOperator &phys_op) { + auto entry = operator_metrics.find(phys_op); + if (entry != operator_metrics.end()) { return entry->second; } // Add a new entry. - operator_infos[phys_op] = OperatorInformation(); - return operator_infos[phys_op]; + operator_metrics[phys_op] = OperatorMetrics(); + return operator_metrics[phys_op]; } void OperatorProfiler::Flush(const PhysicalOperator &phys_op) { - auto entry = operator_infos.find(phys_op); - if (entry == operator_infos.end()) { + auto entry = operator_metrics.find(phys_op); + if (entry == operator_metrics.end()) { return; } auto &info = entry->second; - info.name = phys_op.GetName(); + if (info.name.empty()) { + info.name = EnumUtil::ToString(phys_op.type); + } } void QueryProfiler::Flush(OperatorProfiler &profiler) { @@ -482,39 +492,28 @@ void QueryProfiler::Flush(OperatorProfiler &profiler) { if (!IsEnabled() || !running) { return; } - for (auto &node : profiler.operator_infos) { + for (auto &node : profiler.operator_metrics) { auto &op = node.first.get(); auto entry = tree_map.find(op); D_ASSERT(entry != tree_map.end()); auto &tree_node = entry->second.get(); - auto &info = tree_node.GetProfilingInfo(); - - if (ProfilingInfo::Enabled(profiler.settings, MetricType::OPERATOR_TIMING)) { - info.MetricSum(MetricType::OPERATOR_TIMING, node.second.time); + auto &info = tree_node.GetOperatorMetrics(); + info.Merge(node.second); + // Update extra_info from the per-thread metrics: these are set during execution (StartOperator), + // so they capture runtime values like dynamic filters that aren't known at plan-creation time. + if (!node.second.GetExtraInfo().empty()) { + info.SetExtraInfo(node.second.GetExtraInfo()); } - if (ProfilingInfo::Enabled(profiler.settings, MetricType::OPERATOR_CARDINALITY)) { - info.MetricSum(MetricType::OPERATOR_CARDINALITY, node.second.elements_returned); - } - if (ProfilingInfo::Enabled(profiler.settings, MetricType::OPERATOR_ROWS_SCANNED)) { - info.MetricSum(MetricType::OPERATOR_ROWS_SCANNED, node.second.rows_scanned); - } - if (ProfilingInfo::Enabled(profiler.settings, MetricType::RESULT_SET_SIZE)) { - info.MetricSum(MetricType::RESULT_SET_SIZE, node.second.result_set_size); - } - if (ProfilingInfo::Enabled(profiler.settings, MetricType::EXTRA_INFO)) { - info.metrics[MetricType::EXTRA_INFO] = Value::MAP(node.second.extra_info); - } - if (ProfilingInfo::Enabled(profiler.settings, MetricType::SYSTEM_PEAK_BUFFER_MEMORY)) { - query_metrics.query_global_info.MetricMax(MetricType::SYSTEM_PEAK_BUFFER_MEMORY, - node.second.system_peak_buffer_manager_memory); + + if (node.second.system_peak_buffer_manager_memory > query_metrics.system_peak_buffer_memory) { + query_metrics.system_peak_buffer_memory = node.second.system_peak_buffer_manager_memory; } - if (ProfilingInfo::Enabled(profiler.settings, MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE)) { - query_metrics.query_global_info.MetricMax(MetricType::SYSTEM_PEAK_TEMP_DIR_SIZE, - node.second.system_peak_temp_directory_size); + if (node.second.system_peak_temp_directory_size > query_metrics.system_peak_temp_dir_size) { + query_metrics.system_peak_temp_dir_size = node.second.system_peak_temp_directory_size; } + node.second.ResetMetrics(); } - profiler.operator_infos.clear(); } void QueryProfiler::SetBlockedTime(const double &blocked_thread_time) { @@ -523,10 +522,7 @@ void QueryProfiler::SetBlockedTime(const double &blocked_thread_time) { return; } - auto &info = root->GetProfilingInfo(); - if (info.Enabled(info.expanded_settings, MetricType::BLOCKED_THREAD_TIME)) { - query_metrics.query_global_info.metrics[MetricType::BLOCKED_THREAD_TIME] = blocked_thread_time; - } + query_metrics.blocked_thread_time = blocked_thread_time; } string QueryProfiler::DrawPadded(const string &str, idx_t width) { @@ -588,7 +584,7 @@ void RenderPhaseTimings(std::ostream &ss, const pair &head, map< ss << "└────────────────────────────────────────────────┘\n"; } -void PrintPhaseTimingsToStream(std::ostream &ss, const ProfilingInfo &info, idx_t width) { +void PrintPhaseTimingsToStream(std::ostream &ss, const GatheredMetrics &info, idx_t width) { map optimizer_timings; map planner_timings; map parser_timings; @@ -599,44 +595,40 @@ void PrintPhaseTimingsToStream(std::ostream &ss, const ProfilingInfo &info, idx_ pair parser_head; pair physical_planner_head; - for (const auto &entry : info.metrics) { - if (MetricsUtils::IsOptimizerMetric(entry.first)) { - optimizer_timings[EnumUtil::ToString(entry.first).substr(10)] = entry.second.GetValue(); - } else if (MetricsUtils::IsPhaseTimingMetric(entry.first)) { - switch (entry.first) { - case MetricType::CUMULATIVE_OPTIMIZER_TIMING: - continue; - case MetricType::ALL_OPTIMIZERS: - optimizer_head = {"Optimizer", entry.second.GetValue()}; - break; - case MetricType::PHYSICAL_PLANNER: - physical_planner_head = {"Physical Planner", entry.second.GetValue()}; - break; - case MetricType::PLANNER: - planner_head = {"Planner", entry.second.GetValue()}; - break; - case MetricType::PARSER: - parser_head = {"Parser", entry.second.GetValue()}; - break; - default: - break; - } - - auto metric = EnumUtil::ToString(entry.first); - if (StringUtil::StartsWith(metric, "PHYSICAL_PLANNER") && entry.first != MetricType::PHYSICAL_PLANNER) { - physical_planner_timings[metric.substr(17)] = entry.second.GetValue(); - } else if (StringUtil::StartsWith(metric, "PLANNER") && entry.first != MetricType::PLANNER) { - planner_timings[metric.substr(8)] = entry.second.GetValue(); - } else if (StringUtil::StartsWith(metric, "PARSER")) { - parser_timings[metric] = entry.second.GetValue(); - } + for (const auto &entry : info.GetMetrics()) { + const auto &metric = entry.first; + // Check specific total_time metrics BEFORE group checks — MetricInGroup would otherwise match these first. + if (MetricsUtils::IsMetric(metric)) { + optimizer_head = {"Optimizer", entry.second.GetValue()}; + } else if (MetricsUtils::IsMetric(metric)) { + physical_planner_head = {"Physical Planner", entry.second.GetValue()}; + } else if (MetricsUtils::IsMetric(metric)) { + planner_head = {"Planner", entry.second.GetValue()}; + } else if (MetricsUtils::IsMetric(metric)) { + parser_head = {"Parser", entry.second.GetValue()}; + } else if (MetricsUtils::MetricInGroup(metric, "optimizer")) { + // "optimizer.expression_rewriter" -> display as "expression_rewriter" + optimizer_timings[metric.substr(10)] = entry.second.GetValue(); + } else if (MetricsUtils::MetricInGroup(metric, "physical_planner")) { + // "physical_planner.column_binding" -> display as "column_binding" + physical_planner_timings[metric.substr(17)] = entry.second.GetValue(); + } else if (MetricsUtils::IsMetric(metric)) { + planner_timings["binding_time"] = entry.second.GetValue(); } } - RenderPhaseTimings(ss, optimizer_head, optimizer_timings, width); - RenderPhaseTimings(ss, physical_planner_head, physical_planner_timings, width); - RenderPhaseTimings(ss, planner_head, planner_timings, width); - RenderPhaseTimings(ss, parser_head, parser_timings, width); + if (!optimizer_head.first.empty()) { + RenderPhaseTimings(ss, optimizer_head, optimizer_timings, width); + } + if (!physical_planner_head.first.empty()) { + RenderPhaseTimings(ss, physical_planner_head, physical_planner_timings, width); + } + if (!planner_head.first.empty()) { + RenderPhaseTimings(ss, planner_head, planner_timings, width); + } + if (!parser_head.first.empty()) { + RenderPhaseTimings(ss, parser_head, parser_timings, width); + } } void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { @@ -644,20 +636,19 @@ void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { bool show_query_name = false; if (root) { - auto &info = root->GetProfilingInfo(); - auto &settings = info.expanded_settings; - show_query_name = info.Enabled(settings, MetricType::QUERY_NAME); + auto &info = *metrics; + show_query_name = info.MetricIsTracked(); } ss << "┌─────────────────────────────────────┐\n"; ss << "│┌───────────────────────────────────┐│\n"; ss << "││ Query Profiling Information ││\n"; ss << "│└───────────────────────────────────┘│\n"; ss << "└─────────────────────────────────────┘\n"; - ss << (show_query_name ? StringUtil::Replace(query_metrics.query_name, "\n", " ") : "") + "\n"; + ss << (show_query_name ? StringUtil::Replace(query_metrics.query_sql, "\n", " ") : "") + "\n"; // checking the tree to ensure the query is really empty // the query string is empty when a logical plan is deserialized - if (query_metrics.query_name.empty() && !root) { + if (query_metrics.query_sql.empty() && !root) { return; } @@ -668,7 +659,7 @@ void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { constexpr idx_t TOTAL_BOX_WIDTH = 50; ss << "┌────────────────────────────────────────────────┐\n"; ss << "│┌──────────────────────────────────────────────┐│\n"; - string total_time = "Total Time: " + RenderTiming(query_metrics.GetMetricInSeconds(MetricType::LATENCY)); + string total_time = "Total Time: " + RenderTiming(query_metrics.GetStringMetricInSeconds("query.total_time")); ss << "││" + DrawPadded(total_time, TOTAL_BOX_WIDTH - 4) + "││\n"; ss << "│└──────────────────────────────────────────────┘│\n"; ss << "└────────────────────────────────────────────────┘\n"; @@ -676,7 +667,7 @@ void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { if (root) { // print phase timings if (PrintOptimizerOutput()) { - PrintPhaseTimingsToStream(ss, root->GetProfilingInfo(), TOTAL_BOX_WIDTH); + PrintPhaseTimingsToStream(ss, *metrics, TOTAL_BOX_WIDTH); } Render(*root, ss); } @@ -736,24 +727,108 @@ string QueryProfiler::JSONSanitize(const std::string &text) { return result; } -static yyjson_mut_val *ToJSONRecursive(yyjson_mut_doc *doc, ProfilingNode &node) { - auto result_obj = yyjson_mut_obj(doc); - auto &profiling_info = node.GetProfilingInfo(); - - if (profiling_info.Enabled(profiling_info.settings, MetricType::EXTRA_INFO)) { - profiling_info.metrics[MetricType::EXTRA_INFO] = - QueryProfiler::JSONSanitize(profiling_info.metrics.at(MetricType::EXTRA_INFO)); +profiler_metrics_t OperatorMetrics::GetMetrics(const GatheredMetrics &info) const { + profiler_metrics_t result; + if (info.MetricIsTracked()) { + result["type"] = Value(EnumUtil::ToString(operator_type)); + } + if (info.MetricIsTracked()) { + result["timing"] = Value::DOUBLE(time); + } + if (info.MetricIsTracked()) { + result["intermediate_rows"] = Value::UBIGINT(elements_returned); + } + if (info.MetricIsTracked()) { + result["intermediate_size_bytes"] = Value::UBIGINT(intermediate_size_bytes); + } + if (info.MetricIsTracked() && operator_type == PhysicalOperatorType::TABLE_SCAN) { + result["rows_scanned"] = Value::UBIGINT(rows_scanned); } + if (info.MetricIsTracked() && operator_type == PhysicalOperatorType::TABLE_SCAN) { + result["row_groups_scanned"] = Value::UBIGINT(row_groups_scanned); + } + if (info.MetricIsTracked() && + operator_type == PhysicalOperatorType::TABLE_SCAN) { + result["total_row_groups_to_scan"] = Value::UBIGINT(total_row_groups_to_scan); + } + if (info.MetricIsTracked()) { + result["extra_info"] = QueryProfiler::JSONSanitize(Value::MAP(extra_info)); + } + return result; +} - profiling_info.WriteMetricsToJSON(doc, result_obj); +static yyjson_mut_val *ValueToJSON(yyjson_mut_doc *doc, const Value &val) { + if (val.IsNull()) { + return yyjson_mut_null(doc); + } + auto &type = val.type(); + if (type.id() == LogicalTypeId::MAP) { + // MAP values (e.g. extra_info) become JSON objects; multiline string values become arrays + auto obj = yyjson_mut_obj(doc); + for (auto &child : MapValue::GetChildren(val)) { + auto kv = StructValue::GetChildren(child); + auto k = kv[0].GetValue(); + auto v = kv[1].GetValue(); + auto key_ptr = yyjson_mut_get_str(yyjson_mut_strcpy(doc, k.c_str())); + auto splits = StringUtil::Split(v, "\n"); + if (splits.size() > 1) { + auto arr = yyjson_mut_arr(doc); + for (auto &s : splits) { + yyjson_mut_arr_add_strcpy(doc, arr, s.c_str()); + } + yyjson_mut_obj_add_val(doc, obj, key_ptr, arr); + } else { + yyjson_mut_obj_add_strcpy(doc, obj, key_ptr, v.c_str()); + } + } + return obj; + } + if (type.IsIntegral()) { + return yyjson_mut_uint(doc, val.GetValue()); + } + if (type.IsNumeric()) { + return yyjson_mut_real(doc, val.GetValue()); + } + auto str = val.GetValue(); + return yyjson_mut_strncpy(doc, str.c_str(), str.size()); +} - auto children_list = yyjson_mut_arr(doc); - for (idx_t i = 0; i < node.GetChildCount(); i++) { - auto child = ToJSONRecursive(doc, *node.GetChild(i)); - yyjson_mut_arr_add_val(children_list, child); +static yyjson_mut_val *QueryProfileResultToJSON(yyjson_mut_doc *doc, const QueryProfileResult &node) { + switch (node.kind) { + case QueryProfileResultKind::VALUE: + return ValueToJSON(doc, node.value); + case QueryProfileResultKind::LIST: { + auto arr = yyjson_mut_arr(doc); + for (auto &child : node.children) { + yyjson_mut_arr_add_val(arr, QueryProfileResultToJSON(doc, *child)); + } + return arr; + } + case QueryProfileResultKind::OBJECT: { + auto obj = yyjson_mut_obj(doc); + // Sort children alphabetically by key for deterministic output + vector> sorted_children; + sorted_children.reserve(node.children.size()); + for (auto &child : node.children) { + sorted_children.push_back(*child); + } + std::sort(sorted_children.begin(), sorted_children.end(), + [](const QueryProfileResult &a, const QueryProfileResult &b) { + if (a.IsNested() != b.IsNested()) { + return !a.IsNested(); + } + return a.key < b.key; + }); + for (const QueryProfileResult &child : sorted_children) { + D_ASSERT(!child.key.empty()); + auto key_ptr = yyjson_mut_get_str(yyjson_mut_strcpy(doc, child.key.c_str())); + yyjson_mut_obj_add_val(doc, obj, key_ptr, QueryProfileResultToJSON(doc, child)); + } + return obj; + } + default: + throw InternalException("Unknown QueryProfileResultKind"); } - yyjson_mut_obj_add_val(doc, result_obj, "children", children_list); - return result_obj; } static string StringifyAndFree(ConvertedJSONHolder &json_holder, yyjson_mut_val *object) { @@ -766,46 +841,185 @@ static string StringifyAndFree(ConvertedJSONHolder &json_holder, yyjson_mut_val return result; } +void QueryProfiler::ToLogInternal() const { + if (!root) { + return; + } + metrics->WriteMetricsToLog(context); +} + void QueryProfiler::ToLog() const { lock_guard guard(lock); + ToLogInternal(); +} - if (!root) { - // No root, not much to do - return; +static void OperatorToResultTree(const GatheredMetrics &settings, ProfilingNode &node, QueryProfileResult &result) { + auto operator_metrics = node.GetOperatorMetrics().GetMetrics(settings); + for (auto &entry : operator_metrics) { + result.AddValue(entry.first, std::move(entry.second)); + } + if (node.GetChildCount() > 0) { + auto &children_list = result.AddList("children"); + for (idx_t i = 0; i < node.GetChildCount(); i++) { + auto &child_result = children_list.AppendObject(); + OperatorToResultTree(settings, *node.GetChild(i), child_result); + } } +} + +struct LegacyCumulative { + double timing = 0; + uint64_t cardinality = 0; + uint64_t rows_scanned = 0; +}; - auto &settings = root->GetProfilingInfo(); +static LegacyCumulative LegacyOperatorToResultTree(const GatheredMetrics &info, ProfilingNode &node, + QueryProfileResult &result) { + auto operator_metrics = node.GetOperatorMetrics().GetMetrics(info); - settings.WriteMetricsToLog(context); + auto emit_as = [&](const string &old_key, const string &new_key) { + auto it = operator_metrics.find(old_key); + if (it != operator_metrics.end()) { + result.AddValue(new_key, it->second); + } + }; + + emit_as("type", "operator_type"); + emit_as("timing", "operator_timing"); + emit_as("rows_scanned", "operator_rows_scanned"); + emit_as("intermediate_rows", "operator_cardinality"); + emit_as("intermediate_size_bytes", "result_set_size"); + + auto it_extra = operator_metrics.find("extra_info"); + if (it_extra != operator_metrics.end()) { + result.AddValue("extra_info", it_extra->second); + } + result.AddValue("system_peak_buffer_memory", Value::UBIGINT(0)); + result.AddValue("system_peak_temp_dir_size", Value::UBIGINT(0)); + + LegacyCumulative cumulative; + auto timing_it = operator_metrics.find("timing"); + if (timing_it != operator_metrics.end()) { + cumulative.timing = timing_it->second.GetValue(); + } + auto card_it = operator_metrics.find("intermediate_rows"); + if (card_it != operator_metrics.end()) { + cumulative.cardinality = card_it->second.GetValue(); + } + auto rows_it = operator_metrics.find("rows_scanned"); + if (rows_it != operator_metrics.end()) { + cumulative.rows_scanned = rows_it->second.GetValue(); + } + + if (node.GetChildCount() > 0) { + auto &children_list = result.AddList("children"); + for (idx_t i = 0; i < node.GetChildCount(); i++) { + auto &child_result = children_list.AppendObject(); + auto child_cum = LegacyOperatorToResultTree(info, *node.GetChild(i), child_result); + cumulative.timing += child_cum.timing; + cumulative.cardinality += child_cum.cardinality; + cumulative.rows_scanned += child_cum.rows_scanned; + } + } + + result.AddValue("cpu_time", Value::DOUBLE(cumulative.timing)); + result.AddValue("cumulative_cardinality", Value::UBIGINT(cumulative.cardinality)); + result.AddValue("cumulative_rows_scanned", Value::UBIGINT(cumulative.rows_scanned)); + return cumulative; } -string QueryProfiler::ToJSON() const { - lock_guard guard(lock); - ConvertedJSONHolder json_holder; +unique_ptr QueryProfiler::ToLegacyResultTree() const { + auto result = make_uniq(); + if (!root) { + result->AddValue("result", Value(query_metrics.query_sql.empty() ? "empty" : "error")); + return result; + } - json_holder.doc = yyjson_mut_doc_new(nullptr); - auto result_obj = yyjson_mut_obj(json_holder.doc); - yyjson_mut_doc_set_root(json_holder.doc, result_obj); + const auto &gathered = metrics->GetMetrics(); - if (query_metrics.query_name.empty() && !root) { - yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "empty"); - return StringifyAndFree(json_holder, result_obj); + auto emit = [&](const string &new_key, const string &old_key) { + auto it = gathered.find(old_key); + if (it != gathered.end()) { + result->AddValue(new_key, it->second); + } + }; + + emit("total_memory_allocated", "system.total_memory_allocated"); + emit("total_bytes_written", "io.total_bytes_written"); + emit("total_bytes_read", "io.total_bytes_read"); + emit("system_peak_temp_dir_size", "system.peak_temp_dir_size"); + emit("system_peak_buffer_memory", "system.peak_buffer_memory"); + + // rows_returned = root operator's elements_returned (rows sent to client) + { + auto root_op_metrics = root->GetOperatorMetrics().GetMetrics(*metrics); + auto it = root_op_metrics.find("intermediate_rows"); + if (it != root_op_metrics.end()) { + result->AddValue("rows_returned", it->second); + } + } + + emit("result_set_size", "query.total_intermediate_size_bytes"); + emit("latency", "query.total_time"); + emit("wal_replay_entry_count", "storage.wal_replay_entry_count"); + result->AddValue("extra_info", Value::MAP(InsertionOrderPreservingMap())); + emit("commit_local_storage_latency", "storage.commit_local_storage_latency"); + emit("attach_load_storage_latency", "storage.attach_load_storage_latency"); + emit("query_name", "query.sql"); + emit("cpu_time", "query.cpu_time"); + emit("checkpoint_latency", "storage.checkpoint_latency"); + emit("cumulative_cardinality", "query.total_intermediate_rows"); + emit("waiting_to_attach_latency", "storage.waiting_to_attach_latency"); + emit("write_to_wal_latency", "storage.write_to_wal_latency"); + emit("attach_replay_wal_latency", "storage.attach_replay_wal_latency"); + emit("blocked_thread_time", "system.blocked_thread_time"); + emit("cumulative_rows_scanned", "query.total_rows_scanned"); + emit("total_vacuum_time", "storage.total_vacuum_time"); + + auto &children_list = result->AddList("children"); + auto &root_node = children_list.AppendObject(); + LegacyOperatorToResultTree(*metrics, *root, root_node); + return result; +} + +unique_ptr QueryProfiler::ToResultTree() const { + if (Settings::Get(context)) { + return ToLegacyResultTree(); } + auto result = make_uniq(); if (!root) { - yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "error"); - return StringifyAndFree(json_holder, result_obj); + result->AddValue("result", Value(query_metrics.query_sql.empty() ? "empty" : "error")); + return result; + } + metrics->MetricsToProfileResult(*result); + if (metrics->AnyOperatorMetricTracked()) { + auto &op_list = result->AddList("operator"); + auto &op_node = op_list.AppendObject(); + OperatorToResultTree(*metrics, *root, op_node); } + return result; +} - auto &settings = root->GetProfilingInfo(); +QueryProfileResult &QueryProfiler::GetResult() { + lock_guard guard(lock); + if (!result_tree) { + result_tree = ToResultTree(); + } + return *result_tree; +} - settings.WriteMetricsToJSON(json_holder.doc, result_obj); +bool QueryProfiler::HasRoot() const { + return root != nullptr; +} - // recursively print the physical operator tree - auto children_list = yyjson_mut_arr(json_holder.doc); - yyjson_mut_obj_add_val(json_holder.doc, result_obj, "children", children_list); - auto child = ToJSONRecursive(json_holder.doc, *root->GetChild(0)); - yyjson_mut_arr_add_val(children_list, child); - return StringifyAndFree(json_holder, result_obj); +string QueryProfiler::ToJSON() const { + lock_guard guard(lock); + ConvertedJSONHolder json_holder; + json_holder.doc = yyjson_mut_doc_new(nullptr); + auto result = ToResultTree(); + auto root_val = QueryProfileResultToJSON(json_holder.doc, *result); + yyjson_mut_doc_set_root(json_holder.doc, root_val); + return StringifyAndFree(json_holder, root_val); } void QueryProfiler::WriteToFile(const char *path, string &info) const { @@ -816,50 +1030,24 @@ void QueryProfiler::WriteToFile(const char *path, string &info) const { file->Close(); } -profiler_settings_t EraseQueryRootSettings(profiler_settings_t settings) { - profiler_settings_t phase_timing_settings_to_erase; - - for (auto &setting : settings) { - if (MetricsUtils::IsOptimizerMetric(setting) || MetricsUtils::IsPhaseTimingMetric(setting) || - MetricsUtils::IsRootScopeMetric(setting)) { - phase_timing_settings_to_erase.insert(setting); - } - } - - for (auto &setting : phase_timing_settings_to_erase) { - settings.erase(setting); - } - - return settings; -} - -unique_ptr QueryProfiler::CreateTree(const PhysicalOperator &root_p, const profiler_settings_t &settings, - const idx_t depth) { +unique_ptr QueryProfiler::CreateTree(const PhysicalOperator &root_p, const idx_t depth) { if (OperatorRequiresProfiling(root_p.type)) { query_requires_profiling = true; } - unique_ptr node = make_uniq(); - auto &info = node->GetProfilingInfo(); - info = ProfilingInfo(settings, depth); - auto child_settings = settings; - if (depth == 0) { - child_settings = EraseQueryRootSettings(child_settings); - } + auto node = make_uniq(); + auto &info = node->GetOperatorMetrics(); node->depth = depth; - if (depth != 0) { - info.metrics[MetricType::OPERATOR_NAME] = root_p.GetName(); - info.MetricSum(MetricType::OPERATOR_TYPE, static_cast(root_p.type)); - } - if (info.Enabled(info.settings, MetricType::EXTRA_INFO)) { - info.metrics[MetricType::EXTRA_INFO] = Value::MAP(root_p.ParamsToString()); - } + info.name = EnumUtil::ToString(root_p.type); + info.operator_type = root_p.type; + auto params = root_p.ParamsToString(); + info.SetExtraInfo(std::move(params)); tree_map.insert(make_pair(reference(root_p), reference(*node))); auto children = root_p.GetChildren(); for (auto &child : children) { - auto child_node = CreateTree(child.get(), child_settings, depth + 1); + auto child_node = CreateTree(child.get(), depth + 1); node->AddChild(std::move(child_node)); } return node; @@ -912,15 +1100,15 @@ void QueryProfiler::Initialize(const PhysicalOperator &root_op) { return; } query_requires_profiling = false; - ClientConfig &config = ClientConfig::GetConfig(context); - root = CreateTree(root_op, config.profiler_settings, 0); + root = CreateTree(root_op, 0); if (!query_requires_profiling) { // query does not require profiling: disable profiling for this query running = false; tree_map.clear(); root = nullptr; - phase_timings.clear(); - phase_stack.clear(); + } else { + auto &client_config = ClientConfig::GetConfig(context); + metrics = make_uniq(client_config.tracked_metrics); } } @@ -938,45 +1126,34 @@ void QueryProfiler::Print() { Printer::Print(QueryTreeToString()); } -void QueryProfiler::MoveOptimizerPhasesToRoot() { - auto &root_info = root->GetProfilingInfo(); - auto &root_metrics = root_info.metrics; - - for (auto &entry : phase_timings) { - auto &phase = entry.first; - auto &timing = entry.second; - if (root_info.Enabled(root_info.expanded_settings, phase)) { - root_metrics[phase] = Value::CreateValue(timing); - } +static void MergeOperatorMeasurements(ProfilingNode &root, OperatorMetrics &result) { + // merge in this layer + result.Accumulate(root.GetOperatorMetrics()); + // recurse into children + for (idx_t i = 0; i < root.GetChildCount(); i++) { + auto child = root.GetChild(i); + MergeOperatorMeasurements(*child, result); } } void QueryProfiler::FinalizeMetricsInternal() { - if (metrics_finalized || !IsEnabled() || !root) { + if (metrics_finalized || !IsEnabled() || !metrics) { return; } - if (query_metrics.latency_timer) { query_metrics.latency_timer->EndTimer(); } - - auto &info = root->GetProfilingInfo(); - if (info.Enabled(info.expanded_settings, MetricType::OPERATOR_CARDINALITY)) { - Finalize(*root->GetChild(0)); - } - - auto &child_info = root->children[0]->GetProfilingInfo(); - const auto &settings = info.expanded_settings; - for (const auto &global_info_entry : query_metrics.query_global_info.metrics) { - info.metrics[global_info_entry.first] = global_info_entry.second; - } - - MoveOptimizerPhasesToRoot(); - for (auto &metric : info.metrics) { - if (info.Enabled(settings, metric.first)) { - ProfilingUtils::CollectMetrics(metric.first, query_metrics, metric.second, *root, child_info); - } - } + if (root) { + OperatorMetrics cumulative_metrics; + MergeOperatorMeasurements(*root, cumulative_metrics); + metrics->SetMetric(cumulative_metrics.time); + metrics->SetMetric(cumulative_metrics.elements_returned); + metrics->SetMetric(cumulative_metrics.rows_scanned); + metrics->SetMetric(cumulative_metrics.intermediate_size_bytes); + metrics->SetMetric(cumulative_metrics.row_groups_scanned); + metrics->SetMetric(cumulative_metrics.total_row_groups_to_scan); + } + query_metrics.FinalizeMetrics(*metrics); metrics_finalized = true; } diff --git a/src/duckdb/src/main/query_result.cpp b/src/duckdb/src/main/query_result.cpp index be5b9b2f1..719a69b60 100644 --- a/src/duckdb/src/main/query_result.cpp +++ b/src/duckdb/src/main/query_result.cpp @@ -69,25 +69,27 @@ QueryResult::~QueryResult() { } void QueryResult::DeduplicateColumns(vector &names) { - unordered_map name_map; + auto identifiers = StringsToIdentifiers(names); + DeduplicateColumns(identifiers); + names = IdentifiersToStrings(identifiers); +} + +void QueryResult::DeduplicateColumns(vector &names) { + identifier_map_t name_map; for (auto &column_name : names) { - // put it all lower_case - auto low_column_name = StringUtil::Lower(column_name); - if (name_map.find(low_column_name) == name_map.end()) { + if (name_map.find(column_name) == name_map.end()) { // Name does not exist yet - name_map[low_column_name]++; + name_map[column_name]++; } else { // Name already exists, we add _x where x is the repetition number - string new_column_name = column_name + "_" + std::to_string(name_map[low_column_name]); - auto new_column_name_low = StringUtil::Lower(new_column_name); - while (name_map.find(new_column_name_low) != name_map.end()) { + Identifier new_column_name(column_name + "_" + std::to_string(name_map[column_name])); + while (name_map.find(new_column_name) != name_map.end()) { // This name is already here due to a previous definition - name_map[low_column_name]++; - new_column_name = column_name + "_" + std::to_string(name_map[low_column_name]); - new_column_name_low = StringUtil::Lower(new_column_name); + name_map[column_name]++; + new_column_name = Identifier(column_name + "_" + std::to_string(name_map[column_name])); } column_name = new_column_name; - name_map[new_column_name_low]++; + name_map[new_column_name]++; } } } diff --git a/src/duckdb/src/main/relation.cpp b/src/duckdb/src/main/relation.cpp index 8efa2cd11..8f59982ea 100644 --- a/src/duckdb/src/main/relation.cpp +++ b/src/duckdb/src/main/relation.cpp @@ -39,7 +39,8 @@ shared_ptr Relation::Project(const string &expression, const string &a shared_ptr Relation::Project(const string &select_list, const vector &aliases) { auto expressions = Parser::ParseExpressionList(select_list, context->GetContext()->GetParserOptions()); - return make_shared_ptr(shared_from_this(), std::move(expressions), aliases); + return make_shared_ptr(shared_from_this(), std::move(expressions), + StringsToIdentifiers(aliases)); } shared_ptr Relation::Project(const vector &expressions) { @@ -49,7 +50,8 @@ shared_ptr Relation::Project(const vector &expressions) { shared_ptr Relation::Project(vector> expressions, const vector &aliases) { - return make_shared_ptr(shared_from_this(), std::move(expressions), aliases); + return make_shared_ptr(shared_from_this(), std::move(expressions), + StringsToIdentifiers(aliases)); } static vector> StringListToExpressionList(const ClientContext &context, @@ -70,7 +72,8 @@ static vector> StringListToExpressionList(const Cli shared_ptr Relation::Project(const vector &expressions, const vector &aliases) { auto result_list = StringListToExpressionList(*context->GetContext(), expressions); - return make_shared_ptr(shared_from_this(), std::move(result_list), aliases); + return make_shared_ptr(shared_from_this(), std::move(result_list), + StringsToIdentifiers(aliases)); } shared_ptr Relation::Filter(const string &expression) { @@ -138,7 +141,7 @@ shared_ptr Relation::Join(const shared_ptr &other, JoinRefType ref_type) { if (expression_list.size() > 1 || expression_list[0]->GetExpressionType() == ExpressionType::COLUMN_REF) { // multiple columns or single column ref: the condition is a USING list - vector using_columns; + vector using_columns; for (auto &expr : expression_list) { if (expr->GetExpressionType() != ExpressionType::COLUMN_REF) { throw ParserException("Expected a single expression as join condition"); @@ -147,7 +150,7 @@ shared_ptr Relation::Join(const shared_ptr &other, if (colref.IsQualified()) { throw ParserException("Expected unqualified column for column in USING clause"); } - using_columns.push_back(colref.column_names[0]); + using_columns.push_back(colref.ColumnNames()[0]); } return make_shared_ptr(shared_from_this(), other, std::move(using_columns), type, ref_type); } else { @@ -211,8 +214,8 @@ shared_ptr Relation::Aggregate(vector> ex return make_shared_ptr(shared_from_this(), std::move(expressions), std::move(groups)); } -string Relation::GetAlias() { - return alias; +Identifier Relation::GetAlias() { + return Identifier(alias); } unique_ptr Relation::GetTableRef() { @@ -240,24 +243,24 @@ BoundStatement Relation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } -shared_ptr Relation::InsertRel(const string &schema_name, const string &table_name) { - return InsertRel(INVALID_CATALOG, schema_name, table_name); +shared_ptr Relation::InsertRel(const Identifier &schema_name, const Identifier &table_name) { + return InsertRel(Identifier::InvalidCatalog(), schema_name, table_name); } -shared_ptr Relation::InsertRel(const string &catalog_name, const string &schema_name, - const string &table_name) { +shared_ptr Relation::InsertRel(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &table_name) { return make_shared_ptr(shared_from_this(), catalog_name, schema_name, table_name); } -void Relation::Insert(const string &table_name) { - Insert(INVALID_SCHEMA, table_name); +void Relation::Insert(const Identifier &table_name) { + Insert(Identifier::InvalidSchema(), table_name); } -void Relation::Insert(const string &schema_name, const string &table_name) { - Insert(INVALID_CATALOG, schema_name, table_name); +void Relation::Insert(const Identifier &schema_name, const Identifier &table_name) { + Insert(Identifier::InvalidCatalog(), schema_name, table_name); } -void Relation::Insert(const string &catalog_name, const string &schema_name, const string &table_name) { +void Relation::Insert(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &table_name) { auto insert = InsertRel(catalog_name, schema_name, table_name); auto res = insert->Execute(); if (res->HasError()) { @@ -275,28 +278,28 @@ void Relation::Insert(vector>> &&expressions throw InvalidInputException("INSERT with expressions can only be used on base tables!"); } -shared_ptr Relation::CreateRel(const string &schema_name, const string &table_name, bool temporary, +shared_ptr Relation::CreateRel(const Identifier &schema_name, const Identifier &table_name, bool temporary, OnCreateConflict on_conflict) { - return CreateRel(INVALID_CATALOG, schema_name, table_name, temporary, on_conflict); + return CreateRel(Identifier::InvalidCatalog(), schema_name, table_name, temporary, on_conflict); } -shared_ptr Relation::CreateRel(const string &catalog_name, const string &schema_name, - const string &table_name, bool temporary, OnCreateConflict on_conflict) { +shared_ptr Relation::CreateRel(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &table_name, bool temporary, OnCreateConflict on_conflict) { return make_shared_ptr(shared_from_this(), catalog_name, schema_name, table_name, temporary, on_conflict); } -void Relation::Create(const string &table_name, bool temporary, OnCreateConflict on_conflict) { - Create(INVALID_CATALOG, INVALID_SCHEMA, table_name, temporary, on_conflict); +void Relation::Create(const Identifier &table_name, bool temporary, OnCreateConflict on_conflict) { + Create(Identifier::InvalidCatalog(), Identifier::InvalidSchema(), table_name, temporary, on_conflict); } -void Relation::Create(const string &schema_name, const string &table_name, bool temporary, +void Relation::Create(const Identifier &schema_name, const Identifier &table_name, bool temporary, OnCreateConflict on_conflict) { - Create(INVALID_CATALOG, schema_name, table_name, temporary, on_conflict); + Create(Identifier::InvalidCatalog(), schema_name, table_name, temporary, on_conflict); } -void Relation::Create(const string &catalog_name, const string &schema_name, const string &table_name, bool temporary, - OnCreateConflict on_conflict) { +void Relation::Create(const Identifier &catalog_name, const Identifier &schema_name, const Identifier &table_name, + bool temporary, OnCreateConflict on_conflict) { if (table_name.empty()) { throw ParserException("Empty table name not supported"); } @@ -337,11 +340,12 @@ void Relation::WriteParquet(const string &parquet_file, case_insensitive_map_t Relation::CreateView(const string &name, bool replace, bool temporary) { - return CreateView(INVALID_SCHEMA, name, replace, temporary); +shared_ptr Relation::CreateView(const Identifier &name, bool replace, bool temporary) { + return CreateView(Identifier::InvalidSchema(), name, replace, temporary); } -shared_ptr Relation::CreateView(const string &schema_name, const string &name, bool replace, bool temporary) { +shared_ptr Relation::CreateView(const Identifier &schema_name, const Identifier &name, bool replace, + bool temporary) { auto view = make_shared_ptr(shared_from_this(), schema_name, name, replace, temporary); auto res = view->Execute(); if (res->HasError()) { @@ -355,7 +359,7 @@ unique_ptr Relation::Query(const string &sql) const { return context->GetContext()->Query(sql, false); } -unique_ptr Relation::Query(const string &name, const string &sql) { +unique_ptr Relation::Query(const Identifier &name, const string &sql) { bool replace = true; bool temp = IsReadOnly(); CreateView(name, replace, temp); diff --git a/src/duckdb/src/main/relation/aggregate_relation.cpp b/src/duckdb/src/main/relation/aggregate_relation.cpp index 4a09d7497..340fb33f1 100644 --- a/src/duckdb/src/main/relation/aggregate_relation.cpp +++ b/src/duckdb/src/main/relation/aggregate_relation.cpp @@ -70,7 +70,7 @@ unique_ptr AggregateRelation::GetQueryNode() { return result; } -string AggregateRelation::GetAlias() { +Identifier AggregateRelation::GetAlias() { return child->GetAlias(); } diff --git a/src/duckdb/src/main/relation/create_table_relation.cpp b/src/duckdb/src/main/relation/create_table_relation.cpp index 2a08194c0..895cd8097 100644 --- a/src/duckdb/src/main/relation/create_table_relation.cpp +++ b/src/duckdb/src/main/relation/create_table_relation.cpp @@ -6,7 +6,7 @@ namespace duckdb { -CreateTableRelation::CreateTableRelation(shared_ptr child_p, string schema_name, string table_name, +CreateTableRelation::CreateTableRelation(shared_ptr child_p, Identifier schema_name, Identifier table_name, bool temporary_p, OnCreateConflict on_conflict) : Relation(child_p->context, RelationType::CREATE_TABLE_RELATION), child(std::move(child_p)), schema_name(std::move(schema_name)), table_name(std::move(table_name)), temporary(temporary_p), @@ -14,8 +14,8 @@ CreateTableRelation::CreateTableRelation(shared_ptr child_p, string sc TryBindRelation(columns); } -CreateTableRelation::CreateTableRelation(shared_ptr child_p, string catalog_name, string schema_name, - string table_name, bool temporary_p, OnCreateConflict on_conflict) +CreateTableRelation::CreateTableRelation(shared_ptr child_p, Identifier catalog_name, Identifier schema_name, + Identifier table_name, bool temporary_p, OnCreateConflict on_conflict) : Relation(child_p->context, RelationType::CREATE_TABLE_RELATION), child(std::move(child_p)), catalog_name(std::move(catalog_name)), schema_name(std::move(schema_name)), table_name(std::move(table_name)), temporary(temporary_p), on_conflict(on_conflict) { diff --git a/src/duckdb/src/main/relation/create_view_relation.cpp b/src/duckdb/src/main/relation/create_view_relation.cpp index 6f77f013f..3fb7fbfc6 100644 --- a/src/duckdb/src/main/relation/create_view_relation.cpp +++ b/src/duckdb/src/main/relation/create_view_relation.cpp @@ -5,14 +5,14 @@ namespace duckdb { -CreateViewRelation::CreateViewRelation(shared_ptr child_p, string view_name_p, bool replace_p, +CreateViewRelation::CreateViewRelation(shared_ptr child_p, Identifier view_name_p, bool replace_p, bool temporary_p) : Relation(child_p->context, RelationType::CREATE_VIEW_RELATION), child(std::move(child_p)), view_name(std::move(view_name_p)), replace(replace_p), temporary(temporary_p) { TryBindRelation(columns); } -CreateViewRelation::CreateViewRelation(shared_ptr child_p, string schema_name_p, string view_name_p, +CreateViewRelation::CreateViewRelation(shared_ptr child_p, Identifier schema_name_p, Identifier view_name_p, bool replace_p, bool temporary_p) : Relation(child_p->context, RelationType::CREATE_VIEW_RELATION), child(std::move(child_p)), schema_name(std::move(schema_name_p)), view_name(std::move(view_name_p)), replace(replace_p), diff --git a/src/duckdb/src/main/relation/delete_relation.cpp b/src/duckdb/src/main/relation/delete_relation.cpp index b0b427f21..c7107a3ad 100644 --- a/src/duckdb/src/main/relation/delete_relation.cpp +++ b/src/duckdb/src/main/relation/delete_relation.cpp @@ -8,7 +8,7 @@ namespace duckdb { DeleteRelation::DeleteRelation(shared_ptr &context, unique_ptr condition_p, - string catalog_name_p, string schema_name_p, string table_name_p) + Identifier catalog_name_p, Identifier schema_name_p, Identifier table_name_p) : Relation(context, RelationType::DELETE_RELATION), condition(std::move(condition_p)), catalog_name(std::move(catalog_name_p)), schema_name(std::move(schema_name_p)), table_name(std::move(table_name_p)) { diff --git a/src/duckdb/src/main/relation/distinct_relation.cpp b/src/duckdb/src/main/relation/distinct_relation.cpp index 16bb53822..5f904b124 100644 --- a/src/duckdb/src/main/relation/distinct_relation.cpp +++ b/src/duckdb/src/main/relation/distinct_relation.cpp @@ -17,7 +17,7 @@ unique_ptr DistinctRelation::GetQueryNode() { return child_node; } -string DistinctRelation::GetAlias() { +Identifier DistinctRelation::GetAlias() { return child->GetAlias(); } diff --git a/src/duckdb/src/main/relation/filter_relation.cpp b/src/duckdb/src/main/relation/filter_relation.cpp index 738e83577..82fae0ce4 100644 --- a/src/duckdb/src/main/relation/filter_relation.cpp +++ b/src/duckdb/src/main/relation/filter_relation.cpp @@ -41,7 +41,7 @@ unique_ptr FilterRelation::GetQueryNode() { } } -string FilterRelation::GetAlias() { +Identifier FilterRelation::GetAlias() { return child->GetAlias(); } diff --git a/src/duckdb/src/main/relation/insert_relation.cpp b/src/duckdb/src/main/relation/insert_relation.cpp index 2bda1be79..c63e372e7 100644 --- a/src/duckdb/src/main/relation/insert_relation.cpp +++ b/src/duckdb/src/main/relation/insert_relation.cpp @@ -8,13 +8,14 @@ namespace duckdb { -InsertRelation::InsertRelation(shared_ptr child_p, string schema_name, string table_name) +InsertRelation::InsertRelation(shared_ptr child_p, Identifier schema_name, Identifier table_name) : Relation(child_p->context, RelationType::INSERT_RELATION), child(std::move(child_p)), schema_name(std::move(schema_name)), table_name(std::move(table_name)) { TryBindRelation(columns); } -InsertRelation::InsertRelation(shared_ptr child_p, string catalog_name, string schema_name, string table_name) +InsertRelation::InsertRelation(shared_ptr child_p, Identifier catalog_name, Identifier schema_name, + Identifier table_name) : Relation(child_p->context, RelationType::INSERT_RELATION), child(std::move(child_p)), catalog_name(std::move(catalog_name)), schema_name(std::move(schema_name)), table_name(std::move(table_name)) { TryBindRelation(columns); diff --git a/src/duckdb/src/main/relation/join_relation.cpp b/src/duckdb/src/main/relation/join_relation.cpp index b3701ea39..eb1f672fa 100644 --- a/src/duckdb/src/main/relation/join_relation.cpp +++ b/src/duckdb/src/main/relation/join_relation.cpp @@ -18,8 +18,8 @@ JoinRelation::JoinRelation(shared_ptr left_p, shared_ptr rig TryBindRelation(columns); } -JoinRelation::JoinRelation(shared_ptr left_p, shared_ptr right_p, vector using_columns_p, - JoinType type, JoinRefType join_ref_type) +JoinRelation::JoinRelation(shared_ptr left_p, shared_ptr right_p, + vector using_columns_p, JoinType type, JoinRefType join_ref_type) : Relation(left_p->context, RelationType::JOIN_RELATION), left(std::move(left_p)), right(std::move(right_p)), using_columns(std::move(using_columns_p)), join_type(type), join_ref_type(join_ref_type) { if (left->context->GetContext() != right->context->GetContext()) { diff --git a/src/duckdb/src/main/relation/limit_relation.cpp b/src/duckdb/src/main/relation/limit_relation.cpp index c3a87b0e1..0e5b26830 100644 --- a/src/duckdb/src/main/relation/limit_relation.cpp +++ b/src/duckdb/src/main/relation/limit_relation.cpp @@ -26,7 +26,7 @@ unique_ptr LimitRelation::GetQueryNode() { return child_node; } -string LimitRelation::GetAlias() { +Identifier LimitRelation::GetAlias() { return child->GetAlias(); } diff --git a/src/duckdb/src/main/relation/materialized_relation.cpp b/src/duckdb/src/main/relation/materialized_relation.cpp index 8e17168f0..47d551736 100644 --- a/src/duckdb/src/main/relation/materialized_relation.cpp +++ b/src/duckdb/src/main/relation/materialized_relation.cpp @@ -10,9 +10,9 @@ namespace duckdb { MaterializedRelation::MaterializedRelation(const shared_ptr &context, - unique_ptr &&collection_p, vector names, - string alias_p) - : Relation(context, RelationType::MATERIALIZED_RELATION), alias(std::move(alias_p)), + unique_ptr &&collection_p, vector names, + const Identifier &alias_p) + : Relation(context, RelationType::MATERIALIZED_RELATION), alias(alias_p.GetIdentifierName()), collection(std::move(collection_p)) { // create constant expressions for the values auto types = collection->Types(); @@ -37,14 +37,14 @@ unique_ptr MaterializedRelation::GetQueryNode() { unique_ptr MaterializedRelation::GetTableRef() { auto table_ref = make_uniq(collection); for (auto &col : columns) { - table_ref->expected_names.push_back(col.Name()); + table_ref->expected_names.emplace_back(col.Name()); } - table_ref->alias = GetAlias(); + table_ref->alias = Identifier(GetAlias()); return std::move(table_ref); } -string MaterializedRelation::GetAlias() { - return alias; +Identifier MaterializedRelation::GetAlias() { + return Identifier(alias); } const vector &MaterializedRelation::Columns() { diff --git a/src/duckdb/src/main/relation/order_relation.cpp b/src/duckdb/src/main/relation/order_relation.cpp index ac97d1c52..7a2237da9 100644 --- a/src/duckdb/src/main/relation/order_relation.cpp +++ b/src/duckdb/src/main/relation/order_relation.cpp @@ -25,7 +25,7 @@ unique_ptr OrderRelation::GetQueryNode() { return std::move(select); } -string OrderRelation::GetAlias() { +Identifier OrderRelation::GetAlias() { return child->GetAlias(); } diff --git a/src/duckdb/src/main/relation/projection_relation.cpp b/src/duckdb/src/main/relation/projection_relation.cpp index 50345abf3..c80518d4e 100644 --- a/src/duckdb/src/main/relation/projection_relation.cpp +++ b/src/duckdb/src/main/relation/projection_relation.cpp @@ -7,7 +7,8 @@ namespace duckdb { ProjectionRelation::ProjectionRelation(shared_ptr child_p, - vector> parsed_expressions, vector aliases) + vector> parsed_expressions, + vector aliases) : Relation(child_p->context, RelationType::PROJECTION_RELATION), expressions(std::move(parsed_expressions)), child(std::move(child_p)) { if (!aliases.empty()) { @@ -47,7 +48,7 @@ unique_ptr ProjectionRelation::GetQueryNode() { return result; } -string ProjectionRelation::GetAlias() { +Identifier ProjectionRelation::GetAlias() { return child->GetAlias(); } diff --git a/src/duckdb/src/main/relation/query_relation.cpp b/src/duckdb/src/main/relation/query_relation.cpp index c991aeae2..bcdd55dcf 100644 --- a/src/duckdb/src/main/relation/query_relation.cpp +++ b/src/duckdb/src/main/relation/query_relation.cpp @@ -100,8 +100,8 @@ BoundStatement QueryRelation::Bind(Binder &binder) { return result; } -string QueryRelation::GetAlias() { - return alias; +Identifier QueryRelation::GetAlias() { + return Identifier(alias); } const vector &QueryRelation::Columns() { diff --git a/src/duckdb/src/main/relation/read_csv_relation.cpp b/src/duckdb/src/main/relation/read_csv_relation.cpp index 7ba976526..3f998b126 100644 --- a/src/duckdb/src/main/relation/read_csv_relation.cpp +++ b/src/duckdb/src/main/relation/read_csv_relation.cpp @@ -48,7 +48,7 @@ CSVReaderOptions ReadCSVRelationBind(const shared_ptr &context, c if (file_options.union_by_name) { vector types; - vector names; + vector names; auto result = make_uniq(); auto csv_data = make_uniq(); result->interface = make_uniq(); @@ -74,7 +74,7 @@ CSVReaderOptions ReadCSVRelationBind(const shared_ptr &context, c } else { if (csv_options.auto_detect) { vector return_types; - vector names; + vector names; shared_ptr buffer_manager; CSVSchemaDiscovery::SchemaDiscovery(*context, buffer_manager, csv_options, file_options, return_types, names, multi_file_list); @@ -84,7 +84,7 @@ CSVReaderOptions ReadCSVRelationBind(const shared_ptr &context, c } else { for (idx_t i = 0; i < csv_options.sql_type_list.size(); i++) { D_ASSERT(csv_options.name_list.size() == csv_options.sql_type_list.size()); - columns.emplace_back(csv_options.name_list[i], csv_options.sql_type_list[i]); + columns.emplace_back(Identifier(csv_options.name_list[i]), csv_options.sql_type_list[i]); } } // After sniffing we can consider these set, so they are exported as named parameters @@ -120,7 +120,7 @@ ReadCSVRelation::ReadCSVRelation(const shared_ptr &context, const child_list_t column_names; for (idx_t i = 0; i < columns.size(); i++) { - column_names.push_back(make_pair(columns[i].Name(), Value(columns[i].Type().ToString()))); + column_names.emplace_back(make_pair(columns[i].Name(), Value(columns[i].Type().ToString()))); } if (!file_options.union_by_name) { @@ -131,8 +131,8 @@ ReadCSVRelation::ReadCSVRelation(const shared_ptr &context, const RemoveNamedParameterIfExists("dtypes"); } -string ReadCSVRelation::GetAlias() { - return alias; +Identifier ReadCSVRelation::GetAlias() { + return Identifier(alias); } } // namespace duckdb diff --git a/src/duckdb/src/main/relation/read_json_relation.cpp b/src/duckdb/src/main/relation/read_json_relation.cpp index 6fe7e4a7c..f05053af8 100644 --- a/src/duckdb/src/main/relation/read_json_relation.cpp +++ b/src/duckdb/src/main/relation/read_json_relation.cpp @@ -31,8 +31,8 @@ ReadJSONRelation::ReadJSONRelation(const shared_ptr &context, str ReadJSONRelation::~ReadJSONRelation() { } -string ReadJSONRelation::GetAlias() { - return alias; +Identifier ReadJSONRelation::GetAlias() { + return Identifier(alias); } } // namespace duckdb diff --git a/src/duckdb/src/main/relation/setop_relation.cpp b/src/duckdb/src/main/relation/setop_relation.cpp index 9a24fdd16..609828cfe 100644 --- a/src/duckdb/src/main/relation/setop_relation.cpp +++ b/src/duckdb/src/main/relation/setop_relation.cpp @@ -27,7 +27,7 @@ unique_ptr SetOpRelation::GetQueryNode() { return std::move(result); } -string SetOpRelation::GetAlias() { +Identifier SetOpRelation::GetAlias() { return left->GetAlias(); } diff --git a/src/duckdb/src/main/relation/table_function_relation.cpp b/src/duckdb/src/main/relation/table_function_relation.cpp index 5111752e2..2dbef45bf 100644 --- a/src/duckdb/src/main/relation/table_function_relation.cpp +++ b/src/duckdb/src/main/relation/table_function_relation.cpp @@ -14,12 +14,12 @@ namespace duckdb { void TableFunctionRelation::AddNamedParameter(const string &name, Value argument) { - named_parameters[name] = std::move(argument); + named_parameters[Identifier(name)] = std::move(argument); } void TableFunctionRelation::RemoveNamedParameterIfExists(const string &name) { - if (named_parameters.find(name) != named_parameters.end()) { - named_parameters.erase(name); + if (named_parameters.find(Identifier(name)) != named_parameters.end()) { + named_parameters.erase(Identifier(name)); } } @@ -72,9 +72,9 @@ unique_ptr TableFunctionRelation::GetTableRef() { vector> children; if (input_relation) { // input relation becomes first parameter if present, always auto subquery = make_uniq(); - subquery->subquery = make_uniq(); - subquery->subquery->node = input_relation->GetQueryNode(); - subquery->subquery_type = SubqueryType::SCALAR; + subquery->SubqueryMutable() = make_uniq(); + subquery->SubqueryMutable()->node = input_relation->GetQueryNode(); + subquery->GetSubqueryTypeMutable() = SubqueryType::SCALAR; children.push_back(std::move(subquery)); } for (auto ¶meter : parameters) { @@ -98,7 +98,7 @@ unique_ptr TableFunctionRelation::GetTableRef() { return std::move(table_function); } -string TableFunctionRelation::GetAlias() { +Identifier TableFunctionRelation::GetAlias() { return name; } diff --git a/src/duckdb/src/main/relation/table_relation.cpp b/src/duckdb/src/main/relation/table_relation.cpp index 78d5aaaa4..6ad221df7 100644 --- a/src/duckdb/src/main/relation/table_relation.cpp +++ b/src/duckdb/src/main/relation/table_relation.cpp @@ -34,7 +34,7 @@ unique_ptr TableRelation::GetTableRef() { return std::move(table_ref); } -string TableRelation::GetAlias() { +Identifier TableRelation::GetAlias() { return description->table; } @@ -61,7 +61,7 @@ static unique_ptr ParseCondition(ClientContext &context, const void TableRelation::Update(vector names, vector> &&update, unique_ptr condition) { - vector update_columns = std::move(names); + vector update_columns = StringsToIdentifiers(names); vector> expressions = std::move(update); auto update_relation = @@ -71,7 +71,7 @@ void TableRelation::Update(vector names, vector update_columns; + vector update_columns; vector> expressions; auto cond = ParseCondition(*context->GetContext(), condition); Parser::ParseUpdateList(update_list, update_columns, expressions, context->GetContext()->GetParserOptions()); diff --git a/src/duckdb/src/main/relation/update_relation.cpp b/src/duckdb/src/main/relation/update_relation.cpp index 9992ea107..8cb5c62f1 100644 --- a/src/duckdb/src/main/relation/update_relation.cpp +++ b/src/duckdb/src/main/relation/update_relation.cpp @@ -8,8 +8,8 @@ namespace duckdb { UpdateRelation::UpdateRelation(shared_ptr &context, unique_ptr condition_p, - string catalog_name_p, string schema_name_p, string table_name_p, - vector update_columns_p, vector> expressions_p) + Identifier catalog_name_p, Identifier schema_name_p, Identifier table_name_p, + vector update_columns_p, vector> expressions_p) : Relation(context, RelationType::UPDATE_RELATION), condition(std::move(condition_p)), catalog_name(std::move(catalog_name_p)), schema_name(std::move(schema_name_p)), table_name(std::move(table_name_p)), update_columns(std::move(update_columns_p)), diff --git a/src/duckdb/src/main/relation/value_relation.cpp b/src/duckdb/src/main/relation/value_relation.cpp index bd9adf991..8b17ee334 100644 --- a/src/duckdb/src/main/relation/value_relation.cpp +++ b/src/duckdb/src/main/relation/value_relation.cpp @@ -63,7 +63,7 @@ ValueRelation::ValueRelation(const shared_ptr &context, if (names_p.empty()) { auto &first_list = expressions_p[0]; for (auto &expr : first_list) { - names_p.push_back(expr->GetName()); + names_p.emplace_back(expr->GetName()); } } names = std::move(names_p); @@ -85,11 +85,11 @@ unique_ptr ValueRelation::GetTableRef() { if (columns.empty()) { // no columns yet: only set up names for (idx_t i = 0; i < names.size(); i++) { - table_ref->expected_names.push_back(names[i]); + table_ref->expected_names.push_back(Identifier(names[i])); } } else { for (idx_t i = 0; i < columns.size(); i++) { - table_ref->expected_names.push_back(columns[i].Name()); + table_ref->expected_names.emplace_back(columns[i].Name()); table_ref->expected_types.push_back(columns[i].Type()); D_ASSERT(names.size() == 0 || columns[i].Name() == names[i]); } @@ -103,12 +103,12 @@ unique_ptr ValueRelation::GetTableRef() { } table_ref->values.push_back(std::move(copied_list)); } - table_ref->alias = GetAlias(); + table_ref->alias = Identifier(GetAlias()); return std::move(table_ref); } -string ValueRelation::GetAlias() { - return alias; +Identifier ValueRelation::GetAlias() { + return Identifier(alias); } const vector &ValueRelation::Columns() { diff --git a/src/duckdb/src/main/relation/view_relation.cpp b/src/duckdb/src/main/relation/view_relation.cpp index 16c1bdc37..5fbbf3d38 100644 --- a/src/duckdb/src/main/relation/view_relation.cpp +++ b/src/duckdb/src/main/relation/view_relation.cpp @@ -7,13 +7,14 @@ namespace duckdb { -ViewRelation::ViewRelation(const shared_ptr &context, string schema_name_p, string view_name_p) +ViewRelation::ViewRelation(const shared_ptr &context, Identifier schema_name_p, Identifier view_name_p) : Relation(context, RelationType::VIEW_RELATION), schema_name(std::move(schema_name_p)), view_name(std::move(view_name_p)) { TryBindRelation(columns); } -ViewRelation::ViewRelation(const shared_ptr &context, string schema_name_p, string view_name_p) +ViewRelation::ViewRelation(const shared_ptr &context, Identifier schema_name_p, + Identifier view_name_p) : Relation(context, RelationType::VIEW_RELATION), schema_name(std::move(schema_name_p)), view_name(std::move(view_name_p)) { TryBindRelation(columns); @@ -22,7 +23,7 @@ ViewRelation::ViewRelation(const shared_ptr &context, st ViewRelation::ViewRelation(const shared_ptr &context, unique_ptr ref, const string &view_name) : Relation(context, RelationType::VIEW_RELATION), view_name(view_name), premade_tableref(std::move(ref)) { TryBindRelation(columns); - premade_tableref->alias = view_name; + premade_tableref->alias = Identifier(view_name); } unique_ptr ViewRelation::GetQueryNode() { @@ -42,7 +43,7 @@ unique_ptr ViewRelation::GetTableRef() { return std::move(table_ref); } -string ViewRelation::GetAlias() { +Identifier ViewRelation::GetAlias() { D_ASSERT(!view_name.empty()); return view_name; } diff --git a/src/duckdb/src/main/secret/default_secrets.cpp b/src/duckdb/src/main/secret/default_secrets.cpp index 631c1abe7..f0dce5f5d 100644 --- a/src/duckdb/src/main/secret/default_secrets.cpp +++ b/src/duckdb/src/main/secret/default_secrets.cpp @@ -88,6 +88,9 @@ unique_ptr CreateHTTPSecretFunctions::CreateHTTPSecretFromEnv(Client secret->TrySetValue("extra_http_headers", input); secret->TrySetValue("bearer_token", input); + //! Set redact keys + secret->redact_keys = {"http_proxy_password", "bearer_token"}; + return std::move(secret); } @@ -104,7 +107,7 @@ unique_ptr CreateHTTPSecretFunctions::CreateHTTPSecretFromConfig(Cli secret->TrySetValue("bearer_token", input); //! Set redact keys - secret->redact_keys = {"http_proxy_password"}; + secret->redact_keys = {"http_proxy_password", "bearer_token"}; return std::move(secret); } diff --git a/src/duckdb/src/main/secret/secret.cpp b/src/duckdb/src/main/secret/secret.cpp index b2bc9838a..83f01f0a3 100644 --- a/src/duckdb/src/main/secret/secret.cpp +++ b/src/duckdb/src/main/secret/secret.cpp @@ -61,7 +61,7 @@ string KeyValueSecret::ToString(SecretDisplayType mode) const { result = result.substr(0, result.size() - 1); result += ";"; for (auto it = secret_map.begin(); it != secret_map.end(); it++) { - result.append(it->first); + result.append(it->first.GetIdentifierName()); result.append("="); if (mode == SecretDisplayType::REDACTED && redact_keys.find(it->first) != redact_keys.end()) { result.append("redacted"); @@ -83,8 +83,8 @@ void KeyValueSecret::Serialize(Serializer &serializer) const { vector map_values; for (auto it = secret_map.begin(); it != secret_map.end(); it++) { child_list_t map_struct; - map_struct.push_back(make_pair("key", Value(it->first))); - map_struct.push_back(make_pair("value", Value(it->second))); + map_struct.emplace_back(make_pair("key", Value(it->first))); + map_struct.emplace_back(make_pair("value", Value(it->second))); map_values.push_back(Value::STRUCT(map_struct)); } @@ -101,7 +101,7 @@ void KeyValueSecret::Serialize(Serializer &serializer) const { serializer.WriteProperty(202, "redact_keys", list); } -Value KeyValueSecret::TryGetValue(const string &key, bool error_on_missing) const { +Value KeyValueSecret::TryGetValue(const Identifier &key, bool error_on_missing) const { auto lookup = secret_map.find(key); if (lookup == secret_map.end()) { if (error_on_missing) { @@ -185,7 +185,7 @@ KeyValueSecretReader::~KeyValueSecretReader() { } SettingLookupResult KeyValueSecretReader::TryGetSecretKey(const string &secret_key, Value &result) { - if (secret && secret->TryGetValue(secret_key, result)) { + if (secret && secret->TryGetValue(Identifier(secret_key), result)) { return SettingLookupResult(SettingScope::SECRET); } return SettingLookupResult(); @@ -193,7 +193,7 @@ SettingLookupResult KeyValueSecretReader::TryGetSecretKey(const string &secret_k SettingLookupResult KeyValueSecretReader::TryGetSecretKeyOrSetting(const string &secret_key, const string &setting_name, Value &result) { - if (secret && secret->TryGetValue(secret_key, result)) { + if (secret && secret->TryGetValue(Identifier(secret_key), result)) { return SettingLookupResult(SettingScope::SECRET); } if (context) { @@ -252,8 +252,8 @@ void KeyValueSecretReader::ThrowNotFoundError(const string &secret_key, const st secret_key, setting_name, secret->GetName()); } -bool CreateSecretFunctionSet::ProviderExists(const string &provider_name) { - return functions.find(provider_name) != functions.end(); +bool CreateSecretFunctionSet::ProviderExists(const Identifier &provider_name) { + return functions.find(Identifier(provider_name)) != functions.end(); } void CreateSecretFunctionSet::AddFunction(CreateSecretFunction &function, OnCreateConflict on_conflict) { @@ -273,7 +273,7 @@ void CreateSecretFunctionSet::AddFunction(CreateSecretFunction &function, OnCrea } CreateSecretFunction &CreateSecretFunctionSet::GetFunction(const string &provider) { - const auto &lookup = functions.find(provider); + const auto &lookup = functions.find(Identifier(provider)); if (lookup == functions.end()) { throw InternalException("Could not find Create Secret Function with provider %s"); diff --git a/src/duckdb/src/main/secret/secret_manager.cpp b/src/duckdb/src/main/secret/secret_manager.cpp index 7b132baf8..2d91019f7 100644 --- a/src/duckdb/src/main/secret/secret_manager.cpp +++ b/src/duckdb/src/main/secret/secret_manager.cpp @@ -3,6 +3,7 @@ #include "duckdb/catalog/catalog_entry.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/file_system.hpp" +#include "duckdb/main/database_file_opener.hpp" #include "duckdb/common/local_file_system.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/serializer/binary_deserializer.hpp" @@ -43,7 +44,7 @@ void SecretManager::Initialize(DatabaseInstance &db) { lock_guard lck(manager_lock); // Construct default path - LocalFileSystem fs; + auto &fs = FileSystem::GetLocal(db); config.default_secret_path = fs.GetHomeDirectory(); vector path_components = {".duckdb", "stored_secrets"}; for (auto &path_ele : path_components) { @@ -77,7 +78,7 @@ void SecretManager::LoadSecretStorage(unique_ptr storage) { } void SecretManager::LoadSecretStorageInternal(unique_ptr storage) { - if (secret_storages.find(storage->GetName()) != secret_storages.end()) { + if (secret_storages.find(Identifier(storage->GetName())) != secret_storages.end()) { throw InvalidConfigurationException("Secret Storage with name '%s' already registered!", storage->GetName()); } @@ -90,7 +91,7 @@ void SecretManager::LoadSecretStorageInternal(unique_ptr storage) } } - secret_storages[storage->GetName()] = std::move(storage); + secret_storages[Identifier(storage->GetName())] = std::move(storage); } // FIXME: use serialization scripts? @@ -107,7 +108,8 @@ unique_ptr SecretManager::DeserializeSecret(Deserializer &deserializ switch (serialization_type) { // This allows us to skip looking up the secret type for deserialization altogether case SecretSerializationType::KEY_VALUE_SECRET: - return KeyValueSecret::Deserialize(deserializer, {scope, type, provider, name}); + return KeyValueSecret::Deserialize( + deserializer, BaseSecret(scope, Identifier(type), Identifier(provider), Identifier(name))); // Continues below: we need to do a type lookup to find the secret deserialize method case SecretSerializationType::CUSTOM: break; @@ -118,7 +120,7 @@ unique_ptr SecretManager::DeserializeSecret(Deserializer &deserializ SecretType deserialized_type; if (!TryLookupTypeInternal(type, deserialized_type)) { - ThrowTypeNotFoundError(type, secret_path); + ThrowTypeNotFoundError(Identifier(type), secret_path); } if (!deserialized_type.deserializer) { @@ -126,7 +128,8 @@ unique_ptr SecretManager::DeserializeSecret(Deserializer &deserializ "Attempted to deserialize secret type '%s' which does not have a deserialization method", type); } - return deserialized_type.deserializer(deserializer, {scope, type, provider, name}); + return deserialized_type.deserializer(deserializer, + BaseSecret(scope, Identifier(type), Identifier(provider), Identifier(name))); } void SecretManager::RegisterSecretType(SecretType &type) { @@ -143,13 +146,14 @@ unique_ptr SecretManager::RegisterSecret(CatalogTransaction transac unique_ptr secret, OnCreateConflict on_conflict, SecretPersistType persist_type, const string &storage) { InitializeSecrets(transaction); - return RegisterSecretInternal(transaction, std::move(secret), on_conflict, persist_type, storage); + return RegisterSecretInternal(transaction, std::move(secret), on_conflict, persist_type, Identifier(storage)); } unique_ptr SecretManager::RegisterSecretInternal(CatalogTransaction transaction, unique_ptr secret, OnCreateConflict on_conflict, - SecretPersistType persist_type, const string &storage) { + SecretPersistType persist_type, + const Identifier &storage) { //! Ensure we only create secrets for known types; LookupTypeInternal(secret->GetType()); @@ -170,11 +174,11 @@ unique_ptr SecretManager::RegisterSecretInternal(CatalogTransaction resolved_storage = persist_type == SecretPersistType::PERSISTENT ? config.default_persistent_storage : TEMPORARY_STORAGE_NAME; } else { - resolved_storage = storage; + resolved_storage = storage.GetIdentifierName(); } //! Lookup which backend to store the secret in - auto backend = GetSecretStorage(resolved_storage); + auto backend = GetSecretStorage(Identifier(resolved_storage)); if (!backend) { if (!config.allow_persistent_secrets && (persist_type == SecretPersistType::PERSISTENT || storage == LOCAL_FILE_STORAGE_NAME)) { @@ -203,26 +207,27 @@ unique_ptr SecretManager::RegisterSecretInternal(CatalogTransaction return backend->StoreSecret(std::move(secret), on_conflict, &transaction); } -optional_ptr SecretManager::LookupFunctionInternal(const string &type, const string &provider) { +optional_ptr SecretManager::LookupFunctionInternal(const Identifier &type, + const Identifier &provider) { unique_lock lck(manager_lock); auto lookup = secret_functions.find(type); if (lookup != secret_functions.end()) { if (lookup->second.ProviderExists(provider)) { - return &lookup->second.GetFunction(provider); + return &lookup->second.GetFunction(provider.GetIdentifierName()); } } // Try autoloading lck.unlock(); - AutoloadExtensionForFunction(type, provider); + AutoloadExtensionForFunction(type.GetIdentifierName(), provider.GetIdentifierName()); lck.lock(); lookup = secret_functions.find(type); if (lookup != secret_functions.end()) { if (lookup->second.ProviderExists(provider)) { - return &lookup->second.GetFunction(provider); + return &lookup->second.GetFunction(provider.GetIdentifierName()); } } @@ -238,13 +243,13 @@ unique_ptr SecretManager::CreateSecret(ClientContext &context, cons auto function_input = input; if (function_input.provider.empty()) { auto secret_type = LookupTypeInternal(function_input.type); - function_input.provider = secret_type.default_provider; + function_input.provider = Identifier(secret_type.default_provider); } // Lookup function auto function_lookup = LookupFunctionInternal(function_input.type, function_input.provider); if (!function_lookup) { - ThrowProviderNotFoundError(input.type, input.provider); + ThrowProviderNotFoundError(input.type.GetIdentifierName(), input.provider.GetIdentifierName()); } // Call the function @@ -270,7 +275,7 @@ BoundStatement SecretManager::BindCreateSecret(CatalogTransaction transaction, C if (provider.empty()) { default_provider = true; auto secret_type = LookupTypeInternal(type); - provider = secret_type.default_provider; + provider = Identifier(secret_type.default_provider); } string default_string = default_provider ? "default " : ""; @@ -278,7 +283,7 @@ BoundStatement SecretManager::BindCreateSecret(CatalogTransaction transaction, C auto function = LookupFunctionInternal(type, provider); if (!function) { - ThrowProviderNotFoundError(info.type, info.provider, default_provider); + ThrowProviderNotFoundError(info.type.GetIdentifierName(), info.provider.GetIdentifierName(), default_provider); } auto bound_info = info; @@ -286,7 +291,7 @@ BoundStatement SecretManager::BindCreateSecret(CatalogTransaction transaction, C // We cast the passed parameters for (const auto ¶m : info.options) { - auto matched_param = function->named_parameters.find(param.first); + auto matched_param = function->named_parameters.find(Identifier(param.first)); if (matched_param == function->named_parameters.end()) { throw BinderException("Unknown parameter '%s' for secret type '%s' with %sprovider '%s'", param.first, type, default_string, provider); @@ -300,7 +305,7 @@ BoundStatement SecretManager::BindCreateSecret(CatalogTransaction transaction, C matched_param->second.ToString(), error_msg); } - bound_info.options[matched_param->first] = {cast_value}; + bound_info.options[Identifier(matched_param->first.GetIdentifierName()).GetIdentifierName()] = {cast_value}; } BoundStatement result; @@ -342,7 +347,7 @@ unique_ptr SecretManager::GetSecretByName(CatalogTransaction transa bool found = false; if (!storage.empty()) { - auto storage_lookup = GetSecretStorage(storage); + auto storage_lookup = GetSecretStorage(Identifier(storage)); if (!storage_lookup) { throw InvalidInputException("Unknown secret storage found: '%s'", storage); @@ -367,16 +372,16 @@ unique_ptr SecretManager::GetSecretByName(CatalogTransaction transa return result; } -void SecretManager::DropSecretByName(CatalogTransaction transaction, const string &name, +void SecretManager::DropSecretByName(CatalogTransaction transaction, const Identifier &name, OnEntryNotFound on_entry_not_found, SecretPersistType persist_type, - const string &storage) { + const Identifier &storage) { InitializeSecrets(transaction); vector> matches; // storage to drop from was specified directly if (!storage.empty()) { - auto storage_lookup = GetSecretStorage(storage); + auto storage_lookup = GetSecretStorage(Identifier(storage)); if (!storage_lookup) { throw InvalidInputException("Unknown storage type found for drop secret: '%s'", storage); } @@ -390,7 +395,7 @@ void SecretManager::DropSecretByName(CatalogTransaction transaction, const strin continue; } - auto lookup = storage_ref.get().GetSecretByName(name, &transaction); + auto lookup = storage_ref.get().GetSecretByName(name.GetIdentifierName(), &transaction); if (lookup) { matches.push_back(storage_ref.get()); } @@ -420,12 +425,12 @@ void SecretManager::DropSecretByName(CatalogTransaction transaction, const strin } // Do nothing on OnEntryNotFound::RETURN_NULL... } else { - matches[0].get().DropSecretByName(name, on_entry_not_found, &transaction); + matches[0].get().DropSecretByName(Identifier(name), on_entry_not_found, &transaction); } } SecretType SecretManager::LookupType(const string &type) { - return LookupTypeInternal(type); + return LookupTypeInternal(Identifier(type)); } void SecretManager::RegisterSecretTypeInternal(SecretType &type) { @@ -438,7 +443,7 @@ void SecretManager::RegisterSecretTypeInternal(SecretType &type) { bool SecretManager::TryLookupTypeInternal(const string &type, SecretType &type_out) { unique_lock lck(manager_lock); - auto lookup = secret_types.find(type); + auto lookup = secret_types.find(Identifier(type)); if (lookup != secret_types.end()) { type_out = lookup->second; return true; @@ -449,7 +454,7 @@ bool SecretManager::TryLookupTypeInternal(const string &type, SecretType &type_o AutoloadExtensionForType(type); lck.lock(); - lookup = secret_types.find(type); + lookup = secret_types.find(Identifier(type)); if (lookup != secret_types.end()) { type_out = lookup->second; return true; @@ -458,23 +463,23 @@ bool SecretManager::TryLookupTypeInternal(const string &type, SecretType &type_o return false; } -SecretType SecretManager::LookupTypeInternal(const string &type) { +SecretType SecretManager::LookupTypeInternal(const Identifier &type) { SecretType return_value; - if (!TryLookupTypeInternal(type, return_value)) { + if (!TryLookupTypeInternal(type.GetIdentifierName(), return_value)) { ThrowTypeNotFoundError(type); } return return_value; } void SecretManager::RegisterSecretFunctionInternal(CreateSecretFunction function, OnCreateConflict on_conflict) { - auto lookup = secret_functions.find(function.secret_type); + auto lookup = secret_functions.find(Identifier(function.secret_type)); if (lookup != secret_functions.end()) { lookup->second.AddFunction(function, on_conflict); return; } CreateSecretFunctionSet new_set(function.secret_type); new_set.AddFunction(function, OnCreateConflict::ERROR_ON_CONFLICT); - secret_functions.insert({function.secret_type, new_set}); + secret_functions.insert(make_pair(Identifier(function.secret_type), new_set)); } vector SecretManager::AllSecrets(CatalogTransaction transaction) { @@ -576,8 +581,9 @@ void SecretManager::AutoloadExtensionForType(const string &type) { ExtensionHelper::TryAutoloadFromEntry(*db, StringUtil::Lower(type), EXTENSION_SECRET_TYPES); } -void SecretManager::ThrowTypeNotFoundError(const string &type, const string &secret_path) { - auto entry = ExtensionHelper::FindExtensionInEntries(StringUtil::Lower(type), EXTENSION_SECRET_TYPES); +void SecretManager::ThrowTypeNotFoundError(const Identifier &type, const string &secret_path) { + auto entry = + ExtensionHelper::FindExtensionInEntries(StringUtil::Lower(type.GetIdentifierName()), EXTENSION_SECRET_TYPES); string error_message; if (!entry.empty() && db) { @@ -621,7 +627,7 @@ void SecretManager::ThrowProviderNotFoundError(const string &type, const string throw InvalidInputException("Secret provider '%s' not found for type '%s'", provider, type); } -optional_ptr SecretManager::GetSecretStorage(const string &name) { +optional_ptr SecretManager::GetSecretStorage(const Identifier &name) { lock_guard lock(manager_lock); auto lookup = secret_storages.find(name); @@ -645,18 +651,18 @@ vector> SecretManager::GetSecretStorages() { } DefaultSecretGenerator::DefaultSecretGenerator(Catalog &catalog, SecretManager &secret_manager, - case_insensitive_set_t &persistent_secrets) + identifier_set_t &persistent_secrets) : DefaultGenerator(catalog), secret_manager(secret_manager), persistent_secrets(persistent_secrets) { } -unique_ptr DefaultSecretGenerator::CreateDefaultEntryInternal(const string &entry_name) { +unique_ptr DefaultSecretGenerator::CreateDefaultEntryInternal(const Identifier &entry_name) { lock_guard guard(lock); - auto secret_lu = persistent_secrets.find(entry_name); + auto secret_lu = persistent_secrets.find(Identifier(entry_name)); if (secret_lu == persistent_secrets.end()) { return nullptr; } - LocalFileSystem fs; + auto &fs = FileSystem::GetLocal(catalog.GetDatabase()); string base_secret_path = secret_manager.PersistentSecretPath(); string secret_path = fs.JoinPath(base_secret_path, entry_name + ".duckdb_secret"); @@ -712,20 +718,21 @@ unique_ptr DefaultSecretGenerator::CreateDefaultEntryInternal(cons } unique_ptr DefaultSecretGenerator::CreateDefaultEntry(CatalogTransaction transaction, - const string &entry_name) { + const Identifier &entry_name) { return CreateDefaultEntryInternal(entry_name); } -unique_ptr DefaultSecretGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { +unique_ptr DefaultSecretGenerator::CreateDefaultEntry(ClientContext &context, + const Identifier &entry_name) { return CreateDefaultEntryInternal(entry_name); } -vector DefaultSecretGenerator::GetDefaultEntries() { - vector ret; +vector DefaultSecretGenerator::GetDefaultEntries() { + vector ret; lock_guard guard(lock); for (const auto &res : persistent_secrets) { - ret.push_back(res); + ret.emplace_back(res.GetIdentifierName()); } return ret; @@ -738,8 +745,8 @@ SecretManager &SecretManager::Get(DatabaseInstance &db) { return *DBConfig::GetConfig(db).secret_manager; } -void SecretManager::DropSecretByName(ClientContext &context, const string &name, OnEntryNotFound on_entry_not_found, - SecretPersistType persist_type, const string &storage) { +void SecretManager::DropSecretByName(ClientContext &context, const Identifier &name, OnEntryNotFound on_entry_not_found, + SecretPersistType persist_type, const Identifier &storage) { auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); return DropSecretByName(transaction, name, on_entry_not_found, persist_type, storage); } diff --git a/src/duckdb/src/main/secret/secret_storage.cpp b/src/duckdb/src/main/secret/secret_storage.cpp index 717f4e81e..1173295ee 100644 --- a/src/duckdb/src/main/secret/secret_storage.cpp +++ b/src/duckdb/src/main/secret/secret_storage.cpp @@ -1,6 +1,8 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/local_file_system.hpp" +#include "duckdb/common/virtual_file_system.hpp" +#include "duckdb/main/database_file_opener.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/serializer/binary_serializer.hpp" #include "duckdb/common/serializer/buffered_file_reader.hpp" @@ -85,9 +87,9 @@ vector CatalogSetSecretStorage::AllSecrets(optional_ptr transaction) { - auto entry = secrets->GetEntry(GetTransactionOrDefault(transaction), name); + auto entry = secrets->GetEntry(GetTransactionOrDefault(transaction), Identifier(name)); if (!entry) { if (on_entry_not_found == OnEntryNotFound::THROW_EXCEPTION) { string persist_string = persistent ? "persistent" : "temporary"; @@ -98,8 +100,8 @@ void CatalogSetSecretStorage::DropSecretByName(const string &name, OnEntryNotFou return; } - secrets->DropEntry(GetTransactionOrDefault(transaction), name, true, true); - RemoveSecret(name, on_entry_not_found); + secrets->DropEntry(GetTransactionOrDefault(transaction), Identifier(name), true, true); + RemoveSecret(name.GetIdentifierName(), on_entry_not_found); } SecretMatch CatalogSetSecretStorage::LookupSecret(const string &path, const string &type, @@ -108,7 +110,7 @@ SecretMatch CatalogSetSecretStorage::LookupSecret(const string &path, const stri const std::function callback = [&](CatalogEntry &entry) { auto &cast_entry = entry.Cast(); - if (StringUtil::CIEquals(cast_entry.secret->secret->GetType(), type)) { + if (cast_entry.secret->secret->GetType() == type) { best_match = SelectBestMatch(*cast_entry.secret, path, tie_break_offset, best_match); } }; @@ -123,7 +125,7 @@ SecretMatch CatalogSetSecretStorage::LookupSecret(const string &path, const stri unique_ptr CatalogSetSecretStorage::GetSecretByName(const string &name, optional_ptr transaction) { - auto res = secrets->GetEntry(GetTransactionOrDefault(transaction), name); + auto res = secrets->GetEntry(GetTransactionOrDefault(transaction), Identifier(name)); if (res) { auto &cast_entry = res->Cast(); @@ -140,16 +142,24 @@ LocalFileSecretStorage::LocalFileSecretStorage(SecretManager &manager, DatabaseI persistent = true; // Check existence of persistent secret dir - LocalFileSystem fs; - if (fs.DirectoryExists(secret_path)) { - fs.ListFiles(secret_path, [&](const string &fname, bool is_dir) { - string full_path = fs.JoinPath(secret_path, fname); - - if (StringUtil::EndsWith(full_path, ".duckdb_secret")) { - string secret_name = fname.substr(0, fname.size() - 14); // size of file ext - persistent_secrets.insert(secret_name); - } - }); + try { + auto &fs = FileSystem::GetLocal(db); + if (fs.DirectoryExists(secret_path)) { + fs.ListFiles(secret_path, [&](const string &fname, bool is_dir) { + string full_path = fs.JoinPath(secret_path, fname); + + if (StringUtil::EndsWith(full_path, ".duckdb_secret")) { + string secret_name = fname.substr(0, fname.size() - 14); // size of file ext + persistent_secrets.insert(Identifier(secret_name)); + } + }); + } + } catch (PermissionException &ex) { + // If LocalFileSystem is specifically disabled (not all external access), skip loading persistent secrets + auto &vfs = static_cast(*DBConfig::GetConfig(db).file_system); + if (!vfs.SubSystemIsDisabled("LocalFileSystem")) { + throw; + } } auto &catalog = Catalog::GetSystemCatalog(db); @@ -189,7 +199,7 @@ static void WriteSecretFileToDisk(FileSystem &fs, const string &path, const Base } void LocalFileSecretStorage::WriteSecret(const BaseSecret &secret, OnCreateConflict on_conflict) { - LocalFileSystem fs; + auto &fs = FileSystem::GetLocal(db); // We may need to create the secret dir here if the directory was not present during LocalFileSecretStorage // construction @@ -220,9 +230,9 @@ void LocalFileSecretStorage::WriteSecret(const BaseSecret &secret, OnCreateConfl } void LocalFileSecretStorage::RemoveSecret(const string &secret, OnEntryNotFound on_entry_not_found) { - LocalFileSystem fs; + auto &fs = FileSystem::GetLocal(db); string file = fs.JoinPath(secret_path, secret + ".duckdb_secret"); - persistent_secrets.erase(secret); + persistent_secrets.erase(Identifier(secret)); try { fs.RemoveFile(file); } catch (std::exception &ex) { diff --git a/src/duckdb/src/main/settings/autogenerated_settings.cpp b/src/duckdb/src/main/settings/autogenerated_settings.cpp index e10a5c4af..9438c93de 100644 --- a/src/duckdb/src/main/settings/autogenerated_settings.cpp +++ b/src/duckdb/src/main/settings/autogenerated_settings.cpp @@ -105,6 +105,20 @@ void DebugWindowModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) EnumUtil::FromString(StringValue::Get(parameter)); } +//===----------------------------------------------------------------------===// +// Default Io Mode +//===----------------------------------------------------------------------===// +void DefaultIoModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); +} + +//===----------------------------------------------------------------------===// +// Default Transaction Invalidation Policy +//===----------------------------------------------------------------------===// +void DefaultTransactionInvalidationPolicySetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); +} + //===----------------------------------------------------------------------===// // Deprecated Using Key Syntax //===----------------------------------------------------------------------===// @@ -149,6 +163,14 @@ void ForceBitpackingModeSetting::OnSet(SettingCallbackInfo &info, Value ¶met EnumUtil::FromString(StringValue::Get(parameter)); } +//===----------------------------------------------------------------------===// +// H T T P Proxy +//===----------------------------------------------------------------------===// +Value HTTPProxySetting::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(config.options.http_proxy); +} + //===----------------------------------------------------------------------===// // Lambda Syntax //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/main/settings/custom_settings.cpp b/src/duckdb/src/main/settings/custom_settings.cpp index 1e8632976..a4da981b8 100644 --- a/src/duckdb/src/main/settings/custom_settings.cpp +++ b/src/duckdb/src/main/settings/custom_settings.cpp @@ -16,6 +16,7 @@ #include "duckdb/common/enum_util.hpp" #include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/common/file_system.hpp" #include "duckdb/common/operator/double_cast_operator.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" @@ -78,9 +79,7 @@ bool AccessModeSetting::OnGlobalSet(DatabaseInstance *db, DBConfig &config, cons // Allocator Background Threads //===----------------------------------------------------------------------===// void AllocatorBackgroundThreadsSetting::OnSet(SettingCallbackInfo &info, Value &input) { - if (info.db) { - TaskScheduler::GetScheduler(*info.db).SetAllocatorBackgroundThreads(input.GetValue()); - } + Allocator::SetBackgroundThreads(input.GetValue()); } //===----------------------------------------------------------------------===// @@ -128,23 +127,8 @@ Value DeltaOnlyVariantEncodingEnabledSetting::GetSetting(const ClientContext &co //===----------------------------------------------------------------------===// // Allocator Flush Threshold //===----------------------------------------------------------------------===// -void AllocatorFlushThresholdSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.allocator_flush_threshold = DBConfig::ParseMemoryLimit(input.ToString()); - if (db) { - TaskScheduler::GetScheduler(*db).SetAllocatorFlushThreshold(config.options.allocator_flush_threshold); - } -} - -void AllocatorFlushThresholdSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.allocator_flush_threshold = DBConfigOptions().allocator_flush_threshold; - if (db) { - TaskScheduler::GetScheduler(*db).SetAllocatorFlushThreshold(config.options.allocator_flush_threshold); - } -} - -Value AllocatorFlushThresholdSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::BytesToHumanReadableString(config.options.allocator_flush_threshold)); +void AllocatorFlushThresholdSetting::OnSet(SettingCallbackInfo &info, Value &input) { + StringUtil::ParseFormattedBytes(input.ToString()); } //===----------------------------------------------------------------------===// @@ -342,192 +326,20 @@ Value CheckpointThresholdSetting::GetSetting(const ClientContext &context) { //===----------------------------------------------------------------------===// // Configure Profiling //===----------------------------------------------------------------------===// -bool IsEnabledOptimizer(MetricType metric, const set &disabled_optimizers) { - auto matching_optimizer_type = MetricsUtils::GetOptimizerTypeByMetric(metric); - return matching_optimizer_type != OptimizerType::INVALID && - disabled_optimizers.find(matching_optimizer_type) == disabled_optimizers.end(); -} - -template -static profiler_settings_t ExtractSettings(ExtractFromType extract_from, const set &disabled_optimizers, - vector &invalid_settings) { - profiler_settings_t enabled_metrics; - - auto insert_if_enabled = [&](MetricType m) { - if (!MetricsUtils::IsOptimizerMetric(m) || IsEnabledOptimizer(m, disabled_optimizers)) { - enabled_metrics.insert(m); - } - }; - - extract_from([&](const std::string &metric) { - const auto upper = StringUtil::Upper(metric); - try { - insert_if_enabled(EnumUtil::FromString(upper)); - } catch (std::exception &) { - try { - auto group = EnumUtil::FromString(upper); - for (auto &converted_metric : MetricsUtils::GetMetricsByGroupType(group)) { - insert_if_enabled(converted_metric); - } - } catch (std::exception &) { - invalid_settings.push_back(metric); - } - } - }); - return enabled_metrics; -} - -void AddOptimizerMetrics(profiler_settings_t &settings, const set &disabled_optimizers) { - if (settings.find(MetricType::ALL_OPTIMIZERS) != settings.end()) { - auto optimizer_metrics = MetricsUtils::GetOptimizerMetrics(); - for (auto &metric : optimizer_metrics) { - if (IsEnabledOptimizer(metric, disabled_optimizers)) { - settings.insert(metric); - } - } - } -} - -void ExtractFromList(ClientConfig &config, profiler_settings_t &enabled_metrics, vector &invalid_settings, - const Value &input, const set &disabled_optimizers) { - config.profiler_settings_type = LogicalTypeId::LIST; - - enabled_metrics = ExtractSettings( - [&](const std::function &func) { - for (auto &val : ListValue::GetChildren(input)) { - func(val.GetValue()); - } - }, - disabled_optimizers, invalid_settings); -} - -void ExtractFromStruct(ClientConfig &config, profiler_settings_t &enabled_metrics, vector &invalid_settings, - const Value &input, const set &disabled_optimizers) { - config.profiler_settings_type = LogicalTypeId::STRUCT; - - enabled_metrics = ExtractSettings( - [&](const std::function &func) { - auto &children = StructValue::GetChildren(input); - for (idx_t i = 0; i < children.size(); i++) { - auto child_val = children[i]; - if ((child_val.type() == LogicalType::BOOLEAN && child_val.GetValue() == true) || - StringUtil::Lower(child_val.ToString()) == "true") { - func(StructType::GetChildName(input.type(), i)); - } - } - }, - disabled_optimizers, invalid_settings); -} - -void ExtractFromJSON(ClientConfig &config, profiler_settings_t &enabled_metrics, vector &invalid_settings, - const Value &input, const set &disabled_optimizers) { - config.profiler_settings_type = LogicalTypeId::VARCHAR; - - // JSON string: parse, then accept entries with value == "true" - std::unordered_map json; - try { - json = StringUtil::ParseJSONMap(input.ToString())->Flatten(); - } catch (std::exception &ex) { - throw IOException("Could not parse the custom profiler settings file due to incorrect JSON: \"%s\". Make " - "sure all the keys and values start with a quote. (error: %s)", - input.ToString(), ex.what()); - } - - enabled_metrics = ExtractSettings( - [&](const std::function &func) { - for (auto &entry : json) { - if (StringUtil::Lower(entry.second) == "true") { - func(entry.first); - } - } - }, - disabled_optimizers, invalid_settings); -} - -void ConstructInvalidSettingsAndThrow(const vector &invalid_settings) { - string invalid_settings_str; - for (auto &invalid_setting : invalid_settings) { - if (!invalid_settings_str.empty()) { - invalid_settings_str += ", "; - } - invalid_settings_str += invalid_setting; - } - throw IOException("Invalid custom profiler settings: \"%s\"", invalid_settings_str); -} - void ConfigureProfilingSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - - auto &db_config = DBConfig::GetConfig(context); - auto &disabled_optimizers = db_config.options.disabled_optimizers; - - vector invalid_settings; - profiler_settings_t enabled_metrics; - if (input.type() == LogicalType::LIST(LogicalType::VARCHAR)) { - ExtractFromList(config, enabled_metrics, invalid_settings, input, disabled_optimizers); - } else if (input.type().id() == LogicalTypeId::STRUCT) { - ExtractFromStruct(config, enabled_metrics, invalid_settings, input, disabled_optimizers); - } else if (input.type() == LogicalType::VARCHAR) { - ExtractFromJSON(config, enabled_metrics, invalid_settings, input, disabled_optimizers); - } else { - throw ParserException("Invalid custom profiler settings type \"%s\", expected LIST(VARCHAR) or JSON", - input.type().ToString()); - } - - if (!invalid_settings.empty()) { - ConstructInvalidSettingsAndThrow(invalid_settings); - } - - AddOptimizerMetrics(enabled_metrics, disabled_optimizers); - config.enable_profiler = true; - config.profiler_settings = enabled_metrics; + throw InvalidInputException( + "configure_profiling (and its aliases configure_metrics, custom_profiling_settings) is deprecated. " + "Use SET tracked_metrics = '...' instead. " + "For example: SET tracked_metrics = '*' to track all metrics, " + "or SET tracked_metrics = ['query.total_time', 'operator.*'] to track specific metrics."); } void ConfigureProfilingSetting::ResetLocal(ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - config.enable_profiler = ClientConfig().enable_profiler; - config.profiler_settings = MetricsUtils::GetDefaultMetrics(); - config.profiler_settings_type = LogicalTypeId::VARCHAR; + ClientConfig::GetConfig(context).tracked_metrics = ClientConfig().tracked_metrics; } Value ConfigureProfilingSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - - set enabled_settings; - for (auto &entry : config.profiler_settings) { - enabled_settings.insert(EnumUtil::ToString(entry)); - } - - switch (config.profiler_settings_type) { - case LogicalTypeId::VARCHAR: { - // i.e. JSON - string profiling_settings_str; - for (auto &entry : enabled_settings) { - if (!profiling_settings_str.empty()) { - profiling_settings_str += ", "; - } - profiling_settings_str += "\"" + entry + "\": \"true\""; - } - - return Value(StringUtil::Format("{%s}", profiling_settings_str)); - } - case LogicalTypeId::STRUCT: { - child_list_t children; - for (auto &entry : enabled_settings) { - children.emplace_back(entry, Value::BOOLEAN(true)); - } - return Value::STRUCT(std::move(children)); - } - case LogicalTypeId::LIST: { - vector children; - for (auto &entry : enabled_settings) { - children.emplace_back(entry); - } - return Value::LIST(std::move(children)); - } - default: - throw InternalException("Invalid custom profiler settings type"); - } + return TrackedMetricsSetting::GetSetting(context); } //===----------------------------------------------------------------------===// @@ -1070,18 +882,6 @@ void EnableProfilingSetting::SetLocal(ClientContext &context, const Value &input config.profiler_print_format = ProfilerPrintFormat::QUERY_TREE; } else if (parameter == "query_tree_optimizer") { config.profiler_print_format = ProfilerPrintFormat::QUERY_TREE_OPTIMIZER; - - // add optimizer settings to the profiler settings - auto optimizer_settings = MetricsUtils::GetOptimizerMetrics(); - for (auto &setting : optimizer_settings) { - config.profiler_settings.insert(setting); - } - - // add the phase timing settings to the profiler settings - auto phase_timing_settings = MetricsUtils::GetPhaseTimingMetrics(); - for (auto &setting : phase_timing_settings) { - config.profiler_settings.insert(setting); - } } else if (parameter == "no_output") { config.profiler_print_format = ProfilerPrintFormat::NO_OUTPUT; config.emit_profiler_output = false; @@ -1101,7 +901,6 @@ void EnableProfilingSetting::ResetLocal(ClientContext &context) { config.profiler_print_format = ClientConfig().profiler_print_format; config.enable_profiler = ClientConfig().enable_profiler; config.emit_profiler_output = ClientConfig().emit_profiler_output; - config.profiler_settings = ClientConfig().profiler_settings; } Value EnableProfilingSetting::GetSetting(const ClientContext &context) { @@ -1217,37 +1016,6 @@ void HomeDirectorySetting::OnSet(SettingCallbackInfo &info, Value &input) { } } -//===----------------------------------------------------------------------===// -// Enable H T T P Logging -//===----------------------------------------------------------------------===// -void EnableHTTPLoggingSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.enable_http_logging = input.GetValue(); - - // NOTE: this is a deprecated setting: we mimic the old behaviour by setting the log storage output to STDOUT and - // enabling logging for http only. Note that this behaviour is slightly wonky in that it sets all sorts of logging - // config - auto &log_manager = LogManager::Get(context); - if (config.enable_http_logging) { - log_manager.SetEnableLogging(true); - log_manager.SetLogLevel(HTTPLogType::LEVEL); - unordered_set enabled_log_types = {HTTPLogType::NAME}; - log_manager.SetEnabledLogTypes(enabled_log_types); - log_manager.SetLogStorage(*context.db, LogConfig::STDOUT_STORAGE_NAME); - } else { - log_manager.SetEnableLogging(false); - } -} - -void EnableHTTPLoggingSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).enable_http_logging = ClientConfig().enable_http_logging; -} - -Value EnableHTTPLoggingSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.enable_http_logging); -} - //===----------------------------------------------------------------------===// // Enable Mbedtls //===----------------------------------------------------------------------===// @@ -1287,21 +1055,14 @@ Value ForceMbedtlsUnsafeSetting::GetSetting(const ClientContext &context) { } //===----------------------------------------------------------------------===// -// H T T P Logging Output +// HTTP Proxy //===----------------------------------------------------------------------===// -void HTTPLoggingOutputSetting::SetLocal(ClientContext &context, const Value &input) { - throw NotImplementedException("This setting is deprecated and can no longer be used. Check out the DuckDB docs on " - "logging for more information"); +void HTTPProxySetting::SetGlobal(DatabaseInstance *, DBConfig &config, const Value &input) { + config.options.http_proxy = input.GetValue(); } -void HTTPLoggingOutputSetting::ResetLocal(ClientContext &context) { - throw NotImplementedException("This setting is deprecated and can no longer be used. Check out the DuckDB docs on " - "logging for more information"); -} - -Value HTTPLoggingOutputSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value(config.http_logging_output); +void HTTPProxySetting::ResetGlobal(DatabaseInstance *, DBConfig &config) { + config.options.http_proxy = FileSystem::GetEnvVariable("HTTP_PROXY"); } //===----------------------------------------------------------------------===// @@ -1314,6 +1075,20 @@ void IndexScanPercentageSetting::OnSet(SettingCallbackInfo &, Value &input) { } } +//===----------------------------------------------------------------------===// +// Initial Column Segment Size +//===----------------------------------------------------------------------===// +void InitialColumnSegmentSizeSetting::OnSet(SettingCallbackInfo &, Value &input) { + auto initial_column_segment_size = input.GetValue(); + if (initial_column_segment_size < DEFAULT_BLOCK_HEADER_STORAGE_SIZE) { + throw InvalidInputException( + "The initial column segment size must be at least 8 bytes (default block header storage size)"); + } + if (!IsPowerOfTwo(initial_column_segment_size)) { + throw InvalidInputException("The initial column segment size must be a power of two"); + } +} + //===----------------------------------------------------------------------===// // Log Query Path //===----------------------------------------------------------------------===// @@ -1498,34 +1273,19 @@ void ProfilingModeSetting::SetLocal(ClientContext &context, const Value &input) } else if (parameter == "detailed") { config.enable_profiler = true; config.enable_detailed_profiling = true; - - // add optimizer settings to the profiler settings - auto optimizer_settings = MetricsUtils::GetOptimizerMetrics(); - for (auto &setting : optimizer_settings) { - config.profiler_settings.insert(setting); - } - - // add the phase timing settings to the profiler settings - auto phase_timing_settings = MetricsUtils::GetPhaseTimingMetrics(); - for (auto &setting : phase_timing_settings) { - config.profiler_settings.insert(setting); - } } else if (parameter == "all") { config.enable_profiler = true; - auto all_metrics = MetricsUtils::GetAllMetrics(); - for (auto &metric : all_metrics) { - config.profiler_settings.insert(metric); - } } else { - throw ParserException("Unrecognized profiling mode \"%s\", supported formats: [standard, detailed]", parameter); + throw ParserException("Unrecognized profiling mode \"%s\", supported formats: [standard, detailed, all]", + parameter); } } void ProfilingModeSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).enable_profiler = ClientConfig().enable_profiler; - ClientConfig::GetConfig(context).enable_detailed_profiling = ClientConfig().enable_detailed_profiling; - ClientConfig::GetConfig(context).emit_profiler_output = ClientConfig().emit_profiler_output; - ClientConfig::GetConfig(context).profiler_settings = ClientConfig().profiler_settings; + auto &config = ClientConfig::GetConfig(context); + config.enable_profiler = ClientConfig().enable_profiler; + config.enable_detailed_profiling = ClientConfig().enable_detailed_profiling; + config.emit_profiler_output = ClientConfig().emit_profiler_output; } Value ProfilingModeSetting::GetSetting(const ClientContext &context) { @@ -1638,21 +1398,36 @@ Value SecretDirectorySetting::GetSetting(const ClientContext &context) { //===----------------------------------------------------------------------===// void StorageCompatibilityVersionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { auto version_string = input.GetValue(); - auto serialization_compatibility = SerializationCompatibility::FromString(version_string); - config.options.serialization_compatibility = serialization_compatibility; + auto storage_compatibility = StorageCompatibility::FromString(version_string); + config.options.storage_compatibility = storage_compatibility; } void StorageCompatibilityVersionSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.serialization_compatibility = DBConfigOptions().serialization_compatibility; + config.options.storage_compatibility = DBConfigOptions().storage_compatibility; } Value StorageCompatibilityVersionSetting::GetSetting(const ClientContext &context) { auto &config = DBConfig::GetConfig(context); - auto &version_name = config.options.serialization_compatibility.duckdb_version; + auto &version_name = config.options.storage_compatibility.duckdb_version; return Value(version_name); } +//===----------------------------------------------------------------------===// +// Standard Vector Size +//===----------------------------------------------------------------------===// +void StandardVectorSizeSetting::SetGlobal(DatabaseInstance *, DBConfig &, const Value &) { + throw InvalidInputException("standard_vector_size is a read-only setting determined at compile time"); +} + +void StandardVectorSizeSetting::ResetGlobal(DatabaseInstance *, DBConfig &) { + throw InvalidInputException("standard_vector_size is a read-only setting determined at compile time"); +} + +Value StandardVectorSizeSetting::GetSetting(const ClientContext &) { + return Value::UBIGINT(DuckDB::StandardVectorSize()); +} + //===----------------------------------------------------------------------===// // Streaming Buffer Size //===----------------------------------------------------------------------===// @@ -1717,6 +1492,35 @@ void TempFileEncryptionSetting::OnSet(SettingCallbackInfo &info, Value &input) { } } +//===----------------------------------------------------------------------===// +// Tracked Metrics +//===----------------------------------------------------------------------===// +void TrackedMetricsSetting::SetLocal(ClientContext &context, const Value &input) { + auto &config = ClientConfig::GetConfig(context); + config.tracked_metrics.clear(); + if (input.type() == LogicalType::LIST(LogicalType::VARCHAR)) { + for (auto &child : ListValue::GetChildren(input)) { + config.tracked_metrics.push_back(child.GetValue()); + } + } else { + throw InvalidInputException("Invalid tracked_metrics type \"%s\", expected LIST(VARCHAR)", + input.type().ToString()); + } +} + +void TrackedMetricsSetting::ResetLocal(ClientContext &context) { + ClientConfig::GetConfig(context).tracked_metrics = ClientConfig().tracked_metrics; +} + +Value TrackedMetricsSetting::GetSetting(const ClientContext &context) { + auto &config = ClientConfig::GetConfig(context); + vector children; + for (const auto &pattern : config.tracked_metrics) { + children.emplace_back(pattern); + } + return Value::LIST(LogicalType::VARCHAR, std::move(children)); +} + //===----------------------------------------------------------------------===// // Threads //===----------------------------------------------------------------------===// @@ -1745,6 +1549,34 @@ Value ThreadsSetting::GetSetting(const ClientContext &context) { return Value::BIGINT(NumericCast(config.options.maximum_threads)); } +//===----------------------------------------------------------------------===// +// Async Threads +//===----------------------------------------------------------------------===// +void AsyncThreadsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto new_val = input.GetValue(); + if (new_val < 0) { + throw SyntaxException("Cannot have negative async_threads!"); + } + auto new_async_threads = NumericCast(new_val); + if (db) { + TaskScheduler::GetScheduler(*db).SetAsyncThreads(new_async_threads); + } + config.options.async_threads = new_async_threads; +} + +void AsyncThreadsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + idx_t new_async_threads = config.GetSystemMaxAsyncThreads(*config.file_system); + if (db) { + TaskScheduler::GetScheduler(*db).SetAsyncThreads(new_async_threads); + } + config.options.async_threads = new_async_threads; +} + +Value AsyncThreadsSetting::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value::BIGINT(NumericCast(config.options.async_threads)); +} + //===----------------------------------------------------------------------===// // Warnings As Errors //===----------------------------------------------------------------------===// @@ -1758,6 +1590,32 @@ void WarningsAsErrorsSetting::OnSet(SettingCallbackInfo &info, Value &input) { } } +//===----------------------------------------------------------------------===// +// Streaming Buffer Size +//===----------------------------------------------------------------------===// +void WriteBufferRowGroupMemoryLimitSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + if (input.IsNull() || input.ToString().empty()) { + config.options.write_buffer_row_group_memory_limit = optional_idx(); + } else { + config.options.write_buffer_row_group_memory_limit = DBConfig::ParseMemoryLimit(input.ToString()); + } +} + +void WriteBufferRowGroupMemoryLimitSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + config.options.write_buffer_row_group_memory_limit = optional_idx(); +} + +Value WriteBufferRowGroupMemoryLimitSetting::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + idx_t bytes = 0; + if (config.options.write_buffer_row_group_memory_limit.IsValid()) { + bytes = config.options.write_buffer_row_group_memory_limit.GetIndex(); + } else { + bytes = config.options.maximum_memory / 5 / (config.options.maximum_threads + 1); + } + return Value(StringUtil::BytesToHumanReadableString(bytes)); +} + void CurrentTransactionInvalidationPolicySetting::OnSet(SettingCallbackInfo &info, Value &input) { if (!info.context) { throw InvalidInputException( diff --git a/src/duckdb/src/main/settings/settings.cpp b/src/duckdb/src/main/settings/settings.cpp index 3ab1a7d4a..321555919 100644 --- a/src/duckdb/src/main/settings/settings.cpp +++ b/src/duckdb/src/main/settings/settings.cpp @@ -17,4 +17,26 @@ bool Settings::TryGetSettingInternal(const DatabaseInstance &db, idx_t setting_i return TryGetSettingInternal(DBConfig::GetConfig(db), setting_index, result); } +void Settings::SetSettingInternal(DatabaseInstance &db, idx_t setting_index, SetScope scope, Value target_value) { + SetSettingInternal(DBConfig::GetConfig(db), setting_index, scope, std::move(target_value)); +} + +void Settings::SetSettingInternal(DBConfig &config, idx_t setting_index, SetScope scope, Value target_value) { + if (scope != SetScope::GLOBAL) { + throw InternalException( + "Attempting to set setting with index %d with non-global scope, but no ClientContext was provided", + setting_index); + } + config.SetOption(setting_index, std::move(target_value)); +} + +void Settings::SetSettingInternal(ClientContext &context, idx_t setting_index, SetScope scope, Value target_value) { + if (scope == SetScope::GLOBAL) { + SetSettingInternal(DBConfig::GetConfig(context), setting_index, scope, std::move(target_value)); + return; + } + auto &client_config = ClientConfig::GetConfig(context); + client_config.user_settings.SetUserSetting(setting_index, std::move(target_value)); +} + } // namespace duckdb diff --git a/src/duckdb/src/optimizer/aggregate_function_rewriter.cpp b/src/duckdb/src/optimizer/aggregate_function_rewriter.cpp index b6e2e6925..536494c66 100644 --- a/src/duckdb/src/optimizer/aggregate_function_rewriter.cpp +++ b/src/duckdb/src/optimizer/aggregate_function_rewriter.cpp @@ -3,6 +3,7 @@ #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/function/aggregate/distributive_function_utils.hpp" #include "duckdb/function/function_binder.hpp" +#include "duckdb/function/aggregate/distributive_functions.hpp" #include "duckdb/optimizer/matcher/expression_matcher.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/planner/binder.hpp" @@ -58,7 +59,14 @@ class AvgRewriteRule : public AggregateRewriteRule { public: bool ShouldSkip(const LogicalAggregate &aggr) const override { - return !aggr.grouping_functions.empty() || aggr.grouping_sets.size() > 1; + // AVG -> SUM/COUNT is correct under any grouping (both ignore NULLs + // identically), so ROLLUP/CUBE/GROUPING SETS are safe to rewrite. + // Decomposing here is also the prerequisite for partial-aggregate + // pushdown to fire on AVG queries. The remaining guard is for + // grouping_functions: their per-row grouping-set bitmasks reference + // the original AVG column directly, and the rewrite has no way to + // translate those references. + return !aggr.grouping_functions.empty(); } unique_ptr Rewrite(unique_ptr &expr, vector> &bindings, @@ -67,10 +75,11 @@ class AvgRewriteRule : public AggregateRewriteRule { FunctionBinder function_binder(optimizer.context); // Move the child out of AVG(x) - auto avg_child = std::move(bindings[0].get().Cast().children[0]); + auto avg_child = std::move(bindings[0].get().Cast().GetChildrenMutable()[0]); // Replace AVG(x) with SUM(x) - auto &sum_entry = catalog.GetEntry(optimizer.context, DEFAULT_SCHEMA, "sum"); + auto &sum_entry = + catalog.GetEntry(optimizer.context, Identifier::DefaultSchema(), "sum"); const auto &sum_fun = sum_entry.functions.GetFunctionByArguments(optimizer.context, {avg_child->GetReturnType()}); vector> args; @@ -123,12 +132,12 @@ class SumRewriteRule : public AggregateRewriteRule { vector> &additional_expressions) override { auto &sum = bindings[0].get().Cast(); auto &addition = bindings[1].get().Cast(); - idx_t const_idx = addition.children[0]->GetExpressionType() == ExpressionType::VALUE_CONSTANT ? 0 : 1; - auto const_expr = std::move(addition.children[const_idx]); - auto main_expr = std::move(addition.children[1 - const_idx]); + idx_t const_idx = addition.GetChildren()[0]->GetExpressionType() == ExpressionType::VALUE_CONSTANT ? 0 : 1; + auto const_expr = std::move(addition.GetChildrenMutable()[const_idx]); + auto main_expr = std::move(addition.GetChildrenMutable()[1 - const_idx]); // Turn SUM(x + C) into SUM(x) - sum.children[0] = main_expr->Copy(); + sum.GetChildrenMutable()[0] = main_expr->Copy(); additional_expressions.push_back(std::move(const_expr)); return main_expr; @@ -173,20 +182,20 @@ class OrderedAggregateMatcher : public ExpressionMatcher { return false; } auto &expr = expr_p.Cast(); - if (!FunctionMatcher::Match(function, expr.function.GetName())) { + if (!FunctionMatcher::Match(function, expr.Function().GetName())) { return false; } - if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { + if (!SetMatcher::Match(matchers, expr.GetChildrenMutable(), bindings, policy)) { return false; } - if (!expr.order_bys && !order_bys.empty()) { + if (!expr.GetOrderBys() && !order_bys.empty()) { return false; } - if (order_bys.size() != expr.order_bys->orders.size()) { + if (order_bys.size() != expr.GetOrderBys()->orders.size()) { return false; } for (idx_t i = 0; i < order_bys.size(); ++i) { - if (!order_bys[i]->Match(*expr.order_bys->orders[i].expression, bindings)) { + if (!order_bys[i]->Match(*expr.GetOrderBys()->orders[i].expression, bindings)) { return false; } } @@ -232,7 +241,7 @@ unique_ptr ListRewriteRule::Rewrite(unique_ptr &expr, ve vector> &additional_expressions) { auto &aggr = bindings[0].get().Cast(); - auto &order_bys = aggr.order_bys; + auto &order_bys = aggr.GetOrderBysMutable(); auto &order_by = order_bys->orders[0]; auto sense = make_uniq(EnumUtil::ToChars(order_by.type)); @@ -295,9 +304,9 @@ class AggregateFunctionRewriterInternal : public LogicalOperatorVisitor { } unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *) override { - const auto entry = aggregate_map.find(expr.binding); + const auto entry = aggregate_map.find(expr.Binding()); if (entry != aggregate_map.end()) { - expr.binding = entry->second; + expr.BindingMutable() = entry->second; } return nullptr; } @@ -322,7 +331,7 @@ class AggregateFunctionRewriterInternal : public LogicalOperatorVisitor { // Add COUNT([x]) to the aggregate list FunctionBinder function_binder(optimizer.context); - const auto count_fun = CountFunctionBase::GetFunction(); + const auto count_fun = count_arg ? CountFunctionBase::GetFunction() : CountStarFun::GetFunction(); vector> count_args; if (count_arg) { count_args.push_back(std::move(count_arg)); diff --git a/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp b/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp index 15910474a..d2cb6e05a 100644 --- a/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp +++ b/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp @@ -4,12 +4,16 @@ #include "duckdb/common/type_visitor.hpp" #include "duckdb/common/types/row/tuple_data_layout.hpp" #include "duckdb/execution/ht_entry.hpp" +#include "duckdb/execution/operator/join/join_filter_pushdown.hpp" #include "duckdb/planner/operator/logical_any_join.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_cte.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_join.hpp" #include "duckdb/planner/operator/logical_order.hpp" #include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/optimizer/join_filter_pushdown_optimizer.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/planner/operator/logical_cross_product.hpp" #include "duckdb/planner/operator/logical_projection.hpp" @@ -17,6 +21,11 @@ namespace duckdb { +struct JoinFilterBuildSideHeuristics { + static constexpr idx_t MIN_FILTER_TARGET_CARDINALITY = 1000000; + static constexpr idx_t MAX_BUILD_TO_TARGET_RATIO = 64; +}; + static void GetRowidBindings(LogicalOperator &op, vector &bindings) { if (op.type == LogicalOperatorType::LOGICAL_GET) { auto &get = op.Cast(); @@ -98,6 +107,51 @@ static inline idx_t ComputeOverlappingBindings(const vector &hays return result; } +static idx_t MaxDynamicFilterTargetCardinality(LogicalOperator &op, const Expression &expr) { + JoinFilterPushdownColumn column; + if (!JoinFilterPushdownUtil::PushdownJoinFilterExpression(expr, column)) { + return 0; + } + + vector columns {std::move(column)}; + vector targets; + JoinFilterPushdownOptimizer::GetPushdownFilterTargets(op, std::move(columns), targets); + + // Dynamic filters are applied at the scan target found here. Estimate the work that the + // filter can avoid at that point, even if later filters reduce the cardinality at the join. + idx_t result = 0; + for (auto &target : targets) { + auto &get = target.get; + result = MaxValue(result, get.has_estimated_cardinality ? get.estimated_cardinality : idx_t(0)); + } + return result; +} + +static double DynamicFilterBuildBonus(LogicalComparisonJoin &join, const idx_t probe_idx, const idx_t build_idx, + const idx_t build_cardinality) { + if (!JoinFilterPushdownOptimizer::IsFiltering(join.children[build_idx])) { + return 1; + } + + idx_t max_target_cardinality = 0; + for (auto &cond : join.conditions) { + if (!cond.IsComparison() || cond.GetComparisonType() != ExpressionType::COMPARE_EQUAL) { + continue; + } + auto &probe_expr = probe_idx == 0 ? cond.GetLHS() : cond.GetRHS(); + max_target_cardinality = + MaxValue(max_target_cardinality, MaxDynamicFilterTargetCardinality(*join.children[probe_idx], probe_expr)); + } + if (max_target_cardinality < JoinFilterBuildSideHeuristics::MIN_FILTER_TARGET_CARDINALITY) { + return 1; + } + if (build_cardinality > 0 && + build_cardinality > max_target_cardinality / JoinFilterBuildSideHeuristics::MAX_BUILD_TO_TARGET_RATIO) { + return 1; + } + return static_cast(max_target_cardinality) / static_cast(MaxValue(build_cardinality, 1)); +} + BuildSize BuildProbeSideOptimizer::GetBuildSizes(const LogicalOperator &op, const idx_t lhs_cardinality, const idx_t rhs_cardinality) { BuildSize ret; @@ -164,6 +218,15 @@ idx_t BuildProbeSideOptimizer::ChildHasJoins(LogicalOperator &op) { } bool BuildProbeSideOptimizer::TryFlipJoinChildren(LogicalOperator &op) const { + auto recursive_preference = GetRecursiveProbeSidePreference(op); + if (recursive_preference != RecursiveProbeSidePreference::NONE) { + if (recursive_preference == RecursiveProbeSidePreference::SWAP) { + FlipChildren(op); + return true; + } + return false; + } + auto &left_child = *op.children[0]; auto &right_child = *op.children[1]; const auto lhs_cardinality = left_child.has_estimated_cardinality ? left_child.estimated_cardinality @@ -177,6 +240,17 @@ bool BuildProbeSideOptimizer::TryFlipJoinChildren(LogicalOperator &op) const { bool swap = false; + if (op.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + auto &join = op.Cast(); + if (JoinFilterPushdownUtil::JoinTypeIsSupported(join.join_type)) { + right_side_build_cost /= DynamicFilterBuildBonus(join, 0, 1, rhs_cardinality); + } + if (HasInverseJoinType(join.join_type) && + JoinFilterPushdownUtil::JoinTypeIsSupported(InverseJoinType(join.join_type))) { + left_side_build_cost /= DynamicFilterBuildBonus(join, 1, 0, lhs_cardinality); + } + } + idx_t left_child_joins = ChildHasJoins(*op.children[0]); idx_t right_child_joins = ChildHasJoins(*op.children[1]); // if the right child is a table scan, and the left child has joins, we should prefer the left child @@ -218,7 +292,94 @@ bool BuildProbeSideOptimizer::TryFlipJoinChildren(LogicalOperator &op) const { return swap; } +bool BuildProbeSideOptimizer::ContainsActiveRecursiveReference(const LogicalOperator &op) const { + if (active_recursive_cte_indexes.empty()) { + return false; + } + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + auto &cte_ref = op.Cast(); + for (auto &active_cte_index : active_recursive_cte_indexes) { + if (cte_ref.cte_index == active_cte_index) { + return true; + } + } + } + for (auto &child : op.children) { + if (ContainsActiveRecursiveReference(*child)) { + return true; + } + } + return false; +} + +bool BuildProbeSideOptimizer::ContainsCorrelationSensitiveOperators(const LogicalOperator &op) const { + switch (op.type) { + case LogicalOperatorType::LOGICAL_DELIM_GET: + case LogicalOperatorType::LOGICAL_DELIM_JOIN: + case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: + return true; + default: + break; + } + for (auto &child : op.children) { + if (ContainsCorrelationSensitiveOperators(*child)) { + return true; + } + } + return false; +} + +RecursiveProbeSidePreference BuildProbeSideOptimizer::GetRecursiveProbeSidePreference(const LogicalOperator &op) const { + if (active_recursive_cte_indexes.empty() || op.type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return RecursiveProbeSidePreference::NONE; + } + + auto &join = op.Cast(); + idx_t has_range = 0; + bool prefer_range_joins = Settings::Get(context); + if (!join.HasEquality(has_range) || prefer_range_joins) { + return RecursiveProbeSidePreference::NONE; + } + + auto left_depends_on_recursive_cte = ContainsActiveRecursiveReference(*op.children[0]); + auto right_depends_on_recursive_cte = ContainsActiveRecursiveReference(*op.children[1]); + if (left_depends_on_recursive_cte == right_depends_on_recursive_cte) { + return RecursiveProbeSidePreference::NONE; + } + + const auto ¤t_build_side = *op.children[1]; + auto current_orientation_can_reuse = left_depends_on_recursive_cte && !PropagatesBuildSide(join.join_type) && + !ContainsCorrelationSensitiveOperators(current_build_side); + if (current_orientation_can_reuse) { + return RecursiveProbeSidePreference::KEEP; + } + + if (!HasInverseJoinType(join.join_type)) { + return RecursiveProbeSidePreference::NONE; + } + + const auto &swapped_build_side = *op.children[0]; + auto swapped_orientation_can_reuse = right_depends_on_recursive_cte && + !PropagatesBuildSide(InverseJoinType(join.join_type)) && + !ContainsCorrelationSensitiveOperators(swapped_build_side); + if (swapped_orientation_can_reuse) { + return RecursiveProbeSidePreference::SWAP; + } + + return RecursiveProbeSidePreference::NONE; +} + void BuildProbeSideOptimizer::VisitOperator(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_RECURSIVE_CTE) { + VisitOperator(*op.children[0]); + active_recursive_cte_indexes.push_back(op.Cast().table_index); + VisitOperator(*op.children[1]); + active_recursive_cte_indexes.pop_back(); + return; + } + + VisitOperatorChildren(op); + // then the currentoperator switch (op.type) { case LogicalOperatorType::LOGICAL_DELIM_JOIN: { @@ -273,8 +434,6 @@ void BuildProbeSideOptimizer::VisitOperator(LogicalOperator &op) { default: break; } - - VisitOperatorChildren(op); } } // namespace duckdb diff --git a/src/duckdb/src/optimizer/column_binding_replacer.cpp b/src/duckdb/src/optimizer/column_binding_replacer.cpp index 926bf3497..85f83f2ee 100644 --- a/src/duckdb/src/optimizer/column_binding_replacer.cpp +++ b/src/duckdb/src/optimizer/column_binding_replacer.cpp @@ -28,8 +28,8 @@ void ColumnBindingReplacer::VisitExpression(unique_ptr *expression) if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { auto &bound_column_ref = expr->Cast(); for (const auto &replace_binding : replacement_bindings) { - if (bound_column_ref.binding == replace_binding.old_binding) { - bound_column_ref.binding = replace_binding.new_binding; + if (bound_column_ref.Binding() == replace_binding.old_binding) { + bound_column_ref.BindingMutable() = replace_binding.new_binding; if (replace_binding.replace_type) { bound_column_ref.SetReturnType(replace_binding.new_type); } diff --git a/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp b/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp index abbc0370e..66a13a9a8 100644 --- a/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp +++ b/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp @@ -52,7 +52,7 @@ void ColumnLifetimeAnalyzer::StandardVisitOperator(LogicalOperator &op) { void ColumnLifetimeAnalyzer::ExtractColumnBindings(const Expression &expr, vector &bindings) { ExpressionIterator::VisitExpression( - expr, [&](const BoundColumnRefExpression &bound_ref) { bindings.push_back(bound_ref.binding); }); + expr, [&](const BoundColumnRefExpression &bound_ref) { bindings.push_back(bound_ref.Binding()); }); } void ColumnLifetimeAnalyzer::VisitOperator(LogicalOperator &op) { @@ -268,7 +268,7 @@ void ColumnLifetimeAnalyzer::AddVerificationProjection(unique_ptr ColumnLifetimeAnalyzer::VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) { - column_references.insert(expr.binding); + column_references.insert(expr.Binding()); return nullptr; } diff --git a/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp b/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp index c10be177d..d351ccb6a 100644 --- a/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp +++ b/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp @@ -37,9 +37,9 @@ void CommonAggregateOptimizer::VisitOperator(LogicalOperator &op) { unique_ptr CommonAggregateOptimizer::VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) { // check if this column ref points to an aggregate that was remapped; if it does we remap it - auto entry = aggregate_map.find(expr.binding); + auto entry = aggregate_map.find(expr.Binding()); if (entry != aggregate_map.end()) { - expr.binding = entry->second; + expr.BindingMutable() = entry->second; } return nullptr; } diff --git a/src/duckdb/src/optimizer/common_subplan_optimizer.cpp b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp index eb92be6eb..3b73f6843 100644 --- a/src/duckdb/src/optimizer/common_subplan_optimizer.cpp +++ b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp @@ -317,10 +317,8 @@ class PlanSignatureTableIndexMap { if (op.type == LogicalOperatorType::LOGICAL_GET) { auto &get = op.Cast(); for (auto &entry : get.table_filters) { - if (entry.Filter().filter_type != TableFilterType::EXPRESSION_FILTER) { - continue; - } - auto &expression_filter = entry.Filter().Cast(); + auto &expression_filter = + ExpressionFilter::GetExpressionFilter(entry.Filter(), "CommonSubplanOptimizer::ConvertExpressions"); ConvertExpression(*expression_filter.expr, info_idx, can_materialize); } } @@ -334,17 +332,17 @@ class PlanSignatureTableIndexMap { // Replace column binding if (expr.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { - auto &column_binding = expr.Cast().binding; + auto &col_ref = expr.Cast(); const auto lookup_idx = TYPE == ConversionType::TO_CANONICAL - ? column_binding.table_index - : restore_original_table_index.at(column_binding.table_index); + ? col_ref.Binding().table_index + : restore_original_table_index.at(col_ref.Binding().table_index); auto &table_map = table_index_map.at(lookup_idx); if (!table_map.Empty()) { // Replace column index - column_binding.column_index = table_map.Get(column_binding.column_index); + col_ref.BindingMutable().column_index = table_map.Get(col_ref.Binding().column_index); } // Replace table index - column_binding.table_index = table_index_mapping.at(column_binding.table_index); + col_ref.BindingMutable().table_index = table_index_mapping.at(col_ref.Binding().table_index); } // Replace default fields @@ -356,7 +354,7 @@ class PlanSignatureTableIndexMap { break; case ConversionType::RESTORE_ORIGINAL: auto &info = expression_info[info_idx++]; - expr.SetAlias(std::move(info.first)); + expr.SetAlias(Identifier(std::move(info.first))); expr.SetQueryLocation(info.second); break; } @@ -793,7 +791,7 @@ class CommonSubplanFinder { // Get types and names const auto &primary_subplan = subplan_info.subplans[0]; const auto &types = primary_subplan.op.get()->types; - vector col_names; + vector col_names; for (idx_t i = 0; i < types.size(); i++) { col_names.emplace_back(StringUtil::Format("%s_col_%llu", cte_name, i + 1)); } @@ -869,7 +867,7 @@ class CommonSubplanFinder { auto materialized_projection = make_uniq(optimizer.binder.GenerateTableIndex(), std::move(materialized_select_list)); materialized_projection->children.emplace_back(std::move(materialized_subplan)); - auto cte = make_uniq(cte_name, cte_index, materialized_column_count, + auto cte = make_uniq(Identifier(cte_name), cte_index, materialized_column_count, std::move(materialized_projection), std::move(remainder), CTEMaterialize::CTE_MATERIALIZE_DEFAULT); for (idx_t subplan_idx = 0; subplan_idx < subplan_info.subplans.size(); subplan_idx++) { diff --git a/src/duckdb/src/optimizer/compressed_materialization.cpp b/src/duckdb/src/optimizer/compressed_materialization.cpp index 7199df1ab..c8be97512 100644 --- a/src/duckdb/src/optimizer/compressed_materialization.cpp +++ b/src/duckdb/src/optimizer/compressed_materialization.cpp @@ -1,5 +1,6 @@ #include "duckdb/optimizer/compressed_materialization.hpp" +#include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/scalar/compressed_materialization_utils.hpp" #include "duckdb/function/scalar/nested_functions.hpp" @@ -8,6 +9,7 @@ #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/optimizer/topn_optimizer.hpp" #include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" @@ -15,6 +17,9 @@ namespace duckdb { +//===--------------------------------------------------------------------===// +// Utility +//===--------------------------------------------------------------------===// CMChildInfo::CMChildInfo(LogicalOperator &op, const column_binding_set_t &referenced_bindings) : bindings_before(op.GetColumnBindings()), types(op.types), can_compress(bindings_before.size(), true) { for (const auto &binding : referenced_bindings) { @@ -27,7 +32,7 @@ CMChildInfo::CMChildInfo(LogicalOperator &op, const column_binding_set_t &refere } CMBindingInfo::CMBindingInfo(ColumnBinding binding_p, const LogicalType &type_p) - : binding(binding_p), type(type_p), needs_decompression(false) { + : binding(binding_p), type(type_p), materialization_type(CompressedMaterializationType::INVALID) { } CompressedMaterializationInfo::CompressedMaterializationInfo(LogicalOperator &op, vector &&child_idxs_p, @@ -39,10 +44,278 @@ CompressedMaterializationInfo::CompressedMaterializationInfo(LogicalOperator &op } } -CompressExpression::CompressExpression(unique_ptr expression_p, unique_ptr stats_p) - : expression(std::move(expression_p)), stats(std::move(stats_p)) { +CompressExpression::CompressExpression(unique_ptr expression_p, unique_ptr stats_p, + CompressedMaterializationType materialization_type_p) + : expression(std::move(expression_p)), stats(std::move(stats_p)), materialization_type(materialization_type_p) { } +//===--------------------------------------------------------------------===// +// CMHelper hpp +//===--------------------------------------------------------------------===// +struct CMHelper { + static bool TypeRequiresRestore(CompressedMaterializationType materialization_type); + + static unique_ptr CreateProjection(const Optimizer &optimizer, const LogicalOperator &source, + vector> projections); + + static void RemapBindingMap(CompressedMaterializationInfo &info, + const vector &replacement_bindings); + + static void AddCompressProjectionStats(statistics_map_t &statistics_map, const vector &bindings, + vector> &compress_exprs); + + static Value GetIntegralRangeValue(ClientContext &context, const LogicalType &type, const BaseStatistics &stats); + static LogicalType GetIntegralOffsetType(uint64_t range); + static bool GetIntegralOffsetCompressInfo(ClientContext &context, const LogicalType &type, + const BaseStatistics &stats, LogicalType &offset_type, Value &min, + Value &range_value); + static LogicalType GetSameWidthIntegralType(const LogicalType &type, bool use_signed); + static bool ValuePreservingCastFits(const Value &value, const LogicalType &source_type, + const LogicalType &target_type); + static LogicalType GetIntegralCastType(const LogicalType &source_type, const LogicalType &offset_type, + const BaseStatistics &stats); + static unique_ptr CreateIntegralCastStats(const LogicalType &target_type, + const BaseStatistics &stats); + static unique_ptr CreateIntegralCastCompress(ClientContext &context, + unique_ptr input, + const LogicalType &target_type, + const BaseStatistics &stats); + static unique_ptr CreateIntegralFunctionCompress(unique_ptr input, + const LogicalType &source_type, + const LogicalType &target_type, + const Value &min, const Value &range_value, + const BaseStatistics &stats); + + static unique_ptr CreateStringCompressStats(const BaseStatistics &stats, LogicalType &target_type, + const uint32_t max_string_length); + static bool GetStringCompressInfo(const BaseStatistics &stats, LogicalType &target_type, + uint32_t &max_string_length); + static unique_ptr CreateStringFunctionCompress(unique_ptr input, + const LogicalType &target_type, + unique_ptr compress_stats); +}; + +//===--------------------------------------------------------------------===// +// CMHelper cpp +//===--------------------------------------------------------------------===// +bool CMHelper::TypeRequiresRestore(CompressedMaterializationType materialization_type) { + return materialization_type != CompressedMaterializationType::INVALID; +} + +unique_ptr CMHelper::CreateProjection(const Optimizer &optimizer, const LogicalOperator &source, + vector> projections) { + const auto table_index = optimizer.binder.GenerateTableIndex(); + auto projection = make_uniq(table_index, std::move(projections)); + if (source.has_estimated_cardinality) { + projection->SetEstimatedCardinality(source.estimated_cardinality); + } + return projection; +} + +void CMHelper::RemapBindingMap(CompressedMaterializationInfo &info, + const vector &replacement_bindings) { + auto &binding_map = info.binding_map; + for (const auto &replacement_binding : replacement_bindings) { + auto it = binding_map.find(replacement_binding.old_binding); + if (it == binding_map.end()) { + continue; + } + auto &binding_info = it->second; + if (binding_info.binding == replacement_binding.old_binding) { + binding_info.binding = replacement_binding.new_binding; + } + + if (it->first == replacement_binding.old_binding) { + auto binding_info_local = std::move(binding_info); + binding_map.erase(it); + binding_map.emplace(replacement_binding.new_binding, std::move(binding_info_local)); + } + } +} + +void CMHelper::AddCompressProjectionStats(statistics_map_t &statistics_map, const vector &bindings, + vector> &compress_exprs) { + for (idx_t col_idx = 0; col_idx < bindings.size(); col_idx++) { + const auto &binding = bindings[col_idx]; + auto &stats = compress_exprs[col_idx]->stats; + statistics_map.emplace(binding, std::move(stats)); + } +} + +Value CMHelper::GetIntegralRangeValue(ClientContext &context, const LogicalType &type, const BaseStatistics &stats) { + auto min = NumericStats::Min(stats); + auto max = NumericStats::Max(stats); + if (max < min) { + return Value::UHUGEINT(NumericLimits::Maximum()); + } + + vector> arguments; + arguments.emplace_back(make_uniq(max)); + arguments.emplace_back(make_uniq(min)); + + auto sub = SubtractFunction::GetFunction(type, type).Bind(context, std::move(arguments)); + + Value result; + if (ExpressionExecutor::TryEvaluateScalar(context, *sub, result)) { + return result; + } + // Couldn't evaluate: Return max uhugeint as range so GetIntegralCompress will return nullptr + return Value::UHUGEINT(NumericLimits::Maximum()); +} + +LogicalType CMHelper::GetIntegralOffsetType(const uint64_t range) { + if (range <= NumericLimits().Maximum()) { + return LogicalType::UTINYINT; + } + if (range <= NumericLimits().Maximum()) { + return LogicalType::USMALLINT; + } + if (range <= NumericLimits().Maximum()) { + return LogicalType::UINTEGER; + } + D_ASSERT(range <= NumericLimits().Maximum()); + return LogicalType::UBIGINT; +} + +bool CMHelper::GetIntegralOffsetCompressInfo(ClientContext &context, const LogicalType &type, + const BaseStatistics &stats, LogicalType &offset_type, Value &min, + Value &range_value) { + if (!stats.CanHaveNoNull()) { + // All NULL + offset_type = LogicalType::UTINYINT; + range_value = Value::UTINYINT(0); + min = Value(type); + return true; + } + if (!NumericStats::HasMinMax(stats)) { + return false; + } + + // Get range and cast to UBIGINT (might fail for HUGEINT, in which case we just return) + range_value = GetIntegralRangeValue(context, type, stats); + if (!range_value.DefaultTryCastAs(LogicalType::UBIGINT)) { + return false; + } + offset_type = GetIntegralOffsetType(UBigIntValue::Get(range_value)); + min = NumericStats::Min(stats); + return true; +} + +LogicalType CMHelper::GetSameWidthIntegralType(const LogicalType &type, const bool use_signed) { + switch (GetTypeIdSize(type.InternalType())) { + case 1: + return use_signed ? LogicalType::TINYINT : LogicalType::UTINYINT; + case 2: + return use_signed ? LogicalType::SMALLINT : LogicalType::USMALLINT; + case 4: + return use_signed ? LogicalType::INTEGER : LogicalType::UINTEGER; + case 8: + return use_signed ? LogicalType::BIGINT : LogicalType::UBIGINT; + default: + return LogicalType::INVALID; + } +} + +bool CMHelper::ValuePreservingCastFits(const Value &value, const LogicalType &source_type, + const LogicalType &target_type) { + Value cast_value; + Value roundtrip_value; + try { + if (!value.DefaultTryCastAs(target_type, cast_value, nullptr, true)) { + return false; + } + if (!cast_value.DefaultTryCastAs(source_type, roundtrip_value, nullptr, true)) { + return false; + } + } catch (ConversionException &) { + return false; + } + return value == roundtrip_value; +} + +LogicalType CMHelper::GetIntegralCastType(const LogicalType &source_type, const LogicalType &offset_type, + const BaseStatistics &stats) { + if (GetTypeIdSize(source_type.InternalType()) <= GetTypeIdSize(offset_type.InternalType())) { + return LogicalType::INVALID; + } + + const auto preferred_signed = source_type.IsSigned(); + auto preferred_type = GetSameWidthIntegralType(offset_type, preferred_signed); + if (!stats.CanHaveNoNull()) { + return preferred_type; + } + if (!NumericStats::HasMinMax(stats)) { + return LogicalType::INVALID; + } + + auto alternate_type = GetSameWidthIntegralType(offset_type, !preferred_signed); + const auto min = NumericStats::Min(stats); + const auto max = NumericStats::Max(stats); + if (preferred_type.IsValid() && ValuePreservingCastFits(min, source_type, preferred_type) && + ValuePreservingCastFits(max, source_type, preferred_type)) { + return preferred_type; + } + if (alternate_type.IsValid() && ValuePreservingCastFits(min, source_type, alternate_type) && + ValuePreservingCastFits(max, source_type, alternate_type)) { + return alternate_type; + } + return LogicalType::INVALID; +} + +unique_ptr CMHelper::CreateIntegralCastStats(const LogicalType &target_type, + const BaseStatistics &stats) { + auto compress_stats = BaseStatistics::CreateEmpty(target_type); + compress_stats.CopyBase(stats); + if (NumericStats::HasMinMax(stats)) { + Value cast_min; + Value cast_max; + const auto min_success = NumericStats::Min(stats).DefaultTryCastAs(target_type, cast_min, nullptr, true); + const auto max_success = NumericStats::Max(stats).DefaultTryCastAs(target_type, cast_max, nullptr, true); + if (!min_success || !max_success) { + throw InternalException("Casting failure in CMHelper::CreateIntegralCastStats"); + } + NumericStats::SetMin(compress_stats, cast_min); + NumericStats::SetMax(compress_stats, cast_max); + } + return compress_stats.ToUnique(); +} + +unique_ptr CMHelper::CreateIntegralCastCompress(ClientContext &context, + unique_ptr input, + const LogicalType &target_type, + const BaseStatistics &stats) { + auto compress_expr = BoundCastExpression::AddCastToType(context, std::move(input), target_type); + auto compress_stats = CreateIntegralCastStats(target_type, stats); + return make_uniq(std::move(compress_expr), std::move(compress_stats), + CompressedMaterializationType::CAST); +} + +unique_ptr CMHelper::CreateIntegralFunctionCompress(unique_ptr input, + const LogicalType &source_type, + const LogicalType &target_type, + const Value &min, const Value &range_value, + const BaseStatistics &stats) { + auto compress_function = CMIntegralCompressFun::GetFunction(source_type, target_type); + vector> arguments; + arguments.emplace_back(std::move(input)); + arguments.emplace_back(make_uniq(min)); + + BoundScalarFunction bound_function(compress_function); + bound_function.SetReturnType(target_type); + auto compress_expr = make_uniq(std::move(bound_function), std::move(arguments), nullptr); + + auto compress_stats = BaseStatistics::CreateEmpty(target_type); + compress_stats.CopyBase(stats); + NumericStats::SetMin(compress_stats, Value(0).DefaultCastAs(target_type)); + NumericStats::SetMax(compress_stats, range_value.DefaultCastAs(target_type)); + + return make_uniq(std::move(compress_expr), compress_stats.ToUnique(), + CompressedMaterializationType::FUNCTION); +} + +//===--------------------------------------------------------------------===// +// CompressedMaterialization +//===--------------------------------------------------------------------===// CompressedMaterialization::CompressedMaterialization(Optimizer &optimizer_p, LogicalOperator &root_p, statistics_map_t &statistics_map_p) : optimizer(optimizer_p), context(optimizer.context), root(&root_p), statistics_map(statistics_map_p) { @@ -51,11 +324,11 @@ CompressedMaterialization::CompressedMaterialization(Optimizer &optimizer_p, Log void CompressedMaterialization::GetReferencedBindings(const Expression &root_expr, column_binding_set_t &referenced_bindings) { ExpressionIterator::VisitExpression( - root_expr, [&](const BoundColumnRefExpression &col_ref) { referenced_bindings.insert(col_ref.binding); }); + root_expr, [&](const BoundColumnRefExpression &col_ref) { referenced_bindings.insert(col_ref.Binding()); }); } void CompressedMaterialization::UpdateBindingInfo(CompressedMaterializationInfo &info, const ColumnBinding &binding, - bool needs_decompression) { + CompressedMaterializationType materialization_type) { auto &binding_map = info.binding_map; auto binding_it = binding_map.find(binding); if (binding_it == binding_map.end()) { @@ -63,10 +336,12 @@ void CompressedMaterialization::UpdateBindingInfo(CompressedMaterializationInfo } auto &binding_info = binding_it->second; - binding_info.needs_decompression = needs_decompression; - auto stats_it = statistics_map.find(binding); - if (stats_it != statistics_map.end()) { - binding_info.stats = statistics_map[binding]->ToUnique(); + binding_info.materialization_type = materialization_type; + if (!binding_info.stats) { + auto stats_it = statistics_map.find(binding); + if (stats_it != statistics_map.end() && stats_it->second) { + binding_info.stats = stats_it->second->ToUnique(); + } } } @@ -135,23 +410,23 @@ bool CompressedMaterialization::TryCompressChild(CompressedMaterializationInfo & const auto &child_type = child_info.types[child_i]; const auto &can_compress = child_info.can_compress[child_i]; auto compress_expr = GetCompressExpression(child_binding, child_type, can_compress); - bool compressed = false; - if (compress_expr) { // We compressed, mark the outgoing binding in need of decompression - compress_exprs.emplace_back(std::move(compress_expr)); - compressed = true; - } else { // We did not compress, just push a colref + if (!compress_expr) { + auto it = statistics_map.find(child_binding); + auto colref_stats = it != statistics_map.end() ? it->second->ToUnique() : nullptr; auto colref_expr = make_uniq(child_type, child_binding); - auto it = statistics_map.find(colref_expr->binding); - unique_ptr colref_stats = it != statistics_map.end() ? it->second->ToUnique() : nullptr; - compress_exprs.emplace_back(make_uniq(std::move(colref_expr), std::move(colref_stats))); + compress_expr = make_uniq(std::move(colref_expr), std::move(colref_stats), + CompressedMaterializationType::INVALID); } - UpdateBindingInfo(info, child_binding, compressed); - compressed_anything = compressed_anything || compressed; + const auto materialization_type = compress_expr->materialization_type; + compress_exprs.emplace_back(std::move(compress_expr)); + UpdateBindingInfo(info, child_binding, materialization_type); + compressed_anything = compressed_anything || CMHelper::TypeRequiresRestore(materialization_type); } if (!compressed_anything) { // If we compressed anything non-generically, we still need to decompress for (const auto &entry : info.binding_map) { - compressed_anything = compressed_anything || entry.second.needs_decompression; + compressed_anything = + compressed_anything || CMHelper::TypeRequiresRestore(entry.second.materialization_type); } } return compressed_anything; @@ -166,11 +441,7 @@ void CompressedMaterialization::CreateCompressProjection(unique_ptrexpression)); } - const auto table_index = optimizer.binder.GenerateTableIndex(); - auto compress_projection = make_uniq(table_index, std::move(projections)); - if (child_op->has_estimated_cardinality) { - compress_projection->SetEstimatedCardinality(child_op->estimated_cardinality); - } + auto compress_projection = CMHelper::CreateProjection(optimizer, *child_op, std::move(projections)); compress_projection->ResolveOperatorTypes(); compress_projection->children.emplace_back(std::move(child_op)); @@ -200,29 +471,24 @@ void CompressedMaterialization::CreateCompressProjection(unique_ptrsecond; - if (binding_info.binding == replacement_binding.old_binding) { - binding_info.binding = replacement_binding.new_binding; - } - - if (it->first == replacement_binding.old_binding) { - auto binding_info_local = std::move(binding_info); - binding_map.erase(it); - binding_map.emplace(replacement_binding.new_binding, std::move(binding_info_local)); - } - } + CMHelper::RemapBindingMap(info, replacement_bindings); // Add projection stats to statistics map - for (idx_t col_idx = 0; col_idx < child_info.bindings_after.size(); col_idx++) { - const auto &binding = child_info.bindings_after[col_idx]; - auto &stats = compress_exprs[col_idx]->stats; - statistics_map.emplace(binding, std::move(stats)); + CMHelper::AddCompressProjectionStats(statistics_map, child_info.bindings_after, compress_exprs); +} + +unique_ptr CompressedMaterialization::CreateRestoreExpression(unique_ptr input, + const CMBindingInfo &binding_info, + const BaseStatistics &stats) { + switch (binding_info.materialization_type) { + case CompressedMaterializationType::INVALID: + return input; + case CompressedMaterializationType::FUNCTION: + return GetDecompressExpression(std::move(input), binding_info.type, stats); + case CompressedMaterializationType::CAST: + return BoundCastExpression::AddCastToType(context, std::move(input), binding_info.type); + default: + throw InternalException("Invalid compressed materialization type"); } } @@ -246,8 +512,8 @@ void CompressedMaterialization::CreateDecompressProjection(unique_ptr(table_index, std::move(decompress_exprs)); - if (op->has_estimated_cardinality) { - decompress_projection->SetEstimatedCardinality(op->estimated_cardinality); - } + auto decompress_projection = CMHelper::CreateProjection(optimizer, *op, std::move(decompress_exprs)); decompress_projection->children.emplace_back(std::move(op)); op = std::move(decompress_projection); @@ -311,39 +573,21 @@ unique_ptr CompressedMaterialization::GetCompressExpression( unique_ptr CompressedMaterialization::GetCompressExpression(unique_ptr input, const BaseStatistics &stats) { const auto &type = input->GetReturnType(); + if (type.IsAggregateState()) { + return nullptr; + } if (type != stats.GetType()) { // LCOV_EXCL_START return nullptr; } // LCOV_EXCL_STOP if (type.IsIntegral()) { return GetIntegralCompress(std::move(input), stats); - } else if (type.id() == LogicalTypeId::VARCHAR) { + } + if (type.id() == LogicalTypeId::VARCHAR) { return GetStringCompress(std::move(input), stats); } return nullptr; } -static Value GetIntegralRangeValue(ClientContext &context, const LogicalType &type, const BaseStatistics &stats) { - auto min = NumericStats::Min(stats); - auto max = NumericStats::Max(stats); - if (max < min) { - return Value::UHUGEINT(NumericLimits::Maximum()); - } - - vector> arguments; - arguments.emplace_back(make_uniq(max)); - arguments.emplace_back(make_uniq(min)); - - auto sub = SubtractFunction::GetFunction(type, type).Bind(context, std::move(arguments)); - - Value result; - if (ExpressionExecutor::TryEvaluateScalar(context, *sub, result)) { - return result; - } else { - // Couldn't evaluate: Return max uhugeint as range so GetIntegralCompress will return nullptr - return Value::UHUGEINT(NumericLimits::Maximum()); - } -} - unique_ptr CompressedMaterialization::GetIntegralCompress(unique_ptr input, const BaseStatistics &stats) { const auto &type = input->GetReturnType(); @@ -354,33 +598,7 @@ unique_ptr CompressedMaterialization::GetIntegralCompress(un LogicalType cast_type; Value range_value; Value min; - if (!stats.CanHaveNoNull()) { - // All NULL - cast_type = LogicalType::UTINYINT; - range_value = Value::UTINYINT(0); - min = Value(input->GetReturnType()); - } else if (NumericStats::HasMinMax(stats)) { - // Get range and cast to UBIGINT (might fail for HUGEINT, in which case we just return) - range_value = GetIntegralRangeValue(context, type, stats); - if (!range_value.DefaultTryCastAs(LogicalType::UBIGINT)) { - return nullptr; - } - - // Get the smallest type that the range can fit into - const auto range = UBigIntValue::Get(range_value); - if (range <= NumericLimits().Maximum()) { - cast_type = LogicalType::UTINYINT; - } else if (range <= NumericLimits().Maximum()) { - cast_type = LogicalType::USMALLINT; - } else if (range <= NumericLimits().Maximum()) { - cast_type = LogicalType::UINTEGER; - } else { - D_ASSERT(range <= NumericLimits().Maximum()); - cast_type = LogicalType::UBIGINT; - } - - min = NumericStats::Min(stats); - } else { + if (!CMHelper::GetIntegralOffsetCompressInfo(context, type, stats, cast_type, min, range_value)) { // We don't have enough stats to do anything return nullptr; } @@ -391,86 +609,96 @@ unique_ptr CompressedMaterialization::GetIntegralCompress(un } D_ASSERT(GetTypeIdSize(cast_type.InternalType()) < GetTypeIdSize(type.InternalType())); - // Compressing will yield a benefit - auto compress_function = CMIntegralCompressFun::GetFunction(type, cast_type); - vector> arguments; - arguments.emplace_back(std::move(input)); - arguments.emplace_back(make_uniq(min)); + const auto value_preserving_cast_type = CMHelper::GetIntegralCastType(type, cast_type, stats); + if (value_preserving_cast_type.IsValid()) { + return CMHelper::CreateIntegralCastCompress(context, std::move(input), value_preserving_cast_type, stats); + } - BoundScalarFunction bound_function(compress_function); - bound_function.SetReturnType(cast_type); - auto compress_expr = make_uniq(std::move(bound_function), std::move(arguments), nullptr); + return CMHelper::CreateIntegralFunctionCompress(std::move(input), type, cast_type, min, range_value, stats); +} - auto compress_stats = BaseStatistics::CreateEmpty(cast_type); +unique_ptr CMHelper::CreateStringCompressStats(const BaseStatistics &stats, LogicalType &target_type, + const uint32_t max_string_length) { + auto compress_stats = BaseStatistics::CreateEmpty(target_type); compress_stats.CopyBase(stats); - NumericStats::SetMin(compress_stats, Value(0).DefaultCastAs(cast_type)); - NumericStats::SetMax(compress_stats, range_value.DefaultCastAs(cast_type)); + if (target_type.id() != LogicalTypeId::USMALLINT || !StringStats::HasMinMax(stats)) { + return compress_stats.ToUnique(); + } + + auto min_string = StringStats::Min(stats); + auto max_string = StringStats::Max(stats); - return make_uniq(std::move(compress_expr), compress_stats.ToUnique()); + uint8_t min_numeric = 0; + if (max_string_length != 0 && !min_string.empty()) { + min_numeric = *reinterpret_cast(min_string.c_str()); + } + uint8_t max_numeric = 0; + if (max_string_length != 0 && !max_string.empty()) { + max_numeric = *reinterpret_cast(max_string.c_str()); + } + + Value min_val = Value::USMALLINT(min_numeric); + Value max_val = Value::USMALLINT(max_numeric + 1); + if (max_numeric < NumericLimits::Maximum()) { + target_type = LogicalType::UTINYINT; + compress_stats = BaseStatistics::CreateEmpty(target_type); + compress_stats.CopyBase(stats); + min_val = Value::UTINYINT(min_numeric); + max_val = Value::UTINYINT(max_numeric + 1); + } + + NumericStats::SetMin(compress_stats, min_val); + NumericStats::SetMax(compress_stats, max_val); + return compress_stats.ToUnique(); } -unique_ptr CompressedMaterialization::GetStringCompress(unique_ptr input, - const BaseStatistics &stats) { - LogicalType cast_type = LogicalType::INVALID; - uint32_t max_string_length; +bool CMHelper::GetStringCompressInfo(const BaseStatistics &stats, LogicalType &target_type, + uint32_t &max_string_length) { if (!stats.CanHaveNoNull()) { // All NULL - cast_type = LogicalType::UTINYINT; + target_type = LogicalType::UTINYINT; max_string_length = 0; - } else if (StringStats::HasMaxStringLength(stats)) { - max_string_length = StringStats::MaxStringLength(stats); - for (const auto &compressed_type : CMUtils::StringTypes()) { - if (max_string_length < GetTypeIdSize(compressed_type.InternalType())) { - cast_type = compressed_type; - break; - } - } - if (cast_type == LogicalType::INVALID) { - return nullptr; - } - } else { - // We don't have enough stats to do anything - return nullptr; + return true; + } + if (!StringStats::HasMaxStringLength(stats)) { + return false; } - auto compress_stats = BaseStatistics::CreateEmpty(cast_type); - compress_stats.CopyBase(stats); - if (cast_type.id() == LogicalTypeId::USMALLINT) { - auto min_string = StringStats::Min(stats); - auto max_string = StringStats::Max(stats); - - uint8_t min_numeric = 0; - if (max_string_length != 0 && !min_string.empty()) { - min_numeric = *reinterpret_cast(min_string.c_str()); - } - uint8_t max_numeric = 0; - if (max_string_length != 0 && !max_string.empty()) { - max_numeric = *reinterpret_cast(max_string.c_str()); - } - - Value min_val = Value::USMALLINT(min_numeric); - Value max_val = Value::USMALLINT(max_numeric + 1); - if (max_numeric < NumericLimits::Maximum()) { - cast_type = LogicalType::UTINYINT; - compress_stats = BaseStatistics::CreateEmpty(cast_type); - compress_stats.CopyBase(stats); - min_val = Value::UTINYINT(min_numeric); - max_val = Value::UTINYINT(max_numeric + 1); + max_string_length = StringStats::MaxStringLength(stats); + for (const auto &compressed_type : CMUtils::StringTypes()) { + if (max_string_length < GetTypeIdSize(compressed_type.InternalType())) { + target_type = compressed_type; + return true; } - - NumericStats::SetMin(compress_stats, min_val); - NumericStats::SetMax(compress_stats, max_val); } + return false; +} - auto compress_function = CMStringCompressFun::GetFunction(cast_type); +unique_ptr CMHelper::CreateStringFunctionCompress(unique_ptr input, + const LogicalType &target_type, + unique_ptr compress_stats) { + auto compress_function = CMStringCompressFun::GetFunction(target_type); vector> arguments; arguments.emplace_back(std::move(input)); BoundScalarFunction bound_function(compress_function); - bound_function.SetReturnType(cast_type); + bound_function.SetReturnType(target_type); auto compress_expr = make_uniq(std::move(bound_function), std::move(arguments), nullptr); - return make_uniq(std::move(compress_expr), compress_stats.ToUnique()); + return make_uniq(std::move(compress_expr), std::move(compress_stats), + CompressedMaterializationType::FUNCTION); +} + +unique_ptr CompressedMaterialization::GetStringCompress(unique_ptr input, + const BaseStatistics &stats) { + LogicalType cast_type = LogicalType::INVALID; + uint32_t max_string_length = 0; + if (!CMHelper::GetStringCompressInfo(stats, cast_type, max_string_length)) { + // We don't have enough stats to do anything + return nullptr; + } + auto compress_stats = CMHelper::CreateStringCompressStats(stats, cast_type, max_string_length); + return CMHelper::CreateStringFunctionCompress(std::move(input), cast_type, std::move(compress_stats)); } unique_ptr CompressedMaterialization::GetDecompressExpression(unique_ptr input, @@ -479,11 +707,11 @@ unique_ptr CompressedMaterialization::GetDecompressExpression(unique const auto &type = result_type; if (TypeIsIntegral(type.InternalType())) { return GetIntegralDecompress(std::move(input), result_type, stats); - } else if (type.id() == LogicalTypeId::VARCHAR) { + } + if (type.id() == LogicalTypeId::VARCHAR) { return GetStringDecompress(std::move(input), result_type, stats); - } else { - throw InternalException("Type other than integral/string marked for decompression!"); } + throw InternalException("Type other than integral/string marked for decompression!"); } unique_ptr CompressedMaterialization::GetIntegralDecompress(unique_ptr input, diff --git a/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp b/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp index 83a4be728..4de2f4af8 100644 --- a/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp +++ b/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp @@ -7,9 +7,6 @@ namespace duckdb { void CompressedMaterialization::CompressAggregate(unique_ptr &op) { auto &aggregate = op->Cast(); - if (aggregate.grouping_sets.size() > 1) { - return; // FIXME: we should be able to compress here but for some reason the NULL statistics ain't right - } auto &groups = aggregate.groups; column_binding_set_t group_binding_set; for (const auto &group : groups) { @@ -17,10 +14,10 @@ void CompressedMaterialization::CompressAggregate(unique_ptr &o continue; } auto &colref = group->Cast(); - if (group_binding_set.find(colref.binding) != group_binding_set.end()) { + if (group_binding_set.find(colref.Binding()) != group_binding_set.end()) { return; // Duplicate group - don't compress } - group_binding_set.insert(colref.binding); + group_binding_set.insert(colref.Binding()); } auto &group_stats = aggregate.group_stats; @@ -35,14 +32,14 @@ void CompressedMaterialization::CompressAggregate(unique_ptr &o // But we can try to compress the expression directly column_binding_set_t referenced_bindings; vector group_bindings(groups.size(), ColumnBinding()); - vector needs_decompression(groups.size(), false); + vector materialization_types(groups.size(), CompressedMaterializationType::INVALID); vector> stored_group_stats; stored_group_stats.resize(groups.size()); for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { auto &group_expr = *groups[group_idx]; if (group_expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &colref = group_expr.Cast(); - group_bindings[group_idx] = colref.binding; + group_bindings[group_idx] = colref.Binding(); continue; // Will be compressed generically } @@ -57,7 +54,7 @@ void CompressedMaterialization::CompressAggregate(unique_ptr &o // Try to compress, if successful, replace the expression auto compress_expr = GetCompressExpression(group_expr.Copy(), *group_stats[group_idx]); if (compress_expr) { - needs_decompression[group_idx] = true; + materialization_types[group_idx] = compress_expr->materialization_type; stored_group_stats[group_idx] = std::move(group_stats[group_idx]); groups[group_idx] = std::move(compress_expr->expression); group_stats[group_idx] = std::move(compress_expr->stats); @@ -69,14 +66,14 @@ void CompressedMaterialization::CompressAggregate(unique_ptr &o const auto &expr = *aggregate.expressions[expr_idx]; D_ASSERT(expr.GetExpressionType() == ExpressionType::BOUND_AGGREGATE); const auto &aggr_expr = expr.Cast(); - for (const auto &child : aggr_expr.children) { + for (const auto &child : aggr_expr.GetChildren()) { GetReferencedBindings(*child, referenced_bindings); } - if (aggr_expr.filter) { - GetReferencedBindings(*aggr_expr.filter, referenced_bindings); + if (aggr_expr.GetFilter()) { + GetReferencedBindings(*aggr_expr.GetFilter(), referenced_bindings); } - if (aggr_expr.order_bys) { - for (const auto &order : aggr_expr.order_bys->orders) { + if (aggr_expr.GetOrderBys()) { + for (const auto &order : aggr_expr.GetOrderBys()->orders) { const auto &order_expr = *order.expression; if (order_expr.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { GetReferencedBindings(order_expr, referenced_bindings); @@ -94,8 +91,11 @@ void CompressedMaterialization::CompressAggregate(unique_ptr &o for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { // Aggregate changes bindings as it has a table idx CMBindingInfo binding_info(bindings_out[group_idx], types[group_idx]); - binding_info.needs_decompression = needs_decompression[group_idx]; - if (needs_decompression[group_idx]) { + if (!aggregate.grouping_sets.empty() && group_stats[group_idx]) { + binding_info.stats = group_stats[group_idx]->ToUnique(); + } + binding_info.materialization_type = materialization_types[group_idx]; + if (materialization_types[group_idx] != CompressedMaterializationType::INVALID) { // Compressed non-generically auto entry = info.binding_map.emplace(bindings_out[group_idx], std::move(binding_info)); entry.first->second.stats = std::move(stored_group_stats[group_idx]); @@ -133,7 +133,7 @@ void CompressedMaterialization::UpdateAggregateStats(unique_ptr if (colref.GetReturnType() == group_stats[group_idx]->GetType()) { continue; } - auto it = statistics_map.find(colref.binding); + auto it = statistics_map.find(colref.Binding()); if (it != statistics_map.end() && it->second) { group_stats[group_idx] = it->second->ToUnique(); } diff --git a/src/duckdb/src/optimizer/compressed_materialization/compress_comparison_join.cpp b/src/duckdb/src/optimizer/compressed_materialization/compress_comparison_join.cpp index b60f8bf3d..9570bd716 100644 --- a/src/duckdb/src/optimizer/compressed_materialization/compress_comparison_join.cpp +++ b/src/duckdb/src/optimizer/compressed_materialization/compress_comparison_join.cpp @@ -75,13 +75,13 @@ void CompressedMaterialization::CompressComparisonJoin(unique_ptr(); auto &rhs_colref = condition.GetRHS().Cast(); - bool lhs_referenced = referenced_bindings.count(lhs_colref.binding) > 0; - bool rhs_referenced = referenced_bindings.count(rhs_colref.binding) > 0; + bool lhs_referenced = referenced_bindings.count(lhs_colref.Binding()) > 0; + bool rhs_referenced = referenced_bindings.count(rhs_colref.Binding()) > 0; if (!lhs_referenced && !rhs_referenced) { // Both are bound column refs, see if both can be compressed generically to the same type - auto lhs_it = statistics_map.find(lhs_colref.binding); - auto rhs_it = statistics_map.find(rhs_colref.binding); + auto lhs_it = statistics_map.find(lhs_colref.Binding()); + auto rhs_it = statistics_map.find(rhs_colref.Binding()); if (lhs_it != statistics_map.end() && rhs_it != statistics_map.end() && lhs_it->second && rhs_it->second) { // For joins we need to compress both using the same statistics, otherwise comparisons don't @@ -96,7 +96,7 @@ void CompressedMaterialization::CompressComparisonJoin(unique_ptrsecond->Merge(merged_stats); rhs_it->second->Merge(merged_stats); - probe_compress_bindings.insert(lhs_colref.binding); + probe_compress_bindings.insert(lhs_colref.Binding()); continue; } } @@ -155,8 +155,8 @@ void CompressedMaterialization::UpdateComparisonJoinStats(unique_ptr(); auto &rhs_colref = condition.GetRHS().Cast(); - auto lhs_it = statistics_map.find(lhs_colref.binding); - auto rhs_it = statistics_map.find(rhs_colref.binding); + auto lhs_it = statistics_map.find(lhs_colref.Binding()); + auto rhs_it = statistics_map.find(rhs_colref.Binding()); if (lhs_it != statistics_map.end() && lhs_it->second) { condition.SetLeftStats(lhs_it->second->ToUnique()); } diff --git a/src/duckdb/src/optimizer/compressed_materialization/compress_order.cpp b/src/duckdb/src/optimizer/compressed_materialization/compress_order.cpp index 2b098fda6..38e68372b 100644 --- a/src/duckdb/src/optimizer/compressed_materialization/compress_order.cpp +++ b/src/duckdb/src/optimizer/compressed_materialization/compress_order.cpp @@ -55,7 +55,7 @@ void CompressedMaterialization::UpdateOrderStats(unique_ptr &op continue; } auto &colref = order_expression.Cast(); - auto it = statistics_map.find(colref.binding); + auto it = statistics_map.find(colref.Binding()); if (it != statistics_map.end() && it->second) { bound_order.stats = it->second->ToUnique(); } diff --git a/src/duckdb/src/optimizer/cse_optimizer.cpp b/src/duckdb/src/optimizer/cse_optimizer.cpp index 122e54e34..53664815b 100644 --- a/src/duckdb/src/optimizer/cse_optimizer.cpp +++ b/src/duckdb/src/optimizer/cse_optimizer.cpp @@ -97,17 +97,17 @@ void CommonSubExpressionOptimizer::PerformCSEReplacement(unique_ptr if (expr.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { auto &bound_column_ref = expr.Cast(); // bound column ref, check if this one has already been recorded in the expression list - auto column_entry = state.column_map.find(bound_column_ref.binding); + auto column_entry = state.column_map.find(bound_column_ref.Binding()); if (column_entry == state.column_map.end()) { // not there yet: push the expression auto new_col_ref = make_uniq( - bound_column_ref.GetAlias(), bound_column_ref.GetReturnType(), bound_column_ref.binding); + bound_column_ref.GetAlias(), bound_column_ref.GetReturnType(), bound_column_ref.Binding()); auto new_column_index = ColumnBinding::PushExpression(state.expressions, std::move(new_col_ref)); - state.column_map[bound_column_ref.binding] = new_column_index; - bound_column_ref.binding = ColumnBinding(state.projection_index, new_column_index); + state.column_map[bound_column_ref.Binding()] = new_column_index; + bound_column_ref.BindingMutable() = ColumnBinding(state.projection_index, new_column_index); } else { // else: just update the column binding! - bound_column_ref.binding = ColumnBinding(state.projection_index, column_entry->second); + bound_column_ref.BindingMutable() = ColumnBinding(state.projection_index, column_entry->second); } return; } diff --git a/src/duckdb/src/optimizer/cte_inlining.cpp b/src/duckdb/src/optimizer/cte_inlining.cpp index 089430a7d..b3dba3f65 100644 --- a/src/duckdb/src/optimizer/cte_inlining.cpp +++ b/src/duckdb/src/optimizer/cte_inlining.cpp @@ -87,6 +87,51 @@ bool CTEInlining::EndsInAggregateOrDistinct(const LogicalOperator &op) { return false; } +static bool EndsInDummyScan(const LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_DUMMY_SCAN || op.type == LogicalOperatorType::LOGICAL_EMPTY_RESULT || + op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + return true; + } + if (op.children.size() != 1) { + return false; + } + for (auto &child : op.children) { + if (EndsInDummyScan(*child)) { + return true; + } + } + return false; +} + +static bool ContainsDelimGet(const LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_DELIM_GET) { + return true; + } + for (auto &child : op.children) { + if (ContainsDelimGet(*child)) { + return true; + } + } + return false; +} + +static bool HasCTEReferenceBelowDelimJoin(const LogicalOperator &op, TableIndex cte_index, + bool below_delim_join = false) { + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + auto &cteref = op.Cast(); + if (cteref.cte_index == cte_index) { + return below_delim_join; + } + } + auto child_below_delim_join = below_delim_join || op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN; + for (auto &child : op.children) { + if (HasCTEReferenceBelowDelimJoin(*child, cte_index, child_below_delim_join)) { + return true; + } + } + return false; +} + void CTEInlining::TryInlining(unique_ptr &op) { if (op->type == LogicalOperatorType::LOGICAL_PREPARE) { // we are in a prepare statement, if we have to copy an operator during inlining, @@ -120,6 +165,12 @@ void CTEInlining::TryInlining(unique_ptr &op) { // With ref_count>1 and requires_copy, the DML would execute once per copy. return; } + if (ContainsDelimGet(*cte.children[0]) && HasCTEReferenceBelowDelimJoin(*op->children[1], cte.table_index)) { + // Inlining a CTE that already contains a DELIM_GET stays safe while all matching CTE scans remain outside + // DELIM_JOIN subtrees, but once a scan is nested below another DELIM_JOIN the inlined DELIM_GETs can attach + // to the wrong duplicate-elimination source. + return; + } if (cte.materialize == CTEMaterialize::CTE_MATERIALIZE_ALWAYS) { // This CTE is always materialized, we cannot inline it return; @@ -157,10 +208,14 @@ void CTEInlining::TryInlining(unique_ptr &op) { return; } + bool is_cheap_to_inline = op->children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT || + op->children[0]->type == LogicalOperatorType::LOGICAL_CTE_REF || + EndsInDummyScan(*op->children[0]); + // Check how many base table references the CTE has auto base_table_references = CountBaseTableReferences(*op->children[0]); - if (base_table_references > 2 && base_table_references * ref_count > 10) { + if (!is_cheap_to_inline && base_table_references > 2 && base_table_references * ref_count > 10) { return; } @@ -169,7 +224,7 @@ void CTEInlining::TryInlining(unique_ptr &op) { // even if only a part of the CTE result is needed. // Therefore, we check if the CTE Scans are below the LIMIT or TOP_N operator // and if so, we try to inline the CTE definition. - if (ContainsLimit(*op->children[1]) || op->children[0]->type == LogicalOperatorType::LOGICAL_EMPTY_RESULT) { + if (is_cheap_to_inline || ContainsLimit(*op->children[1])) { // this CTE is referenced multiple times and has a limit, we want to inline it bool success = Inline(op->children[1], *op, true); if (success) { @@ -244,7 +299,7 @@ void PreventInlining::VisitExpression(unique_ptr *expression) { if (expr->GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { auto &bound_function = expr->Cast(); // if we encounter the ErrorFun function, we still want to inline - if (bound_function.function.GetName() == "error") { + if (bound_function.Function().GetName() == "error") { return; } diff --git a/src/duckdb/src/optimizer/deliminator.cpp b/src/duckdb/src/optimizer/deliminator.cpp index 1b2cfb535..c4a5d43bc 100644 --- a/src/duckdb/src/optimizer/deliminator.cpp +++ b/src/duckdb/src/optimizer/deliminator.cpp @@ -119,11 +119,7 @@ bool Deliminator::HasSelection(const LogicalOperator &op) { case LogicalOperatorType::LOGICAL_GET: { auto &get = op.Cast(); for (const auto &entry : get.table_filters) { - auto &filter = entry.Filter(); - if (filter.filter_type != TableFilterType::EXPRESSION_FILTER) { - return true; - } - auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "Deliminator::HasSelection"); + auto &expr_filter = ExpressionFilter::GetExpressionFilter(entry.Filter(), "Deliminator::HasSelection"); auto &expr = *expr_filter.expr; if (expr.GetExpressionClass() != ExpressionClass::BOUND_OPERATOR || expr.GetExpressionType() != ExpressionType::OPERATOR_IS_NOT_NULL) { @@ -226,7 +222,7 @@ bool Deliminator::RemoveJoinWithDelimGet(LogicalComparisonJoin &delim_join, cons } auto &delim_colref = delim_side.Cast(); auto &other_colref = other_side.Cast(); - replacement_bindings.emplace_back(delim_colref.binding, other_colref.binding); + replacement_bindings.emplace_back(delim_colref.Binding(), other_colref.Binding()); // Only add IS NOT NULL filter for regular equality/inequality comparisons // Do NOT add for DISTINCT FROM variants, as they handle NULL correctly @@ -234,7 +230,7 @@ bool Deliminator::RemoveJoinWithDelimGet(LogicalComparisonJoin &delim_join, cons cond.GetComparisonType() != ExpressionType::COMPARE_DISTINCT_FROM) { auto is_not_null_expr = make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); - is_not_null_expr->children.push_back(other_side.Copy()); + is_not_null_expr->GetChildrenMutable().push_back(other_side.Copy()); filter_expressions.push_back(std::move(is_not_null_expr)); } } @@ -280,7 +276,7 @@ bool FindAndReplaceBindings(vector &traced_bindings, const vector } auto &colref = expressions[current_idx]->Cast(); - binding = colref.binding; + binding = colref.Binding(); } return true; } @@ -315,7 +311,7 @@ bool Deliminator::RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_ return false; } auto &colref = cond.GetRHS().Cast(); - traced_bindings.emplace_back(colref.binding); + traced_bindings.emplace_back(colref.Binding()); } // Now we trace down the bindings to the join (for now, we only trace it through a few operators) @@ -352,7 +348,7 @@ bool Deliminator::RemoveInequalityJoinWithDelimGet(LogicalComparisonJoin &delim_ for (auto &join_condition : join_conditions) { auto &delim_side = delim_idx == 0 ? join_condition.GetLHS() : join_condition.GetRHS(); auto &colref = delim_side.Cast(); - if (colref.binding == traced_binding) { + if (colref.Binding() == traced_binding) { if (!join_condition.IsComparison()) { continue; } @@ -417,7 +413,7 @@ void Deliminator::TrySwitchSingleToLeft(LogicalComparisonJoin &delim_join) { return; } auto &colref = cond.GetRHS().Cast(); - join_bindings.emplace_back(colref.binding); + join_bindings.emplace_back(colref.Binding()); } // Now try to find an aggr in the RHS such that the join_column_bindings is a superset of the groups diff --git a/src/duckdb/src/optimizer/expression_heuristics.cpp b/src/duckdb/src/optimizer/expression_heuristics.cpp index ac91f0fec..9c6e196c3 100644 --- a/src/duckdb/src/optimizer/expression_heuristics.cpp +++ b/src/duckdb/src/optimizer/expression_heuristics.cpp @@ -4,9 +4,8 @@ #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/list.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" namespace duckdb { @@ -30,7 +29,7 @@ void ExpressionHeuristics::VisitOperator(LogicalOperator &op) { unique_ptr ExpressionHeuristics::VisitReplace(BoundConjunctionExpression &expr, unique_ptr *expr_ptr) { - ReorderExpressions(expr.children); + ReorderExpressions(expr.GetChildrenMutable()); return nullptr; } @@ -79,11 +78,11 @@ idx_t ExpressionHeuristics::BetweenExpressionCost(const BoundFunctionExpression idx_t ExpressionHeuristics::ExpressionCost(const BoundCaseExpression &expr) { // CASE WHEN check THEN result_if_true ELSE result_if_false END idx_t case_cost = 0; - for (auto &case_check : expr.case_checks) { + for (auto &case_check : expr.CaseChecks()) { case_cost += Cost(*case_check.then_expr); case_cost += Cost(*case_check.when_expr); } - case_cost += Cost(*expr.else_expr); + case_cost += Cost(expr.Else()); return case_cost; } @@ -101,7 +100,7 @@ idx_t ExpressionHeuristics::ExpressionCost(const BoundCastExpression &expr) { cast_cost = 5; } } - return Cost(*expr.child) + cast_cost; + return Cost(expr.Child()) + cast_cost; } idx_t ExpressionHeuristics::ComparisonExpressionCost(const BoundFunctionExpression &expr) { @@ -115,7 +114,7 @@ idx_t ExpressionHeuristics::ComparisonExpressionCost(const BoundFunctionExpressi idx_t ExpressionHeuristics::ExpressionCost(const BoundConjunctionExpression &expr) { // CONJUNCTION_AND, CONJUNCTION_OR idx_t cost = 5; - for (auto &child : expr.children) { + for (auto &child : expr.GetChildren()) { cost += Cost(*child); } return cost; @@ -128,19 +127,18 @@ idx_t ExpressionHeuristics::ExpressionCost(const BoundFunctionExpression &expr) if (BoundComparisonExpression::IsComparison(expr)) { return ComparisonExpressionCost(expr); } - unordered_map function_costs = { - {"+", 5}, {"-", 5}, {"&", 5}, {"#", 5}, - {">>", 5}, {"<<", 5}, {"abs", 5}, {"*", 10}, - {"%", 10}, {"/", 15}, {"date_part", 20}, {"year", 20}, - {"round", 100}, {"~~", 200}, {"!~~", 200}, {"regexp_matches", 200}, - {"||", 200}}; + identifier_map_t function_costs = {{"+", 5}, {"-", 5}, {"&", 5}, {"#", 5}, + {">>", 5}, {"<<", 5}, {"abs", 5}, {"*", 10}, + {"%", 10}, {"/", 15}, {"date_part", 20}, {"year", 20}, + {"round", 100}, {"~~", 200}, {"!~~", 200}, {"regexp_matches", 200}, + {"||", 200}}; idx_t cost_children = 0; - for (auto &child : expr.children) { + for (auto &child : expr.GetChildren()) { cost_children += Cost(*child); } - auto cost_function = function_costs.find(expr.function.GetName()); + auto cost_function = function_costs.find(expr.Function().GetName()); if (cost_function != function_costs.end()) { return cost_children + cost_function->second; } else { @@ -150,7 +148,7 @@ idx_t ExpressionHeuristics::ExpressionCost(const BoundFunctionExpression &expr) idx_t ExpressionHeuristics::ExpressionCost(const BoundOperatorExpression &expr, ExpressionType expr_type) { idx_t sum = 0; - for (auto &child : expr.children) { + for (auto &child : expr.GetChildren()) { sum += Cost(*child); } @@ -159,7 +157,7 @@ idx_t ExpressionHeuristics::ExpressionCost(const BoundOperatorExpression &expr, return sum + 5; } else if (expr_type == ExpressionType::COMPARE_IN || expr_type == ExpressionType::COMPARE_NOT_IN) { // COMPARE_IN, COMPARE_NOT_IN - return sum + (expr.children.size() - 1) * 100; + return sum + (expr.GetChildren().size() - 1) * 100; } else if (expr_type == ExpressionType::OPERATOR_NOT) { // OPERATOR_NOT return sum + 10; // TODO: evaluate via measured runtimes @@ -229,33 +227,13 @@ idx_t ExpressionHeuristics::Cost(const Expression &expr) { } idx_t ExpressionHeuristics::Cost(const TableFilter &filter) { - if (filter.filter_type == TableFilterType::EXPRESSION_FILTER) { - auto &expr_filter = filter.Cast(); - return Cost(*expr_filter.expr); - } - switch (filter.filter_type) { - case TableFilterType::DYNAMIC_FILTER: - case TableFilterType::OPTIONAL_FILTER: + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "ExpressionHeuristics::Cost"); + auto &expr = *expr_filter.expr; + if (ExpressionFilter::ContainsInternalFunction(expr, DynamicFilterScalarFun::NAME) || + ExpressionFilter::IsOptionalExpression(expr)) { return 0; - case TableFilterType::CONJUNCTION_OR: { - auto &conjunction_and = filter.Cast(); - idx_t cost = 5; - for (auto &child_filter : conjunction_and.child_filters) { - cost += Cost(*child_filter); - } - return cost; - } - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and = filter.Cast(); - idx_t cost = 5; - for (auto &child_filter : conjunction_and.child_filters) { - cost += Cost(*child_filter); - } - return cost; - } - default: - return 1000; } + return Cost(expr); } vector ExpressionHeuristics::GetInitialOrder(const TableFilterSet &table_filters) { diff --git a/src/duckdb/src/optimizer/expression_rewriter.cpp b/src/duckdb/src/optimizer/expression_rewriter.cpp index 8fa566c82..09a89cbb4 100644 --- a/src/duckdb/src/optimizer/expression_rewriter.cpp +++ b/src/duckdb/src/optimizer/expression_rewriter.cpp @@ -10,8 +10,14 @@ namespace duckdb { -unique_ptr ExpressionRewriter::ApplyRules(LogicalOperator &op, const vector> &rules, - unique_ptr expr, bool &changes_made, bool is_root) { +struct RewriteFrame { + reference> expr; + bool is_root; + bool children_visited; +}; + +static unique_ptr ApplyRule(LogicalOperator &op, const vector> &rules, + unique_ptr expr, bool &changes_made, bool is_root) { for (auto &rule : rules) { vector> bindings; if (rule.get().root->Match(*expr, bindings)) { @@ -22,25 +28,54 @@ unique_ptr ExpressionRewriter::ApplyRules(LogicalOperator &op, const if (result) { changes_made = true; // the base node changed: the rule applied changes - // rerun on the new node if (!alias.empty()) { result->SetAlias(std::move(alias)); } - return ExpressionRewriter::ApplyRules(op, rules, std::move(result), changes_made); + return result; } else if (rule_made_change) { changes_made = true; - // the base node didn't change, but changes were made, rerun return expr; } // else nothing changed, continue to the next rule continue; } } - // no changes could be made to this node - // recursively run on the children of this node - ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { - child = ExpressionRewriter::ApplyRules(op, rules, std::move(child), changes_made); - }); + return expr; +} + +static void CollectChildren(Expression &expr, vector>> &children) { + children.clear(); + ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &child) { children.push_back(child); }); +} + +unique_ptr ExpressionRewriter::ApplyRules(LogicalOperator &op, const vector> &rules, + unique_ptr expr, bool &changes_made, bool is_root) { + vector stack; + vector>> children; + stack.push_back({expr, is_root, false}); + + while (!stack.empty()) { + auto &frame = stack.back(); + auto ¤t_expr = frame.expr.get(); + D_ASSERT(current_expr); + if (!frame.children_visited) { + frame.children_visited = true; + // Rewrite the subtree bottom-up so parent rules see already-simplified children. + CollectChildren(*current_expr, children); + for (idx_t i = children.size(); i > 0; i--) { + stack.push_back({children[i - 1], false, false}); + } + continue; + } + bool node_made_change = false; + current_expr = ApplyRule(op, rules, std::move(current_expr), node_made_change, frame.is_root); + if (node_made_change) { + changes_made = true; + frame.children_visited = false; + continue; + } + stack.pop_back(); + } return expr; } @@ -90,11 +125,8 @@ void ExpressionRewriter::VisitOperator(LogicalOperator &op) { } void ExpressionRewriter::VisitExpression(unique_ptr *expression) { - bool changes_made; - do { - changes_made = false; - *expression = ExpressionRewriter::ApplyRules(*op, to_apply_rules, std::move(*expression), changes_made, true); - } while (changes_made); + bool changes_made = false; + *expression = ExpressionRewriter::ApplyRules(*op, to_apply_rules, std::move(*expression), changes_made, true); } ClientContext &Rule::GetContext() const { diff --git a/src/duckdb/src/optimizer/filter_combiner.cpp b/src/duckdb/src/optimizer/filter_combiner.cpp index 60bfba2f1..4ea09adb6 100644 --- a/src/duckdb/src/optimizer/filter_combiner.cpp +++ b/src/duckdb/src/optimizer/filter_combiner.cpp @@ -15,8 +15,7 @@ #include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/common/operator/add.hpp" #include "duckdb/common/operator/subtract.hpp" @@ -25,6 +24,7 @@ #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "utf8proc_wrapper.hpp" +#include "duckdb/optimizer/in_clause_rewriter.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" namespace duckdb { @@ -111,6 +111,42 @@ FilterResult FilterCombiner::AddFilter(unique_ptr expr) { return result; } +void FilterCombiner::GenerateEquivalentFilters(const Expression &filter, + const std::function filter)> &callback) { + if (filter.IsVolatile()) { + return; + } + // collect every column reference that belongs to an equivalence set + vector> candidate_columns; + ExpressionIterator::VisitExpression(filter, [&](const BoundColumnRefExpression &col) { + auto entry = equivalence_set_map.find(col); + if (entry == equivalence_set_map.end()) { + return; + } + for (auto &existing : candidate_columns) { + if (existing.get().Equals(col)) { + return; + } + } + candidate_columns.push_back(col); + }); + // for each such column, generate an equivalent filter with the columns swapped + for (auto &col_ref : candidate_columns) { + auto &col = col_ref.get(); + auto set_id = equivalence_set_map.find(col)->second; + for (auto &item : equivalence_map[set_id]) { + auto copy = filter.Copy(); + ExpressionIterator::VisitExpressionMutable( + copy, [&](BoundColumnRefExpression &cref, unique_ptr &child) { + if (cref.Equals(col)) { + child = item.get().Copy(); + } + }); + callback(std::move(copy)); + } + } +} + void FilterCombiner::GenerateFilters(const std::function filter)> &callback) { // first loop over the remaining filters for (auto &filter : remaining_filters) { @@ -149,6 +185,10 @@ void FilterCombiner::GenerateFilters(const std::function(info.constant); auto comparison = BoundComparisonExpression::Create(info.comparison_type, entries[i].get().Copy(), std::move(constant)); + // column refs are already covered above + if (entries[i].get().GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { + GenerateEquivalentFilters(*comparison, callback); + } callback(std::move(comparison)); } } @@ -197,13 +237,13 @@ static bool TryGetProjectionIndex(const Expression &expr, ProjectionIndex &resul switch (expr.GetExpressionType()) { case ExpressionType::BOUND_COLUMN_REF: { auto &ref = expr.Cast(); - result = ref.binding.column_index; + result = ref.Binding().column_index; return true; } case ExpressionType::BOUND_FUNCTION: { auto &func = expr.Cast(); - if (func.function.GetName() == "struct_extract" || func.function.GetName() == "struct_extract_at") { - auto &child_expr = func.children[0]; + if (func.Function().GetName() == "struct_extract" || func.Function().GetName() == "struct_extract_at") { + auto &child_expr = func.GetChildren()[0]; return TryGetProjectionIndex(*child_expr, result); } return false; @@ -228,6 +268,11 @@ static unique_ptr CreateComparisonExpression(const Expression &expr, return BoundComparisonExpression::Create(comparison_type, std::move(left), std::move(right)); } +static unique_ptr CreateOptionalExpressionFilter(unique_ptr child_expr, + const LogicalType &col_type) { + return make_uniq(CreateOptionalFilterExpression(std::move(child_expr), col_type)); +} + bool FilterCombiner::ContainsNull(vector &in_list) { for (idx_t i = 0; i < in_list.size(); i++) { if (in_list[i].IsNull()) { @@ -280,35 +325,6 @@ static bool SupportedFilterComparison(ExpressionType expression_type) { } } -bool FilterCombiner::FindNextLegalUTF8(string &prefix_string) { - // find the start of the last codepoint - idx_t last_codepoint_start; - for (last_codepoint_start = prefix_string.size(); last_codepoint_start > 0; last_codepoint_start--) { - if (IsCharacter(prefix_string[last_codepoint_start - 1])) { - break; - } - } - if (last_codepoint_start == 0) { - throw InvalidInputException("Invalid UTF8 found in string \"%s\"", prefix_string); - } - last_codepoint_start--; - int codepoint_size; - auto codepoint = Utf8Proc::UTF8ToCodepoint(prefix_string.c_str() + last_codepoint_start, codepoint_size) + 1; - if (codepoint >= 0xD800 && codepoint <= 0xDFFF) { - // next codepoint falls within surrogate range increment to next valid character - codepoint = 0xE000; - } - char next_codepoint_text[4]; - int next_codepoint_size; - if (!Utf8Proc::CodepointToUtf8(codepoint, next_codepoint_size, next_codepoint_text)) { - // invalid codepoint - return false; - } - auto s = static_cast(next_codepoint_size); - prefix_string = prefix_string.substr(0, last_codepoint_start) + string(next_codepoint_text, s); - return true; -} - static bool TypeSupportsConstantFilter(const LogicalType &type) { if (TypeIsNumeric(type.InternalType())) { return true; @@ -405,29 +421,29 @@ FilterPushdownResult FilterCombiner::TryPushdownPrefixFilter(TableFilterSet &tab return FilterPushdownResult::NO_PUSHDOWN; } auto &func = expr.Cast(); - if (func.function.GetName() != "prefix") { + if (func.Function().GetName() != "prefix") { return FilterPushdownResult::NO_PUSHDOWN; } - if (func.children[0]->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF || - func.children[1]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { + if (func.GetChildren()[0]->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF || + func.GetChildren()[1]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { // we need prefix(col, 'literal') in order to push this down return FilterPushdownResult::NO_PUSHDOWN; } - auto &column_ref = func.children[0]->Cast(); - auto &constant_value_expr = func.children[1]->Cast(); - auto prefix_string = StringValue::Get(constant_value_expr.value); + auto &column_ref = func.GetChildren()[0]->Cast(); + auto &constant_value_expr = func.GetChildren()[1]->Cast(); + auto prefix_string = StringValue::Get(constant_value_expr.GetValue()); if (prefix_string.empty()) { // empty prefix - skip return FilterPushdownResult::NO_PUSHDOWN; } - auto filter_idx = column_ref.binding.column_index; + auto filter_idx = column_ref.Binding().column_index; //! Replace prefix with a set of comparisons - auto lower_bound = CreateComparisonExpression(*func.children[0], ExpressionType::COMPARE_GREATERTHANOREQUALTO, + auto lower_bound = CreateComparisonExpression(*func.GetChildren()[0], ExpressionType::COMPARE_GREATERTHANOREQUALTO, Value(prefix_string)); table_filters.PushFilter(filter_idx, make_uniq(std::move(lower_bound))); - if (FilterCombiner::FindNextLegalUTF8(prefix_string)) { + if (Utf8Proc::FindNextLegalUTF8(prefix_string)) { auto upper_bound = - CreateComparisonExpression(*func.children[0], ExpressionType::COMPARE_LESSTHAN, Value(prefix_string)); + CreateComparisonExpression(*func.GetChildren()[0], ExpressionType::COMPARE_LESSTHAN, Value(prefix_string)); table_filters.PushFilter(filter_idx, make_uniq(std::move(upper_bound))); return FilterPushdownResult::PUSHED_DOWN_FULLY; } @@ -441,29 +457,29 @@ FilterPushdownResult FilterCombiner::TryPushdownLikeFilter(TableFilterSet &table return FilterPushdownResult::NO_PUSHDOWN; } auto &func = expr.Cast(); - if (func.function.GetName() != "~~") { + if (func.Function().GetName() != "~~") { return FilterPushdownResult::NO_PUSHDOWN; } - if (func.children[0]->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF || - func.children[1]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { + if (func.GetChildren()[0]->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF || + func.GetChildren()[1]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { // we need col LIKE 'literal' in order to generate extra filters return FilterPushdownResult::NO_PUSHDOWN; } //! This is a like function. - auto &column_ref = func.children[0]->Cast(); - auto &constant_value_expr = func.children[1]->Cast(); - auto proj_index = column_ref.binding.column_index; + auto &column_ref = func.GetChildren()[0]->Cast(); + auto &constant_value_expr = func.GetChildren()[1]->Cast(); + auto proj_index = column_ref.Binding().column_index; // constant value expr can sometimes be null. if so, push is not null filter, which will // make the filter unsatisfiable and return no results. - if (constant_value_expr.value.IsNull()) { - auto is_not_null = ExpressionFilter::CreateNullCheckExpression(CreateFilterTargetExpression(*func.children[0]), - ExpressionType::OPERATOR_IS_NOT_NULL); + if (constant_value_expr.GetValue().IsNull()) { + auto is_not_null = ExpressionFilter::CreateNullCheckExpression( + CreateFilterTargetExpression(*func.GetChildren()[0]), ExpressionType::OPERATOR_IS_NOT_NULL); table_filters.PushFilter(proj_index, make_uniq(std::move(is_not_null))); return FilterPushdownResult::PUSHED_DOWN_FULLY; } - auto &like_string = StringValue::Get(constant_value_expr.value); + auto &like_string = StringValue::Get(constant_value_expr.GetValue()); if (like_string[0] == '%' || like_string[0] == '_') { //! If the like starts with a special character we have no fixed prefix so nothing to pushdown return FilterPushdownResult::NO_PUSHDOWN; @@ -479,7 +495,8 @@ FilterPushdownResult FilterCombiner::TryPushdownLikeFilter(TableFilterSet &table } if (equality) { //! If the LIKE has no special characters we can turn it into an equality and push that down - auto equal_filter = CreateComparisonExpression(*func.children[0], ExpressionType::COMPARE_EQUAL, Value(prefix)); + auto equal_filter = + CreateComparisonExpression(*func.GetChildren()[0], ExpressionType::COMPARE_EQUAL, Value(prefix)); table_filters.PushFilter(proj_index, make_uniq(std::move(equal_filter))); return FilterPushdownResult::PUSHED_DOWN_FULLY; } @@ -487,9 +504,10 @@ FilterPushdownResult FilterCombiner::TryPushdownLikeFilter(TableFilterSet &table //! We have a prefix - we can push down the prefix using a bound (x >= PREFIX AND x <= prefix + 1) // Note that we still need to execute the LIKE filter auto lower_bound = - CreateComparisonExpression(*func.children[0], ExpressionType::COMPARE_GREATERTHANOREQUALTO, Value(prefix)); + CreateComparisonExpression(*func.GetChildren()[0], ExpressionType::COMPARE_GREATERTHANOREQUALTO, Value(prefix)); prefix[prefix.size() - 1]++; - auto upper_bound = CreateComparisonExpression(*func.children[0], ExpressionType::COMPARE_LESSTHAN, Value(prefix)); + auto upper_bound = + CreateComparisonExpression(*func.GetChildren()[0], ExpressionType::COMPARE_LESSTHAN, Value(prefix)); table_filters.PushFilter(proj_index, make_uniq(std::move(lower_bound))); table_filters.PushFilter(proj_index, make_uniq(std::move(upper_bound))); return FilterPushdownResult::PUSHED_DOWN_PARTIALLY; @@ -501,23 +519,23 @@ FilterPushdownResult FilterCombiner::TryPushdownInFilter(TableFilterSet &table_f return FilterPushdownResult::NO_PUSHDOWN; } auto &func = expr.Cast(); - D_ASSERT(func.children.size() > 1); - if (func.children[0]->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { + D_ASSERT(func.GetChildren().size() > 1); + if (func.GetChildren()[0]->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { // we need col IN (...) to be able to push this down return FilterPushdownResult::NO_PUSHDOWN; } - auto &column_ref = func.children[0]->Cast(); - auto proj_index = column_ref.binding.column_index; + auto &column_ref = func.GetChildren()[0]->Cast(); + auto proj_index = column_ref.Binding().column_index; //! check if all children are const expr bool children_constant = true; - for (size_t i {1}; i < func.children.size(); i++) { - if (func.children[i]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { + for (size_t i {1}; i < func.GetChildren().size(); i++) { + if (func.GetChildren()[i]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { children_constant = false; break; } - auto &const_value_expr = func.children[i]->Cast(); - if (const_value_expr.value.IsNull()) { + auto &const_value_expr = func.GetChildren()[i]->Cast(); + if (const_value_expr.GetValue().IsNull()) { // cannot simplify NULL values children_constant = false; break; @@ -527,13 +545,13 @@ FilterPushdownResult FilterCombiner::TryPushdownInFilter(TableFilterSet &table_f // all children must be constant return FilterPushdownResult::NO_PUSHDOWN; } - auto &fst_const_value_expr = func.children[1]->Cast(); - auto &type = fst_const_value_expr.value.type(); + auto &fst_const_value_expr = func.GetChildren()[1]->Cast(); + auto &type = fst_const_value_expr.GetValue().type(); - if (func.children.size() == 2 && TypeSupportsConstantFilter(type)) { + if (func.GetChildren().size() == 2 && TypeSupportsConstantFilter(type)) { // col IN (literal) is equivalent to an equality comparison - push that down - auto bound_eq_comparison = - CreateComparisonExpression(*func.children[0], ExpressionType::COMPARE_EQUAL, fst_const_value_expr.value); + auto bound_eq_comparison = CreateComparisonExpression(*func.GetChildren()[0], ExpressionType::COMPARE_EQUAL, + fst_const_value_expr.GetValue()); table_filters.PushFilter(proj_index, make_uniq(std::move(bound_eq_comparison))); return FilterPushdownResult::PUSHED_DOWN_FULLY; } @@ -541,17 +559,17 @@ FilterPushdownResult FilterCombiner::TryPushdownInFilter(TableFilterSet &table_f //! Check if values are consecutive, if yes transform them to >= <= (only for integers) // e.g. if we have x IN (1, 2, 3, 4, 5) we transform this into x >= 1 AND x <= 5 vector in_list; - for (idx_t i = 1; i < func.children.size(); i++) { - auto &const_value_expr = func.children[i]->Cast(); - D_ASSERT(!const_value_expr.value.IsNull()); - in_list.push_back(const_value_expr.value); + for (idx_t i = 1; i < func.GetChildren().size(); i++) { + auto &const_value_expr = func.GetChildren()[i]->Cast(); + D_ASSERT(!const_value_expr.GetValue().IsNull()); + in_list.push_back(const_value_expr.GetValue()); } if (type.IsIntegral() && IsDenseRange(in_list)) { // dense range! turn this into x >= min AND x <= max // IsDenseRange sorts in_list, so the front element is the min and the back element is the max - auto lower_bound = CreateComparisonExpression(*func.children[0], ExpressionType::COMPARE_GREATERTHANOREQUALTO, - std::move(in_list.front())); - auto upper_bound = CreateComparisonExpression(*func.children[0], ExpressionType::COMPARE_LESSTHANOREQUALTO, + auto lower_bound = CreateComparisonExpression( + *func.GetChildren()[0], ExpressionType::COMPARE_GREATERTHANOREQUALTO, std::move(in_list.front())); + auto upper_bound = CreateComparisonExpression(*func.GetChildren()[0], ExpressionType::COMPARE_LESSTHANOREQUALTO, std::move(in_list.back())); table_filters.PushFilter(proj_index, make_uniq(std::move(lower_bound))); table_filters.PushFilter(proj_index, make_uniq(std::move(upper_bound))); @@ -559,9 +577,8 @@ FilterPushdownResult FilterCombiner::TryPushdownInFilter(TableFilterSet &table_f } // if this is not a dense range we can push an optional filter for zone-map pruning auto in_expr = - ExpressionFilter::CreateInExpression(CreateFilterTargetExpression(*func.children[0]), std::move(in_list)); - auto optional_filter = make_uniq(make_uniq(std::move(in_expr))); - table_filters.PushFilter(proj_index, std::move(optional_filter)); + ExpressionFilter::CreateInExpression(CreateFilterTargetExpression(*func.GetChildren()[0]), std::move(in_list)); + table_filters.PushFilter(proj_index, CreateOptionalExpressionFilter(std::move(in_expr), type)); return FilterPushdownResult::PUSHED_DOWN_PARTIALLY; } @@ -574,13 +591,14 @@ FilterPushdownResult FilterCombiner::TryPushdownOrClause(TableFilterSet &table_f if (conj.GetExpressionType() != ExpressionType::CONJUNCTION_OR) { return FilterPushdownResult::NO_PUSHDOWN; } - auto conj_filter = make_uniq(); - if (conj.children.empty()) { + auto conj_filter = make_uniq(ExpressionType::CONJUNCTION_OR); + if (conj.GetChildren().empty()) { return FilterPushdownResult::NO_PUSHDOWN; } ProjectionIndex proj_id; - for (idx_t i = 0; i < conj.children.size(); i++) { - auto &child = conj.children[i]; + LogicalType col_type = LogicalType::INVALID; + for (idx_t i = 0; i < conj.GetChildren().size(); i++) { + auto &child = conj.GetChildren()[i]; if (!BoundComparisonExpression::IsComparison(*child)) { return FilterPushdownResult::NO_PUSHDOWN; } @@ -604,24 +622,25 @@ FilterPushdownResult FilterCombiner::TryPushdownOrClause(TableFilterSet &table_f return FilterPushdownResult::NO_PUSHDOWN; } if (!proj_id.IsValid()) { - proj_id = column_ref->binding.column_index; - } else if (proj_id != column_ref->binding.column_index) { + proj_id = column_ref->Binding().column_index; + col_type = column_ref->GetReturnType(); + } else if (proj_id != column_ref->Binding().column_index) { return FilterPushdownResult::NO_PUSHDOWN; } auto comparison_type = invert ? FlipComparisonExpression(comp.GetExpressionType()) : comp.GetExpressionType(); - if (const_val->value.IsNull()) { + if (const_val->GetValue().IsNull()) { switch (comparison_type) { case ExpressionType::COMPARE_DISTINCT_FROM: { auto null_expr = ExpressionFilter::CreateNullCheckExpression(CreateFilterTargetExpression(*column_ref), ExpressionType::OPERATOR_IS_NOT_NULL); - conj_filter->child_filters.push_back(make_uniq(std::move(null_expr))); + conj_filter->GetChildrenMutable().push_back(std::move(null_expr)); break; } case ExpressionType::COMPARE_NOT_DISTINCT_FROM: { auto null_expr = ExpressionFilter::CreateNullCheckExpression(CreateFilterTargetExpression(*column_ref), ExpressionType::OPERATOR_IS_NULL); - conj_filter->child_filters.push_back(make_uniq(std::move(null_expr))); + conj_filter->GetChildrenMutable().push_back(std::move(null_expr)); break; } default: @@ -630,14 +649,11 @@ FilterPushdownResult FilterCombiner::TryPushdownOrClause(TableFilterSet &table_f break; } } else { - auto expr_filter = - make_uniq(CreateComparisonExpression(*column_ref, comparison_type, const_val->value)); - conj_filter->child_filters.push_back(std::move(expr_filter)); + conj_filter->GetChildrenMutable().push_back( + CreateComparisonExpression(*column_ref, comparison_type, const_val->GetValue())); } } - auto optional_filter = make_uniq(); - optional_filter->child_filter = std::move(conj_filter); - table_filters.PushFilter(proj_id, std::move(optional_filter)); + table_filters.PushFilter(proj_id, CreateOptionalExpressionFilter(std::move(conj_filter), col_type)); return FilterPushdownResult::PUSHED_DOWN_PARTIALLY; } @@ -822,7 +838,7 @@ FilterPushdownResult FilterCombiner::TryPushdownTemporalCastFilter(TableFilterSe // the child of the cast must resolve to a column ref ProjectionIndex proj_index; - if (!TryGetProjectionIndex(*cast_expr.child, proj_index)) { + if (!TryGetProjectionIndex(cast_expr.Child(), proj_index)) { return FilterPushdownResult::NO_PUSHDOWN; } @@ -840,11 +856,8 @@ FilterPushdownResult FilterCombiner::TryPushdownTemporalCastFilter(TableFilterSe } auto push_optional = [&](ExpressionType filter_type, Value filter_val) { - auto opt_filter = make_uniq(); - auto expr_filter = make_uniq( - CreateComparisonExpression(*cast_expr.child, filter_type, std::move(filter_val))); - opt_filter->child_filter = std::move(expr_filter); - table_filters.PushFilter(proj_index, std::move(opt_filter)); + auto filter_expr = CreateComparisonExpression(cast_expr.Child(), filter_type, std::move(filter_val)); + table_filters.PushFilter(proj_index, CreateOptionalExpressionFilter(std::move(filter_expr), source_type)); }; // push relaxed filter(s) as OptionalFilter @@ -1162,27 +1175,27 @@ FilterResult FilterCombiner::AddTransitiveFilters(BoundFunctionExpression &compa break; } auto &bound_cast_expr = right_node.get().Cast(); - if (bound_cast_expr.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + if (bound_cast_expr.Child().GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { break; } - auto &col_ref = bound_cast_expr.child->Cast(); + auto &col_ref = bound_cast_expr.Child().Cast(); for (auto &stored_exp : stored_expressions) { const_reference expr = stored_exp.first; if (expr.get().GetExpressionType() == ExpressionType::OPERATOR_CAST) { - expr = *(right_node.get().Cast().child); + expr = right_node.get().Cast().Child(); } if (expr.get().GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { continue; } auto &st_col_ref = expr.get().Cast(); - if (st_col_ref.binding != col_ref.binding) { + if (st_col_ref.Binding() != col_ref.Binding()) { continue; } if (bound_cast_expr.GetReturnType() != stored_exp.second->GetReturnType()) { continue; } - bound_cast_expr.child = stored_exp.second->Copy(); - right_node = GetNode(*bound_cast_expr.child); + bound_cast_expr.ChildMutable() = stored_exp.second->Copy(); + right_node = GetNode(*bound_cast_expr.ChildMutable()); break; } } while (false); diff --git a/src/duckdb/src/optimizer/filter_pushdown.cpp b/src/duckdb/src/optimizer/filter_pushdown.cpp index e75e4676b..925fb0775 100644 --- a/src/duckdb/src/optimizer/filter_pushdown.cpp +++ b/src/duckdb/src/optimizer/filter_pushdown.cpp @@ -53,7 +53,7 @@ void FilterPushdown::CheckMarkToSemi(LogicalOperator &op, unordered_set(); - new_table_bindings.insert(col_ref.binding.table_index); + new_table_bindings.insert(col_ref.Binding().table_index); } }); table_bindings = new_table_bindings; @@ -71,7 +71,7 @@ void FilterPushdown::CheckMarkToSemi(LogicalOperator &op, unordered_set(); - bindings_to_keep.push_back(col_ref.binding); + bindings_to_keep.push_back(col_ref.Binding()); } }); } @@ -79,7 +79,7 @@ void FilterPushdown::CheckMarkToSemi(LogicalOperator &op, unordered_set(); - bindings_to_keep.push_back(col_ref.binding); + bindings_to_keep.push_back(col_ref.Binding()); } }); } diff --git a/src/duckdb/src/optimizer/in_clause_rewriter.cpp b/src/duckdb/src/optimizer/in_clause_rewriter.cpp index 28a19df69..61980200f 100644 --- a/src/duckdb/src/optimizer/in_clause_rewriter.cpp +++ b/src/duckdb/src/optimizer/in_clause_rewriter.cpp @@ -38,35 +38,35 @@ unique_ptr InClauseRewriter::VisitReplace(BoundOperatorExpression &e return nullptr; } D_ASSERT(root); - auto in_type = expr.children[0]->GetReturnType(); + auto in_type = expr.GetChildrenMutable()[0]->GetReturnType(); bool is_regular_in = expr.GetExpressionType() == ExpressionType::COMPARE_IN; bool all_scalar = true; // IN clause with many children: try to generate a mark join that replaces this IN expression // we can only do this if the expressions in the expression list are scalar - for (idx_t i = 1; i < expr.children.size(); i++) { - if (!expr.children[i]->IsFoldable()) { + for (idx_t i = 1; i < expr.GetChildrenMutable().size(); i++) { + if (!expr.GetChildrenMutable()[i]->IsFoldable()) { // non-scalar expression all_scalar = false; } } - if (expr.children.size() == 2) { + if (expr.GetChildrenMutable().size() == 2) { // only one child // IN: turn into X = 1 // NOT IN: turn into X <> 1 - return BoundComparisonExpression::Create(is_regular_in ? ExpressionType::COMPARE_EQUAL - : ExpressionType::COMPARE_NOTEQUAL, - std::move(expr.children[0]), std::move(expr.children[1])); + return BoundComparisonExpression::Create( + is_regular_in ? ExpressionType::COMPARE_EQUAL : ExpressionType::COMPARE_NOTEQUAL, + std::move(expr.GetChildrenMutable()[0]), std::move(expr.GetChildrenMutable()[1])); } - if (expr.children.size() < 6 || !all_scalar) { + if (expr.GetChildrenMutable().size() < IN_CLAUSE_REWRITE_THRESHOLD || !all_scalar) { // low amount of children or not all scalar // IN: turn into (X = 1 OR X = 2 OR X = 3...) // NOT IN: turn into (X <> 1 AND X <> 2 AND X <> 3 ...) auto conjunction = make_uniq(is_regular_in ? ExpressionType::CONJUNCTION_OR : ExpressionType::CONJUNCTION_AND); - for (idx_t i = 1; i < expr.children.size(); i++) { - conjunction->children.push_back(BoundComparisonExpression::Create( + for (idx_t i = 1; i < expr.GetChildrenMutable().size(); i++) { + conjunction->GetChildrenMutable().push_back(BoundComparisonExpression::Create( is_regular_in ? ExpressionType::COMPARE_EQUAL : ExpressionType::COMPARE_NOTEQUAL, - expr.children[0]->Copy(), std::move(expr.children[i]))); + expr.GetChildrenMutable()[0]->Copy(), std::move(expr.GetChildrenMutable()[i]))); } return std::move(conjunction); } @@ -80,16 +80,15 @@ unique_ptr InClauseRewriter::VisitReplace(BoundOperatorExpression &e DataChunk chunk; chunk.Initialize(context, types); - for (idx_t i = 1; i < expr.children.size(); i++) { + for (idx_t i = 1; i < expr.GetChildrenMutable().size(); i++) { // resolve this expression to a constant Value value; - if (!ExpressionExecutor::TryEvaluateScalar(context, *expr.children[i], value)) { + if (!ExpressionExecutor::TryEvaluateScalar(context, *expr.GetChildrenMutable()[i], value)) { // error while evaluating scalar return nullptr; } - chunk.SetCardinality(chunk.size() + 1); chunk.data[0].Append(value); - if (chunk.size() == STANDARD_VECTOR_SIZE || i + 1 == expr.children.size()) { + if (chunk.size() == STANDARD_VECTOR_SIZE || i + 1 == expr.GetChildrenMutable().size()) { // chunk full: append to chunk collection collection->Append(append_state, chunk); chunk.Reset(); @@ -106,7 +105,7 @@ unique_ptr InClauseRewriter::VisitReplace(BoundOperatorExpression &e join->AddChild(std::move(root)); join->AddChild(std::move(chunk_scan)); // create the JOIN condition - JoinCondition cond(std::move(expr.children[0]), + JoinCondition cond(std::move(expr.GetChildrenMutable()[0]), make_uniq(in_type, ColumnBinding(chunk_index, ProjectionIndex(0))), ExpressionType::COMPARE_EQUAL); join->conditions.push_back(std::move(cond)); @@ -131,7 +130,7 @@ unique_ptr InClauseRewriter::VisitReplace(BoundOperatorExpression &e if (!is_regular_in) { // NOT IN: invert auto invert = make_uniq(ExpressionType::OPERATOR_NOT, LogicalType::BOOLEAN); - invert->children.push_back(std::move(result)); + invert->GetChildrenMutable().push_back(std::move(result)); result = std::move(invert); } return result; diff --git a/src/duckdb/src/optimizer/join_elimination.cpp b/src/duckdb/src/optimizer/join_elimination.cpp index a47b98e05..02c01ceca 100644 --- a/src/duckdb/src/optimizer/join_elimination.cpp +++ b/src/duckdb/src/optimizer/join_elimination.cpp @@ -46,7 +46,7 @@ void JoinElimination::OptimizeChildren(LogicalOperator &op, optional_ptrGetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { break; } - auto table_idx = distinct.distinct_targets[0]->Cast().binding.table_index; + auto table_idx = distinct.distinct_targets[0]->Cast().Binding().table_index; bool can_add = true; for (auto &target : distinct.distinct_targets) { if (target->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { @@ -54,8 +54,8 @@ void JoinElimination::OptimizeChildren(LogicalOperator &op, optional_ptrCast(); - distinct_group.insert(col_ref.binding); - D_ASSERT(table_idx == col_ref.binding.table_index); + distinct_group.insert(col_ref.Binding()); + D_ASSERT(table_idx == col_ref.Binding().table_index); } if (can_add) { pipe_info.distinct_groups[table_idx] = std::move(distinct_group); @@ -94,7 +94,7 @@ void JoinElimination::OptimizeChildren(LogicalOperator &op, optional_ptr().binding.table_index; + auto ref_id = start_expr.Cast().Binding().table_index; for (auto &col : it->second) { auto &expression = projection.GetExpression(col); if (expression.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { @@ -103,11 +103,11 @@ void JoinElimination::OptimizeChildren(LogicalOperator &op, optional_ptr(); - if (ref_id != col_ref.binding.table_index) { + if (ref_id != col_ref.Binding().table_index) { could_add = false; break; } - new_distinct_group.insert(col_ref.binding); + new_distinct_group.insert(col_ref.Binding()); } if (could_add) { pipe_info.distinct_groups[ref_id] = std::move(new_distinct_group); @@ -147,20 +147,20 @@ void JoinElimination::OptimizeChildren(LogicalOperator &op, optional_ptrGetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &col_ref = expression->Cast(); - auto distinct_group_it = pipe_info.distinct_groups.find(col_ref.binding.table_index); + auto distinct_group_it = pipe_info.distinct_groups.find(col_ref.Binding().table_index); if (distinct_group_it == pipe_info.distinct_groups.end()) { continue; } - if (ref_table_columns.find(col_ref.binding.table_index) == ref_table_columns.end()) { + if (ref_table_columns.find(col_ref.Binding().table_index) == ref_table_columns.end()) { auto ref = DistinctGroupRef(); for (auto &col : distinct_group_it->second) { ref.ref_column_ids.insert(col.column_index); } - ref_table_columns[col_ref.binding.table_index] = ref; + ref_table_columns[col_ref.Binding().table_index] = ref; } - ref_table_columns[col_ref.binding.table_index].distinct_group.insert( + ref_table_columns[col_ref.Binding().table_index].distinct_group.insert( ColumnBinding(projection.table_index, ProjectionIndex(proj_idx))); - ref_table_columns[col_ref.binding.table_index].ref_column_ids.erase(col_ref.binding.column_index); + ref_table_columns[col_ref.Binding().table_index].ref_column_ids.erase(col_ref.Binding().column_index); } } for (auto &refs : ref_table_columns) { @@ -270,8 +270,8 @@ unique_ptr JoinElimination::TryEliminateJoin() { is_output_unique = false; break; } - auto inner_binding = inner_idx == 0 ? condition.GetLHS().Cast().binding - : condition.GetRHS().Cast().binding; + auto inner_binding = inner_idx == 0 ? condition.GetLHS().Cast().Binding() + : condition.GetRHS().Cast().Binding(); col_bindings.push_back(inner_binding); } if (is_output_unique && !ContainDistinctGroup(col_bindings)) { @@ -310,7 +310,7 @@ bool JoinElimination::ContainDistinctGroup(vector &column_binding } unique_ptr JoinElimination::VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) { - pipe_info.ref_table_ids.insert(expr.binding.table_index); + pipe_info.ref_table_ids.insert(expr.Binding().table_index); return nullptr; } diff --git a/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp b/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp index dae36b699..820fb35e7 100644 --- a/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp +++ b/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp @@ -20,7 +20,7 @@ namespace duckdb { JoinFilterPushdownOptimizer::JoinFilterPushdownOptimizer(Optimizer &optimizer) : optimizer(optimizer) { } -bool PushdownJoinFilterExpression(const Expression &expr, JoinFilterPushdownColumn &filter) { +bool JoinFilterPushdownUtil::PushdownJoinFilterExpression(const Expression &expr, JoinFilterPushdownColumn &filter) { if (expr.GetReturnType().IsNested()) { // nested columns are not supported for pushdown return false; @@ -29,17 +29,20 @@ bool PushdownJoinFilterExpression(const Expression &expr, JoinFilterPushdownColu // interval is not supported for pushdown return false; } + if (filter.mode == JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION && !filter.runtime_filter_type.IsValid()) { + filter.runtime_filter_type = expr.GetReturnType(); + } switch (expr.GetExpressionClass()) { case ExpressionClass::BOUND_COLUMN_REF: { // column-ref - pass through the new column binding auto &colref = expr.Cast(); - filter.probe_column_index = colref.binding; + filter.probe_column_index = colref.Binding(); return true; } case ExpressionClass::BOUND_CAST: { // We allow pushing through integral down/upcasts, as long as source/target are (u)bigint or smaller const auto &bound_cast = expr.Cast(); - const auto &src = bound_cast.child->GetReturnType(); + const auto &src = bound_cast.Child().GetReturnType(); const auto &tgt = bound_cast.GetReturnType(); if (!src.IsIntegral() || !tgt.IsIntegral()) { return false; @@ -48,22 +51,56 @@ bool PushdownJoinFilterExpression(const Expression &expr, JoinFilterPushdownColu GetTypeIdSize(tgt.InternalType()) > GetTypeIdSize(PhysicalType::INT64)) { return false; // Only do this for (u)bigint and smaller } - return PushdownJoinFilterExpression(*bound_cast.child, filter); + if (!JoinFilterPushdownUtil::PushdownJoinFilterExpression(bound_cast.Child(), filter)) { + return false; + } + const bool widening_signed_cast = + src.IsSigned() == tgt.IsSigned() && GetTypeIdSize(tgt.InternalType()) >= GetTypeIdSize(src.InternalType()); + const bool widening_unsigned_to_signed_cast = + !src.IsSigned() && tgt.IsSigned() && GetTypeIdSize(tgt.InternalType()) > GetTypeIdSize(src.InternalType()); + if (widening_signed_cast || widening_unsigned_to_signed_cast) { + filter.runtime_filter_type = expr.GetReturnType(); + } else { + filter.mode = JoinFilterPushdownMode::STORAGE_ONLY; + filter.runtime_filter_type = LogicalType::INVALID; + } + return true; } default: return false; } } +bool JoinFilterPushdownUtil::JoinTypeIsSupported(JoinType join_type) { + switch (join_type) { + case JoinType::MARK: + case JoinType::SINGLE: + case JoinType::LEFT: + case JoinType::OUTER: + case JoinType::ANTI: + case JoinType::RIGHT_ANTI: + // cannot generate join filters for these join types + // mark/single - cannot change cardinality of probe side + // left/outer always need to include every row from probe side + // FIXME: anti/right_anti - we could do this, but need to invert the join conditions + return false; + default: + return true; + } +} + void JoinFilterPushdownOptimizer::GetPushdownFilterTargets(LogicalOperator &op, vector columns, vector &targets) { auto &probe_child = op; switch (probe_child.type) { case LogicalOperatorType::LOGICAL_LIMIT: + case LogicalOperatorType::LOGICAL_TOP_N: + // LIMIT/TOP_N determines which rows are part of the probe side before the join. + // Pushing a join filter below it can change which rows survive the limit/offset. + break; case LogicalOperatorType::LOGICAL_FILTER: case LogicalOperatorType::LOGICAL_ORDER_BY: - case LogicalOperatorType::LOGICAL_TOP_N: case LogicalOperatorType::LOGICAL_DISTINCT: case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: @@ -101,7 +138,7 @@ void JoinFilterPushdownOptimizer::GetPushdownFilterTargets(LogicalOperator &op, auto child_bindings = child->GetColumnBindings(); child_columns.reserve(columns.size()); for (auto &child_column : columns) { - JoinFilterPushdownColumn new_col; + auto new_col = child_column; new_col.probe_column_index = child_bindings[child_column.probe_column_index.column_index]; child_columns.push_back(new_col); } @@ -150,7 +187,7 @@ void JoinFilterPushdownOptimizer::GetPushdownFilterTargets(LogicalOperator &op, return; } auto &expr = proj.GetExpression(filter.probe_column_index); - if (!PushdownJoinFilterExpression(expr, filter)) { + if (!JoinFilterPushdownUtil::PushdownJoinFilterExpression(expr, filter)) { // cannot push through this expression - bail-out return; } @@ -167,7 +204,7 @@ void JoinFilterPushdownOptimizer::GetPushdownFilterTargets(LogicalOperator &op, return; } auto &expr = aggr.GetExpression(filter.probe_column_index); - if (!PushdownJoinFilterExpression(expr, filter)) { + if (!JoinFilterPushdownUtil::PushdownJoinFilterExpression(expr, filter)) { // cannot push through this expression - bail-out return; } @@ -204,20 +241,8 @@ bool JoinFilterPushdownOptimizer::IsFiltering(const unique_ptr } void JoinFilterPushdownOptimizer::GenerateJoinFilters(LogicalComparisonJoin &join) { - switch (join.join_type) { - case JoinType::MARK: - case JoinType::SINGLE: - case JoinType::LEFT: - case JoinType::OUTER: - case JoinType::ANTI: - case JoinType::RIGHT_ANTI: - // cannot generate join filters for these join types - // mark/single - cannot change cardinality of probe side - // left/outer always need to include every row from probe side - // FIXME: anti/right_anti - we could do this, but need to invert the join conditions + if (!JoinFilterPushdownUtil::JoinTypeIsSupported(join.join_type)) { return; - default: - break; } // re-order conditions here - otherwise this will happen later on and invalidate the indexes we generate PhysicalComparisonJoin::ReorderConditions(join.conditions); @@ -244,7 +269,7 @@ void JoinFilterPushdownOptimizer::GenerateJoinFilters(LogicalComparisonJoin &joi } JoinFilterPushdownColumn pushdown_col; - if (!PushdownJoinFilterExpression(cond.GetLHS(), pushdown_col)) { + if (!JoinFilterPushdownUtil::PushdownJoinFilterExpression(cond.GetLHS(), pushdown_col)) { continue; } @@ -294,7 +319,7 @@ void JoinFilterPushdownOptimizer::GenerateJoinFilters(LogicalComparisonJoin &joi aggr_children.push_back(join.conditions[join_condition].GetRHS().Copy()); auto aggr_expr = function_binder.BindAggregateFunction(aggr, std::move(aggr_children), nullptr, AggregateType::NON_DISTINCT); - if (aggr_expr->children.size() != 1) { + if (aggr_expr->GetChildren().size() != 1) { // min/max with collation - not supported return; } diff --git a/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp b/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp index 9ad6cc628..051a80d7f 100644 --- a/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp +++ b/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp @@ -1,11 +1,12 @@ #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/limits.hpp" +#include "duckdb/common/optional_idx.hpp" #include "duckdb/common/printer.hpp" #include "duckdb/function/table/table_scan.hpp" #include "duckdb/optimizer/join_order/join_node.hpp" #include "duckdb/optimizer/join_order/query_graph_manager.hpp" -#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/storage/data_table.hpp" @@ -14,32 +15,254 @@ namespace duckdb { +struct DenomInfo { +public: + DenomInfo(JoinRelationSet &numerator_relations, double denominator) + : numerator_relations(numerator_relations), denominator(denominator) { + } + +public: + JoinRelationSet &numerator_relations; + double denominator; +}; + +struct DomainEstimate { +public: + DomainEstimate(); + +public: + idx_t GetDistinctCount() const; + bool HasReliableDistinctCount() const; + idx_t GetReliableDistinctCount() const; + void Update(const DistinctCount &distinct_count); + +private: + optional_idx reliable_distinct_count; + optional_idx min_max_distinct_count; + idx_t fallback_distinct_count; +}; + +static bool IsReliableDistinctCount(DistinctCountSource source) { + return source == DistinctCountSource::HLL || source == DistinctCountSource::EXACT; +} + +static void UpdateMaxDistinctCount(optional_idx &target, idx_t distinct_count) { + if (target.IsValid()) { + target = MaxValue(target.GetIndex(), distinct_count); + return; + } + target = distinct_count; +} + +DomainEstimate::DomainEstimate() : fallback_distinct_count(NumericLimits::Maximum()) { +} + +idx_t DomainEstimate::GetDistinctCount() const { + if (reliable_distinct_count.IsValid()) { + return reliable_distinct_count.GetIndex(); + } + if (min_max_distinct_count.IsValid()) { + return min_max_distinct_count.GetIndex(); + } + return fallback_distinct_count; +} + +bool DomainEstimate::HasReliableDistinctCount() const { + return reliable_distinct_count.IsValid(); +} + +idx_t DomainEstimate::GetReliableDistinctCount() const { + D_ASSERT(reliable_distinct_count.IsValid()); + return reliable_distinct_count.GetIndex(); +} + +void DomainEstimate::Update(const DistinctCount &distinct_count) { + if (IsReliableDistinctCount(distinct_count.source)) { + UpdateMaxDistinctCount(reliable_distinct_count, distinct_count.distinct_count); + } else if (distinct_count.source == DistinctCountSource::MIN_MAX) { + UpdateMaxDistinctCount(min_max_distinct_count, distinct_count.distinct_count); + } else { + fallback_distinct_count = MinValue(distinct_count.distinct_count, fallback_distinct_count); + } +} + +struct RelationsSetToStats { +public: + explicit RelationsSetToStats(const column_binding_set_t &column_binding_set) + : equivalent_relations(column_binding_set) { + } + +public: + //! Column binding sets that are equivalent in a join plan. + column_binding_set_t equivalent_relations; + DomainEstimate domain_estimate; + vector> predicates; + vector column_names; +}; + +struct FilterInfoWithTotalDomains { + FilterInfoWithTotalDomains(JoinPredicate &predicate, RelationsSetToStats &relation_set_to_stats) + : predicate(predicate), domain_estimate(relation_set_to_stats.domain_estimate) { + } + +public: + double GetDistinctCount() const; + bool HasReliableDistinctCount() const; + idx_t GetReliableDistinctCount() const; + ExpressionType GetComparisonType(); + bool IsInnerEquality(); + FilterInfo &GetFilter() const; + JoinPredicate &GetPredicate() const; + +public: + reference predicate; + reference domain_estimate; +}; + +struct Subgraph2Denominator { +public: + Subgraph2Denominator() : relations(nullptr), numerator_relations(nullptr), denom(1) { + } + +public: + optional_ptr relations; + optional_ptr numerator_relations; + double denom; +}; + +struct LeftJoinDenomInfo { +public: + LeftJoinDenomInfo(JoinRelationSet &numerator_relations, double denominator) + : numerator_relations(numerator_relations), denominator(denominator) { + } + +public: + JoinRelationSet &numerator_relations; + double denominator; +}; + +struct CompositeJoinPairStats { +public: + //! The row-count cap is only plausible when the candidate key cardinality is within the same order of magnitude + //! as an observed single-column domain. Otherwise broad fact-to-fact joins can look like key lookups. + static constexpr double MAX_CARDINALITY_TO_DISTINCT_RATIO = 8; + + void RegisterDistinctCount(double distinct_count) { + if (!has_distinct_count) { + first_distinct_count = distinct_count; + has_distinct_count = true; + } + if (distinct_count > max_distinct_count) { + max_distinct_count = distinct_count; + } + } + + bool CanApplyCap(double cap) const { + return has_distinct_count && max_distinct_count > 0 && + cap <= max_distinct_count * MAX_CARDINALITY_TO_DISTINCT_RATIO; + } + +public: + double first_distinct_count = 0; + double max_distinct_count = 0; + bool has_distinct_count = false; +}; + +struct CardinalityHelper { +public: + CardinalityHelper() : cardinality_before_filters(0) { + } + explicit CardinalityHelper(double cardinality_before_filters) + : cardinality_before_filters(cardinality_before_filters) { + } + +public: + // Must be a double. Otherwise we can lose significance between different join orders. + double cardinality_before_filters; + vector table_names_joined; + vector column_names; +}; + +struct CardinalityEstimatorState { + vector relation_set_stats; + reference_map_t relation_set_2_cardinality; + vector relation_stats; + vector> or_filters; +}; + +CardinalityEstimator::CardinalityEstimator(JoinRelationSetManager &set_manager, + const JoinPredicateModel &predicate_model) + : state(make_uniq()), set_manager(set_manager), predicate_model(predicate_model) { +} + +CardinalityEstimator::~CardinalityEstimator() { +} + +double FilterInfoWithTotalDomains::GetDistinctCount() const { + return static_cast(domain_estimate.get().GetDistinctCount()); +} + +bool FilterInfoWithTotalDomains::HasReliableDistinctCount() const { + return domain_estimate.get().HasReliableDistinctCount(); +} + +idx_t FilterInfoWithTotalDomains::GetReliableDistinctCount() const { + return domain_estimate.get().GetReliableDistinctCount(); +} + +ExpressionType FilterInfoWithTotalDomains::GetComparisonType() { + return GetPredicate().GetComparisonType(); +} + +bool FilterInfoWithTotalDomains::IsInnerEquality() { + return GetPredicate().IsEquivalencePredicate(); +} + +FilterInfo &FilterInfoWithTotalDomains::GetFilter() const { + return GetPredicate().GetFilter(); +} + +JoinPredicate &FilterInfoWithTotalDomains::GetPredicate() const { + return predicate.get(); +} + // The filter was made on top of a logical sample or other projection, // but no specific columns are referenced. See issue 4978 number 4. -bool CardinalityEstimator::EmptyFilter(FilterInfo &filter_info) { +bool CardinalityEstimator::EmptyFilter(const FilterInfo &filter_info) { if (!filter_info.left_set && !filter_info.right_set) { return true; } return false; } -void CardinalityEstimator::AddRelationStats(FilterInfo &filter_info) { - D_ASSERT(filter_info.set.get().count >= 1); - for (const RelationsSetToStats &r2tdom : relation_set_stats) { +void CardinalityEstimator::AddRelationStats(const FilterInfo &filter_info) { + D_ASSERT(!filter_info.set.get().Empty()); + // Use whichever binding is valid: prefer left_binding, fall back to right_binding. + // left_binding may be INVALID_INDEX when left_set is empty (e.g. residual predicates + // or single-column filters where only the right side references a relation). + auto binding = filter_info.left_binding; + if (!binding.table_index.IsValid()) { + binding = filter_info.right_binding; + } + if (!binding.table_index.IsValid()) { + // No valid binding (EmptyFilter), nothing to record + return; + } + for (const RelationsSetToStats &r2tdom : state->relation_set_stats) { auto &i_set = r2tdom.equivalent_relations; - if (i_set.find(filter_info.left_binding) != i_set.end()) { + if (i_set.find(binding) != i_set.end()) { // found an equivalent filter return; } } - auto key = ColumnBinding(filter_info.left_binding.table_index, filter_info.left_binding.column_index); + auto key = ColumnBinding(binding.table_index, binding.column_index); RelationsSetToStats new_r2tdom(column_binding_set_t({key})); - relation_set_stats.emplace_back(new_r2tdom); + state->relation_set_stats.emplace_back(new_r2tdom); } -bool CardinalityEstimator::SingleColumnFilter(duckdb::FilterInfo &filter_info) { +bool CardinalityEstimator::SingleColumnFilter(const FilterInfo &filter_info) { if (filter_info.left_set && filter_info.right_set && filter_info.set.get().count > 1) { // Both set and are from different relations return false; @@ -53,120 +276,88 @@ bool CardinalityEstimator::SingleColumnFilter(duckdb::FilterInfo &filter_info) { return true; } -vector CardinalityEstimator::DetermineMatchingEquivalentSets(optional_ptr filter_info) { - vector matching_equivalent_sets; - idx_t equivalent_relation_index = 0; - - for (const RelationsSetToStats &r2tdom : relation_set_stats) { - auto &i_set = r2tdom.equivalent_relations; - if (i_set.find(filter_info->left_binding) != i_set.end()) { - matching_equivalent_sets.push_back(equivalent_relation_index); - } else if (i_set.find(filter_info->right_binding) != i_set.end()) { - // don't add both left and right to the matching_equivalent_sets - // since both left and right get added to that index anyway. - matching_equivalent_sets.push_back(equivalent_relation_index); - } - equivalent_relation_index++; - } - return matching_equivalent_sets; +static bool IsJoinOrFilter(const FilterInfo &filter) { + return filter.filter && filter.filter->GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION && + filter.filter->GetExpressionType() == ExpressionType::CONJUNCTION_OR && filter.set.get().count >= 2; } -void CardinalityEstimator::AddToEquivalenceSets(optional_ptr filter_info, - vector matching_equivalent_sets) { - D_ASSERT(matching_equivalent_sets.size() <= 2); - if (matching_equivalent_sets.size() > 1) { - // an equivalence relation is connecting two sets of equivalence relations - // so push all relations from the second set into the first. Later we will delete - // the second set. - for (ColumnBinding i : relation_set_stats.at(matching_equivalent_sets[1]).equivalent_relations) { - relation_set_stats.at(matching_equivalent_sets[0]).equivalent_relations.insert(i); - } - for (auto &column_name : relation_set_stats.at(matching_equivalent_sets[1]).column_names) { - relation_set_stats.at(matching_equivalent_sets[0]).column_names.push_back(column_name); - } - relation_set_stats.at(matching_equivalent_sets[1]).equivalent_relations.clear(); - relation_set_stats.at(matching_equivalent_sets[1]).column_names.clear(); - relation_set_stats.at(matching_equivalent_sets[0]).filters.push_back(filter_info); - // add all values of one set to the other, delete the empty one - } else if (matching_equivalent_sets.size() == 1) { - auto &tdom_i = relation_set_stats.at(matching_equivalent_sets.at(0)); - tdom_i.equivalent_relations.insert(filter_info->left_binding); - tdom_i.equivalent_relations.insert(filter_info->right_binding); - tdom_i.filters.push_back(filter_info); - } else if (matching_equivalent_sets.empty()) { - column_binding_set_t tmp; - tmp.insert(filter_info->left_binding); - tmp.insert(filter_info->right_binding); - relation_set_stats.emplace_back(tmp); - relation_set_stats.back().filters.push_back(filter_info); - } -} - -void CardinalityEstimator::InitEquivalentRelations(const vector> &filter_infos) { - // For each filter, we fill keep track of the index of the equivalent relation set - // the left and right relation needs to be added to. - for (auto &filter : filter_infos) { - if (filter->filter && filter->filter->GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION && - filter->filter->GetExpressionType() == ExpressionType::CONJUNCTION_OR && filter->set.get().count >= 2) { - or_filters.push_back(filter.get()); +void CardinalityEstimator::InitEquivalentRelations() { + state->relation_set_stats.clear(); + state->or_filters.clear(); + + for (auto predicate_ref : predicate_model.GetPredicates()) { + auto &filter = predicate_ref.get().GetFilter(); + if (IsJoinOrFilter(filter)) { + state->or_filters.push_back(filter); continue; } - if (SingleColumnFilter(*filter)) { + if (SingleColumnFilter(filter)) { // Filter on one relation, (i.e. string or range filter on a column). // Grab the first relation and add it to the equivalence_relations - AddRelationStats(*filter); - continue; - } else if (EmptyFilter(*filter)) { + AddRelationStats(filter); + } + } + + for (auto &equality_class : predicate_model.GetEqualityClasses()) { + if (equality_class.columns.empty()) { continue; } - D_ASSERT(filter->left_set->count >= 1); - D_ASSERT(filter->right_set->count >= 1); + state->relation_set_stats.emplace_back(equality_class.columns); + for (auto &edge : equality_class.edges) { + state->relation_set_stats.back().predicates.push_back(edge.predicate); + } + } - auto matching_equivalent_sets = DetermineMatchingEquivalentSets(filter.get()); - AddToEquivalenceSets(filter.get(), matching_equivalent_sets); + for (auto predicate_ref : predicate_model.GetSelectivityPredicates()) { + auto &predicate = predicate_ref.get(); + D_ASSERT(predicate.HasValidJoinEndpoints()); + D_ASSERT(predicate.HasAnyValidStatsBinding()); + + column_binding_set_t domain_group; + if (predicate.GetStatsBinding(true).table_index.IsValid()) { + domain_group.insert(predicate.GetStatsBinding(true)); + } + if (predicate.GetStatsBinding(false).table_index.IsValid()) { + domain_group.insert(predicate.GetStatsBinding(false)); + } + D_ASSERT(!domain_group.empty()); + state->relation_set_stats.emplace_back(domain_group); + state->relation_set_stats.back().predicates.push_back(predicate); } RemoveEmptyTotalDomains(); } void CardinalityEstimator::RemoveEmptyTotalDomains() { auto remove_start = - std::remove_if(relation_set_stats.begin(), relation_set_stats.end(), + std::remove_if(state->relation_set_stats.begin(), state->relation_set_stats.end(), [](RelationsSetToStats &r_2_tdom) { return r_2_tdom.equivalent_relations.empty(); }); - relation_set_stats.erase(remove_start, relation_set_stats.end()); + state->relation_set_stats.erase(remove_start, state->relation_set_stats.end()); } double CardinalityEstimator::GetNumerator(JoinRelationSet &set) { double numerator = 1; for (idx_t i = 0; i < set.count; i++) { auto &single_node_set = set_manager.GetJoinRelation(set.relations[i]); - auto card_helper = relation_set_2_cardinality[single_node_set.ToString()]; + auto entry = state->relation_set_2_cardinality.find(single_node_set); + D_ASSERT(entry != state->relation_set_2_cardinality.end()); + if (entry == state->relation_set_2_cardinality.end()) { + continue; + } + auto &card_helper = entry->second; numerator *= card_helper.cardinality_before_filters == 0 ? 1 : card_helper.cardinality_before_filters; } return numerator; } -bool EdgeConnects(FilterInfoWithTotalDomains &edge, Subgraph2Denominator &subgraph) { - if (edge.filter_info->left_set) { - if (JoinRelationSet::IsSubset(*subgraph.relations, *edge.filter_info->left_set)) { - // cool - return true; - } - } - if (edge.filter_info->right_set) { - if (JoinRelationSet::IsSubset(*subgraph.relations, *edge.filter_info->right_set)) { - return true; - } - } - return false; -} - vector GetEdges(vector &relations_to_tdom, JoinRelationSet &requested_set) { vector res; for (auto &relation_2_tdom : relations_to_tdom) { - for (auto &filter : relation_2_tdom.filters) { - if (JoinRelationSet::IsSubset(requested_set, filter->set)) { - FilterInfoWithTotalDomains new_edge(filter, relation_2_tdom); + for (auto predicate_ref : relation_2_tdom.predicates) { + auto &predicate = predicate_ref.get(); + auto &filter = predicate.GetFilter(); + if (JoinRelationSet::IsSubset(requested_set, filter.set.get()) && filter.left_set != filter.right_set) { + FilterInfoWithTotalDomains new_edge(predicate, relation_2_tdom); res.push_back(new_edge); } } @@ -174,7 +365,33 @@ vector GetEdges(vector &relatio return res; } -vector SubgraphsConnectedByEdge(FilterInfoWithTotalDomains &edge, vector &subgraphs) { +static optional_ptr GetEdgeEndpoint(FilterInfoWithTotalDomains &edge, + JoinRelationSetManager &set_manager, bool left) { + auto &filter = edge.GetFilter(); + if (filter.filter && BoundComparisonExpression::IsComparison(*filter.filter)) { + auto &binding = edge.GetPredicate().GetStatsBinding(left); + if (binding.table_index.IsValid()) { + return &set_manager.GetJoinRelation(RelationIndex(binding.table_index.index)); + } + } + return left ? edge.GetPredicate().GetLeftSetOptional() : edge.GetPredicate().GetRightSetOptional(); +} + +bool EdgeConnects(FilterInfoWithTotalDomains &edge, Subgraph2Denominator &subgraph, + JoinRelationSetManager &set_manager) { + auto left_endpoint = GetEdgeEndpoint(edge, set_manager, true); + if (left_endpoint && JoinRelationSet::IsSubset(*subgraph.relations, *left_endpoint)) { + return true; + } + auto right_endpoint = GetEdgeEndpoint(edge, set_manager, false); + if (right_endpoint && JoinRelationSet::IsSubset(*subgraph.relations, *right_endpoint)) { + return true; + } + return false; +} + +vector SubgraphsConnectedByEdge(FilterInfoWithTotalDomains &edge, vector &subgraphs, + JoinRelationSetManager &set_manager) { vector res; if (subgraphs.empty()) { return res; @@ -184,7 +401,8 @@ vector SubgraphsConnectedByEdge(FilterInfoWithTotalDomains &edge, vector< for (idx_t outer = 0; outer != subgraphs.size(); outer++) { // check if the edge connects two subgraphs. for (idx_t inner = outer + 1; inner != subgraphs.size(); inner++) { - if (EdgeConnects(edge, subgraphs.at(outer)) && EdgeConnects(edge, subgraphs.at(inner))) { + if (EdgeConnects(edge, subgraphs.at(outer), set_manager) && + EdgeConnects(edge, subgraphs.at(inner), set_manager)) { // order is important because we will delete the inner subgraph later res.push_back(outer); res.push_back(inner); @@ -193,7 +411,7 @@ vector SubgraphsConnectedByEdge(FilterInfoWithTotalDomains &edge, vector< } // if the edge does not connect two subgraphs, see if the edge connects with just outer // merge subgraph.at(outer) with the RelationSet(s) that edge connects - if (EdgeConnects(edge, subgraphs.at(outer))) { + if (EdgeConnects(edge, subgraphs.at(outer), set_manager)) { res.push_back(outer); return res; } @@ -205,11 +423,15 @@ vector SubgraphsConnectedByEdge(FilterInfoWithTotalDomains &edge, vector< JoinRelationSet &CardinalityEstimator::UpdateNumeratorRelations(Subgraph2Denominator left, Subgraph2Denominator right, FilterInfoWithTotalDomains &filter) { - switch (filter.filter_info->join_type) { + auto &predicate = filter.GetPredicate(); + switch (predicate.GetJoinType()) { + case JoinType::LEFT: { + return CalculateLeftJoinDenomInfo(left, right, filter).numerator_relations; + } case JoinType::SEMI: case JoinType::ANTI: { - if (JoinRelationSet::IsSubset(*left.relations, *filter.filter_info->left_set) && - JoinRelationSet::IsSubset(*right.relations, *filter.filter_info->right_set)) { + if (JoinRelationSet::IsSubset(*left.relations, predicate.GetLeftSet()) && + JoinRelationSet::IsSubset(*right.relations, predicate.GetRightSet())) { return *left.numerator_relations; } return *right.numerator_relations; @@ -220,204 +442,462 @@ JoinRelationSet &CardinalityEstimator::UpdateNumeratorRelations(Subgraph2Denomin } } -// Given two relations, here is where we considers the filter(s) that join them. -// This could use some work when it comes to join conditions that are not equality join conditions +// Apply the denominator multiplier for a given comparison type and effective distinct count. +static double ApplyComparisonRatio(double base_denom, ExpressionType comparison_type, double effective_d) { + switch (comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return base_denom * effective_d; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_DISTINCT_FROM: + // Assume this blows up, but use the tdom to bound it a bit. + return base_denom * pow(effective_d, 2.0 / 3.0); + default: + return base_denom; + } +} + +static double GetEffectiveDenom(double denom) { + return denom <= 0 ? 1 : denom; +} + +double CardinalityEstimator::CalculateInnerJoinDenom(double base_denom, FilterInfoWithTotalDomains &filter) { + auto effective_d = filter.GetDistinctCount(); + auto comparison_type = filter.GetComparisonType(); + if (comparison_type == ExpressionType::INVALID) { + return base_denom * effective_d; + } + return ApplyComparisonRatio(base_denom, comparison_type, effective_d); +} + +LeftJoinDenomInfo CardinalityEstimator::CalculateLeftJoinDenomInfo(Subgraph2Denominator &left, + Subgraph2Denominator &right, + FilterInfoWithTotalDomains &filter) { + auto &predicate = filter.GetPredicate(); + D_ASSERT(left.relations && right.relations); + D_ASSERT(left.numerator_relations && right.numerator_relations); + auto left_is_preserved = JoinRelationSet::IsSubset(*left.relations, predicate.GetLeftSet()) && + JoinRelationSet::IsSubset(*right.relations, predicate.GetRightSet()); + auto &preserved_numerator = left_is_preserved ? *left.numerator_relations : *right.numerator_relations; + auto preserved_denom = GetEffectiveDenom(left_is_preserved ? left.denom : right.denom); + + auto &inner_numerator = set_manager.Union(*left.numerator_relations, *right.numerator_relations); + auto inner_denom = CalculateInnerJoinDenom(GetEffectiveDenom(left.denom) * GetEffectiveDenom(right.denom), filter); + if (inner_denom <= 0) { + return LeftJoinDenomInfo(preserved_numerator, preserved_denom); + } + auto inner_cardinality = GetNumerator(inner_numerator) / inner_denom; + auto preserved_cardinality = GetNumerator(preserved_numerator) / preserved_denom; + if (inner_cardinality > preserved_cardinality) { + return LeftJoinDenomInfo(inner_numerator, inner_denom); + } + return LeftJoinDenomInfo(preserved_numerator, preserved_denom); +} + +double CardinalityEstimator::CalculateLeftJoinDenom(Subgraph2Denominator &left, Subgraph2Denominator &right, + FilterInfoWithTotalDomains &filter) { + return CalculateLeftJoinDenomInfo(left, right, filter).denominator; +} + +double CardinalityEstimator::CalculateSemiAntiJoinDenom(double base_denom, Subgraph2Denominator &left, + Subgraph2Denominator &right, + FilterInfoWithTotalDomains &filter) { + auto &predicate = filter.GetPredicate(); + if (JoinRelationSet::IsSubset(*left.relations, predicate.GetLeftSet()) && + JoinRelationSet::IsSubset(*right.relations, predicate.GetRightSet())) { + return left.denom * DEFAULT_SEMI_ANTI_SELECTIVITY; + } + return right.denom * DEFAULT_SEMI_ANTI_SELECTIVITY; +} + +// Given two subgraphs, compute the updated denominator for the join between them. double CardinalityEstimator::CalculateUpdatedDenom(Subgraph2Denominator left, Subgraph2Denominator right, FilterInfoWithTotalDomains &filter) { - double new_denom = left.denom * right.denom; - switch (filter.filter_info->join_type) { - case JoinType::INNER: { - // Collect comparison types - ExpressionType comparison_type = ExpressionType::INVALID; - ExpressionIterator::EnumerateExpression(filter.filter_info->filter, [&](Expression &expr) { - if (BoundComparisonExpression::IsComparison(expr)) { - comparison_type = expr.GetExpressionType(); - } - }); - if (comparison_type == ExpressionType::INVALID) { - new_denom *= filter.has_distinct_count_hll ? static_cast(filter.distinct_count_hll) - : static_cast(filter.distinct_count_no_hll); - // no comparison is taking place, so the denominator is just the product of the left and right - return new_denom; + double base_denom = left.denom * right.denom; + switch (filter.GetPredicate().GetJoinType()) { + case JoinType::LEFT: + return CalculateLeftJoinDenom(left, right, filter); + case JoinType::INNER: + return CalculateInnerJoinDenom(base_denom, filter); + case JoinType::SEMI: + case JoinType::ANTI: + return CalculateSemiAntiJoinDenom(base_denom, left, right, filter); + default: + // Cross product: no join condition reduces the denominator. + return base_denom; + } +} + +static bool GetEqualityEdgeRelations(FilterInfoWithTotalDomains &edge, RelationIndex &left, RelationIndex &right) { + if (!edge.IsInnerEquality()) { + return false; + } + auto &left_binding = edge.GetPredicate().GetEqualityBinding(true); + auto &right_binding = edge.GetPredicate().GetEqualityBinding(false); + if (!left_binding.table_index.IsValid() || !right_binding.table_index.IsValid()) { + return false; + } + left = RelationIndex(left_binding.table_index.index); + right = RelationIndex(right_binding.table_index.index); + return left != right; +} + +static optional_ptr GetEqualityJoinPair(FilterInfoWithTotalDomains &edge, + JoinRelationSetManager &set_manager) { + RelationIndex left; + RelationIndex right; + if (!GetEqualityEdgeRelations(edge, left, right)) { + return nullptr; + } + auto &left_set = set_manager.GetJoinRelation(left); + auto &right_set = set_manager.GetJoinRelation(right); + return set_manager.Union(left_set, right_set); +} + +double CardinalityEstimator::GetJoinPairCap(JoinRelationSet &join_pair) { + D_ASSERT(join_pair.count == 2); + double cap = NumericLimits::Maximum(); + for (idx_t relation_index = 0; relation_index < join_pair.count; relation_index++) { + auto &single_relation = set_manager.GetJoinRelation(join_pair.relations[relation_index]); + auto cardinality = GetNumerator(single_relation); + if (cardinality <= 0) { + return 0; } - // extra_ratio helps represents how many tuples will be filtered out if the comparison evaluates to - // false. set to 1 to assume cross product. - double extra_ratio = 1; - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - // extra ratio stays 1 - extra_ratio = filter.has_distinct_count_hll ? static_cast(filter.distinct_count_hll) - : static_cast(filter.distinct_count_no_hll); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_NOTEQUAL: - case ExpressionType::COMPARE_DISTINCT_FROM: - // Assume this blows up, but use the tdom to bound it a bit - extra_ratio = filter.has_distinct_count_hll ? static_cast(filter.distinct_count_hll) - : static_cast(filter.distinct_count_no_hll); - extra_ratio = pow(extra_ratio, 2.0 / 3.0); - break; - default: - break; + cap = MinValue(cap, cardinality); + } + return cap == NumericLimits::Maximum() ? 0 : cap; +} + +bool CardinalityEstimator::ApplyJoinPairCap(double &target_denom, JoinRelationSet &join_pair, + reference_map_t &join_pair_stats, + reference_set_t &capped_join_pairs) { + if (capped_join_pairs.find(join_pair) != capped_join_pairs.end()) { + return false; + } + const auto it = join_pair_stats.find(join_pair); + if (it == join_pair_stats.end()) { + return false; + } + auto &stats = it->second; + auto cap = GetJoinPairCap(join_pair); + // The cap is an FK/PK heuristic. Avoid applying it to broad fact-to-fact joins where the base row count is much + // larger than any observed single-column join domain; in that case row count is a poor proxy for composite NDV. + if (!stats.CanApplyCap(cap)) { + return false; + } + auto &first_d = stats.first_distinct_count; + if (cap <= 0 || first_d <= 0 || first_d > cap) { + return false; + } + if (cap > 0 && first_d < cap && first_d > 0) { + // Raise weak same-pair composite evidence to the FK/PK denominator floor. + target_denom = target_denom / first_d * cap; + first_d = cap; + } + capped_join_pairs.insert(join_pair); + return true; +} + +bool CardinalityEstimator::ApplyJoinIncrement(double &target_denom, FilterInfoWithTotalDomains &edge, + reference_map_t &join_pair_stats, + reference_set_t &capped_join_pairs, + JoinRelationSet &scope, optional_ptr join_pair) { + if (edge.IsInnerEquality()) { + auto target_pair = join_pair; + if (!target_pair) { + target_pair = GetEqualityJoinPair(edge, set_manager); + } + if (!target_pair) { + return false; + } + if (!JoinRelationSet::IsSubset(scope, *target_pair) || + !predicate_model.HasDirectCompositeEquality(*target_pair)) { + return false; } - new_denom *= extra_ratio; - return new_denom; + if (capped_join_pairs.find(*target_pair) != capped_join_pairs.end()) { + // The composite-key cap already accounts for same-pair equality predicates. + return true; + } + return ApplyJoinPairCap(target_denom, *target_pair, join_pair_stats, capped_join_pairs); } - case JoinType::SEMI: - case JoinType::ANTI: { - if (JoinRelationSet::IsSubset(*left.relations, *filter.filter_info->left_set) && - JoinRelationSet::IsSubset(*right.relations, *filter.filter_info->right_set)) { - new_denom = left.denom * CardinalityEstimator::DEFAULT_SEMI_ANTI_SELECTIVITY; - return new_denom; + + if (edge.GetPredicate().GetJoinType() == JoinType::LEFT) { + // LEFT joins preserve the LHS cardinality in this estimator. Additional LEFT conditions + // should not reduce the output or add unused-edge penalties. + return true; + } + + return false; +} + +bool CardinalityEstimator::ApplyCompositeJoinPairCaps( + double &target_denom, JoinRelationSet &scope, + reference_map_t &join_pair_stats, + reference_set_t &capped_join_pairs) { + bool applied = false; + for (auto &entry : join_pair_stats) { + auto &join_pair = entry.first.get(); + if (capped_join_pairs.find(join_pair) != capped_join_pairs.end()) { + continue; + } + if (!JoinRelationSet::IsSubset(scope, join_pair)) { + continue; } - new_denom = right.denom * CardinalityEstimator::DEFAULT_SEMI_ANTI_SELECTIVITY; - return new_denom; + if (!predicate_model.HasDirectCompositeEquality(join_pair)) { + continue; + } + applied = ApplyJoinPairCap(target_denom, join_pair, join_pair_stats, capped_join_pairs) || applied; + } + return applied; +} + +enum class DenominatorEdgeKind : uint8_t { + INNER_EQUIVALENCE, + LEFT_JOIN, + SEMI_ANTI_JOIN, + NON_EQUALITY_SELECTIVITY, + REDUNDANT_TRANSITIVE_EQUALITY, + OTHER +}; + +static bool EquivalenceGroupApplied(FilterInfoWithTotalDomains &edge, const unordered_set &applied_groups) { + auto equality_class_index = edge.GetPredicate().GetEqualityClassIndex(); + return equality_class_index.IsValid() && applied_groups.count(equality_class_index.GetIndex()) > 0; +} + +static DenominatorEdgeKind ClassifyDenominatorEdge(FilterInfoWithTotalDomains &edge, + const unordered_set &applied_groups) { + if (EquivalenceGroupApplied(edge, applied_groups)) { + return DenominatorEdgeKind::REDUNDANT_TRANSITIVE_EQUALITY; } + switch (edge.GetPredicate().GetPredicateClass()) { + case JoinPredicateClass::INNER_EQUALITY: + return DenominatorEdgeKind::INNER_EQUIVALENCE; + case JoinPredicateClass::INNER_NON_EQUALITY: + return DenominatorEdgeKind::NON_EQUALITY_SELECTIVITY; + case JoinPredicateClass::LEFT_JOIN: + return DenominatorEdgeKind::LEFT_JOIN; + case JoinPredicateClass::SEMI_ANTI_JOIN: + return DenominatorEdgeKind::SEMI_ANTI_JOIN; default: - // cross product - return new_denom; + return DenominatorEdgeKind::OTHER; } } -DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { - vector subgraphs; - - // Finding the denominator is tricky. You need to go through the tdoms in decreasing order - // Then loop through all filters in the equivalence set of the tdom to see if both the - // left and right relations are in the new set, if so you can use that filter. - // You must also make sure that the filters all relations in the given set, so we use subgraphs - // that should eventually merge into one connected graph that joins all the relations - // TODO: Implement a method to cache subgraphs so you don't have to build them up every - // time the cardinality of a new set is requested +static bool CanIncrementExistingJoin(DenominatorEdgeKind kind) { + return kind == DenominatorEdgeKind::INNER_EQUIVALENCE || kind == DenominatorEdgeKind::LEFT_JOIN; +} - // relations_to_tdoms has already been sorted by largest to smallest total domain - // then we look through the filters for the relations_to_tdoms, - // and we start to choose the filters that join relations in the set. +static void RegisterEqualityJoinPair(FilterInfoWithTotalDomains &edge, JoinRelationSetManager &set_manager, + reference_map_t &join_pair_stats) { + if (!edge.IsInnerEquality()) { + return; + } + auto join_pair = GetEqualityJoinPair(edge, set_manager); + if (!join_pair) { + return; + } + join_pair_stats[*join_pair].RegisterDistinctCount(edge.GetDistinctCount()); +} - // edges are guaranteed to be in order of largest tdom to smallest tdom. +struct DenominatorState { + vector subgraphs; + unordered_set applied_equivalence_groups; + reference_map_t join_pair_stats; + reference_set_t capped_join_pairs; unordered_set unused_edge_tdoms; - auto edges = GetEdges(relation_set_stats, set); - for (auto &edge : edges) { - if (subgraphs.size() == 1 && subgraphs.at(0).relations->ToString() == set.ToString()) { - // the first subgraph has connected all the desired relations, just skip the rest of the edges - if (edge.has_distinct_count_hll) { - unused_edge_tdoms.insert(edge.distinct_count_hll); - } - continue; - } +}; - auto subgraph_connections = SubgraphsConnectedByEdge(edge, subgraphs); - if (subgraph_connections.empty()) { - // create a subgraph out of left and right, then merge right into left and add left to subgraphs. - // this helps cover a case where there are no subgraphs yet, and the only join filter is a SEMI JOIN - auto left_subgraph = Subgraph2Denominator(); - auto right_subgraph = Subgraph2Denominator(); - left_subgraph.relations = edge.filter_info->left_set; - left_subgraph.numerator_relations = edge.filter_info->left_set; - right_subgraph.relations = edge.filter_info->right_set; - right_subgraph.numerator_relations = edge.filter_info->right_set; - left_subgraph.numerator_relations = &UpdateNumeratorRelations(left_subgraph, right_subgraph, edge); - left_subgraph.relations = edge.filter_info->set.get(); - left_subgraph.denom = CalculateUpdatedDenom(left_subgraph, right_subgraph, edge); - subgraphs.push_back(left_subgraph); - } else if (subgraph_connections.size() == 1) { - auto left_subgraph = &subgraphs.at(subgraph_connections.at(0)); - auto right_subgraph = Subgraph2Denominator(); - right_subgraph.relations = edge.filter_info->right_set; - right_subgraph.numerator_relations = edge.filter_info->right_set; - if (JoinRelationSet::IsSubset(*left_subgraph->relations, *right_subgraph.relations)) { - right_subgraph.relations = edge.filter_info->left_set; - right_subgraph.numerator_relations = edge.filter_info->left_set; - } +static bool DenominatorSubgraphComplete(DenominatorState &state, JoinRelationSet &set) { + return state.subgraphs.size() == 1 && state.subgraphs.at(0).relations && + RefersToSameObject(*state.subgraphs.at(0).relations, set); +} - if (JoinRelationSet::IsSubset(*left_subgraph->relations, *edge.filter_info->left_set) && - JoinRelationSet::IsSubset(*left_subgraph->relations, *edge.filter_info->right_set)) { - // here we have an edge that connects the same subgraph to the same subgraph. Just continue. no need to - // update the denom - continue; - } - left_subgraph->numerator_relations = &UpdateNumeratorRelations(*left_subgraph, right_subgraph, edge); - left_subgraph->relations = &set_manager.Union(*left_subgraph->relations, *right_subgraph.relations); - left_subgraph->denom = CalculateUpdatedDenom(*left_subgraph, right_subgraph, edge); - } else if (subgraph_connections.size() == 2) { - // The two subgraphs in the subgraph_connections can be merged by this edge. - D_ASSERT(subgraph_connections.at(0) < subgraph_connections.at(1)); - auto subgraph_to_merge_into = &subgraphs.at(subgraph_connections.at(0)); - auto subgraph_to_delete = &subgraphs.at(subgraph_connections.at(1)); - subgraph_to_merge_into->relations = - &set_manager.Union(*subgraph_to_merge_into->relations, *subgraph_to_delete->relations); - subgraph_to_merge_into->numerator_relations = - &UpdateNumeratorRelations(*subgraph_to_merge_into, *subgraph_to_delete, edge); - subgraph_to_merge_into->denom = CalculateUpdatedDenom(*subgraph_to_merge_into, *subgraph_to_delete, edge); - subgraph_to_delete->relations = nullptr; - auto remove_start = std::remove_if(subgraphs.begin(), subgraphs.end(), - [](Subgraph2Denominator &s) { return !s.relations; }); - subgraphs.erase(remove_start, subgraphs.end()); +void CardinalityEstimator::ProcessDenominatorEdge(FilterInfoWithTotalDomains &edge, JoinRelationSet &requested_set, + DenominatorState &state) { + auto edge_kind = ClassifyDenominatorEdge(edge, state.applied_equivalence_groups); + if (DenominatorSubgraphComplete(state, requested_set)) { + auto &complete_subgraph = state.subgraphs.at(0); + ApplyCompositeJoinPairCaps(complete_subgraph.denom, *complete_subgraph.relations, state.join_pair_stats, + state.capped_join_pairs); + if (edge_kind == DenominatorEdgeKind::REDUNDANT_TRANSITIVE_EQUALITY) { + return; } + if (CanIncrementExistingJoin(edge_kind) && + ApplyJoinIncrement(complete_subgraph.denom, edge, state.join_pair_stats, state.capped_join_pairs, + *complete_subgraph.relations)) { + return; + } + if (edge.HasReliableDistinctCount()) { + state.unused_edge_tdoms.insert(edge.GetReliableDistinctCount()); + } + return; + } + + auto equality_class_index = edge.GetPredicate().GetEqualityClassIndex(); + if (equality_class_index.IsValid()) { + state.applied_equivalence_groups.insert(equality_class_index.GetIndex()); + } + RegisterEqualityJoinPair(edge, set_manager, state.join_pair_stats); + + auto edge_left_set = GetEdgeEndpoint(edge, set_manager, true); + auto edge_right_set = GetEdgeEndpoint(edge, set_manager, false); + if (!edge_left_set || !edge_right_set) { + return; + } + + auto subgraph_connections = SubgraphsConnectedByEdge(edge, state.subgraphs, set_manager); + if (subgraph_connections.empty()) { + CreateDenominatorSubgraph(edge, *edge_left_set, *edge_right_set, state); + } else if (subgraph_connections.size() == 1) { + ExtendDenominatorSubgraph(subgraph_connections.at(0), edge, *edge_left_set, *edge_right_set, + CanIncrementExistingJoin(edge_kind), state); + } else if (subgraph_connections.size() == 2) { + MergeDenominatorSubgraphs(subgraph_connections, edge, state); + } +} + +void CardinalityEstimator::CreateDenominatorSubgraph(FilterInfoWithTotalDomains &edge, JoinRelationSet &edge_left_set, + JoinRelationSet &edge_right_set, DenominatorState &state) { + auto left_subgraph = Subgraph2Denominator(); + auto right_subgraph = Subgraph2Denominator(); + left_subgraph.relations = &edge_left_set; + left_subgraph.numerator_relations = &edge_left_set; + right_subgraph.relations = &edge_right_set; + right_subgraph.numerator_relations = &edge_right_set; + auto &numerator_relations = UpdateNumeratorRelations(left_subgraph, right_subgraph, edge); + auto denom = CalculateUpdatedDenom(left_subgraph, right_subgraph, edge); + left_subgraph.numerator_relations = &numerator_relations; + left_subgraph.relations = &set_manager.Union(edge_left_set, edge_right_set); + left_subgraph.denom = denom; + ApplyCompositeJoinPairCaps(left_subgraph.denom, *left_subgraph.relations, state.join_pair_stats, + state.capped_join_pairs); + state.subgraphs.push_back(left_subgraph); +} + +void CardinalityEstimator::ExtendDenominatorSubgraph(idx_t subgraph_index, FilterInfoWithTotalDomains &edge, + JoinRelationSet &edge_left_set, JoinRelationSet &edge_right_set, + bool can_increment_existing_join, DenominatorState &state) { + auto left_subgraph = &state.subgraphs.at(subgraph_index); + auto right_subgraph = Subgraph2Denominator(); + right_subgraph.relations = &edge_right_set; + right_subgraph.numerator_relations = &edge_right_set; + if (JoinRelationSet::IsSubset(*left_subgraph->relations, *right_subgraph.relations)) { + right_subgraph.relations = &edge_left_set; + right_subgraph.numerator_relations = &edge_left_set; } - // Slight penalty to cardinality for unused edges - auto denom_multiplier = 1.0 + static_cast(unused_edge_tdoms.size()); - - // It's possible cross-products were added and are not present in the filters in the relation_2_tdom - // structures. When that's the case, merge all remaining subgraphs as if they are connected by a cross product - if (subgraphs.size() > 1) { - auto final_subgraph = subgraphs.at(0); - for (auto merge_with = subgraphs.begin() + 1; merge_with != subgraphs.end(); merge_with++) { - D_ASSERT(final_subgraph.relations && merge_with->relations); - final_subgraph.relations = &set_manager.Union(*final_subgraph.relations, *merge_with->relations); - D_ASSERT(final_subgraph.numerator_relations && merge_with->numerator_relations); - final_subgraph.numerator_relations = - &set_manager.Union(*final_subgraph.numerator_relations, *merge_with->numerator_relations); - final_subgraph.denom *= merge_with->denom; + if (JoinRelationSet::IsSubset(*left_subgraph->relations, edge_left_set) && + JoinRelationSet::IsSubset(*left_subgraph->relations, edge_right_set)) { + if (can_increment_existing_join) { + ApplyJoinIncrement(left_subgraph->denom, edge, state.join_pair_stats, state.capped_join_pairs, + *left_subgraph->relations); } + ApplyCompositeJoinPairCaps(left_subgraph->denom, *left_subgraph->relations, state.join_pair_stats, + state.capped_join_pairs); + return; } - if (!subgraphs.empty()) { - // Some relations are connected by cross products and will not end up in a subgraph - // Check and make sure all relations were considered, if not, they are connected to the graph by cross products - auto &returning_subgraph = subgraphs.at(0); - if (returning_subgraph.relations->count != set.count) { - for (idx_t rel_index = 0; rel_index < set.count; rel_index++) { - auto relation_id = set.relations[rel_index]; - auto &rel = set_manager.GetJoinRelation(relation_id); - if (!JoinRelationSet::IsSubset(*returning_subgraph.relations, rel)) { - returning_subgraph.numerator_relations = - &set_manager.Union(*returning_subgraph.numerator_relations, rel); - returning_subgraph.relations = &set_manager.Union(*returning_subgraph.relations, rel); - } - } + + auto &numerator_relations = UpdateNumeratorRelations(*left_subgraph, right_subgraph, edge); + auto denom = CalculateUpdatedDenom(*left_subgraph, right_subgraph, edge); + left_subgraph->numerator_relations = &numerator_relations; + left_subgraph->relations = &set_manager.Union(*left_subgraph->relations, *right_subgraph.relations); + left_subgraph->denom = denom; + ApplyCompositeJoinPairCaps(left_subgraph->denom, *left_subgraph->relations, state.join_pair_stats, + state.capped_join_pairs); +} + +void CardinalityEstimator::MergeDenominatorSubgraphs(const vector &subgraph_connections, + FilterInfoWithTotalDomains &edge, DenominatorState &state) { + D_ASSERT(subgraph_connections.size() == 2); + D_ASSERT(subgraph_connections.at(0) < subgraph_connections.at(1)); + auto subgraph_to_merge_into = &state.subgraphs.at(subgraph_connections.at(0)); + auto subgraph_to_delete = &state.subgraphs.at(subgraph_connections.at(1)); + auto &numerator_relations = UpdateNumeratorRelations(*subgraph_to_merge_into, *subgraph_to_delete, edge); + auto denom = CalculateUpdatedDenom(*subgraph_to_merge_into, *subgraph_to_delete, edge); + subgraph_to_merge_into->relations = + &set_manager.Union(*subgraph_to_merge_into->relations, *subgraph_to_delete->relations); + subgraph_to_merge_into->numerator_relations = &numerator_relations; + subgraph_to_merge_into->denom = denom; + ApplyCompositeJoinPairCaps(subgraph_to_merge_into->denom, *subgraph_to_merge_into->relations, state.join_pair_stats, + state.capped_join_pairs); + subgraph_to_delete->relations = nullptr; + auto remove_start = std::remove_if(state.subgraphs.begin(), state.subgraphs.end(), + [](Subgraph2Denominator &s) { return !s.relations; }); + state.subgraphs.erase(remove_start, state.subgraphs.end()); +} + +void CardinalityEstimator::MergeDisconnectedDenominatorSubgraphs(DenominatorState &state) { + if (state.subgraphs.size() <= 1) { + return; + } + auto final_subgraph = state.subgraphs.at(0); + for (auto merge_with = state.subgraphs.begin() + 1; merge_with != state.subgraphs.end(); merge_with++) { + D_ASSERT(final_subgraph.relations && merge_with->relations); + final_subgraph.relations = &set_manager.Union(*final_subgraph.relations, *merge_with->relations); + D_ASSERT(final_subgraph.numerator_relations && merge_with->numerator_relations); + final_subgraph.numerator_relations = + &set_manager.Union(*final_subgraph.numerator_relations, *merge_with->numerator_relations); + final_subgraph.denom *= merge_with->denom; + } + state.subgraphs.clear(); + state.subgraphs.push_back(final_subgraph); +} + +void CardinalityEstimator::AddCrossProductRelations(JoinRelationSet &set, DenominatorState &state) { + if (state.subgraphs.empty()) { + return; + } + auto &returning_subgraph = state.subgraphs.at(0); + if (returning_subgraph.relations->count == set.count) { + return; + } + for (idx_t rel_index = 0; rel_index < set.count; rel_index++) { + auto relation_id = set.relations[rel_index]; + auto &rel = set_manager.GetJoinRelation(relation_id); + if (!JoinRelationSet::IsSubset(*returning_subgraph.relations, rel)) { + returning_subgraph.numerator_relations = &set_manager.Union(*returning_subgraph.numerator_relations, rel); + returning_subgraph.relations = &set_manager.Union(*returning_subgraph.relations, rel); } } +} + +DenomInfo CardinalityEstimator::CreateDenominatorResult(JoinRelationSet &set, DenominatorState &state) { + auto denom_multiplier = 1.0 + static_cast(state.unused_edge_tdoms.size()); + MergeDisconnectedDenominatorSubgraphs(state); + AddCrossProductRelations(set, state); + + if (state.subgraphs.empty() || state.subgraphs.at(0).denom == 0) { + return DenomInfo(set, 1); + } + return DenomInfo(*state.subgraphs.at(0).numerator_relations, state.subgraphs.at(0).denom * denom_multiplier); +} - // can happen if a table has cardinality 0, a tdom is set to 0, or if a cross product is used. - if (subgraphs.empty() || subgraphs.at(0).denom == 0) { - // denominator is 1 and numerators are a cross product of cardinalities. - return DenomInfo(set, 1, 1); +DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { + DenominatorState state; + auto edges = GetEdges(this->state->relation_set_stats, set); + for (auto &edge : edges) { + ProcessDenominatorEdge(edge, set, state); } - return DenomInfo(*subgraphs.at(0).numerator_relations, 1, subgraphs.at(0).denom * denom_multiplier); + return CreateDenominatorResult(set, state); } -// Cardinality is calculated using logic found in -// https://blobs.duckdb.org/papers/tom-ebergen-msc-thesis-join-order-optimization-with-almost-no-statistics.pdf TL;DR -// Cardinality is estimated based on cardinality of base tables and the distinct counts of joined columns. If you have -// two tables A and B joined using A.x = B.y we assume that each tuple in A will match ~ B/(distinct(y)) tuples in B. -// The cardinality estimation then becomes (|A|x|B|) / max(distinct(x), distinct(y)). -// If there are extra joins, you can add the cardinality of the table to the numerator, and the -// distinct count of the join condition to the denominator. -// One benefit of this cardinality estimation formula is that it is associative and commutative, which means regardless -// of the order of the joins/join tree, the cardinality estimate will always be the same. The drawback of this current -// implementation, however, is that it only considers equality join conditions. Some modification have been made for -// comparison types like <, <=, >, >=, !=, but only a "penalty" was introduced, and the calculated cardinality is not -// based on stats (see CalculateUpdatedDenom()). +// Cardinality is calculated using logic based on +// https://blobs.duckdb.org/papers/tom-ebergen-msc-thesis-join-order-optimization-with-almost-no-statistics.pdf +// +// The estimator starts with base relation cardinalities and divides by a denominator assembled from predicate domain +// groups. INNER equality predicates use transitive equality classes, composite same-pair equalities can apply an FK/PK +// cap, disconnected predicate subgraphs are merged by cross product, and LEFT/SEMI/ANTI joins adjust the numerator side +// according to their output semantics. Non-equality predicates still use a heuristic total-domain penalty. template <> double CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) { double result; - auto it = relation_set_2_cardinality.find(new_set.ToString()); - if (it != relation_set_2_cardinality.end()) { + auto it = state->relation_set_2_cardinality.find(new_set); + if (it != state->relation_set_2_cardinality.end()) { result = it->second.cardinality_before_filters; } else { // can happen if a table has cardinality 0, or a tdom is set to 0 @@ -426,13 +906,13 @@ double CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set // include cardinalities of relations on the RHS of a semi/anti join. auto numerator = GetNumerator(denom.numerator_relations); result = numerator / denom.denominator; - relation_set_2_cardinality[new_set.ToString()] = CardinalityHelper(result); + state->relation_set_2_cardinality[new_set] = CardinalityHelper(result); } return ApplyOrFilterSelectivities(new_set, result); } double CardinalityEstimator::ApplyOrFilterSelectivities(JoinRelationSet &new_set, double cardinality) const { - for (auto &filter : or_filters) { + for (auto &filter : state->or_filters) { if (JoinRelationSet::IsSubset(new_set, filter->set.get())) { cardinality *= RelationStatisticsHelper::DEFAULT_SELECTIVITY; } @@ -451,16 +931,7 @@ idx_t CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) } bool SortTdoms(const RelationsSetToStats &a, const RelationsSetToStats &b) { - if (a.has_distinct_count_hll && b.has_distinct_count_hll) { - return a.distinct_count_hll > b.distinct_count_hll; - } - if (a.has_distinct_count_hll) { - return a.distinct_count_hll > b.distinct_count_no_hll; - } - if (b.has_distinct_count_hll) { - return a.distinct_count_no_hll > b.distinct_count_hll; - } - return a.distinct_count_no_hll > b.distinct_count_no_hll; + return a.domain_estimate.GetDistinctCount() > b.domain_estimate.GetDistinctCount(); } void CardinalityEstimator::InitCardinalityEstimatorProps(optional_ptr set, RelationStats &stats) { @@ -469,12 +940,12 @@ void CardinalityEstimator::InitCardinalityEstimatorProps(optional_ptrToString()] = card_helper; + state->relation_set_2_cardinality[*set] = card_helper; UpdateTotalDomains(set, stats); // sort relations from greatest tdom to lowest tdom. - std::sort(relation_set_stats.begin(), relation_set_stats.end(), SortTdoms); + std::sort(state->relation_set_stats.begin(), state->relation_set_stats.end(), SortTdoms); } void CardinalityEstimator::UpdateTotalDomains(optional_ptr set, RelationStats &stats) { @@ -488,23 +959,13 @@ void CardinalityEstimator::UpdateTotalDomains(optional_ptr set, //! the cardinality // Update the relation_to_tdom set with the estimated distinct count (or tdom) calculated above auto key = ColumnBinding(TableIndex(relation_id.index), ProjectionIndex(i)); - for (auto &relation_to_tdom : relation_set_stats) { - column_binding_set_t i_set = relation_to_tdom.equivalent_relations; + auto distinct_count = stats.column_distinct_count.at(i); + for (auto &relation_to_tdom : state->relation_set_stats) { + const auto &i_set = relation_to_tdom.equivalent_relations; if (i_set.find(key) == i_set.end()) { continue; } - auto distinct_count = stats.column_distinct_count.at(i); - if (distinct_count.from_hll && relation_to_tdom.has_distinct_count_hll) { - relation_to_tdom.distinct_count_hll = - MaxValue(relation_to_tdom.distinct_count_hll, distinct_count.distinct_count); - } else if (distinct_count.from_hll && !relation_to_tdom.has_distinct_count_hll) { - relation_to_tdom.has_distinct_count_hll = true; - relation_to_tdom.distinct_count_hll = distinct_count.distinct_count; - } else { - relation_to_tdom.distinct_count_no_hll = - MinValue(distinct_count.distinct_count, relation_to_tdom.distinct_count_no_hll); - } - break; + relation_to_tdom.domain_estimate.Update(distinct_count); } } } @@ -513,12 +974,12 @@ void CardinalityEstimator::UpdateTotalDomains(optional_ptr set, void CardinalityEstimator::AddRelationNamesToRelationStats(vector &stats) { #ifdef DEBUG - for (auto &total_domain : relation_set_stats) { + for (auto &total_domain : state->relation_set_stats) { for (auto &binding : total_domain.equivalent_relations) { D_ASSERT(binding.table_index.index < stats.size()); string column_name; if (binding.column_index < stats[binding.table_index.index].column_names.size()) { - column_name = stats[binding.table_index.index].column_names[binding.column_index]; + column_name = stats[binding.table_index.index].column_names[binding.column_index].GetIdentifierName(); } else { column_name = "[unknown]"; } @@ -529,14 +990,12 @@ void CardinalityEstimator::AddRelationNamesToRelationStats(vector } void CardinalityEstimator::PrintRelationStats() { - for (auto &total_domain : relation_set_stats) { + for (auto &total_domain : state->relation_set_stats) { string domain = "Following columns have the same distinct count: "; for (auto &column_name : total_domain.column_names) { domain += column_name + ", "; } - bool have_hll = total_domain.has_distinct_count_hll; - domain += "\n TOTAL DOMAIN = " + - to_string(have_hll ? total_domain.distinct_count_hll : total_domain.distinct_count_no_hll); + domain += "\n TOTAL DOMAIN = " + to_string(total_domain.domain_estimate.GetDistinctCount()); Printer::Print(domain); } } diff --git a/src/duckdb/src/optimizer/join_order/cost_model.cpp b/src/duckdb/src/optimizer/join_order/cost_model.cpp index ef1e23972..821a4edfa 100644 --- a/src/duckdb/src/optimizer/join_order/cost_model.cpp +++ b/src/duckdb/src/optimizer/join_order/cost_model.cpp @@ -2,18 +2,48 @@ #include "duckdb/optimizer/join_order/join_order_optimizer.hpp" #include "duckdb/optimizer/join_order/cost_model.hpp" +#include "duckdb/optimizer/join_order/query_graph_manager.hpp" + namespace duckdb { -CostModel::CostModel(QueryGraphManager &query_graph_manager) - : query_graph_manager(query_graph_manager), cardinality_estimator() { +CostModel::CostModel(QueryGraphManager &query_graph_manager, CardinalityEstimator &cardinality_estimator) + : query_graph_manager(query_graph_manager), cardinality_estimator(cardinality_estimator) { +} + +CardinalityEstimator &CostModel::GetCardinalityEstimator() { + return cardinality_estimator; +} + +static double GetLeftJoinInputCost(CardinalityEstimator &cardinality_estimator, + const vector> &possible_connections) { + double cost = 0; + reference_set_t seen_right_sides; + for (auto &connection : possible_connections) { + for (auto predicate_ref : connection.get().predicates) { + auto &predicate = predicate_ref.get(); + if (predicate.GetJoinType() != JoinType::LEFT) { + continue; + } + D_ASSERT(predicate.GetRightSetOptional()); + if (!seen_right_sides.insert(predicate.GetRightSet()).second) { + continue; + } + cost += cardinality_estimator.EstimateCardinalityWithSet(predicate.GetRightSet()); + } + } + return cost; } -// Currently cost of a join only factors in the cardinalities. -// If join types and join algorithms are to be considered, they should be added here. -double CostModel::ComputeCost(DPJoinNode &left, DPJoinNode &right) { - auto &combination = query_graph_manager.set_manager.Union(left.set, right.set); +// Currently cost of a join mostly factors in the cardinalities. +// LEFT joins need an explicit RHS input component because their output cardinality preserves the LHS, +// which otherwise makes early LEFT joins over large RHS inputs look almost free. +double CostModel::ComputeCost(DPJoinNode &left, DPJoinNode &right, JoinRelationSet &combination, + const vector> &possible_connections) { auto join_card = cardinality_estimator.EstimateCardinalityWithSet(combination); auto join_cost = join_card; + if (query_graph_manager.GetPredicateModel().HasLeftJoinPredicates()) { + join_cost += GetLeftJoinInputCost(cardinality_estimator, possible_connections); + } return join_cost + left.cost + right.cost; } diff --git a/src/duckdb/src/optimizer/join_order/filter_info.cpp b/src/duckdb/src/optimizer/join_order/filter_info.cpp new file mode 100644 index 000000000..2a324f546 --- /dev/null +++ b/src/duckdb/src/optimizer/join_order/filter_info.cpp @@ -0,0 +1,22 @@ +#include "duckdb/optimizer/join_order/filter_info.hpp" + +#include "duckdb/planner/expression.hpp" + +namespace duckdb { + +FilterInfo::FilterInfo(unique_ptr filter, JoinRelationSet &set, idx_t filter_index, JoinType join_type) + : filter(std::move(filter)), set(set), filter_index(filter_index), join_type(join_type) { +} + +FilterInfo::~FilterInfo() { +} + +void FilterInfo::SetLeftSet(optional_ptr left_set_new) { + left_set = left_set_new; +} + +void FilterInfo::SetRightSet(optional_ptr right_set_new) { + right_set = right_set_new; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/join_node.cpp b/src/duckdb/src/optimizer/join_order/join_node.cpp index 031f56ce2..22042b06a 100644 --- a/src/duckdb/src/optimizer/join_order/join_node.cpp +++ b/src/duckdb/src/optimizer/join_order/join_node.cpp @@ -6,7 +6,8 @@ namespace duckdb { -DPJoinNode::DPJoinNode(JoinRelationSet &set) : set(set), info(nullptr), is_leaf(true), left_set(set), right_set(set) { +DPJoinNode::DPJoinNode(JoinRelationSet &set) + : set(set), info(nullptr), is_leaf(true), left_set(set), right_set(set), cost(0) { } DPJoinNode::DPJoinNode(JoinRelationSet &set, optional_ptr info, JoinRelationSet &left, diff --git a/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp b/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp index d570c96d8..c66114984 100644 --- a/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp +++ b/src/duckdb/src/optimizer/join_order/join_order_optimizer.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/limits.hpp" #include "duckdb/common/pair.hpp" +#include "duckdb/optimizer/join_order/cardinality_estimator.hpp" #include "duckdb/optimizer/join_order/cost_model.hpp" #include "duckdb/optimizer/join_order/plan_enumerator.hpp" #include "duckdb/planner/expression/list.hpp" @@ -46,7 +47,9 @@ unique_ptr JoinOrderOptimizer::Optimize(unique_ptrGetExpressionType(); +} + +JoinPredicateClass JoinOrderUtil::ClassifyJoinPredicate(const FilterInfo &filter) { + switch (filter.join_type) { + case JoinType::INNER: { + const auto comparison_type = GetJoinPredicateComparisonType(filter); + if (comparison_type == ExpressionType::COMPARE_EQUAL || + comparison_type == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + return JoinPredicateClass::INNER_EQUALITY; + } + return comparison_type == ExpressionType::INVALID ? JoinPredicateClass::OTHER + : JoinPredicateClass::INNER_NON_EQUALITY; + } + case JoinType::LEFT: + return JoinPredicateClass::LEFT_JOIN; + case JoinType::SEMI: + case JoinType::ANTI: + return JoinPredicateClass::SEMI_ANTI_JOIN; + default: + return JoinPredicateClass::OTHER; + } +} + +bool JoinOrderUtil::HasAnyValidJoinBinding(const FilterInfo &filter) { + return filter.left_binding.table_index.IsValid() || filter.right_binding.table_index.IsValid(); +} + +bool JoinOrderUtil::HasValidJoinEndpoints(const FilterInfo &filter) { + return filter.left_set && filter.right_set && !filter.left_set->Empty() && !filter.right_set->Empty() && + filter.left_set != filter.right_set; +} + +bool JoinOrderUtil::HasDisjointJoinEndpoints(const FilterInfo &filter) { + if (!HasValidJoinEndpoints(filter)) { + return false; + } + auto &left = *filter.left_set; + auto &right = *filter.right_set; + for (idx_t left_idx = 0; left_idx < left.count; left_idx++) { + for (idx_t right_idx = 0; right_idx < right.count; right_idx++) { + if (left.relations[left_idx] == right.relations[right_idx]) { + return false; + } + } + } + return true; +} + +RelationIndex JoinOrderUtil::GetBindingRelation(const ColumnBinding &binding) { + D_ASSERT(binding.table_index.IsValid()); + return RelationIndex(binding.table_index.index); +} + +JoinPredicate::JoinPredicate(idx_t index, FilterInfo &filter, JoinPredicateClass predicate_class, + ColumnBinding left_equality_binding, ColumnBinding right_equality_binding) + : index(index), filter(filter), predicate_class(predicate_class), + comparison_type(JoinOrderUtil::GetJoinPredicateComparisonType(filter)), + left_equality_binding(left_equality_binding), right_equality_binding(right_equality_binding) { +} + +idx_t JoinPredicate::GetIndex() const { + return index; +} + +FilterInfo &JoinPredicate::GetFilter() const { + return filter.get(); +} + +JoinPredicateClass JoinPredicate::GetPredicateClass() const { + return predicate_class; +} + +JoinType JoinPredicate::GetJoinType() const { + return filter.get().join_type; +} + +ExpressionType JoinPredicate::GetComparisonType() const { + return comparison_type; +} + +JoinRelationSet &JoinPredicate::GetSet() const { + return filter.get().set; +} + +JoinRelationSet &JoinPredicate::GetLeftSet() const { + D_ASSERT(filter.get().left_set); + return *filter.get().left_set; +} + +JoinRelationSet &JoinPredicate::GetRightSet() const { + D_ASSERT(filter.get().right_set); + return *filter.get().right_set; +} + +optional_ptr JoinPredicate::GetLeftSetOptional() const { + return filter.get().left_set; +} + +optional_ptr JoinPredicate::GetRightSetOptional() const { + return filter.get().right_set; +} + +const ColumnBinding &JoinPredicate::GetStatsBinding(bool left) const { + return left ? filter.get().left_binding : filter.get().right_binding; +} + +const ColumnBinding &JoinPredicate::GetEqualityBinding(bool left) const { + return left ? left_equality_binding : right_equality_binding; +} + +bool JoinPredicate::HasValidJoinEndpoints() const { + return JoinOrderUtil::HasValidJoinEndpoints(filter.get()); +} + +bool JoinPredicate::HasDisjointJoinEndpoints() const { + return JoinOrderUtil::HasDisjointJoinEndpoints(filter.get()); +} + +bool JoinPredicate::HasAnyValidStatsBinding() const { + return JoinOrderUtil::HasAnyValidJoinBinding(filter.get()); +} + +bool JoinPredicate::HasValidEqualityBindings() const { + return left_equality_binding.table_index.IsValid() && right_equality_binding.table_index.IsValid(); +} + +bool JoinPredicate::CanBuildEqualityClosure() const { + return predicate_class == JoinPredicateClass::INNER_EQUALITY && HasValidEqualityBindings() && + HasValidJoinEndpoints(); +} + +bool JoinPredicate::CanBuildSelectivityDomain() const { + if (!HasValidJoinEndpoints() || !HasAnyValidStatsBinding()) { + return false; + } + return predicate_class != JoinPredicateClass::INNER_EQUALITY || !CanBuildEqualityClosure(); +} + +bool JoinPredicate::IsEquivalencePredicate() const { + return CanBuildEqualityClosure(); +} + +void JoinPredicate::SetEqualityClassIndex(optional_idx equality_class_index) { + this->equality_class_index = equality_class_index; + filter.get().edge_equivalence_index = equality_class_index; +} + +optional_idx JoinPredicate::GetEqualityClassIndex() const { + return equality_class_index; +} + +JoinEqualityPredicateEdge::JoinEqualityPredicateEdge(JoinPredicate &predicate, RelationIndex left_relation, + RelationIndex right_relation, ColumnBinding left_binding, + ColumnBinding right_binding) + : predicate(predicate), left_relation(left_relation), right_relation(right_relation), left_binding(left_binding), + right_binding(right_binding) { +} + +bool RelationPairEqualitySummary::HasDirectCompositeEquality() const { + return direct_equality_class_indices.size() >= 2 && first_relation_bindings.size() >= 2 && + second_relation_bindings.size() >= 2; +} + +void JoinPredicateModel::Clear() { + predicates.clear(); + all_predicates.clear(); + equality_join_predicates.clear(); + selectivity_predicates.clear(); + graph_predicates.clear(); + has_left_join_predicates = false; + equality_classes.clear(); + equality_pairs.clear(); +} + +JoinPredicate &JoinPredicateModel::RegisterPredicate(FilterInfo &filter, JoinPredicateClass predicate_class, + ColumnBinding left_equality_binding, + ColumnBinding right_equality_binding) { + auto predicate = make_uniq(predicates.size(), filter, predicate_class, left_equality_binding, + right_equality_binding); + auto &result = *predicate; + predicates.push_back(std::move(predicate)); + all_predicates.push_back(result); + if (predicate_class == JoinPredicateClass::LEFT_JOIN) { + has_left_join_predicates = true; + } + if (result.HasDisjointJoinEndpoints()) { + graph_predicates.push_back(result); + } + if (result.CanBuildSelectivityDomain()) { + selectivity_predicates.push_back(result); + } + return result; +} + +void JoinPredicateModel::AddEqualityClass(JoinEqualityClass equality_class) { + D_ASSERT(equality_class.index == equality_classes.size()); + for (auto &edge : equality_class.edges) { + equality_join_predicates.push_back(edge.predicate); + } + equality_classes.push_back(std::move(equality_class)); +} + +bool JoinPredicateModel::ContainsClassIndex(const vector &class_indices, idx_t equality_class_index) { + return std::find(class_indices.begin(), class_indices.end(), equality_class_index) != class_indices.end(); +} + +void JoinPredicateModel::AddDirectEqualityPairClass(JoinRelationSet &pair, idx_t equality_class_index, + ColumnBinding first_binding, ColumnBinding second_binding) { + auto &summary = equality_pairs[pair]; + if (!ContainsClassIndex(summary.direct_equality_class_indices, equality_class_index)) { + summary.direct_equality_class_indices.push_back(equality_class_index); + } + summary.first_relation_bindings.insert(first_binding); + summary.second_relation_bindings.insert(second_binding); +} + +const vector> &JoinPredicateModel::GetPredicates() const { + return all_predicates; +} + +const vector> &JoinPredicateModel::GetEqualityJoinPredicates() const { + return equality_join_predicates; +} + +const vector> &JoinPredicateModel::GetSelectivityPredicates() const { + return selectivity_predicates; +} + +const vector> &JoinPredicateModel::GetGraphPredicates() const { + return graph_predicates; +} + +const vector &JoinPredicateModel::GetEqualityClasses() const { + return equality_classes; +} + +bool JoinPredicateModel::HasLeftJoinPredicates() const { + return has_left_join_predicates; +} + +bool JoinPredicateModel::HasDirectCompositeEquality(JoinRelationSet &pair) const { + auto entry = equality_pairs.find(pair); + if (entry == equality_pairs.end()) { + return false; + } + return entry->second.HasDirectCompositeEquality(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/join_order/join_relation_set.cpp b/src/duckdb/src/optimizer/join_order/join_relation_set.cpp index 5cfe32d41..b5d4a8b21 100644 --- a/src/duckdb/src/optimizer/join_order/join_relation_set.cpp +++ b/src/duckdb/src/optimizer/join_order/join_relation_set.cpp @@ -1,4 +1,4 @@ -#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" #include "duckdb/common/printer.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/to_string.hpp" @@ -9,6 +9,10 @@ namespace duckdb { using JoinRelationTreeNode = JoinRelationSetManager::JoinRelationTreeNode; +JoinRelationSet::JoinRelationSet(unsafe_unique_array relations, idx_t count) + : relations(std::move(relations)), count(count) { +} + // LCOV_EXCL_START string JoinRelationSet::ToString() const { string result = "["; @@ -19,6 +23,10 @@ string JoinRelationSet::ToString() const { } // LCOV_EXCL_STOP +bool JoinRelationSet::Empty() const { + return count == 0; +} + //! Returns true if sub is a subset of super bool JoinRelationSet::IsSubset(JoinRelationSet &super, JoinRelationSet &sub) { D_ASSERT(sub.count > 0); @@ -67,6 +75,14 @@ JoinRelationSet &JoinRelationSetManager::GetJoinRelation(RelationIndex index) { return GetJoinRelation(std::move(relations), count); } +JoinRelationSet &JoinRelationSetManager::GetEmptyJoinRelationSet() { + if (!empty_relation_set) { + const unordered_set empty_bindings = {}; + empty_relation_set = GetJoinRelation(empty_bindings); + } + return *empty_relation_set.get(); +} + JoinRelationSet &JoinRelationSetManager::GetJoinRelation(const unordered_set &bindings) { // create a sorted vector of the relations unsafe_unique_array relations = diff --git a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp index 608758f9f..012d6e0da 100644 --- a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp +++ b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp @@ -9,6 +9,11 @@ namespace duckdb { +PlanEnumerator::PlanEnumerator(QueryGraphManager &query_graph_manager, CostModel &cost_model, + const QueryGraphEdges &query_graph) + : query_graph(query_graph), query_graph_manager(query_graph_manager), cost_model(cost_model) { +} + static vector> AddSuperSets(const vector> ¤t, const vector &all_neighbors) { vector> ret; @@ -78,6 +83,7 @@ static vector> GetAllNeighborSets(vector> &PlanEnumerator:: return plans; } +const vector> &PlanEnumerator::GetConnections(JoinRelationSet &left, JoinRelationSet &right) { + auto &left_cache = connection_cache[left]; + auto entry = left_cache.find(right); + if (entry != left_cache.end()) { + return entry->second; + } + auto connections = query_graph.GetConnections(left, right); + auto inserted = left_cache.insert(make_pair(reference(right), std::move(connections))); + return inserted.first->second; +} + +static idx_t GetNeighborSetCacheKey(const vector &neighbors) { + idx_t key = 0; + for (auto &neighbor : neighbors) { + D_ASSERT(neighbor.index < sizeof(idx_t) * 8); + key |= idx_t(1) << neighbor.index; + } + return key; +} + +const vector> &PlanEnumerator::GetAllNeighborRelationSets(vector neighbors) { + auto key = GetNeighborSetCacheKey(neighbors); + auto entry = neighbor_set_cache.find(key); + if (entry != neighbor_set_cache.end()) { + return entry->second; + } + + auto all_subset = GetAllNeighborSets(std::move(neighbors)); + vector> relation_sets; + relation_sets.reserve(all_subset.size()); + for (const auto &rel_set : all_subset) { + relation_sets.push_back(query_graph_manager.set_manager.GetJoinRelation(rel_set)); + } + auto inserted = neighbor_set_cache.insert(make_pair(key, std::move(relation_sets))); + return inserted.first->second; +} + //! Create a new JoinTree node by joining together two previous JoinTree nodes unique_ptr PlanEnumerator::CreateJoinTree(JoinRelationSet &set, const vector> &possible_connections, DPJoinNode &left, DPJoinNode &right) { // FIXME: should consider different join algorithms, should we pick a join algorithm here as well? (probably) optional_ptr best_connection = possible_connections.back().get(); - // cross products are technically still connections, but the filter expression is a null_ptr + // cross products are technically still connections, but they have no predicates bool found_non_cross_product_connection = false; for (auto &connection : possible_connections) { - for (auto &filter : connection.get().filters) { - if (filter->join_type != JoinType::INVALID) { + for (auto predicate_ref : connection.get().predicates) { + if (predicate_ref.get().GetJoinType() != JoinType::INVALID) { best_connection = connection.get(); found_non_cross_product_connection = true; break; @@ -119,12 +162,13 @@ unique_ptr PlanEnumerator::CreateJoinTree(JoinRelationSet &set, } } auto join_type = JoinType::INVALID; - for (auto &filter_binding : best_connection->filters) { - if (!filter_binding->left_set || !filter_binding->right_set) { + for (auto predicate_ref : best_connection->predicates) { + auto &predicate = predicate_ref.get(); + if (!predicate.GetLeftSetOptional() || !predicate.GetRightSetOptional()) { continue; } - join_type = filter_binding->join_type; + join_type = predicate.GetJoinType(); // prefer joining on semi and anti joins as they have a higher chance of being more // selective if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { @@ -132,9 +176,9 @@ unique_ptr PlanEnumerator::CreateJoinTree(JoinRelationSet &set, } } // need the filter info from the Neighborhood info. - auto cost = cost_model.ComputeCost(left, right); + auto cost = cost_model.ComputeCost(left, right, set, possible_connections); auto result = make_uniq(set, best_connection, left.set, right.set, cost); - result->cardinality = cost_model.cardinality_estimator.EstimateCardinalityWithSet(set); + result->cardinality = cost_model.GetCardinalityEstimator().EstimateCardinalityWithSet(set); return result; } @@ -161,6 +205,21 @@ DPJoinNode &PlanEnumerator::EmitPair(JoinRelationSet &left, JoinRelationSet &rig plans[new_set] = std::move(new_plan); return *plans[new_set]; } + // Tiebreaker for equal-cost plans: needed for LEFT JOINs, + // because all orderings preserve the LHS cardinality and always tie. + // This tiebreaker causes less joins to be flipped from LEFT to RIGHT + // later by the BuildProbeSideOptimizer, and we strongly prefer LEFT + if (new_cost == old_cost) { + auto new_right_cardinality = right_plan->second->cardinality; + auto existing_right = plans.find(entry->second->right_set); + if (existing_right != plans.end()) { + auto old_right_cardinality = existing_right->second->cardinality; + if (new_right_cardinality > old_right_cardinality) { + plans[new_set] = std::move(new_plan); + return *plans[new_set]; + } + } + } // Create join node from the plan currently in the DP table. return *entry->second; } @@ -217,7 +276,7 @@ bool PlanEnumerator::EmitCSG(JoinRelationSet &node) { // since the GetNeighbors only returns the smallest element in a list, the entry might not be connected to // (only!) this neighbor, hence we have to do a connectedness check before we can emit it auto &neighbor_relation = query_graph_manager.set_manager.GetJoinRelation(neighbor); - auto connections = query_graph.GetConnections(node, neighbor_relation); + auto &connections = GetConnections(node, neighbor_relation); if (!connections.empty()) { if (!TryEmitPair(node, neighbor_relation, connections)) { return false; @@ -241,18 +300,17 @@ bool PlanEnumerator::EnumerateCmpRecursive(JoinRelationSet &left, JoinRelationSe return true; } - auto all_subset = GetAllNeighborSets(neighbors); + auto &all_subset = GetAllNeighborRelationSets(neighbors); vector> union_sets; union_sets.reserve(all_subset.size()); - for (const auto &rel_set : all_subset) { - auto &neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); + for (auto &neighbor : all_subset) { // emit the combinations of this node and its neighbors - auto &combined_set = query_graph_manager.set_manager.Union(right, neighbor); + auto &combined_set = query_graph_manager.set_manager.Union(right, neighbor.get()); // If combined_set.count == right.count, This means we found a neighbor that has been present before // This means we didn't set exclusion_set correctly. D_ASSERT(combined_set.count > right.count); if (plans.find(combined_set) != plans.end()) { - auto connections = query_graph.GetConnections(left, combined_set); + auto &connections = GetConnections(left, combined_set); if (!connections.empty()) { if (!TryEmitPair(left, combined_set, connections)) { return false; @@ -284,13 +342,12 @@ bool PlanEnumerator::EnumerateCSGRecursive(JoinRelationSet &node, unordered_set< return true; } - auto all_subset = GetAllNeighborSets(neighbors); + auto &all_subset = GetAllNeighborRelationSets(neighbors); vector> union_sets; union_sets.reserve(all_subset.size()); - for (const auto &rel_set : all_subset) { - auto &neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); + for (auto &neighbor : all_subset) { // emit the combinations of this node and its neighbors - auto &new_set = query_graph_manager.set_manager.Union(node, neighbor); + auto &new_set = query_graph_manager.set_manager.Union(node, neighbor.get()); D_ASSERT(new_set.count > node.count); if (plans.find(new_set) != plans.end()) { if (!EmitCSG(new_set)) { @@ -353,32 +410,34 @@ void PlanEnumerator::SolveJoinOrderApproximately() { // long is needed to prevent clang-tidy complaints. (idx_t) cannot be added to an iterator position because it // is unsigned. idx_t best_left = 0, best_right = 0; - optional_ptr best_connection; + bool found_connection = false; + double best_cost = NumericLimits::Maximum(); for (idx_t i = 0; i < join_relations.size(); i++) { auto left = join_relations[i]; for (idx_t j = i + 1; j < join_relations.size(); j++) { auto right = join_relations[j]; // check if we can connect these two relations - auto connection = query_graph.GetConnections(left, right); + auto &connection = GetConnections(left, right); if (!connection.empty()) { // we can check the cost of this connection - auto node = EmitPair(left, right, connection); + auto &node = EmitPair(left, right, connection); // update the DP tree in case a plan created by the DP algorithm uses the node // that was potentially just updated by EmitPair. You will get a use-after-free // error if future plans rely on the old node that was just replaced. // if node in FullPath, then updateDP tree. - if (!best_connection || node.cost < best_connection->cost) { + if (!found_connection || node.cost < best_cost) { // best pair found so far - best_connection = &EmitPair(left, right, connection); + found_connection = true; + best_cost = node.cost; best_left = i; best_right = j; } } } } - if (!best_connection) { + if (!found_connection) { // could not find a connection, but we were not done with finding a completed plan // we have to add a cross product; we add it between the two smallest relations optional_ptr smallest_plans[2]; @@ -416,11 +475,12 @@ void PlanEnumerator::SolveJoinOrderApproximately() { auto &right = smallest_plans[1]->set; // create a cross product edge (i.e. edge with empty filter) between these two sets in the query graph query_graph_manager.CreateQueryGraphCrossProduct(left, right); + connection_cache.clear(); // now emit the pair and continue with the algorithm - auto connections = query_graph.GetConnections(left, right); + auto &connections = GetConnections(left, right); D_ASSERT(!connections.empty()); - best_connection = &EmitPair(left, right, connections); + EmitPair(left, right, connections); best_left = smallest_index[0]; best_right = smallest_index[1]; @@ -450,8 +510,8 @@ void PlanEnumerator::InitLeafPlans() { // first initialize equivalent relations based on the filters auto relation_stats = query_graph_manager.relation_manager.GetRelationStats(); - cost_model.cardinality_estimator.InitEquivalentRelations(query_graph_manager.GetFilterBindings()); - cost_model.cardinality_estimator.AddRelationNamesToRelationStats(relation_stats); + cost_model.GetCardinalityEstimator().InitEquivalentRelations(); + cost_model.GetCardinalityEstimator().AddRelationNamesToRelationStats(relation_stats); // then update the total domains based on the cardinalities of each relation. for (idx_t i = 0; i < relation_stats.size(); i++) { @@ -462,7 +522,7 @@ void PlanEnumerator::InitLeafPlans() { join_node->cardinality = stats.cardinality; D_ASSERT(join_node->set.count == 1); plans[relation_set] = std::move(join_node); - cost_model.cardinality_estimator.InitCardinalityEstimatorProps(&relation_set, stats); + cost_model.GetCardinalityEstimator().InitCardinalityEstimatorProps(&relation_set, stats); } } @@ -489,7 +549,7 @@ void PlanEnumerator::SolveJoinOrder() { auto final_plan = plans.find(total_relation); if (final_plan == plans.end()) { // could not find the final plan - // this should only happen in case the sets are actually disjunct + // this should only happen in case the sets are actually disjoint // in this case we need to generate cross product to connect the disjoint sets if (force_no_cross_product) { throw InvalidInputException( @@ -497,7 +557,7 @@ void PlanEnumerator::SolveJoinOrder() { } GenerateCrossProducts(); //! solve the join order again, returning the final plan - return SolveJoinOrder(); + SolveJoinOrder(); } } diff --git a/src/duckdb/src/optimizer/join_order/query_graph.cpp b/src/duckdb/src/optimizer/join_order/query_graph.cpp index 777adc7b9..5f43fb190 100644 --- a/src/duckdb/src/optimizer/join_order/query_graph.cpp +++ b/src/duckdb/src/optimizer/join_order/query_graph.cpp @@ -8,6 +8,9 @@ namespace duckdb { using QueryEdge = QueryGraphEdges::QueryEdge; +NeighborInfo::NeighborInfo(optional_ptr neighbor) : neighbor(neighbor) { +} + // LCOV_EXCL_START static string QueryEdgeToString(const QueryEdge *info, vector prefix) { string result = ""; @@ -53,26 +56,26 @@ optional_ptr QueryGraphEdges::GetQueryEdge(JoinRelationSet &left) { return info; } -void QueryGraphEdges::CreateEdge(JoinRelationSet &left, JoinRelationSet &right, optional_ptr filter_info) { +void QueryGraphEdges::CreateEdge(JoinRelationSet &left, JoinRelationSet &right, optional_ptr predicate) { D_ASSERT(left.count > 0 && right.count > 0); // find the EdgeInfo corresponding to the left set auto info = GetQueryEdge(left); // now insert the edge to the right relation, if it does not exist for (idx_t i = 0; i < info->neighbors.size(); i++) { if (info->neighbors[i]->neighbor == &right) { - if (filter_info) { - // neighbor already exists just add the filter, if we have any - info->neighbors[i]->filters.push_back(filter_info); + if (predicate) { + // neighbor already exists, just add the predicate if we have one + info->neighbors[i]->predicates.push_back(*predicate); } return; } } // neighbor does not exist, create it auto n = make_uniq(&right); - // if the edge represents a cross product, filter_info is null. The easiest way then to determine - // if an edge is for a cross product is if the filters are empty - if (info && filter_info) { - n->filters.push_back(filter_info); + // if the edge represents a cross product, predicate is null. The easiest way then to determine + // if an edge is for a cross product is if the predicates are empty + if (info && predicate) { + n->predicates.push_back(*predicate); } info->neighbors.push_back(std::move(n)); } diff --git a/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp b/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp index b4f9a4975..13b70e8b4 100644 --- a/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp +++ b/src/duckdb/src/optimizer/join_order/query_graph_manager.cpp @@ -2,7 +2,10 @@ #include "duckdb/common/assert.hpp" #include "duckdb/common/enums/join_type.hpp" -#include "duckdb/optimizer/join_order/join_relation.hpp" +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" @@ -12,12 +15,115 @@ namespace duckdb { -//! Returns true if A and B are disjoint, false otherwise -template -static bool Disjoint(const unordered_set &a, const unordered_set &b) { - return std::all_of(a.begin(), a.end(), [&b](typename std::unordered_set::const_reference entry) { - return b.find(entry) == b.end(); - }); +GenerateJoinRelation::GenerateJoinRelation(optional_ptr set, unique_ptr op_p) + : set(set), op(std::move(op_p)) { +} + +QueryGraphManager::QueryGraphManager(ClientContext &context) : context(context), relation_manager(context) { +} + +void QueryGraphManager::BuildPredicateModel() { + predicate_model.Clear(); + + column_binding_map_t binding_to_component; + vector parents; + + auto get_component = [&](const ColumnBinding &binding) -> idx_t { + auto entry = binding_to_component.find(binding); + if (entry != binding_to_component.end()) { + return entry->second; + } + auto component = parents.size(); + parents.push_back(component); + binding_to_component[binding] = component; + return component; + }; + + auto find_root = [&](idx_t component) -> idx_t { + while (parents[component] != component) { + parents[component] = parents[parents[component]]; + component = parents[component]; + } + return component; + }; + + // First union all normalized equality predicates. Do not assign edge ids in this pass, + // because a later predicate can merge two previously separate components. + for (auto &filter : filters_and_bindings) { + filter->edge_equivalence_index = optional_idx(); + auto predicate_class = JoinOrderUtil::ClassifyJoinPredicate(*filter); + ColumnBinding left_equality_binding; + ColumnBinding right_equality_binding; + if (BoundComparisonExpression::IsComparison(*filter->filter)) { + auto &comparison = filter->filter->Cast(); + GetEquivalenceBinding(BoundComparisonExpression::Left(comparison), left_equality_binding); + GetEquivalenceBinding(BoundComparisonExpression::Right(comparison), right_equality_binding); + } + auto &predicate = + predicate_model.RegisterPredicate(*filter, predicate_class, left_equality_binding, right_equality_binding); + if (!predicate.CanBuildEqualityClosure()) { + continue; + } + + auto left_root = find_root(get_component(predicate.GetEqualityBinding(true))); + auto right_root = find_root(get_component(predicate.GetEqualityBinding(false))); + if (left_root != right_root) { + parents[MaxValue(left_root, right_root)] = MinValue(left_root, right_root); + } + } + + // Then assign stable final ids to every equality edge using the final roots. + unordered_map root_to_equivalence_id; + vector>> equality_class_predicates; + for (auto predicate_ref : predicate_model.GetPredicates()) { + auto &predicate = predicate_ref.get(); + if (!predicate.CanBuildEqualityClosure()) { + continue; + } + auto root = find_root(binding_to_component[predicate.GetEqualityBinding(true)]); + auto entry = root_to_equivalence_id.find(root); + if (entry == root_to_equivalence_id.end()) { + entry = root_to_equivalence_id.insert(make_pair(root, root_to_equivalence_id.size())).first; + equality_class_predicates.emplace_back(); + } + predicate.SetEqualityClassIndex(optional_idx(entry->second)); + equality_class_predicates[entry->second].push_back(predicate); + } + + for (idx_t equality_class_index = 0; equality_class_index < equality_class_predicates.size(); + equality_class_index++) { + JoinEqualityClass equality_class; + equality_class.index = equality_class_index; + for (auto predicate_ref : equality_class_predicates[equality_class_index]) { + auto &predicate = predicate_ref.get(); + auto &left_binding = predicate.GetEqualityBinding(true); + auto &right_binding = predicate.GetEqualityBinding(false); + auto left_relation = JoinOrderUtil::GetBindingRelation(left_binding); + auto right_relation = JoinOrderUtil::GetBindingRelation(right_binding); + equality_class.columns.insert(left_binding); + equality_class.columns.insert(right_binding); + equality_class.relations.insert(left_relation); + equality_class.relations.insert(right_relation); + equality_class.edges.emplace_back(predicate, left_relation, right_relation, left_binding, right_binding); + } + predicate_model.AddEqualityClass(std::move(equality_class)); + } + + for (auto &equality_class : predicate_model.GetEqualityClasses()) { + for (auto &edge : equality_class.edges) { + auto &left = set_manager.GetJoinRelation(edge.left_relation); + auto &right = set_manager.GetJoinRelation(edge.right_relation); + auto &pair = set_manager.Union(left, right); + auto first_binding = edge.left_binding; + auto second_binding = edge.right_binding; + if (pair.relations[0] != edge.left_relation) { + D_ASSERT(pair.relations[0] == edge.right_relation); + first_binding = edge.right_binding; + second_binding = edge.left_binding; + } + predicate_model.AddDirectEqualityPairClass(pair, equality_class.index, first_binding, second_binding); + } + } } bool QueryGraphManager::Build(JoinOrderOptimizer &optimizer, LogicalOperator &op) { @@ -31,34 +137,56 @@ bool QueryGraphManager::Build(JoinOrderOptimizer &optimizer, LogicalOperator &op } // extract the edges of the hypergraph, creating a list of filters and their associated bindings. filters_and_bindings = relation_manager.ExtractEdges(op, filter_operators, set_manager); - // Create the query_graph hyper edges + // Populate left/right endpoints and stats bindings before building the predicate model. + BindFilterEndpoints(); + // Build the predicate model after BindFilterEndpoints so that left_binding/right_binding are populated. + BuildPredicateModel(); + // Create query_graph hyper edges from the normalized predicate model. CreateHyperGraphEdges(); return true; } +const JoinPredicateModel &QueryGraphManager::GetPredicateModel() const { + return predicate_model; +} + void QueryGraphManager::GetColumnBinding(const Expression &root_expr, ColumnBinding &binding) { ExpressionIterator::VisitExpression( root_expr, [&](const BoundColumnRefExpression &colref) { - D_ASSERT(colref.depth == 0); - D_ASSERT(colref.binding.table_index.IsValid()); + D_ASSERT(colref.Depth() == 0); + D_ASSERT(colref.Binding().table_index.IsValid()); // map the base table index to the relation index used by the JoinOrderOptimizer - D_ASSERT(relation_manager.relation_mapping.find(colref.binding.table_index) != + D_ASSERT(relation_manager.relation_mapping.find(colref.Binding().table_index) != relation_manager.relation_mapping.end()); - binding = ColumnBinding(TableIndex(relation_manager.relation_mapping[colref.binding.table_index].index), - colref.binding.column_index); + binding = ColumnBinding(TableIndex(relation_manager.relation_mapping[colref.Binding().table_index].index), + colref.Binding().column_index); }); } -const vector> &QueryGraphManager::GetFilterBindings() const { - return filters_and_bindings; -} - -void FilterInfo::SetLeftSet(optional_ptr left_set_new) { - left_set = left_set_new; -} - -void FilterInfo::SetRightSet(optional_ptr right_set_new) { - right_set = right_set_new; +void QueryGraphManager::GetEquivalenceBinding(const Expression &expression, ColumnBinding &binding) { + switch (expression.GetExpressionClass()) { + case ExpressionClass::BOUND_COLUMN_REF: { + auto &colref = expression.Cast(); + D_ASSERT(colref.Depth() == 0); + if (!colref.Binding().table_index.IsValid()) { + return; + } + auto entry = relation_manager.relation_mapping.find(colref.Binding().table_index); + D_ASSERT(entry != relation_manager.relation_mapping.end()); + binding = ColumnBinding(TableIndex(entry->second.index), colref.Binding().column_index); + return; + } + case ExpressionClass::BOUND_CAST: { + auto &cast = expression.Cast(); + if (cast.IsTryCast() || !BoundCastExpression::CastIsInvertible(cast.source_type(), cast.GetReturnType())) { + return; + } + GetEquivalenceBinding(cast.Child(), binding); + return; + } + default: + return; + } } static unique_ptr PushFilter(unique_ptr node, unique_ptr expr) { @@ -77,8 +205,11 @@ static unique_ptr PushFilter(unique_ptr node, return node; } -void QueryGraphManager::CreateHyperGraphEdges() { - // create potential edges from the comparisons +static bool RelationSetsEqual(JoinRelationSet &left, JoinRelationSet &right) { + return JoinRelationSet::IsSubset(left, right) && JoinRelationSet::IsSubset(right, left); +} + +void QueryGraphManager::BindFilterEndpoints() { for (auto &filter_info : filters_and_bindings) { auto &filter = filter_info->filter; // now check if it can be used as a join predicate @@ -101,14 +232,11 @@ void QueryGraphManager::CreateHyperGraphEdges() { if (!filter_info->right_set) { filter_info->right_set = &set_manager.GetJoinRelation(right_bindings); } - // we can only create a meaningful edge if the sets are not exactly the same - if (filter_info->left_set != filter_info->right_set) { - // check if the sets are disjoint - if (Disjoint(left_bindings, right_bindings)) { - // they are disjoint, we only need to create one set of edges in the join graph - query_graph.CreateEdge(*filter_info->left_set, *filter_info->right_set, filter_info); - query_graph.CreateEdge(*filter_info->right_set, *filter_info->left_set, filter_info); - } + D_ASSERT(filter_info->left_set && filter_info->right_set); + auto &condition_set = set_manager.Union(*filter_info->left_set, *filter_info->right_set); + if (!RelationSetsEqual(filter_info->set.get(), condition_set)) { + filter_info->SetLeftSet(nullptr); + filter_info->SetRightSet(nullptr); } } } else if (filter->GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION) { @@ -123,7 +251,7 @@ void QueryGraphManager::CreateHyperGraphEdges() { D_ASSERT(filter_info->left_set); D_ASSERT(filter_info->right_set); D_ASSERT(filter_info->join_type == JoinType::SEMI || filter_info->join_type == JoinType::ANTI); - for (auto &child_comp : conjunction.children) { + for (auto &child_comp : conjunction.GetChildren()) { if (!BoundComparisonExpression::IsComparison(*child_comp)) { continue; } @@ -143,20 +271,21 @@ void QueryGraphManager::CreateHyperGraphEdges() { } } if (!left_bindings.empty() && !right_bindings.empty()) { - // we can only create a meaningful edge if the sets are not exactly the same - if (filter_info->left_set != filter_info->right_set) { - // check if the sets are disjoint - if (Disjoint(left_bindings, right_bindings)) { - // they are disjoint, we only need to create one set of edges in the join graph - query_graph.CreateEdge(*filter_info->left_set, *filter_info->right_set, filter_info); - query_graph.CreateEdge(*filter_info->right_set, *filter_info->left_set, filter_info); - } - } + D_ASSERT(filter_info->left_set && filter_info->right_set); } } } } +void QueryGraphManager::CreateHyperGraphEdges() { + for (auto predicate_ref : predicate_model.GetGraphPredicates()) { + auto &predicate = predicate_ref.get(); + D_ASSERT(predicate.GetLeftSetOptional() && predicate.GetRightSetOptional()); + query_graph.CreateEdge(predicate.GetLeftSet(), predicate.GetRightSet(), predicate); + query_graph.CreateEdge(predicate.GetRightSet(), predicate.GetLeftSet(), predicate); + } +} + static unique_ptr ExtractJoinRelation(unique_ptr &rel) { auto &children = rel->parent->children; for (idx_t i = 0; i < children.size(); i++) { @@ -240,6 +369,21 @@ static JoinCondition MaybeInvertConditions(unique_ptr condition, boo return JoinCondition(std::move(left), std::move(right), comp_type); } +static bool ShouldInvertJoinCondition(JoinRelationSetManager &set_manager, GenerateJoinRelation &left, + GenerateJoinRelation &right, FilterInfo &filter, bool fallback_invert) { + if (!filter.left_binding.table_index.IsValid()) { + return fallback_invert; + } + auto &condition_left_set = set_manager.GetJoinRelation(RelationIndex(filter.left_binding.table_index.index)); + if (JoinRelationSet::IsSubset(*right.set, condition_left_set)) { + return true; + } + if (JoinRelationSet::IsSubset(*left.set, condition_left_set)) { + return false; + } + return fallback_invert; +} + GenerateJoinRelation QueryGraphManager::GenerateJoins(vector> &extracted_relations, JoinRelationSet &set) { optional_ptr left_node; @@ -256,17 +400,22 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vectorleft_set); auto right = GenerateJoins(extracted_relations, node->right_set); - if (dp_entry->second->info->filters.empty()) { + if (dp_entry->second->info->predicates.empty()) { // no filters, create a cross product auto cardinality = left.op->estimated_cardinality * right.op->estimated_cardinality; result_operator = LogicalCrossProduct::Create(std::move(left.op), std::move(right.op)); result_operator->SetEstimatedCardinality(cardinality); } else { // we have filters, create a join node - auto chosen_filter = node->info->filters.at(0); - for (idx_t i = 0; i < node->info->filters.size(); i++) { - if (node->info->filters.at(i)->join_type == JoinType::INNER) { - chosen_filter = node->info->filters.at(i); + // Prefer non-INNER join types (LEFT/SEMI/ANTI) since WHERE clause filters default + // to INNER but should not override the actual join semantics of the edge. + auto &chosen_predicates = node->info->predicates; + auto chosen_filter = &chosen_predicates.at(0).get().GetFilter(); + for (idx_t i = 0; i < chosen_predicates.size(); i++) { + auto &predicate = chosen_predicates.at(i).get(); + auto filter_join_type = predicate.GetJoinType(); + if (filter_join_type != JoinType::INNER) { + chosen_filter = &predicate.GetFilter(); break; } } @@ -277,8 +426,8 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vectorchildren.push_back(std::move(right.op)); // set the join conditions from the join node - for (auto &filter_ref : node->info->filters) { - auto f = filter_ref.get(); + for (auto predicate_ref : node->info->predicates) { + auto f = &predicate_ref.get().GetFilter(); // extract the filter from the operator it originally belonged to D_ASSERT(filters_and_bindings[f->filter_index]->filter); auto &filter_and_binding = filters_and_bindings.at(f->filter_index); @@ -289,23 +438,26 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vectorright_set) && JoinRelationSet::IsSubset(*right.set, *f->left_set))); - bool invert = !JoinRelationSet::IsSubset(*left.set, *f->left_set); - // If the left and right set are inverted AND it is a semi or anti join - // swap left and right children back. + bool invert_children = !JoinRelationSet::IsSubset(*left.set, *f->left_set); - if (invert && (f->join_type == JoinType::SEMI || f->join_type == JoinType::ANTI)) { + // If the left and right set are inverted for LEFT/SEMI/ANTI joins then swap them back + // and set invert = false. This is to preserve left/rightedness of relations + if (invert_children && (f->join_type == JoinType::LEFT || f->join_type == JoinType::SEMI || + f->join_type == JoinType::ANTI)) { std::swap(join->children[0], join->children[1]); - invert = false; + std::swap(left, right); + invert_children = false; } if (BoundComparisonExpression::IsComparison(*condition)) { + auto invert = ShouldInvertJoinCondition(set_manager, left, right, *f, invert_children); auto cond = MaybeInvertConditions(std::move(condition), invert); join->conditions.push_back(std::move(cond)); } else if (condition->GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION) { auto &conjunction = condition->Cast(); - for (auto &child : conjunction.children) { + for (auto &child : conjunction.GetChildrenMutable()) { D_ASSERT(BoundComparisonExpression::IsComparison(*child)); - auto cond = MaybeInvertConditions(std::move(child), invert); + auto cond = MaybeInvertConditions(std::move(child), invert_children); join->conditions.push_back(std::move(cond)); } } @@ -326,7 +478,6 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vectorestimated_props = node.estimated_props->Copy(); result_operator->estimated_cardinality = node->cardinality; result_operator->has_estimated_cardinality = true; @@ -372,6 +523,11 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vector 0 && JoinRelationSet::IsSubset(*result_relation, info.set)) { auto &filter_and_binding = filters_and_bindings[info.filter_index]; auto filter = std::move(filter_and_binding->filter); diff --git a/src/duckdb/src/optimizer/join_order/relation_manager.cpp b/src/duckdb/src/optimizer/join_order/relation_manager.cpp index 112c71b0c..14384075c 100644 --- a/src/duckdb/src/optimizer/join_order/relation_manager.cpp +++ b/src/duckdb/src/optimizer/join_order/relation_manager.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/enums/logical_operator_type.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/optimizer/join_order/join_order_optimizer.hpp" +#include "duckdb/optimizer/join_order/join_relation_set.hpp" #include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" #include "duckdb/parser/expression_map.hpp" #include "duckdb/planner/expression/list.hpp" @@ -13,6 +14,17 @@ namespace duckdb { +SingleJoinRelation::SingleJoinRelation(LogicalOperator &op, optional_ptr parent) + : op(op), parent(parent) { +} + +SingleJoinRelation::SingleJoinRelation(LogicalOperator &op, optional_ptr parent, RelationStats stats) + : op(op), parent(parent), stats(std::move(stats)) { +} + +RelationManager::RelationManager(ClientContext &context) : context(context) { +} + const vector RelationManager::GetRelationStats() { vector ret; for (idx_t i = 0; i < relations.size(); i++) { @@ -76,14 +88,19 @@ void RelationManager::AddRelation(LogicalOperator &op, optional_ptr(); - // TODO: SEMI/ANTI joins with residual predicates are not supported - if (join.join_type == JoinType::SEMI || join.join_type == JoinType::ANTI) { - for (auto &cond : join.conditions) { - if (!cond.IsComparison()) { - return false; - } - } - } - switch (join.join_type) { case JoinType::INNER: case JoinType::SEMI: + case JoinType::LEFT: case JoinType::ANTI: for (auto &cond : join.conditions) { if (cond.IsComparison() && ExpressionContainsColumnRef(cond.GetLHS()) && @@ -176,6 +209,15 @@ static bool JoinIsReorderable(LogicalOperator &op) { return false; } +static bool RecursiveCTERefCanReorder(optional_ptr parent) { + if (!parent || parent->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return false; + } + auto &join = parent->Cast(); + idx_t range_count = 0; + return join.HasEquality(range_count); +} + static bool HasNonReorderableChild(LogicalOperator &op) { LogicalOperator *tmp = &op; while (tmp->children.size() == 1) { @@ -332,14 +374,14 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica // So we treat the right side of left join as its own relation so no relations // are pushed into the right side, or taken out of the right side. - if (join.join_type == JoinType::SEMI || join.join_type == JoinType::ANTI) { + if (join.join_type == JoinType::SEMI || join.join_type == JoinType::ANTI || join.join_type == JoinType::LEFT) { RelationStats child_stats; // optimize the child and copy the stats auto child_optimizer = optimizer.CreateChildOptimizer(); op->children[1] = child_optimizer.Optimize(std::move(op->children[1]), &child_stats); AddRelation(*op->children[1], op, child_stats); // remember that if a cross product needs to be forced, it cannot be forced - // across the children of a semi or anti join + // across the children of a semi/anti/left join no_cross_product_relations.insert(relations.size() - 1); auto right_child_bindings = op->children[1]->GetColumnBindings(); for (auto &bindings : right_child_bindings) { @@ -455,7 +497,11 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica auto is_recursive = optimizer.recursive_cte_indexes.find(cte_ref.cte_index); if (is_recursive != optimizer.recursive_cte_indexes.end()) { - return false; + // Only let recursive CTE references participate in join reordering when they sit directly + // under an equality comparison join. This still enables the reusable recursive hash-join + // case, while avoiding broader reorderings that let residual/range predicates drive the + // reconstructed join tree into expensive range-join plans. + return RecursiveCTERefCanReorder(parent); } return true; } @@ -529,26 +575,82 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica } } +void RelationManager::GetColumnBindingsFromExpression(const Expression &expression, + column_binding_set_t &column_bindings) { + if (expression.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + // Here you have a filter on a single column in a table. Return a binding for the column + // being filtered on so the filter estimator knows what HLL count to pull + auto &colref = expression.Cast(); + D_ASSERT(colref.Depth() == 0); + D_ASSERT(colref.Binding().table_index.IsValid()); + // only add column bindings that map to relations. + if (relation_mapping.find(colref.Binding().table_index) == relation_mapping.end()) { + return; + } + // map the base table index to the relation index used by the JoinOrderOptimizer + column_bindings.insert(ColumnBinding(TableIndex(relation_mapping[colref.Binding().table_index].index), + colref.Binding().column_index)); + } + + // TODO: handle inequality filters with functions. + ExpressionIterator::EnumerateChildren( + expression, [&](const Expression &expr) { GetColumnBindingsFromExpression(expr, column_bindings); }); +} + +void RelationManager::GetColumnBindingsFromOperator(LogicalOperator &op, column_binding_set_t &column_bindings) { + for (auto &binding : op.GetColumnBindings()) { + auto entry = relation_mapping.find(binding.table_index); + if (entry == relation_mapping.end()) { + continue; + } + column_bindings.insert(ColumnBinding(TableIndex(entry->second.index), binding.column_index)); + } +} + +static bool BindingsAreSubset(const column_binding_set_t &bindings, const column_binding_set_t &super) { + if (bindings.empty()) { + return false; + } + for (auto &binding : bindings) { + if (super.find(binding) == super.end()) { + return false; + } + } + return true; +} + +optional_ptr RelationManager::GetJoinRelations(column_binding_set_t &column_bindings, + JoinRelationSetManager &set_manager) { + optional_ptr ret = set_manager.GetEmptyJoinRelationSet(); + for (auto &binding : column_bindings) { + // binding.table_index stores a RelationIndex wrapped in a TableIndex + optional_ptr binding_set = + set_manager.GetJoinRelation(RelationIndex(binding.table_index.index)); + ret = set_manager.Union(*ret, *binding_set); + } + return *ret; +} + bool RelationManager::ExtractBindings(const Expression &expression, unordered_set &bindings) { if (expression.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &colref = expression.Cast(); - D_ASSERT(colref.depth == 0); - D_ASSERT(colref.binding.table_index.IsValid()); + D_ASSERT(colref.Depth() == 0); + D_ASSERT(colref.Binding().table_index.IsValid()); // map the base table index to the relation index used by the JoinOrderOptimizer if (expression.GetAlias() == "SUBQUERY" && - relation_mapping.find(colref.binding.table_index) == relation_mapping.end()) { + relation_mapping.find(colref.Binding().table_index) == relation_mapping.end()) { // most likely a BoundSubqueryExpression that was created from an uncorrelated subquery // Here we return true and don't fill the bindings, the expression can be reordered. // A filter will be created using this expression, and pushed back on top of the parent // operator during plan reconstruction return true; } - if (relation_mapping.find(colref.binding.table_index) != relation_mapping.end()) { - bindings.insert(relation_mapping[colref.binding.table_index]); + if (relation_mapping.find(colref.Binding().table_index) != relation_mapping.end()) { + bindings.insert(relation_mapping[colref.Binding().table_index]); } } if (expression.GetExpressionType() == ExpressionType::BOUND_REF) { - // bound expression + // bound expression, clear bindings as we can't use this bindings.clear(); return false; } @@ -576,67 +678,200 @@ vector> RelationManager::ExtractEdges(LogicalOperator &op f_op.type == LogicalOperatorType::LOGICAL_ASOF_JOIN) { auto &join = f_op.Cast(); D_ASSERT(join.expressions.empty()); - if (join.join_type == JoinType::SEMI || join.join_type == JoinType::ANTI) { + switch (join.join_type) { + case JoinType::SEMI: + case JoinType::ANTI: { auto conjunction_expression = make_uniq(ExpressionType::CONJUNCTION_AND); - // create a conjunction expression for the semi join. + // create a conjunction expression for the semi/anti join. // It's possible multiple LHS relations have a condition in - // this semi join. Suppose we have ((A ⨝ B) ⋉ C). (example in test_4950.test) + // this join. Suppose we have ((A ⨝ B) ⋉ C). (example in test_4950.test) // If the semi join condition has A.x = C.y AND B.x = C.z then we need to prevent a reordering // that looks like ((A ⋉ C) ⨝ B)), since all columns from C will be lost after it joins with A, // and the condition B.x = C.z will no longer be possible. - // if we make a conjunction expressions and populate the left set and right set with all - // the relations from the conditions in the conjunction expression, we can prevent invalid - // reordering. + // if we make a conjunction expression and populate the left set and right set with all + // the relations from the conditions in the conjunction expression, we can prevent invalid reordering. + + // Non-comparison conditions (e.g. l.val > 1 stored as a residual in the ON clause) + // are NOT added to the conjunction to avoid contaminating the edge binding analysis. + // Instead they are stored as residual predicates below. for (auto &cond : join.conditions) { if (cond.IsComparison()) { auto comparison = BoundComparisonExpression::Create(cond.GetComparisonType(), cond.GetLHS().Copy(), cond.GetRHS().Copy()); - conjunction_expression->children.push_back(std::move(comparison)); + conjunction_expression->GetChildrenMutable().push_back(std::move(comparison)); + } + } + + if (!conjunction_expression->GetChildrenMutable().empty()) { + // create the filter info so all required LHS relations are present when reconstructing the join + optional_ptr left_set; + optional_ptr right_set; + for (auto &bound_expr : conjunction_expression->GetChildrenMutable()) { + unordered_set right_bindings, left_bindings; + auto &comp = bound_expr->Cast(); + auto &left = BoundComparisonExpression::Left(comp); + auto &right = BoundComparisonExpression::Right(comp); + ExtractBindings(right, right_bindings); + ExtractBindings(left, left_bindings); + + if (!left_set) { + left_set = set_manager.GetJoinRelation(left_bindings); + } else { + left_set = set_manager.Union(set_manager.GetJoinRelation(left_bindings), *left_set); + } + if (!right_set) { + right_set = set_manager.GetJoinRelation(right_bindings); + } else { + right_set = set_manager.Union(set_manager.GetJoinRelation(right_bindings), *right_set); + } } + auto &full_set = set_manager.Union(*left_set, *right_set); + D_ASSERT(left_set && left_set->count > 0); + D_ASSERT(right_set && right_set->count == 1); + D_ASSERT(full_set.count > 0); + + auto filter_info = make_uniq(std::move(conjunction_expression), full_set, + filters_and_bindings.size(), join.join_type); + filter_info->SetLeftSet(left_set); + filter_info->SetRightSet(right_set); + filters_and_bindings.push_back(std::move(filter_info)); } - // create the filter info so all required LHS relations are present when reconstructing the - // join - optional_ptr left_set; - optional_ptr right_set; - optional_ptr full_set; - // here we create a left_set that unions all relations from the left side of - // every expression and a right_set that unions all relations from the right side of a - // every expression (although this should always be 1). - for (auto &bound_expr : conjunction_expression->children) { - auto &comp = bound_expr->Cast(); - unordered_set right_bindings, left_bindings; - auto &left = BoundComparisonExpression::Left(comp); - auto &right = BoundComparisonExpression::Right(comp); - ExtractBindings(right, right_bindings); - ExtractBindings(left, left_bindings); - - if (!left_set) { - left_set = set_manager.GetJoinRelation(left_bindings); + // Store non-comparison ON-clause conditions as residual predicates. + unordered_set all_bindings; + for (auto &cond : join.conditions) { + if (!cond.IsComparison()) { + ExtractBindings(cond.GetJoinExpression(), all_bindings); } else { - left_set = set_manager.Union(set_manager.GetJoinRelation(left_bindings), *left_set); + ExtractBindings(cond.GetLHS(), all_bindings); + ExtractBindings(cond.GetRHS(), all_bindings); + } + } + auto full_set = &set_manager.GetJoinRelation(all_bindings); + if (!full_set->Empty()) { + for (auto &cond : join.conditions) { + if (!cond.IsComparison()) { + auto nc_expr = cond.GetJoinExpression().Copy(); + auto new_filter = make_uniq(std::move(nc_expr), *full_set, + filters_and_bindings.size(), join.join_type); + new_filter->from_residual_predicate = true; + filters_and_bindings.push_back(std::move(new_filter)); + } } - if (!right_set) { - right_set = set_manager.GetJoinRelation(right_bindings); + } + break; + } + case JoinType::LEFT: { + // For LEFT joins, create one FilterInfo per comparison condition. + // Each uses the full left/right relation sets from the comparison conditions + // to preserve join-ordering constraints, but individual column bindings so that + // the cardinality estimator gets an accurate distinct count for each condition. + // With all per-condition distinct counts available, GetDenominator can multiply + // them together for a correct multi-condition LEFT join cardinality estimate. + + // Predicate expression sides are not reliable here: SQL can write a LEFT join condition as + // rhs_col = lhs_col. Normalize comparison sides so left_set is the preserved side and right_set + // is the nullable RHS side, without expanding left_set to every relation in the actual left child. + column_binding_set_t semantic_right_bindings; + GetColumnBindingsFromOperator(*join.children[1], semantic_right_bindings); + + column_binding_set_t full_left_bindings, full_right_bindings; + for (auto &cond : join.conditions) { + if (!cond.IsComparison()) { + continue; + } + column_binding_set_t cond_left_bindings, cond_right_bindings; + GetColumnBindingsFromExpression(cond.GetLHS(), cond_left_bindings); + GetColumnBindingsFromExpression(cond.GetRHS(), cond_right_bindings); + + bool lhs_is_nullable = BindingsAreSubset(cond_left_bindings, semantic_right_bindings); + bool rhs_is_nullable = BindingsAreSubset(cond_right_bindings, semantic_right_bindings); + auto &preserved_bindings = + lhs_is_nullable && !rhs_is_nullable ? cond_right_bindings : cond_left_bindings; + auto &nullable_bindings = + lhs_is_nullable && !rhs_is_nullable ? cond_left_bindings : cond_right_bindings; + for (auto &binding : preserved_bindings) { + full_left_bindings.insert(binding); + } + for (auto &binding : nullable_bindings) { + full_right_bindings.insert(binding); + } + } + + if (!full_left_bindings.empty() || !full_right_bindings.empty()) { + auto full_left_set = GetJoinRelations(full_left_bindings, set_manager); + auto full_right_set = GetJoinRelations(full_right_bindings, set_manager); + auto &full_set = set_manager.Union(*full_left_set, *full_right_set); + + for (auto &cond : join.conditions) { + if (!cond.IsComparison()) { + continue; + } + column_binding_set_t cond_left_bindings, cond_right_bindings; + GetColumnBindingsFromExpression(cond.GetLHS(), cond_left_bindings); + GetColumnBindingsFromExpression(cond.GetRHS(), cond_right_bindings); + + auto comparison_type = cond.GetComparisonType(); + bool lhs_is_nullable = BindingsAreSubset(cond_left_bindings, semantic_right_bindings); + bool rhs_is_nullable = BindingsAreSubset(cond_right_bindings, semantic_right_bindings); + unique_ptr left_expr; + unique_ptr right_expr; + if (lhs_is_nullable && !rhs_is_nullable) { + left_expr = cond.GetRHS().Copy(); + right_expr = cond.GetLHS().Copy(); + comparison_type = FlipComparisonExpression(comparison_type); + } else { + left_expr = cond.GetLHS().Copy(); + right_expr = cond.GetRHS().Copy(); + } + + auto comparison = BoundComparisonExpression::Create(comparison_type, std::move(left_expr), + std::move(right_expr)); + + auto filter_info = make_uniq(std::move(comparison), full_set, + filters_and_bindings.size(), join.join_type); + filter_info->SetLeftSet(full_left_set); + filter_info->SetRightSet(full_right_set); + filters_and_bindings.push_back(std::move(filter_info)); + } + + // Filters above a LEFT join that reference nullable-side bindings must be evaluated + // after the LEFT join has introduced NULL-extended rows. Pin them to the full LEFT + // join output, otherwise reconstruction can push null-aware predicates like + // "lhs IS DISTINCT FROM rhs" below the LEFT join and turn them into inner filters. + for (auto &filter : filters_and_bindings) { + if (filter->join_type == JoinType::LEFT) { + continue; + } + PinFilterAfterLeftJoin(*filter, *full_right_set, full_set, set_manager); + } + } + + // Store non-comparison ON-clause conditions as residual predicates. + unordered_set all_bindings; + for (auto &cond : join.conditions) { + if (cond.IsComparison()) { + ExtractBindings(cond.GetLHS(), all_bindings); + ExtractBindings(cond.GetRHS(), all_bindings); } else { - right_set = set_manager.Union(set_manager.GetJoinRelation(right_bindings), *right_set); + ExtractBindings(cond.GetJoinExpression(), all_bindings); + } + } + auto full_set = &set_manager.GetJoinRelation(all_bindings); + if (!full_set->Empty()) { + for (auto &cond : join.conditions) { + if (!cond.IsComparison()) { + auto nc_expr = cond.GetJoinExpression().Copy(); + auto new_filter = make_uniq(std::move(nc_expr), *full_set, + filters_and_bindings.size(), join.join_type); + new_filter->from_residual_predicate = true; + filters_and_bindings.push_back(std::move(new_filter)); + } } } - full_set = set_manager.Union(*left_set, *right_set); - D_ASSERT(left_set && left_set->count > 0); - D_ASSERT(right_set && right_set->count == 1); - D_ASSERT(full_set && full_set->count > 0); - - // now we push the conjunction expressions - // In QueryGraphManager::GenerateJoins we extract each condition again and create a standalone join - // condition. - auto filter_info = make_uniq(std::move(conjunction_expression), *full_set, - filters_and_bindings.size(), join.join_type); - filter_info->SetLeftSet(left_set); - filter_info->SetRightSet(right_set); - - filters_and_bindings.push_back(std::move(filter_info)); - } else { + break; + } + default: { // can extract every inner join condition individually. for (auto &cond : join.conditions) { unique_ptr expr; @@ -661,10 +896,13 @@ vector> RelationManager::ExtractEdges(LogicalOperator &op filters_and_bindings.push_back(std::move(filter_info)); } } + break; + } } join.conditions.clear(); } else { + // handle filters from logical filters vector> leftover_expressions; for (auto &expression : f_op.expressions) { if (filter_set.find(*expression) == filter_set.end()) { @@ -673,7 +911,7 @@ vector> RelationManager::ExtractEdges(LogicalOperator &op ExtractBindings(*expression, bindings); if (bindings.empty()) { // the filter is on a column that is not in our relational map. (example: limit_rownum) - // in this case we do not create a FilterInfo for it. (duckdb-internal/#1493)s + // in this case we do not create a FilterInfo for it. (duckdb-internal/#1493) leftover_expressions.push_back(std::move(expression)); continue; } diff --git a/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp b/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp index 65d5da23b..675a4184f 100644 --- a/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp +++ b/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp @@ -1,20 +1,139 @@ #include "duckdb/optimizer/join_order/relation_statistics_helper.hpp" #include "duckdb/planner/expression/list.hpp" #include "duckdb/planner/operator/list.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/common/operator/add.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/types/hugeint.hpp" #include "duckdb/function/table/table_scan.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/storage/data_table.hpp" +#include "duckdb/storage/statistics/numeric_stats.hpp" #include namespace duckdb { +bool ExpressionBinding::FoundExpression() const { + return expression; +} + +bool ExpressionBinding::FoundColumnRef() const { + if (!FoundExpression()) { + return false; + } + return expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF; +} + +RelationStats::RelationStats() : cardinality(1), filter_strength(1), stats_initialized(false) { +} + +DistinctCount::DistinctCount(idx_t distinct_count, DistinctCountSource source) + : distinct_count(distinct_count), source(source) { +} + +static idx_t CapMinMaxDistinctCount(uint64_t distinct_count, idx_t base_table_cardinality) { + if (base_table_cardinality == 0) { + return 0; + } + if (distinct_count == 0) { + return 0; + } + auto capped_distinct_count = MinValue(distinct_count, base_table_cardinality); + return capped_distinct_count == NumericLimits::Maximum() ? 0 : capped_distinct_count; +} + +static idx_t GetMinMaxSpanDistinctCount(uint64_t span, idx_t base_table_cardinality) { + uint64_t distinct_count; + if (!TryAddOperator::Operation(span, 1, distinct_count)) { + return 0; + } + return CapMinMaxDistinctCount(distinct_count, base_table_cardinality); +} + +template +static idx_t GetSignedMinMaxDistinctCount(const BaseStatistics &base_stats, idx_t base_table_cardinality) { + auto min_value = NumericStats::Min(base_stats).GetValueUnsafe(); + auto max_value = NumericStats::Max(base_stats).GetValueUnsafe(); + if (max_value < min_value) { + return 0; + } + hugeint_t span; + if (!TrySubtractOperator::Operation(hugeint_t(static_cast(max_value)), + hugeint_t(static_cast(min_value)), span)) { + return 0; + } + uint64_t unsigned_span; + if (!Hugeint::TryCast(span, unsigned_span)) { + return 0; + } + return GetMinMaxSpanDistinctCount(unsigned_span, base_table_cardinality); +} + +template +static idx_t GetUnsignedMinMaxDistinctCount(const BaseStatistics &base_stats, idx_t base_table_cardinality) { + auto min_value = NumericStats::Min(base_stats).GetValueUnsafe(); + auto max_value = NumericStats::Max(base_stats).GetValueUnsafe(); + T span; + if (!TrySubtractOperator::Operation(max_value, min_value, span)) { + return 0; + } + return GetMinMaxSpanDistinctCount(static_cast(span), base_table_cardinality); +} + +static idx_t GetBooleanMinMaxDistinctCount(const BaseStatistics &base_stats, idx_t base_table_cardinality) { + auto min_value = NumericStats::Min(base_stats).GetValueUnsafe(); + auto max_value = NumericStats::Max(base_stats).GetValueUnsafe(); + auto distinct_count = min_value == max_value ? 1 : 2; + return CapMinMaxDistinctCount(distinct_count, base_table_cardinality); +} + +static idx_t GetMinMaxDistinctCount(const BaseStatistics &base_stats, idx_t base_table_cardinality) { + if (base_table_cardinality == 0 || base_stats.GetStatsType() != StatisticsType::NUMERIC_STATS || + !NumericStats::HasMinMax(base_stats)) { + return 0; + } + + switch (base_stats.GetType().InternalType()) { + case PhysicalType::BOOL: + return GetBooleanMinMaxDistinctCount(base_stats, base_table_cardinality); + case PhysicalType::INT8: + return GetSignedMinMaxDistinctCount(base_stats, base_table_cardinality); + case PhysicalType::INT16: + return GetSignedMinMaxDistinctCount(base_stats, base_table_cardinality); + case PhysicalType::INT32: + return GetSignedMinMaxDistinctCount(base_stats, base_table_cardinality); + case PhysicalType::INT64: + return GetSignedMinMaxDistinctCount(base_stats, base_table_cardinality); + case PhysicalType::UINT8: + return GetUnsignedMinMaxDistinctCount(base_stats, base_table_cardinality); + case PhysicalType::UINT16: + return GetUnsignedMinMaxDistinctCount(base_stats, base_table_cardinality); + case PhysicalType::UINT32: + return GetUnsignedMinMaxDistinctCount(base_stats, base_table_cardinality); + case PhysicalType::UINT64: + return GetUnsignedMinMaxDistinctCount(base_stats, base_table_cardinality); + default: + return 0; + } +} + +static DistinctCount GetDistinctCountFromStats(BaseStatistics &base_stats, idx_t base_table_cardinality) { + auto distinct_count = base_stats.GetDistinctCount(); + if (distinct_count > 0) { + return DistinctCount(distinct_count, DistinctCountSource::HLL); + } + distinct_count = GetMinMaxDistinctCount(base_stats, base_table_cardinality); + if (distinct_count > 0) { + return DistinctCount(distinct_count, DistinctCountSource::MIN_MAX); + } + return DistinctCount(0, DistinctCountSource::CARDINALITY); +} + static ExpressionBinding GetChildColumnBinding(Expression &expr) { auto ret = ExpressionBinding(); switch (expr.GetExpressionClass()) { @@ -22,7 +141,7 @@ static ExpressionBinding GetChildColumnBinding(Expression &expr) { // TODO: Other expression classes that can have 0 children? auto &func = expr.Cast(); // no children some sort of gen_random_uuid() or equivalent. - if (func.children.empty()) { + if (func.GetChildren().empty()) { ret.expression = expr; ret.expression_is_constant = true; return ret; @@ -32,7 +151,7 @@ static ExpressionBinding GetChildColumnBinding(Expression &expr) { case ExpressionClass::BOUND_COLUMN_REF: { ret.expression = expr; auto &new_col_ref = expr.Cast(); - ret.child_binding = ColumnBinding(new_col_ref.binding.table_index, new_col_ref.binding.column_index); + ret.child_binding = ColumnBinding(new_col_ref.Binding().table_index, new_col_ref.Binding().column_index); return ret; } case ExpressionClass::BOUND_LAMBDA_REF: @@ -61,24 +180,26 @@ static ExpressionBinding GetChildColumnBinding(Expression &expr) { return ret; } -idx_t RelationStatisticsHelper::GetDistinctCount(LogicalGet &get, ClientContext &context, - const ColumnIndex &column_id) { - if (!get.function.statistics && !get.function.statistics_extended) { - return 0; +unique_ptr RelationStatisticsHelper::GetColumnStatistics(LogicalGet &get, ClientContext &context, + const ColumnIndex &column_id) { + if (!get.bind_data || (!get.function.statistics && !get.function.statistics_extended)) { + return nullptr; } - unique_ptr column_statistics; if (get.function.statistics_extended) { TableFunctionGetStatisticsInput input(get.bind_data.get(), column_id); - column_statistics = get.function.statistics_extended(context, input); - } else { - D_ASSERT(get.function.statistics); - column_statistics = get.function.statistics(context, get.bind_data.get(), column_id.GetPrimaryIndex()); + return get.function.statistics_extended(context, input); } + D_ASSERT(get.function.statistics); + return get.function.statistics(context, get.bind_data.get(), column_id.GetPrimaryIndex()); +} + +DistinctCount RelationStatisticsHelper::GetDistinctCount(LogicalGet &get, ClientContext &context, + const ColumnIndex &column_id, idx_t base_table_cardinality) { + auto column_statistics = GetColumnStatistics(get, context, column_id); if (!column_statistics) { - return 0; + return DistinctCount(0, DistinctCountSource::CARDINALITY); } - auto distinct_count = column_statistics->GetDistinctCount(); - return distinct_count; + return GetDistinctCountFromStats(*column_statistics, base_table_cardinality); } RelationStats RelationStatisticsHelper::ExtractGetStats(LogicalGet &get, ClientContext &context) { @@ -90,48 +211,37 @@ RelationStats RelationStatisticsHelper::ExtractGetStats(LogicalGet &get, ClientC auto catalog_table = get.GetTable(); auto name = string("some table"); if (catalog_table) { - name = catalog_table->name; - return_stats.table_name = name; + name = catalog_table->name.GetIdentifierName(); + return_stats.table_name = Identifier(name); } // first push back basic distinct counts for each column (if we have them). auto &column_ids = get.GetColumnIds(); for (idx_t i = 0; i < column_ids.size(); i++) { auto column_id = column_ids[i].GetPrimaryIndex(); - auto distinct_count = GetDistinctCount(get, context, column_ids[i]); - if (distinct_count > 0) { - auto column_distinct_count = DistinctCount({distinct_count, true}); - return_stats.column_distinct_count.push_back(column_distinct_count); - return_stats.column_names.push_back(name + "." + get.names.at(column_id)); + auto distinct_count = GetDistinctCount(get, context, column_ids[i], base_table_cardinality); + if (distinct_count.distinct_count > 0) { + return_stats.column_distinct_count.emplace_back(distinct_count.distinct_count, distinct_count.source); + return_stats.column_names.push_back(Identifier(name + "." + get.names.at(column_id))); } else { // treat the cardinality as the distinct count. // the cardinality estimator will update these distinct counts based // on the extra columns that are joined on. - auto column_distinct_count = DistinctCount({cardinality_after_filters, false}); - return_stats.column_distinct_count.push_back(column_distinct_count); + return_stats.column_distinct_count.emplace_back(cardinality_after_filters, + DistinctCountSource::CARDINALITY); auto column_name = string("column"); if (column_id < get.names.size()) { - column_name = get.names.at(column_id); + column_name = get.names.at(column_id).GetIdentifierName(); } - return_stats.column_names.push_back(get.GetName() + "." + column_name); + return_stats.column_names.push_back(Identifier(get.GetName() + "." + column_name)); } } if (get.table_filters.HasFilters()) { - unique_ptr column_statistics; bool has_non_optional_filters = false; for (auto &entry : get.table_filters) { auto &column_index = get.GetColumnIndex(entry.GetIndex()); - if (get.bind_data && (get.function.statistics || get.function.statistics_extended)) { - if (get.function.statistics_extended) { - TableFunctionGetStatisticsInput input(get.bind_data.get(), column_index); - column_statistics = get.function.statistics_extended(context, input); - } else { - D_ASSERT(get.function.statistics); - column_statistics = - get.function.statistics(context, get.bind_data.get(), column_index.GetPrimaryIndex()); - } - } + auto column_statistics = GetColumnStatistics(get, context, column_index); if (column_statistics) { idx_t cardinality_with_filter = @@ -139,7 +249,7 @@ RelationStats RelationStatisticsHelper::ExtractGetStats(LogicalGet &get, ClientC cardinality_after_filters = MinValue(cardinality_after_filters, cardinality_with_filter); } - if (entry.Filter().filter_type != TableFilterType::OPTIONAL_FILTER) { + if (!ExpressionFilter::IsOptionalFilter(entry.Filter())) { has_non_optional_filters = true; } } @@ -155,6 +265,7 @@ RelationStats RelationStatisticsHelper::ExtractGetStats(LogicalGet &get, ClientC cardinality_after_filters = 0; } } + return_stats.cardinality = cardinality_after_filters; // update the estimated cardinality of the get as well. // This is not updated during plan reconstruction. @@ -167,13 +278,13 @@ RelationStats RelationStatisticsHelper::ExtractGetStats(LogicalGet &get, ClientC RelationStats RelationStatisticsHelper::ExtractDelimGetStats(LogicalDelimGet &delim_get, ClientContext &context) { RelationStats stats; - stats.table_name = delim_get.GetName(); + stats.table_name = Identifier(delim_get.GetName()); idx_t card = delim_get.EstimateCardinality(context); stats.cardinality = card; stats.stats_initialized = true; for (auto &binding : delim_get.GetColumnBindings()) { - stats.column_distinct_count.push_back(DistinctCount({1, false})); - stats.column_names.push_back("column" + to_string(binding.column_index)); + stats.column_distinct_count.emplace_back(1, DistinctCountSource::CARDINALITY); + stats.column_names.push_back(Identifier("column" + to_string(binding.column_index))); } return stats; } @@ -181,25 +292,26 @@ RelationStats RelationStatisticsHelper::ExtractDelimGetStats(LogicalDelimGet &de RelationStats RelationStatisticsHelper::ExtractProjectionStats(LogicalProjection &proj, RelationStats &child_stats) { auto proj_stats = RelationStats(); proj_stats.cardinality = child_stats.cardinality; - proj_stats.table_name = proj.GetName(); + proj_stats.table_name = Identifier(proj.GetName()); for (auto &expr : proj.expressions) { - proj_stats.column_names.push_back(expr->GetName()); + proj_stats.column_names.emplace_back(expr->GetName()); auto res = GetChildColumnBinding(*expr); D_ASSERT(res.FoundExpression()); if (res.expression_is_constant) { - proj_stats.column_distinct_count.push_back(DistinctCount({1, true})); + proj_stats.column_distinct_count.emplace_back(1, DistinctCountSource::EXACT); } else { auto column_index = res.child_binding.column_index; if (column_index >= child_stats.column_distinct_count.size() && expr->ToString() == "count_star()") { // only one value for a count star - proj_stats.column_distinct_count.push_back(DistinctCount({1, true})); + proj_stats.column_distinct_count.emplace_back(1, DistinctCountSource::EXACT); } else { // TODO: add this back in // D_ASSERT(column_index < stats.column_distinct_count.size()); if (column_index < child_stats.column_distinct_count.size()) { proj_stats.column_distinct_count.push_back(child_stats.column_distinct_count.at(column_index)); } else { - proj_stats.column_distinct_count.push_back(DistinctCount({proj_stats.cardinality, false})); + proj_stats.column_distinct_count.emplace_back(proj_stats.cardinality, + DistinctCountSource::CARDINALITY); } } } @@ -213,7 +325,7 @@ RelationStats RelationStatisticsHelper::ExtractDummyScanStats(LogicalDummyScan & idx_t card = dummy_scan.EstimateCardinality(context); stats.cardinality = card; for (idx_t i = 0; i < dummy_scan.GetColumnBindings().size(); i++) { - stats.column_distinct_count.push_back(DistinctCount({card, false})); + stats.column_distinct_count.emplace_back(card, DistinctCountSource::CARDINALITY); stats.column_names.push_back("dummy_scan_column"); } stats.stats_initialized = true; @@ -238,7 +350,7 @@ RelationStats RelationStatisticsHelper::CombineStatsOfReorderableOperator(vector stats.column_distinct_count.push_back(child_stats.column_distinct_count.at(i)); stats.column_names.push_back(child_stats.column_names.at(i)); } - stats.table_name += "joined with " + child_stats.table_name; + stats.table_name = Identifier(stats.table_name + "joined with " + child_stats.table_name); max_card = MaxValue(max_card, child_stats.cardinality); } stats.stats_initialized = true; @@ -305,12 +417,12 @@ RelationStats RelationStatisticsHelper::CombineStatsOfNonReorderableOperator(Log ret.stats_initialized = true; ret.filter_strength = 1; - ret.table_name = string(); + ret.table_name = Identifier(string()); for (auto &stats : child_stats) { if (!ret.table_name.empty()) { - ret.table_name += " joined with "; + ret.table_name = Identifier(ret.table_name + " joined with "); } - ret.table_name += stats.table_name; + ret.table_name = Identifier(ret.table_name + stats.table_name); // MARK joins are nonreorderable. They won't return initialized stats // continue in this case. if (!stats.stats_initialized) { @@ -332,7 +444,7 @@ RelationStats RelationStatisticsHelper::ExtractExpressionGetStats(LogicalExpress idx_t card = expression_get.EstimateCardinality(context); stats.cardinality = card; for (idx_t i = 0; i < expression_get.GetColumnBindings().size(); i++) { - stats.column_distinct_count.push_back(DistinctCount({card, false})); + stats.column_distinct_count.emplace_back(card, DistinctCountSource::CARDINALITY); stats.column_names.push_back("expression_get_column"); } stats.stats_initialized = true; @@ -350,7 +462,7 @@ RelationStats RelationStatisticsHelper::ExtractWindowStats(LogicalWindow &window for (idx_t column_index = child_stats.column_distinct_count.size(); column_index < num_child_columns; column_index++) { - stats.column_distinct_count.push_back(DistinctCount({child_stats.cardinality, false})); + stats.column_distinct_count.emplace_back(child_stats.cardinality, DistinctCountSource::CARDINALITY); stats.column_names.push_back("window"); } return stats; @@ -370,7 +482,7 @@ RelationStats RelationStatisticsHelper::ExtractAggregationStats(LogicalAggregate continue; } auto &bound_col = group.Cast(); - auto col_index = bound_col.binding.column_index; + auto col_index = bound_col.Binding().column_index; if (col_index >= child_stats.column_distinct_count.size()) { // it is possible the column index of the grouping_set is not in the child stats. // this can happen when delim joins are present, since delim scans are not currently @@ -425,11 +537,11 @@ RelationStats RelationStatisticsHelper::ExtractAggregationStats(LogicalAggregate const auto &binding = aggr_column_bindings[column_index]; if (binding.table_index == aggr.group_index && column_index < distinct_counts.size()) { // Group column that we have the HLL of - stats.column_distinct_count.push_back( - DistinctCount({LossyNumericCast(distinct_counts[column_index]), true})); + stats.column_distinct_count.emplace_back(LossyNumericCast(distinct_counts[column_index]), + DistinctCountSource::HLL); } else { // Non-group column, or we don't have the HLL - stats.column_distinct_count.push_back(DistinctCount({child_stats.cardinality, false})); + stats.column_distinct_count.emplace_back(child_stats.cardinality, DistinctCountSource::CARDINALITY); } stats.column_names.push_back("aggregate"); } @@ -439,7 +551,7 @@ RelationStats RelationStatisticsHelper::ExtractAggregationStats(LogicalAggregate RelationStats RelationStatisticsHelper::ExtractEmptyResultStats(LogicalEmptyResult &empty) { RelationStats stats; for (idx_t i = 0; i < empty.GetColumnBindings().size(); i++) { - stats.column_distinct_count.push_back(DistinctCount({0, false})); + stats.column_distinct_count.emplace_back(0, DistinctCountSource::CARDINALITY); stats.column_names.push_back("empty_result_column"); } stats.stats_initialized = true; @@ -449,11 +561,11 @@ RelationStats RelationStatisticsHelper::ExtractEmptyResultStats(LogicalEmptyResu idx_t RelationStatisticsHelper::InspectTableFilter(idx_t cardinality, const TableFilter &filter, BaseStatistics &base_stats) { auto cardinality_after_filters = cardinality; - auto expr_filter = ExpressionFilter::FromTableFilter(filter, base_stats.GetType()); - auto &expr = *expr_filter->expr; + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "RelationStatisticsHelper::InspectTableFilter"); + auto &expr = *expr_filter.expr; if (expr.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { auto &conj = expr.Cast(); - for (auto &child : conj.children) { + for (auto &child : conj.GetChildren()) { ExpressionFilter child_filter(child->Copy()); cardinality_after_filters = MinValue(cardinality_after_filters, InspectTableFilter(cardinality, child_filter, base_stats)); @@ -467,8 +579,8 @@ idx_t RelationStatisticsHelper::InspectTableFilter(idx_t cardinality, const Tabl if (comparison.GetExpressionType() != ExpressionType::COMPARE_EQUAL) { return cardinality_after_filters; } - auto column_count = base_stats.GetDistinctCount(); - // column_count = 0 when there is no column count (i.e parquet scans) + auto column_count = GetDistinctCountFromStats(base_stats, cardinality).distinct_count; + // column_count = 0 when there is no HLL and no usable min/max proxy. if (column_count > 0) { // we want the ceil of cardinality/column_count. We also want to avoid compiler errors cardinality_after_filters = (cardinality + column_count - 1) / column_count; diff --git a/src/duckdb/src/optimizer/late_materialization.cpp b/src/duckdb/src/optimizer/late_materialization.cpp index b4d031658..cc6c6cc81 100644 --- a/src/duckdb/src/optimizer/late_materialization.cpp +++ b/src/duckdb/src/optimizer/late_materialization.cpp @@ -122,7 +122,7 @@ void LateMaterialization::ReplaceTopLevelTableIndex(LogicalOperator &root, Table void LateMaterialization::ReplaceTableReferences(unique_ptr &root_expr, TableIndex new_table_index) { ExpressionIterator::VisitExpressionMutable( root_expr, [&](BoundColumnRefExpression &bound_column_ref, unique_ptr &expr) { - bound_column_ref.binding.table_index = new_table_index; + bound_column_ref.BindingMutable().table_index = new_table_index; }); } @@ -134,7 +134,7 @@ unique_ptr LateMaterialization::GetExpression(LogicalOperator &op, P auto &column_id = get.GetColumnIndex(column_binding); auto column_name = get.GetColumnName(column_id); auto &column_type = get.GetColumnType(column_id); - auto expr = make_uniq(column_name, column_type, column_binding); + auto expr = make_uniq(Identifier(column_name), column_type, column_binding); return std::move(expr); } case LogicalOperatorType::LOGICAL_PROJECTION: { @@ -151,7 +151,7 @@ unique_ptr LateMaterialization::GetExpression(LogicalOperator &op, P void LateMaterialization::ReplaceExpressionReferences(LogicalOperator &next_op, unique_ptr &root_expr) { ExpressionIterator::VisitExpressionMutable( root_expr, [&](BoundColumnRefExpression &bound_column_ref, unique_ptr &expr) { - expr = GetExpression(next_op, bound_column_ref.binding.column_index); + expr = GetExpression(next_op, bound_column_ref.Binding().column_index); }); } diff --git a/src/duckdb/src/optimizer/matcher/expression_matcher.cpp b/src/duckdb/src/optimizer/matcher/expression_matcher.cpp index b177aebb4..44a02246c 100644 --- a/src/duckdb/src/optimizer/matcher/expression_matcher.cpp +++ b/src/duckdb/src/optimizer/matcher/expression_matcher.cpp @@ -42,7 +42,7 @@ bool ComparisonExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { @@ -53,7 +53,7 @@ bool CastExpressionMatcher::Match(Expression &expr_p, vector(); - return matcher->Match(*expr.child, bindings); + return matcher->Match(*expr.ChildMutable(), bindings); } bool InClauseExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { @@ -65,7 +65,35 @@ bool InClauseExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { + if (!ExpressionMatcher::Match(expr_p, bindings)) { + return false; + } + auto &expr = expr_p.Cast(); + if (expr.GetExpressionType() != ExpressionType::COMPARE_IN || + expr.GetExpressionType() == ExpressionType::COMPARE_NOT_IN) { + return false; + } + + auto &entries = expr.GetChildrenMutable(); + if (entries.size() < 2) { + return false; + } + + if (!probe_matcher->Match(*entries[0], bindings)) { + return false; + } + + for (idx_t i = 1; i < entries.size(); ++i) { + if (!child_matcher->Match(*entries[i], bindings)) { + return false; + } + } + + return true; } bool ConjunctionExpressionMatcher::Match(Expression &expr_p, vector> &bindings) { @@ -73,7 +101,7 @@ bool ConjunctionExpressionMatcher::Match(Expression &expr_p, vector(); - if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { + if (!SetMatcher::Match(matchers, expr.GetChildrenMutable(), bindings, policy)) { return false; } return true; @@ -84,10 +112,10 @@ bool FunctionExpressionMatcher::Match(Expression &expr_p, vector(); - if (!FunctionMatcher::Match(function, expr.function.GetName())) { + if (!FunctionMatcher::Match(function, expr.Function().GetName())) { return false; } - if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { + if (!SetMatcher::Match(matchers, expr.GetChildrenMutable(), bindings, policy)) { return false; } return true; @@ -98,14 +126,17 @@ bool AggregateExpressionMatcher::Match(Expression &expr_p, vector(); - if (!FunctionMatcher::Match(function, expr.function.GetName())) { + if (expr.StateExportMode() == AggregateStateExportMode::STATE_EXPORT) { + return false; + } + if (!FunctionMatcher::Match(function, expr.Function().GetName())) { return false; } // we should create matchers for these in the future - if (expr.filter || expr.order_bys || expr.aggr_type != AggregateType::NON_DISTINCT) { + if (expr.GetFilter() || expr.GetOrderBys() || expr.GetAggregateType() != AggregateType::NON_DISTINCT) { return false; } - if (!SetMatcher::Match(matchers, expr.children, bindings, policy)) { + if (!SetMatcher::Match(matchers, expr.GetChildrenMutable(), bindings, policy)) { return false; } return true; diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp index 5e113f1f0..844ce38ae 100644 --- a/src/duckdb/src/optimizer/optimizer.cpp +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -43,10 +43,15 @@ #include "duckdb/optimizer/row_number_rewriter.hpp" #include "duckdb/optimizer/optimizer_extension.hpp" #include "duckdb/optimizer/outer_join_simplification.hpp" +#include "duckdb/optimizer/partial_aggregate_pushdown.hpp" #include "duckdb/optimizer/projection_pullup.hpp" +#include "duckdb/optimizer/rule/contains_to_in_clause.hpp" #include "duckdb/optimizer/rule/predicate_factoring.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/planner.hpp" +#include "duckdb/optimizer/remote_pushdown_optimizer.hpp" +#include "duckdb/main/database_manager.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -61,6 +66,7 @@ Optimizer::Optimizer(Binder &binder, ClientContext &context) : context(context), rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); @@ -75,6 +81,7 @@ Optimizer::Optimizer(Binder &binder, ClientContext &context) : context(context), rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); rewriter.rules.push_back(make_uniq(rewriter)); + rewriter.rules.push_back(make_uniq(rewriter)); #ifdef DEBUG for (auto &rule : rewriter.rules) { @@ -93,6 +100,10 @@ bool Optimizer::OptimizerDisabled(OptimizerType type) { } bool Optimizer::OptimizerDisabled(ClientContext &context_p, OptimizerType type) { + if (!Settings::Get(context_p)) { + // all optimizes are disabled + return true; + } auto &config = DBConfig::GetConfig(context_p); return config.options.disabled_optimizers.find(type) != config.options.disabled_optimizers.end(); } @@ -107,9 +118,10 @@ void Optimizer::RunOptimizer(OptimizerType type, const std::function &ca return; } auto &profiler = QueryProfiler::Get(context); - profiler.StartPhase(MetricsUtils::GetOptimizerMetricByType(type)); - callback(); - profiler.EndPhase(); + { + auto optimizer_timer = profiler.StartTimerInternal("optimizer." + StringUtil::Lower(EnumUtil::ToString(type))); + callback(); + } if (plan) { Verify(*plan); } @@ -142,6 +154,19 @@ static bool CTEContainsDML(const LogicalOperator &op) { return false; } +void Optimizer::OptimizeStatement(unique_ptr &statement) { + if (!Settings::Get(context)) { + return; + } + if (DatabaseManager::Get(context).GetRemoteCatalogCount() > 0) { + // if we have any remote catalogs attached then pushdown into remote + RunOptimizer(OptimizerType::REMOTE_PUSHDOWN, [&]() { + RemotePushdownOptimizer optimizer(binder); + optimizer.Rewrite(statement); + }); + } +} + void Optimizer::RunBuiltInOptimizers() { switch (plan->type) { case LogicalOperatorType::LOGICAL_TRANSACTION: @@ -248,6 +273,13 @@ void Optimizer::RunBuiltInOptimizers() { plan = optimizer.Optimize(std::move(plan)); }); + // Pre-aggregate SUM/COUNT below joins when GROUP BY is on dimension columns + // and the fact side dominates cardinality — reduces probe-side work. + RunOptimizer(OptimizerType::PARTIAL_AGGREGATE_PUSHDOWN, [&]() { + PartialAggregatePushdown partial_aggregate_pushdown(*this); + partial_aggregate_pushdown.VisitOperator(plan); + }); + RunOptimizer(OptimizerType::JOIN_ELIMINATION, [&]() { JoinElimination join_elimination; plan = join_elimination.Optimize(std::move(plan)); @@ -387,6 +419,9 @@ void Optimizer::RunBuiltInOptimizers() { } unique_ptr Optimizer::Optimize(unique_ptr plan_p) { + if (!Settings::Get(context)) { + return plan_p; + } Verify(*plan_p); this->plan = std::move(plan_p); @@ -416,13 +451,13 @@ unique_ptr Optimizer::Optimize(unique_ptr plan return std::move(plan); } -unique_ptr Optimizer::BindScalarFunction(const string &name, unique_ptr c1) { +unique_ptr Optimizer::BindScalarFunction(const Identifier &name, unique_ptr c1) { vector> children; children.push_back(std::move(c1)); return BindScalarFunction(name, std::move(children)); } -unique_ptr Optimizer::BindScalarFunction(const string &name, unique_ptr c1, +unique_ptr Optimizer::BindScalarFunction(const Identifier &name, unique_ptr c1, unique_ptr c2) { vector> children; children.push_back(std::move(c1)); @@ -430,7 +465,7 @@ unique_ptr Optimizer::BindScalarFunction(const string &name, unique_ return BindScalarFunction(name, std::move(children)); } -unique_ptr Optimizer::BindScalarFunction(const string &name, unique_ptr c1, +unique_ptr Optimizer::BindScalarFunction(const Identifier &name, unique_ptr c1, unique_ptr c2, unique_ptr c3) { vector> children; children.push_back(std::move(c1)); @@ -439,10 +474,10 @@ unique_ptr Optimizer::BindScalarFunction(const string &name, unique_ return BindScalarFunction(name, std::move(children)); } -unique_ptr Optimizer::BindScalarFunction(const string &name, vector> children) { +unique_ptr Optimizer::BindScalarFunction(const Identifier &name, vector> children) { FunctionBinder binder(context); ErrorData error; - auto expr = binder.BindScalarFunction(DEFAULT_SCHEMA, name, std::move(children), error); + auto expr = binder.BindScalarFunction(Identifier::DefaultSchema(), name, std::move(children), error); if (error.HasError()) { throw InternalException("Optimizer exception - failed to bind function %s: %s", name, error.Message()); } diff --git a/src/duckdb/src/optimizer/outer_join_simplification.cpp b/src/duckdb/src/optimizer/outer_join_simplification.cpp index 555731aad..e07921b55 100644 --- a/src/duckdb/src/optimizer/outer_join_simplification.cpp +++ b/src/duckdb/src/optimizer/outer_join_simplification.cpp @@ -16,7 +16,7 @@ void OuterJoinSimplification::HandleExpression(const Expression &expr) { return; } auto &colref = expr.Cast(); - null_filtered_columns.insert(colref.binding); + null_filtered_columns.insert(colref.Binding()); } void OuterJoinSimplification::VisitOperator(LogicalOperator &op) { @@ -96,7 +96,7 @@ void OuterJoinSimplification::VisitOperator(LogicalOperator &op) { if (null_filtered_columns.find(binding) == null_filtered_columns.end()) { continue; } - null_filtered_columns.insert(expr.Cast().binding); + null_filtered_columns.insert(expr.Cast().Binding()); } VisitOperatorChildren(op); return; @@ -109,7 +109,7 @@ void OuterJoinSimplification::VisitOperator(LogicalOperator &op) { if (expr->GetExpressionClass() == ExpressionClass::BOUND_OPERATOR && expr->GetExpressionType() == ExpressionType::OPERATOR_IS_NOT_NULL) { const auto &is_not_null = expr->Cast(); - HandleExpression(*is_not_null.children[0]); + HandleExpression(*is_not_null.GetChildren()[0]); } else if (BoundComparisonExpression::IsComparison(*expr)) { if (expr->GetExpressionType() == ExpressionType::COMPARE_DISTINCT_FROM || expr->GetExpressionType() == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { diff --git a/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp b/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp new file mode 100644 index 000000000..fb616240b --- /dev/null +++ b/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp @@ -0,0 +1,487 @@ +#include "duckdb/optimizer/partial_aggregate_pushdown.hpp" + +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/function/scalar/generic_common.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" + +namespace duckdb { + +PartialAggregatePushdown::PartialAggregatePushdown(Optimizer &optimizer_p) : optimizer(optimizer_p) { +} + +struct PartialAggregatePushdownHeuristics { + static constexpr idx_t MIN_DIMENSION_GROUPS = 4; + static constexpr idx_t MIN_AGGREGATE_TO_DIMENSION_RATIO = 4; + static constexpr idx_t MAX_JOIN_SELECTIVITY_INV = 8; + static constexpr idx_t MAX_EXTRA_LOWER_GROUPS = 1; +}; + +struct PartialAggregatePushdownInfo { + idx_t aggregate_side; + idx_t dimension_side; + TableIndex lower_group_index; + TableIndex lower_aggregate_index; + idx_t join_key_count; + unordered_set side_bindings[2]; + vector lower_group_bindings; + column_binding_map_t lower_group_map; + column_binding_map_t lower_group_types; +}; + +static bool IsSubset(const unordered_set &bindings, const unordered_set &side_bindings) { + for (auto &binding : bindings) { + if (side_bindings.find(binding) == side_bindings.end()) { + return false; + } + } + return true; +} + +static unordered_set GetExpressionBindings(const Expression &expr) { + unordered_set bindings; + LogicalJoin::GetExpressionBindings(expr, bindings); + return bindings; +} + +static bool GetColumnBinding(const Expression &expr, ColumnBinding &binding) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { + return false; + } + binding = expr.Cast().Binding(); + return true; +} + +static bool IsSupportedAggregate(const BoundAggregateExpression &expr) { + if (expr.IsDistinct() || expr.GetFilter() || expr.GetOrderBys()) { + return false; + } + if (expr.GetChildren().size() != 1) { + return false; + } + if (!expr.Function().HasGetStateTypeCallback()) { + return false; + } + return true; +} + +static bool ContainsAggregateInput(const LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + return true; + } + for (auto &child : op.children) { + if (ContainsAggregateInput(*child)) { + return true; + } + } + return false; +} + +static bool GetPushdownOperators(LogicalOperator &op, LogicalAggregate *&aggr, LogicalComparisonJoin *&join) { + if (op.type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY || op.children.size() != 1) { + return false; + } + aggr = &op.Cast(); + if (!aggr->grouping_functions.empty() || aggr->groups.empty() || aggr->expressions.empty()) { + return false; + } + auto &child = *op.children[0]; + if (child.type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return false; + } + join = &child.Cast(); + return join->join_type == JoinType::INNER && !join->HasProjectionMap() && join->children.size() == 2 && + !join->conditions.empty(); +} + +static bool GetExpressionSide(const Expression &expr, const PartialAggregatePushdownInfo &info, idx_t &side) { + auto bindings = GetExpressionBindings(expr); + if (bindings.empty()) { + return false; + } + if (IsSubset(bindings, info.side_bindings[0])) { + side = 0; + return true; + } + if (IsSubset(bindings, info.side_bindings[1])) { + side = 1; + return true; + } + return false; +} + +static bool FindAggregateSide(const LogicalAggregate &aggr, PartialAggregatePushdownInfo &info) { + optional_idx aggregate_side; + for (auto &expr : aggr.expressions) { + if (expr->GetExpressionClass() != ExpressionClass::BOUND_AGGREGATE) { + return false; + } + auto &aggregate = expr->Cast(); + if (!IsSupportedAggregate(aggregate)) { + return false; + } + idx_t side; + if (!GetExpressionSide(*aggregate.GetChildren()[0], info, side)) { + return false; + } + if (!aggregate_side.IsValid()) { + aggregate_side = side; + } else if (aggregate_side.GetIndex() != side) { + return false; + } + } + if (!aggregate_side.IsValid()) { + return false; + } + info.aggregate_side = aggregate_side.GetIndex(); + info.dimension_side = 1 - info.aggregate_side; + return true; +} + +static bool PassesCardinalityHeuristic(const LogicalComparisonJoin &join, const PartialAggregatePushdownInfo &info) { + auto &aggregate_child = *join.children[info.aggregate_side]; + auto &dimension_child = *join.children[info.dimension_side]; + if (!aggregate_child.has_estimated_cardinality || !dimension_child.has_estimated_cardinality) { + return true; + } + if (aggregate_child.estimated_cardinality < + PartialAggregatePushdownHeuristics::MIN_AGGREGATE_TO_DIMENSION_RATIO * dimension_child.estimated_cardinality) { + return false; + } + if (join.has_estimated_cardinality && + join.estimated_cardinality * PartialAggregatePushdownHeuristics::MAX_JOIN_SELECTIVITY_INV < + aggregate_child.estimated_cardinality) { + return false; + } + return true; +} + +static bool GetJoinSideExpressions(JoinCondition &condition, const PartialAggregatePushdownInfo &info, + unique_ptr *&aggregate_expr, unique_ptr *&dimension_expr) { + if (!condition.IsComparison() || condition.GetComparisonType() != ExpressionType::COMPARE_EQUAL) { + return false; + } + auto left_bindings = GetExpressionBindings(condition.GetLHS()); + auto right_bindings = GetExpressionBindings(condition.GetRHS()); + if (left_bindings.empty() || right_bindings.empty()) { + return false; + } + if (IsSubset(left_bindings, info.side_bindings[info.aggregate_side]) && + IsSubset(right_bindings, info.side_bindings[info.dimension_side])) { + aggregate_expr = &condition.LeftReference(); + dimension_expr = &condition.RightReference(); + return true; + } + if (IsSubset(right_bindings, info.side_bindings[info.aggregate_side]) && + IsSubset(left_bindings, info.side_bindings[info.dimension_side])) { + aggregate_expr = &condition.RightReference(); + dimension_expr = &condition.LeftReference(); + return true; + } + return false; +} + +static bool ValidateJoinConditions(LogicalComparisonJoin &join, const PartialAggregatePushdownInfo &info) { + for (auto &condition : join.conditions) { + unique_ptr *aggregate_expr; + unique_ptr *dimension_expr; + if (!GetJoinSideExpressions(condition, info, aggregate_expr, dimension_expr)) { + return false; + } + ColumnBinding join_key; + if (!GetColumnBinding(**aggregate_expr, join_key)) { + return false; + } + } + return true; +} + +static bool HasWideDimensionGroups(const LogicalAggregate &aggr, const PartialAggregatePushdownInfo &info) { + idx_t dimension_group_count = 0; + for (auto &group : aggr.groups) { + ColumnBinding group_binding; + if (!GetColumnBinding(*group, group_binding)) { + return false; + } + idx_t side; + if (!GetExpressionSide(*group, info, side)) { + return false; + } + if (side == info.dimension_side) { + dimension_group_count++; + } + } + return dimension_group_count >= PartialAggregatePushdownHeuristics::MIN_DIMENSION_GROUPS; +} + +static bool AnalyzePushdown(LogicalAggregate &aggr, LogicalComparisonJoin &join, PartialAggregatePushdownInfo &info) { + LogicalJoin::GetTableReferences(*join.children[0], info.side_bindings[0]); + LogicalJoin::GetTableReferences(*join.children[1], info.side_bindings[1]); + if (!FindAggregateSide(aggr, info)) { + return false; + } + if (ContainsAggregateInput(*join.children[info.aggregate_side])) { + return false; + } + if (info.side_bindings[info.dimension_side].size() != 1) { + return false; + } + if (!PassesCardinalityHeuristic(join, info)) { + return false; + } + if (!ValidateJoinConditions(join, info)) { + return false; + } + return HasWideDimensionGroups(aggr, info); +} + +static void AddLowerGroup(PartialAggregatePushdownInfo &info, ColumnBinding binding, const LogicalType &type) { + if (info.lower_group_map.find(binding) != info.lower_group_map.end()) { + return; + } + auto lower_binding = ColumnBinding(info.lower_group_index, ProjectionIndex(info.lower_group_bindings.size())); + info.lower_group_bindings.push_back(binding); + info.lower_group_map[binding] = lower_binding; + info.lower_group_types[binding] = type; +} + +static void BuildLowerGroupMap(LogicalAggregate &aggr, LogicalComparisonJoin &join, + PartialAggregatePushdownInfo &info) { + for (auto &condition : join.conditions) { + unique_ptr *aggregate_expr; + unique_ptr *dimension_expr; + GetJoinSideExpressions(condition, info, aggregate_expr, dimension_expr); + ColumnBinding binding; + GetColumnBinding(**aggregate_expr, binding); + AddLowerGroup(info, binding, (*aggregate_expr)->GetReturnType()); + } + info.join_key_count = info.lower_group_bindings.size(); + for (auto &group : aggr.groups) { + auto &group_ref = group->Cast(); + if (info.lower_group_map.find(group_ref.Binding()) != info.lower_group_map.end()) { + continue; + } + idx_t side; + GetExpressionSide(*group, info, side); + if (side == info.aggregate_side) { + AddLowerGroup(info, group_ref.Binding(), group->GetReturnType()); + } + } +} + +static bool PassesLowerGroupHeuristic(const PartialAggregatePushdownInfo &info) { + return info.lower_group_bindings.size() <= + info.join_key_count + PartialAggregatePushdownHeuristics::MAX_EXTRA_LOWER_GROUPS; +} + +static bool BindPushdownAggregates(ClientContext &context, LogicalAggregate &aggr, TableIndex lower_aggregate_index, + vector> &lower_aggregates, + vector> &upper_aggregates) { + auto combine_function = CombineAggrFun::GetFunction(); + FunctionBinder function_binder(context); + + for (idx_t i = 0; i < aggr.expressions.size(); i++) { + auto aggregate_copy = unique_ptr_cast(aggr.expressions[i]->Copy()); + auto lower_aggregate = ExportAggregateFunction::Bind(std::move(aggregate_copy)); + auto lower_type = lower_aggregate->GetReturnType(); + if (!lower_type.IsAggregateState()) { + return false; + } + + vector> arguments; + auto lower_binding = ColumnBinding(lower_aggregate_index, ProjectionIndex(i)); + arguments.push_back(make_uniq(lower_type, lower_binding)); + auto upper_aggregate = function_binder.BindAggregateFunction(combine_function, std::move(arguments)); + if (!upper_aggregate->GetReturnType().IsAggregateState()) { + return false; + } + lower_aggregates.push_back(std::move(lower_aggregate)); + upper_aggregates.push_back(std::move(upper_aggregate)); + } + return true; +} + +static vector> CreateLowerGroups(const PartialAggregatePushdownInfo &info) { + vector> lower_groups; + for (auto &binding : info.lower_group_bindings) { + auto type = info.lower_group_types.at(binding); + lower_groups.push_back(make_uniq(type, binding)); + } + return lower_groups; +} + +static unique_ptr CreateLowerAggregate(LogicalAggregate &aggr, LogicalComparisonJoin &join, + PartialAggregatePushdownInfo &info, + vector> lower_aggregates) { + auto lower_aggr = + make_uniq(info.lower_group_index, info.lower_aggregate_index, std::move(lower_aggregates)); + lower_aggr->groups = CreateLowerGroups(info); + lower_aggr->children.push_back(std::move(join.children[info.aggregate_side])); + lower_aggr->ResolveOperatorTypes(); + lower_aggr->estimated_cardinality = aggr.estimated_cardinality; + lower_aggr->has_estimated_cardinality = aggr.has_estimated_cardinality; + return lower_aggr; +} + +static unique_ptr CreateJoin(LogicalComparisonJoin &join, PartialAggregatePushdownInfo &info, + unique_ptr lower_aggr) { + auto new_join = make_uniq(JoinType::INNER); + if (info.aggregate_side == 0) { + new_join->children.push_back(std::move(lower_aggr)); + new_join->children.push_back(std::move(join.children[info.dimension_side])); + } else { + new_join->children.push_back(std::move(join.children[info.dimension_side])); + new_join->children.push_back(std::move(lower_aggr)); + } + + for (auto &condition : join.conditions) { + unique_ptr *aggregate_expr; + unique_ptr *dimension_expr; + GetJoinSideExpressions(condition, info, aggregate_expr, dimension_expr); + ColumnBinding join_key; + GetColumnBinding(**aggregate_expr, join_key); + auto lower_binding = info.lower_group_map[join_key]; + auto lower_type = new_join->children[info.aggregate_side]->types[lower_binding.column_index.GetIndex()]; + auto lower_expr = make_uniq(lower_type, lower_binding); + if (info.aggregate_side == 0) { + new_join->conditions.emplace_back(std::move(lower_expr), (*dimension_expr)->Copy(), + ExpressionType::COMPARE_EQUAL); + } else { + new_join->conditions.emplace_back((*dimension_expr)->Copy(), std::move(lower_expr), + ExpressionType::COMPARE_EQUAL); + } + } + new_join->ResolveOperatorTypes(); + new_join->estimated_cardinality = join.estimated_cardinality; + new_join->has_estimated_cardinality = join.has_estimated_cardinality; + return new_join; +} + +static vector> CreateUpperGroups(LogicalAggregate &aggr, LogicalComparisonJoin &new_join, + const PartialAggregatePushdownInfo &info) { + vector> upper_groups; + for (auto &group : aggr.groups) { + auto &group_ref = group->Cast(); + auto entry = info.lower_group_map.find(group_ref.Binding()); + if (entry == info.lower_group_map.end()) { + upper_groups.push_back(group->Copy()); + continue; + } + auto type = new_join.children[info.aggregate_side]->types[entry->second.column_index.GetIndex()]; + upper_groups.push_back(make_uniq(type, entry->second)); + } + return upper_groups; +} + +static unique_ptr CreateUpperAggregate(LogicalAggregate &aggr, + unique_ptr new_join, + const PartialAggregatePushdownInfo &info, + vector> upper_aggregates) { + auto upper_aggr = make_uniq(aggr.group_index, aggr.aggregate_index, std::move(upper_aggregates)); + upper_aggr->groups = CreateUpperGroups(aggr, *new_join, info); + // Preserve ROLLUP / CUBE / GROUPING SETS. CreateUpperGroups builds the + // upper groups in the same order as `aggr.groups`, so the indices stored + // in grouping_sets remain valid. Without this copy the upper aggregate + // collapses to a single set and silently over-aggregates. + upper_aggr->grouping_sets = aggr.grouping_sets; + upper_aggr->grouping_functions = aggr.grouping_functions; + upper_aggr->children.push_back(std::move(new_join)); + upper_aggr->ResolveOperatorTypes(); + upper_aggr->estimated_cardinality = aggr.estimated_cardinality; + upper_aggr->has_estimated_cardinality = aggr.has_estimated_cardinality; + return upper_aggr; +} + +static unique_ptr CreateFinalProjection(Optimizer &optimizer, LogicalAggregate &aggr, + unique_ptr upper_aggr, + column_binding_map_t &replacement_map) { + const auto proj_index = optimizer.binder.GenerateTableIndex(); + const auto group_count = aggr.groups.size(); + vector> projection_expressions; + projection_expressions.reserve(group_count + aggr.expressions.size()); + + for (idx_t group_idx = 0; group_idx < group_count; group_idx++) { + auto group_binding = ColumnBinding(upper_aggr->group_index, ProjectionIndex(group_idx)); + replacement_map[group_binding] = ColumnBinding(proj_index, ProjectionIndex(group_idx)); + projection_expressions.push_back( + make_uniq(upper_aggr->types[group_idx], group_binding)); + } + + for (idx_t aggr_idx = 0; aggr_idx < aggr.expressions.size(); aggr_idx++) { + auto aggregate_binding = ColumnBinding(upper_aggr->aggregate_index, ProjectionIndex(aggr_idx)); + replacement_map[aggregate_binding] = ColumnBinding(proj_index, ProjectionIndex(group_count + aggr_idx)); + auto aggregate_type = upper_aggr->types[group_count + aggr_idx]; + auto aggregate_ref = make_uniq(aggregate_type, aggregate_binding); + auto final_expression = optimizer.BindScalarFunction("finalize", std::move(aggregate_ref)); + if (final_expression->GetReturnType() != aggr.expressions[aggr_idx]->GetReturnType()) { + return nullptr; + } + projection_expressions.push_back(std::move(final_expression)); + } + + auto projection = make_uniq(proj_index, std::move(projection_expressions)); + if (upper_aggr->has_estimated_cardinality) { + projection->SetEstimatedCardinality(upper_aggr->estimated_cardinality); + } + projection->children.push_back(std::move(upper_aggr)); + projection->ResolveOperatorTypes(); + return projection; +} + +void PartialAggregatePushdown::VisitOperator(unique_ptr &op) { + LogicalOperatorVisitor::VisitOperator(op); + TryPushdownAggregate(op); +} + +unique_ptr PartialAggregatePushdown::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + auto entry = replacement_map.find(expr.Binding()); + if (entry != replacement_map.end()) { + expr.BindingMutable() = entry->second; + } + return nullptr; +} + +bool PartialAggregatePushdown::TryPushdownAggregate(unique_ptr &op) { + LogicalAggregate *aggr; + LogicalComparisonJoin *join; + if (!GetPushdownOperators(*op, aggr, join)) { + return false; + } + + PartialAggregatePushdownInfo info; + if (!AnalyzePushdown(*aggr, *join, info)) { + return false; + } + info.lower_group_index = optimizer.binder.GenerateTableIndex(); + info.lower_aggregate_index = optimizer.binder.GenerateTableIndex(); + BuildLowerGroupMap(*aggr, *join, info); + if (!PassesLowerGroupHeuristic(info)) { + return false; + } + + vector> lower_aggregates; + vector> upper_aggregates; + if (!BindPushdownAggregates(optimizer.context, *aggr, info.lower_aggregate_index, lower_aggregates, + upper_aggregates)) { + return false; + } + + auto lower_aggr = CreateLowerAggregate(*aggr, *join, info, std::move(lower_aggregates)); + auto new_join = CreateJoin(*join, info, std::move(lower_aggr)); + auto upper_aggr = CreateUpperAggregate(*aggr, std::move(new_join), info, std::move(upper_aggregates)); + auto final_projection = CreateFinalProjection(optimizer, *aggr, std::move(upper_aggr), replacement_map); + if (!final_projection) { + return false; + } + op = std::move(final_projection); + return true; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/partitioned_execution.cpp b/src/duckdb/src/optimizer/partitioned_execution.cpp index a07e2ba1c..d05f91078 100644 --- a/src/duckdb/src/optimizer/partitioned_execution.cpp +++ b/src/duckdb/src/optimizer/partitioned_execution.cpp @@ -32,7 +32,7 @@ struct PartitionedExecutionConfig { struct PartitionedExecutionColumn { explicit PartitionedExecutionColumn(idx_t original_idx_p, Expression &expr, OrderType order_type_p = OrderType::ASCENDING) - : original_idx(original_idx_p), column_binding(expr.Cast().binding), + : original_idx(original_idx_p), column_binding(expr.Cast().Binding()), order_type(order_type_p) { } @@ -74,7 +74,7 @@ static bool PartitionedExecutionGetColumns(LogicalOperator &op, vectorGetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { columns.emplace_back(columns.size(), *partition); } @@ -115,7 +115,7 @@ static optional_ptr PartitionedExecutionTraceColumns(LogicalOperator for (auto it = columns.begin(); it != columns.end();) { auto &expr = *proj.expressions[it->column_binding.column_index.GetIndex()]; if (expr.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { - it->column_binding = expr.Cast().binding; + it->column_binding = expr.Cast().Binding(); it++; } else { it = PartitionedExecutionHandleColumnRemoval(op.type, columns, it); diff --git a/src/duckdb/src/optimizer/projection_pullup.cpp b/src/duckdb/src/optimizer/projection_pullup.cpp index 9c47804a9..fdea0831a 100644 --- a/src/duckdb/src/optimizer/projection_pullup.cpp +++ b/src/duckdb/src/optimizer/projection_pullup.cpp @@ -80,10 +80,19 @@ void ProjectionPullup::InsertProjectionBelowOp(unique_ptr &op, void ProjectionPullup::PullUpColrefProjection(unique_ptr &op, LogicalProjection &proj, vector &proj_bindings) { + // LOGICAL_DISTINCT sets `everything_referenced = true` in RemoveUnusedColumns + // for its subtree. The projection above it acts as a binding barrier; removing + // it lets upstream references point past DISTINCT and breaks column pruning + // down to READ_PARQUET. Repro: TPC-DS Q54 regresses 4-5x without this guard. + if (proj.children[0]->type == LogicalOperatorType::LOGICAL_DISTINCT) { + ProjectionPullup next(optimizer, root); + next.Optimize(proj.children[0]); + return; + } ColumnBindingReplacer replacer; for (idx_t i = 0; i < proj.expressions.size(); i++) { auto &colref = proj.expressions[i]->Cast(); - replacer.replacement_bindings.emplace_back(proj_bindings[i], colref.binding); + replacer.replacement_bindings.emplace_back(proj_bindings[i], colref.Binding()); } replacer.stop_operator = proj.children[0]; @@ -131,7 +140,7 @@ void ProjectionPullup::PullUpNonColrefProjection(unique_ptr &op for (idx_t i = 0; i < proj.expressions.size(); i++) { if (proj.expressions[i]->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &colref = proj.expressions[i]->Cast(); - replacer.replacement_bindings.emplace_back(proj_bindings[i], colref.binding); + replacer.replacement_bindings.emplace_back(proj_bindings[i], colref.Binding()); } } for (idx_t i = 0; i < pull_up_to_here; i++) { @@ -191,7 +200,7 @@ void ProjectionPullup::CanPullThrough(column_binding_map_tCast(); - auto entry = projection_map.find(colref.binding); + auto entry = projection_map.find(colref.Binding()); if (entry == projection_map.end()) { return; diff --git a/src/duckdb/src/optimizer/pullup/pullup_projection.cpp b/src/duckdb/src/optimizer/pullup/pullup_projection.cpp index 7d9e28088..c69e5535d 100644 --- a/src/duckdb/src/optimizer/pullup/pullup_projection.cpp +++ b/src/duckdb/src/optimizer/pullup/pullup_projection.cpp @@ -27,8 +27,8 @@ static void ReplaceExpressionBinding(vector> &proj_expres auto &proj_expr = *proj_expressions[proj_idx]; if (proj_expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { if (colref.Equals(proj_expr)) { - colref.binding.table_index = proj_table_idx; - colref.binding.column_index = ProjectionIndex(proj_idx); + colref.BindingMutable().table_index = proj_table_idx; + colref.BindingMutable().column_index = ProjectionIndex(proj_idx); found_proj_col = true; break; } @@ -37,8 +37,9 @@ static void ReplaceExpressionBinding(vector> &proj_expres if (!found_proj_col) { // Project a new column auto new_colref = colref.Copy(); - colref.binding.table_index = proj_table_idx; - colref.binding.column_index = ColumnBinding::PushExpression(proj_expressions, std::move(new_colref)); + colref.BindingMutable().table_index = proj_table_idx; + colref.BindingMutable().column_index = + ColumnBinding::PushExpression(proj_expressions, std::move(new_colref)); } } ExpressionIterator::EnumerateChildren( diff --git a/src/duckdb/src/optimizer/pullup/pullup_set_operation.cpp b/src/duckdb/src/optimizer/pullup/pullup_set_operation.cpp index 00f642ca0..83515d4c9 100644 --- a/src/duckdb/src/optimizer/pullup/pullup_set_operation.cpp +++ b/src/duckdb/src/optimizer/pullup/pullup_set_operation.cpp @@ -8,9 +8,9 @@ namespace duckdb { static void ReplaceFilterTableIndex(Expression &expr, LogicalSetOperation &setop) { if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &colref = expr.Cast(); - D_ASSERT(colref.depth == 0); + D_ASSERT(colref.Depth() == 0); - colref.binding.table_index = setop.table_index; + colref.BindingMutable().table_index = setop.table_index; return; } ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { ReplaceFilterTableIndex(child, setop); }); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_aggregate.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_aggregate.cpp index 79968c7a3..84affa038 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_aggregate.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_aggregate.cpp @@ -12,16 +12,16 @@ using Filter = FilterPushdown::Filter; static unique_ptr ReplaceGroupBindings(LogicalAggregate &aggr, unique_ptr root_expr) { ExpressionIterator::VisitExpressionMutable( root_expr, [&](BoundColumnRefExpression &colref, unique_ptr &expr) { - D_ASSERT(colref.depth == 0); + D_ASSERT(colref.Depth() == 0); // replace the binding with a copy to the expression at the referenced index - expr = aggr.GetExpression(colref.binding).Copy(); + expr = aggr.GetExpression(colref.Binding()).Copy(); }); return root_expr; } void FilterPushdown::ExtractFilterBindings(const Expression &expr, vector &bindings) { ExpressionIterator::VisitExpression( - expr, [&](const BoundColumnRefExpression &colref) { bindings.push_back(colref.binding); }); + expr, [&](const BoundColumnRefExpression &colref) { bindings.push_back(colref.Binding()); }); } unique_ptr FilterPushdown::PushdownAggregate(unique_ptr op) { diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp index f129bc850..8412dbbb0 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp @@ -1,6 +1,8 @@ #include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/optimizer/in_clause_rewriter.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/expression/bound_parameter_expression.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" @@ -77,10 +79,21 @@ unique_ptr FilterPushdown::PushdownGet(unique_ptr(); + if (!in_expr.GetChildren().empty() && + in_expr.GetChildren()[0]->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF && + in_expr.GetChildren().size() - 1 >= InClauseRewriter::IN_CLAUSE_REWRITE_THRESHOLD) { + continue; + } + } // Allow pushing down filters that can throw only if there is a single expression - // For now, do not push down single expressions with IN either. Later we can change InClauseRewriter to handle - // this case - if (expr.CanThrow() && (expr.GetExpressionType() == ExpressionType::COMPARE_IN || filters.size() > 1)) { + if (expr.CanThrow() && filters.size() > 1) { continue; } pushdown_result = combiner.TryPushdownGenericExpression(get, expr); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp index df18c8eba..e93d1ccf8 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp @@ -34,7 +34,7 @@ static unique_ptr ReplaceColRefWithNull(unique_ptr root_ unordered_set &right_bindings) { ExpressionIterator::VisitExpressionMutable( root_expr, [&](BoundColumnRefExpression &bound_colref, unique_ptr &expr) { - if (right_bindings.find(bound_colref.binding.table_index) != right_bindings.end()) { + if (right_bindings.find(bound_colref.Binding().table_index) != right_bindings.end()) { // bound colref belongs to RHS // replace it with a constant NULL expr = make_uniq(Value(expr->GetReturnType())); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp index 6021695de..fc8c34cf3 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp @@ -53,7 +53,7 @@ unique_ptr FilterPushdown::PushdownMarkJoin(unique_ptrfilter->GetExpressionType() == ExpressionType::OPERATOR_NOT) { auto &op_expr = filters[i]->filter->Cast(); - if (op_expr.children[0]->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + if (op_expr.GetChildren()[0]->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { // the filter is NOT(marker), check the join conditions bool all_null_values_are_equal = true; for (auto &cond : comp_join.conditions) { diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_projection.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_projection.cpp index c285f86d4..c0f45f441 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_projection.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_projection.cpp @@ -11,8 +11,8 @@ namespace duckdb { static bool IsVolatile(LogicalProjection &proj, const Expression &expr) { bool is_volatile = false; ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { - D_ASSERT(colref.depth == 0); - auto &proj_expr = proj.GetExpression(colref.binding); + D_ASSERT(colref.Depth() == 0); + auto &proj_expr = proj.GetExpression(colref.Binding()); if (proj_expr.IsVolatile()) { is_volatile = true; } @@ -23,9 +23,9 @@ static bool IsVolatile(LogicalProjection &proj, const Expression &expr) { static unique_ptr ReplaceProjectionBindings(LogicalProjection &proj, unique_ptr root_expr) { ExpressionIterator::VisitExpressionMutable( root_expr, [&](BoundColumnRefExpression &colref, unique_ptr &expr) { - D_ASSERT(colref.depth == 0); + D_ASSERT(colref.Depth() == 0); // replace the binding with a copy to the expression at the referenced index - auto &proj_expr = proj.GetExpression(colref.binding); + auto &proj_expr = proj.GetExpression(colref.Binding()); auto copy = proj_expr.Copy(); if (!colref.GetAlias().empty()) { copy->SetAlias(colref.GetAlias()); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_set_operation.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_set_operation.cpp index e56abdcd2..916275330 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_set_operation.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_set_operation.cpp @@ -15,12 +15,12 @@ static void ReplaceSetOpBindings(vector &bindings, Filter &filter LogicalSetOperation &setop) { ExpressionIterator::VisitExpressionMutable( root_expr, [&](BoundColumnRefExpression &colref, unique_ptr &expr) { - D_ASSERT(colref.binding.table_index == setop.table_index); - D_ASSERT(colref.depth == 0); + D_ASSERT(colref.Binding().table_index == setop.table_index); + D_ASSERT(colref.Depth() == 0); // rewrite the binding by looking into the bound_tables list of the subquery - colref.binding = bindings[colref.binding.column_index]; - filter.bindings.insert(colref.binding.table_index); + colref.BindingMutable() = bindings[colref.Binding().column_index]; + filter.bindings.insert(colref.Binding().table_index); }); } diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_window.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_window.cpp index 5e8442532..97cb907e8 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_window.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_window.cpp @@ -43,7 +43,7 @@ unique_ptr FilterPushdown::PushdownWindow(unique_ptrCast(); - auto &partitions = window_expr.partitions; + auto &partitions = window_expr.Partitions(); if (partitions.empty()) { // If any window expression does not have partitions, we cannot push any filters. // all window expressions need to be partitioned by the same column @@ -57,7 +57,7 @@ unique_ptr FilterPushdown::PushdownWindow(unique_ptrCast(); - partition_bindings.insert(partition_col.binding); + partition_bindings.insert(partition_col.Binding()); break; } default: diff --git a/src/duckdb/src/optimizer/regex_range_filter.cpp b/src/duckdb/src/optimizer/regex_range_filter.cpp index ba8d10300..3e0ca04da 100644 --- a/src/duckdb/src/optimizer/regex_range_filter.cpp +++ b/src/duckdb/src/optimizer/regex_range_filter.cpp @@ -29,19 +29,19 @@ unique_ptr RegexRangeFilter::Rewrite(unique_ptrexpressions) { if (expr->GetExpressionType() == ExpressionType::BOUND_FUNCTION) { auto &func = expr->Cast(); - if (func.function.GetName() != "regexp_full_match" || func.children.size() != 2) { + if (func.Function().GetName() != "regexp_full_match" || func.GetChildren().size() != 2) { continue; } - auto &info = func.bind_info->Cast(); + auto &info = func.BindInfo()->Cast(); if (!info.range_success) { continue; } auto filter_left = BoundComparisonExpression::Create( - ExpressionType::COMPARE_GREATERTHANOREQUALTO, func.children[0]->Copy(), + ExpressionType::COMPARE_GREATERTHANOREQUALTO, func.GetChildren()[0]->Copy(), make_uniq(Value::BLOB_RAW(info.range_min))); - auto filter_right = - BoundComparisonExpression::Create(ExpressionType::COMPARE_LESSTHANOREQUALTO, func.children[0]->Copy(), - make_uniq(Value::BLOB_RAW(info.range_max))); + auto filter_right = BoundComparisonExpression::Create( + ExpressionType::COMPARE_LESSTHANOREQUALTO, func.GetChildren()[0]->Copy(), + make_uniq(Value::BLOB_RAW(info.range_max))); auto filter_expr = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(filter_left), std::move(filter_right)); diff --git a/src/duckdb/src/optimizer/remote_pushdown_optimizer.cpp b/src/duckdb/src/optimizer/remote_pushdown_optimizer.cpp new file mode 100644 index 000000000..ae0ab7a89 --- /dev/null +++ b/src/duckdb/src/optimizer/remote_pushdown_optimizer.cpp @@ -0,0 +1,1505 @@ +#include "duckdb/optimizer/remote_pushdown_optimizer.hpp" + +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/catalog/entry_lookup_info.hpp" +#include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/query_node/delete_query_node.hpp" +#include "duckdb/parser/query_node/insert_query_node.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/update_query_node.hpp" +#include "duckdb/parser/result_modifier.hpp" +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/tableref/expressionlistref.hpp" +#include "duckdb/parser/tableref/joinref.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/type_expression.hpp" +#include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb/common/extra_type_info.hpp" +#include "duckdb/function/table_function.hpp" +#include "duckdb/parser/query_node/recursive_cte_node.hpp" +#include "duckdb/planner/expression_binder/constant_binder.hpp" +#include "duckdb/execution/expression_executor.hpp" + +namespace duckdb { + +CatalogPushdownResult::CatalogPushdownResult(CatalogReferenceType reference_type_p) : reference_type(reference_type_p) { +} + +CatalogPushdownResult CatalogPushdownResult::Unknown() { + return CatalogPushdownResult(CatalogReferenceType::UNKNOWN_CATALOG_REFERENCE); +} + +CatalogPushdownResult CatalogPushdownResult::NoCatalogReference() { + return CatalogPushdownResult(CatalogReferenceType::NO_CATALOG_REFERENCED); +} + +CatalogPushdownResult CatalogPushdownResult::RemoteReference(Catalog &catalog) { + CatalogPushdownResult result(CatalogReferenceType::SINGLE_REMOTE_CATALOG); + result.catalog = catalog; + return result; +} + +RemotePushdownOptimizer::RemotePushdownOptimizer(Binder &binder) + : binder(binder), owned_pushdown_state(make_uniq()), pushdown_state(*owned_pushdown_state) { +} + +RemotePushdownOptimizer::RemotePushdownOptimizer(optional_ptr parent_p) + : binder(parent_p->binder), parent(parent_p), pushdown_state(parent->pushdown_state) { + // inherit table / column names from parent (for correlated subquery detection) + local_table_names = parent->local_table_names; +} + +void RemotePushdownOptimizer::FindRemoteCatalogsInSearchPath() { + if (pushdown_state.search_path_initialized) { + return; + } + pushdown_state.search_path_initialized = true; + auto &client_data = ClientData::Get(binder.context); + // iterate over all catalogs mentioned in the search path and check if they are remote + auto search_path = client_data.catalog_search_path->Get(); + // Deduplicate by catalog name. + identifier_set_t seen_remote_catalogs; + for (auto &entry : search_path) { + auto catalog_entry = Catalog::GetCatalogEntry(binder.context, entry.catalog); + if (!catalog_entry) { + continue; + } + if (!catalog_entry->Supports(RemoteCapability::EXECUTE_QUERY_NODE)) { + pushdown_state.local_catalogs_in_search_path.push_back(entry); + } else { + if (seen_remote_catalogs.insert(catalog_entry->GetName()).second) { + pushdown_state.remote_catalogs_in_search_path.push_back(*catalog_entry); + } + } + } +} + +CatalogPushdownResult RemotePushdownOptimizer::Merge(CatalogPushdownResult a, CatalogPushdownResult b) { + if (a.reference_type == CatalogReferenceType::NO_CATALOG_REFERENCED && + b.reference_type == CatalogReferenceType::NO_CATALOG_REFERENCED) { + // both sides refer to no catalog - result is no catalog reference, but with unified references + auto result = CatalogPushdownResult::NoCatalogReference(); + result.used_expressions.insert(result.used_expressions.end(), a.used_expressions.begin(), + a.used_expressions.end()); + result.used_expressions.insert(result.used_expressions.end(), b.used_expressions.begin(), + b.used_expressions.end()); + result.used_table_constructs.insert(result.used_table_constructs.end(), a.used_table_constructs.begin(), + a.used_table_constructs.end()); + result.used_table_constructs.insert(result.used_table_constructs.end(), b.used_table_constructs.begin(), + b.used_table_constructs.end()); + result.used_nodes.insert(result.used_nodes.end(), a.used_nodes.begin(), a.used_nodes.end()); + result.used_nodes.insert(result.used_nodes.end(), b.used_nodes.begin(), b.used_nodes.end()); + return result; + } + if (a.reference_type == CatalogReferenceType::SINGLE_REMOTE_CATALOG && + b.reference_type == CatalogReferenceType::NO_CATALOG_REFERENCED) { + // swap "a" and "b" so the merge happens below + return Merge(b, a); + } + if (a.reference_type == CatalogReferenceType::NO_CATALOG_REFERENCED && + b.reference_type == CatalogReferenceType::SINGLE_REMOTE_CATALOG) { + // "a" refers to no catalog, "b" refers to a single remote catalog + // check if "b" supports all constructs referenced in "a" + auto &remote_catalog = *b.catalog; + for (auto &expr : a.used_expressions) { + if (!remote_catalog.SupportsPushdown(expr.get())) { + // pushdown not supported - result is UNKNOWN_CATALOG_REFERENCE + return CatalogPushdownResult::Unknown(); + } + } + for (auto &table_construct : a.used_table_constructs) { + if (!remote_catalog.SupportsPushdown(table_construct.get())) { + // pushdown not supported - result is UNKNOWN_CATALOG_REFERENCE + return CatalogPushdownResult::Unknown(); + } + } + for (auto &query_node : a.used_nodes) { + if (!remote_catalog.SupportsPushdown(query_node.get())) { + // pushdown not supported - result is UNKNOWN_CATALOG_REFERENCE + return CatalogPushdownResult::Unknown(); + } + } + return b; + } + if (a.reference_type == CatalogReferenceType::UNKNOWN_CATALOG_REFERENCE || + b.reference_type == CatalogReferenceType::UNKNOWN_CATALOG_REFERENCE) { + return CatalogPushdownResult::Unknown(); + } + // Both are SINGLE_REMOTE_CATALOG - only valid if they refer to the same catalog + if (a.catalog == b.catalog) { + return a; + } + return CatalogPushdownResult::Unknown(); +} + +void RemotePushdownOptimizer::Rewrite(unique_ptr &statement) { + CatalogPushdownResult result; + switch (statement->type) { + case StatementType::SELECT_STATEMENT: + result = Rewrite(*statement->Cast().node); + break; + case StatementType::INSERT_STATEMENT: + result = Rewrite(*statement->Cast().node); + break; + case StatementType::DELETE_STATEMENT: + result = Rewrite(*statement->Cast().node); + break; + case StatementType::UPDATE_STATEMENT: + result = Rewrite(*statement->Cast().node); + break; + case StatementType::EXPLAIN_STATEMENT: + Rewrite(statement->Cast().stmt); + return; + default: + return; + } + FinishPushdown(statement, result); +} + +CatalogPushdownResult RemotePushdownOptimizer::Rewrite(QueryNode &node) { + if (!cte_results.empty()) { + throw InternalException( + "RemotePushdownOptimizer already has CTEs defined - this means no child was created correctly"); + } + for (auto &cte_pair : node.cte_map.map) { + const Identifier &cte_name = cte_pair.first; + auto &cte_info = *cte_pair.second; + CatalogPushdownResult cte_result; + if (cte_info.query_node) { + RemotePushdownOptimizer child_optimizer(this); + cte_result = child_optimizer.Rewrite(*cte_info.query_node); + } else { + cte_result = CatalogPushdownResult::Unknown(); + } + for (auto &key : cte_info.key_targets) { + cte_result = Merge(cte_result, Rewrite(key)); + } + cte_results[cte_name] = cte_result; + } + CatalogPushdownResult result; + switch (node.type) { + case QueryNodeType::SELECT_NODE: + result = RewriteNode(node.Cast()); + break; + case QueryNodeType::INSERT_QUERY_NODE: + result = RewriteNode(node.Cast()); + break; + case QueryNodeType::DELETE_QUERY_NODE: + result = RewriteNode(node.Cast()); + break; + case QueryNodeType::UPDATE_QUERY_NODE: + result = RewriteNode(node.Cast()); + break; + case QueryNodeType::SET_OPERATION_NODE: + result = RewriteNode(node.Cast()); + break; + case QueryNodeType::RECURSIVE_CTE_NODE: + result = RewriteNode(node.Cast()); + break; + default: + return CatalogPushdownResult::Unknown(); + } + // Merge results of all CTEs defined in this scope + // FIXME: this is only necessary because we push all CTEs, including unreferenced ones, to the result + // if we pruned unreferenced CTEs we could remove this + for (auto &cte_pair : node.cte_map.map) { + auto it = cte_results.find(cte_pair.first); + if (it != cte_results.end()) { + result = Merge(result, it->second); + } + } + if (result.reference_type == CatalogReferenceType::SINGLE_REMOTE_CATALOG) { + if (!result.catalog->SupportsPushdown(node)) { + // bail - referenced catalog does not support pushing down this node type + result = CatalogPushdownResult::Unknown(); + } + } else if (result.reference_type == CatalogReferenceType::NO_CATALOG_REFERENCED) { + result.used_nodes.push_back(node); + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::RewriteNode(RecursiveCTENode &node) { + RemotePushdownOptimizer left_optimizer(this); + CatalogPushdownResult left_result = left_optimizer.Rewrite(*node.left); + + // for recursive CTEs - the right-hand side of the CTE can refer to the recursive CTE itself + // we use whatever the CatalogPushdownResult of the LHS was to count this reference + RemotePushdownOptimizer recursive_optimizer(this); + recursive_optimizer.cte_results[node.ctename] = left_result; + + RemotePushdownOptimizer right_optimizer(&recursive_optimizer); + CatalogPushdownResult right_result = right_optimizer.Rewrite(*node.right); + + auto result = Merge(left_result, right_result); + for (auto &key : node.key_targets) { + result = Merge(result, Rewrite(key)); + } + for (auto &modifier : node.modifiers) { + switch (modifier->type) { + case ResultModifierType::ORDER_MODIFIER: { + auto &order_mod = modifier->Cast(); + for (auto &order : order_mod.orders) { + // ORDER BY entries cannot be constant-folded - a bare integer literal is a + // positional reference there + result = Merge(result, Rewrite(order.expression, ExpressionFoldingMode::FOLD_CHILDREN_ONLY)); + } + break; + } + case ResultModifierType::LIMIT_MODIFIER: { + auto &limit_mod = modifier->Cast(); + if (limit_mod.limit) { + result = Merge(result, Rewrite(limit_mod.limit)); + } + if (limit_mod.offset) { + result = Merge(result, Rewrite(limit_mod.offset)); + } + break; + } + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct_mod = modifier->Cast(); + for (auto &expr : distinct_mod.distinct_on_targets) { + // DISTINCT ON entries cannot be constant-folded - a bare integer literal is a + // positional reference there + result = Merge(result, Rewrite(expr, ExpressionFoldingMode::FOLD_CHILDREN_ONLY)); + } + break; + } + default: + break; + } + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::RewriteNode(SelectNode &node) { + auto from_result = CatalogPushdownResult::NoCatalogReference(); + if (node.from_table) { + from_result = Rewrite(node.from_table); + } + + // Merge from_table result with all expressions to determine if the whole node can be pushed + CatalogPushdownResult result = from_result; + for (auto &expr : node.select_list) { + result = Merge(result, Rewrite(expr)); + } + if (node.where_clause) { + result = Merge(result, Rewrite(node.where_clause)); + } + for (auto &expr : node.groups.group_expressions) { + // GROUP BY entries cannot be constant-folded - a bare integer literal is a + // positional reference there + result = Merge(result, Rewrite(expr, ExpressionFoldingMode::FOLD_CHILDREN_ONLY)); + } + if (node.having) { + result = Merge(result, Rewrite(node.having)); + } + if (node.qualify) { + result = Merge(result, Rewrite(node.qualify)); + } + for (auto &modifier : node.modifiers) { + switch (modifier->type) { + case ResultModifierType::ORDER_MODIFIER: { + auto &order_mod = modifier->Cast(); + for (auto &order : order_mod.orders) { + // ORDER BY entries cannot be constant-folded - a bare integer literal is a + // positional reference there + result = Merge(result, Rewrite(order.expression, ExpressionFoldingMode::FOLD_CHILDREN_ONLY)); + } + break; + } + case ResultModifierType::LIMIT_MODIFIER: { + auto &limit_mod = modifier->Cast(); + if (limit_mod.limit) { + result = Merge(result, Rewrite(limit_mod.limit)); + } + if (limit_mod.offset) { + result = Merge(result, Rewrite(limit_mod.offset)); + } + break; + } + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct_mod = modifier->Cast(); + for (auto &expr : distinct_mod.distinct_on_targets) { + // DISTINCT ON entries cannot be constant-folded - a bare integer literal is a + // positional reference there + result = Merge(result, Rewrite(expr, ExpressionFoldingMode::FOLD_CHILDREN_ONLY)); + } + break; + } + default: + break; + } + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::RewriteNode(InsertQueryNode &node) { + // first bind the target table for the insert + BaseTableRef target_ref; + target_ref.catalog_name = node.catalog; + target_ref.schema_name = node.schema; + target_ref.table_name = node.table; + + RemotePushdownOptimizer target_optimizer(this); + auto result = target_optimizer.Rewrite(target_ref); + if (node.select_statement) { + RemotePushdownOptimizer select_optimizer(this); + auto select_result = select_optimizer.Rewrite(*node.select_statement->node); + result = Merge(result, select_result); + if (select_result.reference_type == CatalogReferenceType::SINGLE_REMOTE_CATALOG) { + bool push_select_only = result.reference_type != CatalogReferenceType::SINGLE_REMOTE_CATALOG; + if (!push_select_only && !result.catalog->SupportsPushdown(node)) { + // the catalog cannot execute the INSERT itself remotely - push down only the SELECT part + push_select_only = true; + } + if (push_select_only) { + FinishPushdown(node.select_statement->node, select_result); + } + } + } + if (node.on_conflict_info) { + if (node.on_conflict_info->condition) { + auto condition_result = Rewrite(node.on_conflict_info->condition); + result = Merge(result, condition_result); + } + if (node.on_conflict_info->set_info) { + if (node.on_conflict_info->set_info->condition) { + auto condition_result = Rewrite(node.on_conflict_info->set_info->condition); + result = Merge(result, condition_result); + } + for (auto &expr : node.on_conflict_info->set_info->expressions) { + auto expr_result = Rewrite(expr); + result = Merge(result, expr_result); + } + } + } + for (auto &expr : node.returning_list) { + auto expr_result = Rewrite(expr); + result = Merge(result, expr_result); + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::RewriteNode(DeleteQueryNode &node) { + auto result = Rewrite(node.table); + vector using_results; + for (auto &using_clause : node.using_clauses) { + auto using_result = Rewrite(using_clause); + using_results.push_back(using_result); + result = Merge(result, using_result); + } + + if (node.condition) { + auto condition_result = Rewrite(node.condition); + result = Merge(result, condition_result); + } + for (auto &expr : node.returning_list) { + auto expr_result = Rewrite(expr); + result = Merge(result, expr_result); + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::RewriteNode(UpdateQueryNode &node) { + auto result = Rewrite(node.table); + auto from_result = CatalogPushdownResult::NoCatalogReference(); + if (node.from_table) { + from_result = Rewrite(node.from_table); + result = Merge(result, from_result); + } + + if (node.set_info) { + if (node.set_info->condition) { + auto condition_result = Rewrite(node.set_info->condition); + result = Merge(result, condition_result); + } + + for (auto &expr : node.set_info->expressions) { + auto expr_result = Rewrite(expr); + result = Merge(result, expr_result); + } + } + for (auto &expr : node.returning_list) { + auto expr_result = Rewrite(expr); + result = Merge(result, expr_result); + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::RewriteNode(SetOperationNode &node) { + // Rewrite each child independently so we can push down individual children if needed + vector child_results; + child_results.reserve(node.children.size()); + auto result = CatalogPushdownResult::NoCatalogReference(); + for (auto &child : node.children) { + RemotePushdownOptimizer child_optimizer(this); + auto child_result = child_optimizer.Rewrite(*child); + result = Merge(result, child_result); + child_results.push_back(child_result); + } + + // Check result modifiers (ORDER BY / LIMIT on the set operation itself) + bool has_expression_modifiers = false; + for (auto &modifier : node.modifiers) { + switch (modifier->type) { + case ResultModifierType::ORDER_MODIFIER: { + auto &order_mod = modifier->Cast(); + for (auto &order : order_mod.orders) { + // ORDER BY entries cannot be constant-folded - a bare integer literal is a + // positional reference there + result = Merge(result, Rewrite(order.expression, ExpressionFoldingMode::FOLD_CHILDREN_ONLY)); + has_expression_modifiers = true; + } + break; + } + case ResultModifierType::LIMIT_MODIFIER: { + auto &limit_mod = modifier->Cast(); + if (limit_mod.limit) { + result = Merge(result, Rewrite(limit_mod.limit)); + has_expression_modifiers = true; + } + if (limit_mod.offset) { + result = Merge(result, Rewrite(limit_mod.offset)); + has_expression_modifiers = true; + } + break; + } + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct_mod = modifier->Cast(); + for (auto &expr : distinct_mod.distinct_on_targets) { + // DISTINCT ON entries cannot be constant-folded - a bare integer literal is a + // positional reference there + result = Merge(result, Rewrite(expr, ExpressionFoldingMode::FOLD_CHILDREN_ONLY)); + has_expression_modifiers = true; + } + break; + } + default: + break; + } + } + // If the whole set operation resolves to a single remote catalog, propagate upward. + if (result.reference_type == CatalogReferenceType::SINGLE_REMOTE_CATALOG) { + return result; + } + if (has_expression_modifiers) { + // if the set operation has any modifiers (e.g. ORDER BY ) then binding can go wrong if we do a pushdown + // into children, since we might have something like SELECT i + 1 FROM remote UNION ALL ... ORDER BY i + 1 + // this requires "peeking into" the child query to figure out that the expressions match + // for now just be safe and skip pushdown into individual queries in this scenario + return result; + } + for (idx_t i = 0; i < node.children.size(); i++) { + FinishPushdown(node.children[i], child_results[i]); + } + return result; +} + +void RemotePushdownOptimizer::TrackLocalTable(const TableRef &ref) { + switch (ref.type) { + case TableReferenceType::BASE_TABLE: + TrackLocalTable(ref.Cast()); + break; + case TableReferenceType::TABLE_FUNCTION: + TrackLocalTable(ref.Cast()); + break; + case TableReferenceType::SUBQUERY: + TrackLocalTable(ref.Cast()); + break; + default: + break; + } +} + +CatalogPushdownResult RemotePushdownOptimizer::Rewrite(unique_ptr &ref) { + CatalogPushdownResult result; + switch (ref->type) { + case TableReferenceType::BASE_TABLE: + result = Rewrite(ref->Cast()); + break; + case TableReferenceType::JOIN: + result = Rewrite(ref->Cast()); + break; + case TableReferenceType::SUBQUERY: + result = Rewrite(ref->Cast()); + break; + case TableReferenceType::EXPRESSION_LIST: + result = Rewrite(ref->Cast()); + break; + case TableReferenceType::TABLE_FUNCTION: + result = Rewrite(ref->Cast()); + break; + case TableReferenceType::EMPTY_FROM: + case TableReferenceType::COLUMN_DATA: + result = CatalogPushdownResult::NoCatalogReference(); + break; + default: + return CatalogPushdownResult::Unknown(); + } + if (result.reference_type == CatalogReferenceType::SINGLE_REMOTE_CATALOG) { + // the table reference is fully remote - check if the remote catalog supports pushing it down + // (e.g. DuckDB-specific join types or TABLESAMPLE clauses cannot be sent to most remotes) + if (!result.catalog->SupportsPushdown(*ref)) { + TrackLocalTable(*ref); + return CatalogPushdownResult::Unknown(); + } + } else if (result.reference_type == CatalogReferenceType::NO_CATALOG_REFERENCED) { + // record the table reference so a remote catalog can veto it during a later merge + result.used_table_constructs.push_back(*ref); + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::Rewrite(ExpressionListRef &ref) { + auto result = CatalogPushdownResult::NoCatalogReference(); + for (auto &row : ref.values) { + for (auto &expr : row) { + result = Merge(result, Rewrite(expr)); + } + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::Rewrite(SubqueryRef &ref) { + RemotePushdownOptimizer child_binder(this); + auto result = child_binder.Rewrite(*ref.subquery->node); + if (result.reference_type == CatalogReferenceType::UNKNOWN_CATALOG_REFERENCE) { + TrackLocalTable(ref); + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::RewriteTableFunctionOnly(TableFunctionRef &ref) { + if (ref.function->GetExpressionClass() != ExpressionClass::FUNCTION) { + throw InternalException("RemotePushdownOptimizer: TableFunctionRef does not hold a function expression"); + } + auto &func_expr = ref.function->Cast(); + + // Figure out + Identifier catalog_name = func_expr.Catalog(); + Identifier schema_name = func_expr.Schema(); + Binder::BindSchemaOrCatalog(binder.context, catalog_name, schema_name); + + // If the function has an explicit catalog prefix, check if it's remote + if (!catalog_name.empty()) { + auto catalog = Catalog::GetCatalogEntry(binder.context, catalog_name); + if (catalog && catalog->Supports(RemoteCapability::EXECUTE_QUERY_NODE) && catalog->SupportsPushdown(ref)) { + // "catalog" is remote and we can pushdown this function + return CatalogPushdownResult::RemoteReference(*catalog); + } + // catalog was not found or catalog does not support pushdown - bail on pushdown for now + TrackLocalTable(ref); + return CatalogPushdownResult::Unknown(); + } + + // we have an unqualified table function + // this function can either live in a local / system catalog, or in a remote (if it is in the search path) + // check the search path + FindRemoteCatalogsInSearchPath(); + EntryLookupInfo func_lookup(CatalogType::TABLE_FUNCTION_ENTRY, func_expr.FunctionName()); + for (auto &local_entry : pushdown_state.local_catalogs_in_search_path) { + const Identifier &schema = schema_name.empty() ? local_entry.schema : schema_name; + auto entry = + Catalog::GetEntry(binder.context, local_entry.catalog, schema, func_lookup, OnEntryNotFound::RETURN_NULL); + if (entry && entry->type == CatalogType::TABLE_FUNCTION_ENTRY) { + auto &tf_entry = entry->Cast(); + bool is_set_returning = false; + for (auto &func : tf_entry.functions.functions) { + if (func.return_type == TableFunctionReturnType::SET_RETURNING_FUNCTION) { + is_set_returning = true; + break; + } + } + if (!is_set_returning) { + // TABLE_RETURNING_FUNCTION - blocks pushdown; track alias so correlated + // refs from nested lateral subqueries are detected + TrackLocalTable(ref); + return CatalogPushdownResult::Unknown(); + } + // SET_RETURNING_FUNCTION: neutral, recurse into args + // the generic TableRef dispatch records the function so a remote catalog can veto it + auto result = CatalogPushdownResult::NoCatalogReference(); + for (auto &arg : func_expr.GetArgumentsMutable()) { + result = Merge(result, Rewrite(arg.GetExpressionMutable())); + } + return result; + } + } + // we did not find the table function in a local catalog + if (pushdown_state.remote_catalogs_in_search_path.size() == 1) { + // if we have a single catalog in the remote search path - assume the function lives there + auto &remote_catalog = pushdown_state.remote_catalogs_in_search_path[0].get(); + if (remote_catalog.Supports(RemoteCapability::EXECUTE_QUERY_NODE) && remote_catalog.SupportsPushdown(ref)) { + // "catalog" is remote and we can pushdown this function + return CatalogPushdownResult::RemoteReference(remote_catalog); + } + } + // we couldn't find the function locally or remotely - skip pushing down + TrackLocalTable(ref); + return CatalogPushdownResult::Unknown(); +} + +CatalogPushdownResult RemotePushdownOptimizer::Rewrite(TableFunctionRef &ref) { + // rewrite the table function only + auto result = RewriteTableFunctionOnly(ref); + if (result.reference_type == CatalogReferenceType::UNKNOWN_CATALOG_REFERENCE) { + // don't bother recursing - we can never pushdown + return result; + } + // recurse into the function arguments + auto &func_expr = ref.function->Cast(); + for (auto &arg : func_expr.GetArgumentsMutable()) { + result = Merge(result, Rewrite(arg.GetExpressionMutable())); + } + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::Rewrite(JoinRef &ref) { + auto left_result = Rewrite(ref.left); + + // the right side of a join can be correlated to the left side - use a child optimizer to track this + RemotePushdownOptimizer child_optimizer(this); + auto right_result = child_optimizer.Rewrite(ref.right); + + auto result = Merge(left_result, right_result); + // Also analyze the join condition - it may contain subqueries or local macro calls + // that affect whether the join can be pushed as a whole. + if (ref.condition) { + result = Merge(result, Rewrite(ref.condition)); + } + return result; +} + +void RemotePushdownOptimizer::TrackLocalTable(const BaseTableRef &ref) { + if (!ref.alias.empty()) { + local_table_names.insert(ref.alias); + } else { + local_table_names.insert(ref.table_name); + } +} + +void RemotePushdownOptimizer::TrackLocalTable(const TableFunctionRef &ref) { + if (!ref.alias.empty()) { + local_table_names.insert(ref.alias); + } else { + local_table_names.insert(ref.function->Cast().FunctionName()); + } +} + +void RemotePushdownOptimizer::TrackLocalTable(const SubqueryRef &ref) { + if (!ref.alias.empty()) { + local_table_names.insert(ref.alias); + } else { + local_table_names.insert("unnamed_subquery"); + } +} + +bool RemotePushdownOptimizer::RefersToCTE(const Identifier &cte_name, CatalogPushdownResult &result) const { + auto entry = cte_results.find(cte_name); + if (entry != cte_results.end()) { + result = entry->second; + return true; + } + if (parent) { + return parent->RefersToCTE(cte_name, result); + } + return false; +} + +CatalogPushdownResult RemotePushdownOptimizer::Rewrite(BaseTableRef &ref) { + // Resolve schema_name-as-catalog ambiguity using the binder's own resolution logic + Identifier catalog_name = ref.catalog_name; + Identifier schema_name = ref.schema_name; + Binder::BindSchemaOrCatalog(binder.context, catalog_name, schema_name); + + // Case 0: check if this is a CTE reference (must have no explicit catalog/schema) + if (catalog_name.empty() && schema_name.empty()) { + CatalogPushdownResult pushdown_result; + if (RefersToCTE(ref.table_name, pushdown_result)) { + if (pushdown_result.reference_type == CatalogReferenceType::UNKNOWN_CATALOG_REFERENCE) { + // Local/unknown CTE - track as local for correlated subquery detection + TrackLocalTable(ref); + } + return pushdown_result; + } + } + + // Case 1: catalog is explicitly specified - check if it's a remote catalog + if (!catalog_name.empty()) { + auto catalog = Catalog::GetCatalogEntry(binder.context, catalog_name); + if (catalog && catalog->Supports(RemoteCapability::EXECUTE_QUERY_NODE)) { + // verify the table actually exists in the remote catalog - if it does not, fall back + // to the binder so it can report a proper error message + EntryLookupInfo table_lookup(CatalogType::TABLE_ENTRY, ref.table_name); + const auto &schema = schema_name.empty() ? Identifier(DEFAULT_SCHEMA) : schema_name; + auto entry = Catalog::GetEntry(binder.context, catalog->GetName(), schema, table_lookup, + OnEntryNotFound::RETURN_NULL); + if (!entry) { + TrackLocalTable(ref); + return CatalogPushdownResult::Unknown(); + } + return CatalogPushdownResult::RemoteReference(*catalog); + } + // A local table always blocks pushdown of any query that contains it. + // Returning UNKNOWN (not NO_CATALOG) ensures Merge(SINGLE_REMOTE, UNKNOWN) = UNKNOWN + // rather than the otherwise-neutral SINGLE_REMOTE. + TrackLocalTable(ref); + return CatalogPushdownResult::Unknown(); + } + + // Case 2: no explicit catalog - lazily populate search path catalogs on first use + FindRemoteCatalogsInSearchPath(); + + EntryLookupInfo table_lookup(CatalogType::TABLE_ENTRY, ref.table_name); + + if (pushdown_state.remote_catalogs_in_search_path.size() != 1) { + TrackLocalTable(ref); + return CatalogPushdownResult::Unknown(); + } + + for (auto &local_entry : pushdown_state.local_catalogs_in_search_path) { + // If the ref specifies a schema, use it; otherwise use the search path schema + const auto &schema = schema_name.empty() ? local_entry.schema : schema_name; + auto entry = + Catalog::GetEntry(binder.context, local_entry.catalog, schema, table_lookup, OnEntryNotFound::RETURN_NULL); + if (entry) { + TrackLocalTable(ref); + // Same as Case 1: local table → UNKNOWN to prevent Merge from treating it as neutral. + return CatalogPushdownResult::Unknown(); + } + } + + // Not found in any local catalog - push to the single remote catalog in the search path, + // but only if the table actually exists there (otherwise fall back to the binder for a proper error) + auto &remote_catalog = pushdown_state.remote_catalogs_in_search_path.front().get(); + const auto &schema = schema_name.empty() ? Identifier(DEFAULT_SCHEMA) : schema_name; + auto entry = + Catalog::GetEntry(binder.context, remote_catalog.GetName(), schema, table_lookup, OnEntryNotFound::RETURN_NULL); + if (!entry) { + TrackLocalTable(ref); + return CatalogPushdownResult::Unknown(); + } + return CatalogPushdownResult::RemoteReference(remote_catalog); +} + +bool RemotePushdownOptimizer::RefersToLocalTable(const ColumnRefExpression &col_ref) const { + // figuring out if a column refers to a local table is challenging without knowing all of the columns + // challenges are: + // (1) we might have a subquery (e.g. FROM (SELECT ...), (SELECT ...)) + // - when binding the second subquery we need to know the schema of the first subquery + // (2) we might have CTEs (e.g. FROM cte, (SELECT ...)) + // - when binding the subquery we need to know the schema of the CTE + // (3) we might have struct columns (e.g. FROM tbl, (SELECT struct_col.field) + // - we need to correctly deal with this scenario and figure out which struct col this column references + // this is effectively having to re-implement many components of binding but with some information missing + if (local_table_names.empty()) { + return false; + } + return true; +} + +ExpressionPushdownResult RemotePushdownOptimizer::AnalyzeExpression(const SubqueryExpression &subquery_expr) { + ExpressionPushdownResult state; + RemotePushdownOptimizer child_optimizer(this); + state.result = child_optimizer.Rewrite(*subquery_expr.Subquery()->node); + return state; +} + +CatalogPushdownResult RemotePushdownOptimizer::CheckCatalogQualification(const ParsedExpression &expr, + const Identifier &catalog_p, + const Identifier &schema_p) { + Identifier catalog_name = catalog_p; + Identifier schema_name = schema_p; + Binder::BindSchemaOrCatalog(binder.context, catalog_name, schema_name); + if (!catalog_name.empty()) { + auto catalog = Catalog::GetCatalogEntry(binder.context, catalog_name); + if (catalog && catalog->Supports(RemoteCapability::EXECUTE_QUERY_NODE)) { + // the generic expression dispatch verifies that the catalog supports pushing down this expression + return CatalogPushdownResult::RemoteReference(*catalog); + } + // Explicitly local-catalog: block pushdown. + return CatalogPushdownResult::Unknown(); + } + return CatalogPushdownResult::NoCatalogReference(); +} + +ExpressionPushdownResult RemotePushdownOptimizer::AnalyzeExpression(const FunctionExpression &func) { + ExpressionPushdownResult state; + state.result = CheckCatalogQualification(func, func.Catalog(), func.Schema()); + // look up the function once - this determines both whether it can be constant-folded and + // whether it is a macro in a local catalog (which cannot be evaluated remotely) + EntryLookupInfo function_lookup(CatalogType::SCALAR_FUNCTION_ENTRY, func.FunctionName()); + auto entry = + Catalog::GetEntry(binder.context, func.Catalog(), func.Schema(), function_lookup, OnEntryNotFound::RETURN_NULL); + if (!entry) { + return state; + } + // aggregate-style modifiers cannot be constant-folded + bool foldable_modifiers = !func.Filter() && !func.Distinct() && !func.ExportState() && + (!func.OrderBy() || func.OrderBy()->orders.empty()); + switch (entry->type) { + case CatalogType::MACRO_ENTRY: + // scalar macros can be folded - the stability of the expansion is checked after binding + if (foldable_modifiers) { + state.foldability = ExpressionFoldability::FOLDABLE; + } + if (!entry->ParentCatalog().Supports(RemoteCapability::EXECUTE_QUERY_NODE)) { + // macros in local catalogs cannot be evaluated remotely - if the macro is not folded + // away, pushdown is blocked + state.result = CatalogPushdownResult::Unknown(); + } + break; + case CatalogType::SCALAR_FUNCTION_ENTRY: { + // at least one overload must be non-volatile (the selected overload is verified after binding) + auto &scalar_entry = entry->Cast(); + for (auto &overload : scalar_entry.functions.functions) { + if (foldable_modifiers && overload.GetStability() != FunctionStability::VOLATILE) { + state.foldability = ExpressionFoldability::FOLDABLE; + break; + } + } + break; + } + default: + break; + } + return state; +} + +ExpressionPushdownResult RemotePushdownOptimizer::AnalyzeExpression(const WindowExpression &func) { + ExpressionPushdownResult state; + state.result = CheckCatalogQualification(func, func.Catalog(), func.Schema()); + return state; +} + +ExpressionPushdownResult RemotePushdownOptimizer::AnalyzeExpression(const TypeExpression &type_expr) { + ExpressionPushdownResult state; + state.result = CheckCatalogQualification(type_expr, type_expr.GetCatalog(), type_expr.GetSchema()); + return state; +} + +ExpressionPushdownResult RemotePushdownOptimizer::AnalyzeExpression(const ColumnRefExpression &col_ref) { + ExpressionPushdownResult state; + if (RefersToLocalTable(col_ref)) { + // column refers to local table - bail + state.result = CatalogPushdownResult::Unknown(); + } + return state; +} + +RemotePushdownOptimizer::ConstantFoldResult +RemotePushdownOptimizer::TryConstantFold(unique_ptr &expr) { + // bind a copy of the expression (binding modifies the expression in-place) + unique_ptr bound_expr; + try { + auto expr_copy = expr->Copy(); + auto fold_binder = Binder::CreateBinder(binder.context); + ConstantBinder constant_binder(*fold_binder, binder.context, "remote pushdown"); + bound_expr = constant_binder.Bind(expr_copy); + } catch (std::exception &) { + // the expression cannot be bound as a constant (e.g. no matching function overload) + return ConstantFoldResult::NOT_FOLDABLE; + } + if (!bound_expr || bound_expr->HasParameter() || bound_expr->HasSubquery()) { + return ConstantFoldResult::NOT_FOLDABLE; + } + if (bound_expr->IsVolatile()) { + // volatile functions (random(), ...) must be re-evaluated on every row + return ConstantFoldResult::NOT_FOLDABLE; + } + if (!bound_expr->IsConsistent()) { + // functions like now() must be re-evaluated (re-folded) when a prepared statement is re-executed + binder.SetAlwaysRequireRebind(); + } + Value fold_result; + if (!ExpressionExecutor::TryEvaluateScalar(binder.context, *bound_expr, fold_result)) { + // evaluating the expression raises an error (e.g. an out-of-range error) + return ConstantFoldResult::FOLD_ERROR; + } + auto folded = make_uniq(std::move(fold_result)); + // preserve the name DuckDB would generate for the original expression + folded->SetAlias(expr->GetAlias().empty() ? Identifier(expr->ToString()) : expr->GetAlias()); + folded->SetQueryLocation(expr->GetQueryLocation()); + expr = std::move(folded); + return ConstantFoldResult::FOLDED; +} + +ExpressionPushdownResult RemotePushdownOptimizer::AnalyzeExpression(const ParsedExpression &expr) { + switch (expr.GetExpressionClass()) { + case ExpressionClass::SUBQUERY: + return AnalyzeExpression(expr.Cast()); + case ExpressionClass::FUNCTION: + return AnalyzeExpression(expr.Cast()); + case ExpressionClass::WINDOW: + return AnalyzeExpression(expr.Cast()); + case ExpressionClass::TYPE: + return AnalyzeExpression(expr.Cast()); + case ExpressionClass::COLUMN_REF: + return AnalyzeExpression(expr.Cast()); + case ExpressionClass::CONSTANT: + case ExpressionClass::CAST: + case ExpressionClass::COMPARISON: + case ExpressionClass::BETWEEN: + case ExpressionClass::CONJUNCTION: + case ExpressionClass::OPERATOR: + case ExpressionClass::CASE: { + // deterministic expression types - foldable when all inputs are foldable + ExpressionPushdownResult state; + state.foldability = ExpressionFoldability::FOLDABLE; + return state; + } + default: + return ExpressionPushdownResult(); + } +} + +ExpressionPushdownResult RemotePushdownOptimizer::RewriteExpression(unique_ptr &expr, + ExpressionFoldingMode mode) { + // rewrite the children - foldable subtrees are folded at the last possible moment, so that + // every maximal foldable subtree is bound and evaluated exactly once + auto result = CatalogPushdownResult::NoCatalogReference(); + vector>> foldable_children; + bool all_children_foldable = true; + ParsedExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { + auto child_state = RewriteExpression(child, ExpressionFoldingMode::FOLD_EXPRESSION); + if (child_state.foldability == ExpressionFoldability::FOLDABLE) { + // defer - the subtree is folded when it turns out to be a maximal foldable subtree + foldable_children.push_back(child); + } else { + all_children_foldable = false; + result = Merge(std::move(result), std::move(child_state.result)); + } + }); + auto state = AnalyzeExpression(*expr); + if (mode == ExpressionFoldingMode::FOLD_EXPRESSION && all_children_foldable && + state.foldability == ExpressionFoldability::FOLDABLE) { + // this expression is itself foldable - defer folding to the parent + ExpressionPushdownResult folding_deferred; + folding_deferred.foldability = ExpressionFoldability::FOLDABLE; + return folding_deferred; + } + // this expression is not foldable - fold the (maximal) foldable children now + for (auto &child : foldable_children) { + result = Merge(std::move(result), FoldExpression(child.get())); + } + result = Merge(std::move(result), std::move(state.result)); + if (result.reference_type == CatalogReferenceType::SINGLE_REMOTE_CATALOG) { + // the expression is fully remote - check if the remote catalog supports pushing it down + if (!result.catalog->SupportsPushdown(*expr)) { + result = CatalogPushdownResult::Unknown(); + } + } else if (result.reference_type == CatalogReferenceType::NO_CATALOG_REFERENCED) { + // record the expression so a remote catalog can veto pushdown of any expression class + // (functions, comparisons, operators, star expressions, parameters, etc.) during a later merge + result.used_expressions.push_back(*expr); + } + state.result = std::move(result); + state.foldability = ExpressionFoldability::NOT_FOLDABLE; + return state; +} + +CatalogPushdownResult RemotePushdownOptimizer::FoldExpression(unique_ptr &expr) { + if (expr->GetExpressionClass() != ExpressionClass::CONSTANT) { + // replace the expression with its locally-evaluated result + switch (TryConstantFold(expr)) { + case ConstantFoldResult::FOLD_ERROR: + // evaluating is guaranteed to fail - keep the query local so the user sees DuckDB's error + return CatalogPushdownResult::Unknown(); + case ConstantFoldResult::NOT_FOLDABLE: + // binding did not succeed after all - process the expression without re-attempting the fold + return RewriteExpression(expr, ExpressionFoldingMode::FOLD_CHILDREN_ONLY).result; + case ConstantFoldResult::FOLDED: + break; + } + } + // record the constant so a remote catalog can verify that it is supported + auto result = CatalogPushdownResult::NoCatalogReference(); + result.used_expressions.push_back(*expr); + return result; +} + +CatalogPushdownResult RemotePushdownOptimizer::Rewrite(unique_ptr &expr, ExpressionFoldingMode mode) { + auto state = RewriteExpression(expr, mode); + if (state.foldability == ExpressionFoldability::FOLDABLE) { + // the entire expression is foldable - fold it at the root + return FoldExpression(expr); + } + return state.result; +} + +unique_ptr RemotePushdownOptimizer::CreateRemoteFunctionRef(CatalogPushdownResult &result, + unique_ptr node) { + return result.catalog->RemoteExecute(binder.context, std::move(node)); +} + +void RemotePushdownOptimizer::StripCatalogName(TableRef &ref, const Identifier &catalog_name) { + switch (ref.type) { + case TableReferenceType::BASE_TABLE: { + auto &base = ref.Cast(); + if (base.catalog_name == catalog_name) { + base.catalog_name = ""; + } else if (base.catalog_name.empty() && base.schema_name == catalog_name) { + // 2-part name (schema.table) where the schema is actually the catalog being pushed to + base.schema_name = ""; + } + break; + } + case TableReferenceType::JOIN: { + auto &join = ref.Cast(); + StripCatalogName(*join.left, catalog_name); + StripCatalogName(*join.right, catalog_name); + if (join.condition) { + StripCatalogName(*join.condition, catalog_name); + } + break; + } + case TableReferenceType::SUBQUERY: { + auto &sq = ref.Cast(); + StripCatalogName(*sq.subquery->node, catalog_name); + break; + } + case TableReferenceType::TABLE_FUNCTION: { + auto &tf = ref.Cast(); + if (tf.function) { + StripCatalogName(*tf.function, catalog_name); + } + break; + } + case TableReferenceType::EXPRESSION_LIST: { + auto &el = ref.Cast(); + for (auto &row : el.values) { + for (auto &expr : row) { + StripCatalogName(*expr, catalog_name); + } + } + break; + } + default: + break; + } +} + +void RemotePushdownOptimizer::StripCatalogName(ParsedExpression &expr, const Identifier &catalog_name) { + if (expr.GetExpressionClass() == ExpressionClass::COLUMN_REF) { + auto &col_ref = expr.Cast(); + // Strip catalog prefix from qualified column references, normalising to exactly table.col (2 parts). + // Require at least 3 names: a 2-part ref like "rpc.field" is either table.col or struct-column.field — + // not catalog-qualified — so stripping would be wrong. + // For 3-part catalog.table.col → table.col (one level stripped) + // For 4-part catalog.schema.table.col → table.col (catalog + schema stripped) + if (col_ref.ColumnNames().size() >= 3 && col_ref.ColumnNames()[0] == catalog_name) { + Identifier table_name = col_ref.ColumnNames()[col_ref.ColumnNames().size() - 2]; + Identifier col_name = col_ref.ColumnNames()[col_ref.ColumnNames().size() - 1]; + col_ref.ColumnNamesMutable() = {std::move(table_name), std::move(col_name)}; + } + return; + } + if (expr.GetExpressionClass() == ExpressionClass::SUBQUERY) { + auto &subq = expr.Cast(); + StripCatalogName(*subq.SubqueryMutable()->node, catalog_name); + if (subq.GetChild()) { + StripCatalogName(*subq.GetChildMutable(), catalog_name); + } + return; + } + // Strip catalog prefix from explicitly-qualified function/window/type calls. + // Also handle 2-part names (schema.func/type) where the schema is actually the remote catalog name + // (e.g. "rpc.my_func()" parsed as schema="rpc", catalog=""). + if (expr.GetExpressionClass() == ExpressionClass::FUNCTION) { + auto &func = expr.Cast(); + if (func.Catalog() == catalog_name) { + func.CatalogMutable() = ""; + } else if (func.Catalog().empty() && func.Schema() == catalog_name) { + func.SchemaMutable() = ""; + } + // Fall through to EnumerateChildren to also strip catalog refs inside arguments + } else if (expr.GetExpressionClass() == ExpressionClass::WINDOW) { + auto &win = expr.Cast(); + if (win.Catalog() == catalog_name) { + win.CatalogMutable() = ""; + } else if (win.Catalog().empty() && win.Schema() == catalog_name) { + win.SchemaMutable() = ""; + } + // Fall through to EnumerateChildren to strip catalog refs inside partitions/orders/children + } else if (expr.GetExpressionClass() == ExpressionClass::CAST) { + // CastExpression stores the cast target as a LogicalType, not an expression child — EnumerateChildren + // only visits the value being cast. For unbound (user-defined) types we must strip the catalog from the + // embedded TypeExpression and reconstruct the LogicalType::UNBOUND wrapper. + auto &cast_expr = expr.Cast(); + auto &target_type = cast_expr.TargetTypeMutable(); + if (target_type.id() == LogicalTypeId::UNBOUND) { + auto type_expr = UnboundType::GetTypeExpression(target_type)->Copy(); + StripCatalogName(*type_expr, catalog_name); + target_type = LogicalType::UNBOUND(std::move(type_expr)); + } + // Fall through to EnumerateChildren to strip catalog refs inside the cast argument + } else if (expr.GetExpressionClass() == ExpressionClass::TYPE) { + // TypeExpression (used as a type argument) may carry catalog/schema qualifiers. + auto &type_expr = expr.Cast(); + if (type_expr.GetCatalog() == catalog_name) { + type_expr.SetCatalog(""); + } else if (type_expr.GetCatalog().empty() && type_expr.GetSchema() == catalog_name) { + type_expr.SetSchema(""); + } + // Fall through to EnumerateChildren to strip catalog refs inside type parameters + } + ParsedExpressionIterator::EnumerateChildren( + expr, [&](ParsedExpression &child) { StripCatalogName(child, catalog_name); }); +} + +void RemotePushdownOptimizer::StripCatalogName(QueryNode &node, const Identifier &catalog_name) { + switch (node.type) { + case QueryNodeType::SELECT_NODE: { + auto &select = node.Cast(); + // Strip within CTE definitions so the remote receives catalog-free SQL + for (auto &cte_pair : select.cte_map.map) { + if (cte_pair.second->query_node) { + StripCatalogName(*cte_pair.second->query_node, catalog_name); + } + for (auto &key : cte_pair.second->key_targets) { + StripCatalogName(*key, catalog_name); + } + } + if (select.from_table) { + StripCatalogName(*select.from_table, catalog_name); + } + for (auto &expr : select.select_list) { + StripCatalogName(*expr, catalog_name); + } + if (select.where_clause) { + StripCatalogName(*select.where_clause, catalog_name); + } + for (auto &expr : select.groups.group_expressions) { + StripCatalogName(*expr, catalog_name); + } + if (select.having) { + StripCatalogName(*select.having, catalog_name); + } + if (select.qualify) { + StripCatalogName(*select.qualify, catalog_name); + } + for (auto &modifier : select.modifiers) { + switch (modifier->type) { + case ResultModifierType::ORDER_MODIFIER: { + auto &order_mod = modifier->Cast(); + for (auto &order : order_mod.orders) { + StripCatalogName(*order.expression, catalog_name); + } + break; + } + case ResultModifierType::LIMIT_MODIFIER: { + auto &limit_mod = modifier->Cast(); + if (limit_mod.limit) { + StripCatalogName(*limit_mod.limit, catalog_name); + } + if (limit_mod.offset) { + StripCatalogName(*limit_mod.offset, catalog_name); + } + break; + } + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct_mod = modifier->Cast(); + for (auto &expr : distinct_mod.distinct_on_targets) { + StripCatalogName(*expr, catalog_name); + } + break; + } + default: + break; + } + } + break; + } + case QueryNodeType::INSERT_QUERY_NODE: { + auto &insert = node.Cast(); + for (auto &cte_pair : insert.cte_map.map) { + if (cte_pair.second->query_node) { + StripCatalogName(*cte_pair.second->query_node, catalog_name); + } + for (auto &key : cte_pair.second->key_targets) { + StripCatalogName(*key, catalog_name); + } + } + // Strip from the target table's catalog/schema fields (these are what ToString() serializes) + if (insert.catalog == catalog_name) { + insert.catalog = ""; + } else if (insert.catalog.empty() && insert.schema == catalog_name) { + insert.schema = ""; + } + if (insert.select_statement) { + StripCatalogName(*insert.select_statement->node, catalog_name); + } + if (insert.on_conflict_info) { + if (insert.on_conflict_info->condition) { + StripCatalogName(*insert.on_conflict_info->condition, catalog_name); + } + if (insert.on_conflict_info->set_info) { + if (insert.on_conflict_info->set_info->condition) { + StripCatalogName(*insert.on_conflict_info->set_info->condition, catalog_name); + } + for (auto &expr : insert.on_conflict_info->set_info->expressions) { + StripCatalogName(*expr, catalog_name); + } + } + } + for (auto &expr : insert.returning_list) { + StripCatalogName(*expr, catalog_name); + } + break; + } + case QueryNodeType::DELETE_QUERY_NODE: { + auto &del = node.Cast(); + for (auto &cte_pair : del.cte_map.map) { + if (cte_pair.second->query_node) { + StripCatalogName(*cte_pair.second->query_node, catalog_name); + } + for (auto &key : cte_pair.second->key_targets) { + StripCatalogName(*key, catalog_name); + } + } + if (del.table) { + StripCatalogName(*del.table, catalog_name); + } + if (del.condition) { + StripCatalogName(*del.condition, catalog_name); + } + for (auto &clause : del.using_clauses) { + StripCatalogName(*clause, catalog_name); + } + for (auto &expr : del.returning_list) { + StripCatalogName(*expr, catalog_name); + } + break; + } + case QueryNodeType::UPDATE_QUERY_NODE: { + auto &upd = node.Cast(); + for (auto &cte_pair : upd.cte_map.map) { + if (cte_pair.second->query_node) { + StripCatalogName(*cte_pair.second->query_node, catalog_name); + } + for (auto &key : cte_pair.second->key_targets) { + StripCatalogName(*key, catalog_name); + } + } + if (upd.table) { + StripCatalogName(*upd.table, catalog_name); + } + if (upd.from_table) { + StripCatalogName(*upd.from_table, catalog_name); + } + if (upd.set_info) { + if (upd.set_info->condition) { + StripCatalogName(*upd.set_info->condition, catalog_name); + } + for (auto &expr : upd.set_info->expressions) { + StripCatalogName(*expr, catalog_name); + } + } + for (auto &expr : upd.returning_list) { + StripCatalogName(*expr, catalog_name); + } + break; + } + case QueryNodeType::SET_OPERATION_NODE: { + auto &setop = node.Cast(); + for (auto &cte_pair : setop.cte_map.map) { + if (cte_pair.second->query_node) { + StripCatalogName(*cte_pair.second->query_node, catalog_name); + } + for (auto &key : cte_pair.second->key_targets) { + StripCatalogName(*key, catalog_name); + } + } + for (auto &child : setop.children) { + StripCatalogName(*child, catalog_name); + } + for (auto &modifier : setop.modifiers) { + switch (modifier->type) { + case ResultModifierType::ORDER_MODIFIER: { + auto &order_mod = modifier->Cast(); + for (auto &order : order_mod.orders) { + StripCatalogName(*order.expression, catalog_name); + } + break; + } + case ResultModifierType::LIMIT_MODIFIER: { + auto &limit_mod = modifier->Cast(); + if (limit_mod.limit) { + StripCatalogName(*limit_mod.limit, catalog_name); + } + if (limit_mod.offset) { + StripCatalogName(*limit_mod.offset, catalog_name); + } + break; + } + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct_mod = modifier->Cast(); + for (auto &expr : distinct_mod.distinct_on_targets) { + StripCatalogName(*expr, catalog_name); + } + break; + } + default: + break; + } + } + break; + } + case QueryNodeType::RECURSIVE_CTE_NODE: { + auto &rec = node.Cast(); + for (auto &cte_pair : rec.cte_map.map) { + if (cte_pair.second->query_node) { + StripCatalogName(*cte_pair.second->query_node, catalog_name); + } + for (auto &key : cte_pair.second->key_targets) { + StripCatalogName(*key, catalog_name); + } + } + for (auto &key : rec.key_targets) { + StripCatalogName(*key, catalog_name); + } + for (auto &modifier : rec.modifiers) { + switch (modifier->type) { + case ResultModifierType::ORDER_MODIFIER: { + auto &order_mod = modifier->Cast(); + for (auto &order : order_mod.orders) { + StripCatalogName(*order.expression, catalog_name); + } + break; + } + case ResultModifierType::LIMIT_MODIFIER: { + auto &limit_mod = modifier->Cast(); + if (limit_mod.limit) { + StripCatalogName(*limit_mod.limit, catalog_name); + } + if (limit_mod.offset) { + StripCatalogName(*limit_mod.offset, catalog_name); + } + break; + } + case ResultModifierType::DISTINCT_MODIFIER: { + auto &distinct_mod = modifier->Cast(); + for (auto &expr : distinct_mod.distinct_on_targets) { + StripCatalogName(*expr, catalog_name); + } + break; + } + default: + break; + } + } + if (rec.left) { + StripCatalogName(*rec.left, catalog_name); + } + if (rec.right) { + StripCatalogName(*rec.right, catalog_name); + } + break; + } + default: + break; + } +} + +void RemotePushdownOptimizer::StripCatalogName(SQLStatement &statement, const Identifier &catalog_name) { + switch (statement.type) { + case StatementType::SELECT_STATEMENT: + StripCatalogName(*statement.Cast().node, catalog_name); + break; + case StatementType::INSERT_STATEMENT: + StripCatalogName(*statement.Cast().node, catalog_name); + break; + case StatementType::DELETE_STATEMENT: + StripCatalogName(*statement.Cast().node, catalog_name); + break; + case StatementType::UPDATE_STATEMENT: + StripCatalogName(*statement.Cast().node, catalog_name); + break; + default: + break; + } +} + +unique_ptr GetNodeFromStatement(SQLStatement &statement) { + switch (statement.type) { + case StatementType::SELECT_STATEMENT: + return std::move(statement.Cast().node); + case StatementType::INSERT_STATEMENT: + return std::move(statement.Cast().node); + case StatementType::DELETE_STATEMENT: + return std::move(statement.Cast().node); + case StatementType::UPDATE_STATEMENT: + return std::move(statement.Cast().node); + default: + return nullptr; + } +} + +void RemotePushdownOptimizer::FinishPushdown(unique_ptr &statement, CatalogPushdownResult result) { + if (result.reference_type != CatalogReferenceType::SINGLE_REMOTE_CATALOG) { + return; + } + // Strip the catalog name so the remote server doesn't recursively re-push + StripCatalogName(*statement, result.catalog->GetName()); + auto node = GetNodeFromStatement(*statement); + if (!node) { + return; + } + + auto select_node = make_uniq(); + select_node->select_list.push_back(make_uniq()); + select_node->from_table = CreateRemoteFunctionRef(result, std::move(node)); + auto select_stmt = make_uniq(); + select_stmt->node = std::move(select_node); + statement = std::move(select_stmt); +} + +void RemotePushdownOptimizer::FinishPushdown(unique_ptr &node, CatalogPushdownResult result) { + if (result.reference_type != CatalogReferenceType::SINGLE_REMOTE_CATALOG) { + return; + } + // FIXME: work-around for referencing a CTE in a parent + // if this query refers to a CTE in the parent we can't push down only this node + // as the parent CTE is lost. For now we just block all pushdown if the parent has a CTE. + // we could fix this in a better way by either (1) moving the parent CTE into this node + // or (2) tracking if we actually refer to a parent CTE + for (auto opt = this; opt; opt = opt->parent.get()) { + for (auto &cte_entry : opt->cte_results) { + auto ref_type = cte_entry.second.reference_type; + if (ref_type == CatalogReferenceType::SINGLE_REMOTE_CATALOG || + ref_type == CatalogReferenceType::NO_CATALOG_REFERENCED) { + return; + } + } + } + StripCatalogName(*node, result.catalog->GetName()); + auto select_node = make_uniq(); + select_node->select_list.push_back(make_uniq()); + select_node->from_table = CreateRemoteFunctionRef(result, std::move(node)); + node = std::move(select_node); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/remove_duplicate_groups.cpp b/src/duckdb/src/optimizer/remove_duplicate_groups.cpp index 667d8db93..c0ab697ba 100644 --- a/src/duckdb/src/optimizer/remove_duplicate_groups.cpp +++ b/src/duckdb/src/optimizer/remove_duplicate_groups.cpp @@ -26,7 +26,7 @@ struct GroupRemoval { // Stops once a second distinct binding is seen. bool CollectColumnRefBindings(const Expression &expr, column_binding_set_t &result) { if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - result.insert(expr.Cast().binding); + result.insert(expr.Cast().Binding()); return result.size() <= 1; } bool ok = true; @@ -66,7 +66,7 @@ void CollectGroupRemovals(const LogicalAggregate &aggr, vector &re if (group->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { continue; } - const auto &binding = group->Cast().binding; + const auto &binding = group->Cast().Binding(); auto it = first_occurrence.find(binding); if (it == first_occurrence.end()) { first_occurrence.emplace(binding, group_idx); @@ -184,7 +184,7 @@ BuildDerivedProjection(LogicalAggregate &aggr, idx_t num_original_groups, idx_t continue; } auto base_post = group_binding_map.at(ColumnBinding(original_group_index, r.target_idx)); - const auto &base_source = aggr.groups[base_post.column_index]->Cast().binding; + const auto &base_source = aggr.groups[base_post.column_index]->Cast().Binding(); rebinder.replacement_bindings.clear(); rebinder.replacement_bindings.emplace_back(base_source, base_post); rebinder.VisitExpression(&r.derived_expr); diff --git a/src/duckdb/src/optimizer/remove_unused_columns.cpp b/src/duckdb/src/optimizer/remove_unused_columns.cpp index 6b5d3afe4..ef6168622 100644 --- a/src/duckdb/src/optimizer/remove_unused_columns.cpp +++ b/src/duckdb/src/optimizer/remove_unused_columns.cpp @@ -13,6 +13,7 @@ #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/filter/expression_filter.hpp" #include "duckdb/planner/operator/logical_aggregate.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_distinct.hpp" @@ -85,8 +86,8 @@ idx_t BaseColumnPruner::ReplaceBinding(ColumnBinding current_binding, ColumnBind //! No pushdown extract, just rewrite the existing bindings for (auto &colref_p : col.bindings) { auto &colref = colref_p.get(); - D_ASSERT(colref.binding == current_binding); - colref.binding = new_binding; + D_ASSERT(colref.Binding() == current_binding); + colref.BindingMutable() = new_binding; } created_bindings = 1; } @@ -193,16 +194,16 @@ void RemoveUnusedColumns::VisitOperator(unique_ptr &op_ref) { auto &lhs_col = cond.GetLHS().Cast(); auto &rhs_col = cond.GetRHS().Cast(); // if there are any columns that refer to the RHS, - auto colrefs = column_references.find(rhs_col.binding); + auto colrefs = column_references.find(rhs_col.Binding()); if (colrefs == column_references.end()) { continue; } for (auto &entry : colrefs->second.bindings) { auto &colref = entry.get(); - colref.binding = lhs_col.binding; + colref.BindingMutable() = lhs_col.Binding(); AddBinding(colref); } - column_references.erase(rhs_col.binding); + column_references.erase(rhs_col.Binding()); } break; } @@ -554,7 +555,7 @@ void RemoveUnusedColumns::WritePushdownExtractColumns( new_expr.SetReturnType(return_type); auto column_index = callback(*entry, component.cast ? &(*component.cast)->GetReturnType() : nullptr); - new_expr.binding.column_index = column_index; + new_expr.BindingMutable().column_index = column_index; } } @@ -605,7 +606,7 @@ void RemoveUnusedColumns::RewriteExpressions(LogicalProjection &proj, idx_t expr } auto &expr = *proj.expressions[expression_idx++]; auto &colref = expr.Cast(); - auto original_binding = colref.binding; + auto original_binding = colref.Binding(); auto &column_type = expr.GetReturnType(); idx_t start = expressions.size(); //! Pushdown Extract is supported, emit a column for every field @@ -758,7 +759,8 @@ void RemoveUnusedColumns::RemoveColumnsFromLogicalGet(LogicalGet &get, unique_pt ColumnBinding filter_binding(get.table_index, filter_idx); auto column_ref = make_uniq(std::move(column_type), filter_binding); //! Convert the filter to an expression, so we can visit it - auto filter_expr = filter.ToExpression(*column_ref); + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "RemoveUnusedColumns::VisitGet"); + auto filter_expr = expr_filter.ToExpression(*column_ref); if (filter_expr->IsScalar()) { filter_expr = std::move(column_ref); } @@ -941,9 +943,9 @@ bool BaseColumnPruner::HandleStructExtract(unique_ptr &expr_p, reference &path_ref, vector &expressions) { auto &function = expr_p->Cast(); - auto &child = function.children[0]; + auto &child = function.GetChildrenMutable()[0]; D_ASSERT(child->GetReturnType().id() == LogicalTypeId::STRUCT); - auto &bind_data = function.bind_info->Cast(); + auto &bind_data = function.BindInfo()->Cast(); // struct extract, check if left child is a bound column ref if (child->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { // column reference - check if it is a struct @@ -975,9 +977,9 @@ bool BaseColumnPruner::HandleVariantExtract(unique_ptr &expr_p, reference &path_ref, vector &expressions) { auto &function = expr_p->Cast(); - auto &child = function.children[0]; + auto &child = function.GetChildrenMutable()[0]; D_ASSERT(child->GetReturnType().id() == LogicalTypeId::VARIANT); - auto &bind_data = function.bind_info->Cast(); + auto &bind_data = function.BindInfo()->Cast(); if (bind_data.component.lookup_mode != VariantChildLookupMode::BY_KEY) { //! We don't push down variant extract on ARRAY values return false; @@ -1020,14 +1022,14 @@ bool BaseColumnPruner::HandleExtractRecursive(unique_ptr &expr_p, return false; } auto &function = expr.Cast(); - if (function.function.GetName() != "struct_extract_at" && function.function.GetName() != "struct_extract" && - function.function.GetName() != "array_extract" && function.function.GetName() != "variant_extract") { + if (function.Function().GetName() != "struct_extract_at" && function.Function().GetName() != "struct_extract" && + function.Function().GetName() != "array_extract" && function.Function().GetName() != "variant_extract") { return false; } - if (!function.bind_info) { + if (!function.BindInfo()) { return false; } - auto &child = function.children[0]; + auto &child = function.GetChildrenMutable()[0]; auto child_type = child->GetReturnType().id(); switch (child_type) { case LogicalTypeId::STRUCT: @@ -1100,13 +1102,13 @@ void BaseColumnPruner::MergeChildColumns(vector ¤t_child_colu } void BaseColumnPruner::AddBinding(BoundColumnRefExpression &col, ColumnIndex child_column) { - auto entry = column_references.find(col.binding); + auto entry = column_references.find(col.Binding()); if (entry == column_references.end()) { // column not referenced yet - add a binding to it entirely ReferencedColumn column; column.bindings.push_back(col); column.child_columns.push_back(std::move(child_column)); - entry = column_references.emplace(make_pair(col.binding, std::move(column))).first; + entry = column_references.emplace(make_pair(col.Binding(), std::move(column))).first; } else { // column reference already exists - check add the binding auto &column = entry->second; @@ -1153,7 +1155,7 @@ void ReferencedColumn::AddPath(const ColumnIndex &path) { void BaseColumnPruner::AddBinding(BoundColumnRefExpression &col, ColumnIndex child_column, const vector &parent) { AddBinding(col, child_column); - auto entry = column_references.find(col.binding); + auto entry = column_references.find(col.Binding()); if (entry == column_references.end()) { throw InternalException("ColumnBinding for the col was somehow not added by the previous step?"); } @@ -1169,10 +1171,10 @@ void BaseColumnPruner::AddBinding(BoundColumnRefExpression &col, ColumnIndex chi } void BaseColumnPruner::AddBinding(BoundColumnRefExpression &col) { - auto entry = column_references.find(col.binding); + auto entry = column_references.find(col.Binding()); if (entry == column_references.end()) { // column not referenced yet - add a binding to it entirely - column_references[col.binding].bindings.push_back(col); + column_references[col.Binding()].bindings.push_back(col); } else { // column reference already exists - add the binding and clear any sub-references auto &column = entry->second; @@ -1187,11 +1189,11 @@ static bool TryGetCastChild(unique_ptr &expr, optional_ptrGetExpressionClass() == ExpressionClass::BOUND_CAST); auto &cast = expr->Cast(); - if (cast.try_cast) { + if (cast.IsTryCast()) { return false; } - child = cast.child; + child = cast.ChildMutable(); return true; } diff --git a/src/duckdb/src/optimizer/row_group_pruner.cpp b/src/duckdb/src/optimizer/row_group_pruner.cpp index 679fadfd2..b1de0f586 100644 --- a/src/duckdb/src/optimizer/row_group_pruner.cpp +++ b/src/duckdb/src/optimizer/row_group_pruner.cpp @@ -90,10 +90,9 @@ bool RowGroupPruner::TryOptimize(LogicalOperator &op) const { void RowGroupPruner::GetLimitAndOffset(const LogicalLimit &logical_limit, optional_idx &row_limit, optional_idx &row_offset) const { + // UNSET = no LIMIT = unbounded; leave row_limit invalid. if (logical_limit.limit_val.Type() == LimitNodeType::CONSTANT_VALUE) { row_limit = logical_limit.limit_val.GetConstantValue(); - } else if (logical_limit.limit_val.Type() == LimitNodeType::UNSET) { - row_limit = 0; } if (logical_limit.offset_val.Type() == LimitNodeType::CONSTANT_VALUE) { @@ -133,7 +132,7 @@ optional_ptr RowGroupPruner::FindLogicalGet(const LogicalOrder &logi auto &colref = primary_order.expression->Cast(); JoinFilterPushdownColumn column; - column.probe_column_index = colref.binding; + column.probe_column_index = colref.Binding(); vector columns {std::move(column)}; vector pushdown_targets; JoinFilterPushdownOptimizer::GetPushdownFilterTargets(*logical_order.children[0], std::move(columns), diff --git a/src/duckdb/src/optimizer/row_number_rewriter.cpp b/src/duckdb/src/optimizer/row_number_rewriter.cpp index a5acd2516..0e3330df4 100644 --- a/src/duckdb/src/optimizer/row_number_rewriter.cpp +++ b/src/duckdb/src/optimizer/row_number_rewriter.cpp @@ -20,7 +20,7 @@ bool RowNumberRewriter::CanOptimize(LogicalOperator &op) { return false; } auto &window_expr = expression->Cast(); - if (!window_expr.partitions.empty() || !window_expr.orders.empty()) { + if (!window_expr.Partitions().empty() || !window_expr.OrderBy().empty()) { return false; } diff --git a/src/duckdb/src/optimizer/rule/arithmetic_simplification.cpp b/src/duckdb/src/optimizer/rule/arithmetic_simplification.cpp index 1cb0606c4..5e165eacb 100644 --- a/src/duckdb/src/optimizer/rule/arithmetic_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/arithmetic_simplification.cpp @@ -15,7 +15,7 @@ ArithmeticSimplificationRule::ArithmeticSimplificationRule(ExpressionRewriter &r op->matchers.push_back(make_uniq()); op->policy = SetMatcher::Policy::SOME; // we only match on simple arithmetic expressions (+, -, *, /) - op->function = make_uniq(unordered_set {"+", "-", "*", "//"}); + op->function = make_uniq(identifier_set_t {"+", "-", "*", "//"}); // and only with numeric results op->type = make_uniq(); op->matchers[0]->type = make_uniq(); @@ -27,41 +27,41 @@ unique_ptr ArithmeticSimplificationRule::Apply(LogicalOperator &op, bool &changes_made, bool is_root) { auto &root = bindings[0].get().Cast(); auto &constant = bindings[1].get().Cast(); - idx_t constant_child = root.children[0].get() == &constant ? 0 : 1; - D_ASSERT(root.children.size() == 2); + idx_t constant_child = root.GetChildren()[0].get() == &constant ? 0 : 1; + D_ASSERT(root.GetChildren().size() == 2); (void)root; // any arithmetic operator involving NULL is always NULL - if (constant.value.IsNull()) { + if (constant.GetValue().IsNull()) { return make_uniq(Value(root.GetReturnType())); } - auto &func_name = root.function.GetName(); + auto &func_name = root.Function().GetName(); if (func_name == "+") { - if (constant.value == 0) { + if (constant.GetValue() == 0) { // addition with 0 // we can remove the entire operator and replace it with the non-constant child - return std::move(root.children[1 - constant_child]); + return std::move(root.GetChildrenMutable()[1 - constant_child]); } } else if (func_name == "-") { - if (constant_child == 1 && constant.value == 0) { + if (constant_child == 1 && constant.GetValue() == 0) { // subtraction by 0 // we can remove the entire operator and replace it with the non-constant child - return std::move(root.children[1 - constant_child]); + return std::move(root.GetChildrenMutable()[1 - constant_child]); } } else if (func_name == "*") { - if (constant.value == 1) { + if (constant.GetValue() == 1) { // multiply with 1, replace with non-constant child - return std::move(root.children[1 - constant_child]); - } else if (constant.value == 0) { + return std::move(root.GetChildrenMutable()[1 - constant_child]); + } else if (constant.GetValue() == 0) { // multiply by zero: replace with constant or null - return ExpressionRewriter::ConstantOrNull(std::move(root.children[1 - constant_child]), + return ExpressionRewriter::ConstantOrNull(std::move(root.GetChildrenMutable()[1 - constant_child]), Value::Numeric(root.GetReturnType(), 0)); } } else if (func_name == "//") { if (constant_child == 1) { - if (constant.value == 1) { + if (constant.GetValue() == 1) { // divide by 1, replace with non-constant child - return std::move(root.children[1 - constant_child]); - } else if (constant.value == 0) { + return std::move(root.GetChildrenMutable()[1 - constant_child]); + } else if (constant.GetValue() == 0) { // divide by 0, replace with NULL return make_uniq(Value(root.GetReturnType())); } diff --git a/src/duckdb/src/optimizer/rule/case_simplification.cpp b/src/duckdb/src/optimizer/rule/case_simplification.cpp index 2dbbecef5..4e2568a3c 100644 --- a/src/duckdb/src/optimizer/rule/case_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/case_simplification.cpp @@ -14,8 +14,8 @@ CaseSimplificationRule::CaseSimplificationRule(ExpressionRewriter &rewriter) : R unique_ptr CaseSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { auto &root = bindings[0].get().Cast(); - for (idx_t i = 0; i < root.case_checks.size(); i++) { - auto &case_check = root.case_checks[i]; + for (idx_t i = 0; i < root.CaseChecksMutable().size(); i++) { + auto &case_check = root.CaseChecksMutable()[i]; if (case_check.when_expr->IsFoldable()) { // the WHEN check is a foldable expression // use an ExpressionExecutor to execute the expression @@ -25,21 +25,22 @@ unique_ptr CaseSimplificationRule::Apply(LogicalOperator &op, vector auto condition = constant_value.DefaultCastAs(LogicalType::BOOLEAN); if (condition.IsNull() || !BooleanValue::Get(condition)) { // the condition is always false: remove this case check - root.case_checks.erase_at(i); + root.CaseChecksMutable().erase_at(i); i--; } else { // the condition is always true // move the THEN clause to the ELSE of the case - root.else_expr = std::move(case_check.then_expr); + root.ElseMutable() = std::move(case_check.then_expr); // remove this case check and any case checks after this one - root.case_checks.erase(root.case_checks.begin() + NumericCast(i), root.case_checks.end()); + root.CaseChecksMutable().erase(root.CaseChecksMutable().begin() + NumericCast(i), + root.CaseChecksMutable().end()); break; } } } - if (root.case_checks.empty()) { + if (root.CaseChecksMutable().empty()) { // no case checks left: return the ELSE expression - return std::move(root.else_expr); + return std::move(root.ElseMutable()); } return nullptr; } diff --git a/src/duckdb/src/optimizer/rule/comparison_simplification.cpp b/src/duckdb/src/optimizer/rule/comparison_simplification.cpp index 57a180b04..d311b7e80 100644 --- a/src/duckdb/src/optimizer/rule/comparison_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/comparison_simplification.cpp @@ -1,6 +1,8 @@ #include "duckdb/planner/expression/list.hpp" #include "duckdb/optimizer/rule/comparison_simplification.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/types/timestamp.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/optimizer/expression_rewriter.hpp" @@ -8,6 +10,79 @@ namespace duckdb { +static bool DateTimestampComparisonIsInvertible(BoundFunctionExpression &expr, BoundCastExpression &cast_expression, + const Value &constant_value, Value &cast_constant, bool column_ref_left, + unique_ptr &replacement) { + if (Timestamp::GetTime(constant_value.GetValue()) == dtime_t(0)) { + return true; // it's midnight: no replacement needed + } + auto op = expr.GetExpressionType(); + + // Non-midnight TIMESTAMPs cannot equal DATE values. + switch (op) { + case ExpressionType::COMPARE_EQUAL: + // d = T -> false, preserving NULL + replacement = + ExpressionRewriter::ConstantOrNull(std::move(cast_expression.ChildMutable()), Value::BOOLEAN(false)); + return true; + case ExpressionType::COMPARE_NOTEQUAL: + // d != T -> true, preserving NULL + replacement = + ExpressionRewriter::ConstantOrNull(std::move(cast_expression.ChildMutable()), Value::BOOLEAN(true)); + return true; + case ExpressionType::COMPARE_DISTINCT_FROM: + // d IS DISTINCT FROM T -> true + replacement = make_uniq(Value::BOOLEAN(true)); + return true; + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + // d IS NOT DISTINCT FROM T -> false + replacement = make_uniq(Value::BOOLEAN(false)); + return true; + default: + break; + } + + // The examples describe the column-left form; column-right comparisons invert the day bump. + bool add_one_day; + switch (op) { + case ExpressionType::COMPARE_LESSTHAN: + // d < T -> d < DATE '2024-06-16' -- needs +1 + add_one_day = column_ref_left; + break; + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + // d >= T -> d >= DATE '2024-06-16' -- needs +1 + add_one_day = column_ref_left; + break; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // d <= T -> d <= DATE '2024-06-15' -- no +1 + add_one_day = !column_ref_left; + break; + case ExpressionType::COMPARE_GREATERTHAN: + // d > T -> d > DATE '2024-06-15' -- no +1 + add_one_day = !column_ref_left; + break; + default: + return false; + } + if (add_one_day) { + cast_constant = Value::DATE(date_t(cast_constant.GetValue().days + 1)); + } + return true; +} + +static bool ConstantCastIsInvertible(BoundFunctionExpression &expr, BoundCastExpression &cast_expression, + const Value &constant_value, Value &cast_constant, const LogicalType &target_type, + bool column_ref_left, unique_ptr &replacement) { + if (cast_constant.IsNull() || BoundCastExpression::CastIsInvertible(cast_expression.GetReturnType(), target_type)) { + return true; + } + if (target_type.id() != LogicalTypeId::DATE || cast_expression.GetReturnType().id() != LogicalTypeId::TIMESTAMP) { + return false; + } + return DateTimestampComparisonIsInvertible(expr, cast_expression, constant_value, cast_constant, column_ref_left, + replacement); +} + ComparisonSimplificationRule::ComparisonSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { // match on a ComparisonExpression that has a ConstantExpression as a check auto op = make_uniq(); @@ -56,14 +131,17 @@ unique_ptr ComparisonSimplificationRule::Apply(LogicalOperator &op, } // Is the constant cast invertible? - if (!cast_constant.IsNull() && - !BoundCastExpression::CastIsInvertible(cast_expression.GetReturnType(), target_type)) { - // Cast is not invertible, so we do not rewrite this expression to ensure that the cast is executed + unique_ptr replacement; + if (!ConstantCastIsInvertible(expr, cast_expression, constant_value, cast_constant, target_type, + column_ref_left, replacement)) { return nullptr; } + if (replacement) { + return replacement; + } //! We can cast, now we change our column_ref_expression from an operator cast to a column reference - auto child_expression = std::move(cast_expression.child); + auto child_expression = std::move(cast_expression.ChildMutable()); auto new_constant_expr = make_uniq(cast_constant); if (column_ref_left) { left = std::move(child_expression); diff --git a/src/duckdb/src/optimizer/rule/conjunction_simplification.cpp b/src/duckdb/src/optimizer/rule/conjunction_simplification.cpp index f1d5838d0..ba11e6a02 100644 --- a/src/duckdb/src/optimizer/rule/conjunction_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/conjunction_simplification.cpp @@ -17,16 +17,16 @@ ConjunctionSimplificationRule::ConjunctionSimplificationRule(ExpressionRewriter unique_ptr ConjunctionSimplificationRule::RemoveExpression(BoundConjunctionExpression &conj, const Expression &expr) { - for (idx_t i = 0; i < conj.children.size(); i++) { - if (conj.children[i].get() == &expr) { + for (idx_t i = 0; i < conj.GetChildrenMutable().size(); i++) { + if (conj.GetChildrenMutable()[i].get() == &expr) { // erase the expression - conj.children.erase_at(i); + conj.GetChildrenMutable().erase_at(i); break; } } - if (conj.children.size() == 1) { + if (conj.GetChildrenMutable().size() == 1) { // one expression remaining: simply return that expression and erase the conjunction - return std::move(conj.children[0]); + return std::move(conj.GetChildrenMutable()[0]); } return nullptr; } diff --git a/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp b/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp index 596898ac2..4768b8a3b 100644 --- a/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp +++ b/src/duckdb/src/optimizer/rule/constant_order_normalization.cpp @@ -6,6 +6,10 @@ namespace duckdb { +static bool MultiplicationConstantBelongsAtEnd(const BoundFunctionExpression &expr) { + return expr.Function().GetName() == "*"; +} + class RecursiveFunctionExpressionMatcher : public ExpressionMatcher { public: explicit RecursiveFunctionExpressionMatcher(vector> func_matchers) @@ -34,7 +38,7 @@ class RecursiveFunctionExpressionMatcher : public ExpressionMatcher { vector> curr_bindings; if (func_matcher->Match(expr, curr_bindings)) { auto &func_expr = expr.Cast(); - for (auto &child : func_expr.children) { + for (auto &child : func_expr.GetChildren()) { RecursiveMatch(func_matcher, *(child.get()), bindings); } } else { @@ -80,6 +84,7 @@ unique_ptr ConstantOrderNormalizationRule::Apply(LogicalOperator &op vector> &bindings, bool &changes_made, bool is_root) { auto &root = bindings.back().get().Cast(); + const auto expression_count = bindings.size() - 1; // Put all constant expressions in front. vector> ordered_bindings; @@ -94,10 +99,25 @@ unique_ptr ConstantOrderNormalizationRule::Apply(LogicalOperator &op } } - if (ordered_bindings.size() <= 1 || last_constant_position == ordered_bindings.size() - 1) { + if (ordered_bindings.empty()) { return nullptr; } - ordered_bindings.insert(ordered_bindings.end(), remain_bindings.begin(), remain_bindings.end()); + + if (ordered_bindings.size() == 1) { + // Multiplication trees can lose their folded constant subtree in the bottom-up rewriter + // (e.g. 1 * 2 * x -> 2 * x). Normalize that remaining binary form to x * 2 so it + // matches the shape produced by the other multiplication variants. + if (!MultiplicationConstantBelongsAtEnd(root) || last_constant_position == expression_count - 1) { + return nullptr; + } + remain_bindings.push_back(ordered_bindings[0]); + ordered_bindings = std::move(remain_bindings); + } else { + if (last_constant_position == ordered_bindings.size() - 1) { + return nullptr; + } + ordered_bindings.insert(ordered_bindings.end(), remain_bindings.begin(), remain_bindings.end()); + } // Reconstruct the expression. FunctionBinder binder(rewriter.context); @@ -108,8 +128,8 @@ unique_ptr ConstantOrderNormalizationRule::Apply(LogicalOperator &op for (idx_t i = 1; i < ordered_bindings.size(); ++i) { // Right child. children.push_back(ordered_bindings[i].get().Copy()); - new_root = binder.BindScalarFunction(DEFAULT_SCHEMA, root.function.GetName(), std::move(children), error, - root.is_operator); + new_root = binder.BindScalarFunction(Identifier::DefaultSchema(), root.Function().GetName(), + std::move(children), error, root.IsOperator()); if (!new_root) { error.Throw(); } diff --git a/src/duckdb/src/optimizer/rule/contains_to_in_clause.cpp b/src/duckdb/src/optimizer/rule/contains_to_in_clause.cpp new file mode 100644 index 000000000..0693edc0b --- /dev/null +++ b/src/duckdb/src/optimizer/rule/contains_to_in_clause.cpp @@ -0,0 +1,48 @@ +#include "duckdb/optimizer/rule/contains_to_in_clause.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" + +namespace duckdb { + +ContainsToInClauseRule::ContainsToInClauseRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + auto func = make_uniq(); + func->function = make_uniq("contains"); + func->matchers.push_back(make_uniq()); + func->matchers.push_back(make_uniq()); + func->policy = SetMatcher::Policy::ORDERED; + root = std::move(func); +} + +unique_ptr ContainsToInClauseRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &expr = bindings[0].get().Cast(); + auto &list_arg = expr.GetChildren()[0]; + auto &probe_arg = expr.GetChildren()[1]; + + if (probe_arg->IsFoldable()) { + return nullptr; + } + + Value list_val; + if (!ExpressionExecutor::TryEvaluateScalar(GetContext(), *list_arg, list_val)) { + return nullptr; + } + if (list_val.IsNull() || list_val.type().id() != LogicalTypeId::LIST || ListValue::GetChildren(list_val).empty()) { + return nullptr; + } + + auto in_expr = make_uniq(ExpressionType::COMPARE_IN, LogicalType::BOOLEAN); + in_expr->GetChildrenMutable().push_back(probe_arg->Copy()); + const auto &child_type = ListType::GetChildType(list_val.type()); + for (const auto &elem : ListValue::GetChildren(list_val)) { + Value v = elem.DefaultCastAs(child_type); + in_expr->GetChildrenMutable().push_back(make_uniq(std::move(v))); + } + changes_made = true; + return std::move(in_expr); +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/date_part_simplification.cpp b/src/duckdb/src/optimizer/rule/date_part_simplification.cpp index ac367e7ff..a363ff519 100644 --- a/src/duckdb/src/optimizer/rule/date_part_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/date_part_simplification.cpp @@ -24,7 +24,7 @@ unique_ptr DatePartSimplificationRule::Apply(LogicalOperator &op, ve bool &changes_made, bool is_root) { auto &date_part = bindings[0].get().Cast(); auto &constant_expr = bindings[1].get().Cast(); - auto &constant = constant_expr.value; + const auto &constant = constant_expr.GetValue(); if (constant.IsNull()) { // NULL specifier: return constant NULL @@ -32,7 +32,7 @@ unique_ptr DatePartSimplificationRule::Apply(LogicalOperator &op, ve } // otherwise check the specifier auto specifier = GetDatePartSpecifier(StringValue::Get(constant)); - string new_function_name; + Identifier new_function_name; switch (specifier) { case DatePartSpecifier::YEAR: new_function_name = "year"; @@ -90,11 +90,12 @@ unique_ptr DatePartSimplificationRule::Apply(LogicalOperator &op, ve } // found a replacement function: bind it vector> children; - children.push_back(std::move(date_part.children[1])); + children.push_back(std::move(date_part.GetChildrenMutable()[1])); ErrorData error; FunctionBinder binder(rewriter.context); - auto function = binder.BindScalarFunction(DEFAULT_SCHEMA, new_function_name, std::move(children), error, false); + auto function = + binder.BindScalarFunction(Identifier::DefaultSchema(), new_function_name, std::move(children), error, false); if (!function) { error.Throw(); } diff --git a/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp b/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp index 13dad341c..a90a4593c 100644 --- a/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/date_trunc_simplification.cpp @@ -22,7 +22,8 @@ DateTruncSimplificationRule::DateTruncSimplificationRule(ExpressionRewriter &rew auto op = make_uniq(); auto lhs = make_uniq(); - lhs->function = make_uniq(unordered_set {"date_trunc", "datetrunc"}); + lhs->function = + make_uniq(identifier_set_t {Identifier("date_trunc"), Identifier("datetrunc")}); lhs->matchers.push_back(make_uniq()); lhs->matchers.push_back(make_uniq()); lhs->policy = SetMatcher::Policy::ORDERED; @@ -91,9 +92,9 @@ unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, v // { // First check if we can just return `column IS NULL`. - if (rhs_comparison_type == ExpressionType::COMPARE_NOT_DISTINCT_FROM && rhs.value.IsNull()) { + if (rhs_comparison_type == ExpressionType::COMPARE_NOT_DISTINCT_FROM && rhs.GetValue().IsNull()) { auto op = make_uniq(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN); - op->children.push_back(column_part.Copy()); + op->GetChildrenMutable().push_back(column_part.Copy()); return std::move(op); } else { if (!is_truncated) { @@ -122,7 +123,7 @@ unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, v auto isnotnull = make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); - isnotnull->children.push_back(column_part.Copy()); + isnotnull->GetChildrenMutable().push_back(column_part.Copy()); return make_uniq(ExpressionType::CONJUNCTION_AND, std::move(comp), std::move(isnotnull)); @@ -157,11 +158,11 @@ unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, v // column IS NULL) // { - if (rhs_comparison_type == ExpressionType::COMPARE_DISTINCT_FROM && rhs.value.IsNull()) { + if (rhs_comparison_type == ExpressionType::COMPARE_DISTINCT_FROM && rhs.GetValue().IsNull()) { // Return 'column IS NOT NULL'. auto op = make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); - op->children.push_back(column_part.Copy()); + op->GetChildrenMutable().push_back(column_part.Copy()); return std::move(op); } else { if (!is_truncated) { @@ -190,7 +191,7 @@ unique_ptr DateTruncSimplificationRule::Apply(LogicalOperator &op, v auto isnull = make_uniq(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN); - isnull->children.push_back(column_part.Copy()); + isnull->GetChildrenMutable().push_back(column_part.Copy()); return make_uniq(ExpressionType::CONJUNCTION_OR, std::move(comp), std::move(isnull)); @@ -347,7 +348,7 @@ unique_ptr DateTruncSimplificationRule::CreateTrunc(const BoundConst vector> args; args.emplace_back(date_part.Copy()); args.emplace_back(rhs.Copy()); - auto trunc = binder.BindScalarFunction(DEFAULT_SCHEMA, "date_trunc", std::move(args), error); + auto trunc = binder.BindScalarFunction(Identifier::DefaultSchema(), "date_trunc", std::move(args), error); // Ensure that the RHS type matches the column type. if (trunc->GetReturnType().id() != return_type.id()) { @@ -369,7 +370,7 @@ unique_ptr DateTruncSimplificationRule::CreateTrunc(const BoundConst unique_ptr DateTruncSimplificationRule::CreateTruncAdd(const BoundConstantExpression &date_part, const BoundConstantExpression &rhs, const LogicalType &return_type) { - DatePartSpecifier part = GetDatePartSpecifier(StringValue::Get(date_part.value)); + DatePartSpecifier part = GetDatePartSpecifier(StringValue::Get(date_part.GetValue())); const string interval_func_name = DatePartToFunc(part); // If the date part cannot be represented as an interval, then we cannot @@ -384,7 +385,8 @@ unique_ptr DateTruncSimplificationRule::CreateTruncAdd(const BoundCo vector> args1; auto constant_param = make_uniq(Value::INTEGER(1)); args1.emplace_back(std::move(constant_param)); - auto interval = binder.BindScalarFunction(DEFAULT_SCHEMA, interval_func_name, std::move(args1), error); + auto interval = + binder.BindScalarFunction(Identifier::DefaultSchema(), Identifier(interval_func_name), std::move(args1), error); if (!interval) { return nullptr; // Something wrong---just don't do the optimization. } @@ -392,12 +394,12 @@ unique_ptr DateTruncSimplificationRule::CreateTruncAdd(const BoundCo vector> args2; args2.emplace_back(rhs.Copy()); args2.emplace_back(std::move(interval)); - auto add = binder.BindScalarFunction(DEFAULT_SCHEMA, "+", std::move(args2), error); + auto add = binder.BindScalarFunction(Identifier::DefaultSchema(), "+", std::move(args2), error); vector> args3; args3.emplace_back(date_part.Copy()); args3.emplace_back(std::move(add)); - auto trunc = binder.BindScalarFunction(DEFAULT_SCHEMA, "date_trunc", std::move(args3), error); + auto trunc = binder.BindScalarFunction(Identifier::DefaultSchema(), "date_trunc", std::move(args3), error); // Ensure that the RHS type matches the column type. if (trunc->GetReturnType().id() != return_type.id()) { @@ -419,7 +421,7 @@ unique_ptr DateTruncSimplificationRule::CreateTruncAdd(const BoundCo bool DateTruncSimplificationRule::DateIsTruncated(const BoundConstantExpression &date_part, const BoundConstantExpression &rhs) { // If the rhs is null, then the date is "truncated" in the sense that date_trunc(..., NULL) is also NULL. - if (rhs.value.IsNull()) { + if (rhs.GetValue().IsNull()) { return true; } diff --git a/src/duckdb/src/optimizer/rule/distinct_aggregate_optimizer.cpp b/src/duckdb/src/optimizer/rule/distinct_aggregate_optimizer.cpp index c6031a01a..ec7b183e5 100644 --- a/src/duckdb/src/optimizer/rule/distinct_aggregate_optimizer.cpp +++ b/src/duckdb/src/optimizer/rule/distinct_aggregate_optimizer.cpp @@ -17,9 +17,9 @@ unique_ptr DistinctAggregateOptimizer::Apply(ClientContext &context, // no DISTINCT defined return nullptr; } - if (aggr.function.GetDistinctDependent() == AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT) { + if (aggr.Function().GetDistinctDependent() == AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT) { // not a distinct-sensitive aggregate but we have an DISTINCT modifier - remove it - aggr.aggr_type = AggregateType::NON_DISTINCT; + aggr.GetAggregateTypeMutable() = AggregateType::NON_DISTINCT; changes_made = true; return nullptr; } @@ -39,17 +39,17 @@ DistinctWindowedOptimizer::DistinctWindowedOptimizer(ExpressionRewriter &rewrite unique_ptr DistinctWindowedOptimizer::Apply(ClientContext &context, BoundWindowExpression &wexpr, bool &changes_made) { - if (!wexpr.distinct) { + if (!wexpr.Distinct()) { // no DISTINCT defined return nullptr; } - if (!wexpr.aggregate) { + if (!wexpr.AggregateFunction()) { // not an aggregate return nullptr; } - if (wexpr.aggregate->GetDistinctDependent() == AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT) { + if (wexpr.AggregateFunction()->GetDistinctDependent() == AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT) { // not a distinct-sensitive aggregate but we have an DISTINCT modifier - remove it - wexpr.distinct = false; + wexpr.DistinctMutable() = false; changes_made = true; return nullptr; } diff --git a/src/duckdb/src/optimizer/rule/distributivity.cpp b/src/duckdb/src/optimizer/rule/distributivity.cpp index f9fd52d1d..285e65f73 100644 --- a/src/duckdb/src/optimizer/rule/distributivity.cpp +++ b/src/duckdb/src/optimizer/rule/distributivity.cpp @@ -17,7 +17,7 @@ DistributivityRule::DistributivityRule(ExpressionRewriter &rewriter) : Rule(rewr void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &set) { if (expr.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { auto &and_expr = expr.Cast(); - for (auto &child : and_expr.children) { + for (auto &child : and_expr.GetChildrenMutable()) { set.insert(*child); } } else { @@ -27,20 +27,20 @@ void DistributivityRule::AddExpressionSet(Expression &expr, expression_set_t &se unique_ptr DistributivityRule::ExtractExpression(BoundConjunctionExpression &conj, idx_t idx, const Expression &expr) { - auto &child = conj.children[idx]; + auto &child = conj.GetChildrenMutable()[idx]; unique_ptr result; if (child->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { // AND, remove expression from the list auto &and_expr = child->Cast(); - for (idx_t i = 0; i < and_expr.children.size(); i++) { - if (and_expr.children[i]->Equals(expr)) { - result = std::move(and_expr.children[i]); - and_expr.children.erase_at(i); + for (idx_t i = 0; i < and_expr.GetChildrenMutable().size(); i++) { + if (and_expr.GetChildrenMutable()[i]->Equals(expr)) { + result = std::move(and_expr.GetChildrenMutable()[i]); + and_expr.GetChildrenMutable().erase_at(i); break; } } - if (and_expr.children.size() == 1) { - conj.children[idx] = std::move(and_expr.children[0]); + if (and_expr.GetChildrenMutable().size() == 1) { + conj.GetChildrenMutable()[idx] = std::move(and_expr.GetChildrenMutable()[0]); } } // not an AND node(e.g. (X AND B) OR X) or this is the last expr, @@ -48,7 +48,7 @@ unique_ptr DistributivityRule::ExtractExpression(BoundConjunctionExp if (!result) { D_ASSERT(child->Equals(expr)); result = std::move(child); - conj.children[idx] = nullptr; + conj.GetChildrenMutable()[idx] = nullptr; } D_ASSERT(result); return result; @@ -63,13 +63,13 @@ unique_ptr DistributivityRule::Apply(LogicalOperator &op, vector DistributivityRule::Apply(LogicalOperator &op, vector(ExpressionType::CONJUNCTION_AND); for (auto &expr : candidate_set) { - D_ASSERT(initial_or.children.size() > 0); + D_ASSERT(initial_or.GetChildrenMutable().size() > 0); // extract the expression from the first child of the OR auto result = ExtractExpression(initial_or, 0, expr.get()); // now for the subsequent expressions, simply remove the expression - for (idx_t i = 1; i < initial_or.children.size(); i++) { + for (idx_t i = 1; i < initial_or.GetChildrenMutable().size(); i++) { ExtractExpression(initial_or, i, *result); } // now we add the expression to the new root - new_root->children.push_back(std::move(result)); + new_root->GetChildrenMutable().push_back(std::move(result)); } // check if we completely erased one of the children of the OR @@ -105,30 +105,30 @@ unique_ptr DistributivityRule::Apply(LogicalOperator &op, vectorchildren.size() <= 1) { - return std::move(new_root->children[0]); + for (idx_t i = 0; i < initial_or.GetChildrenMutable().size(); i++) { + if (!initial_or.GetChildrenMutable()[i]) { + if (new_root->GetChildrenMutable().size() <= 1) { + return std::move(new_root->GetChildrenMutable()[0]); } else { return std::move(new_root); } } } // finally we need to add the remaining expressions in the OR to the new root - if (initial_or.children.size() == 1) { + if (initial_or.GetChildrenMutable().size() == 1) { // one child: skip the OR entirely and only add the single child - new_root->children.push_back(std::move(initial_or.children[0])); - } else if (initial_or.children.size() > 1) { + new_root->GetChildrenMutable().push_back(std::move(initial_or.GetChildrenMutable()[0])); + } else if (initial_or.GetChildrenMutable().size() > 1) { // multiple children still remain: push them into a new OR and add that to the new root auto new_or = make_uniq(ExpressionType::CONJUNCTION_OR); - for (auto &child : initial_or.children) { - new_or->children.push_back(std::move(child)); + for (auto &child : initial_or.GetChildrenMutable()) { + new_or->GetChildrenMutable().push_back(std::move(child)); } - new_root->children.push_back(std::move(new_or)); + new_root->GetChildrenMutable().push_back(std::move(new_or)); } // finally return the new root - if (new_root->children.size() == 1) { - return std::move(new_root->children[0]); + if (new_root->GetChildrenMutable().size() == 1) { + return std::move(new_root->GetChildrenMutable()[0]); } return std::move(new_root); } diff --git a/src/duckdb/src/optimizer/rule/empty_needle_removal.cpp b/src/duckdb/src/optimizer/rule/empty_needle_removal.cpp index 0938dffa0..52e6e6f03 100644 --- a/src/duckdb/src/optimizer/rule/empty_needle_removal.cpp +++ b/src/duckdb/src/optimizer/rule/empty_needle_removal.cpp @@ -17,7 +17,7 @@ EmptyNeedleRemovalRule::EmptyNeedleRemovalRule(ExpressionRewriter &rewriter) : R func->matchers.push_back(make_uniq()); func->policy = SetMatcher::Policy::SOME; - unordered_set functions = {"prefix", "contains", "suffix"}; + identifier_set_t functions = {"prefix", "contains", "suffix"}; func->function = make_uniq(functions); root = std::move(func); } @@ -25,7 +25,7 @@ EmptyNeedleRemovalRule::EmptyNeedleRemovalRule(ExpressionRewriter &rewriter) : R unique_ptr EmptyNeedleRemovalRule::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { auto &root = bindings[0].get().Cast(); - D_ASSERT(root.children.size() == 2); + D_ASSERT(root.GetChildren().size() == 2); auto &prefix_expr = bindings[2].get(); // the constant_expr is a scalar expression that we have to fold @@ -50,7 +50,7 @@ unique_ptr EmptyNeedleRemovalRule::Apply(LogicalOperator &op, vector // PREFIX(NULL, '') is NULL // so rewrite PREFIX(x, '') to TRUE_OR_NULL(x) if (needle_string.empty()) { - return ExpressionRewriter::ConstantOrNull(std::move(root.children[0]), Value::BOOLEAN(true)); + return ExpressionRewriter::ConstantOrNull(std::move(root.GetChildrenMutable()[0]), Value::BOOLEAN(true)); } return nullptr; } diff --git a/src/duckdb/src/optimizer/rule/enum_comparison.cpp b/src/duckdb/src/optimizer/rule/enum_comparison.cpp index 50e8822fa..318c84936 100644 --- a/src/duckdb/src/optimizer/rule/enum_comparison.cpp +++ b/src/duckdb/src/optimizer/rule/enum_comparison.cpp @@ -52,7 +52,7 @@ unique_ptr EnumComparisonRule::Apply(LogicalOperator &op, vector(); auto &right_child = bindings[3].get().Cast(); - if (!AreMatchesPossible(left_child.child->GetReturnType(), right_child.child->GetReturnType())) { + if (!AreMatchesPossible(left_child.Child().GetReturnType(), right_child.Child().GetReturnType())) { vector> children; children.push_back(std::move(BoundComparisonExpression::LeftMutable(root))); children.push_back(std::move(BoundComparisonExpression::RightMutable(root))); @@ -63,10 +63,10 @@ unique_ptr EnumComparisonRule::Apply(LogicalOperator &op, vectorGetReturnType(), true); + auto cast_left_to_right = BoundCastExpression::AddDefaultCastToType(std::move(left_child.ChildMutable()), + right_child.Child().GetReturnType(), true); return BoundComparisonExpression::Create(root.GetExpressionType(), std::move(cast_left_to_right), - std::move(right_child.child)); + std::move(right_child.ChildMutable())); } } // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/equal_or_null_simplification.cpp b/src/duckdb/src/optimizer/rule/equal_or_null_simplification.cpp index b7eace180..d3281ae17 100644 --- a/src/duckdb/src/optimizer/rule/equal_or_null_simplification.cpp +++ b/src/duckdb/src/optimizer/rule/equal_or_null_simplification.cpp @@ -47,7 +47,7 @@ static unique_ptr TryRewriteEqualOrIsNull(Expression &equal_expr, Ex auto &equal_cast = equal_expr.Cast(); auto &and_cast = and_expr.Cast(); - if (and_cast.children.size() != 2) { + if (and_cast.GetChildren().size() != 2) { return nullptr; } @@ -57,12 +57,12 @@ static unique_ptr TryRewriteEqualOrIsNull(Expression &equal_expr, Ex bool a_is_null_found = false; bool b_is_null_found = false; - for (const auto &item : and_cast.children) { + for (const auto &item : and_cast.GetChildren()) { auto &next_exp = *item; if (next_exp.GetExpressionType() == ExpressionType::OPERATOR_IS_NULL) { auto &next_exp_cast = next_exp.Cast(); - auto &child = *next_exp_cast.children[0]; + auto &child = *next_exp_cast.GetChildren()[0]; // Test for equality on both 'a' and 'b' expressions if (Expression::Equals(child, a_exp)) { @@ -94,12 +94,12 @@ unique_ptr EqualOrNullSimplification::Apply(LogicalOperator &op, vec const auto &or_exp_cast = or_exp.Cast(); - if (or_exp_cast.children.size() != 2) { + if (or_exp_cast.GetChildren().size() != 2) { return nullptr; } - auto &left_exp = *or_exp_cast.children[0]; - auto &right_exp = *or_exp_cast.children[1]; + auto &left_exp = *or_exp_cast.GetChildren()[0]; + auto &right_exp = *or_exp_cast.GetChildren()[1]; // Test for: a=b OR (a IS NULL AND b IS NULL) auto first_try = TryRewriteEqualOrIsNull(left_exp, right_exp); if (first_try) { diff --git a/src/duckdb/src/optimizer/rule/in_clause_simplification_rule.cpp b/src/duckdb/src/optimizer/rule/in_clause_simplification_rule.cpp index 6e98808ab..b3077539d 100644 --- a/src/duckdb/src/optimizer/rule/in_clause_simplification_rule.cpp +++ b/src/duckdb/src/optimizer/rule/in_clause_simplification_rule.cpp @@ -1,8 +1,10 @@ #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/optimizer/expression_rewriter.hpp" #include "duckdb/optimizer/rule/in_clause_simplification.hpp" #include "duckdb/planner/expression/list.hpp" #include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/common/string_map_set.hpp" namespace duckdb { @@ -16,11 +18,11 @@ InClauseSimplificationRule::InClauseSimplificationRule(ExpressionRewriter &rewri unique_ptr InClauseSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { auto &expr = bindings[0].get().Cast(); - if (expr.children[0]->GetExpressionClass() != ExpressionClass::BOUND_CAST) { + if (expr.GetChildrenMutable()[0]->GetExpressionClass() != ExpressionClass::BOUND_CAST) { return nullptr; } - auto &cast_expression = expr.children[0]->Cast(); - if (cast_expression.child->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { + auto &cast_expression = expr.GetChildrenMutable()[0]->Cast(); + if (cast_expression.Child().GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { return nullptr; } //! The goal here is to remove the cast from the probe expression @@ -33,12 +35,12 @@ unique_ptr InClauseSimplificationRule::Apply(LogicalOperator &op, ve } vector> cast_list; //! First check if we can cast all children - for (size_t i = 1; i < expr.children.size(); i++) { - if (expr.children[i]->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { + for (size_t i = 1; i < expr.GetChildrenMutable().size(); i++) { + if (expr.GetChildrenMutable()[i]->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { return nullptr; } - D_ASSERT(expr.children[i]->IsFoldable()); - auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), *expr.children[i]); + D_ASSERT(expr.GetChildrenMutable()[i]->IsFoldable()); + auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), *expr.GetChildrenMutable()[i]); if (!BoundCastExpression::CastIsInvertible(constant_value.type(), target_type)) { return nullptr; } @@ -51,13 +53,85 @@ unique_ptr InClauseSimplificationRule::Apply(LogicalOperator &op, ve } } //! We can cast, so we move the new constant - for (size_t i = 1; i < expr.children.size(); i++) { - expr.children[i] = std::move(cast_list[i - 1]); + for (size_t i = 1; i < expr.GetChildrenMutable().size(); i++) { + expr.GetChildrenMutable()[i] = std::move(cast_list[i - 1]); // expr->children[i] = std::move(new_constant_expr); } //! We can cast the full list, so we move the column - expr.children[0] = std::move(cast_expression.child); + expr.GetChildrenMutable()[0] = std::move(cast_expression.ChildMutable()); + return nullptr; +} + +InEnumSimplificationRule::InEnumSimplificationRule(ExpressionRewriter &rewriter) : Rule(rewriter) { + // match enum::VARCHAR IN (string literals) + auto op = make_uniq(); + + // Probe must be enum::VARCHAR + auto cast_matcher = make_uniq(); + cast_matcher->type = make_uniq(LogicalType::VARCHAR); + auto enum_matcher = make_uniq(); + enum_matcher->type = make_uniq(LogicalTypeId::ENUM); + cast_matcher->matcher = std::move(enum_matcher); + op->probe_matcher = std::move(cast_matcher); + + // Children must be constant strings + vector ids; + ids.emplace_back(LogicalTypeId::VARCHAR); + ids.emplace_back(LogicalTypeId::STRING_LITERAL); + ids.emplace_back(LogicalTypeId::SQLNULL); + op->child_matcher = make_uniq(ExpressionClass::BOUND_CONSTANT); + op->child_matcher->type = make_uniq(LogicalType::VARCHAR); + + root = std::move(op); +} + +unique_ptr InEnumSimplificationRule::Apply(LogicalOperator &op, vector> &bindings, + bool &changes_made, bool is_root) { + auto &expr = bindings[0].get().Cast(); + auto &children = expr.GetChildrenMutable(); + auto &cast_expr = children[0]->Cast(); + auto &enum_expr = cast_expr.Child(); + auto &enum_type = enum_expr.GetReturnType(); + + // Look up the domain for the original ENUM + string_map_t enum_domain; + auto strings = FlatVector::GetData(EnumType::GetValuesInsertOrder(enum_type)); + for (uint64_t i = 0; i < EnumType::GetSize(enum_type); ++i) { + enum_domain[strings[i]] = i; + } + + vector> in_children; + in_children.emplace_back(enum_expr.Copy()); + + for (idx_t i = 1; i < children.size(); ++i) { + auto &child = *children[i]; + auto &v = child.Cast().GetValue(); + if (v.IsNull()) { + // Keep NULLs + in_children.emplace_back(make_uniq(Value(enum_type))); + } else { + auto s = v.ToString(); + auto it = enum_domain.find(s); + if (it != enum_domain.end()) { + // The literal is in the ENUM domain, so translate it + in_children.emplace_back(make_uniq(Value::ENUM(it->second, enum_type))); + } + // Not in the domain, so ignore it. + } + } + + // If ALL the children are string constants being matched against an ENUM (NOT) IN, + // and SOME of them are in the domain + // then swap out the children for the valid ENUM values + if (in_children.size() > 1) { + children.swap(in_children); + } else { + // constant_or_null(not_in, child[0]) + const bool not_in = (expr.GetExpressionType() == ExpressionType::COMPARE_NOT_IN); + return rewriter.ConstantOrNull(in_children[0]->Copy(), Value::BOOLEAN(not_in)); + } + return nullptr; } diff --git a/src/duckdb/src/optimizer/rule/join_dependent_filter.cpp b/src/duckdb/src/optimizer/rule/join_dependent_filter.cpp index 186eb1c4d..624e29b1a 100644 --- a/src/duckdb/src/optimizer/rule/join_dependent_filter.cpp +++ b/src/duckdb/src/optimizer/rule/join_dependent_filter.cpp @@ -17,7 +17,7 @@ JoinDependentFilterRule::JoinDependentFilterRule(ExpressionRewriter &rewriter) : static void GetTableIndices(const Expression &root_expr, unordered_set &table_idxs) { ExpressionIterator::VisitExpression( - root_expr, [&](const BoundColumnRefExpression &colref) { table_idxs.insert(colref.binding.table_index); }); + root_expr, [&](const BoundColumnRefExpression &colref) { table_idxs.insert(colref.Binding().table_index); }); } static inline bool ExpressionReferencesMultipleTables(const Expression &binding) { @@ -30,7 +30,7 @@ static inline void ExtractConjunctedExpressions(Expression &expression, unordered_map> &expressions) { if (expression.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { auto &conjunction = expression.Cast(); - for (auto &child : conjunction.children) { + for (auto &child : conjunction.GetChildren()) { ExtractConjunctedExpressions(*child, expressions); } } else if (!expression.IsVolatile()) { @@ -85,7 +85,7 @@ unique_ptr JoinDependentFilterRule::Apply(LogicalOperator &op, vecto } // Must have at least one join-dependent AND expression - auto &children = conjunction.children; + const auto &children = conjunction.GetChildren(); bool eligible = false; for (const auto &child : children) { if (child->GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION && @@ -109,7 +109,7 @@ unique_ptr JoinDependentFilterRule::Apply(LogicalOperator &op, vecto auto derived_filter = make_uniq(ExpressionType::CONJUNCTION_AND); for (auto &entry : conjuncted_expressions[0]) { auto derived_entry_filter = make_uniq(ExpressionType::CONJUNCTION_OR); - derived_entry_filter->children.push_back(entry.second->Copy()); + derived_entry_filter->GetChildrenMutable().push_back(entry.second->Copy()); bool found = true; for (idx_t conj_idx = 1; conj_idx < children.size(); conj_idx++) { @@ -119,29 +119,29 @@ unique_ptr JoinDependentFilterRule::Apply(LogicalOperator &op, vecto found = false; break; // Expression does not appear in every conjuncted expression, cannot derive any restriction } - derived_entry_filter->children.push_back(other_it->second->Copy()); + derived_entry_filter->GetChildrenMutable().push_back(other_it->second->Copy()); } if (!found) { continue; // Expression must show up in every entry } - derived_filter->children.push_back(std::move(derived_entry_filter)); + derived_filter->GetChildrenMutable().push_back(std::move(derived_entry_filter)); } - if (derived_filter->children.empty()) { + if (derived_filter->GetChildrenMutable().empty()) { return nullptr; // Could not derive filters that can be pushed down } // Add the derived expression to the original expression with an AND auto result = make_uniq_base(ExpressionType::CONJUNCTION_AND); auto &result_conjunction = result->Cast(); - result_conjunction.children.push_back(conjunction.Copy()); + result_conjunction.GetChildrenMutable().push_back(conjunction.Copy()); - if (derived_filter->children.size() == 1) { - result_conjunction.children.push_back(std::move(derived_filter->children[0])); + if (derived_filter->GetChildrenMutable().size() == 1) { + result_conjunction.GetChildrenMutable().push_back(std::move(derived_filter->GetChildrenMutable()[0])); } else { - result_conjunction.children.push_back(std::move(derived_filter)); + result_conjunction.GetChildrenMutable().push_back(std::move(derived_filter)); } return result; } diff --git a/src/duckdb/src/optimizer/rule/like_optimizations.cpp b/src/duckdb/src/optimizer/rule/like_optimizations.cpp index cf3581f49..da0504fe6 100644 --- a/src/duckdb/src/optimizer/rule/like_optimizations.cpp +++ b/src/duckdb/src/optimizer/rule/like_optimizations.cpp @@ -17,7 +17,7 @@ LikeOptimizationRule::LikeOptimizationRule(ExpressionRewriter &rewriter) : Rule( func->matchers.push_back(make_uniq()); func->policy = SetMatcher::Policy::ORDERED; // we match on LIKE ("~~") and NOT LIKE ("!~~") - func->function = make_uniq(unordered_set {"!~~", "~~"}); + func->function = make_uniq(identifier_set_t {Identifier("!~~"), Identifier("~~")}); root = std::move(func); } @@ -105,9 +105,9 @@ unique_ptr LikeOptimizationRule::Apply(LogicalOperator &op, vector(); auto &constant_expr = bindings[2].get().Cast(); - D_ASSERT(root.children.size() == 2); + D_ASSERT(root.GetChildren().size() == 2); - if (constant_expr.value.IsNull()) { + if (constant_expr.GetValue().IsNull()) { return make_uniq(Value(root.GetReturnType())); } @@ -120,12 +120,12 @@ unique_ptr LikeOptimizationRule::Apply(LogicalOperator &op, vector LikeOptimizationRule::Apply(LogicalOperator &op, vector LikeOptimizationRule::ApplyRule(BoundFunctionExpression &expr, const ScalarFunction &function, string pattern, bool is_not_like) const { // replace LIKE by an optimized function - auto result = function.Bind(GetContext(), std::move(expr.children)); + auto result = function.Bind(GetContext(), std::move(expr.GetChildrenMutable())); // removing "%" from the pattern pattern.erase(std::remove(pattern.begin(), pattern.end(), '%'), pattern.end()); - result->children[1] = make_uniq(Value(std::move(pattern))); + result->GetChildrenMutable()[1] = make_uniq(Value(std::move(pattern))); if (!is_not_like) { return std::move(result); } auto negation = make_uniq(ExpressionType::OPERATOR_NOT, LogicalType::BOOLEAN); - negation->children.push_back(std::move(result)); + negation->GetChildrenMutable().push_back(std::move(result)); return std::move(negation); } diff --git a/src/duckdb/src/optimizer/rule/list_comprehension_rewrite.cpp b/src/duckdb/src/optimizer/rule/list_comprehension_rewrite.cpp index 3147e4148..e0d1a2286 100644 --- a/src/duckdb/src/optimizer/rule/list_comprehension_rewrite.cpp +++ b/src/duckdb/src/optimizer/rule/list_comprehension_rewrite.cpp @@ -18,15 +18,15 @@ namespace duckdb { namespace { bool IsTargetListFunction(ClientContext &context, const BoundFunctionExpression &expr, const string &target_name) { - if (expr.function.GetName() == target_name) { - D_ASSERT(!expr.children.empty() && expr.children[0]->GetReturnType().id() == LogicalTypeId::LIST); + if (expr.Function().GetName() == target_name) { + D_ASSERT(!expr.GetChildren().empty() && expr.GetChildren()[0]->GetReturnType().id() == LogicalTypeId::LIST); return true; } // Compare function name with catalog to recognize aliases auto &catalog = Catalog::GetSystemCatalog(context); - auto entry = catalog.GetEntry(context, DEFAULT_SCHEMA, expr.function.GetName(), - OnEntryNotFound::RETURN_NULL); + auto entry = catalog.GetEntry(context, Identifier::DefaultSchema(), + expr.Function().GetName(), OnEntryNotFound::RETURN_NULL); if (!entry) { return false; } @@ -37,20 +37,20 @@ bool IsTargetListFunction(ClientContext &context, const BoundFunctionExpression } bool IsStructPack(const BoundFunctionExpression &expr) { - return expr.function.GetName() == "struct_pack"; + return expr.Function().GetName() == "struct_pack"; } optional_ptr UnwrapCasts(optional_ptr expr) { while (expr->GetExpressionClass() == ExpressionClass::BOUND_CAST) { auto &cast_expr = expr->Cast(); - expr = cast_expr.child.get(); + expr = cast_expr.ChildMutable().get(); } return expr; } optional_ptr FindStructPackChildByName(BoundFunctionExpression &struct_pack, const string &name) { D_ASSERT(struct_pack.GetReturnType().id() == LogicalTypeId::STRUCT); - for (auto &child : struct_pack.children) { + for (auto &child : struct_pack.GetChildren()) { if (child->GetAlias() == name) { return child.get(); } @@ -61,7 +61,7 @@ optional_ptr FindStructPackChildByName(BoundFunctionExpression &stru bool UsesIndexParameter(Expression &expr) { if (expr.GetExpressionClass() == ExpressionClass::BOUND_REF) { auto &ref = expr.Cast(); - if (ref.index == 0) { + if (ref.Index() == 0) { return true; } } @@ -73,7 +73,7 @@ bool UsesIndexParameter(Expression &expr) { } if (child.GetExpressionClass() == ExpressionClass::BOUND_REF) { auto &ref = child.Cast(); - if (ref.index == 0) { + if (ref.Index() == 0) { uses_index = true; return; } @@ -87,8 +87,8 @@ void RemoveIndexInputSlot(unique_ptr &expr) { ExpressionIterator::VisitExpressionClassMutable(expr, ExpressionClass::BOUND_REF, [&](unique_ptr &child) { auto &ref = child->Cast(); - D_ASSERT(ref.index > 0); - ref.index--; + D_ASSERT(ref.Index() > 0); + ref.IndexMutable()--; }); } @@ -98,29 +98,29 @@ bool MatchesStructFieldProjection(Expression &expr, const string &field_name) { return false; } auto &extract_expr = base->Cast(); - if (extract_expr.function.GetName() != "struct_extract" || extract_expr.children.size() != 2) { + if (extract_expr.Function().GetName() != "struct_extract" || extract_expr.GetChildren().size() != 2) { return false; } - auto struct_arg = UnwrapCasts(*extract_expr.children[0]); + auto struct_arg = UnwrapCasts(*extract_expr.GetChildren()[0]); if (struct_arg->GetExpressionClass() != ExpressionClass::BOUND_REF) { return false; } auto &struct_ref = struct_arg->Cast(); - if (struct_ref.index != 0) { + if (struct_ref.Index() != 0) { return false; } - auto field_arg = UnwrapCasts(*extract_expr.children[1]); + auto field_arg = UnwrapCasts(*extract_expr.GetChildren()[1]); if (field_arg->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { return false; } auto &field_constant = field_arg->Cast(); - if (field_constant.value.IsNull() || field_constant.value.type() != LogicalTypeId::VARCHAR) { + if (field_constant.GetValue().IsNull() || field_constant.GetValue().type() != LogicalTypeId::VARCHAR) { return false; } - return StringValue::Get(field_constant.value) == field_name; + return StringValue::Get(field_constant.GetValue()) == field_name; } struct ListComprehensionMatch { @@ -135,9 +135,9 @@ struct ListComprehensionMatch { vector> CopyCapturedChildren(BoundFunctionExpression &inner_apply) { vector> captured_children; - captured_children.reserve(inner_apply.children.empty() ? 0 : inner_apply.children.size() - 1); - for (idx_t i = 1; i < inner_apply.children.size(); i++) { - captured_children.push_back(inner_apply.children[i]->Copy()); + captured_children.reserve(inner_apply.GetChildren().empty() ? 0 : inner_apply.GetChildren().size() - 1); + for (idx_t i = 1; i < inner_apply.GetChildren().size(); i++) { + captured_children.push_back(inner_apply.GetChildren()[i]->Copy()); } return captured_children; } @@ -148,21 +148,21 @@ unique_ptr MatchListComprehensionRewrite(ClientContext & if (!IsTargetListFunction(context, root, "list_transform")) { return nullptr; } - if (root.children.empty()) { + if (root.GetChildren().empty()) { return nullptr; } - if (root.children[0]->GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + if (root.GetChildren()[0]->GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { return nullptr; } - auto &list_filter_expr = root.children[0]->Cast(); + auto &list_filter_expr = root.GetChildren()[0]->Cast(); if (!IsTargetListFunction(context, list_filter_expr, "list_filter")) { return nullptr; } - if (!list_filter_expr.bind_info || !root.bind_info) { + if (!list_filter_expr.BindInfo() || !root.BindInfo()) { return nullptr; } - auto &list_filter_bind = list_filter_expr.bind_info->Cast(); - auto &root_bind = root.bind_info->Cast(); + auto &list_filter_bind = list_filter_expr.BindInfo()->Cast(); + auto &root_bind = root.BindInfo()->Cast(); if (!list_filter_bind.lambda_expr || !root_bind.lambda_expr) { return nullptr; } @@ -170,17 +170,17 @@ unique_ptr MatchListComprehensionRewrite(ClientContext & !MatchesStructFieldProjection(*root_bind.lambda_expr, "result")) { return nullptr; } - if (list_filter_expr.children.empty()) { + if (list_filter_expr.GetChildren().empty()) { return nullptr; } - if (list_filter_expr.children[0]->GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + if (list_filter_expr.GetChildren()[0]->GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { return nullptr; } - auto &inner_apply = list_filter_expr.children[0]->Cast(); - if (!IsTargetListFunction(context, inner_apply, "list_transform") || !inner_apply.bind_info) { + auto &inner_apply = list_filter_expr.GetChildren()[0]->Cast(); + if (!IsTargetListFunction(context, inner_apply, "list_transform") || !inner_apply.BindInfo()) { return nullptr; } - auto &inner_bind = inner_apply.bind_info->Cast(); + auto &inner_bind = inner_apply.BindInfo()->Cast(); if (!inner_bind.lambda_expr) { return nullptr; } @@ -215,7 +215,7 @@ unique_ptr MatchListComprehensionRewrite(ClientContext & return match; } -unique_ptr BuildListComprehensionRewrite(ListComprehensionMatch &match) { +unique_ptr BuildListComprehensionRewrite(ClientContext &context, ListComprehensionMatch &match) { auto &inner_apply = *match.inner_apply; auto &list_filter_expr = *match.list_filter_expr; auto &root = *match.root; @@ -225,23 +225,28 @@ unique_ptr BuildListComprehensionRewrite(ListComprehensionMatch &mat // Build list_filter(list, lambda filter_expr) vector> filter_children; - filter_children.reserve(inner_apply.children.size()); - for (idx_t i = 0; i < inner_apply.children.size(); i++) { - filter_children.push_back(std::move(inner_apply.children[i])); + filter_children.reserve(inner_apply.GetChildren().size()); + for (idx_t i = 0; i < inner_apply.GetChildren().size(); i++) { + filter_children.push_back(std::move(inner_apply.GetChildrenMutable()[i])); } auto filter_return_type = filter_children[0]->GetReturnType(); - auto filter_bind_info = make_uniq(filter_return_type, filter_expr.Copy(), inner_bind.has_index); + auto filter_lambda = filter_expr.Copy(); + if (filter_lambda->GetReturnType() != LogicalType::BOOLEAN) { + filter_lambda = BoundCastExpression::AddCastToType(context, std::move(filter_lambda), LogicalType::BOOLEAN); + } + auto filter_bind_info = + make_uniq(filter_return_type, std::move(filter_lambda), inner_bind.has_index); - auto new_func = list_filter_expr.function; + auto new_func = list_filter_expr.Function(); new_func.SetReturnType(filter_return_type); auto new_filter = make_uniq(std::move(new_func), std::move(filter_children), - std::move(filter_bind_info), list_filter_expr.is_operator); + std::move(filter_bind_info), list_filter_expr.IsOperator()); // Build list_apply(list_filter(...), lambda result_expr) vector> apply_children; - apply_children.reserve(inner_apply.children.size()); + apply_children.reserve(inner_apply.GetChildren().size()); apply_children.push_back(std::move(new_filter)); for (auto &captured_child : match.captured_children) { apply_children.push_back(std::move(captured_child)); @@ -254,8 +259,8 @@ unique_ptr BuildListComprehensionRewrite(ListComprehensionMatch &mat RemoveIndexInputSlot(apply_lambda); } auto apply_bind_info = make_uniq(apply_return_type, std::move(apply_lambda)); - return make_uniq(root.function, std::move(apply_children), std::move(apply_bind_info), - root.is_operator); + return make_uniq(root.Function(), std::move(apply_children), std::move(apply_bind_info), + root.IsOperator()); } } // namespace @@ -271,7 +276,7 @@ unique_ptr ListComprehensionRewriteRule::Apply(LogicalOperator &, ve if (!match) { return nullptr; } - return BuildListComprehensionRewrite(*match); + return BuildListComprehensionRewrite(GetContext(), *match); } } // namespace duckdb diff --git a/src/duckdb/src/optimizer/rule/move_constants.cpp b/src/duckdb/src/optimizer/rule/move_constants.cpp index 575ef13af..e97356629 100644 --- a/src/duckdb/src/optimizer/rule/move_constants.cpp +++ b/src/duckdb/src/optimizer/rule/move_constants.cpp @@ -18,7 +18,8 @@ MoveConstantsRule::MoveConstantsRule(ExpressionRewriter &rewriter) : Rule(rewrit // we handle multiplication, addition and subtraction because those are "easy" // integer division makes the division case difficult // e.g. [x / 2 = 3] means [x = 6 OR x = 7] because of truncation -> no clean rewrite rules - arithmetic->function = make_uniq(unordered_set {"+", "-", "*"}); + arithmetic->function = + make_uniq(identifier_set_t {Identifier("+"), Identifier("-"), Identifier("*")}); // we match only on integral numeric types arithmetic->type = make_uniq(); auto child_constant_matcher = make_uniq(); @@ -39,8 +40,8 @@ unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector(); auto &inner_constant = bindings[3].get().Cast(); D_ASSERT(arithmetic.GetReturnType().IsIntegral()); - D_ASSERT(arithmetic.children[0]->GetReturnType().IsIntegral()); - if (inner_constant.value.IsNull() || outer_constant.value.IsNull()) { + D_ASSERT(arithmetic.GetChildren()[0]->GetReturnType().IsIntegral()); + if (inner_constant.GetValue().IsNull() || outer_constant.GetValue().IsNull()) { if (comparison.GetExpressionType() == ExpressionType::COMPARE_DISTINCT_FROM || comparison.GetExpressionType() == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { return nullptr; @@ -48,11 +49,11 @@ unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector(Value(comparison.GetReturnType())); } auto &constant_type = outer_constant.GetReturnType(); - hugeint_t outer_value = IntegralValue::Get(outer_constant.value); - hugeint_t inner_value = IntegralValue::Get(inner_constant.value); + hugeint_t outer_value = IntegralValue::Get(outer_constant.GetValue()); + hugeint_t inner_value = IntegralValue::Get(inner_constant.GetValue()); - idx_t arithmetic_child_index = arithmetic.children[0].get() == &inner_constant ? 1 : 0; - auto &op_type = arithmetic.function.GetName(); + idx_t arithmetic_child_index = arithmetic.GetChildrenMutable()[0].get() == &inner_constant ? 1 : 0; + auto &op_type = arithmetic.Function().GetName(); if (op_type == "+") { // [x + 1 COMP 10] OR [1 + x COMP 10] // order does not matter in addition: @@ -68,10 +69,10 @@ unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector MoveConstantsRule::Apply(LogicalOperator &op, vector MoveConstantsRule::Apply(LogicalOperator &op, vector 2] BoundComparisonExpression::FlipType(comparison); @@ -132,8 +133,8 @@ unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector >= < <=, skip the simplification for now return nullptr; @@ -148,14 +149,14 @@ unique_ptr MoveConstantsRule::Apply(LogicalOperator &op, vector OrderedAggregateOptimizer::Apply(ClientContext &context, vector> &groups, optional_ptr> grouping_sets, bool &changes_made) { - if (!aggr.order_bys) { + if (!aggr.GetOrderBys()) { // no ORDER BYs defined return nullptr; } - if (aggr.function.GetOrderDependent() == AggregateOrderDependent::NOT_ORDER_DEPENDENT) { + if (aggr.Function().GetOrderDependent() == AggregateOrderDependent::NOT_ORDER_DEPENDENT) { // not an order dependent aggregate but we have an ORDER BY clause - remove it - aggr.order_bys.reset(); + aggr.GetOrderBysMutable().reset(); changes_made = true; return nullptr; } // Remove unnecessary ORDER BY clauses and return if nothing remains - if (aggr.order_bys->Simplify(groups, grouping_sets)) { - aggr.order_bys.reset(); + if (aggr.GetOrderBysMutable()->Simplify(groups, grouping_sets)) { + aggr.GetOrderBysMutable().reset(); changes_made = true; return nullptr; } // Rewrite first/last/arbitrary/any_value to use arg_xxx[_null] and create_sort_key - const auto &aggr_name = aggr.function.GetName(); + const auto &aggr_name = aggr.Function().GetName(); string arg_xxx_name; if (aggr_name == "last") { arg_xxx_name = "arg_max_null"; @@ -53,25 +53,26 @@ unique_ptr OrderedAggregateOptimizer::Apply(ClientContext &context, FunctionBinder binder(context); vector> sort_children; - for (auto &order : aggr.order_bys->orders) { + for (auto &order : aggr.GetOrderBysMutable()->orders) { sort_children.emplace_back(std::move(order.expression)); sort_children.emplace_back(make_uniq(Value(order.GetOrderModifier()))); } - aggr.order_bys.reset(); + aggr.GetOrderBysMutable().reset(); ErrorData error; - auto sort_key = binder.BindScalarFunction(DEFAULT_SCHEMA, "create_sort_key", std::move(sort_children), error); + auto sort_key = + binder.BindScalarFunction(Identifier::DefaultSchema(), "create_sort_key", std::move(sort_children), error); if (!sort_key) { error.Throw(); } - auto &children = aggr.children; + auto &children = aggr.GetChildrenMutable(); children.emplace_back(std::move(sort_key)); // Look up the arg_xxx_name function in the catalog QueryErrorContext error_context; - auto &func = Catalog::GetEntry(context, SYSTEM_CATALOG, DEFAULT_SCHEMA, arg_xxx_name, - error_context); + auto &func = Catalog::GetEntry( + context, Identifier::SystemCatalog(), Identifier::DefaultSchema(), Identifier(arg_xxx_name), error_context); D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); // bind the aggregate @@ -85,7 +86,7 @@ unique_ptr OrderedAggregateOptimizer::Apply(ClientContext &context, } // found a matching function! const auto &bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); - return binder.BindAggregateFunction(bound_function, std::move(children), std::move(aggr.filter), + return binder.BindAggregateFunction(bound_function, std::move(children), std::move(aggr.GetFilterMutable()), aggr.IsDistinct() ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT); } @@ -97,6 +98,10 @@ unique_ptr OrderedAggregateOptimizer::Apply(LogicalOperator &op, vec if (op.type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { return nullptr; } + // don't rewrite state-export aggregates - the rewrite would lose the STATE_EXPORT mode + if (aggr.StateExportMode() == AggregateStateExportMode::STATE_EXPORT) { + return nullptr; + } return Apply(rewriter.context, aggr, op.Cast().groups, op.Cast().grouping_sets, changes_made); diff --git a/src/duckdb/src/optimizer/rule/predicate_factoring.cpp b/src/duckdb/src/optimizer/rule/predicate_factoring.cpp index 991cf8dc5..35ce8b9bd 100644 --- a/src/duckdb/src/optimizer/rule/predicate_factoring.cpp +++ b/src/duckdb/src/optimizer/rule/predicate_factoring.cpp @@ -23,7 +23,7 @@ static bool ExpressionIsDisjunction(const Expression &expr) { static void ExtractDisjunctedPredicates(Expression &expression, vector> &disjuncted_children) { if (ExpressionIsDisjunction(expression)) { auto &disjunction = expression.Cast(); - for (auto &child : disjunction.children) { + for (auto &child : disjunction.GetChildren()) { ExtractDisjunctedPredicates(*child, disjuncted_children); } } else { @@ -45,8 +45,8 @@ static bool GetSingleColumnBinding(const Expression &expr, ColumnBinding &column bool found_multiple = false; ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { if (!ColumnBindingIsvalid(column_binding)) { - column_binding = colref.binding; - } else if (column_binding != colref.binding) { + column_binding = colref.Binding(); + } else if (column_binding != colref.Binding()) { found_multiple = true; } }); @@ -60,7 +60,7 @@ static void ExtractDisjunctedPredicates(Expression &expression, static void ExtractConjunctedPredicates(BoundConjunctionExpression &conjunction, column_binding_map_t>> &binding_map) { D_ASSERT(conjunction.GetExpressionType() == ExpressionType::CONJUNCTION_AND); - for (auto &conjunction_child : conjunction.children) { + for (auto &conjunction_child : conjunction.GetChildren()) { if (conjunction_child->GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION && conjunction_child->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { ExtractConjunctedPredicates(conjunction_child->Cast(), binding_map); @@ -161,8 +161,8 @@ unique_ptr PredicateFactoringRule::Apply(LogicalOperator &op, vector column_filter = expr.get().Copy(); } else { auto new_disjunction = make_uniq(ExpressionType::CONJUNCTION_OR); - new_disjunction->children.push_back(std::move(column_filter)); - new_disjunction->children.push_back(expr.get().Copy()); + new_disjunction->GetChildrenMutable().push_back(std::move(column_filter)); + new_disjunction->GetChildrenMutable().push_back(expr.get().Copy()); column_filter = std::move(new_disjunction); } } @@ -172,16 +172,16 @@ unique_ptr PredicateFactoringRule::Apply(LogicalOperator &op, vector derived_filter = std::move(column_filter); } else { auto new_conjunction = make_uniq(ExpressionType::CONJUNCTION_AND); - new_conjunction->children.push_back(std::move(column_filter)); - new_conjunction->children.push_back(std::move(derived_filter)); + new_conjunction->GetChildrenMutable().push_back(std::move(column_filter)); + new_conjunction->GetChildrenMutable().push_back(std::move(derived_filter)); derived_filter = std::move(new_conjunction); } } // Now add the derived filter as an AND to the original expression auto result = make_uniq(ExpressionType::CONJUNCTION_AND); - result->children.push_back(bindings[0].get().Copy()); - result->children.push_back(std::move(derived_filter)); + result->GetChildrenMutable().push_back(bindings[0].get().Copy()); + result->GetChildrenMutable().push_back(std::move(derived_filter)); return std::move(result); } diff --git a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp index bf7e16bac..3b6d02ca4 100644 --- a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp +++ b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp @@ -153,15 +153,15 @@ unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector< bool &changes_made, bool is_root) { auto &root = bindings[0].get().Cast(); auto &constant_expr = bindings[2].get().Cast(); - D_ASSERT(root.children.size() == 2 || root.children.size() == 3); - auto regexp_bind_data = root.bind_info.get()->Cast(); + D_ASSERT(root.GetChildrenMutable().size() == 2 || root.GetChildrenMutable().size() == 3); + auto regexp_bind_data = root.BindInfo().get()->Cast(); auto constant_value = ExpressionExecutor::EvaluateScalar(GetContext(), constant_expr); D_ASSERT(constant_value.type() == constant_expr.GetReturnType()); duckdb_re2::RE2::Options parsed_options = regexp_bind_data.options; - if (constant_expr.value.IsNull()) { + if (constant_expr.GetValue().IsNull()) { return make_uniq(Value(root.GetReturnType())); } auto patt_str = StringValue::Get(constant_value); @@ -188,15 +188,15 @@ unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector< } // if regexp had options, remove them so the new Contains Expression can be matched for other optimizers. - if (root.children.size() == 3) { - root.children.pop_back(); - D_ASSERT(root.children.size() == 2); + if (root.GetChildrenMutable().size() == 3) { + root.GetChildrenMutable().pop_back(); + D_ASSERT(root.GetChildrenMutable().size() == 2); } auto parameter = make_uniq(Value(std::move(escaped_like_string.like_string))); - auto contains = GetStringContains().Bind(GetContext(), std::move(root.children)); + auto contains = GetStringContains().Bind(GetContext(), std::move(root.GetChildrenMutable())); - contains->children[1] = std::move(parameter); + contains->GetChildrenMutable()[1] = std::move(parameter); return std::move(contains); } else if (pattern.Regexp()->op() == duckdb_re2::kRegexpConcat) { @@ -210,18 +210,18 @@ unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector< } // if regexp had options, remove them so the new Like Expression can be matched for other optimizers. - if (root.children.size() == 3) { - root.children.pop_back(); - D_ASSERT(root.children.size() == 2); + if (root.GetChildrenMutable().size() == 3) { + root.GetChildrenMutable().pop_back(); + D_ASSERT(root.GetChildrenMutable().size() == 2); } - auto like_expression = LikeFun::GetFunction().Bind(GetContext(), std::move(root.children)); + auto like_expression = LikeFun::GetFunction().Bind(GetContext(), std::move(root.GetChildrenMutable())); // Clear the bind info, as the LikeFun bind info is not valid for this new expression. - like_expression->bind_info.reset(); + like_expression->BindInfoMutable().reset(); auto parameter = make_uniq(Value(std::move(like_string.like_string))); - like_expression->children[1] = std::move(parameter); + like_expression->GetChildrenMutable()[1] = std::move(parameter); return std::move(like_expression); } @@ -250,10 +250,10 @@ static bool RegexpHasWholeTextAnchors(duckdb_re2::Regexp *regexp) { unique_ptr RegexpReplaceExtractRule::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { auto &root = bindings[0].get().Cast(); - if (!root.bind_info) { + if (!root.BindInfo()) { return nullptr; } - auto &bind_data = root.bind_info->Cast(); + auto &bind_data = root.BindInfo()->Cast(); if (!bind_data.constant_pattern || bind_data.global_replace) { return nullptr; } @@ -263,10 +263,10 @@ unique_ptr RegexpReplaceExtractRule::Apply(LogicalOperator &op, vect } const auto &replace_const = bindings[3].get().Cast(); - if (replace_const.value.IsNull() || replace_const.value.type().id() != LogicalTypeId::VARCHAR) { + if (replace_const.GetValue().IsNull() || replace_const.GetValue().type().id() != LogicalTypeId::VARCHAR) { return nullptr; } - const auto &replace_str = StringValue::Get(replace_const.value); + const auto &replace_str = StringValue::Get(replace_const.GetValue()); int32_t group_index; // Only a bare backreference can be represented by regexp_extract. Replacement literals or // composites such as `x\1`, `\1x`, and `\1\2` must stay on regexp_replace. @@ -282,23 +282,24 @@ unique_ptr RegexpReplaceExtractRule::Apply(LogicalOperator &op, vect } string options_str = "k"; - if (root.children.size() == 4) { - auto &options_const = root.children[3]->Cast(); - if (options_const.value.IsNull() || options_const.value.type().id() != LogicalTypeId::VARCHAR) { + if (root.GetChildrenMutable().size() == 4) { + auto &options_const = root.GetChildrenMutable()[3]->Cast(); + if (options_const.GetValue().IsNull() || options_const.GetValue().type().id() != LogicalTypeId::VARCHAR) { return nullptr; } - options_str = StringValue::Get(options_const.value) + options_str; + options_str = StringValue::Get(options_const.GetValue()) + options_str; } vector> extract_children; - extract_children.emplace_back(std::move(root.children[0])); + extract_children.emplace_back(std::move(root.GetChildrenMutable()[0])); extract_children.emplace_back(make_uniq(Value(pattern))); extract_children.emplace_back(make_uniq(Value::INTEGER(group_index))); extract_children.emplace_back(make_uniq(Value(std::move(options_str)))); FunctionBinder binder(rewriter.context); ErrorData error; - auto extract_expr = binder.BindScalarFunction(DEFAULT_SCHEMA, "regexp_extract", std::move(extract_children), error); + auto extract_expr = + binder.BindScalarFunction(Identifier::DefaultSchema(), "regexp_extract", std::move(extract_children), error); if (!extract_expr) { return nullptr; } diff --git a/src/duckdb/src/optimizer/rule/timestamp_comparison.cpp b/src/duckdb/src/optimizer/rule/timestamp_comparison.cpp index 32f886e56..ee9eda67e 100644 --- a/src/duckdb/src/optimizer/rule/timestamp_comparison.cpp +++ b/src/duckdb/src/optimizer/rule/timestamp_comparison.cpp @@ -17,10 +17,9 @@ namespace duckdb { TimeStampComparison::TimeStampComparison(ExpressionRewriter &rewriter) : Rule(rewriter), context(rewriter.context) { - // match on a ComparisonExpression that is an Equality and has a VARCHAR and ENUM as its children + // match on an equality comparison between a timestamp column cast to DATE and a foldable DATE constant auto op = make_uniq(); op->policy = SetMatcher::Policy::UNORDERED; - // Enum requires expression to be root op->expr_type = make_uniq(ExpressionType::COMPARE_EQUAL); // one side is timestamp cast to date @@ -31,34 +30,49 @@ TimeStampComparison::TimeStampComparison(ExpressionRewriter &rewriter) : Rule(re left->matcher->type = make_uniq(LogicalTypeId::TIMESTAMP); op->matchers.push_back(std::move(left)); - // other side is varchar to date? - auto right = make_uniq(); + // other side is any foldable DATE constant expression; bottom-up rewriting can fold CAST(VARCHAR AS DATE) + auto right = make_uniq(); right->type = make_uniq(LogicalTypeId::DATE); - right->matcher = make_uniq(); - right->matcher->expr_class = ExpressionClass::BOUND_CONSTANT; - right->matcher->type = make_uniq(LogicalTypeId::VARCHAR); op->matchers.push_back(std::move(right)); root = std::move(op); } -static void ExpressionIsConstant(const Expression &root_expr, bool &is_constant) { - ExpressionIterator::VisitExpression( - root_expr, [&](const BoundColumnRefExpression &column_ref) { is_constant = false; }); +static BoundCastExpression *GetTimestampCast(Expression &expr) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_CAST) { + return nullptr; + } + auto &cast_expr = expr.Cast(); + if (cast_expr.GetReturnType().id() != LogicalTypeId::DATE) { + return nullptr; + } + if (cast_expr.Child().GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF || + cast_expr.Child().GetReturnType().id() != LogicalTypeId::TIMESTAMP) { + return nullptr; + } + return &cast_expr; } unique_ptr TimeStampComparison::Apply(LogicalOperator &op, vector> &bindings, bool &changes_made, bool is_root) { - auto cast_constant = bindings[3].get().Copy(); - auto cast_columnref = bindings[2].get().Copy(); - auto is_constant = true; - ExpressionIsConstant(*cast_constant, is_constant); - if (!is_constant) { - // means the matchers are flipped, so we need to flip our bindings - // for some reason an extra binding is added in this case. - cast_constant = bindings[4].get().Copy(); - cast_columnref = bindings[3].get().Copy(); + auto &comparison = bindings[0].get().Cast(); + D_ASSERT(comparison.GetChildren().size() == 2); + + BoundCastExpression *cast_expr = nullptr; + Expression *constant_expr = nullptr; + for (auto &child : comparison.GetChildren()) { + if (auto timestamp_cast = GetTimestampCast(*child)) { + cast_expr = timestamp_cast; + } else if (child->IsFoldable() && child->GetReturnType().id() == LogicalTypeId::DATE) { + constant_expr = child.get(); + } + } + if (!cast_expr || !constant_expr) { + return nullptr; } + + auto cast_constant = constant_expr->Copy(); + auto cast_columnref = cast_expr->Child().Copy(); auto new_expr = make_uniq(ExpressionType::CONJUNCTION_AND); Value result; @@ -94,8 +108,8 @@ unique_ptr TimeStampComparison::Apply(LogicalOperator &op, vectorchildren.push_back(std::move(gt_eq_expr)); - new_expr->children.push_back(std::move(lt_eq_expr)); + new_expr->GetChildrenMutable().push_back(std::move(gt_eq_expr)); + new_expr->GetChildrenMutable().push_back(std::move(lt_eq_expr)); return std::move(new_expr); } return nullptr; diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp index 052046801..4416dbcc1 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_aggregate.cpp @@ -6,8 +6,8 @@ namespace duckdb { unique_ptr StatisticsPropagator::PropagateExpression(BoundAggregateExpression &aggr, unique_ptr &expr_ptr) { vector stats; - stats.reserve(aggr.children.size()); - for (auto &child : aggr.children) { + stats.reserve(aggr.GetChildren().size()); + for (auto &child : aggr.GetChildrenMutable()) { auto stat = PropagateExpression(child); if (!stat) { stats.push_back(BaseStatistics::CreateUnknown(child->GetReturnType())); @@ -15,11 +15,11 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundAggreg stats.push_back(stat->Copy()); } } - if (!aggr.function.GetCallbacks().HasStatisticsCallback()) { + if (!aggr.Function().GetCallbacks().HasStatisticsCallback()) { return nullptr; } - AggregateStatisticsInput input(aggr.bind_info.get(), stats, node_stats.get()); - return aggr.function.GetCallbacks().GetStatisticsCallback()(context, aggr, input); + AggregateStatisticsInput input(aggr.BindInfo(), stats, node_stats.get()); + return aggr.Function().GetCallbacks().GetStatisticsCallback()(context, aggr, input); } } // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_case.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_case.cpp index d74a05755..906cd7bd9 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_case.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_case.cpp @@ -6,8 +6,8 @@ namespace duckdb { unique_ptr StatisticsPropagator::PropagateExpression(BoundCaseExpression &bound_case, unique_ptr &expr_ptr) { // propagate in all the children - auto result_stats = PropagateExpression(bound_case.else_expr); - for (auto &case_check : bound_case.case_checks) { + auto result_stats = PropagateExpression(bound_case.ElseMutable()); + for (auto &case_check : bound_case.CaseChecksMutable()) { PropagateExpression(case_check.when_expr); auto then_stats = PropagateExpression(case_check.then_expr); if (!then_stats) { diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp index 8be2c8a6a..6d972f4c5 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp @@ -232,12 +232,12 @@ unique_ptr StatisticsPropagator::TryPropagateCast(const BaseStat unique_ptr StatisticsPropagator::PropagateExpression(BoundCastExpression &cast, unique_ptr &expr_ptr) { - auto child_stats = PropagateExpression(cast.child); + auto child_stats = PropagateExpression(cast.ChildMutable()); if (!child_stats) { return nullptr; } - auto result_stats = TryPropagateCast(*child_stats, cast.child->GetReturnType(), cast.GetReturnType()); - if (cast.try_cast && result_stats) { + auto result_stats = TryPropagateCast(*child_stats, cast.Child().GetReturnType(), cast.GetReturnType()); + if (cast.IsTryCast() && result_stats) { result_stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); } return result_stats; diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_columnref.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_columnref.cpp index 94ed2b0ab..23c07c189 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_columnref.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_columnref.cpp @@ -5,7 +5,7 @@ namespace duckdb { unique_ptr StatisticsPropagator::PropagateExpression(BoundColumnRefExpression &colref, unique_ptr &expr_ptr) { - auto stats = statistics_map.find(colref.binding); + auto stats = statistics_map.find(colref.Binding()); if (stats == statistics_map.end()) { return nullptr; } diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_conjunction.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_conjunction.cpp index 31e415f19..557f1351a 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_conjunction.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_conjunction.cpp @@ -10,8 +10,8 @@ namespace duckdb { unique_ptr StatisticsPropagator::PropagateExpression(BoundConjunctionExpression &expr, unique_ptr &expr_ptr) { auto is_and = expr.GetExpressionType() == ExpressionType::CONJUNCTION_AND; - for (idx_t expr_idx = 0; expr_idx < expr.children.size(); expr_idx++) { - auto &child = expr.children[expr_idx]; + for (idx_t expr_idx = 0; expr_idx < expr.GetChildrenMutable().size(); expr_idx++) { + auto &child = expr.GetChildrenMutable()[expr_idx]; auto stats = PropagateExpression(child); if (!child->IsFoldable()) { continue; @@ -46,20 +46,20 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundConjun } } if (prune_child) { - expr.children.erase_at(expr_idx); + expr.GetChildrenMutable().erase_at(expr_idx); expr_idx--; continue; } expr_ptr = make_uniq(Value::BOOLEAN(constant_value)); return PropagateExpression(expr_ptr); } - if (expr.children.empty()) { + if (expr.GetChildrenMutable().empty()) { // if there are no children left, replace the conjunction with TRUE (for AND) or FALSE (for OR) expr_ptr = make_uniq(Value::BOOLEAN(is_and)); return PropagateExpression(expr_ptr); - } else if (expr.children.size() == 1) { + } else if (expr.GetChildrenMutable().size() == 1) { // if there is one child left, replace the conjunction with that one child - expr_ptr = std::move(expr.children[0]); + expr_ptr = std::move(expr.GetChildrenMutable()[0]); } return nullptr; } diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_constant.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_constant.cpp index a87328107..4fe004cff 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_constant.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_constant.cpp @@ -12,7 +12,7 @@ unique_ptr StatisticsPropagator::StatisticsFromValue(const Value unique_ptr StatisticsPropagator::PropagateExpression(BoundConstantExpression &constant, unique_ptr &expr_ptr) { - return StatisticsFromValue(constant.value); + return StatisticsFromValue(constant.GetValue()); } } // namespace duckdb diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp index 1d4b21f6d..a7b74cdf0 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_function.cpp @@ -16,8 +16,8 @@ static bool TryEvaluateAtConstants(ClientContext &context, const BoundFunctionEx for (auto &v : arg_values) { children.push_back(make_uniq(v)); } - auto bind_info_clone = func.bind_info ? func.bind_info->Copy() : nullptr; - BoundFunctionExpression clone(func.function, std::move(children), std::move(bind_info_clone), func.is_operator); + auto bind_info_clone = func.BindInfo() ? func.BindInfo()->Copy() : nullptr; + BoundFunctionExpression clone(func.Function(), std::move(children), std::move(bind_info_clone), func.IsOperator()); return ExpressionExecutor::TryEvaluateScalar(context, clone, result); } @@ -25,10 +25,10 @@ static bool TryEvaluateAtConstants(ClientContext &context, const BoundFunctionEx //! Decreasing args are swapped so f(lo_args) and f(hi_args) bracket the output range. static unique_ptr TryPropagateMonotoneBounds(ClientContext &context, BoundFunctionExpression &func, const vector &child_stats) { - if (!func.function.HasArgProperties() || func.children.empty()) { + if (!func.Function().HasArgProperties() || func.GetChildren().empty()) { return nullptr; } - if (func.function.GetStability() != FunctionStability::CONSISTENT) { + if (func.Function().GetStability() != FunctionStability::CONSISTENT) { return nullptr; } if (BaseStatistics::GetStatsType(func.GetReturnType()) != StatisticsType::NUMERIC_STATS || @@ -36,15 +36,15 @@ static unique_ptr TryPropagateMonotoneBounds(ClientContext &cont return nullptr; } - vector lo_args(func.children.size()); - vector hi_args(func.children.size()); + vector lo_args(func.GetChildren().size()); + vector hi_args(func.GetChildren().size()); // SPECIAL_HANDLING means the function may produce nulls on non-null inputs (e.g. try_cast). - bool output_can_have_null = (func.function.GetNullHandling() != FunctionNullHandling::DEFAULT_NULL_HANDLING); + bool output_can_have_null = (func.Function().GetNullHandling() != FunctionNullHandling::DEFAULT_NULL_HANDLING); - for (idx_t i = 0; i < func.children.size(); i++) { - auto &child = *func.children[i]; + for (idx_t i = 0; i < func.GetChildren().size(); i++) { + auto &child = *func.GetChildren()[i]; auto &cs = child_stats[i]; - auto &props = func.function.GetArgProperties(i); + auto &props = func.Function().GetArgProperties(i); if (child.IsFoldable()) { Value v; @@ -111,7 +111,7 @@ static unique_ptr TryPropagateMonotoneBounds(ClientContext &cont } if (out_hi < out_lo) { throw InternalException("Monotonic arg annotation violated for '%s': output min exceeds output max", - func.function.GetName()); + func.Function().GetName()); } auto result = NumericStats::CreateEmpty(func.GetReturnType()); @@ -134,18 +134,18 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundFuncti return PropagateComparison(func, expr_ptr); } vector stats; - stats.reserve(func.children.size()); - for (idx_t i = 0; i < func.children.size(); i++) { - auto stat = PropagateExpression(func.children[i]); + stats.reserve(func.GetChildrenMutable().size()); + for (idx_t i = 0; i < func.GetChildrenMutable().size(); i++) { + auto stat = PropagateExpression(func.GetChildrenMutable()[i]); if (!stat) { - stats.push_back(BaseStatistics::CreateUnknown(func.children[i]->GetReturnType())); + stats.push_back(BaseStatistics::CreateUnknown(func.GetChildrenMutable()[i]->GetReturnType())); } else { stats.push_back(stat->Copy()); } } - if (func.function.HasStatisticsCallback()) { - FunctionStatisticsInput input(func, func.bind_info.get(), stats, &expr_ptr); - return func.function.GetStatisticsCallback()(context, input); + if (func.Function().HasStatisticsCallback()) { + FunctionStatisticsInput input(func, func.BindInfo().get(), stats, &expr_ptr); + return func.Function().GetStatisticsCallback()(context, input); } return TryPropagateMonotoneBounds(context, func, stats); } diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_operator.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_operator.cpp index 234459478..2600572ed 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_operator.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_operator.cpp @@ -8,8 +8,8 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundOperat unique_ptr &expr_ptr) { bool all_have_stats = true; vector> child_stats; - child_stats.reserve(expr.children.size()); - for (auto &child : expr.children) { + child_stats.reserve(expr.GetChildrenMutable().size()); + for (auto &child : expr.GetChildrenMutable()) { auto stats = PropagateExpression(child); if (!stats) { all_have_stats = false; @@ -22,13 +22,13 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundOperat switch (expr.GetExpressionType()) { case ExpressionType::OPERATOR_COALESCE: // COALESCE, merge stats of all children - for (idx_t i = 0; i < expr.children.size(); i++) { + for (idx_t i = 0; i < expr.GetChildrenMutable().size(); i++) { D_ASSERT(child_stats[i]); if (!child_stats[i]->CanHaveNoNull()) { // this child is always NULL, we can remove it from the coalesce // UNLESS there is only one node remaining - if (expr.children.size() > 1) { - expr.children.erase_at(i); + if (expr.GetChildrenMutable().size() > 1) { + expr.GetChildrenMutable().erase_at(i); child_stats.erase_at(i); i--; } @@ -36,22 +36,23 @@ unique_ptr StatisticsPropagator::PropagateExpression(BoundOperat // coalesce child cannot have NULL entries // this is the last coalesce node that influences the result // we can erase any children after this node - if (i + 1 < expr.children.size()) { - expr.children.erase(expr.children.begin() + NumericCast(i + 1), expr.children.end()); + if (i + 1 < expr.GetChildrenMutable().size()) { + expr.GetChildrenMutable().erase(expr.GetChildrenMutable().begin() + NumericCast(i + 1), + expr.GetChildrenMutable().end()); child_stats.erase(child_stats.begin() + NumericCast(i + 1), child_stats.end()); } break; } } - D_ASSERT(!expr.children.empty()); - D_ASSERT(expr.children.size() == child_stats.size()); - if (expr.children.size() == 1) { + D_ASSERT(!expr.GetChildrenMutable().empty()); + D_ASSERT(expr.GetChildrenMutable().size() == child_stats.size()); + if (expr.GetChildrenMutable().size() == 1) { // coalesce of one entry: simply return that entry - expr_ptr = std::move(expr.children[0]); + expr_ptr = std::move(expr.GetChildrenMutable()[0]); } else { // coalesce of multiple entries // merge the stats - for (idx_t i = 1; i < expr.children.size(); i++) { + for (idx_t i = 1; i < expr.GetChildrenMutable().size(); i++) { child_stats[0]->Merge(*child_stats[i]); } } diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp index 0b8bfd18e..00ec0951a 100644 --- a/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_aggregate.cpp @@ -10,7 +10,11 @@ #include "duckdb/function/function_binder.hpp" #include "duckdb/function/partition_stats.hpp" #include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/planner/binder.hpp" #include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/filter/expression_filter.hpp" #include "duckdb/planner/operator/logical_aggregate.hpp" #include "duckdb/planner/operator/logical_dummy_scan.hpp" #include "duckdb/planner/operator/logical_get.hpp" @@ -19,6 +23,7 @@ #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/statistics/string_stats.hpp" #include "duckdb/storage/storage_index.hpp" @@ -55,7 +60,7 @@ struct MaxValueComp : public ValueComparator { }; template -unique_ptr GetComparator(const string &fun_name) { +unique_ptr GetComparator(const Identifier &fun_name) { if (fun_name == "min") { return make_uniq>(); } @@ -63,7 +68,7 @@ unique_ptr GetComparator(const string &fun_name) { return make_uniq>(); } -unique_ptr GetComparator(const string &fun_name, const LogicalType &type) { +unique_ptr GetComparator(const Identifier &fun_name, const LogicalType &type) { if (type == LogicalType::VARCHAR) { return GetComparator(fun_name); } else if (type.IsNumeric() || type.IsTemporal()) { @@ -94,11 +99,29 @@ bool TryGetValueFromStats(const PartitionStatistics &stats, const StorageIndex & if (!StringStats::HasMinMax(*column_stats)) { return false; } + if (StringStats::GetMinType(*column_stats) != StringStatsType::EXACT_STATS || + StringStats::GetMaxType(*column_stats) != StringStatsType::EXACT_STATS) { + // string stats are not exact - we cannot use the value from the stats + return false; + } } result = comparator.GetVal(*column_stats); return true; } +bool GroupingSetCanIntroduceNull(const LogicalAggregate &aggr, idx_t group_idx) { + if (aggr.grouping_sets.empty()) { + return false; + } + const auto projection_idx = ProjectionIndex(group_idx); + for (const auto &grouping_set : aggr.grouping_sets) { + if (grouping_set.find(projection_idx) == grouping_set.end()) { + return true; + } + } + return false; +} + } // namespace void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_ptr &node_ptr) { @@ -118,18 +141,22 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p return; } auto &aggr_expr = aggr_ref->Cast(); - if (aggr_expr.filter) { + if (aggr_expr.GetFilter()) { // aggregate has a filter - bail return; } - const string &fun_name = aggr_expr.function.GetName(); + if (aggr_expr.StateExportMode() == AggregateStateExportMode::STATE_EXPORT) { + // aggregate is in state export mode - cannot replace with a constant + return; + } + auto &fun_name = aggr_expr.Function().GetName(); if (fun_name == "min" || fun_name == "max") { - if (aggr_expr.children.size() != 1 || - aggr_expr.children[0]->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + if (aggr_expr.GetChildren().size() != 1 || + aggr_expr.GetChildren()[0]->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { return; } - const auto &col_ref = aggr_expr.children[0]->Cast(); - min_max_bindings.push_back(col_ref.binding); + const auto &col_ref = aggr_expr.GetChildren()[0]->Cast(); + min_max_bindings.push_back(col_ref.Binding()); auto comparator = GetComparator(fun_name, col_ref.GetReturnType()); if (!comparator) { // Type has no min max statistics @@ -153,7 +180,7 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p if (expr.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { return; } - binding = expr.Cast().binding; + binding = expr.Cast().Binding(); } child_ref = *child_ref.get().children[0]; } @@ -193,6 +220,8 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p vector types; vector> agg_results; + bool need_to_scan = false; + vector scan_partition_indices; // we can keep execute eager aggregate if all partitions could be either filtered entirely or remained entirely if (get.table_filters.HasFilters()) { map> filter_storage_index_map; @@ -206,8 +235,9 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p } filter_storage_index_map.emplace(storage_index, filter); } - vector remaining_partition_stats; - for (auto &stats : partition_stats) { + vector precomputed_partition_stats; + for (idx_t partition_idx = 0; partition_idx < partition_stats.size(); partition_idx++) { + auto &stats = partition_stats[partition_idx]; if (!stats.partition_row_group) { return; } @@ -223,7 +253,9 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p if (!column_stats) { return; } - auto col_filter_result = filter.get().CheckStatistics(*column_stats); + auto &expr_filter = + ExpressionFilter::GetExpressionFilter(filter.get(), "AggregateStats::CheckPartitionFilters"); + auto col_filter_result = expr_filter.CheckStatistics(*column_stats); if (col_filter_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { // all data in this partition is filtered out, remove this partition entirely filter_result = FilterPropagateResult::FILTER_ALWAYS_FALSE; @@ -235,17 +267,21 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p } switch (filter_result) { case FilterPropagateResult::FILTER_ALWAYS_TRUE: - // all filters passed - this partition should keep execute eager aggregate - remaining_partition_stats.push_back(std::move(stats)); + precomputed_partition_stats.push_back(std::move(stats)); break; case FilterPropagateResult::FILTER_ALWAYS_FALSE: break; default: - // any filter that is not always true/false - bail - return; + need_to_scan = true; + scan_partition_indices.push_back(partition_idx); + break; } } - partition_stats = std::move(remaining_partition_stats); + if (precomputed_partition_stats.empty()) { + // no partitions can be pre-computed + return; + } + partition_stats = std::move(precomputed_partition_stats); } if (!min_max_bindings.empty()) { @@ -289,6 +325,66 @@ void StatisticsPropagator::TryExecuteAggregates(LogicalAggregate &aggr, unique_p } } + if (need_to_scan) { + // Partial precomputation: some partitions need scanning + // Insert a LogicalProjection above the aggregate that combines pre-computed constants with scan results + if (!get.function.set_partitions_to_scan) { + // scan does not support partition filtering - bail out + return; + } + + // Build projection expressions that merge pre-computed values with aggregate results + auto proj_index = optimizer.binder.GenerateTableIndex(); + vector> proj_expressions; + for (idx_t i = 0; i < aggr.expressions.size(); i++) { + auto &aggr_expr = aggr.expressions[i]->Cast(); + auto &fun_name = aggr_expr.Function().GetName(); + + // Reference to the aggregate output column + auto agg_col_ref = make_uniq( + aggr_expr.GetReturnType(), ColumnBinding(aggr.aggregate_index, ProjectionIndex(i))); + + if (fun_name == "count_star") { + // pre_count + count_star_from_scan + auto &pre_count_expr = agg_results[i]; + auto add_expr = optimizer.BindScalarFunction("+", pre_count_expr->Copy(), std::move(agg_col_ref)); + add_expr->SetAlias(aggr.expressions[i]->GetAlias()); + proj_expressions.push_back(std::move(add_expr)); + } else if (fun_name == "min" || fun_name == "max") { + // For min: COALESCE(least(pre_min, agg_min), pre_min) + // For max: COALESCE(greatest(pre_max, agg_max), pre_max) + auto &pre_val_expr = agg_results[i]; + Identifier merge_func((fun_name == "min") ? "least" : "greatest"); + auto merged = optimizer.BindScalarFunction(merge_func, pre_val_expr->Copy(), std::move(agg_col_ref)); + auto coalesce = + make_uniq(ExpressionType::OPERATOR_COALESCE, aggr_expr.GetReturnType()); + coalesce->GetChildrenMutable().push_back(std::move(merged)); + coalesce->GetChildrenMutable().push_back(pre_val_expr->Copy()); + coalesce->SetAlias(aggr.expressions[i]->GetAlias()); + proj_expressions.push_back(std::move(coalesce)); + } + } + + // Tell the scan to only scan partitions whose aggregates were NOT pre-computed + get.SetPartitionsToScan(std::move(scan_partition_indices)); + + // Create LogicalProjection above the aggregate + auto projection = make_uniq(proj_index, std::move(proj_expressions)); + projection->children.push_back(std::move(node_ptr)); + + ColumnBindingReplacer replacer; + for (idx_t i = 0; i < aggr.expressions.size(); i++) { + auto old_binding = ColumnBinding(aggr.aggregate_index, ProjectionIndex(i)); + auto new_binding = ColumnBinding(proj_index, ProjectionIndex(i)); + replacer.replacement_bindings.emplace_back(old_binding, new_binding); + } + + replacer.stop_operator = projection.get(); + node_ptr = std::move(projection); + replacer.VisitOperator(*root); + return; + } + // Set column names for (idx_t expr_idx = 0; expr_idx < agg_results.size(); expr_idx++) { agg_results[expr_idx]->SetAlias(aggr.expressions[expr_idx]->GetAlias()); @@ -311,16 +407,13 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalAggr aggr.group_stats.resize(aggr.groups.size()); for (idx_t group_idx = 0; group_idx < aggr.groups.size(); group_idx++) { auto stats = PropagateExpression(aggr.groups[group_idx]); + if (stats && GroupingSetCanIntroduceNull(aggr, group_idx)) { + stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } aggr.group_stats[group_idx] = stats ? stats->ToUnique() : nullptr; if (!stats) { continue; } - if (aggr.grouping_sets.size() > 1) { - // aggregates with multiple grouping sets can introduce NULL values to certain groups - // FIXME: actually figure out WHICH groups can have null values introduced - stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); - continue; - } ColumnBinding group_binding(aggr.group_index, ProjectionIndex(group_idx)); statistics_map[group_binding] = std::move(stats); } @@ -341,7 +434,7 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalAggr vector> bindings; if (count_matcher->Match(*expr, bindings)) { auto &aggr_expr = expr->Cast(); - const auto child_stats = PropagateExpression(aggr_expr.children[0]); + const auto child_stats = PropagateExpression(aggr_expr.GetChildrenMutable()[0]); if (child_stats && !child_stats->CanHaveNull()) { expr = function_binder.BindAggregateFunction(count_fun, {}, nullptr, AggregateType::NON_DISTINCT); } @@ -367,14 +460,14 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalAggr break; } auto &aggr_expr = aggr_ref->Cast(); - for (const auto &child : aggr_expr.children) { + for (const auto &child : aggr_expr.GetChildren()) { if (child->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { // Bail if bound aggregate child is not a colref distinct_validity = TupleDataValidityType::CAN_HAVE_NULL_VALUES; break; } const auto &col_ref = child->Cast(); - auto it = statistics_map.find(col_ref.binding); + auto it = statistics_map.find(col_ref.Binding()); if (it == statistics_map.end() || !it->second || it->second->CanHaveNull()) { // Bail if no stats or if there can be a NULL distinct_validity = TupleDataValidityType::CAN_HAVE_NULL_VALUES; diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_filter.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_filter.cpp index f61f665ed..6854237ec 100644 --- a/src/duckdb/src/optimizer/statistics/operator/propagate_filter.cpp +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_filter.cpp @@ -18,7 +18,7 @@ static bool IsCompareDistinct(ExpressionType type) { bool StatisticsPropagator::ExpressionIsConstant(Expression &expr, const Value &val) { Value expr_value; if (expr.GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { - expr_value = expr.Cast().value; + expr_value = expr.Cast().GetValue(); } else if (expr.IsFoldable()) { if (!ExpressionExecutor::TryEvaluateScalar(context, expr, expr_value)) { return false; @@ -163,10 +163,10 @@ void StatisticsPropagator::UpdateFilterStatistics(const Expression &left, const // any column ref involved in a comparison will not be null after the comparison bool compare_distinct = IsCompareDistinct(comparison_type); if (!compare_distinct && left.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - SetStatisticsNotNull((left.Cast()).binding); + SetStatisticsNotNull((left.Cast()).Binding()); } if (!compare_distinct && right.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - SetStatisticsNotNull((right.Cast()).binding); + SetStatisticsNotNull((right.Cast()).Binding()); } // check if this is a comparison between a constant and a column ref optional_ptr constant; @@ -185,8 +185,8 @@ void StatisticsPropagator::UpdateFilterStatistics(const Expression &left, const // comparison between two column refs auto &left_column_ref = left.Cast(); auto &right_column_ref = right.Cast(); - auto lentry = statistics_map.find(left_column_ref.binding); - auto rentry = statistics_map.find(right_column_ref.binding); + auto lentry = statistics_map.find(left_column_ref.Binding()); + auto rentry = statistics_map.find(right_column_ref.Binding()); if (lentry == statistics_map.end() || rentry == statistics_map.end()) { return; } @@ -197,11 +197,11 @@ void StatisticsPropagator::UpdateFilterStatistics(const Expression &left, const } if (constant && columnref) { // comparison between columnref - auto entry = statistics_map.find(columnref->binding); + auto entry = statistics_map.find(columnref->Binding()); if (entry == statistics_map.end()) { return; } - UpdateFilterStatistics(*entry->second, comparison_type, constant->value); + UpdateFilterStatistics(*entry->second, comparison_type, constant->GetValue()); } } diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp index 6a406b828..327177741 100644 --- a/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_get.cpp @@ -2,8 +2,6 @@ #include "duckdb/optimizer/statistics_propagator.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_operator_expression.hpp" @@ -21,10 +19,10 @@ static bool IsDirectFilterColumnRef(const Expression &expr) { expr.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF; } -static void GetColumnIndex(const unique_ptr &expr, idx_t &index, string &alias) { +static void GetColumnIndex(const unique_ptr &expr, idx_t &index, Identifier &alias) { if (expr->GetExpressionType() == ExpressionType::BOUND_REF) { auto &bound_ref = expr->Cast(); - index = bound_ref.index; + index = bound_ref.Index(); alias = bound_ref.GetAlias(); return; } @@ -34,15 +32,12 @@ static void GetColumnIndex(const unique_ptr &expr, idx_t &index, str FilterPropagateResult StatisticsPropagator::PropagateTableFilter(ColumnBinding stats_binding, BaseStatistics &stats, TableFilter &filter) { - if (filter.filter_type != TableFilterType::EXPRESSION_FILTER) { - return filter.CheckStatistics(stats); - } auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "StatisticsPropagator::PropagateTableFilter"); // get physical storage index of the filter // since it is a table filter, every storage index is the same idx_t physical_index = DConstants::INVALID_INDEX; - string column_alias; + Identifier column_alias; GetColumnIndex(expr_filter.expr, physical_index, column_alias); D_ASSERT(physical_index != DConstants::INVALID_INDEX); @@ -71,19 +66,6 @@ FilterPropagateResult StatisticsPropagator::PropagateTableFilter(ColumnBinding s } void StatisticsPropagator::UpdateFilterStatistics(BaseStatistics &input, const TableFilter &filter) { - if (filter.filter_type != TableFilterType::EXPRESSION_FILTER) { - switch (filter.filter_type) { - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and = filter.Cast(); - for (auto &child_filter : conjunction_and.child_filters) { - UpdateFilterStatistics(input, *child_filter); - } - return; - } - default: - return; - } - } auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "StatisticsPropagator::UpdateFilterStatistics"); UpdateExpressionFilterStatistics(input, *expr_filter.expr); } @@ -102,15 +84,15 @@ void StatisticsPropagator::UpdateExpressionFilterStatistics(BaseStatistics &inpu auto &right = BoundComparisonExpression::Right(comp); if (IsDirectFilterColumnRef(left) && right.GetExpressionType() == ExpressionType::VALUE_CONSTANT) { auto &constant = right.Cast(); - if (constant.value.type().InternalType() == input.GetType().InternalType()) { - UpdateFilterStatistics(input, comp.GetExpressionType(), constant.value); + if (constant.GetValue().type().InternalType() == input.GetType().InternalType()) { + UpdateFilterStatistics(input, comp.GetExpressionType(), constant.GetValue()); } else if (!is_compare_distinct) { input.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); } } else if (left.GetExpressionType() == ExpressionType::VALUE_CONSTANT && IsDirectFilterColumnRef(right)) { auto &constant = left.Cast(); - if (constant.value.type().InternalType() == input.GetType().InternalType()) { - UpdateFilterStatistics(input, FlipComparisonExpression(comp.GetExpressionType()), constant.value); + if (constant.GetValue().type().InternalType() == input.GetType().InternalType()) { + UpdateFilterStatistics(input, FlipComparisonExpression(comp.GetExpressionType()), constant.GetValue()); } else if (!is_compare_distinct) { input.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); } @@ -120,7 +102,7 @@ void StatisticsPropagator::UpdateExpressionFilterStatistics(BaseStatistics &inpu case ExpressionClass::BOUND_CONJUNCTION: { auto &conj = expr.Cast(); if (conj.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { - for (auto &child : conj.children) { + for (auto &child : conj.GetChildren()) { UpdateExpressionFilterStatistics(input, *child); } } @@ -128,8 +110,8 @@ void StatisticsPropagator::UpdateExpressionFilterStatistics(BaseStatistics &inpu } case ExpressionClass::BOUND_OPERATOR: { auto &op = expr.Cast(); - if (expr.GetExpressionType() == ExpressionType::OPERATOR_IS_NOT_NULL && !op.children.empty() && - IsDirectFilterColumnRef(*op.children[0])) { + if (expr.GetExpressionType() == ExpressionType::OPERATOR_IS_NOT_NULL && !op.GetChildren().empty() && + IsDirectFilterColumnRef(*op.GetChildren()[0])) { input.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); } break; @@ -140,9 +122,6 @@ void StatisticsPropagator::UpdateExpressionFilterStatistics(BaseStatistics &inpu } static bool IsConstantOrNullFilter(const TableFilter &table_filter) { - if (table_filter.filter_type != TableFilterType::EXPRESSION_FILTER) { - return false; - } auto &expr_filter = ExpressionFilter::GetExpressionFilter(table_filter, "IsConstantOrNullFilter"); if (expr_filter.expr->GetExpressionType() != ExpressionType::BOUND_FUNCTION) { return false; @@ -155,11 +134,10 @@ static bool CanReplaceConstantOrNull(const TableFilter &table_filter) { if (!IsConstantOrNullFilter(table_filter)) { throw InternalException("CanReplaceConstantOrNull() called on unexpected Table Filter"); } - D_ASSERT(table_filter.filter_type == TableFilterType::EXPRESSION_FILTER); auto &expr_filter = ExpressionFilter::GetExpressionFilter(table_filter, "CanReplaceConstantOrNull"); auto &func = expr_filter.expr->Cast(); if (ConstantOrNull::IsConstantOrNull(func, Value::BOOLEAN(true))) { - for (auto child = ++func.children.begin(); child != func.children.end(); child++) { + for (auto child = ++func.GetChildren().begin(); child != func.GetChildren().end(); child++) { switch (child->get()->GetExpressionType()) { case ExpressionType::BOUND_REF: case ExpressionType::VALUE_CONSTANT: diff --git a/src/duckdb/src/optimizer/statistics/operator/propagate_window.cpp b/src/duckdb/src/optimizer/statistics/operator/propagate_window.cpp index 99b2ea947..4114ddd5b 100644 --- a/src/duckdb/src/optimizer/statistics/operator/propagate_window.cpp +++ b/src/duckdb/src/optimizer/statistics/operator/propagate_window.cpp @@ -12,26 +12,26 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalWind // then propagate to each of the order expressions for (auto &window_expr : window.expressions) { auto &over_expr = window_expr->Cast(); - for (auto &expr : over_expr.partitions) { - over_expr.partitions_stats.push_back(PropagateExpression(expr)); + for (auto &expr : over_expr.PartitionsMutable()) { + over_expr.PartitionsStatsMutable().push_back(PropagateExpression(expr)); } - for (auto &bound_order : over_expr.orders) { + for (auto &bound_order : over_expr.OrderByMutable()) { bound_order.stats = PropagateExpression(bound_order.expression); } - if (over_expr.start_expr) { - over_expr.expr_stats.push_back(PropagateExpression(over_expr.start_expr)); + if (over_expr.StartExpr()) { + over_expr.ExprStatsMutable().push_back(PropagateExpression(over_expr.StartExprMutable())); } else { - over_expr.expr_stats.push_back(nullptr); + over_expr.ExprStatsMutable().push_back(nullptr); } - if (over_expr.end_expr) { - over_expr.expr_stats.push_back(PropagateExpression(over_expr.end_expr)); + if (over_expr.EndExpr()) { + over_expr.ExprStatsMutable().push_back(PropagateExpression(over_expr.EndExprMutable())); } else { - over_expr.expr_stats.push_back(nullptr); + over_expr.ExprStatsMutable().push_back(nullptr); } - for (auto &bound_order : over_expr.arg_orders) { + for (auto &bound_order : over_expr.ArgOrdersMutable()) { bound_order.stats = PropagateExpression(bound_order.expression); } } diff --git a/src/duckdb/src/optimizer/topn_optimizer.cpp b/src/duckdb/src/optimizer/topn_optimizer.cpp index 4529c2e84..22815eef9 100644 --- a/src/duckdb/src/optimizer/topn_optimizer.cpp +++ b/src/duckdb/src/optimizer/topn_optimizer.cpp @@ -5,10 +5,9 @@ #include "duckdb/planner/operator/logical_limit.hpp" #include "duckdb/planner/operator/logical_order.hpp" #include "duckdb/planner/operator/logical_top_n.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/dynamic_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/execution/operator/join/join_filter_pushdown.hpp" #include "duckdb/optimizer/join_filter_pushdown_optimizer.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" @@ -83,7 +82,7 @@ void TopN::PushdownDynamicFilters(LogicalTopN &op) { auto &colref = op.orders[0].expression->Cast(); vector columns; JoinFilterPushdownColumn column; - column.probe_column_index = colref.binding; + column.probe_column_index = colref.Binding(); columns.emplace_back(column); vector pushdown_targets; JoinFilterPushdownOptimizer::GetPushdownFilterTargets(*op.children[0], std::move(columns), pushdown_targets); @@ -116,20 +115,20 @@ void TopN::PushdownDynamicFilters(LogicalTopN &op) { auto col_binding = target.columns[0].probe_column_index; // create the actual dynamic filter - auto dynamic_filter = make_uniq(filter_data); - unique_ptr pushed_filter = std::move(dynamic_filter); + auto pushed_expr = CreateDynamicFilterExpression(filter_data, type); if (nulls_first) { - auto or_filter = make_uniq(); + auto or_filter = make_uniq(ExpressionType::CONJUNCTION_OR); auto is_null = ExpressionFilter::CreateNullCheckExpression( make_uniq(type, idx_t(0)), ExpressionType::OPERATOR_IS_NULL); - or_filter->child_filters.push_back(make_uniq(std::move(is_null))); - or_filter->child_filters.push_back(std::move(pushed_filter)); - pushed_filter = std::move(or_filter); + or_filter->GetChildrenMutable().push_back(std::move(is_null)); + or_filter->GetChildrenMutable().push_back(std::move(pushed_expr)); + pushed_expr = std::move(or_filter); } - auto optional_filter = make_uniq(std::move(pushed_filter)); // push the filter into the table scan - get.table_filters.PushFilter(col_binding.column_index, std::move(optional_filter)); + get.table_filters.PushFilter( + col_binding.column_index, + make_uniq(CreateOptionalFilterExpression(std::move(pushed_expr), type))); } } diff --git a/src/duckdb/src/optimizer/topn_window_elimination.cpp b/src/duckdb/src/optimizer/topn_window_elimination.cpp index b7fca4748..54005d119 100644 --- a/src/duckdb/src/optimizer/topn_window_elimination.cpp +++ b/src/duckdb/src/optimizer/topn_window_elimination.cpp @@ -159,13 +159,25 @@ string GetLHSRowIdColumnName(const unique_ptr &op, idx_t column D_ASSERT(op.get()->expressions.size() > column_id && op.get()->expressions[column_id]->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF); const auto &colref = op.get()->expressions[column_id]->Cast(); - column_id = colref.binding.column_index; + column_id = colref.Binding().column_index; current_op = *op.get()->children[0]; } const auto &logical_get = current_op.get().Cast(); const auto column_index = logical_get.GetColumnIds()[column_id]; - return logical_get.GetColumnName(column_index); + return logical_get.GetColumnName(column_index).GetIdentifierName(); +} + +//! Late materialization recreates a payload by re-scanning the base table column and re-applying a type cast at the +//! topmost projection. That only reproduces the original value when the projection expression is a plain column +//! reference, optionally wrapped in casts. A value-transforming expression (e.g. parse_filename(col)) cannot be +//! rebuilt that way, so it must not be eligible for late materialization. +bool IsRecreatableByLateMaterialization(const Expression &expr) { + reference current = expr; + while (current.get().GetExpressionClass() == ExpressionClass::BOUND_CAST) { + current = current.get().Cast().Child(); + } + return current.get().GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF; } } // namespace @@ -291,7 +303,8 @@ TopNWindowElimination::CreateAggregateExpression(vector> fun_name += params.order_type == OrderType::ASCENDING ? "min" : "max"; fun_name += params.can_be_null && (requires_arg || change_to_arg) ? "_nulls_last" : ""; - auto &fun_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, fun_name); + auto &fun_entry = + catalog.GetEntry(context, Identifier::DefaultSchema(), Identifier(fun_name)); const auto &fun = fun_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(aggregate_params)); return function_binder.BindAggregateFunction(fun, std::move(aggregate_params)); } @@ -300,7 +313,7 @@ unique_ptr TopNWindowElimination::CreateAggregateOperator(LogicalWindow &window, vector> args, const TopNWindowEliminationParameters ¶ms) const { auto &window_expr = window.expressions[0]->Cast(); - D_ASSERT(window_expr.orders.size() == 1); + D_ASSERT(window_expr.OrderBy().size() == 1); vector> aggregate_params; aggregate_params.reserve(3); @@ -312,14 +325,15 @@ TopNWindowElimination::CreateAggregateOperator(LogicalWindow &window, vector(context, DEFAULT_SCHEMA, "struct_pack"); + auto &struct_pack_entry = + catalog.GetEntry(context, Identifier::DefaultSchema(), "struct_pack"); const auto &struct_pack_fun = struct_pack_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(args)); auto struct_pack_expr = function_binder.BindScalarFunction(struct_pack_fun, std::move(args)); aggregate_params.push_back(std::move(struct_pack_expr)); } - aggregate_params.push_back(std::move(window_expr.orders[0].expression)); + aggregate_params.push_back(std::move(window_expr.OrderByMutable()[0].expression)); if (params.limit > 1) { aggregate_params.push_back(std::move(make_uniq(Value::BIGINT(params.limit)))); } @@ -332,7 +346,7 @@ TopNWindowElimination::CreateAggregateOperator(LogicalWindow &window, vector(optimizer.binder.GenerateTableIndex(), optimizer.binder.GenerateTableIndex(), std::move(select_list)); aggregate->groupings_index = optimizer.binder.GenerateTableIndex(); - aggregate->groups = std::move(window_expr.partitions); + aggregate->groups = std::move(window_expr.PartitionsMutable()); aggregate->children.push_back(std::move(window.children[0])); aggregate->ResolveOperatorTypes(); @@ -342,7 +356,7 @@ TopNWindowElimination::CreateAggregateOperator(LogicalWindow &window, vectorgroups[i]; if (group->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &column_ref = group->Cast(); - auto group_stats = stats->find(column_ref.binding); + auto group_stats = stats->find(column_ref.Binding()); if (group_stats == stats->end()) { continue; } @@ -360,7 +374,8 @@ TopNWindowElimination::CreateRowNumberGenerator(unique_ptr aggregate auto &catalog = Catalog::GetSystemCatalog(context); // array_length - auto &array_length_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "array_length"); + auto &array_length_entry = + catalog.GetEntry(context, Identifier::DefaultSchema(), "array_length"); vector> array_length_exprs; array_length_exprs.push_back(std::move(aggregate_column_ref)); array_length_exprs.push_back(make_uniq(1)); @@ -371,7 +386,7 @@ TopNWindowElimination::CreateRowNumberGenerator(unique_ptr aggregate // generate_series auto &generate_series_entry = - catalog.GetEntry(context, DEFAULT_SCHEMA, "generate_series"); + catalog.GetEntry(context, Identifier::DefaultSchema(), "generate_series"); vector> generate_series_exprs; generate_series_exprs.push_back(make_uniq(1)); @@ -385,7 +400,7 @@ TopNWindowElimination::CreateRowNumberGenerator(unique_ptr aggregate // unnest auto unnest_row_number_expr = make_uniq(LogicalType::BIGINT); unnest_row_number_expr->SetAlias("row_number"); - unnest_row_number_expr->child = std::move(bound_generate_series_fun); + unnest_row_number_expr->ChildMutable() = std::move(bound_generate_series_fun); return unique_ptr(std::move(unnest_row_number_expr)); } @@ -412,7 +427,7 @@ TopNWindowElimination::TryCreateUnnestOperator(unique_ptr op, vector> unnest_exprs; auto unnest_aggregate = make_uniq(ListType::GetChildType(aggregate_type)); - unnest_aggregate->child = aggregate_column_ref->Copy(); + unnest_aggregate->ChildMutable() = aggregate_column_ref->Copy(); unnest_exprs.push_back(std::move(unnest_aggregate)); if (params.include_row_number) { @@ -434,7 +449,7 @@ void TopNWindowElimination::AddStructExtractExprs( FunctionBinder function_binder(context); auto &catalog = Catalog::GetSystemCatalog(context); auto &struct_extract_entry = - catalog.GetEntry(context, DEFAULT_SCHEMA, "struct_extract"); + catalog.GetEntry(context, Identifier::DefaultSchema(), "struct_extract"); const auto &struct_extract_fun = struct_extract_entry.functions.GetFunctionByArguments(context, {struct_type, LogicalType::VARCHAR}); @@ -528,14 +543,14 @@ bool TopNWindowElimination::CanOptimize(LogicalOperator &op) { return false; } auto &filter_value = right.Cast(); - if (filter_value.value.type() != LogicalType::BIGINT) { + if (filter_value.GetValue().type() != LogicalType::BIGINT) { return false; } - if (filter_value.value.IsNull()) { + if (filter_value.GetValue().IsNull()) { return false; } - const auto bigint_value = filter_value.value.GetValue(); + const auto bigint_value = filter_value.GetValue().GetValue(); switch (comparison) { case ExpressionType::COMPARE_LESSTHANOREQUALTO: if (bigint_value < 1) { @@ -594,7 +609,7 @@ bool TopNWindowElimination::CanOptimize(LogicalOperator &op) { return false; } const auto &first_window_expr = window.expressions[0]->Cast(); - for (auto &partition : first_window_expr.partitions) { + for (auto &partition : first_window_expr.Partitions()) { if (partition->GetExpressionClass() != ExpressionClass::BOUND_COLUMN_REF) { return false; } @@ -611,13 +626,14 @@ bool TopNWindowElimination::CanOptimize(LogicalOperator &op) { } auto &window_expr = window.expressions[0]->Cast(); - if (window_expr.orders.size() != 1) { + if (window_expr.OrderBy().size() != 1) { return false; } - if (window_expr.orders[0].type != OrderType::DESCENDING && window_expr.orders[0].type != OrderType::ASCENDING) { + if (window_expr.OrderBy()[0].type != OrderType::DESCENDING && + window_expr.OrderBy()[0].type != OrderType::ASCENDING) { return false; } - if (window_expr.orders[0].null_order != OrderByNullType::NULLS_LAST) { + if (window_expr.OrderBy()[0].null_order != OrderByNullType::NULLS_LAST) { return false; } @@ -638,8 +654,8 @@ vector> TopNWindowElimination::GenerateAggregatePayload(c // Remember order of group columns to recreate that order in new bindings later column_binding_map_t group_bindings; - for (idx_t i = 0; i < window_expr.partitions.size(); i++) { - auto &expr = window_expr.partitions[i]; + for (idx_t i = 0; i < window_expr.PartitionsMutable().size(); i++) { + auto &expr = window_expr.PartitionsMutable()[i]; VisitExpression(&expr); group_bindings[column_references.begin()->first] = i; column_references.clear(); @@ -659,21 +675,21 @@ vector> TopNWindowElimination::GenerateAggregatePayload(c auto column_id = binding.ToString(); if (window.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION) { // The column index points to the correct column binding - aggregate_args.push_back( - make_uniq(column_id, window_child_types[binding.column_index], binding)); + aggregate_args.push_back(make_uniq( + Identifier(column_id), window_child_types[binding.column_index], binding)); } else { // The child operator could have multiple or no table indexes. Therefore, we must find the right type first const auto child_column_idx = static_cast(std::find(window_child_bindings.begin(), window_child_bindings.end(), binding) - window_child_bindings.begin()); - aggregate_args.push_back( - make_uniq(column_id, window_child_types[child_column_idx], binding)); + aggregate_args.push_back(make_uniq( + Identifier(column_id), window_child_types[child_column_idx], binding)); } } if (aggregate_args.size() == 1) { // If we only project the aggregate value itself, we do not need it as an arg - VisitExpression(&window_expr.orders[0].expression); + VisitExpression(&window_expr.OrderByMutable()[0].expression); if (column_references.size() != 1) { column_references.clear(); return aggregate_args; @@ -681,8 +697,8 @@ vector> TopNWindowElimination::GenerateAggregatePayload(c const auto aggregate_value_binding = column_references.begin()->first; column_references.clear(); - if (window_expr.orders[0].expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF && - aggregate_args[0]->Cast().binding == aggregate_value_binding) { + if (window_expr.OrderBy()[0].expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF && + aggregate_args[0]->Cast().Binding() == aggregate_value_binding) { return {}; } } @@ -843,16 +859,16 @@ TopNWindowElimination::ExtractOptimizerParameters(const LogicalWindow &window, c auto &filter_expr = filter.expressions[0]->Cast(); auto &limit_expr = BoundComparisonExpression::Right(filter_expr); - params.limit = limit_expr.Cast().value.GetValue(); + params.limit = limit_expr.Cast().GetValue().GetValue(); if (filter_expr.GetExpressionType() == ExpressionType::COMPARE_LESSTHAN) { --params.limit; } params.include_row_number = BindingsReferenceRowNumber(bindings, window); params.payload_type = aggregate_payload.size() > 1 ? TopNPayloadType::STRUCT_PACK : TopNPayloadType::SINGLE_COLUMN; auto &window_expr = window.expressions[0]->Cast(); - params.order_type = window_expr.orders[0].type; + params.order_type = window_expr.OrderBy()[0].type; - VisitExpression(&window_expr.orders[0].expression); + VisitExpression(&window_expr.OrderByMutable()[0].expression); if (params.payload_type == TopNPayloadType::SINGLE_COLUMN && !aggregate_payload.empty()) { VisitExpression(&aggregate_payload[0]); } @@ -887,18 +903,18 @@ bool TopNWindowElimination::CanUseLateMaterialization(const LogicalWindow &windo vector &lhs_projections, vector> &stack) { auto &window_expr = window.expressions[0]->Cast(); - vector projections(window_expr.partitions.size() + args.size()); + vector projections(window_expr.Partitions().size() + args.size()); // Build a projection list for an LHS table scan to recreate the column order of an aggregate with struct packing - for (idx_t i = 0; i < window_expr.partitions.size(); i++) { - auto &partition = window_expr.partitions[i]; + for (idx_t i = 0; i < window_expr.PartitionsMutable().size(); i++) { + auto &partition = window_expr.PartitionsMutable()[i]; if (!ExtractSingleBinding(&partition, projections[i])) { return false; } } for (idx_t i = 0; i < args.size(); i++) { auto &arg = args[i]; - if (!ExtractSingleBinding(&arg, projections[window_expr.partitions.size() + i])) { + if (!ExtractSingleBinding(&arg, projections[window_expr.Partitions().size() + i])) { return false; } } @@ -919,6 +935,11 @@ bool TopNWindowElimination::CanUseLateMaterialization(const LogicalWindow &windo if (projection_idx >= projection.expressions.size()) { return false; } + if (!IsRecreatableByLateMaterialization(*projection.expressions[projection_idx])) { + // The projected value is derived through a value-transforming expression that the rowid semi-join + // cannot reconstruct. Fall back to carrying the computed value through the aggregate instead. + return false; + } if (!ExtractSingleBinding(&projection.expressions[projection_idx], projections[i])) { return false; } @@ -1206,9 +1227,10 @@ unique_ptr TopNWindowElimination::ConstructJoin(unique_ptr( - alias, lhs->types[lhs_rowid_idx], ColumnBinding {lhs->GetTableIndex()[0], ProjectionIndex(lhs_rowid_idx)}); + Identifier(alias), lhs->types[lhs_rowid_idx], + ColumnBinding {lhs->GetTableIndex()[0], ProjectionIndex(lhs_rowid_idx)}); auto rhs_expr = - make_uniq(alias, rhs->types[aggregate_offset + i], + make_uniq(Identifier(alias), rhs->types[aggregate_offset + i], ColumnBinding {GetAggregateIdx(rhs), ProjectionIndex(rhs_rowid_idx)}); join->conditions.push_back( JoinCondition(std::move(lhs_expr), std::move(rhs_expr), ExpressionType::COMPARE_EQUAL)); diff --git a/src/duckdb/src/optimizer/unnest_rewriter.cpp b/src/duckdb/src/optimizer/unnest_rewriter.cpp index 8198da073..22451dd21 100644 --- a/src/duckdb/src/optimizer/unnest_rewriter.cpp +++ b/src/duckdb/src/optimizer/unnest_rewriter.cpp @@ -28,8 +28,8 @@ void UnnestRewriterPlanUpdater::VisitExpression(unique_ptr *expressi if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { auto &bound_column_ref = expr->Cast(); for (idx_t i = 0; i < replace_bindings.size(); i++) { - if (bound_column_ref.binding == replace_bindings[i].old_binding) { - bound_column_ref.binding = replace_bindings[i].new_binding; + if (bound_column_ref.Binding() == replace_bindings[i].old_binding) { + bound_column_ref.BindingMutable() = replace_bindings[i].new_binding; break; } } @@ -162,12 +162,12 @@ void UnnestRewriter::FindCandidates(unique_ptr &root, unique_pt } auto &bind_col = proj.expressions[col_bind.column_index]->Cast(); auto unnest_expr = make_uniq(unnest_get->types[i]); - unnest_expr->child = proj.expressions[col_bind.column_index]->Copy(); - bind_col.binding = ColumnBinding(unnest_get_index, bind_col.binding.column_index); + unnest_expr->ChildMutable() = proj.expressions[col_bind.column_index]->Copy(); + bind_col.BindingMutable() = ColumnBinding(unnest_get_index, bind_col.Binding().column_index); auto unnest_proj_idx = ColumnBinding::PushExpression(unnest->expressions, std::move(unnest_expr)); - ColumnBinding new_column_ref(bind_col.binding.table_index, unnest_proj_idx); + ColumnBinding new_column_ref(bind_col.Binding().table_index, unnest_proj_idx); auto unnest_ref = make_uniq(bind_col.GetAlias(), unnest_get->types[i], - new_column_ref, bind_col.depth); + new_column_ref, bind_col.Depth()); proj.expressions[col_bind.column_index] = std::move(unnest_ref); proj.types[col_bind.column_index] = unnest_get->types[i]; replacer.replacement_bindings.push_back(ReplacementBinding( @@ -376,7 +376,7 @@ void UnnestRewriter::GetDelimColumns(LogicalOperator &op) { auto &expr = *delim_join.duplicate_eliminated_columns[i]; D_ASSERT(expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF); auto &bound_colref_expr = expr.Cast(); - delim_columns.push_back(bound_colref_expr.binding); + delim_columns.push_back(bound_colref_expr.Binding()); } } diff --git a/src/duckdb/src/optimizer/window_self_join.cpp b/src/duckdb/src/optimizer/window_self_join.cpp index 1cc315722..087ec1313 100644 --- a/src/duckdb/src/optimizer/window_self_join.cpp +++ b/src/duckdb/src/optimizer/window_self_join.cpp @@ -8,49 +8,47 @@ #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/function/aggregate_state.hpp" -#include "duckdb/planner/logical_operator_visitor.hpp" -#include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/logical_operator_deep_copy.hpp" namespace duckdb { static unique_ptr TranslateAggregate(const BoundWindowExpression &w_expr) { - auto agg_func = *w_expr.aggregate; + auto agg_func = *w_expr.AggregateFunction(); unique_ptr bind_info; - if (w_expr.bind_info) { - bind_info = w_expr.bind_info->Copy(); + if (w_expr.BindInfo()) { + bind_info = w_expr.BindInfo()->Copy(); } else { bind_info = nullptr; } vector> children; - for (auto &child : w_expr.children) { + for (auto &child : w_expr.GetChildren()) { auto child_copy = child->Copy(); children.push_back(std::move(child_copy)); } unique_ptr filter; - if (w_expr.filter_expr) { - filter = w_expr.filter_expr->Copy(); + if (w_expr.Filter()) { + filter = w_expr.Filter()->Copy(); } - auto aggr_type = w_expr.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT; + auto aggr_type = w_expr.Distinct() ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT; auto result = make_uniq(std::move(agg_func), std::move(children), std::move(filter), std::move(bind_info), aggr_type); - if (!w_expr.arg_orders.empty()) { - result->order_bys = make_uniq(); - auto &orders = result->order_bys->orders; - for (auto &order : w_expr.arg_orders) { + if (!w_expr.ArgOrders().empty()) { + result->GetOrderBysMutable() = make_uniq(); + auto &orders = result->GetOrderBysMutable()->orders; + for (auto &order : w_expr.ArgOrders()) { auto order_copy = order.Copy(); orders.emplace_back(std::move(order_copy)); } - } else if (!w_expr.orders.empty()) { + } else if (!w_expr.OrderBy().empty()) { // If the frame was ordered, copy the frame ordering to the aggregate function - result->order_bys = make_uniq(); - auto &orders = result->order_bys->orders; - for (auto &order : w_expr.orders) { + result->GetOrderBysMutable() = make_uniq(); + auto &orders = result->GetOrderBysMutable()->orders; + for (auto &order : w_expr.OrderBy()) { auto order_copy = order.Copy(); orders.emplace_back(std::move(order_copy)); } @@ -79,12 +77,12 @@ bool WindowSelfJoinOptimizer::CanOptimize(const BoundWindowExpression &w_expr, // We can only accept ORDER BY clauses if the frame is the entire partition // In that case, we will have to move the ordering clauses into the aggregate. // ROWS framing is excluded because the frame depends on physical row position even without ORDER BY. - switch (w_expr.start) { + switch (w_expr.WindowStart()) { case WindowBoundary::UNBOUNDED_PRECEDING: break; case WindowBoundary::CURRENT_ROW_RANGE: case WindowBoundary::CURRENT_ROW_GROUPS: - if (!w_expr.orders.empty()) { + if (!w_expr.OrderBy().empty()) { return false; } break; @@ -92,22 +90,22 @@ bool WindowSelfJoinOptimizer::CanOptimize(const BoundWindowExpression &w_expr, return false; } - switch (w_expr.end) { + switch (w_expr.WindowEnd()) { case WindowBoundary::UNBOUNDED_FOLLOWING: break; case WindowBoundary::CURRENT_ROW_RANGE: case WindowBoundary::CURRENT_ROW_GROUPS: - if (!w_expr.orders.empty()) { + if (!w_expr.OrderBy().empty()) { return false; } break; default: return false; } - if (w_expr.partitions.empty()) { + if (w_expr.Partitions().empty()) { return false; } - if (w_expr.exclude_clause != WindowExcludeMode::NO_OTHER) { + if (w_expr.WindowExclude() != WindowExcludeMode::NO_OTHER) { return false; } if (!w_expr.PartitionsAreEquivalent(w_expr0)) { @@ -117,6 +115,24 @@ bool WindowSelfJoinOptimizer::CanOptimize(const BoundWindowExpression &w_expr, return true; } +bool WindowSelfJoinOptimizer::CanOptimize(const LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_GET: + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + case LogicalOperatorType::LOGICAL_PROJECTION: + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + case LogicalOperatorType::LOGICAL_FILTER: + if (!op.children.empty()) { + return CanOptimize(*op.children[0]); + } + return true; + default: + break; + } + return false; +} + unique_ptr WindowSelfJoinOptimizer::OptimizeInternal(unique_ptr op, ColumnBindingReplacer &replacer) { if (op->type == LogicalOperatorType::LOGICAL_WINDOW) { @@ -125,6 +141,10 @@ unique_ptr WindowSelfJoinOptimizer::OptimizeInternal(unique_ptr // Check recursively window.children[0] = OptimizeInternal(std::move(window.children[0]), replacer); + if (!CanOptimize(*window.children[0])) { + return op; + } + auto &w_expr0 = window.expressions[0]->Cast(); for (auto &expr : window.expressions) { auto &w_expr = expr->Cast(); @@ -132,7 +152,7 @@ unique_ptr WindowSelfJoinOptimizer::OptimizeInternal(unique_ptr return op; } } - auto &partitions = w_expr0.partitions; + auto &partitions = w_expr0.Partitions(); // --- Transformation --- // try to copy the LHS @@ -142,7 +162,7 @@ unique_ptr WindowSelfJoinOptimizer::OptimizeInternal(unique_ptr LogicalOperatorDeepCopy deep_copy(optimizer.binder, nullptr); try { copy_child = deep_copy.DeepCopy(window.children[0]); - } catch (NotImplementedException &ex) { + } catch (std::exception &ex) { // failed to copy the LHS - cannot run this optimizer return op; } diff --git a/src/duckdb/src/parallel/async_result.cpp b/src/duckdb/src/parallel/async_result.cpp index 150e791ef..a956fe75f 100644 --- a/src/duckdb/src/parallel/async_result.cpp +++ b/src/duckdb/src/parallel/async_result.cpp @@ -4,6 +4,8 @@ #include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/execution/executor.hpp" #include "duckdb/execution/physical_table_scan_enum.hpp" +#include "duckdb/logging/log_type.hpp" +#include "duckdb/logging/logger.hpp" #ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE #include "duckdb/parallel/sleep_async_task.hpp" @@ -11,31 +13,43 @@ namespace duckdb { -struct Counter { - explicit Counter(idx_t size) : counter(size) { +struct AsyncBatchCompletion { + explicit AsyncBatchCompletion(idx_t size) : counter(size), callback_sent(false) { } + bool IterateAndCheckCounter() { D_ASSERT(counter.load() > 0); idx_t post_decreast = --counter; return (post_decreast == 0); } + bool MarkCallbackSent() { + bool expected = false; + return callback_sent.compare_exchange_strong(expected, true); + } + private: atomic counter; + atomic callback_sent; }; class AsyncExecutionTask : public ExecutorTask { + enum class CompletionSignal { BATCH_FINISHED, BATCH_ERRORED }; + public: AsyncExecutionTask(Executor &executor, unique_ptr &&async_task, InterruptState &interrupt_state, - shared_ptr counter) + shared_ptr completion) : ExecutorTask(executor, nullptr), async_task(std::move(async_task)), interrupt_state(interrupt_state), - counter(std::move(counter)) { + completion(std::move(completion)) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - async_task->Execute(); - if (counter->IterateAndCheckCounter()) { - interrupt_state.Callback(); + try { + async_task->Execute(); + } catch (...) { + SignalCompletion(CompletionSignal::BATCH_ERRORED); + throw; } + SignalCompletion(CompletionSignal::BATCH_FINISHED); return TaskExecutionResult::TASK_FINISHED; } @@ -44,9 +58,22 @@ class AsyncExecutionTask : public ExecutorTask { } private: + void SignalCompletion(CompletionSignal signal) { + auto finished = completion->IterateAndCheckCounter(); + if ((signal == CompletionSignal::BATCH_ERRORED || finished)) { + SendCallback(); + } + } + + void SendCallback() { + if (completion->MarkCallbackSent()) { + interrupt_state.Callback(); + } + } + unique_ptr async_task; InterruptState interrupt_state; - shared_ptr counter; + shared_ptr completion; }; AsyncResult::AsyncResult(SourceResultType t) : AsyncResult(GetAsyncResultType(t)) { @@ -58,8 +85,8 @@ AsyncResult::AsyncResult(AsyncResultType t) : result_type(t) { } } -AsyncResult::AsyncResult(vector> &&tasks) - : result_type(AsyncResultType::BLOCKED), async_tasks(std::move(tasks)) { +AsyncResult::AsyncResult(vector> &&tasks, TaskSchedulerType pool_type_p) + : result_type(AsyncResultType::BLOCKED), async_tasks(std::move(tasks)), pool_type(pool_type_p) { if (async_tasks.empty()) { throw InternalException("AsyncResult constructed from empty vector of tasks"); } @@ -80,6 +107,7 @@ AsyncResult &AsyncResult::operator=(duckdb::AsyncResultType t) { AsyncResult &AsyncResult::operator=(AsyncResult &&other) noexcept { result_type = other.result_type; async_tasks = std::move(other.async_tasks); + pool_type = other.pool_type; return *this; } @@ -92,11 +120,13 @@ void AsyncResult::ScheduleTasks(InterruptState &interrupt_state, Executor &execu throw InternalException("AsyncResult::ScheduleTasks called with no available tasks"); } - shared_ptr counter = make_shared_ptr(async_tasks.size()); + DUCKDB_LOG(executor.context, AsyncTaskScheduleLogType, EnumUtil::ToString(pool_type), async_tasks.size()); + + shared_ptr completion = make_shared_ptr(async_tasks.size()); for (auto &async_task : async_tasks) { - auto task = make_uniq(executor, std::move(async_task), interrupt_state, counter); - TaskScheduler::GetScheduler(executor.context).ScheduleTask(executor.GetToken(), std::move(task)); + auto task = make_uniq(executor, std::move(async_task), interrupt_state, completion); + TaskScheduler::GetScheduler(executor.context).ScheduleTask(executor.GetToken(), std::move(task), pool_type); } } diff --git a/src/duckdb/src/parallel/executor.cpp b/src/duckdb/src/parallel/executor.cpp index cd869c865..36eb31ca4 100644 --- a/src/duckdb/src/parallel/executor.cpp +++ b/src/duckdb/src/parallel/executor.cpp @@ -452,17 +452,20 @@ void Executor::CancelTasks() { events.clear(); } -void Executor::WorkOnTasks() { +bool Executor::WorkOnTasks() { auto &scheduler = TaskScheduler::GetScheduler(context); + bool did_work = false; shared_ptr task_from_producer; while (scheduler.GetTaskFromProducer(*producer, task_from_producer)) { + did_work = true; auto res = task_from_producer->Execute(TaskExecutionMode::PROCESS_ALL); if (res == TaskExecutionResult::TASK_BLOCKED) { task_from_producer->Deschedule(); } task_from_producer.reset(); } + return did_work; } void Executor::SignalTaskRescheduled(lock_guard &) { diff --git a/src/duckdb/src/parallel/meta_pipeline.cpp b/src/duckdb/src/parallel/meta_pipeline.cpp index 2a69fbfae..c61976721 100644 --- a/src/duckdb/src/parallel/meta_pipeline.cpp +++ b/src/duckdb/src/parallel/meta_pipeline.cpp @@ -216,6 +216,10 @@ optional_ptr MetaPipeline::GetFinishGroup(Pipeline &pipeline) const { return it == finish_map.end() ? nullptr : &it->second; } +const vector> &MetaPipeline::GetChildren() const { + return children; +} + Pipeline &MetaPipeline::CreateUnionPipeline(Pipeline ¤t, bool order_matters) { // create the union pipeline (batch index 0, should be set correctly afterwards) auto &union_pipeline = CreatePipeline(); diff --git a/src/duckdb/src/parallel/pipeline.cpp b/src/duckdb/src/parallel/pipeline.cpp index d48cc1241..ed251c0be 100644 --- a/src/duckdb/src/parallel/pipeline.cpp +++ b/src/duckdb/src/parallel/pipeline.cpp @@ -98,7 +98,7 @@ void Pipeline::ScheduleSequentialTask(shared_ptr &event) { event->SetTasks(std::move(tasks)); } -bool Pipeline::ScheduleParallel(shared_ptr &event) { +bool Pipeline::TryGetMaxThreads(idx_t &max_threads) { // check if the sink, source and all intermediate operators support parallelism if (!sink->ParallelSink()) { return false; @@ -106,21 +106,15 @@ bool Pipeline::ScheduleParallel(shared_ptr &event) { if (!source->ParallelSource()) { return false; } - auto max_threads = source_state->MaxThreads(); + max_threads = source_state->MaxThreads(); for (auto &op_ref : operators) { auto &op = op_ref.get(); if (!op.ParallelOperator()) { return false; } - max_threads = MinValue(max_threads, op.op_state->MaxThreads(max_threads)); - } - - auto partition_info = sink->RequiredPartitionInfo(); - if (partition_info.batch_index) { - if (!source->SupportsPartitioning(OperatorPartitionInfo::BatchIndex())) { - throw InternalException( - "Attempting to schedule a pipeline where the sink requires batch index but source does not support it"); + if (op.op_state) { + max_threads = MinValue(max_threads, op.op_state->MaxThreads(max_threads)); } } @@ -135,9 +129,39 @@ bool Pipeline::ScheduleParallel(shared_ptr &event) { if (max_threads > active_threads) { max_threads = active_threads; } + + return true; +} + +bool Pipeline::ScheduleParallel(shared_ptr &event) { + idx_t max_threads; + + if (!TryGetMaxThreads(max_threads)) { + return false; + } + + // Handle partition requirements specific to scheduling + auto partition_info = sink->RequiredPartitionInfo(); + if (partition_info.batch_index) { + if (!source->SupportsPartitioning(OperatorPartitionInfo::BatchIndex())) { + throw InternalException( + "Attempting to schedule a pipeline where the sink requires batch index but source does not support it"); + } + } + return LaunchScanTasks(event, max_threads); } +idx_t Pipeline::GetMaxThreads() { + idx_t max_threads; + + if (!TryGetMaxThreads(max_threads)) { + return 1; // Fallback for unsupported parallelism + } + + return max_threads; +} + bool Pipeline::IsOrderDependent() const { if (source) { auto source_order = source->SourceOrder(); @@ -204,6 +228,23 @@ void Pipeline::ResetSink() { } } +void Pipeline::ResetSinkForReschedule() { + if (!sink) { + return; + } + if (!sink->IsSink()) { + throw InternalException("Sink of pipeline does not have IsSink set"); + } + lock_guard guard(sink->lock); + auto &client = GetClientContext(); + auto allow_reuse = Settings::Get(client); + if (allow_reuse && sink->sink_state && sink->sink_state->SupportsReuse()) { + sink->sink_state->Reset(client); + return; + } + sink->sink_state = sink->GetGlobalSinkState(client); +} + void Pipeline::PrepareFinalize() { if (sink) { if (!sink->IsSink()) { @@ -232,6 +273,31 @@ void Pipeline::Reset() { initialized = true; } +void Pipeline::ResetForReschedule(bool reset_sink) { + if (reset_sink) { + ResetSinkForReschedule(); + } + auto &client = GetClientContext(); + auto allow_reuse = Settings::Get(client); + for (auto &op_ref : operators) { + auto &op = op_ref.get(); + lock_guard guard(op.lock); + if (allow_reuse && op.op_state && op.ResetGlobalOperatorState(client, *op.op_state)) { + continue; + } + op.op_state = op.GetGlobalOperatorState(client); + } + if (source && !source->IsSource()) { + throw InternalException("Source of pipeline does not have IsSource set"); + } + if (!allow_reuse || !source_state || !source_state->SupportsReuse()) { + source_state = source->GetGlobalSourceState(client); + } else { + source_state->Reset(client); + } + initialized = true; +} + void Pipeline::ResetSource(bool force) { if (source && !source->IsSource()) { throw InternalException("Source of pipeline does not have IsSource set"); diff --git a/src/duckdb/src/parallel/pipeline_executor.cpp b/src/duckdb/src/parallel/pipeline_executor.cpp index 12c883e41..eae56e896 100644 --- a/src/duckdb/src/parallel/pipeline_executor.cpp +++ b/src/duckdb/src/parallel/pipeline_executor.cpp @@ -4,6 +4,8 @@ #include "duckdb/common/mutex.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/settings.hpp" + #ifdef DUCKDB_DEBUG_ASYNC_SINK_SOURCE #include #include @@ -50,6 +52,75 @@ PipelineExecutor::PipelineExecutor(ClientContext &context_p, Pipeline &pipeline_ InitializeChunk(final_chunk); } +void PipelineExecutor::Reset() { + D_ASSERT(pipeline.source_state); + auto allow_reuse = Settings::Get(context.client); + + // Reset execution flags + exhausted_pipeline = false; + finalized = false; + started_flushing = false; + done_flushing = false; + remaining_sink_chunk = false; + next_batch_blocked = false; + finished_processing_idx = -1; + should_flush_current_idx = true; + while (!in_process_operators.empty()) { + in_process_operators.pop(); + } + + // Recreate local sink state (destroyed by PushFinalize) + if (pipeline.sink) { + if (!allow_reuse || !local_sink_state || !local_sink_state->SupportsReuse()) { + local_sink_state = pipeline.sink->GetLocalSinkState(context); + } else { + local_sink_state->Reset(context, *pipeline.sink->sink_state); + } + required_partition_info = pipeline.sink->RequiredPartitionInfo(); + local_sink_state->partition_info = SourcePartitionInfo(); + if (required_partition_info.AnyRequired()) { + D_ASSERT(pipeline.source->SupportsPartitioning(OperatorPartitionInfo::BatchIndex())); + auto &partition_info = local_sink_state->partition_info; + D_ASSERT(!partition_info.batch_index.IsValid()); + partition_info.batch_index = pipeline.RegisterNewBatchIndex(); + partition_info.min_batch_index = partition_info.batch_index; + } + } + + // Recreate local source state (source data changed) + if (!allow_reuse || !local_source_state || !local_source_state->SupportsReuse()) { + local_source_state = pipeline.source->GetLocalSourceState(context, *pipeline.source_state); + } else { + local_source_state->Reset(context, *pipeline.source_state); + } + + // Recreate intermediate operator states (finalized in previous execution) + // Keep intermediate_chunks — reuse their allocated memory + for (idx_t i = 0; i < pipeline.operators.size(); i++) { + auto ¤t_operator = pipeline.operators[i].get(); + if (!allow_reuse || !intermediate_states[i] || !intermediate_states[i]->SupportsReuse()) { + intermediate_states[i] = current_operator.GetOperatorState(context); + } else { + intermediate_states[i]->Reset(); + } + + if (current_operator.IsSink() && current_operator.sink_state->state == SinkFinalizeType::NO_OUTPUT_POSSIBLE) { + FinishProcessing(); + } + } + + // Reset the final chunk data (keep allocation) + final_chunk.Reset(); +} + +void PipelineExecutor::PrepareForExecution() { + if (!has_executed) { + has_executed = true; + return; + } + Reset(); +} + bool PipelineExecutor::TryFlushCachingOperators(ExecutionBudget &chunk_budget) { if (!started_flushing) { // Remainder of this method assumes any in process operators are from flushing @@ -190,6 +261,7 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { D_ASSERT(pipeline.sink); auto &source_chunk = pipeline.operators.empty() ? final_chunk : *intermediate_chunks[0]; ExecutionBudget chunk_budget(max_chunks); + do { context.client.InterruptCheck(); @@ -231,6 +303,7 @@ PipelineExecuteResult PipelineExecutor::Execute(idx_t max_chunks) { return PipelineExecuteResult::INTERRUPTED; } if (source_result == SourceResultType::FINISHED) { + exhausted_source = true; exhausted_pipeline = true; } } @@ -379,12 +452,17 @@ PipelineExecuteResult PipelineExecutor::PushFinalize() { } finalized = true; + + // If source was not exhausted (e.g. LIMIT stopped the pipeline), collect exact metrics now + if (!source_profiling_finalized && local_source_state) { + context.thread.profiler.FinishSource(*pipeline.source, *pipeline.source_state, *local_source_state); + } + // flush all query profiler info for (idx_t i = 0; i < intermediate_states.size(); i++) { intermediate_states[i]->Finalize(pipeline.operators[i].get(), context); } pipeline.executor.Flush(thread); - local_sink_state.reset(); return PipelineExecuteResult::FINISHED; } @@ -478,7 +556,11 @@ OperatorResultType PipelineExecutor::Execute(DataChunk &input, DataChunk &result } void PipelineExecutor::SetTaskForInterrupts(weak_ptr current_task) { - interrupt_state = InterruptState(std::move(current_task)); + SetInterruptState(InterruptState(std::move(current_task))); +} + +void PipelineExecutor::SetInterruptState(InterruptState interrupt_state_p) { + interrupt_state = std::move(interrupt_state_p); } SourceResultType PipelineExecutor::GetData(DataChunk &chunk, OperatorSourceInput &input) { @@ -531,6 +613,7 @@ SourceResultType PipelineExecutor::FetchFromSource(DataChunk &result) { if (res == SourceResultType::FINISHED) { // final call into the source - finish source execution context.thread.profiler.FinishSource(*pipeline.source_state, *local_source_state); + source_profiling_finalized = true; } EndOperator(*pipeline.source, &result); diff --git a/src/duckdb/src/parallel/task_executor.cpp b/src/duckdb/src/parallel/task_executor.cpp index f7f995e93..11dca042d 100644 --- a/src/duckdb/src/parallel/task_executor.cpp +++ b/src/duckdb/src/parallel/task_executor.cpp @@ -6,11 +6,12 @@ namespace duckdb { -TaskExecutor::TaskExecutor(TaskScheduler &scheduler) - : scheduler(scheduler), token(scheduler.CreateProducer()), completed_tasks(0), total_tasks(0) { +TaskExecutor::TaskExecutor(TaskScheduler &scheduler, TaskSchedulerType type_p) + : scheduler(scheduler), type(type_p), token(scheduler.CreateProducer()), completed_tasks(0), total_tasks(0) { } -TaskExecutor::TaskExecutor(ClientContext &context_p) : TaskExecutor(TaskScheduler::GetScheduler(context_p)) { +TaskExecutor::TaskExecutor(ClientContext &context_p, TaskSchedulerType type_p) + : TaskExecutor(TaskScheduler::GetScheduler(context_p), type_p) { context = context_p; } @@ -31,7 +32,7 @@ void TaskExecutor::ThrowError() { void TaskExecutor::ScheduleTask(unique_ptr task) { ++total_tasks; - scheduler.ScheduleTask(*token, std::move(task)); + scheduler.ScheduleTask(*token, std::move(task), type); } void TaskExecutor::FinishTask() { ++completed_tasks; diff --git a/src/duckdb/src/parallel/task_scheduler.cpp b/src/duckdb/src/parallel/task_scheduler.cpp index 29eb57ddf..1038ca382 100644 --- a/src/duckdb/src/parallel/task_scheduler.cpp +++ b/src/duckdb/src/parallel/task_scheduler.cpp @@ -1,20 +1,18 @@ #include "duckdb/parallel/task_scheduler.hpp" -#include "duckdb/common/chrono.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/numeric_utils.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/parallel/task_scheduler_pool.hpp" +#include "duckdb/parallel/task_scheduler_queue.hpp" #include "duckdb/storage/block_allocator.hpp" #ifndef DUCKDB_NO_THREADS #include "concurrentqueue.h" #include "duckdb/common/thread.hpp" -#include "lightweightsemaphore.h" #include -#else -#include #endif #if defined(_WIN32) @@ -28,308 +26,180 @@ #endif namespace duckdb { -struct SchedulerThread { -#ifndef DUCKDB_NO_THREADS - explicit SchedulerThread(unique_ptr thread_p) : internal_thread(std::move(thread_p)) { - } - - unique_ptr internal_thread; -#endif -}; - -#ifndef DUCKDB_NO_THREADS -typedef duckdb_moodycamel::ConcurrentQueue> concurrent_queue_t; -typedef duckdb_moodycamel::LightweightSemaphore lightweight_semaphore_t; - -struct ConcurrentQueue { - ConcurrentQueue() : tasks_in_queue(0) { - } - - lightweight_semaphore_t semaphore; - - void Enqueue(ProducerToken &token, shared_ptr task); - void EnqueueBulk(ProducerToken &token, vector> &tasks); - bool DequeueFromProducer(ProducerToken &token, shared_ptr &task); - bool Dequeue(shared_ptr &task); - idx_t GetTasksInQueue() const; - idx_t GetApproxSize() const; - idx_t GetProducerCount() const; - idx_t GetTaskCountForProducer(ProducerToken &token) const; - concurrent_queue_t &GetQueue() { - return q; - } -private: - concurrent_queue_t q; - atomic tasks_in_queue; -}; - -struct QueueProducerToken { - explicit QueueProducerToken(ConcurrentQueue &queue) : queue_token(queue.GetQueue()) { - } - - duckdb_moodycamel::ProducerToken queue_token; -}; - -void ConcurrentQueue::Enqueue(ProducerToken &token, shared_ptr task) { - lock_guard producer_lock(token.producer_lock); - task->token = token; - if (q.enqueue(token.token->queue_token, std::move(task))) { - ++tasks_in_queue; - semaphore.signal(); - } else { - throw InternalException("Could not schedule task!"); - } -} - -void ConcurrentQueue::EnqueueBulk(ProducerToken &token, vector> &tasks) { - typedef std::make_signed::type ssize_t; - lock_guard producer_lock(token.producer_lock); - for (auto &task : tasks) { - task->token = token; - } - if (q.enqueue_bulk(token.token->queue_token, std::make_move_iterator(tasks.begin()), tasks.size())) { - tasks_in_queue += tasks.size(); - semaphore.signal(NumericCast(tasks.size())); - } else { - throw InternalException("Could not schedule tasks!"); +TaskScheduler::TaskScheduler(DatabaseInstance &db) : db(db) { + for (uint8_t i = 0; i < TASK_SCHEDULER_TYPE_COUNT; i++) { + pools[i] = make_uniq(db, static_cast(i)); + queues[i] = make_uniq(static_cast(i)); } } -bool ConcurrentQueue::DequeueFromProducer(ProducerToken &token, shared_ptr &task) { - lock_guard producer_lock(token.producer_lock); - if (!q.try_dequeue_from_producer(token.token->queue_token, task)) { - return false; - } - --tasks_in_queue; - return true; -} - -bool ConcurrentQueue::Dequeue(shared_ptr &task) { - if (!q.try_dequeue(task)) { - return false; +TaskScheduler::~TaskScheduler() { +#ifndef DUCKDB_NO_THREADS + try { + for (auto &pool : pools) { + pool->RelaunchThreads(*this, true); + } + } catch (...) { + // nothing we can do in the destructor if this fails } - --tasks_in_queue; - return true; +#endif } -idx_t ConcurrentQueue::GetTasksInQueue() const { - return tasks_in_queue; -} -idx_t ConcurrentQueue::GetApproxSize() const { - return q.size_approx(); -} -idx_t ConcurrentQueue::GetProducerCount() const { - return q.size_producers_approx(); +TaskScheduler &TaskScheduler::GetScheduler(ClientContext &context) { + return GetScheduler(DatabaseInstance::GetDatabase(context)); } -idx_t ConcurrentQueue::GetTaskCountForProducer(ProducerToken &token) const { - lock_guard producer_lock(token.producer_lock); - return q.size_producer_approx(token.token->queue_token); +TaskScheduler &TaskScheduler::GetScheduler(DatabaseInstance &db) { + return db.GetScheduler(); } -#else -struct ConcurrentQueue { - reference_map_t>> q; - mutable mutex qlock; - - void Enqueue(ProducerToken &token, shared_ptr task); - void EnqueueBulk(ProducerToken &token, vector> &tasks); - bool DequeueFromProducer(ProducerToken &token, shared_ptr &task); - bool Dequeue(shared_ptr &task); - idx_t GetTasksInQueue() const; - idx_t GetApproxSize() const; - idx_t GetProducerCount() const; - idx_t GetTaskCountForProducer(ProducerToken &token) const; -}; - -void ConcurrentQueue::Enqueue(ProducerToken &token, shared_ptr task) { - lock_guard lock(qlock); - task->token = token; - q[std::ref(*token.token)].push(std::move(task)); -} - -void ConcurrentQueue::EnqueueBulk(ProducerToken &token, vector> &tasks) { - lock_guard lock(qlock); - for (auto &task : tasks) { - task->token = token; - q[std::ref(*token.token)].push(std::move(task)); - } +unique_ptr TaskScheduler::CreateProducer() { + return make_uniq(queues); } -bool ConcurrentQueue::DequeueFromProducer(ProducerToken &token, shared_ptr &task) { - lock_guard lock(qlock); - D_ASSERT(!q.empty()); - - const auto it = q.find(std::ref(*token.token)); - if (it == q.end() || it->second.empty()) { - return false; - } - - task = std::move(it->second.front()); - it->second.pop(); - - return true; +void TaskScheduler::ScheduleTask(ProducerToken &token, shared_ptr task) { + ScheduleTask(token, std::move(task), TaskSchedulerType::REGULAR); } -bool ConcurrentQueue::Dequeue(shared_ptr &task) { - throw InternalException("Global dequeue not supported for no threads queue"); +void TaskScheduler::ScheduleTasks(ProducerToken &producer, vector> &tasks) { + ScheduleTasks(producer, tasks, TaskSchedulerType::REGULAR); } -idx_t ConcurrentQueue::GetTasksInQueue() const { - lock_guard lock(qlock); - idx_t task_count = 0; - for (auto &producer : q) { - task_count += producer.second.size(); +bool TaskScheduler::GetTaskFromProducer(ProducerToken &token, shared_ptr &task) { + for (auto &queue : queues) { + if (queue->DequeueFromProducer(token, task)) { + return true; + } } - return task_count; -} - -idx_t ConcurrentQueue::GetApproxSize() const { - return GetTasksInQueue(); + return false; } -idx_t ConcurrentQueue::GetProducerCount() const { - lock_guard lock(qlock); - return q.size(); +TaskSchedulerPool &TaskScheduler::GetPool(TaskSchedulerType pool_type) { + return *pools[static_cast(pool_type)]; } -idx_t ConcurrentQueue::GetTaskCountForProducer(ProducerToken &token) const { - lock_guard lock(qlock); - const auto it = q.find(std::ref(*token.token)); - if (it == q.end()) { - return 0; - } - return it->second.size(); +TaskSchedulerQueue &TaskScheduler::GetQueue(TaskSchedulerType pool_type) const { + return *queues[static_cast(pool_type)]; } -struct QueueProducerToken { - explicit QueueProducerToken(ConcurrentQueue &queue) : queue(&queue) { - } - - ~QueueProducerToken() { - lock_guard lock(queue->qlock); - queue->q.erase(*this); - } - -private: - ConcurrentQueue *queue; -}; -#endif - -ProducerToken::ProducerToken(TaskScheduler &scheduler, unique_ptr token) - : scheduler(scheduler), token(std::move(token)) { +void TaskScheduler::SetThreadsInternal(TaskSchedulerType pool_type, idx_t n) { + GetPool(pool_type).SetThreads(n); } -ProducerToken::~ProducerToken() { +void TaskScheduler::ScheduleTask(ProducerToken &token, shared_ptr task, TaskSchedulerType pool_type) { + GetQueue(pool_type).Enqueue(token, std::move(task)); + SignalForTaskType(pool_type, 1); } -TaskScheduler::TaskScheduler(DatabaseInstance &db) - : db(db), queue(make_uniq()), - allocator_flush_threshold(db.config.options.allocator_flush_threshold), - allocator_background_threads(Settings::Get(db)), requested_thread_count(0), - current_thread_count(1) { - SetAllocatorBackgroundThreads(allocator_background_threads); +void TaskScheduler::ScheduleTasks(ProducerToken &producer, vector> &tasks, + TaskSchedulerType pool_type) { + GetQueue(pool_type).EnqueueBulk(producer, tasks); + SignalForTaskType(pool_type, tasks.size()); } -TaskScheduler::~TaskScheduler() { -#ifndef DUCKDB_NO_THREADS - try { - RelaunchThreadsInternal(0, true); - } catch (...) { - // nothing we can do in the destructor if this fails +bool TaskScheduler::GetTaskInternal(shared_ptr &task) { + for (auto &queue : queues) { + if (queue->Dequeue(task)) { + return true; + } } -#endif + return false; } -TaskScheduler &TaskScheduler::GetScheduler(ClientContext &context) { - return TaskScheduler::GetScheduler(DatabaseInstance::GetDatabase(context)); -} - -TaskScheduler &TaskScheduler::GetScheduler(DatabaseInstance &db) { - return db.GetScheduler(); +bool TaskScheduler::GetTaskInternal(shared_ptr &task, TaskSchedulerType pool_type) { + return GetQueue(pool_type).Dequeue(task); } -unique_ptr TaskScheduler::CreateProducer() { - auto token = make_uniq(*queue); - return make_uniq(*this, std::move(token)); +void TaskScheduler::ExecuteForever(atomic *marker) { + ExecuteForever(marker, TaskSchedulerType::REGULAR); } -void TaskScheduler::ScheduleTask(ProducerToken &token, shared_ptr task) { - // Enqueue a task for the given producer token and signal any sleeping threads - queue->Enqueue(token, std::move(task)); -} +bool TaskScheduler::TryDequeueAndProcessTask(const DBConfig &config, TaskSchedulerQueue &queue, + shared_ptr &task) { + if (queue.Dequeue(task)) { + auto process_mode = TaskExecutionMode::PROCESS_ALL; + if (Settings::Get(config)) { + process_mode = TaskExecutionMode::PROCESS_PARTIAL; + } + auto execute_result = task->Execute(process_mode); -void TaskScheduler::ScheduleTasks(ProducerToken &producer, vector> &tasks) { - queue->EnqueueBulk(producer, tasks); -} + switch (execute_result) { + case TaskExecutionResult::TASK_FINISHED: + case TaskExecutionResult::TASK_ERROR: + task.reset(); + break; + case TaskExecutionResult::TASK_NOT_FINISHED: { + // task is not finished - reschedule immediately + auto &token = *task->token; + queue.Enqueue(token, std::move(task)); + SignalForTaskType(queue.GetPoolType(), 1); + break; + } + case TaskExecutionResult::TASK_BLOCKED: + task->Deschedule(); + task.reset(); + break; + } + return true; + } -bool TaskScheduler::GetTaskFromProducer(ProducerToken &token, shared_ptr &task) { - return queue->DequeueFromProducer(token, task); + if (queue.GetTasksInQueue() > 0) { + // failed to dequeue but there are still tasks remaining - signal again to retry + SignalForTaskType(queue.GetPoolType(), 1); + } + return false; } -void TaskScheduler::ExecuteForever(atomic *marker) { +void TaskScheduler::ExecuteForever(atomic *marker, const TaskSchedulerType pool_type) { #ifndef DUCKDB_NO_THREADS - static constexpr const int64_t INITIAL_FLUSH_WAIT = 500000; // initial wait time of 0.5s (in mus) before flushing + static constexpr int64_t INITIAL_FLUSH_WAIT = 500000; // initial wait time of 0.5s (in mus) before flushing const auto &block_allocator = BlockAllocator::Get(db); const auto &config = DBConfig::GetConfig(db); + auto &pool = GetPool(pool_type); shared_ptr task; // loop until the marker is set to false while (*marker) { if (!block_allocator.SupportsFlush()) { // allocator can't flush, just start an untimed wait - queue->semaphore.wait(); - } else if (!queue->semaphore.wait(INITIAL_FLUSH_WAIT)) { + pool.Wait(); + } else if (!pool.Wait(INITIAL_FLUSH_WAIT)) { // allocator can flush, we flush this threads outstanding allocations after it was idle for 0.5s - block_allocator.ThreadFlush(allocator_background_threads, allocator_flush_threshold, - NumericCast(requested_thread_count.load())); + block_allocator.ThreadFlush( + Settings::Get(db), + StringUtil::ParseFormattedBytes(Settings::Get(db)), + NumericCast(GetPool(TaskSchedulerType::REGULAR).NumberOfThreads())); auto decay_delay = Allocator::DecayDelay(); if (!decay_delay.IsValid()) { // no decay delay specified - just wait - queue->semaphore.wait(); + pool.Wait(); } else { - if (!queue->semaphore.wait(UnsafeNumericCast(decay_delay.GetIndex()) * 1000000 - - INITIAL_FLUSH_WAIT)) { + if (!pool.Wait(UnsafeNumericCast(decay_delay.GetIndex()) * 1000000 - INITIAL_FLUSH_WAIT)) { // in total, the thread was idle for the entire decay delay (note: seconds converted to mus) // mark it as idle and start an untimed wait Allocator::ThreadIdle(); - queue->semaphore.wait(); + pool.Wait(); } } } - if (queue->Dequeue(task)) { - auto process_mode = TaskExecutionMode::PROCESS_ALL; - if (Settings::Get(config)) { - process_mode = TaskExecutionMode::PROCESS_PARTIAL; - } - auto execute_result = task->Execute(process_mode); - switch (execute_result) { - case TaskExecutionResult::TASK_FINISHED: - case TaskExecutionResult::TASK_ERROR: - task.reset(); - break; - case TaskExecutionResult::TASK_NOT_FINISHED: { - // task is not finished - reschedule immediately - auto &token = *task->token; - queue->Enqueue(token, std::move(task)); - break; - } - case TaskExecutionResult::TASK_BLOCKED: - task->Deschedule(); - task.reset(); - break; + if (pool_type == TaskSchedulerType::REGULAR) { + // Regular thread pool picks up tasks from all pools + for (auto &queue : queues) { + if (TryDequeueAndProcessTask(config, *queue, task)) { + break; + } } - } else if (queue->GetTasksInQueue() > 0) { - // failed to dequeue but there are still tasks remaining - signal again to retry - queue->semaphore.signal(1); + } else { + TryDequeueAndProcessTask(config, GetQueue(pool_type), task); } } // this thread will exit, flush all of its outstanding allocations if (block_allocator.SupportsFlush()) { - block_allocator.ThreadFlush(allocator_background_threads, 0, NumericCast(requested_thread_count.load())); + block_allocator.ThreadFlush(Settings::Get(db), 0, + NumericCast(GetPool(TaskSchedulerType::REGULAR).NumberOfThreads())); Allocator::ThreadIdle(); } #else @@ -343,7 +213,7 @@ idx_t TaskScheduler::ExecuteTasks(atomic *marker, idx_t max_tasks) { // loop until the marker is set to false while (*marker && completed_tasks < max_tasks) { shared_ptr task; - if (!queue->Dequeue(task)) { + if (!GetTaskInternal(task)) { return completed_tasks; } auto execute_result = task->Execute(TaskExecutionMode::PROCESS_ALL); @@ -372,8 +242,8 @@ void TaskScheduler::ExecuteTasks(idx_t max_tasks) { #ifndef DUCKDB_NO_THREADS shared_ptr task; for (idx_t i = 0; i < max_tasks; i++) { - queue->semaphore.wait(TASK_TIMEOUT_USECS); - if (!queue->Dequeue(task)) { + GetPool(TaskSchedulerType::REGULAR).Wait(TASK_TIMEOUT_USECS); + if (!GetTaskInternal(task)) { return; } try { @@ -399,26 +269,29 @@ void TaskScheduler::ExecuteTasks(idx_t max_tasks) { #endif } -#ifndef DUCKDB_NO_THREADS -static void ThreadExecuteTasks(TaskScheduler *scheduler, atomic *marker) { - scheduler->ExecuteForever(marker); -} -#endif - int32_t TaskScheduler::NumberOfThreads() { - return current_thread_count.load(); + return GetPool(TaskSchedulerType::REGULAR).NumberOfThreads(); } idx_t TaskScheduler::GetNumberOfTasks() const { - return queue->GetTasksInQueue(); + idx_t num_tasks = 0; + for (auto &queue : queues) { + num_tasks += queue->GetTasksInQueue(); + } + return num_tasks; } idx_t TaskScheduler::GetProducerCount() const { - return queue->GetProducerCount(); + // We always create a producer in all queues, so we can just get the producer count of the regular queue here + return GetQueue(TaskSchedulerType::REGULAR).GetProducerCount(); } idx_t TaskScheduler::GetTaskCountForProducer(ProducerToken &token) const { - return queue->GetTaskCountForProducer(token); + idx_t task_count = 0; + for (uint8_t i = 0; i < TASK_SCHEDULER_TYPE_COUNT; i++) { + task_count += GetQueue(static_cast(i)).GetTaskCountForProducer(token); + } + return task_count; } void TaskScheduler::SetThreads(idx_t total_threads, idx_t external_threads) { @@ -435,23 +308,39 @@ void TaskScheduler::SetThreads(idx_t total_threads, idx_t external_threads) { "DuckDB was compiled without threads! Setting total_threads != external_threads is not allowed."); } #endif - requested_thread_count = NumericCast(total_threads - external_threads); + SetThreadsInternal(TaskSchedulerType::REGULAR, total_threads - external_threads); +} + +void TaskScheduler::SetAsyncThreads(idx_t n) { +#ifdef DUCKDB_NO_THREADS + if (n != 0) { + throw NotImplementedException( + "DuckDB was compiled without threads! Setting async threads != 0 is not allowed."); + } +#endif + SetThreadsInternal(TaskSchedulerType::ASYNC, n); } -void TaskScheduler::SetAllocatorFlushThreshold(idx_t threshold) { - allocator_flush_threshold = threshold; +void TaskScheduler::Signal(idx_t n) { + Signal(TaskSchedulerType::REGULAR, n); } -void TaskScheduler::SetAllocatorBackgroundThreads(bool enable) { - allocator_background_threads = enable; - Allocator::SetBackgroundThreads(enable); +void TaskScheduler::Signal(TaskSchedulerType pool_type, idx_t n) { + GetPool(pool_type).Signal(n); } -void TaskScheduler::Signal(idx_t n) { -#ifndef DUCKDB_NO_THREADS - typedef std::make_signed::type ssize_t; - queue->semaphore.signal(NumericCast(n)); -#endif +void TaskScheduler::SignalAllPools(idx_t n) { + for (auto &pool : pools) { + pool->Signal(n); + } +} + +void TaskScheduler::SignalForTaskType(TaskSchedulerType task_type, idx_t n) { + if (task_type == TaskSchedulerType::REGULAR) { + Signal(TaskSchedulerType::REGULAR, n); + } else { + SignalAllPools(n); + } } void TaskScheduler::YieldThread() { @@ -499,114 +388,9 @@ idx_t TaskScheduler::GetEstimatedCPUId() { void TaskScheduler::RelaunchThreads() { lock_guard t(thread_lock); - auto n = requested_thread_count.load(); - RelaunchThreadsInternal(n, false); -} - -#ifndef DUCKDB_NO_THREADS -static vector GetProcessCPUMask() { -#if defined(__GLIBC__) - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - if (sched_getaffinity(0, sizeof(cpu_set_t), &cpuset) != 0) { - return {}; - } - vector available_cpus; - for (int cpu = 0; cpu < CPU_SETSIZE; ++cpu) { - if (CPU_ISSET(cpu, &cpuset)) { - available_cpus.push_back(cpu); - } - } - return available_cpus; -#else - return {}; -#endif -} - -static void SetThreadAffinity(thread &thread, const vector &available_cpus, idx_t thread_idx) { -#if defined(__GLIBC__) - if (thread_idx < available_cpus.size()) { - const auto cpu_id = available_cpus[thread_idx]; - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - CPU_SET(cpu_id, &cpuset); - - // note that we don't care about the return value here - // if we did not manage to set affinity, the thread just does not have affinity, which is OK - pthread_setaffinity_np(thread.native_handle(), sizeof(cpu_set_t), &cpuset); - } -#endif -} -#endif - -void TaskScheduler::RelaunchThreadsInternal(int32_t n, bool destroy) { -#ifndef DUCKDB_NO_THREADS - auto &config = DBConfig::GetConfig(db); - auto new_thread_count = NumericCast(n); - - idx_t external_threads = 0; - ThreadPinMode pin_thread_mode = ThreadPinMode::AUTO; - if (!destroy) { - // If we are destroying, i.e., calling ~TaskScheduler, we don't want to read the settings - external_threads = Settings::Get(config); - pin_thread_mode = Settings::Get(db); - } - - if (threads.size() == new_thread_count) { - current_thread_count = NumericCast(threads.size() + external_threads); - return; + for (auto &pool : pools) { + pool->RelaunchThreads(*this, false); } - if (threads.size() != new_thread_count) { - // we are changing the number of threads: clear all threads first - // we do this even when increasing the number of threads to make sure that all threads follow the current - // affinity mask - for (idx_t i = 0; i < threads.size(); i++) { - *markers[i] = false; - } - Signal(threads.size()); - // now join the threads to ensure they are fully stopped before erasing them - for (idx_t i = 0; i < threads.size(); i++) { - threads[i]->internal_thread->join(); - } - // erase the threads/markers - threads.clear(); - markers.clear(); - } - if (threads.size() < new_thread_count) { - // we are increasing the number of threads: launch them and run tasks on them - idx_t create_new_threads = new_thread_count - threads.size(); - - // Whether to pin threads to cores - static constexpr idx_t THREAD_PIN_THRESHOLD = 64; - const auto pin_threads = - pin_thread_mode == ThreadPinMode::ON || - (pin_thread_mode == ThreadPinMode::AUTO && std::thread::hardware_concurrency() > THREAD_PIN_THRESHOLD); - const auto available_cpus = pin_threads ? GetProcessCPUMask() : vector(); - // If we have fewer available cores than threads, do not pin and let OS scheduler handle it - const auto can_pin = pin_threads && new_thread_count <= available_cpus.size(); - for (idx_t i = 0; i < create_new_threads; i++) { - // launch a thread and assign it a cancellation marker - auto marker = unique_ptr>(new atomic(true)); - unique_ptr worker_thread; - try { - worker_thread = make_uniq(ThreadExecuteTasks, this, marker.get()); - if (can_pin) { - SetThreadAffinity(*worker_thread, available_cpus, threads.size()); - } - } catch (std::exception &ex) { - // thread constructor failed - this can happen when the system has too many threads allocated - // in this case we cannot allocate more threads - stop launching them - break; - } - auto thread_wrapper = make_uniq(std::move(worker_thread)); - - threads.push_back(std::move(thread_wrapper)); - markers.push_back(std::move(marker)); - } - } - current_thread_count = NumericCast(threads.size() + external_threads); - BlockAllocator::Get(db).FlushAll(); -#endif } } // namespace duckdb diff --git a/src/duckdb/src/parallel/task_scheduler_pool.cpp b/src/duckdb/src/parallel/task_scheduler_pool.cpp new file mode 100644 index 000000000..20e2c0af1 --- /dev/null +++ b/src/duckdb/src/parallel/task_scheduler_pool.cpp @@ -0,0 +1,199 @@ +#include "duckdb/parallel/task_scheduler_pool.hpp" + +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/settings.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/thread.hpp" +#include "duckdb/storage/block_allocator.hpp" + +#ifndef DUCKDB_NO_THREADS +#include "lightweightsemaphore.h" + +#include +#endif + +#if defined(_WIN32) +#include +#elif defined(__GNUC__) +#include +#include +#if defined(__GLIBC__) +#include +#endif +#endif + +namespace duckdb { + +#ifndef DUCKDB_NO_THREADS +typedef duckdb_moodycamel::LightweightSemaphore lightweight_semaphore_t; + +struct LightWeightSemaphoreWrapper { + lightweight_semaphore_t s; +}; +#endif + +struct TaskSchedulerThread { +#ifndef DUCKDB_NO_THREADS + explicit TaskSchedulerThread(unique_ptr thread_p) : internal_thread(std::move(thread_p)) { + } + + unique_ptr internal_thread; +#endif +}; + +TaskSchedulerPool::TaskSchedulerPool(DatabaseInstance &db_p, TaskSchedulerType pool_type_p) + : db(db_p), pool_type(pool_type_p), requested_thread_count(0), + current_thread_count(pool_type == TaskSchedulerType::REGULAR ? 1 : 0) { +#ifndef DUCKDB_NO_THREADS + semaphore = make_uniq(); +#endif +} + +TaskSchedulerPool::~TaskSchedulerPool() { +} + +void TaskSchedulerPool::SetThreads(idx_t n) { + requested_thread_count = NumericCast(n); +} + +int32_t TaskSchedulerPool::NumberOfThreads() { + return current_thread_count.load(); +} + +void TaskSchedulerPool::Signal(idx_t n) { +#ifndef DUCKDB_NO_THREADS + if (n == 0) { + return; + } + typedef std::make_signed::type ssize_t; + semaphore->s.signal(NumericCast(n)); +#endif +} + +#ifndef DUCKDB_NO_THREADS +void TaskSchedulerPool::Wait() { + semaphore->s.wait(); +} + +bool TaskSchedulerPool::Wait(int64_t timeout_usecs) { + return semaphore->s.wait(timeout_usecs); +} +#endif + +#ifndef DUCKDB_NO_THREADS +static vector GetProcessCPUMask() { +#if defined(__GLIBC__) + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + if (sched_getaffinity(0, sizeof(cpu_set_t), &cpuset) != 0) { + return {}; + } + vector available_cpus; + for (int cpu = 0; cpu < CPU_SETSIZE; ++cpu) { + if (CPU_ISSET(cpu, &cpuset)) { + available_cpus.push_back(cpu); + } + } + return available_cpus; +#else + return {}; +#endif +} + +static void SetThreadAffinity(thread &thread, const vector &available_cpus, idx_t thread_idx) { +#if defined(__GLIBC__) + if (thread_idx < available_cpus.size()) { + const auto cpu_id = available_cpus[thread_idx]; + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(cpu_id, &cpuset); + + // note that we don't care about the return value here + // if we did not manage to set affinity, the thread just does not have affinity, which is OK + pthread_setaffinity_np(thread.native_handle(), sizeof(cpu_set_t), &cpuset); + } +#endif +} +#endif + +#ifndef DUCKDB_NO_THREADS +static void ThreadExecuteTasks(TaskScheduler *scheduler, atomic *marker, const TaskSchedulerType pool_type) { + scheduler->ExecuteForever(marker, pool_type); +} +#endif + +void TaskSchedulerPool::RelaunchThreads(TaskScheduler &scheduler, bool destroy) { +#ifndef DUCKDB_NO_THREADS + auto &config = DBConfig::GetConfig(db); + auto new_thread_count = NumericCast(destroy ? 0 : requested_thread_count.load()); + + idx_t external_threads = 0; + ThreadPinMode pin_thread_mode = ThreadPinMode::AUTO; + if (!destroy) { + // If we are destroying, i.e., calling ~TaskScheduler, we don't want to read the settings + external_threads = Settings::Get(config); + pin_thread_mode = Settings::Get(db); + } + + if (threads.size() == new_thread_count) { + current_thread_count = + NumericCast(threads.size() + (pool_type == TaskSchedulerType::REGULAR ? external_threads : 0)); + return; + } + if (threads.size() != new_thread_count) { + // we are changing the number of threads: clear all threads first + // we do this even when increasing the number of threads to make sure that all threads follow the current + // affinity mask + for (idx_t i = 0; i < threads.size(); i++) { + *markers[i] = false; + } + Signal(threads.size()); + // now join the threads to ensure they are fully stopped before erasing them + for (idx_t i = 0; i < threads.size(); i++) { + threads[i]->internal_thread->join(); + } + // erase the threads/markers + threads.clear(); + markers.clear(); + } + if (threads.size() < new_thread_count) { + // we are increasing the number of threads: launch them and run tasks on them + idx_t create_new_threads = new_thread_count - threads.size(); + + // Whether to pin threads to cores + static constexpr idx_t THREAD_PIN_THRESHOLD = 64; + const auto pin_threads = + pool_type == TaskSchedulerType::REGULAR && // Only pin regular threads! + (pin_thread_mode == ThreadPinMode::ON || + (pin_thread_mode == ThreadPinMode::AUTO && std::thread::hardware_concurrency() > THREAD_PIN_THRESHOLD)); + const auto available_cpus = pin_threads ? GetProcessCPUMask() : vector(); + // If we have fewer available cores than threads, do not pin and let OS scheduler handle it + const auto can_pin = pin_threads && new_thread_count <= available_cpus.size(); + for (idx_t i = 0; i < create_new_threads; i++) { + // launch a thread and assign it a cancellation marker + auto marker = unique_ptr>(new atomic(true)); + unique_ptr worker_thread; + try { + worker_thread = make_uniq(ThreadExecuteTasks, &scheduler, marker.get(), pool_type); + if (can_pin) { + SetThreadAffinity(*worker_thread, available_cpus, threads.size()); + } + } catch (std::exception &ex) { + // thread constructor failed - this can happen when the system has too many threads allocated + // in this case we cannot allocate more threads - stop launching them + break; + } + auto thread_wrapper = make_uniq(std::move(worker_thread)); + + threads.push_back(std::move(thread_wrapper)); + markers.push_back(std::move(marker)); + } + } + current_thread_count = + NumericCast(threads.size() + (pool_type == TaskSchedulerType::REGULAR ? external_threads : 0)); + BlockAllocator::Get(db).FlushAll(); +#endif +} + +}; // namespace duckdb diff --git a/src/duckdb/src/parallel/task_scheduler_queue.cpp b/src/duckdb/src/parallel/task_scheduler_queue.cpp new file mode 100644 index 000000000..85632e0a4 --- /dev/null +++ b/src/duckdb/src/parallel/task_scheduler_queue.cpp @@ -0,0 +1,195 @@ +#include "duckdb/parallel/task_scheduler_queue.hpp" + +#include "duckdb/parallel/task.hpp" +#include "duckdb/parallel/task_scheduler.hpp" + +#ifndef DUCKDB_NO_THREADS +#include "concurrentqueue.h" +#endif + +namespace duckdb { + +TaskSchedulerType TaskSchedulerQueue::GetPoolType() { + return pool_type; +} + +#ifndef DUCKDB_NO_THREADS +typedef duckdb_moodycamel::ConcurrentQueue> concurrent_queue_t; + +struct ConcurrentQueueWrapper { + concurrent_queue_t q; +}; + +struct QueueProducerToken { +public: + explicit QueueProducerToken(TaskSchedulerQueue &queue) : token(queue.GetQueue().q) { + } + +public: + duckdb_moodycamel::ProducerToken token; +}; + +TaskSchedulerQueue::TaskSchedulerQueue(TaskSchedulerType pool_type_p) + : pool_type(pool_type_p), queue(make_uniq()) { +} + +void TaskSchedulerQueue::Enqueue(ProducerToken &token, shared_ptr task) { + lock_guard producer_lock(token.producer_lock); + task->token = token; + if (queue->q.enqueue(token.GetQueueProducerToken(pool_type).token, std::move(task))) { + ++tasks_in_queue; + } else { + throw InternalException("Could not schedule task!"); + } +} + +void TaskSchedulerQueue::EnqueueBulk(ProducerToken &token, vector> &tasks) { + lock_guard producer_lock(token.producer_lock); + for (auto &task : tasks) { + task->token = token; + } + if (queue->q.enqueue_bulk(token.GetQueueProducerToken(pool_type).token, std::make_move_iterator(tasks.begin()), + tasks.size())) { + tasks_in_queue += tasks.size(); + } else { + throw InternalException("Could not schedule tasks!"); + } +} + +bool TaskSchedulerQueue::DequeueFromProducer(ProducerToken &token, shared_ptr &task) { + lock_guard producer_lock(token.producer_lock); + if (!queue->q.try_dequeue_from_producer(token.GetQueueProducerToken(pool_type).token, task)) { + return false; + } + --tasks_in_queue; + return true; +} + +bool TaskSchedulerQueue::Dequeue(shared_ptr &task) { + if (!queue->q.try_dequeue(task)) { + return false; + } + --tasks_in_queue; + return true; +} + +idx_t TaskSchedulerQueue::GetTasksInQueue() const { + return tasks_in_queue; +} +idx_t TaskSchedulerQueue::GetApproxSize() const { + return queue->q.size_approx(); +} +idx_t TaskSchedulerQueue::GetProducerCount() const { + return queue->q.size_producers_approx(); +} + +idx_t TaskSchedulerQueue::GetTaskCountForProducer(ProducerToken &token) const { + lock_guard producer_lock(token.producer_lock); + return queue->q.size_producer_approx(token.GetQueueProducerToken(pool_type).token); +} + +ConcurrentQueueWrapper &TaskSchedulerQueue::GetQueue() { + return *queue; +} + +#else + +struct QueueProducerToken { +public: + explicit QueueProducerToken(TaskSchedulerQueue &queue) : queue(&queue) { + } + + ~QueueProducerToken() { + queue->RemoveToken(*this); + } + +private: + TaskSchedulerQueue *queue; +}; + +TaskSchedulerQueue::TaskSchedulerQueue(TaskSchedulerType pool_type_p) : pool_type(pool_type_p) { +} + +void TaskSchedulerQueue::Enqueue(ProducerToken &token, shared_ptr task) { + lock_guard lock(qlock); + task->token = token; + q[token.GetQueueProducerToken(pool_type)].push(std::move(task)); +} + +void TaskSchedulerQueue::EnqueueBulk(ProducerToken &token, vector> &tasks) { + lock_guard lock(qlock); + for (auto &task : tasks) { + task->token = token; + q[token.GetQueueProducerToken(pool_type)].push(std::move(task)); + } +} + +bool TaskSchedulerQueue::DequeueFromProducer(ProducerToken &token, shared_ptr &task) { + lock_guard lock(qlock); + + const auto it = q.find(token.GetQueueProducerToken(pool_type)); + if (it == q.end() || it->second.empty()) { + return false; + } + + task = std::move(it->second.front()); + it->second.pop(); + + return true; +} + +bool TaskSchedulerQueue::Dequeue(shared_ptr &task) { + throw InternalException("Global dequeue not supported for no threads queue"); +} + +idx_t TaskSchedulerQueue::GetTasksInQueue() const { + lock_guard lock(qlock); + idx_t task_count = 0; + for (auto &producer : q) { + task_count += producer.second.size(); + } + return task_count; +} + +idx_t TaskSchedulerQueue::GetApproxSize() const { + return GetTasksInQueue(); +} + +idx_t TaskSchedulerQueue::GetProducerCount() const { + lock_guard lock(qlock); + return q.size(); +} + +idx_t TaskSchedulerQueue::GetTaskCountForProducer(ProducerToken &token) const { + lock_guard lock(qlock); + const auto it = q.find(token.GetQueueProducerToken(pool_type)); + if (it == q.end()) { + return 0; + } + return it->second.size(); +} + +void TaskSchedulerQueue::RemoveToken(QueueProducerToken &token) { + lock_guard lock(qlock); + q.erase(token); +} + +#endif + +TaskSchedulerQueue::~TaskSchedulerQueue() { +} + +ProducerToken::ProducerToken(array, TASK_SCHEDULER_TYPE_COUNT> &queues) { + for (uint8_t i = 0; i < TASK_SCHEDULER_TYPE_COUNT; i++) { + tokens[i] = make_uniq(*queues[i]); + } +} + +ProducerToken::~ProducerToken() { +} + +QueueProducerToken &ProducerToken::GetQueueProducerToken(TaskSchedulerType pool_type) { + return *tokens[static_cast(pool_type)]; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/base_expression.cpp b/src/duckdb/src/parser/base_expression.cpp index 92fa254a8..26b4ef96b 100644 --- a/src/duckdb/src/parser/base_expression.cpp +++ b/src/duckdb/src/parser/base_expression.cpp @@ -9,13 +9,13 @@ void BaseExpression::Print() const { Printer::Print(ToString()); } -string BaseExpression::GetName() const { +Identifier BaseExpression::GetName() const { #ifdef DEBUG if (DBConfigOptions::debug_print_bindings) { - return ToString(); + return Identifier(ToString()); } #endif - return !alias.empty() ? alias : ToString(); + return !alias.empty() ? alias : Identifier(ToString()); } bool BaseExpression::Equals(const BaseExpression &other) const { diff --git a/src/duckdb/src/parser/column_definition.cpp b/src/duckdb/src/parser/column_definition.cpp index b591a1091..0c92d0cb2 100644 --- a/src/duckdb/src/parser/column_definition.cpp +++ b/src/duckdb/src/parser/column_definition.cpp @@ -7,11 +7,11 @@ namespace duckdb { -ColumnDefinition::ColumnDefinition(string name_p, LogicalType type_p) +ColumnDefinition::ColumnDefinition(Identifier name_p, LogicalType type_p) : name(std::move(name_p)), type(std::move(type_p)) { } -ColumnDefinition::ColumnDefinition(string name_p, LogicalType type_p, unique_ptr expression, +ColumnDefinition::ColumnDefinition(Identifier name_p, LogicalType type_p, unique_ptr expression, TableColumnType category) : name(std::move(name_p)), type(std::move(type_p)), category(category), expression(std::move(expression)) { } @@ -64,10 +64,10 @@ void ColumnDefinition::SetType(const LogicalType &type) { this->type = type; } -const string &ColumnDefinition::Name() const { +const Identifier &ColumnDefinition::Name() const { return name; } -void ColumnDefinition::SetName(const string &name) { +void ColumnDefinition::SetName(const Identifier &name) { this->name = name; } @@ -156,7 +156,7 @@ string ColumnDefinition::ToSQLString() const { auto &expr = generated_expression.get(); D_ASSERT(expr.GetExpressionType() == ExpressionType::OPERATOR_CAST); auto &cast_expr = expr.Cast(); - generated_expression = *cast_expr.child; + generated_expression = cast_expr.Child(); } result += " GENERATED ALWAYS AS(" + generated_expression.get().ToString() + ")"; } else if (HasDefaultValue()) { @@ -185,7 +185,7 @@ static void InnerGetListOfDependencies(ParsedExpression &expr, vector &d if (expr.GetExpressionType() == ExpressionType::COLUMN_REF) { auto columnref = expr.Cast(); auto &name = columnref.GetColumnName(); - dependencies.push_back(name); + dependencies.emplace_back(name); } ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) { if (expr.GetExpressionType() == ExpressionType::LAMBDA) { @@ -200,7 +200,7 @@ void ColumnDefinition::GetListOfDependencies(vector &dependencies) const InnerGetListOfDependencies(*expression, dependencies); } -string ColumnDefinition::GetName() const { +Identifier ColumnDefinition::GetName() const { return name; } diff --git a/src/duckdb/src/parser/column_list.cpp b/src/duckdb/src/parser/column_list.cpp index f0b29f483..9bf34743b 100644 --- a/src/duckdb/src/parser/column_list.cpp +++ b/src/duckdb/src/parser/column_list.cpp @@ -38,9 +38,9 @@ void ColumnList::Finalize() { void ColumnList::AddToNameMap(ColumnDefinition &col) { if (allow_duplicate_names) { idx_t index = 1; - string base_name = col.Name(); + string base_name = col.Name().GetIdentifierName(); while (name_map.find(col.Name()) != name_map.end()) { - col.SetName(base_name + "_" + to_string(index++)); + col.SetName(Identifier(base_name + "_" + to_string(index++))); } } else { if (name_map.find(col.Name()) != name_map.end()) { @@ -66,10 +66,10 @@ ColumnDefinition &ColumnList::GetColumnMutable(PhysicalIndex physical) { return columns[logical_index]; } -ColumnDefinition &ColumnList::GetColumnMutable(const string &name) { +ColumnDefinition &ColumnList::GetColumnMutable(const Identifier &name) { auto entry = name_map.find(name); if (entry == name_map.end()) { - throw InternalException("Column with name \"%s\" does not exist", name); + throw InternalException("Column with name \"%s\" does not exist", name.GetIdentifierName()); } auto logical_index = entry->second; D_ASSERT(logical_index < columns.size()); @@ -92,7 +92,7 @@ const ColumnDefinition &ColumnList::GetColumn(PhysicalIndex physical) const { return columns[logical_index]; } -const ColumnDefinition &ColumnList::GetColumn(const string &name) const { +const ColumnDefinition &ColumnList::GetColumn(const Identifier &name) const { auto entry = name_map.find(name); if (entry == name_map.end()) { throw InternalException("Column with name \"%s\" does not exist", name); @@ -106,7 +106,7 @@ vector ColumnList::GetColumnNames() const { vector names; names.reserve(columns.size()); for (auto &column : columns) { - names.push_back(column.Name()); + names.emplace_back(column.Name()); } return names; } @@ -120,7 +120,7 @@ vector ColumnList::GetColumnTypes() const { return types; } -bool ColumnList::ColumnExists(const string &name) const { +bool ColumnList::ColumnExists(const Identifier &name) const { auto entry = name_map.find(name); return entry != name_map.end(); } @@ -138,7 +138,7 @@ LogicalIndex ColumnList::PhysicalToLogical(PhysicalIndex index) const { return column.Logical(); } -LogicalIndex ColumnList::GetColumnIndex(string &column_name) const { +LogicalIndex ColumnList::GetColumnIndex(Identifier &column_name) const { auto entry = name_map.find(column_name); if (entry == name_map.end()) { return LogicalIndex(DConstants::INVALID_INDEX); diff --git a/src/duckdb/src/parser/common_table_expression_info.cpp b/src/duckdb/src/parser/common_table_expression_info.cpp index 2ee703982..10e9804d3 100644 --- a/src/duckdb/src/parser/common_table_expression_info.cpp +++ b/src/duckdb/src/parser/common_table_expression_info.cpp @@ -35,7 +35,7 @@ unique_ptr CommonTableExpressionInfo::Copy() { } CTEMaterialize CommonTableExpressionInfo::GetMaterializedForSerialization(Serializer &serializer) const { - if (serializer.ShouldSerialize(7)) { + if (serializer.ShouldSerialize(StorageVersion::V1_5_0)) { return materialized; } return CTEMaterialize::CTE_MATERIALIZE_DEFAULT; diff --git a/src/duckdb/src/parser/constraints/foreign_key_constraint.cpp b/src/duckdb/src/parser/constraints/foreign_key_constraint.cpp index a2c7f5891..ddddac24c 100644 --- a/src/duckdb/src/parser/constraints/foreign_key_constraint.cpp +++ b/src/duckdb/src/parser/constraints/foreign_key_constraint.cpp @@ -8,7 +8,8 @@ namespace duckdb { ForeignKeyConstraint::ForeignKeyConstraint() : Constraint(ConstraintType::FOREIGN_KEY) { } -ForeignKeyConstraint::ForeignKeyConstraint(vector pk_columns, vector fk_columns, ForeignKeyInfo info) +ForeignKeyConstraint::ForeignKeyConstraint(vector pk_columns, vector fk_columns, + ForeignKeyInfo info) : Constraint(ConstraintType::FOREIGN_KEY), pk_columns(std::move(pk_columns)), fk_columns(std::move(fk_columns)), info(std::move(info)) { } @@ -25,10 +26,10 @@ string ForeignKeyConstraint::ToString() const { } base += ") REFERENCES "; if (!info.schema.empty() && info.schema != DEFAULT_SCHEMA) { - base += info.schema; + base += info.schema.GetIdentifierName(); base += "."; } - base += info.table; + base += info.table.GetIdentifierName(); if (!pk_columns.empty()) { base += "("; diff --git a/src/duckdb/src/parser/constraints/unique_constraint.cpp b/src/duckdb/src/parser/constraints/unique_constraint.cpp index 43749dbfa..648fb8928 100644 --- a/src/duckdb/src/parser/constraints/unique_constraint.cpp +++ b/src/duckdb/src/parser/constraints/unique_constraint.cpp @@ -11,12 +11,12 @@ UniqueConstraint::UniqueConstraint() : Constraint(ConstraintType::UNIQUE), index UniqueConstraint::UniqueConstraint(const LogicalIndex index, const bool is_primary_key) : Constraint(ConstraintType::UNIQUE), index(index), is_primary_key(is_primary_key) { } -UniqueConstraint::UniqueConstraint(const LogicalIndex index, string column_name_p, const bool is_primary_key) +UniqueConstraint::UniqueConstraint(const LogicalIndex index, Identifier column_name_p, const bool is_primary_key) : UniqueConstraint(index, is_primary_key) { - columns.push_back(std::move(column_name_p)); + columns.emplace_back(std::move(column_name_p)); } -UniqueConstraint::UniqueConstraint(vector columns, const bool is_primary_key) +UniqueConstraint::UniqueConstraint(vector columns, const bool is_primary_key) : Constraint(ConstraintType::UNIQUE), index(DConstants::INVALID_INDEX), columns(std::move(columns)), is_primary_key(is_primary_key) { } @@ -37,7 +37,7 @@ unique_ptr UniqueConstraint::Copy() const { return make_uniq(columns, is_primary_key); } - auto result = make_uniq(index, columns.empty() ? string() : columns[0], is_primary_key); + auto result = make_uniq(index, columns.empty() ? Identifier() : columns[0], is_primary_key); return std::move(result); } @@ -61,12 +61,12 @@ void UniqueConstraint::SetIndex(const LogicalIndex new_index) { index = new_index; } -const vector &UniqueConstraint::GetColumnNames() const { +const vector &UniqueConstraint::GetColumnNames() const { D_ASSERT(!columns.empty()); return columns; } -vector &UniqueConstraint::GetColumnNamesMutable() { +vector &UniqueConstraint::GetColumnNamesMutable() { D_ASSERT(!columns.empty()); return columns; } @@ -86,7 +86,7 @@ vector UniqueConstraint::GetLogicalIndexes(const ColumnList &colum return indexes; } -string UniqueConstraint::GetName(const string &table_name) const { +Identifier UniqueConstraint::GetName(const Identifier &table_name) const { auto type = IsPrimaryKey() ? IndexConstraintType::PRIMARY : IndexConstraintType::UNIQUE; auto type_name = EnumUtil::ToString(type); @@ -94,7 +94,7 @@ string UniqueConstraint::GetName(const string &table_name) const { for (const auto &column_name : GetColumnNames()) { name += "_" + column_name; } - return type_name + "_" + table_name + name; + return Identifier(type_name + "_" + table_name + name); } } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/between_expression.cpp b/src/duckdb/src/parser/expression/between_expression.cpp index 4c5e90df1..acc313c7d 100644 --- a/src/duckdb/src/parser/expression/between_expression.cpp +++ b/src/duckdb/src/parser/expression/between_expression.cpp @@ -17,23 +17,4 @@ string BetweenExpression::ToString() const { return ToString(Input(), LowerBound(), UpperBound()); } -bool BetweenExpression::Equal(const BetweenExpression &a, const BetweenExpression &b) { - if (!a.input->Equals(*b.input)) { - return false; - } - if (!a.lower->Equals(*b.lower)) { - return false; - } - if (!a.upper->Equals(*b.upper)) { - return false; - } - return true; -} - -unique_ptr BetweenExpression::Copy() const { - auto copy = make_uniq(input->Copy(), lower->Copy(), upper->Copy()); - copy->CopyProperties(*this); - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/case_expression.cpp b/src/duckdb/src/parser/expression/case_expression.cpp index 4b009b831..071af0a2e 100644 --- a/src/duckdb/src/parser/expression/case_expression.cpp +++ b/src/duckdb/src/parser/expression/case_expression.cpp @@ -10,35 +10,4 @@ string CaseExpression::ToString() const { return ToString(*this); } -bool CaseExpression::Equal(const CaseExpression &a, const CaseExpression &b) { - if (a.case_checks.size() != b.case_checks.size()) { - return false; - } - for (idx_t i = 0; i < a.case_checks.size(); i++) { - if (!a.case_checks[i].when_expr->Equals(*b.case_checks[i].when_expr)) { - return false; - } - if (!a.case_checks[i].then_expr->Equals(*b.case_checks[i].then_expr)) { - return false; - } - } - if (!a.else_expr->Equals(*b.else_expr)) { - return false; - } - return true; -} - -unique_ptr CaseExpression::Copy() const { - auto copy = make_uniq(); - copy->CopyProperties(*this); - for (auto &check : case_checks) { - CaseCheck new_check; - new_check.when_expr = check.when_expr->Copy(); - new_check.then_expr = check.then_expr->Copy(); - copy->case_checks.push_back(std::move(new_check)); - } - copy->else_expr = else_expr->Copy(); - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/cast_expression.cpp b/src/duckdb/src/parser/expression/cast_expression.cpp index 6cc43bc05..057ab4d07 100644 --- a/src/duckdb/src/parser/expression/cast_expression.cpp +++ b/src/duckdb/src/parser/expression/cast_expression.cpp @@ -17,23 +17,4 @@ string CastExpression::ToString() const { return ToString(*this); } -bool CastExpression::Equal(const CastExpression &a, const CastExpression &b) { - if (!a.child->Equals(*b.child)) { - return false; - } - if (a.cast_type != b.cast_type) { - return false; - } - if (a.try_cast != b.try_cast) { - return false; - } - return true; -} - -unique_ptr CastExpression::Copy() const { - auto copy = make_uniq(cast_type, child->Copy(), try_cast); - copy->CopyProperties(*this); - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/collate_expression.cpp b/src/duckdb/src/parser/expression/collate_expression.cpp index 2874c0033..97d3dc18b 100644 --- a/src/duckdb/src/parser/expression/collate_expression.cpp +++ b/src/duckdb/src/parser/expression/collate_expression.cpp @@ -16,20 +16,4 @@ string CollateExpression::ToString() const { return StringUtil::Format("%s COLLATE %s", child->ToString(), SQLIdentifier(collation)); } -bool CollateExpression::Equal(const CollateExpression &a, const CollateExpression &b) { - if (!a.child->Equals(*b.child)) { - return false; - } - if (a.collation != b.collation) { - return false; - } - return true; -} - -unique_ptr CollateExpression::Copy() const { - auto copy = make_uniq(collation, child->Copy()); - copy->CopyProperties(*this); - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/columnref_expression.cpp b/src/duckdb/src/parser/expression/columnref_expression.cpp index 7188f331e..46d2f1760 100644 --- a/src/duckdb/src/parser/expression/columnref_expression.cpp +++ b/src/duckdb/src/parser/expression/columnref_expression.cpp @@ -10,32 +10,35 @@ namespace duckdb { ColumnRefExpression::ColumnRefExpression() : ParsedExpression(ExpressionType::COLUMN_REF, ExpressionClass::COLUMN_REF) { } -ColumnRefExpression::ColumnRefExpression(string column_name, string table_name) - : ColumnRefExpression(table_name.empty() ? vector {std::move(column_name)} - : vector {std::move(table_name), std::move(column_name)}) { +ColumnRefExpression::ColumnRefExpression(Identifier column_name, Identifier table_name) + : ColumnRefExpression(table_name.empty() ? vector {std::move(column_name)} + : vector {std::move(table_name), std::move(column_name)}) { } -ColumnRefExpression::ColumnRefExpression(string column_name, const BindingAlias &alias) +ColumnRefExpression::ColumnRefExpression(Identifier column_name, const BindingAlias &alias) : ParsedExpression(ExpressionType::COLUMN_REF, ExpressionClass::COLUMN_REF) { if (alias.IsSet()) { if (!alias.GetCatalog().empty()) { - column_names.push_back(alias.GetCatalog()); + column_names.emplace_back(alias.GetCatalog()); } if (!alias.GetSchema().empty()) { - column_names.push_back(alias.GetSchema()); + column_names.emplace_back(alias.GetSchema()); } - column_names.push_back(alias.GetAlias()); + column_names.emplace_back(alias.GetAlias()); } - column_names.push_back(std::move(column_name)); + column_names.emplace_back(std::move(column_name)); } -ColumnRefExpression::ColumnRefExpression(string column_name) - : ColumnRefExpression(vector {std::move(column_name)}) { +ColumnRefExpression::ColumnRefExpression(Identifier column_name) + : ColumnRefExpression(vector {std::move(column_name)}) { } -ColumnRefExpression::ColumnRefExpression(vector column_names_p) - : ParsedExpression(ExpressionType::COLUMN_REF, ExpressionClass::COLUMN_REF), - column_names(std::move(column_names_p)) { +ColumnRefExpression::ColumnRefExpression(vector column_names_p) + : ParsedExpression(ExpressionType::COLUMN_REF, ExpressionClass::COLUMN_REF) { + column_names.reserve(column_names_p.size()); + for (auto &col_name : column_names_p) { + column_names.emplace_back(std::move(col_name)); + } #ifdef DEBUG for (auto &col_name : column_names) { D_ASSERT(!col_name.empty()); @@ -47,12 +50,12 @@ bool ColumnRefExpression::IsQualified() const { return column_names.size() > 1; } -const string &ColumnRefExpression::GetColumnName() const { +const Identifier &ColumnRefExpression::GetColumnName() const { D_ASSERT(column_names.size() <= 4); return column_names.back(); } -const string &ColumnRefExpression::GetTableName() const { +const Identifier &ColumnRefExpression::GetTableName() const { D_ASSERT(column_names.size() >= 2 && column_names.size() <= 4); if (column_names.size() == 4) { return column_names[2]; @@ -63,7 +66,7 @@ const string &ColumnRefExpression::GetTableName() const { return column_names[0]; } -string ColumnRefExpression::GetName() const { +Identifier ColumnRefExpression::GetName() const { return !alias.empty() ? alias : column_names.back(); } @@ -78,30 +81,4 @@ string ColumnRefExpression::ToString() const { return result; } -bool ColumnRefExpression::Equal(const ColumnRefExpression &a, const ColumnRefExpression &b) { - if (a.column_names.size() != b.column_names.size()) { - return false; - } - for (idx_t i = 0; i < a.column_names.size(); i++) { - if (!StringUtil::CIEquals(a.column_names[i], b.column_names[i])) { - return false; - } - } - return true; -} - -hash_t ColumnRefExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - for (auto &column_name : column_names) { - result = CombineHash(result, StringUtil::CIHash(column_name)); - } - return result; -} - -unique_ptr ColumnRefExpression::Copy() const { - auto copy = make_uniq(column_names); - copy->CopyProperties(*this); - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/comparison_expression.cpp b/src/duckdb/src/parser/expression/comparison_expression.cpp index 9f7b5cfc6..1e0a9e282 100644 --- a/src/duckdb/src/parser/expression/comparison_expression.cpp +++ b/src/duckdb/src/parser/expression/comparison_expression.cpp @@ -2,6 +2,9 @@ namespace duckdb { +ComparisonExpression::ComparisonExpression() : ParsedExpression(ExpressionType::INVALID, ExpressionClass::COMPARISON) { +} + ComparisonExpression::ComparisonExpression(ExpressionType type) : ParsedExpression(type, ExpressionClass::COMPARISON) { } @@ -11,23 +14,7 @@ ComparisonExpression::ComparisonExpression(ExpressionType type, unique_ptr(type, *left, *right); -} - -bool ComparisonExpression::Equal(const ComparisonExpression &a, const ComparisonExpression &b) { - if (!a.left->Equals(*b.left)) { - return false; - } - if (!a.right->Equals(*b.right)) { - return false; - } - return true; -} - -unique_ptr ComparisonExpression::Copy() const { - auto copy = make_uniq(type, left->Copy(), right->Copy()); - copy->CopyProperties(*this); - return std::move(copy); + return ToString(type, Left(), Right()); } } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/conjunction_expression.cpp b/src/duckdb/src/parser/expression/conjunction_expression.cpp index 0b689e01d..59ca519c3 100644 --- a/src/duckdb/src/parser/expression/conjunction_expression.cpp +++ b/src/duckdb/src/parser/expression/conjunction_expression.cpp @@ -4,6 +4,10 @@ namespace duckdb { +ConjunctionExpression::ConjunctionExpression() + : ParsedExpression(ExpressionType::INVALID, ExpressionClass::CONJUNCTION) { +} + ConjunctionExpression::ConjunctionExpression(ExpressionType type) : ParsedExpression(type, ExpressionClass::CONJUNCTION) { } @@ -38,20 +42,4 @@ string ConjunctionExpression::ToString() const { return ToString(*this); } -bool ConjunctionExpression::Equal(const ConjunctionExpression &a, const ConjunctionExpression &b) { - return ExpressionUtil::SetEquals(a.children, b.children); -} - -unique_ptr ConjunctionExpression::Copy() const { - vector> copy_children; - copy_children.reserve(children.size()); - for (auto &expr : children) { - copy_children.push_back(expr->Copy()); - } - - auto copy = make_uniq(type, std::move(copy_children)); - copy->CopyProperties(*this); - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/constant_expression.cpp b/src/duckdb/src/parser/expression/constant_expression.cpp index 5e19bc46b..c486d9acd 100644 --- a/src/duckdb/src/parser/expression/constant_expression.cpp +++ b/src/duckdb/src/parser/expression/constant_expression.cpp @@ -16,18 +16,4 @@ string ConstantExpression::ToString() const { return value.ToSQLString(); } -bool ConstantExpression::Equal(const ConstantExpression &a, const ConstantExpression &b) { - return a.value.type() == b.value.type() && !ValueOperations::DistinctFrom(a.value, b.value); -} - -hash_t ConstantExpression::Hash() const { - return value.Hash(); -} - -unique_ptr ConstantExpression::Copy() const { - auto copy = make_uniq(value); - copy->CopyProperties(*this); - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/default_expression.cpp b/src/duckdb/src/parser/expression/default_expression.cpp index 7fe7d8f91..31553b565 100644 --- a/src/duckdb/src/parser/expression/default_expression.cpp +++ b/src/duckdb/src/parser/expression/default_expression.cpp @@ -9,10 +9,4 @@ string DefaultExpression::ToString() const { return "DEFAULT"; } -unique_ptr DefaultExpression::Copy() const { - auto copy = make_uniq(); - copy->CopyProperties(*this); - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/function_expression.cpp b/src/duckdb/src/parser/expression/function_expression.cpp index 8503fa657..fcf4cbc55 100644 --- a/src/duckdb/src/parser/expression/function_expression.cpp +++ b/src/duckdb/src/parser/expression/function_expression.cpp @@ -4,106 +4,206 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/types/hash.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { FunctionExpression::FunctionExpression() : ParsedExpression(ExpressionType::FUNCTION, ExpressionClass::FUNCTION) { } -FunctionExpression::FunctionExpression(string catalog, string schema, const string &function_name, +FunctionExpression::FunctionExpression(Identifier catalog, Identifier schema, const Identifier &function_name, vector> children_p, unique_ptr filter, unique_ptr order_bys_p, bool distinct, bool is_operator, bool export_state_p) : ParsedExpression(ExpressionType::FUNCTION, ExpressionClass::FUNCTION), catalog(std::move(catalog)), - schema(std::move(schema)), function_name(StringUtil::Lower(function_name)), is_operator(is_operator), - children(std::move(children_p)), distinct(distinct), filter(std::move(filter)), order_bys(std::move(order_bys_p)), + schema(std::move(schema)), function_name(StringUtil::Lower(function_name.GetIdentifierName())), + is_operator(is_operator), distinct(distinct), filter(std::move(filter)), order_bys(std::move(order_bys_p)), export_state(export_state_p) { + arguments.reserve(children_p.size()); + for (auto &child : children_p) { + arguments.emplace_back(std::move(child)); + } D_ASSERT(!function_name.empty()); if (!order_bys) { order_bys = make_uniq(); } } -FunctionExpression::FunctionExpression(const string &function_name, vector> children_p, +FunctionExpression::FunctionExpression(const Identifier &function_name, vector> children_p, unique_ptr filter, unique_ptr order_bys, bool distinct, bool is_operator, bool export_state_p) - : FunctionExpression(INVALID_CATALOG, INVALID_SCHEMA, function_name, std::move(children_p), std::move(filter), - std::move(order_bys), distinct, is_operator, export_state_p) { + : FunctionExpression(Identifier::InvalidCatalog(), Identifier::InvalidSchema(), function_name, + std::move(children_p), std::move(filter), std::move(order_bys), distinct, is_operator, + export_state_p) { } -string FunctionExpression::ToString() const { - return ToString(*this, catalog, schema, function_name, is_operator, distinct, - filter.get(), order_bys.get(), export_state, true); +FunctionExpression::FunctionExpression(Identifier catalog_name, Identifier schema_name, const Identifier &function_name, + vector children, unique_ptr filter, + unique_ptr order_bys_p, bool distinct, bool is_operator, + bool export_state) + : ParsedExpression(ExpressionType::FUNCTION, ExpressionClass::FUNCTION), catalog(std::move(catalog_name)), + schema(std::move(schema_name)), function_name(StringUtil::Lower(function_name.GetIdentifierName())), + is_operator(is_operator), arguments(std::move(children)), distinct(distinct), filter(std::move(filter)), + order_bys(std::move(order_bys_p)), export_state(export_state) { + D_ASSERT(!function_name.empty()); + if (!order_bys) { + this->order_bys = make_uniq(); + } } -bool FunctionExpression::Equal(const FunctionExpression &a, const FunctionExpression &b) { - if (a.catalog != b.catalog || a.schema != b.schema || a.function_name != b.function_name || - b.distinct != a.distinct) { - return false; - } - if (b.children.size() != a.children.size()) { - return false; - } - for (idx_t i = 0; i < a.children.size(); i++) { - if (!a.children[i]->Equals(*b.children[i])) { - return false; +FunctionExpression::FunctionExpression(const Identifier &function_name, vector children, + unique_ptr filter, unique_ptr order_bys, + bool distinct, bool is_operator, bool export_state) + : FunctionExpression(Identifier::InvalidCatalog(), Identifier::InvalidSchema(), function_name, std::move(children), + std::move(filter), std::move(order_bys), distinct, is_operator, export_state) { +} + +string FunctionExpression::ToString() const { + if (is_operator) { + // built-in operator + D_ASSERT(!distinct); + if (arguments.size() == 1) { + if (StringUtil::Contains(function_name.GetIdentifierName(), "__postfix")) { + return "((" + arguments[0].ToString() + ")" + + StringUtil::Replace(function_name.GetIdentifierName(), "__postfix", "") + ")"; + } + return function_name.GetIdentifierName() + "(" + arguments[0].ToString() + ")"; + } + if (arguments.size() == 2) { + return StringUtil::Format("(%s %s %s)", arguments[0].ToString(), function_name.GetIdentifierName(), + arguments[1].ToString()); } } - if (!ParsedExpression::Equals(a.filter, b.filter)) { - return false; + // standard function call + string result; + if (!catalog.empty()) { + result += SQLIdentifier(catalog) + "."; } - if (!OrderModifier::Equals(a.order_bys, b.order_bys)) { - return false; + if (!schema.empty()) { + result += SQLIdentifier(schema) + "."; } - if (a.export_state != b.export_state) { - return false; + result += SQLIdentifier(function_name); + result += "("; + if (distinct) { + result += "DISTINCT "; } - return true; -} + result += StringUtil::Join(arguments, arguments.size(), ", ", + [&](const FunctionArgument &child) { return child.ToString(); }); + // ordered aggregate + if (order_bys && !order_bys->orders.empty()) { + if (arguments.empty()) { + result += ") WITHIN GROUP ("; + } + result += " ORDER BY "; + for (idx_t i = 0; i < order_bys->orders.size(); i++) { + if (i > 0) { + result += ", "; + } + result += order_bys->orders[i].ToString(); + } + } + result += ")"; -hash_t FunctionExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - result = CombineHash(result, duckdb::Hash(schema.c_str())); - result = CombineHash(result, duckdb::Hash(function_name.c_str())); - result = CombineHash(result, duckdb::Hash(distinct)); - result = CombineHash(result, duckdb::Hash(export_state)); - return result; -} + // filtered aggregate + if (filter) { + result += " FILTER (WHERE " + filter->ToString() + ")"; + } -unique_ptr FunctionExpression::Copy() const { - vector> copy_children; - unique_ptr filter_copy; - copy_children.reserve(children.size()); - for (auto &child : children) { - copy_children.push_back(child->Copy()); + if (export_state) { + result += " EXPORT_STATE"; } - if (filter) { - filter_copy = filter->Copy(); - } - auto order_copy = order_bys ? unique_ptr_cast(order_bys->Copy()) : nullptr; - auto copy = - make_uniq(catalog, schema, function_name, std::move(copy_children), std::move(filter_copy), - std::move(order_copy), distinct, is_operator, export_state); - copy->CopyProperties(*this); - return std::move(copy); + + return result; } void FunctionExpression::Verify() const { D_ASSERT(!function_name.empty()); } -optional_ptr FunctionExpression::IsLambdaFunction() const { +optional_ptr FunctionExpression::IsLambdaFunction() { // Ignore the ->> operator (JSON extension). if (function_name == "->>") { return nullptr; } // Check the children for lambda expressions. - for (auto &child : children) { - if (child->GetExpressionClass() == ExpressionClass::LAMBDA) { - return *child; + for (auto &child : arguments) { + if (child.GetExpression().GetExpressionClass() == ExpressionClass::LAMBDA) { + return *child.GetExpressionMutable(); } } return nullptr; } +void FunctionExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "function_name", function_name); + serializer.WritePropertyWithDefault(201, "schema", schema); + + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + // Legacy serialization. + vector> children; + for (auto &arg : arguments) { + auto copy = arg.GetExpression().Copy(); + copy->SetAlias(arg.GetName()); + children.push_back(std::move(copy)); + } + serializer.WritePropertyWithDefault>>(202, "children", children); + } + + serializer.WritePropertyWithDefault>(203, "filter", filter); + serializer.WritePropertyWithDefault>(204, "order_bys", order_bys); + serializer.WritePropertyWithDefault(205, "distinct", distinct); + serializer.WritePropertyWithDefault(206, "is_operator", is_operator); + serializer.WritePropertyWithDefault(207, "export_state", export_state); + serializer.WritePropertyWithDefault(208, "catalog", catalog); + + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + serializer.WritePropertyWithDefault>(209, "arguments", arguments); + } +} + +unique_ptr FunctionExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new FunctionExpression()); + deserializer.ReadPropertyWithDefault(200, "function_name", result->function_name); + deserializer.ReadPropertyWithDefault(201, "schema", result->schema); + + // Legacy children deserialization + vector> children; + deserializer.ReadPropertyWithDefault>>(202, "children", children); + if (!children.empty()) { + result->arguments.reserve(children.size()); + for (auto &child : children) { + auto alias = child->GetAlias(); + result->arguments.emplace_back(alias, std::move(child)); + } + + // Mark this function expression as a legacy function call, so that the binder can handle it accordingly. + result->is_legacy_function_call = true; + } + + deserializer.ReadPropertyWithDefault>(203, "filter", result->filter); + auto order_bys = deserializer.ReadPropertyWithDefault>(204, "order_bys"); + result->order_bys = unique_ptr_cast(std::move(order_bys)); + deserializer.ReadPropertyWithDefault(205, "distinct", result->distinct); + deserializer.ReadPropertyWithDefault(206, "is_operator", result->is_operator); + deserializer.ReadPropertyWithDefault(207, "export_state", result->export_state); + deserializer.ReadPropertyWithDefault(208, "catalog", result->catalog); + + // New children deserialization + if (children.empty()) { + deserializer.ReadPropertyWithDefault>(209, "arguments", result->arguments); + } + + return std::move(result); +} + +hash_t FunctionArgument::Hash() const { + hash_t result = duckdb::Hash(name.c_str()); + if (expression) { + result = CombineHash(result, expression->Hash()); + } + return result; +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/lambda_expression.cpp b/src/duckdb/src/parser/expression/lambda_expression.cpp index 7e0a812a2..e1d48556a 100644 --- a/src/duckdb/src/parser/expression/lambda_expression.cpp +++ b/src/duckdb/src/parser/expression/lambda_expression.cpp @@ -14,13 +14,13 @@ LambdaExpression::LambdaExpression(vector named_parameters_p, unique_ptr : ParsedExpression(ExpressionType::LAMBDA, ExpressionClass::LAMBDA), syntax_type(LambdaSyntaxType::LAMBDA_KEYWORD), expr(std::move(expr)) { if (named_parameters_p.size() == 1) { - lhs = make_uniq(named_parameters_p.back()); + lhs = make_uniq(Identifier(named_parameters_p.back())); return; } // Create a dummy row function and insert the children. vector> children; for (const auto &name : named_parameters_p) { - auto child = make_uniq(name); + auto child = make_uniq(Identifier(name)); children.push_back(std::move(child)); } lhs = make_uniq("row", std::move(children)); @@ -36,26 +36,26 @@ vector> LambdaExpression::ExtractColumnRefExpr // since we can't distinguish between a lambda function and the JSON operator yet vector> column_refs; - if (lhs->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + if (Left().GetExpressionClass() == ExpressionClass::COLUMN_REF) { // single column reference - column_refs.emplace_back(*lhs); + column_refs.emplace_back(Left()); return column_refs; } - if (lhs->GetExpressionClass() == ExpressionClass::FUNCTION) { + if (Left().GetExpressionClass() == ExpressionClass::FUNCTION) { // list of column references - auto &func_expr = lhs->Cast(); - if (func_expr.function_name != "row") { + auto &func_expr = Left().Cast(); + if (func_expr.FunctionName() != "row") { error_message = InvalidParametersErrorMessage(); return column_refs; } - for (auto &child : func_expr.children) { - if (child->GetExpressionClass() != ExpressionClass::COLUMN_REF) { + for (auto &child : func_expr.GetArguments()) { + if (child.GetExpression().GetExpressionClass() != ExpressionClass::COLUMN_REF) { error_message = InvalidParametersErrorMessage(); return column_refs; } - column_refs.emplace_back(*child); + column_refs.emplace_back(child.GetExpression()); } } @@ -70,8 +70,8 @@ string LambdaExpression::InvalidParametersErrorMessage() { return "Invalid lambda parameters! Parameters must be unqualified comma-separated names like x or (x, y)."; } -bool LambdaExpression::IsLambdaParameter(const vector> &lambda_params, - const string ¶meter_name) { +bool LambdaExpression::IsLambdaParameter(const vector &lambda_params, + const Identifier ¶meter_name) { for (const auto &level : lambda_params) { if (level.find(parameter_name) != level.end()) { return true; @@ -81,8 +81,8 @@ bool LambdaExpression::IsLambdaParameter(const vector> &la } string LambdaExpression::ToString() const { - if (syntax_type != LambdaSyntaxType::LAMBDA_KEYWORD) { - return "(" + lhs->ToString() + " -> " + expr->ToString() + ")"; + if (GetLambdaSyntaxType() != LambdaSyntaxType::LAMBDA_KEYWORD) { + return "(" + Left().ToString() + " -> " + Right().ToString() + ")"; } string str = ""; @@ -99,25 +99,7 @@ string LambdaExpression::ToString() const { } str += cast_column_ref.ToString() + ", "; } - return str + ": " + expr->ToString() + ")"; -} - -bool LambdaExpression::Equal(const LambdaExpression &a, const LambdaExpression &b) { - return a.lhs->Equals(*b.lhs) && a.expr->Equals(*b.expr); -} - -hash_t LambdaExpression::Hash() const { - hash_t result = lhs->Hash(); - ParsedExpression::Hash(); - result = CombineHash(result, expr->Hash()); - return result; -} - -unique_ptr LambdaExpression::Copy() const { - auto copy = make_uniq(lhs->Copy(), expr->Copy()); - copy->syntax_type = syntax_type; - copy->CopyProperties(*this); - return std::move(copy); + return str + ": " + Right().ToString() + ")"; } } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/lambdaref_expression.cpp b/src/duckdb/src/parser/expression/lambdaref_expression.cpp index f1e7e59bf..7c2655b7c 100644 --- a/src/duckdb/src/parser/expression/lambdaref_expression.cpp +++ b/src/duckdb/src/parser/expression/lambdaref_expression.cpp @@ -5,7 +5,7 @@ namespace duckdb { -LambdaRefExpression::LambdaRefExpression(idx_t lambda_idx, string column_name_p) +LambdaRefExpression::LambdaRefExpression(idx_t lambda_idx, Identifier column_name_p) : ParsedExpression(ExpressionType::LAMBDA_REF, ExpressionClass::LAMBDA_REF), lambda_idx(lambda_idx), column_name(std::move(column_name_p)) { alias = column_name; @@ -15,7 +15,7 @@ bool LambdaRefExpression::IsScalar() const { throw InternalException("lambda reference expressions are transient, IsScalar should never be called"); } -string LambdaRefExpression::GetName() const { +Identifier LambdaRefExpression::GetName() const { return column_name; } @@ -23,20 +23,9 @@ string LambdaRefExpression::ToString() const { throw InternalException("lambda reference expressions are transient, ToString should never be called"); } -hash_t LambdaRefExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - result = CombineHash(result, lambda_idx); - result = CombineHash(result, StringUtil::CIHash(column_name)); - return result; -} - -unique_ptr LambdaRefExpression::Copy() const { - throw InternalException("lambda reference expressions are transient, Copy should never be called"); -} - unique_ptr LambdaRefExpression::FindMatchingBinding(optional_ptr> &lambda_bindings, - const string &column_name) { + const Identifier &column_name) { // if this is a lambda parameter, then we temporarily add a BoundLambdaRef, // which we capture and remove later diff --git a/src/duckdb/src/parser/expression/operator_expression.cpp b/src/duckdb/src/parser/expression/operator_expression.cpp index 5b248bb1c..d2e597d65 100644 --- a/src/duckdb/src/parser/expression/operator_expression.cpp +++ b/src/duckdb/src/parser/expression/operator_expression.cpp @@ -4,6 +4,9 @@ namespace duckdb { +OperatorExpression::OperatorExpression() : ParsedExpression(ExpressionType::INVALID, ExpressionClass::OPERATOR) { +} + OperatorExpression::OperatorExpression(ExpressionType type, unique_ptr left, unique_ptr right) : ParsedExpression(type, ExpressionClass::OPERATOR) { @@ -23,25 +26,4 @@ string OperatorExpression::ToString() const { return ToString(*this); } -bool OperatorExpression::Equal(const OperatorExpression &a, const OperatorExpression &b) { - if (a.children.size() != b.children.size()) { - return false; - } - for (idx_t i = 0; i < a.children.size(); i++) { - if (!a.children[i]->Equals(*b.children[i])) { - return false; - } - } - return true; -} - -unique_ptr OperatorExpression::Copy() const { - auto copy = make_uniq(type); - copy->CopyProperties(*this); - for (auto &it : children) { - copy->children.push_back(it->Copy()); - } - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/parameter_expression.cpp b/src/duckdb/src/parser/expression/parameter_expression.cpp index 90149d195..baff0dcec 100644 --- a/src/duckdb/src/parser/expression/parameter_expression.cpp +++ b/src/duckdb/src/parser/expression/parameter_expression.cpp @@ -13,20 +13,4 @@ string ParameterExpression::ToString() const { return "$" + identifier; } -unique_ptr ParameterExpression::Copy() const { - auto copy = make_uniq(); - copy->identifier = identifier; - copy->CopyProperties(*this); - return std::move(copy); -} - -bool ParameterExpression::Equal(const ParameterExpression &a, const ParameterExpression &b) { - return StringUtil::CIEquals(a.identifier, b.identifier); -} - -hash_t ParameterExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - return CombineHash(duckdb::Hash(identifier.c_str(), identifier.size()), result); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/positional_reference_expression.cpp b/src/duckdb/src/parser/expression/positional_reference_expression.cpp index b1a4d54b5..fd8fd4573 100644 --- a/src/duckdb/src/parser/expression/positional_reference_expression.cpp +++ b/src/duckdb/src/parser/expression/positional_reference_expression.cpp @@ -18,20 +18,4 @@ string PositionalReferenceExpression::ToString() const { return "#" + to_string(index); } -bool PositionalReferenceExpression::Equal(const PositionalReferenceExpression &a, - const PositionalReferenceExpression &b) { - return a.index == b.index; -} - -unique_ptr PositionalReferenceExpression::Copy() const { - auto copy = make_uniq(index); - copy->CopyProperties(*this); - return std::move(copy); -} - -hash_t PositionalReferenceExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - return CombineHash(duckdb::Hash(index), result); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/star_expression.cpp b/src/duckdb/src/parser/expression/star_expression.cpp index 8fa3a7f50..75cf690c4 100644 --- a/src/duckdb/src/parser/expression/star_expression.cpp +++ b/src/duckdb/src/parser/expression/star_expression.cpp @@ -7,7 +7,7 @@ namespace duckdb { -StarExpression::StarExpression(string relation_name_p) +StarExpression::StarExpression(Identifier relation_name_p) : ParsedExpression(ExpressionType::STAR, ExpressionClass::STAR), relation_name(std::move(relation_name_p)) { } @@ -68,31 +68,6 @@ string StarExpression::ToString() const { return result; } -bool StarExpression::Equal(const StarExpression &a, const StarExpression &b) { - if (a.relation_name != b.relation_name || a.exclude_list != b.exclude_list || a.rename_list != b.rename_list) { - return false; - } - if (a.columns != b.columns) { - return false; - } - if (a.replace_list.size() != b.replace_list.size()) { - return false; - } - for (auto &entry : a.replace_list) { - auto other_entry = b.replace_list.find(entry.first); - if (other_entry == b.replace_list.end()) { - return false; - } - if (!entry.second->Equals(*other_entry->second)) { - return false; - } - } - if (!ParsedExpression::Equals(a.expr, b.expr)) { - return false; - } - return true; -} - bool StarExpression::IsStar(const ParsedExpression &a) { if (a.GetExpressionClass() != ExpressionClass::STAR) { return false; @@ -117,11 +92,11 @@ bool StarExpression::IsColumnsUnpacked(const ParsedExpression &a) { } unique_ptr -StarExpression::DeserializeStarExpression(string &&relation_name, const case_insensitive_set_t &exclude_list, - case_insensitive_map_t> &&replace_list, - bool columns, unique_ptr expr, bool unpacked, +StarExpression::DeserializeStarExpression(Identifier &&relation_name, const identifier_set_t &exclude_list, + identifier_map_t> &&replace_list, bool columns, + unique_ptr expr, bool unpacked, const qualified_column_set_t &qualified_exclude_list, - qualified_column_map_t &&rename_list) { + qualified_column_map_t &&rename_list) { auto result = duckdb::unique_ptr(new StarExpression(exclude_list, qualified_exclude_list)); result->relation_name = std::move(relation_name); result->replace_list = std::move(replace_list); @@ -138,29 +113,16 @@ StarExpression::DeserializeStarExpression(string &&relation_name, const case_ins return std::move(result); } -unique_ptr StarExpression::Copy() const { - auto copy = make_uniq(relation_name); - copy->exclude_list = exclude_list; - for (auto &entry : replace_list) { - copy->replace_list[entry.first] = entry.second->Copy(); - } - copy->rename_list = rename_list; - copy->columns = columns; - copy->expr = expr ? expr->Copy() : nullptr; - copy->CopyProperties(*this); - return std::move(copy); -} - -StarExpression::StarExpression(const case_insensitive_set_t &exclude_list_p, qualified_column_set_t qualified_set) +StarExpression::StarExpression(const identifier_set_t &exclude_list_p, qualified_column_set_t qualified_set) : ParsedExpression(ExpressionType::STAR, ExpressionClass::STAR), exclude_list(std::move(qualified_set)) { for (auto &entry : exclude_list_p) { - exclude_list.insert(QualifiedColumnName(entry)); + exclude_list.insert(QualifiedColumnName(Identifier(entry))); } } -case_insensitive_set_t StarExpression::SerializedExcludeList() const { +identifier_set_t StarExpression::SerializedExcludeList() const { // we serialize non-qualified elements in a separate list of only column names for backwards compatibility - case_insensitive_set_t result; + identifier_set_t result; for (auto &entry : exclude_list) { if (!entry.IsQualified()) { result.insert(entry.column); diff --git a/src/duckdb/src/parser/expression/subquery_expression.cpp b/src/duckdb/src/parser/expression/subquery_expression.cpp index 759675e62..0a9ae0314 100644 --- a/src/duckdb/src/parser/expression/subquery_expression.cpp +++ b/src/duckdb/src/parser/expression/subquery_expression.cpp @@ -25,25 +25,4 @@ string SubqueryExpression::ToString() const { } } -bool SubqueryExpression::Equal(const SubqueryExpression &a, const SubqueryExpression &b) { - if (!a.subquery || !b.subquery) { - return false; - } - if (!ParsedExpression::Equals(a.child, b.child)) { - return false; - } - return a.comparison_type == b.comparison_type && a.subquery_type == b.subquery_type && - a.subquery->Equals(*b.subquery); -} - -unique_ptr SubqueryExpression::Copy() const { - auto copy = make_uniq(); - copy->CopyProperties(*this); - copy->subquery = unique_ptr_cast(subquery->Copy()); - copy->subquery_type = subquery_type; - copy->child = child ? child->Copy() : nullptr; - copy->comparison_type = comparison_type; - return std::move(copy); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/type_expression.cpp b/src/duckdb/src/parser/expression/type_expression.cpp index 14174fb04..05aa8c9fc 100644 --- a/src/duckdb/src/parser/expression/type_expression.cpp +++ b/src/duckdb/src/parser/expression/type_expression.cpp @@ -7,15 +7,20 @@ #include "duckdb/common/types/hash.hpp" namespace duckdb { -TypeExpression::TypeExpression(string catalog, string schema, string type_name, +TypeExpression::TypeExpression(Identifier catalog, Identifier schema, Identifier type_name, vector> children_p) : ParsedExpression(ExpressionType::TYPE, ExpressionClass::TYPE), catalog(std::move(catalog)), schema(std::move(schema)), type_name(std::move(type_name)), children(std::move(children_p)) { D_ASSERT(!this->type_name.empty()); } -TypeExpression::TypeExpression(string type_name, vector> children) - : TypeExpression(INVALID_CATALOG, INVALID_SCHEMA, std::move(type_name), std::move(children)) { +TypeExpression::TypeExpression(Identifier type_name, vector> children) + : TypeExpression(Identifier::InvalidCatalog(), Identifier::InvalidSchema(), std::move(type_name), + std::move(children)) { +} + +TypeExpression::TypeExpression(const string &type_name, vector> children) + : TypeExpression(Identifier(type_name), std::move(children)) { } TypeExpression::TypeExpression() : ParsedExpression(ExpressionType::TYPE, ExpressionClass::TYPE) { @@ -33,16 +38,16 @@ string TypeExpression::ToString() const { auto ¶ms = children; // LIST and ARRAY have special syntax - if (result.empty() && StringUtil::CIEquals(type_name, "LIST") && params.size() == 1) { + if (result.empty() && type_name == "LIST" && params.size() == 1) { return params[0]->ToString() + "[]"; } - if (result.empty() && StringUtil::CIEquals(type_name, "ARRAY") && params.size() == 2) { + if (result.empty() && type_name == "ARRAY" && params.size() == 2) { auto &type_param = params[0]; auto &size_param = params[1]; return type_param->ToString() + "[" + size_param->ToString() + "]"; } // So does STRUCT, MAP and UNION - if (result.empty() && StringUtil::CIEquals(type_name, "STRUCT")) { + if (result.empty() && type_name == "STRUCT") { if (params.empty()) { return "STRUCT"; } @@ -56,7 +61,7 @@ string TypeExpression::ToString() const { struct_result += ")"; return struct_result; } - if (result.empty() && StringUtil::CIEquals(type_name, "UNION")) { + if (result.empty() && type_name == "UNION") { if (params.empty()) { return "UNION"; } @@ -71,27 +76,27 @@ string TypeExpression::ToString() const { return union_result; } - if (result.empty() && StringUtil::CIEquals(type_name, "MAP") && params.size() == 2) { + if (result.empty() && type_name == "MAP" && params.size() == 2) { return "MAP(" + params[0]->ToString() + ", " + params[1]->ToString() + ")"; } - if (result.empty() && StringUtil::CIEquals(type_name, "VARCHAR") && !params.empty()) { - if (params.back()->HasAlias() && StringUtil::CIEquals(params.back()->GetAlias(), "collation")) { + if (result.empty() && type_name == "VARCHAR" && !params.empty()) { + if (params.back()->HasAlias() && params.back()->GetAlias() == "collation") { // Special case for VARCHAR with collation auto collate_expr = params.back()->Cast(); return StringUtil::Format("VARCHAR COLLATE %s", SQLIdentifier(StringValue::Get(collate_expr.GetValue()))); } } - if (result.empty() && StringUtil::CIEquals(type_name, "INTERVAL") && !params.empty()) { + if (result.empty() && type_name == "INTERVAL" && !params.empty()) { // We ignore interval types parameters. return "INTERVAL"; } - auto type_id = TransformStringToLogicalTypeId(type_name); + auto type_id = TransformStringToLogicalTypeId(type_name.GetIdentifierName()); if (type_id != LogicalTypeId::UNBOUND && type_id != LogicalTypeId::SQLNULL) { // Built-in type name - result += type_name; + result += type_name.GetIdentifierName(); } else { result += SQLIdentifier(type_name); } @@ -109,36 +114,8 @@ string TypeExpression::ToString() const { return result; } -unique_ptr TypeExpression::Copy() const { - vector> copy_children; - copy_children.reserve(children.size()); - for (const auto &child : children) { - copy_children.push_back(child->Copy()); - } - - auto copy = make_uniq(catalog, schema, type_name, std::move(copy_children)); - copy->CopyProperties(*this); - - return std::move(copy); -} - -bool TypeExpression::Equal(const TypeExpression &a, const TypeExpression &b) { - if (a.catalog != b.catalog || a.schema != b.schema || a.type_name != b.type_name) { - return false; - } - return ParsedExpression::ListEquals(a.children, b.children); -} - void TypeExpression::Verify() const { D_ASSERT(!type_name.empty()); } -hash_t TypeExpression::Hash() const { - hash_t result = ParsedExpression::Hash(); - result = CombineHash(result, duckdb::Hash(catalog.c_str())); - result = CombineHash(result, duckdb::Hash(schema.c_str())); - result = CombineHash(result, duckdb::Hash(type_name.c_str())); - return result; -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/expression/window_expression.cpp b/src/duckdb/src/parser/expression/window_expression.cpp index f9eb6f559..a8de313b4 100644 --- a/src/duckdb/src/parser/expression/window_expression.cpp +++ b/src/duckdb/src/parser/expression/window_expression.cpp @@ -1,45 +1,39 @@ #include "duckdb/parser/expression/window_expression.hpp" - #include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { -WindowExpression::WindowExpression(ExpressionType type, vector> children_p, - unique_ptr offset_expr, unique_ptr default_expr) - : ParsedExpression(type, ExpressionClass::WINDOW), children(std::move(children_p)) { - if (offset_expr) { - children.emplace_back(std::move(offset_expr)); - } - if (default_expr) { - children.emplace_back(std::move(default_expr)); - } +WindowExpression::WindowExpression() : ParsedExpression(ExpressionType::INVALID, ExpressionClass::WINDOW) { } vector> WindowExpression::SerializedChildren(Serializer &serializer) const { vector> result; - idx_t nargs = children.size(); - if (!serializer.ShouldSerialize(8) && (function_name == "lead" || function_name == "lag")) { + idx_t nargs = arguments.size(); + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0) && (function_name == "lead" || function_name == "lag")) { nargs = 1; } for (idx_t i = 0; i < nargs; ++i) { - result.emplace_back(children[i]->Copy()); + result.emplace_back(arguments[i].GetExpression().Copy()); } return result; } unique_ptr WindowExpression::SerializedOffset(Serializer &serializer) const { - if (!serializer.ShouldSerialize(8) && children.size() > 1 && (function_name == "lead" || function_name == "lag")) { - return children[1]->Copy(); + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0) && arguments.size() > 1 && + (function_name == "lead" || function_name == "lag")) { + return arguments[1].GetExpression().Copy(); } return nullptr; } unique_ptr WindowExpression::SerializedDefault(Serializer &serializer) const { - if (!serializer.ShouldSerialize(8) && children.size() > 2 && (function_name == "lead" || function_name == "lag")) { - return children[2]->Copy(); + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0) && arguments.size() > 2 && + (function_name == "lead" || function_name == "lag")) { + return arguments[2].GetExpression().Copy(); } return nullptr; @@ -93,85 +87,18 @@ string WindowExpression::ExpressionTypeToWindow(ExpressionType expression_type) } void WindowExpression::SetFunctionName(const string &function_name_p) { - function_name = function_name_p; - type = WindowToExpressionType(function_name); + function_name = Identifier(function_name_p); + type = WindowToExpressionType(function_name.GetIdentifierName()); } string WindowExpression::ToString() const { - return ToString(*this, schema, function_name); + return ToString(*this, schema.GetIdentifierName(), + function_name.GetIdentifierName()); } -bool WindowExpression::Equal(const WindowExpression &a, const WindowExpression &b) { - // check if the child expressions are equivalent - if (a.has_ignore_nulls != b.has_ignore_nulls) { - return false; - } - if (a.has_ignore_nulls && a.ignore_nulls != b.ignore_nulls) { - return false; - } - if (a.distinct != b.distinct) { - return false; - } - if (!ParsedExpression::ListEquals(a.children, b.children)) { - return false; - } - if (a.start != b.start || a.end != b.end) { - return false; - } - if (a.exclude_clause != b.exclude_clause) { - return false; - } - // check if the framing expressions are equivalent - if (!ParsedExpression::Equals(a.start_expr, b.start_expr) || !ParsedExpression::Equals(a.end_expr, b.end_expr)) { - return false; - } - - // check if the argument orderings are equivalent - if (a.arg_orders.size() != b.arg_orders.size()) { - return false; - } - for (idx_t i = 0; i < a.arg_orders.size(); i++) { - if (a.arg_orders[i].type != b.arg_orders[i].type) { - return false; - } - if (a.arg_orders[i].null_order != b.arg_orders[i].null_order) { - return false; - } - if (!a.arg_orders[i].expression->Equals(*b.arg_orders[i].expression)) { - return false; - } - } - - // check if the partitions are equivalent - if (!ParsedExpression::ListEquals(a.partitions, b.partitions)) { - return false; - } - // check if the orderings are equivalent - if (a.orders.size() != b.orders.size()) { - return false; - } - for (idx_t i = 0; i < a.orders.size(); i++) { - if (a.orders[i].type != b.orders[i].type) { - return false; - } - if (a.orders[i].null_order != b.orders[i].null_order) { - return false; - } - if (!a.orders[i].expression->Equals(*b.orders[i].expression)) { - return false; - } - } - // check if the filter clauses are equivalent - if (!ParsedExpression::Equals(a.filter_expr, b.filter_expr)) { - return false; - } - - return true; -} - -bool WindowExpression::HasBoundedParts() { - for (auto &child : children) { - if ((*child).GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION) { +bool WindowExpression::HasBoundedParts() const { + for (auto &child : arguments) { + if (child.GetExpression().GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION) { return true; } } @@ -195,38 +122,88 @@ bool WindowExpression::HasBoundedParts() { return false; } -unique_ptr WindowExpression::Copy() const { - auto new_window = make_uniq(catalog, schema, function_name); - new_window->CopyProperties(*this); - - for (auto &child : children) { - new_window->children.push_back(child->Copy()); - } - - for (auto &e : partitions) { - new_window->partitions.push_back(e->Copy()); - } - - for (auto &o : orders) { - new_window->orders.emplace_back(o.type, o.null_order, o.expression->Copy()); - } - - for (auto &o : arg_orders) { - new_window->arg_orders.emplace_back(o.type, o.null_order, o.expression->Copy()); +void WindowExpression::Serialize(Serializer &serializer) const { + ParsedExpression::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "function_name", function_name); + serializer.WritePropertyWithDefault(201, "schema", schema); + serializer.WritePropertyWithDefault(202, "catalog", catalog); + + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + // Legacy serialization. + vector> children; + for (auto &arg : arguments) { + auto copy = arg.GetExpression().Copy(); + copy->SetAlias(arg.GetName()); + children.push_back(std::move(copy)); + } + serializer.WritePropertyWithDefault>>(203, "children", children); + } + + serializer.WritePropertyWithDefault>>(204, "partitions", partitions); + serializer.WritePropertyWithDefault>(205, "orders", orders); + serializer.WriteProperty(206, "start", start); + serializer.WriteProperty(207, "end", end); + serializer.WritePropertyWithDefault>(208, "start_expr", start_expr); + serializer.WritePropertyWithDefault>(209, "end_expr", end_expr); + serializer.WritePropertyWithDefault>(210, "offset_expr", SerializedOffset(serializer)); + serializer.WritePropertyWithDefault>(211, "default_expr", + SerializedDefault(serializer)); + serializer.WritePropertyWithDefault(212, "ignore_nulls", ignore_nulls); + serializer.WritePropertyWithDefault>(213, "filter_expr", filter_expr); + serializer.WritePropertyWithDefault(214, "exclude_clause", exclude_clause, + WindowExcludeMode::NO_OTHER); + serializer.WritePropertyWithDefault(215, "distinct", distinct); + serializer.WritePropertyWithDefault>(216, "arg_orders", arg_orders); + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + serializer.WritePropertyWithDefault(217, "has_ignore_nulls", has_ignore_nulls); + } + + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + serializer.WritePropertyWithDefault>(218, "arguments", arguments); } +} - new_window->filter_expr = filter_expr ? filter_expr->Copy() : nullptr; - - new_window->start = start; - new_window->end = end; - new_window->exclude_clause = exclude_clause; - new_window->start_expr = start_expr ? start_expr->Copy() : nullptr; - new_window->end_expr = end_expr ? end_expr->Copy() : nullptr; - new_window->has_ignore_nulls = has_ignore_nulls; - new_window->ignore_nulls = ignore_nulls; - new_window->distinct = distinct; - - return std::move(new_window); +unique_ptr WindowExpression::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new WindowExpression()); + deserializer.ReadPropertyWithDefault(200, "function_name", result->function_name); + deserializer.ReadPropertyWithDefault(201, "schema", result->schema); + deserializer.ReadPropertyWithDefault(202, "catalog", result->catalog); + + // Legacy children deserialization + vector> children; + deserializer.ReadPropertyWithDefault>>(203, "children", children); + if (!children.empty()) { + result->arguments.reserve(children.size()); + for (auto &child : children) { + auto alias = child->GetAlias(); + result->arguments.emplace_back(std::move(alias), std::move(child)); + } + // Mark this function expression as a legacy function call, so that the binder can handle it accordingly. + result->is_legacy_function_call = true; + } + + deserializer.ReadPropertyWithDefault>>(204, "partitions", result->partitions); + deserializer.ReadPropertyWithDefault>(205, "orders", result->orders); + deserializer.ReadProperty(206, "start", result->start); + deserializer.ReadProperty(207, "end", result->end); + deserializer.ReadPropertyWithDefault>(208, "start_expr", result->start_expr); + deserializer.ReadPropertyWithDefault>(209, "end_expr", result->end_expr); + deserializer.ReadDeletedProperty>(210, "offset_expr"); + deserializer.ReadDeletedProperty>(211, "default_expr"); + deserializer.ReadPropertyWithDefault(212, "ignore_nulls", result->ignore_nulls); + deserializer.ReadPropertyWithDefault>(213, "filter_expr", result->filter_expr); + deserializer.ReadPropertyWithExplicitDefault(214, "exclude_clause", result->exclude_clause, + WindowExcludeMode::NO_OTHER); + deserializer.ReadPropertyWithDefault(215, "distinct", result->distinct); + deserializer.ReadPropertyWithDefault>(216, "arg_orders", result->arg_orders); + deserializer.ReadPropertyWithDefault(217, "has_ignore_nulls", result->has_ignore_nulls); + + // New children deserialization + if (children.empty()) { + deserializer.ReadPropertyWithDefault>(218, "arguments", result->arguments); + } + + return std::move(result); } } // namespace duckdb diff --git a/src/duckdb/src/parser/keyword_helper.cpp b/src/duckdb/src/parser/keyword_helper.cpp index 93b0f1e18..9334f4c9e 100644 --- a/src/duckdb/src/parser/keyword_helper.cpp +++ b/src/duckdb/src/parser/keyword_helper.cpp @@ -1,6 +1,7 @@ #include "duckdb/parser/keyword_helper.hpp" #include "duckdb/parser/peg/keyword_helper.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/common/identifier.hpp" namespace duckdb { @@ -82,6 +83,9 @@ string KeywordHelper::WriteOptionallyQuoted(const string &text, char quote, bool return WriteQuotedAndEscaped(text, quote); } +SQLIdentifier::SQLIdentifier(const Identifier &id) : raw_string(id.GetIdentifierName()) { +} + string SQLIdentifier::ToString(const string &identifier) { if (!KeywordHelper::RequiresQuotes(identifier)) { return identifier; diff --git a/src/duckdb/src/parser/parsed_data/alter_database_info.cpp b/src/duckdb/src/parser/parsed_data/alter_database_info.cpp index 7c7ab75ce..51e56f90c 100644 --- a/src/duckdb/src/parser/parsed_data/alter_database_info.cpp +++ b/src/duckdb/src/parser/parsed_data/alter_database_info.cpp @@ -4,13 +4,13 @@ namespace duckdb { AlterDatabaseInfo::AlterDatabaseInfo(AlterDatabaseType alter_database_type) - : AlterInfo(AlterType::ALTER_DATABASE, string(), "", "", OnEntryNotFound::THROW_EXCEPTION), + : AlterInfo(AlterType::ALTER_DATABASE, Identifier(), Identifier(), Identifier(), OnEntryNotFound::THROW_EXCEPTION), alter_database_type(alter_database_type) { } -AlterDatabaseInfo::AlterDatabaseInfo(AlterDatabaseType alter_database_type, string catalog_p, +AlterDatabaseInfo::AlterDatabaseInfo(AlterDatabaseType alter_database_type, Identifier catalog_p, OnEntryNotFound if_not_found) - : AlterInfo(AlterType::ALTER_DATABASE, std::move(catalog_p), "", "", if_not_found), + : AlterInfo(AlterType::ALTER_DATABASE, std::move(catalog_p), Identifier(), Identifier(), if_not_found), alter_database_type(alter_database_type) { } @@ -24,7 +24,7 @@ CatalogType AlterDatabaseInfo::GetCatalogType() const { RenameDatabaseInfo::RenameDatabaseInfo() : AlterDatabaseInfo(AlterDatabaseType::RENAME_DATABASE) { } -RenameDatabaseInfo::RenameDatabaseInfo(string catalog_p, string new_name_p, OnEntryNotFound if_not_found) +RenameDatabaseInfo::RenameDatabaseInfo(Identifier catalog_p, Identifier new_name_p, OnEntryNotFound if_not_found) : AlterDatabaseInfo(AlterDatabaseType::RENAME_DATABASE, std::move(catalog_p), if_not_found), new_name(std::move(new_name_p)) { } @@ -39,7 +39,7 @@ string RenameDatabaseInfo::ToString() const { if (if_not_found == OnEntryNotFound::RETURN_NULL) { result += "IF EXISTS "; } - result += StringUtil::Format("%s SET ALIAS TO %s", SQLIdentifier(catalog), SQLIdentifier(new_name)); + result += StringUtil::Format("%s SET ALIAS TO %s", catalog, new_name); return result; } diff --git a/src/duckdb/src/parser/parsed_data/alter_info.cpp b/src/duckdb/src/parser/parsed_data/alter_info.cpp index 671f27663..3aead3374 100644 --- a/src/duckdb/src/parser/parsed_data/alter_info.cpp +++ b/src/duckdb/src/parser/parsed_data/alter_info.cpp @@ -5,7 +5,8 @@ namespace duckdb { -AlterInfo::AlterInfo(AlterType type, string catalog_p, string schema_p, string name_p, OnEntryNotFound if_not_found) +AlterInfo::AlterInfo(AlterType type, Identifier catalog_p, Identifier schema_p, Identifier name_p, + OnEntryNotFound if_not_found) : ParseInfo(TYPE), type(type), if_not_found(if_not_found), catalog(std::move(catalog_p)), schema(std::move(schema_p)), name(std::move(name_p)), allow_internal(false) { } diff --git a/src/duckdb/src/parser/parsed_data/alter_table_info.cpp b/src/duckdb/src/parser/parsed_data/alter_table_info.cpp index 16931df19..0f92ff3a3 100644 --- a/src/duckdb/src/parser/parsed_data/alter_table_info.cpp +++ b/src/duckdb/src/parser/parsed_data/alter_table_info.cpp @@ -8,9 +8,9 @@ namespace duckdb { //===--------------------------------------------------------------------===// // ChangeOwnershipInfo //===--------------------------------------------------------------------===// -ChangeOwnershipInfo::ChangeOwnershipInfo(CatalogType entry_catalog_type, string entry_catalog_p, string entry_schema_p, - string entry_name_p, string owner_schema_p, string owner_name_p, - OnEntryNotFound if_not_found) +ChangeOwnershipInfo::ChangeOwnershipInfo(CatalogType entry_catalog_type, Identifier entry_catalog_p, + Identifier entry_schema_p, Identifier entry_name_p, Identifier owner_schema_p, + Identifier owner_name_p, OnEntryNotFound if_not_found) : AlterInfo(AlterType::CHANGE_OWNERSHIP, std::move(entry_catalog_p), std::move(entry_schema_p), std::move(entry_name_p), if_not_found), entry_catalog_type(entry_catalog_type), owner_schema(std::move(owner_schema_p)), @@ -48,8 +48,8 @@ string ChangeOwnershipInfo::ToString() const { //===--------------------------------------------------------------------===// // SetCommentInfo //===--------------------------------------------------------------------===// -SetCommentInfo::SetCommentInfo(CatalogType entry_catalog_type, string entry_catalog_p, string entry_schema_p, - string entry_name_p, Value new_comment_value_p, OnEntryNotFound if_not_found) +SetCommentInfo::SetCommentInfo(CatalogType entry_catalog_type, Identifier entry_catalog_p, Identifier entry_schema_p, + Identifier entry_name_p, Value new_comment_value_p, OnEntryNotFound if_not_found) : AlterInfo(AlterType::SET_COMMENT, std::move(entry_catalog_p), std::move(entry_schema_p), std::move(entry_name_p), if_not_found), entry_catalog_type(entry_catalog_type), comment_value(std::move(new_comment_value_p)) { @@ -101,7 +101,7 @@ CatalogType AlterTableInfo::GetCatalogType() const { //===--------------------------------------------------------------------===// // RenameColumnInfo //===--------------------------------------------------------------------===// -RenameColumnInfo::RenameColumnInfo(AlterEntryData data, string old_name_p, string new_name_p) +RenameColumnInfo::RenameColumnInfo(AlterEntryData data, Identifier old_name_p, Identifier new_name_p) : AlterTableInfo(AlterTableType::RENAME_COLUMN, std::move(data)), old_name(std::move(old_name_p)), new_name(std::move(new_name_p)) { } @@ -134,7 +134,7 @@ string RenameColumnInfo::ToString() const { //===--------------------------------------------------------------------===// // RenameFieldInfo //===--------------------------------------------------------------------===// -RenameFieldInfo::RenameFieldInfo(AlterEntryData data, vector column_path_p, string new_name_p) +RenameFieldInfo::RenameFieldInfo(AlterEntryData data, vector column_path_p, Identifier new_name_p) : AlterTableInfo(AlterTableType::RENAME_FIELD, std::move(data)), column_path(std::move(column_path_p)), new_name(std::move(new_name_p)) { } @@ -175,7 +175,7 @@ string RenameFieldInfo::ToString() const { RenameTableInfo::RenameTableInfo() : AlterTableInfo(AlterTableType::RENAME_TABLE) { } -RenameTableInfo::RenameTableInfo(AlterEntryData data, string new_name_p) +RenameTableInfo::RenameTableInfo(AlterEntryData data, Identifier new_name_p) : AlterTableInfo(AlterTableType::RENAME_TABLE, std::move(data)), new_table_name(std::move(new_name_p)) { } @@ -246,7 +246,7 @@ AddFieldInfo::AddFieldInfo(ColumnDefinition new_field_p) : AlterTableInfo(AlterTableType::ADD_FIELD), new_field(std::move(new_field_p)) { } -AddFieldInfo::AddFieldInfo(AlterEntryData data, vector column_path_p, ColumnDefinition new_field_p, +AddFieldInfo::AddFieldInfo(AlterEntryData data, vector column_path_p, ColumnDefinition new_field_p, bool if_field_not_exists) : AlterTableInfo(AlterTableType::ADD_FIELD, std::move(data)), column_path(std::move(column_path_p)), new_field(std::move(new_field_p)), if_field_not_exists(if_field_not_exists) { @@ -293,7 +293,8 @@ RemoveColumnInfo::~RemoveColumnInfo() { } unique_ptr RemoveColumnInfo::Copy() const { - return make_uniq_base(GetAlterEntryData(), removed_column, if_column_exists, cascade); + return make_uniq_base(GetAlterEntryData(), removed_column.GetIdentifierName(), + if_column_exists, cascade); } string RemoveColumnInfo::ToString() const { @@ -321,7 +322,8 @@ string RemoveColumnInfo::ToString() const { RemoveFieldInfo::RemoveFieldInfo() : AlterTableInfo(AlterTableType::REMOVE_FIELD) { } -RemoveFieldInfo::RemoveFieldInfo(AlterEntryData data, vector column_path_p, bool if_column_exists, bool cascade) +RemoveFieldInfo::RemoveFieldInfo(AlterEntryData data, vector column_path_p, bool if_column_exists, + bool cascade) : AlterTableInfo(AlterTableType::REMOVE_FIELD, std::move(data)), column_path(std::move(column_path_p)), if_column_exists(if_column_exists), cascade(cascade) { } @@ -362,7 +364,7 @@ string RemoveFieldInfo::ToString() const { ChangeColumnTypeInfo::ChangeColumnTypeInfo() : AlterTableInfo(AlterTableType::ALTER_COLUMN_TYPE) { } -ChangeColumnTypeInfo::ChangeColumnTypeInfo(AlterEntryData data, string column_name, LogicalType target_type, +ChangeColumnTypeInfo::ChangeColumnTypeInfo(AlterEntryData data, Identifier column_name, LogicalType target_type, unique_ptr expression) : AlterTableInfo(AlterTableType::ALTER_COLUMN_TYPE, std::move(data)), column_name(std::move(column_name)), target_type(std::move(target_type)), expression(std::move(expression)) { @@ -409,7 +411,7 @@ string ChangeColumnTypeInfo::ToString() const { SetDefaultInfo::SetDefaultInfo() : AlterTableInfo(AlterTableType::SET_DEFAULT) { } -SetDefaultInfo::SetDefaultInfo(AlterEntryData data, string column_name_p, unique_ptr new_default) +SetDefaultInfo::SetDefaultInfo(AlterEntryData data, Identifier column_name_p, unique_ptr new_default) : AlterTableInfo(AlterTableType::SET_DEFAULT, std::move(data)), column_name(std::move(column_name_p)), expression(std::move(new_default)) { } @@ -446,7 +448,7 @@ string SetDefaultInfo::ToString() const { SetNotNullInfo::SetNotNullInfo() : AlterTableInfo(AlterTableType::SET_NOT_NULL) { } -SetNotNullInfo::SetNotNullInfo(AlterEntryData data, string column_name_p) +SetNotNullInfo::SetNotNullInfo(AlterEntryData data, Identifier column_name_p) : AlterTableInfo(AlterTableType::SET_NOT_NULL, std::move(data)), column_name(std::move(column_name_p)) { } SetNotNullInfo::~SetNotNullInfo() { @@ -476,7 +478,7 @@ string SetNotNullInfo::ToString() const { DropNotNullInfo::DropNotNullInfo() : AlterTableInfo(AlterTableType::DROP_NOT_NULL) { } -DropNotNullInfo::DropNotNullInfo(AlterEntryData data, string column_name_p) +DropNotNullInfo::DropNotNullInfo(AlterEntryData data, Identifier column_name_p) : AlterTableInfo(AlterTableType::DROP_NOT_NULL, std::move(data)), column_name(std::move(column_name_p)) { } DropNotNullInfo::~DropNotNullInfo() { @@ -506,8 +508,8 @@ string DropNotNullInfo::ToString() const { AlterForeignKeyInfo::AlterForeignKeyInfo() : AlterTableInfo(AlterTableType::FOREIGN_KEY_CONSTRAINT) { } -AlterForeignKeyInfo::AlterForeignKeyInfo(AlterEntryData data, string fk_table, vector pk_columns, - vector fk_columns, vector pk_keys, +AlterForeignKeyInfo::AlterForeignKeyInfo(AlterEntryData data, Identifier fk_table, vector pk_columns, + vector fk_columns, vector pk_keys, vector fk_keys, AlterForeignKeyType type_p) : AlterTableInfo(AlterTableType::FOREIGN_KEY_CONSTRAINT, std::move(data)), fk_table(std::move(fk_table)), pk_columns(std::move(pk_columns)), fk_columns(std::move(fk_columns)), pk_keys(std::move(pk_keys)), @@ -548,7 +550,7 @@ CatalogType AlterViewInfo::GetCatalogType() const { //===--------------------------------------------------------------------===// RenameViewInfo::RenameViewInfo() : AlterViewInfo(AlterViewType::RENAME_VIEW) { } -RenameViewInfo::RenameViewInfo(AlterEntryData data, string new_name_p) +RenameViewInfo::RenameViewInfo(AlterEntryData data, Identifier new_name_p) : AlterViewInfo(AlterViewType::RENAME_VIEW, std::move(data)), new_view_name(std::move(new_name_p)) { } RenameViewInfo::~RenameViewInfo() { @@ -719,7 +721,7 @@ string SetTableOptionsInfo::ToString() const { ResetTableOptionsInfo::ResetTableOptionsInfo() : AlterTableInfo(AlterTableType::RESET_TABLE_OPTIONS) { } -ResetTableOptionsInfo::ResetTableOptionsInfo(AlterEntryData data, case_insensitive_set_t table_options) +ResetTableOptionsInfo::ResetTableOptionsInfo(AlterEntryData data, identifier_set_t table_options) : AlterTableInfo(AlterTableType::RESET_TABLE_OPTIONS, std::move(data)), table_options(std::move(table_options)) { } @@ -727,7 +729,7 @@ ResetTableOptionsInfo::~ResetTableOptionsInfo() { } unique_ptr ResetTableOptionsInfo::Copy() const { - case_insensitive_set_t table_options_copy; + identifier_set_t table_options_copy; for (auto &option : table_options) { table_options_copy.emplace(option); } @@ -743,7 +745,7 @@ string ResetTableOptionsInfo::ToString() const { if (i > 0) { result += ", "; } - result += SQLString(entry); + result += SQLString(entry.GetIdentifierName()); i++; } result += ")"; diff --git a/src/duckdb/src/parser/parsed_data/comment_on_column_info.cpp b/src/duckdb/src/parser/parsed_data/comment_on_column_info.cpp index 50828afb2..6aaf4ad14 100644 --- a/src/duckdb/src/parser/parsed_data/comment_on_column_info.cpp +++ b/src/duckdb/src/parser/parsed_data/comment_on_column_info.cpp @@ -5,12 +5,13 @@ namespace duckdb { SetColumnCommentInfo::SetColumnCommentInfo() - : AlterInfo(AlterType::SET_COLUMN_COMMENT, INVALID_CATALOG, INVALID_SCHEMA, "", OnEntryNotFound::THROW_EXCEPTION), + : AlterInfo(AlterType::SET_COLUMN_COMMENT, Identifier::InvalidCatalog(), Identifier::InvalidSchema(), "", + OnEntryNotFound::THROW_EXCEPTION), catalog_entry_type(CatalogType::INVALID), column_name(""), comment_value(Value()) { } -SetColumnCommentInfo::SetColumnCommentInfo(string catalog, string schema, string name, string column_name, - Value comment_value, OnEntryNotFound if_not_found) +SetColumnCommentInfo::SetColumnCommentInfo(Identifier catalog, Identifier schema, Identifier name, + Identifier column_name, Value comment_value, OnEntryNotFound if_not_found) : AlterInfo(AlterType::SET_COLUMN_COMMENT, std::move(catalog), std::move(schema), std::move(name), if_not_found), catalog_entry_type(CatalogType::INVALID), column_name(std::move(column_name)), comment_value(std::move(comment_value)) { diff --git a/src/duckdb/src/parser/parsed_data/connect_info.cpp b/src/duckdb/src/parser/parsed_data/connect_info.cpp new file mode 100644 index 000000000..a376de802 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/connect_info.cpp @@ -0,0 +1,30 @@ +#include "duckdb/parser/parsed_data/connect_info.hpp" + +namespace duckdb { + +unique_ptr ConnectInfo::Copy() const { + auto result = make_uniq(); + result->name = name; + result->target_is_local = target_is_local; + result->name_is_string_literal = name_is_string_literal; + return result; +} + +string ConnectInfo::ToString() const { + if (target_is_local) { + return "CONNECT LOCAL;"; + } + if (name.empty()) { + return "CONNECT;"; + } + string result = "CONNECT "; + if (name_is_string_literal) { + result += SQLString(name.GetIdentifierName()); + } else { + result += SQLIdentifier(name); + } + result += ";"; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/copy_info.cpp b/src/duckdb/src/parser/parsed_data/copy_info.cpp index 3ed6fdbd2..f3228ab4f 100644 --- a/src/duckdb/src/parser/parsed_data/copy_info.cpp +++ b/src/duckdb/src/parser/parsed_data/copy_info.cpp @@ -84,7 +84,7 @@ string CopyInfo::TablePartToString() const { if (!select_list.empty()) { vector options; for (auto &option : select_list) { - options.push_back(SQLIdentifier::ToString(option)); + options.push_back(SQLIdentifier::ToString(option.GetIdentifierName())); } result += " ("; result += StringUtil::Join(options, ", "); diff --git a/src/duckdb/src/parser/parsed_data/create_collation_info.cpp b/src/duckdb/src/parser/parsed_data/create_collation_info.cpp index 9ae2e6102..18ce5e81b 100644 --- a/src/duckdb/src/parser/parsed_data/create_collation_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_collation_info.cpp @@ -2,7 +2,7 @@ namespace duckdb { -CreateCollationInfo::CreateCollationInfo(string name_p, ScalarFunction function_p, bool combinable_p, +CreateCollationInfo::CreateCollationInfo(Identifier name_p, ScalarFunction function_p, bool combinable_p, bool not_required_for_equality_p) : CreateInfo(CatalogType::COLLATION_ENTRY), function(std::move(function_p)), combinable(combinable_p), not_required_for_equality(not_required_for_equality_p) { diff --git a/src/duckdb/src/parser/parsed_data/create_coordinate_system_info.cpp b/src/duckdb/src/parser/parsed_data/create_coordinate_system_info.cpp index be48d88dd..9e50ba435 100644 --- a/src/duckdb/src/parser/parsed_data/create_coordinate_system_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_coordinate_system_info.cpp @@ -2,8 +2,8 @@ namespace duckdb { -CreateCoordinateSystemInfo::CreateCoordinateSystemInfo(string name_p, string authority, string code, string projjson, - string wkt2_2019) +CreateCoordinateSystemInfo::CreateCoordinateSystemInfo(Identifier name_p, string authority, string code, + string projjson, string wkt2_2019) : CreateInfo(CatalogType::COORDINATE_SYSTEM_ENTRY), name(std::move(name_p)), authority(std::move(authority)), code(std::move(code)), projjson_definition(std::move(projjson)), wkt2_2019_definition(std::move(wkt2_2019)) { internal = true; diff --git a/src/duckdb/src/parser/parsed_data/create_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_function_info.cpp index b5ec484d2..add9d76b4 100644 --- a/src/duckdb/src/parser/parsed_data/create_function_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_function_info.cpp @@ -2,7 +2,7 @@ namespace duckdb { -CreateFunctionInfo::CreateFunctionInfo(CatalogType type, string schema) : CreateInfo(type, std::move(schema)) { +CreateFunctionInfo::CreateFunctionInfo(CatalogType type, Identifier schema) : CreateInfo(type, std::move(schema)) { D_ASSERT(type == CatalogType::SCALAR_FUNCTION_ENTRY || type == CatalogType::AGGREGATE_FUNCTION_ENTRY || type == CatalogType::TABLE_FUNCTION_ENTRY || type == CatalogType::PRAGMA_FUNCTION_ENTRY || type == CatalogType::MACRO_ENTRY || type == CatalogType::TABLE_MACRO_ENTRY || diff --git a/src/duckdb/src/parser/parsed_data/create_index_info.cpp b/src/duckdb/src/parser/parsed_data/create_index_info.cpp index f406c2103..99c2b82da 100644 --- a/src/duckdb/src/parser/parsed_data/create_index_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_index_info.cpp @@ -5,7 +5,7 @@ namespace duckdb { -CreateIndexInfo::CreateIndexInfo() : CreateInfo(CatalogType::INDEX_ENTRY, INVALID_SCHEMA) { +CreateIndexInfo::CreateIndexInfo() : CreateInfo(CatalogType::INDEX_ENTRY, Identifier::InvalidSchema()) { } CreateIndexInfo::CreateIndexInfo(const duckdb::CreateIndexInfo &info) @@ -14,10 +14,10 @@ CreateIndexInfo::CreateIndexInfo(const duckdb::CreateIndexInfo &info) column_ids(info.column_ids), scan_types(info.scan_types), names(info.names) { } -static void RemoveTableQualificationRecursive(unique_ptr &root_expr, const string &table_name) { +static void RemoveTableQualificationRecursive(unique_ptr &root_expr, const Identifier &table_name) { ParsedExpressionIterator::VisitExpressionMutable( *root_expr, [&](ColumnRefExpression &col_ref) { - auto &col_names = col_ref.column_names; + auto &col_names = col_ref.ColumnNamesMutable(); if (col_ref.IsQualified() && col_ref.GetTableName() == table_name) { col_names.erase(col_names.begin()); } @@ -71,7 +71,7 @@ string CreateIndexInfo::ToString() const { } result += SQLIdentifier(index_name); result += " ON "; - result += QualifierToString(temporary ? "" : catalog, schema, table); + result += QualifierToString(temporary ? Identifier() : catalog, schema, table); if (index_type != "ART") { result += " USING "; result += SQLIdentifier(index_type); diff --git a/src/duckdb/src/parser/parsed_data/create_macro_info.cpp b/src/duckdb/src/parser/parsed_data/create_macro_info.cpp index 64b979cb4..bdce364e8 100644 --- a/src/duckdb/src/parser/parsed_data/create_macro_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_macro_info.cpp @@ -4,12 +4,12 @@ namespace duckdb { -CreateMacroInfo::CreateMacroInfo(CatalogType type) : CreateFunctionInfo(type, INVALID_SCHEMA) { +CreateMacroInfo::CreateMacroInfo(CatalogType type) : CreateFunctionInfo(type, Identifier::InvalidSchema()) { } CreateMacroInfo::CreateMacroInfo(CatalogType type, unique_ptr function, vector> extra_functions) - : CreateFunctionInfo(type, INVALID_SCHEMA) { + : CreateFunctionInfo(type, Identifier::InvalidSchema()) { macros.push_back(std::move(function)); for (auto ¯o : extra_functions) { macros.push_back(std::move(macro)); @@ -18,7 +18,7 @@ CreateMacroInfo::CreateMacroInfo(CatalogType type, unique_ptr fun string CreateMacroInfo::ToString() const { auto prefix = GetCreatePrefix("MACRO"); - prefix += QualifierToString(temporary ? "" : catalog, schema, name) + " "; + prefix += QualifierToString(temporary ? Identifier() : catalog, schema, name) + " "; string definitions; for (auto &function : macros) { if (!definitions.empty()) { diff --git a/src/duckdb/src/parser/parsed_data/create_pragma_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_pragma_function_info.cpp index 7e38ee885..ec574e6c3 100644 --- a/src/duckdb/src/parser/parsed_data/create_pragma_function_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_pragma_function_info.cpp @@ -8,7 +8,12 @@ CreatePragmaFunctionInfo::CreatePragmaFunctionInfo(PragmaFunction function) functions.AddFunction(std::move(function)); internal = true; } -CreatePragmaFunctionInfo::CreatePragmaFunctionInfo(string name, PragmaFunctionSet functions_p) +CreatePragmaFunctionInfo::CreatePragmaFunctionInfo(PragmaFunctionSet functions_p) + : CreateFunctionInfo(CatalogType::PRAGMA_FUNCTION_ENTRY), functions(std::move(functions_p)) { + name = functions.name; + internal = true; +} +CreatePragmaFunctionInfo::CreatePragmaFunctionInfo(Identifier name, PragmaFunctionSet functions_p) : CreateFunctionInfo(CatalogType::PRAGMA_FUNCTION_ENTRY), functions(std::move(functions_p)) { this->name = std::move(name); internal = true; diff --git a/src/duckdb/src/parser/parsed_data/create_scalar_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_scalar_function_info.cpp index 51598a72e..7fdfdc678 100644 --- a/src/duckdb/src/parser/parsed_data/create_scalar_function_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_scalar_function_info.cpp @@ -19,7 +19,7 @@ CreateScalarFunctionInfo::CreateScalarFunctionInfo(ScalarFunctionSet set) } unique_ptr CreateScalarFunctionInfo::Copy() const { - ScalarFunctionSet set(name); + ScalarFunctionSet set {name}; set.functions = functions.functions; auto result = make_uniq(std::move(set)); CopyFunctionProperties(*result); diff --git a/src/duckdb/src/parser/parsed_data/create_schema_info.cpp b/src/duckdb/src/parser/parsed_data/create_schema_info.cpp index 8eb6faf2d..889a73e85 100644 --- a/src/duckdb/src/parser/parsed_data/create_schema_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_schema_info.cpp @@ -13,7 +13,7 @@ unique_ptr CreateSchemaInfo::Copy() const { string CreateSchemaInfo::ToString() const { string ret = ""; - string qualified = QualifierToString(temporary ? "" : catalog, "", schema); + string qualified = QualifierToString(temporary ? Identifier() : catalog, Identifier(), schema); switch (on_conflict) { case OnCreateConflict::ALTER_ON_CONFLICT: { diff --git a/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp b/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp index 9a03769e7..1b504f270 100644 --- a/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_sequence_info.cpp @@ -3,8 +3,8 @@ namespace duckdb { CreateSequenceInfo::CreateSequenceInfo() - : CreateInfo(CatalogType::SEQUENCE_ENTRY, INVALID_SCHEMA), name(string()), usage_count(0), increment(1), - min_value(1), max_value(NumericLimits::Maximum()), start_value(1), cycle(false) { + : CreateInfo(CatalogType::SEQUENCE_ENTRY, Identifier::InvalidSchema()), name(string()), usage_count(0), + increment(1), min_value(1), max_value(NumericLimits::Maximum()), start_value(1), cycle(false) { } unique_ptr CreateSequenceInfo::Copy() const { @@ -35,7 +35,7 @@ string CreateSequenceInfo::ToString() const { if (on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT) { ss << " IF NOT EXISTS "; } - ss << QualifierToString(temporary ? "" : catalog, schema, name); + ss << QualifierToString(temporary ? Identifier() : catalog, schema, name); ss << " INCREMENT BY " << increment; ss << " MINVALUE " << min_value; ss << " MAXVALUE " << max_value; diff --git a/src/duckdb/src/parser/parsed_data/create_table_function_info.cpp b/src/duckdb/src/parser/parsed_data/create_table_function_info.cpp index 3e774fba9..b5b7f6f38 100644 --- a/src/duckdb/src/parser/parsed_data/create_table_function_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_table_function_info.cpp @@ -19,7 +19,7 @@ CreateTableFunctionInfo::CreateTableFunctionInfo(TableFunctionSet set) } unique_ptr CreateTableFunctionInfo::Copy() const { - TableFunctionSet set(name); + TableFunctionSet set {name}; set.functions = functions.functions; auto result = make_uniq(std::move(set)); CopyFunctionProperties(*result); diff --git a/src/duckdb/src/parser/parsed_data/create_table_info.cpp b/src/duckdb/src/parser/parsed_data/create_table_info.cpp index d556a5418..ef4af7f08 100644 --- a/src/duckdb/src/parser/parsed_data/create_table_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_table_info.cpp @@ -5,14 +5,14 @@ namespace duckdb { -CreateTableInfo::CreateTableInfo() : CreateInfo(CatalogType::TABLE_ENTRY, INVALID_SCHEMA) { +CreateTableInfo::CreateTableInfo() : CreateInfo(CatalogType::TABLE_ENTRY, Identifier::InvalidSchema()) { } -CreateTableInfo::CreateTableInfo(string catalog_p, string schema_p, string name_p) +CreateTableInfo::CreateTableInfo(Identifier catalog_p, Identifier schema_p, Identifier name_p) : CreateInfo(CatalogType::TABLE_ENTRY, std::move(schema_p), std::move(catalog_p)), table(std::move(name_p)) { } -CreateTableInfo::CreateTableInfo(SchemaCatalogEntry &schema, string name_p) +CreateTableInfo::CreateTableInfo(SchemaCatalogEntry &schema, Identifier name_p) : CreateTableInfo(schema.catalog.GetName(), schema.name, std::move(name_p)) { } @@ -69,7 +69,7 @@ string CreateTableInfo::ExtraOptionsToString() const { string CreateTableInfo::ToString() const { string ret = GetCreatePrefix("TABLE"); - ret += QualifierToString(temporary ? "" : catalog, schema, table); + ret += QualifierToString(temporary ? Identifier() : catalog, schema, table); if (query != nullptr) { ret += TableCatalogEntry::ColumnNamesToSQL(columns); diff --git a/src/duckdb/src/parser/parsed_data/create_trigger_info.cpp b/src/duckdb/src/parser/parsed_data/create_trigger_info.cpp index 4349fdc9d..7b01ce71b 100644 --- a/src/duckdb/src/parser/parsed_data/create_trigger_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_trigger_info.cpp @@ -5,7 +5,7 @@ namespace duckdb { CreateTriggerInfo::CreateTriggerInfo() - : CreateInfo(CatalogType::TRIGGER_ENTRY, INVALID_SCHEMA), timing(TriggerTiming::AFTER), + : CreateInfo(CatalogType::TRIGGER_ENTRY, Identifier::InvalidSchema()), timing(TriggerTiming::AFTER), event_type(TriggerEventType::INSERT_EVENT), for_each(TriggerForEach::STATEMENT) { } @@ -18,6 +18,8 @@ unique_ptr CreateTriggerInfo::Copy() const { result->event_type = event_type; result->columns = columns; result->for_each = for_each; + result->referencing_new_table = referencing_new_table; + result->referencing_old_table = referencing_old_table; result->trigger_action = trigger_action->Copy(); return std::move(result); } @@ -51,6 +53,15 @@ string CreateTriggerInfo::ToString() const { } ss << " ON "; ss << base_table->ToString(); + if (!referencing_new_table.empty() || !referencing_old_table.empty()) { + ss << " REFERENCING"; + if (!referencing_new_table.empty()) { + ss << " NEW TABLE AS " << SQLIdentifier(referencing_new_table); + } + if (!referencing_old_table.empty()) { + ss << " OLD TABLE AS " << SQLIdentifier(referencing_old_table); + } + } ss << " FOR EACH " << EnumUtil::ToString(for_each); ss << " " << trigger_action->ToString(); ss << ";"; diff --git a/src/duckdb/src/parser/parsed_data/create_type_info.cpp b/src/duckdb/src/parser/parsed_data/create_type_info.cpp index 95d9bee19..5f473d23e 100644 --- a/src/duckdb/src/parser/parsed_data/create_type_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_type_info.cpp @@ -24,7 +24,7 @@ unique_ptr CreateTypeInfo::Copy() const { string CreateTypeInfo::ToString() const { string result = GetCreatePrefix("TYPE"); - result += QualifierToString(temporary ? "" : catalog, schema, name); + result += QualifierToString(temporary ? Identifier() : catalog, schema, name); if (type.id() == LogicalTypeId::ENUM) { auto &values_insert_order = EnumType::GetValuesInsertOrder(type); idx_t size = EnumType::GetSize(type); diff --git a/src/duckdb/src/parser/parsed_data/create_view_info.cpp b/src/duckdb/src/parser/parsed_data/create_view_info.cpp index 493989cc4..a96e0c6b8 100644 --- a/src/duckdb/src/parser/parsed_data/create_view_info.cpp +++ b/src/duckdb/src/parser/parsed_data/create_view_info.cpp @@ -8,24 +8,24 @@ namespace duckdb { -CreateViewInfo::CreateViewInfo() : CreateInfo(CatalogType::VIEW_ENTRY, INVALID_SCHEMA) { +CreateViewInfo::CreateViewInfo() : CreateInfo(CatalogType::VIEW_ENTRY, Identifier::InvalidSchema()) { } -CreateViewInfo::CreateViewInfo(string catalog_p, string schema_p, string view_name_p) +CreateViewInfo::CreateViewInfo(Identifier catalog_p, Identifier schema_p, Identifier view_name_p) : CreateInfo(CatalogType::VIEW_ENTRY, std::move(schema_p), std::move(catalog_p)), view_name(std::move(view_name_p)) { } -CreateViewInfo::CreateViewInfo(SchemaCatalogEntry &schema, string view_name) +CreateViewInfo::CreateViewInfo(SchemaCatalogEntry &schema, Identifier view_name) : CreateViewInfo(schema.catalog.GetName(), schema.name, std::move(view_name)) { } string CreateViewInfo::ToString() const { string result = GetCreatePrefix("VIEW"); - result += QualifierToString(temporary ? "" : catalog, schema, view_name); + result += QualifierToString(temporary ? Identifier() : catalog, schema, view_name); if (!aliases.empty()) { result += " ("; result += - StringUtil::Join(aliases, aliases.size(), ", ", [](const string &name) { return SQLIdentifier(name); }); + StringUtil::Join(aliases, aliases.size(), ", ", [](const Identifier &name) { return SQLIdentifier(name); }); result += ")"; } if (binding_mode == CreateViewBindingMode::SKIP_BINDING) { @@ -111,20 +111,20 @@ vector CreateViewInfo::GetColumnCommentsList() const { vector result; result.resize(names.size()); for (auto &entry : column_comments_map) { - auto it = std::find(names.begin(), names.end(), entry.first); + auto it = std::find_if(names.begin(), names.end(), [&](const Identifier &n) { return entry.first == n; }); if (it == names.end()) { throw InternalException( "While serializing comments for view \"%s\" - did not find column \"%s\" in list of names", view_name, - entry.first); + entry.first.GetIdentifierName()); } result[NumericCast(it - names.begin())] = entry.second; } return result; } -CreateViewInfo::CreateViewInfo(vector names_p, vector comments, - unordered_map column_comments_p) - : CreateInfo(CatalogType::VIEW_ENTRY, INVALID_SCHEMA), names(std::move(names_p)), +CreateViewInfo::CreateViewInfo(vector names_p, vector comments, + identifier_map_t column_comments_p) + : CreateInfo(CatalogType::VIEW_ENTRY, Identifier::InvalidSchema()), names(std::move(names_p)), column_comments_map(std::move(column_comments_p)) { if (comments.empty()) { return; diff --git a/src/duckdb/src/parser/parsed_data/disconnect_info.cpp b/src/duckdb/src/parser/parsed_data/disconnect_info.cpp new file mode 100644 index 000000000..50a5ef3a5 --- /dev/null +++ b/src/duckdb/src/parser/parsed_data/disconnect_info.cpp @@ -0,0 +1,13 @@ +#include "duckdb/parser/parsed_data/disconnect_info.hpp" + +namespace duckdb { + +unique_ptr DisconnectInfo::Copy() const { + return make_uniq(); +} + +string DisconnectInfo::ToString() const { + return "DISCONNECT;"; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parsed_data/exported_table_data.cpp b/src/duckdb/src/parser/parsed_data/exported_table_data.cpp index a4b66578d..e28896b79 100644 --- a/src/duckdb/src/parser/parsed_data/exported_table_data.cpp +++ b/src/duckdb/src/parser/parsed_data/exported_table_data.cpp @@ -5,7 +5,7 @@ namespace duckdb { ExportedTableInfo::ExportedTableInfo(TableCatalogEntry &entry, ExportedTableData table_data_p, - vector ¬_null_columns_p) + vector ¬_null_columns_p) : entry(entry), table_data(std::move(table_data_p)) { table_data.not_null_columns = not_null_columns_p; } diff --git a/src/duckdb/src/parser/parsed_data/load_info.cpp b/src/duckdb/src/parser/parsed_data/load_info.cpp index 3c7c42c26..e4afe427b 100644 --- a/src/duckdb/src/parser/parsed_data/load_info.cpp +++ b/src/duckdb/src/parser/parsed_data/load_info.cpp @@ -12,12 +12,14 @@ unique_ptr LoadInfo::Copy() const { result->load_type = load_type; result->repo_is_alias = repo_is_alias; result->version = version; + result->alias = alias; return result; } static string LoadInfoToString(LoadType load_type) { switch (load_type) { case LoadType::LOAD: + case LoadType::LOAD_AS: return "LOAD"; case LoadType::INSTALL: return "INSTALL"; @@ -39,6 +41,9 @@ string LoadInfo::ToString() const { result += " FROM " + SQLString(repository); } } + if (!alias.empty()) { + result += " AS " + SQLIdentifier(alias); + } result += ";"; return result; diff --git a/src/duckdb/src/parser/parsed_data/parse_info.cpp b/src/duckdb/src/parser/parsed_data/parse_info.cpp index b0a645885..73399a861 100644 --- a/src/duckdb/src/parser/parsed_data/parse_info.cpp +++ b/src/duckdb/src/parser/parsed_data/parse_info.cpp @@ -35,7 +35,7 @@ string ParseInfo::TypeToString(CatalogType type) { } } -string ParseInfo::QualifierToString(const string &catalog, const string &schema, const string &name) { +string ParseInfo::QualifierToString(const Identifier &catalog, const Identifier &schema, const Identifier &name) { string result; if (!catalog.empty()) { result += SQLIdentifier(catalog) + "."; diff --git a/src/duckdb/src/parser/parsed_data/vacuum_info.cpp b/src/duckdb/src/parser/parsed_data/vacuum_info.cpp index dc1920b6f..0aa6efef9 100644 --- a/src/duckdb/src/parser/parsed_data/vacuum_info.cpp +++ b/src/duckdb/src/parser/parsed_data/vacuum_info.cpp @@ -27,7 +27,7 @@ string VacuumInfo::ToString() const { if (!columns.empty()) { vector names; for (auto &column : columns) { - names.push_back(SQLIdentifier::ToString(column)); + names.push_back(SQLIdentifier::ToString(column.GetIdentifierName())); } result += "(" + StringUtil::Join(names, ", ") + ")"; } diff --git a/src/duckdb/src/parser/parsed_expression.cpp b/src/duckdb/src/parser/parsed_expression.cpp index 5d270b8ff..e6ff5f1f4 100644 --- a/src/duckdb/src/parser/parsed_expression.cpp +++ b/src/duckdb/src/parser/parsed_expression.cpp @@ -47,60 +47,6 @@ bool ParsedExpression::HasSubquery() const { return has_subquery; } -bool ParsedExpression::Equals(const BaseExpression &other) const { - if (!BaseExpression::Equals(other)) { - return false; - } - switch (expression_class) { - case ExpressionClass::BETWEEN: - return BetweenExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::CASE: - return CaseExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::CAST: - return CastExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::COLLATE: - return CollateExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::COLUMN_REF: - return ColumnRefExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::COMPARISON: - return ComparisonExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::CONJUNCTION: - return ConjunctionExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::CONSTANT: - return ConstantExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::DEFAULT: - return true; - case ExpressionClass::FUNCTION: - return FunctionExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::LAMBDA: - return LambdaExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::OPERATOR: - return OperatorExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::PARAMETER: - return ParameterExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::POSITIONAL_REFERENCE: - return PositionalReferenceExpression::Equal(Cast(), - other.Cast()); - case ExpressionClass::STAR: - return StarExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::SUBQUERY: - return SubqueryExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::WINDOW: - return WindowExpression::Equal(Cast(), other.Cast()); - case ExpressionClass::TYPE: - return TypeExpression::Equal(Cast(), other.Cast()); - default: - throw SerializationException("Unsupported type for expression comparison!"); - } -} - -hash_t ParsedExpression::Hash() const { - hash_t hash = duckdb::Hash(static_cast(type)); - ParsedExpressionIterator::EnumerateChildren( - *this, [&](const ParsedExpression &child) { hash = CombineHash(child.Hash(), hash); }); - return hash; -} - bool ParsedExpression::Equals(const unique_ptr &left, const unique_ptr &right) { if (left.get() == right.get()) { return true; diff --git a/src/duckdb/src/parser/parsed_expression_iterator.cpp b/src/duckdb/src/parser/parsed_expression_iterator.cpp index c994cadbb..a5f750905 100644 --- a/src/duckdb/src/parser/parsed_expression_iterator.cpp +++ b/src/duckdb/src/parser/parsed_expression_iterator.cpp @@ -9,16 +9,17 @@ #include "duckdb/parser/query_node/delete_query_node.hpp" #include "duckdb/parser/query_node/insert_query_node.hpp" #include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" #include "duckdb/parser/tableref/list.hpp" namespace duckdb { void ParsedExpressionIterator::EnumerateChildren(const ParsedExpression &expression, const std::function &callback) { - EnumerateChildren((ParsedExpression &)expression, [&](unique_ptr &child) { - D_ASSERT(child); - callback(*child); - }); + for (const auto &child : expression.Children()) { + callback(child); + } } void ParsedExpressionIterator::EnumerateChildren(ParsedExpression &expr, @@ -31,136 +32,8 @@ void ParsedExpressionIterator::EnumerateChildren(ParsedExpression &expr, void ParsedExpressionIterator::EnumerateChildren( ParsedExpression &expr, const std::function &child)> &callback) { - switch (expr.GetExpressionClass()) { - case ExpressionClass::BETWEEN: { - auto &cast_expr = expr.Cast(); - callback(cast_expr.input); - callback(cast_expr.lower); - callback(cast_expr.upper); - break; - } - case ExpressionClass::CASE: { - auto &case_expr = expr.Cast(); - for (auto &check : case_expr.case_checks) { - callback(check.when_expr); - callback(check.then_expr); - } - callback(case_expr.else_expr); - break; - } - case ExpressionClass::CAST: { - auto &cast_expr = expr.Cast(); - callback(cast_expr.child); - break; - } - case ExpressionClass::COLLATE: { - auto &cast_expr = expr.Cast(); - callback(cast_expr.child); - break; - } - case ExpressionClass::COMPARISON: { - auto &comp_expr = expr.Cast(); - callback(comp_expr.left); - callback(comp_expr.right); - break; - } - case ExpressionClass::CONJUNCTION: { - auto &conj_expr = expr.Cast(); - for (auto &child : conj_expr.children) { - callback(child); - } - break; - } - - case ExpressionClass::FUNCTION: { - auto &func_expr = expr.Cast(); - for (auto &child : func_expr.children) { - callback(child); - } - if (func_expr.filter) { - callback(func_expr.filter); - } - if (func_expr.order_bys) { - for (auto &order : func_expr.order_bys->orders) { - callback(order.expression); - } - } - break; - } - case ExpressionClass::TYPE: { - auto &type_expr = expr.Cast(); - for (auto &child : type_expr.GetChildren()) { - callback(child); - } - break; - } - case ExpressionClass::LAMBDA: { - auto &lambda_expr = expr.Cast(); - callback(lambda_expr.lhs); - callback(lambda_expr.expr); - break; - } - case ExpressionClass::OPERATOR: { - auto &op_expr = expr.Cast(); - for (auto &child : op_expr.children) { - callback(child); - } - break; - } - case ExpressionClass::STAR: { - auto &star_expr = expr.Cast(); - if (star_expr.expr) { - callback(star_expr.expr); - } - for (auto &item : star_expr.replace_list) { - callback(item.second); - } - break; - } - case ExpressionClass::SUBQUERY: { - auto &subquery_expr = expr.Cast(); - if (subquery_expr.child) { - callback(subquery_expr.child); - } - break; - } - case ExpressionClass::WINDOW: { - auto &window_expr = expr.Cast(); - for (auto &partition : window_expr.partitions) { - callback(partition); - } - for (auto &order : window_expr.orders) { - callback(order.expression); - } - for (auto &child : window_expr.children) { - callback(child); - } - if (window_expr.filter_expr) { - callback(window_expr.filter_expr); - } - if (window_expr.start_expr) { - callback(window_expr.start_expr); - } - if (window_expr.end_expr) { - callback(window_expr.end_expr); - } - for (auto &order : window_expr.arg_orders) { - callback(order.expression); - } - break; - } - case ExpressionClass::BOUND_EXPRESSION: - case ExpressionClass::COLUMN_REF: - case ExpressionClass::LAMBDA_REF: - case ExpressionClass::CONSTANT: - case ExpressionClass::DEFAULT: - case ExpressionClass::PARAMETER: - case ExpressionClass::POSITIONAL_REFERENCE: - // these node types have no children - break; - default: - // called on non ParsedExpression type! - throw NotImplementedException("Unimplemented expression class"); + for (auto &child : expr.ChildrenMutable()) { + callback(child); } } @@ -177,17 +50,6 @@ void ParsedExpressionIterator::EnumerateQueryNodeModifiers( callback(limit_modifier.offset); } } break; - - case ResultModifierType::LIMIT_PERCENT_MODIFIER: { - auto &limit_modifier = modifier->Cast(); - if (limit_modifier.limit) { - callback(limit_modifier.limit); - } - if (limit_modifier.offset) { - callback(limit_modifier.offset); - } - } break; - case ResultModifierType::ORDER_MODIFIER: { auto &order_modifier = modifier->Cast(); for (auto &order : order_modifier.orders) { @@ -363,6 +225,40 @@ void ParsedExpressionIterator::EnumerateQueryNodeChildren( } break; } + case QueryNodeType::MERGE_QUERY_NODE: { + auto &merge_node = node.Cast(); + if (merge_node.target) { + EnumerateTableRefChildren(*merge_node.target, expr_callback, ref_callback); + } + if (merge_node.source) { + EnumerateTableRefChildren(*merge_node.source, expr_callback, ref_callback); + } + if (merge_node.join_condition) { + expr_callback(merge_node.join_condition); + } + for (auto &entry : merge_node.actions) { + for (auto &action : entry.second) { + if (action->condition) { + expr_callback(action->condition); + } + if (action->update_info) { + for (auto &expr : action->update_info->expressions) { + expr_callback(expr); + } + if (action->update_info->condition) { + expr_callback(action->update_info->condition); + } + } + for (auto &expr : action->expressions) { + expr_callback(expr); + } + } + } + for (auto &expr : merge_node.returning_list) { + expr_callback(expr); + } + break; + } default: throw NotImplementedException("QueryNode type not implemented for traversal"); } diff --git a/src/duckdb/src/parser/parser.cpp b/src/duckdb/src/parser/parser.cpp index 0d043b51a..d55d7646e 100644 --- a/src/duckdb/src/parser/parser.cpp +++ b/src/duckdb/src/parser/parser.cpp @@ -258,40 +258,62 @@ void Parser::ParseQuery(const string &query) { last_strict_extension_error.Throw(); } } - // PEG parser: tokenize then transform - auto &cache = GetCache(); - auto peg_matcher = cache.GetMatcher(); - auto peg_factory = cache.GetTransformerFactory(); - + // PEG parser: tokenize, then peel one TopLevelStatement at a time. On per-statement PEG + // failure, hand the rest of the query to parse_function extensions; the extension reports + // how many bytes it consumed and we advance the token cursor past them. vector tokens; ParserTokenizer tokenizer(query, tokens); tokenizer.TokenizeInput(); - if (!tokens.empty()) { + idx_t token_cursor = 0; + while (token_cursor < tokens.size()) { try { - auto peg_statements = peg_factory->Transform(tokens, options, peg_matcher->Root()); - for (auto &stmt : peg_statements) { + auto stmt = ParseTopLevelStatement(tokens, token_cursor); + if (stmt) { statements.push_back(std::move(stmt)); } } catch (ParserException &e) { - // fall back to parse_function extensions for unknown statement types bool parsed = false; if (options.extensions && options.extensions->HasParserExtensions()) { + idx_t failure_byte = token_cursor < tokens.size() ? tokens[token_cursor].offset : query.size(); + // SimpleToken view of the tail: text + classified type, in source order, so + // extensions can dispatch on the token stream without re-tokenizing. The extension + // reports how many of these tokens it consumed. + vector simple_tokens; + simple_tokens.reserve(tokens.size() - token_cursor); + for (idx_t i = token_cursor; i < tokens.size(); i++) { + simple_tokens.emplace_back(tokens[i].text, tokens[i].type); + } for (auto &ext : options.extensions->ParserExtensions()) { if (!ext.parse_function) { continue; } - auto result = ext.parse_function(ext.parser_info.get(), query); - if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { - auto estmt = make_uniq(ext, std::move(result.parse_data)); - estmt->stmt_location = 0; - estmt->stmt_length = query.size(); - statements.push_back(std::move(estmt)); - parsed = true; - break; - } - if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { + auto result = ext.parse_function(ext.parser_info.get(), simple_tokens); + if (result.consumed_tokens < 0) { + // The extension wants to surface an error. throw ParserException::SyntaxError(query, result.error, result.error_location); } + if (result.consumed_tokens == 0) { + // The extension ran but did not claim this input — let the next one try. + continue; + } + // consumed_tokens > 0: the extension accepted that many leading tokens. + auto consumed = NumericCast(result.consumed_tokens); + if (consumed > simple_tokens.size()) { + throw ParserException( + "Extension returned consumed_tokens=%llu — only %llu tokens are available", + (uint64_t)consumed, (uint64_t)simple_tokens.size()); + } + // The claimed span runs from the failure point to the end of the last consumed + // token; advancing the cursor by consumed_tokens lands on a token boundary. + auto &last_token = tokens[token_cursor + consumed - 1]; + const idx_t span_end_byte = last_token.offset + last_token.length; + auto estmt = make_uniq(ext, std::move(result.parse_data)); + estmt->stmt_location = failure_byte; + estmt->stmt_length = span_end_byte - failure_byte; + statements.push_back(std::move(estmt)); + token_cursor += consumed; + parsed = true; + break; } } if (!parsed) { @@ -317,6 +339,18 @@ void Parser::ParseQuery(const string &query) { } } +unique_ptr Parser::ParseTopLevelStatement(vector &tokens, idx_t &token_cursor) { + if (token_cursor >= tokens.size()) { + return nullptr; + } + auto &cache = GetCache(); + auto peg_matcher = cache.GetMatcher(); + auto peg_factory = cache.GetTransformerFactory(); + + return peg_factory->TransformTopLevelStatement(tokens, options, peg_matcher->TopLevelStatementMatcher(), + token_cursor); +} + vector Parser::Tokenize(const string &query) { HighlightTokenizer tokenizer(query); tokenizer.TokenizeInput(); @@ -592,7 +626,7 @@ vector Parser::ParseOrderList(const string &select_list, ParserOpti return std::move(order.orders); } -void Parser::ParseUpdateList(const string &update_list, vector &update_columns, +void Parser::ParseUpdateList(const string &update_list, vector &update_columns, vector> &expressions, ParserOptions options) { // construct a mock query string mock_query = "UPDATE tbl SET " + update_list; @@ -604,7 +638,7 @@ void Parser::ParseUpdateList(const string &update_list, vector &update_c throw ParserException("Expected a single UPDATE statement"); } auto &update = parser.statements[0]->Cast(); - update_columns = std::move(update.node->set_info->columns); + update_columns = update.node->set_info->columns; expressions = std::move(update.node->set_info->expressions); } diff --git a/src/duckdb/src/parser/peg/autocomplete_core.cpp b/src/duckdb/src/parser/peg/autocomplete_core.cpp index dc04e26bc..fa90539df 100644 --- a/src/duckdb/src/parser/peg/autocomplete_core.cpp +++ b/src/duckdb/src/parser/peg/autocomplete_core.cpp @@ -168,6 +168,10 @@ class AutoCompleteTokenizer : public BaseTokenizer { last_pos = 0; } + TokenType GetTerminator() const override { + return TokenType::END_OF_INPUT_AUTOCOMPLETE; + } + void OnLastToken(TokenizeState state, string last_word_p, idx_t last_pos_p) override { if (TokenizeStateToType(state) == TokenType::STRING_LITERAL) { suggestions.emplace_back(SuggestionState::SUGGEST_FILE_NAME); @@ -201,15 +205,15 @@ vector GenerateAutoCompleteSuggestions(AutoCompleteCatal string clean_sql; const string &sql_ref = Parser::StripUnicodeSpaces(sql, clean_sql) ? clean_sql : sql; AutoCompleteTokenizer tokenizer(sql_ref, state); - auto allow_complete = tokenizer.TokenizeInput(); - if (!allow_complete) { + tokenizer.TokenizeInput(); + if (!tokenizer.CanAutocomplete()) { return {}; } if (state.suggestions.empty()) { // no suggestions found during tokenizing // run the root matcher auto peg_matcher = provider.GetPEGMatcher(); - peg_matcher->Root().Match(state); + peg_matcher->ProgramMatcher().Match(state); } if (state.suggestions.empty()) { return {}; diff --git a/src/duckdb/src/parser/peg/keyword_map.cpp b/src/duckdb/src/parser/peg/keyword_map.cpp index 4173f7697..d65d7278d 100644 --- a/src/duckdb/src/parser/peg/keyword_map.cpp +++ b/src/duckdb/src/parser/peg/keyword_map.cpp @@ -125,6 +125,7 @@ void PEGKeywordHelper::InitializeKeywordMaps() { // Renamed for clarity unreserved_keyword_map.insert("compression"); unreserved_keyword_map.insert("configuration"); unreserved_keyword_map.insert("conflict"); + unreserved_keyword_map.insert("connect"); unreserved_keyword_map.insert("connection"); unreserved_keyword_map.insert("constraints"); unreserved_keyword_map.insert("content"); @@ -156,6 +157,7 @@ void PEGKeywordHelper::InitializeKeywordMaps() { // Renamed for clarity unreserved_keyword_map.insert("dictionary"); unreserved_keyword_map.insert("disable"); unreserved_keyword_map.insert("discard"); + unreserved_keyword_map.insert("disconnect"); unreserved_keyword_map.insert("document"); unreserved_keyword_map.insert("domain"); unreserved_keyword_map.insert("double"); diff --git a/src/duckdb/src/parser/peg/matcher.cpp b/src/duckdb/src/parser/peg/matcher.cpp index feac02bd5..c0a2e53c6 100644 --- a/src/duckdb/src/parser/peg/matcher.cpp +++ b/src/duckdb/src/parser/peg/matcher.cpp @@ -124,7 +124,10 @@ class ListMatcher : public Matcher { auto saved_suggestion_size = suppress_suggestions ? list_state.suggestions.size() : 0; for (idx_t child_idx = 0; child_idx < matchers.size(); child_idx++) { auto &child_matcher = matchers[child_idx].get(); - if (list_state.token_index >= list_state.tokens.size()) { + bool at_autocomplete_cursor = + list_state.token_index < list_state.tokens.size() && + list_state.tokens[list_state.token_index].type == TokenType::END_OF_INPUT_AUTOCOMPLETE; + if (at_autocomplete_cursor) { if (suppress_suggestions) { // this rule should not contribute autocomplete suggestions // discard any suggestions added by earlier children @@ -133,7 +136,7 @@ class ListMatcher : public Matcher { list_state.suggestions.end()); return MatchResultType::FAIL; } - // we exhausted the tokens - push suggestions for the child matcher + // cursor is here - push suggestions for what could follow for (; child_idx < matchers.size(); child_idx++) { auto suggestion_type = matchers[child_idx].get().AddSuggestion(list_state); if (suggestion_type == SuggestionType::MANDATORY) { @@ -356,9 +359,10 @@ class RepeatMatcher : public Matcher { // update the token index we propagate upwards state.token_index = repeat_state.token_index; - // check if we have tokens left - if (repeat_state.token_index >= state.tokens.size()) { - // we exhausted the tokens - suggest the element + bool at_autocomplete_cursor = + repeat_state.token_index < state.tokens.size() && + state.tokens[repeat_state.token_index].type == TokenType::END_OF_INPUT_AUTOCOMPLETE; + if (at_autocomplete_cursor) { element.AddSuggestion(state); return MatchResultType::SUCCESS; } @@ -395,8 +399,8 @@ class RepeatMatcher : public Matcher { // Propagate the new state upwards. state.token_index = repeat_state.token_index; - // Check if there are any tokens left. - if (repeat_state.token_index >= state.tokens.size()) { + if (repeat_state.token_index < state.tokens.size() && + state.tokens[repeat_state.token_index].type == TokenType::END_OF_INPUT_AUTOCOMPLETE) { break; } @@ -425,6 +429,44 @@ class RepeatMatcher : public Matcher { Matcher &element; }; +//! Consumes the END_OF_INPUT sentinel; wired into the grammar's EndOfInput rule. +class EndOfInputMatcher : public Matcher { +public: + static constexpr MatcherType TYPE = MatcherType::END_OF_INPUT; + +public: + EndOfInputMatcher() : Matcher(TYPE) { + } + + MatchResultType Match(MatchState &state) const override { + if (state.token_index < state.tokens.size() && + state.tokens[state.token_index].type == TokenType::END_OF_INPUT) { + state.token_index++; + state.UpdateMaxTokenIndex(); + return MatchResultType::SUCCESS; + } + return MatchResultType::FAIL; + } + + optional_ptr MatchParseResult(MatchState &state) const override { + if (state.token_index < state.tokens.size() && + state.tokens[state.token_index].type == TokenType::END_OF_INPUT) { + state.token_index++; + state.UpdateMaxTokenIndex(); + return state.allocator.Allocate(make_uniq()); + } + return nullptr; + } + + SuggestionType AddSuggestionInternal(MatchState &state) const override { + return SuggestionType::MANDATORY; + } + + string ToString() const override { + return "EndOfInput"; + } +}; + class IdentifierMatcher : public Matcher { public: static constexpr MatcherType TYPE = MatcherType::VARIABLE; @@ -457,6 +499,9 @@ class IdentifierMatcher : public Matcher { if (IsQuoted(text)) { return true; } + if (BaseTokenizer::CharacterIsInitialNumber(text[0])) { + return false; + } return BaseTokenizer::CharacterIsKeyword(text[0]); } @@ -575,6 +620,9 @@ class IdentifierMatcher : public Matcher { private: bool MatchIdentifier(MatchState &state) const { + if (state.token_index >= state.tokens.size()) { + return false; + } // variable matchers match anything except for reserved keywords auto &token_text = state.tokens[state.token_index].text; const auto &keyword_helper = PEGKeywordHelper::Instance(); @@ -660,6 +708,9 @@ class ReservedIdentifierMatcher : public IdentifierMatcher { private: bool MatchReservedIdentifier(MatchState &state) const { + if (state.token_index >= state.tokens.size()) { + return false; + } auto &token_text = state.tokens[state.token_index].text; if (!IsIdentifier(token_text)) { return false; @@ -791,8 +842,11 @@ class NumberLiteralMatcher : public Matcher { private: static bool MatchNumberLiteral(MatchState &state) { + if (state.token_index >= state.tokens.size()) { + return false; + } auto &token_text = state.tokens[state.token_index].text; - if (!BaseTokenizer::CharacterIsInitialNumber(token_text[0])) { + if (token_text.empty() || !BaseTokenizer::CharacterIsInitialNumber(token_text[0])) { return false; } // A lone '.' is a dot operator, not a number literal (e.g., '?.method()' should not consume '.') @@ -892,6 +946,9 @@ class OperatorMatcher : public Matcher { private: static bool MatchOperator(MatchState &state) { + if (state.token_index >= state.tokens.size()) { + return false; + } auto &token_text = state.tokens[state.token_index].text; // Exclude the lambda arrow and JSON arrow — these have dedicated grammar roles if (token_text == "->" || token_text == "->>") { @@ -959,6 +1016,9 @@ class ArithmeticOperatorMatcher : public Matcher { private: static bool MatchArithmeticOperator(MatchState &state) { + if (state.token_index >= state.tokens.size()) { + return false; + } auto &token_text = state.tokens[state.token_index].text; for (auto &c : token_text) { if (!IsArithmeticOperatorChar(c)) { @@ -993,6 +1053,9 @@ class MatcherFactory { //! Create a matcher from a PEG grammar Matcher &CreateMatcher(const char *grammar, const char *root_rule); + //! Look up a matcher for a rule that was already built (as a sub-rule of a previous + //! CreateMatcher call). Throws if the rule has not been built. + Matcher &GetMatcher(const string &rule_name); private: // Base primitives @@ -1002,28 +1065,8 @@ class MatcherFactory { Matcher &Choice(vector> matchers) const; Matcher &Optional(Matcher &matcher) const; Matcher &Repeat(Matcher &matcher) const; - Matcher &Variable() const; - Matcher &CatalogName() const; - Matcher &SchemaName() const; - Matcher &TypeName() const; - Matcher &TableName() const; - Matcher &ColumnName() const; - Matcher &StringLiteral() const; - Matcher &NumberLiteral() const; - Matcher &Operator() const; - Matcher &ArithmeticOperator() const; - Matcher &ScalarFunctionName() const; - Matcher &TableFunctionName() const; - Matcher &PragmaName() const; - Matcher &SettingName() const; - Matcher &CopyOptionName() const; - Matcher &ReservedSchemaName() const; - Matcher &ReservedTableName() const; - Matcher &ReservedColumnName() const; - Matcher &ReservedScalarFunctionName() const; - Matcher &ReservedVariable() const; - - void AddKeywordOverride(const char *name, uint32_t score, char extra_char = ' '); + + void AddKeywordOverride(const char *name, int32_t score, char extra_char = ' '); void AddRuleOverride(const char *name, Matcher &matcher); void SuppressSuggestions(const char *name); Matcher &CreateMatcher(PEGParser &parser, string_t rule_name); @@ -1064,83 +1107,12 @@ Matcher &MatcherFactory::Repeat(Matcher &matcher) const { return allocator.Allocate(make_uniq(matcher)); } -Matcher &MatcherFactory::Variable() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE)); -} - -Matcher &MatcherFactory::ReservedVariable() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE)); -} -Matcher &MatcherFactory::CatalogName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_CATALOG_NAME)); -} - -Matcher &MatcherFactory::SchemaName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SCHEMA_NAME)); -} - -Matcher &MatcherFactory::ReservedSchemaName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SCHEMA_NAME)); -} - -Matcher &MatcherFactory::TableName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TABLE_NAME)); -} - -Matcher &MatcherFactory::ReservedTableName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TABLE_NAME)); -} - -Matcher &MatcherFactory::ColumnName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_COLUMN_NAME)); -} - -Matcher &MatcherFactory::ReservedColumnName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_COLUMN_NAME)); -} - -Matcher &MatcherFactory::TypeName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TYPE_NAME)); -} - -Matcher &MatcherFactory::ScalarFunctionName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SCALAR_FUNCTION_NAME)); -} - -Matcher &MatcherFactory::ReservedScalarFunctionName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SCALAR_FUNCTION_NAME)); -} - -Matcher &MatcherFactory::TableFunctionName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TABLE_FUNCTION_NAME)); -} - -Matcher &MatcherFactory::PragmaName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_PRAGMA_NAME)); -} - -Matcher &MatcherFactory::SettingName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SETTING_NAME)); -} - -Matcher &MatcherFactory::CopyOptionName() const { - return allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE)); -} - -Matcher &MatcherFactory::NumberLiteral() const { - return allocator.Allocate(make_uniq()); -} - -Matcher &MatcherFactory::StringLiteral() const { - return allocator.Allocate(make_uniq()); -} - -Matcher &MatcherFactory::Operator() const { - return allocator.Allocate(make_uniq()); -} - -Matcher &MatcherFactory::ArithmeticOperator() const { - return allocator.Allocate(make_uniq()); +Matcher &MatcherFactory::GetMatcher(const string &rule_name) { + auto entry = matchers.find(rule_name); + if (entry == matchers.end()) { + throw InternalException("Matcher for rule '%s' has not been built", rule_name); + } + return entry->second.get(); } Matcher &MatcherFactory::CreateMatcher(PEGParser &parser, string_t rule_name) { @@ -1177,7 +1149,9 @@ struct MatcherList { throw InternalException("Choice matcher should never be the root in the matcher stack"); } root_matcher.Cast().matchers.push_back(matcher); - matchers.pop_back(); + if (!matchers.empty()) { + matchers.pop_back(); + } break; default: throw InternalException("Cannot add matcher to root matcher of this type"); @@ -1311,7 +1285,9 @@ Matcher &MatcherFactory::CreateMatcher(PEGParser &parser, string_t rule_name, ve final_matcher = Repeat(final_matcher.get()); } auto &replaced_matcher = Optional(final_matcher); - list_matcher.matchers.pop_back(); + if (!list_matcher.matchers.empty()) { + list_matcher.matchers.pop_back(); + } list_matcher.matchers.push_back(replaced_matcher); break; } @@ -1327,7 +1303,9 @@ Matcher &MatcherFactory::CreateMatcher(PEGParser &parser, string_t rule_name, ve } auto &final_matcher = list_matcher.matchers.back(); final_matcher = Repeat(final_matcher.get()); - list_matcher.matchers.pop_back(); + if (!list_matcher.matchers.empty()) { + list_matcher.matchers.pop_back(); + } list_matcher.matchers.push_back(final_matcher); break; } @@ -1350,7 +1328,9 @@ Matcher &MatcherFactory::CreateMatcher(PEGParser &parser, string_t rule_name, ve choice_options.push_back(previous_matcher); auto &new_choice_matcher = Choice(choice_options); - list_matcher.matchers.pop_back(); + if (!list_matcher.matchers.empty()) { + list_matcher.matchers.pop_back(); + } list_matcher.matchers.push_back(new_choice_matcher); list.AddRootMatcher(new_choice_matcher); @@ -1395,7 +1375,7 @@ Matcher &MatcherFactory::CreateMatcher(PEGParser &parser, string_t rule_name, ve return matcher; } -void MatcherFactory::AddKeywordOverride(const char *name, uint32_t score, char extra_char) { +void MatcherFactory::AddKeywordOverride(const char *name, int32_t score, char extra_char) { auto &keyword_matcher = allocator.Allocate(make_uniq(name, score, extra_char)); keyword_overrides.insert(make_pair(name, reference(keyword_matcher))); } @@ -1418,31 +1398,54 @@ Matcher &MatcherFactory::CreateMatcher(const char *grammar, const char *root_rul AddKeywordOverride(".", 0, '\0'); AddKeywordOverride("(", 0, '\0'); // rule overrides - AddRuleOverride("Identifier", Variable()); - AddRuleOverride("ReservedIdentifier", ReservedVariable()); - - AddRuleOverride("CatalogName", CatalogName()); - AddRuleOverride("SchemaName", SchemaName()); - AddRuleOverride("ReservedSchemaName", ReservedSchemaName()); - AddRuleOverride("TableName", TableName()); - AddRuleOverride("ReservedTableName", ReservedTableName()); - AddRuleOverride("ColumnName", ColumnName()); - AddRuleOverride("ReservedColumnName", ReservedColumnName()); - AddRuleOverride("IndexName", Variable()); - AddRuleOverride("SequenceName", Variable()); - - AddRuleOverride("FunctionName", ScalarFunctionName()); - AddRuleOverride("ReservedFunctionName", ReservedScalarFunctionName()); - AddRuleOverride("TableFunctionName", TableFunctionName()); - - AddRuleOverride("TypeName", TypeName()); - AddRuleOverride("PragmaName", PragmaName()); - AddRuleOverride("SettingName", SettingName()); - AddRuleOverride("CopyOptionName", CopyOptionName()); - - AddRuleOverride("NumberLiteral", NumberLiteral()); - AddRuleOverride("StringLiteral", StringLiteral()); - AddRuleOverride("OperatorLiteral", Operator()); + //===--------------------------------------------------------------------===// + // START GENERATED RULE OVERRIDES + //===--------------------------------------------------------------------===// + AddRuleOverride("Identifier", allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE))); + AddRuleOverride("ReservedIdentifier", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE))); + AddRuleOverride("CatalogName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_CATALOG_NAME))); + AddRuleOverride("SchemaName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SCHEMA_NAME))); + AddRuleOverride("ReservedSchemaName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SCHEMA_NAME))); + AddRuleOverride("TableName", allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TABLE_NAME))); + AddRuleOverride("ReservedTableName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TABLE_NAME))); + AddRuleOverride("ColumnName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_COLUMN_NAME))); + AddRuleOverride("ReservedColumnName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_COLUMN_NAME))); + AddRuleOverride("IndexName", allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE))); + AddRuleOverride("ReservedIndexName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE))); + AddRuleOverride("SequenceName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE))); + AddRuleOverride("FunctionName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SCALAR_FUNCTION_NAME))); + AddRuleOverride("ReservedFunctionName", allocator.Allocate(make_uniq( + SuggestionState::SUGGEST_SCALAR_FUNCTION_NAME))); + AddRuleOverride("TableFunctionName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TABLE_FUNCTION_NAME))); + AddRuleOverride("TypeName", allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TYPE_NAME))); + AddRuleOverride("ReservedTypeName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_TYPE_NAME))); + AddRuleOverride("PragmaName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_PRAGMA_NAME))); + AddRuleOverride("SettingName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_SETTING_NAME))); + AddRuleOverride("CopyOptionName", + allocator.Allocate(make_uniq(SuggestionState::SUGGEST_VARIABLE))); + AddRuleOverride("NumberLiteral", allocator.Allocate(make_uniq())); + AddRuleOverride("StringLiteral", allocator.Allocate(make_uniq())); + AddRuleOverride("OperatorLiteral", allocator.Allocate(make_uniq())); + //===--------------------------------------------------------------------===// + // END GENERATED RULE OVERRIDES + //===--------------------------------------------------------------------===// + + // EndOfInput has no grammar body; satisfied here (outside the regenerated block). + AddRuleOverride("EndOfInput", allocator.Allocate(make_uniq())); // suppress suggestions for catch-all rules that would pollute statement-level autocomplete SuppressSuggestions("ExpressionStatement"); @@ -1476,10 +1479,12 @@ shared_ptr ParserCache::GetMatcher() { buffer << t.rdbuf(); auto grammar_string = buffer.str(); - new_matcher->root = factory.CreateMatcher(grammar_string.c_str(), "Program"); + new_matcher->program_matcher = factory.CreateMatcher(grammar_string.c_str(), "Program"); #else - new_matcher->root = factory.CreateMatcher(const_char_ptr_cast(INLINED_PEG_GRAMMAR), "Program"); + new_matcher->program_matcher = factory.CreateMatcher(const_char_ptr_cast(INLINED_PEG_GRAMMAR), "Program"); #endif + // TopLevelStatement is referenced by Program, so it has already been built and cached. + new_matcher->top_level_statement_matcher = factory.GetMatcher("TopLevelStatement"); std::unique_lock lock(mutex); if (!matcher) { matcher = std::move(new_matcher); diff --git a/src/duckdb/src/parser/peg/peg_parser.cpp b/src/duckdb/src/parser/peg/peg_parser.cpp index 6238dbe04..6caf39a09 100644 --- a/src/duckdb/src/parser/peg/peg_parser.cpp +++ b/src/duckdb/src/parser/peg/peg_parser.cpp @@ -1,4 +1,5 @@ #include "duckdb/parser/peg/peg_parser.hpp" +#include "duckdb/common/numeric_utils.hpp" namespace duckdb { @@ -56,7 +57,7 @@ void PEGParser::ParseRules(const char *grammar) { if (c == start_pos) { throw InternalException("Failed to parse grammar - expected an alpha-numeric rule name (pos %d)", c); } - rule_name = string_t(grammar + start_pos, c - start_pos); + rule_name = string_t(grammar + start_pos, UnsafeNumericCast(c - start_pos)); rule.Clear(); parse_state = PEGParseState::RULE_SEPARATOR; break; @@ -76,7 +77,8 @@ void PEGParser::ParseRules(const char *grammar) { throw InternalException("Failed to parse grammar - expected a parameter at position %d", c); } rule.parameters.insert( - make_pair(string_t(grammar + parameter_start, c - parameter_start), rule.parameters.size())); + make_pair(string_t(grammar + parameter_start, UnsafeNumericCast(c - parameter_start)), + rule.parameters.size())); if (grammar[c] != ')') { throw InternalException("Failed to parse grammar - expected closing bracket at position %d", c); } @@ -111,7 +113,7 @@ void PEGParser::ParseRules(const char *grammar) { throw InternalException("Failed to parse grammar - did not find closing ' (pos %d)", c); } PEGToken token; - token.text = string_t(grammar + literal_start, c - literal_start); + token.text = string_t(grammar + literal_start, UnsafeNumericCast(c - literal_start)); token.type = PEGTokenType::LITERAL; rule.tokens.push_back(token); c++; @@ -126,7 +128,7 @@ void PEGParser::ParseRules(const char *grammar) { c++; } PEGToken token; - token.text = string_t(grammar + rule_start, c - rule_start); + token.text = string_t(grammar + rule_start, UnsafeNumericCast(c - rule_start)); if (grammar[c] == '(') { // this is a function call c++; @@ -151,7 +153,7 @@ void PEGParser::ParseRules(const char *grammar) { } c++; PEGToken token; - token.text = string_t(grammar + rule_start, c - rule_start); + token.text = string_t(grammar + rule_start, UnsafeNumericCast(c - rule_start)); token.type = PEGTokenType::REGEX; rule.tokens.push_back(token); } else if (IsPEGOperator(grammar[c])) { @@ -176,6 +178,7 @@ void PEGParser::ParseRules(const char *grammar) { throw InternalException("Unrecognized rule contents in rule %s (character %s)", rule_name.GetString(), string(1, grammar[c])); } + break; } default: break; diff --git a/src/duckdb/src/parser/peg/sql_formatter.cpp b/src/duckdb/src/parser/peg/sql_formatter.cpp index 934058aed..57f25eb79 100644 --- a/src/duckdb/src/parser/peg/sql_formatter.cpp +++ b/src/duckdb/src/parser/peg/sql_formatter.cpp @@ -89,6 +89,12 @@ SQLFormatter::SQLFormatter(const FormatterConfig &config) : config(config) { string SQLFormatter::Format(const string &sql) { HighlightTokenizer tokenizer(sql); tokenizer.TokenizeInput(); + if (!tokenizer.tokens.empty()) { + auto back_type = tokenizer.tokens.back().type; + if (back_type == TokenType::END_OF_INPUT || back_type == TokenType::END_OF_INPUT_AUTOCOMPLETE) { + tokenizer.tokens.pop_back(); + } + } const auto &tokens = tokenizer.tokens; if (tokens.empty()) { diff --git a/src/duckdb/src/parser/peg/tokenizer/base_tokenizer.cpp b/src/duckdb/src/parser/peg/tokenizer/base_tokenizer.cpp index f60960474..ff7d7735d 100644 --- a/src/duckdb/src/parser/peg/tokenizer/base_tokenizer.cpp +++ b/src/duckdb/src/parser/peg/tokenizer/base_tokenizer.cpp @@ -204,7 +204,7 @@ bool BaseTokenizer::IsValidDollarTagCharacter(char c) { if (c == '_') { return true; } - if (c >= '\200' && c <= '\377') { + if ((unsigned char)c >= (unsigned char)'\200') { return true; } return false; @@ -221,7 +221,19 @@ bool BaseTokenizer::IsUnterminatedState(TokenizeState state) { } } -bool BaseTokenizer::TokenizeInput() { +bool BaseTokenizer::CanAutocomplete() const { + return !tokens.empty() && tokens.back().type == TokenType::END_OF_INPUT_AUTOCOMPLETE; +} + +void BaseTokenizer::TokenizeInput() { + if (TokenizeInputInternal()) { + tokens.emplace_back("", sql.size(), GetTerminator()); + } else { + tokens.emplace_back("", sql.size(), TokenType::END_OF_INPUT); + } +} + +bool BaseTokenizer::TokenizeInputInternal() { auto state = TokenizeState::STANDARD; idx_t last_pos = 0; string dollar_quote_marker; @@ -412,7 +424,8 @@ bool BaseTokenizer::TokenizeInput() { break; case TokenizeState::KEYWORD: // keyword - check if this is still a keyword - if (!CharacterIsKeyword(c)) { + // '$' is valid as a non-initial identifier character in PostgreSQL + if (c != '$' && !CharacterIsKeyword(c)) { // not a keyword - return to standard state auto word = sql.substr(last_pos, i - last_pos); auto token_type = keyword_helper.IsKeyword(word) ? TokenType::KEYWORD : TokenType::IDENTIFIER; @@ -506,12 +519,10 @@ bool BaseTokenizer::TokenizeInput() { } } - // finished processing - check the final state switch (state) { case TokenizeState::SINGLE_LINE_COMMENT: case TokenizeState::MULTI_LINE_COMMENT: PushToken(last_pos, sql.size(), TokenType::COMMENT); - // no suggestions in comments or dollar-quoted strings return false; case TokenizeState::DOLLAR_QUOTED_STRING: PushToken(last_pos, sql.size(), TokenType::STRING_LITERAL, true); diff --git a/src/duckdb/src/parser/peg/tokenizer/parser_tokenizer.cpp b/src/duckdb/src/parser/peg/tokenizer/parser_tokenizer.cpp index 7043f648c..1e9b7c472 100644 --- a/src/duckdb/src/parser/peg/tokenizer/parser_tokenizer.cpp +++ b/src/duckdb/src/parser/peg/tokenizer/parser_tokenizer.cpp @@ -1,4 +1,5 @@ #include "duckdb/parser/peg/tokenizer/parser_tokenizer.hpp" +#include "duckdb/common/exception/parser_exception.hpp" namespace duckdb { @@ -11,4 +12,16 @@ void ParserTokenizer::OnStatementEnd(idx_t pos) { tokens.emplace_back(";", pos, TokenType::TERMINATOR); } +void ParserTokenizer::OnLastToken(TokenizeState state, string last_word, idx_t last_pos) { + switch (state) { + case TokenizeState::STRING_LITERAL: + throw ParserException::SyntaxError(sql, "unterminated string literal", optional_idx(last_pos)); + case TokenizeState::QUOTED_IDENTIFIER: + throw ParserException::SyntaxError(sql, "unterminated quoted identifier", optional_idx(last_pos)); + default: + break; + } + BaseTokenizer::OnLastToken(state, std::move(last_word), last_pos); +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/peg_transformer.cpp b/src/duckdb/src/parser/peg/transformer/peg_transformer.cpp index 033ff8146..b97ae11d8 100644 --- a/src/duckdb/src/parser/peg/transformer/peg_transformer.cpp +++ b/src/duckdb/src/parser/peg/transformer/peg_transformer.cpp @@ -24,7 +24,7 @@ void PEGTransformer::ParamTypeCheck(PreparedParamType last_type, PreparedParamTy } } -bool PEGTransformer::GetParam(const string &identifier, idx_t &index, PreparedParamType type) { +bool PEGTransformer::GetParam(const Identifier &identifier, idx_t &index, PreparedParamType type) { ParamTypeCheck(last_param_type, type); auto entry = named_parameter_map.find(identifier); if (entry == named_parameter_map.end()) { @@ -34,7 +34,7 @@ bool PEGTransformer::GetParam(const string &identifier, idx_t &index, PreparedPa return true; } -void PEGTransformer::SetParam(const string &identifier, idx_t index, PreparedParamType type) { +void PEGTransformer::SetParam(const Identifier &identifier, idx_t index, PreparedParamType type) { ParamTypeCheck(last_param_type, type); last_param_type = type; D_ASSERT(!named_parameter_map.count(identifier)); @@ -65,9 +65,9 @@ unique_ptr PEGTransformer::GenerateCreateEnumStmt(unique_ptr(); info->temporary = true; info->internal = false; - info->catalog = INVALID_CATALOG; - info->schema = INVALID_SCHEMA; - info->name = std::move(entry->enum_name); + info->catalog = Identifier::InvalidCatalog(); + info->schema = Identifier::InvalidSchema(); + info->name = Identifier(std::move(entry->enum_name)); info->on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT; // generate the query that will result in the enum creation @@ -112,10 +112,13 @@ unique_ptr PEGTransformer::CreatePivotStatement(unique_ptrcolumn->ToString()); } - result->statements.push_back(GenerateCreateEnumStmt(std::move(pivot))); + auto enum_stmt = GenerateCreateEnumStmt(std::move(pivot)); + enum_stmt->query = enum_stmt->ToString(); + result->statements.push_back(std::move(enum_stmt)); } result->stmt_location = statement->stmt_location; result->stmt_length = statement->stmt_length; + statement->query = statement->ToString(); result->statements.push_back(std::move(statement)); return std::move(result); } @@ -149,8 +152,8 @@ bool PEGTransformer::IsWindowFrameDefault(WindowBoundary start, WindowBoundary e return start_is_default && end_is_default; } -unique_ptr PEGTransformer::GetWindowClause(const string &window_name) { - auto it = window_clauses.find(string(window_name)); +unique_ptr PEGTransformer::GetWindowClause(const Identifier &window_name) { + auto it = window_clauses.find(window_name); if (it == window_clauses.end()) { throw ParserException("window \"%s\" does not exist", window_name); } diff --git a/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp b/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp index 35775dfd8..c3be3b8c1 100644 --- a/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp +++ b/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp @@ -53,166 +53,72 @@ static unique_ptr ExtractAndTransformStatement(PEGTransformer &tra return stmt; } -vector> PEGTransformerFactory::Transform(vector &tokens, ParserOptions &options, - Matcher &root_matcher) { - if (tokens.empty()) { - return {}; - } - string token_stream; - for (auto &token : tokens) { - token_stream += token.text + " "; +unique_ptr PEGTransformerFactory::TransformTopLevelStatement(vector &tokens, + ParserOptions &options, + Matcher &root_matcher, idx_t &token_cursor) { + if (token_cursor >= tokens.size()) { + return nullptr; } vector suggestions; ParseResultAllocator parse_result_allocator; - idx_t max_token_index = 0; - MatchState state(tokens, suggestions, parse_result_allocator, max_token_index, options.preserve_identifier_case); - auto &matcher = root_matcher; - auto match_result = matcher.MatchParseResult(state); - if (match_result == nullptr || state.token_index < state.tokens.size()) { + idx_t max_token_index = token_cursor; + MatchState state(tokens, suggestions, parse_result_allocator, max_token_index, options.preserve_identifier_case, + token_cursor); + auto match_result = root_matcher.MatchParseResult(state); + if (match_result == nullptr) { + // syntax error — surface as a parser exception in the same shape as Transform() + string token_stream; + for (auto &token : tokens) { + token_stream += token.text + " "; + } idx_t error_token_idx = state.GetMaxTokenIndex(); if (error_token_idx >= tokens.size()) { error_token_idx = tokens.size() - 1; } - idx_t stmt_start = error_token_idx; - while (stmt_start > 0 && tokens[stmt_start - 1].text != ";") { - stmt_start--; - } - idx_t stmt_end = error_token_idx; - while (stmt_end < tokens.size() && tokens[stmt_end].text != ";") { - stmt_end++; + // Walk back past the EOI sentinel so the error message names a real token. + if (error_token_idx > 0 && (tokens[error_token_idx].type == TokenType::END_OF_INPUT || + tokens[error_token_idx].type == TokenType::END_OF_INPUT_AUTOCOMPLETE)) { + error_token_idx--; } - if (stmt_start < stmt_end && (stmt_start > 0 || stmt_end < tokens.size())) { - vector statement_tokens; - statement_tokens.reserve(stmt_end - stmt_start); - for (idx_t i = stmt_start; i < stmt_end; i++) { - statement_tokens.push_back(tokens[i]); - } - Transform(statement_tokens, options, root_matcher); - } - auto &error_token = tokens[error_token_idx]; auto error_message = "syntax error at or near \"" + error_token.text + "\""; throw ParserException::SyntaxError(token_stream, error_message, error_token.offset); } - match_result->name = "Program"; - // Program <- Statement? (';'+ Statement)* ';'* - // Program[0] = optional first Statement - // Program[1] = optional repeat of groups: (';'+ Statement) - // Program[2] = optional repeat of trailing ';' - auto &prog = match_result->Cast(); + // Advance the caller's cursor past the consumed tokens. + token_cursor = state.token_index; + + // TopLevelStatement <- Statement? (';'+ / EndOfInput) + // child 0: Optional + // child 1: bracket-wrapper list around Choice<';'+ | EndOfInput> + auto &tls = match_result->Cast(); + auto &stmt_opt = tls.Child(0); + if (!stmt_opt.HasResult()) { + // separator-only or EOI-only TopLevelStatement — no statement to yield + return nullptr; + } + auto &term_wrapper = tls.Child(1); + auto &term_inner = term_wrapper.Child(0).GetResult(); + optional_idx terminator_offset; + if (term_inner.type != ParseResultType::END_OF_INPUT) { + auto semi_children = term_inner.Cast().GetChildren(); + if (!semi_children.empty()) { + terminator_offset = semi_children[0].get().offset; + } + } ArenaAllocator transformer_allocator(Allocator::DefaultAllocator()); PEGTransformerState transformer_state(tokens); PEGTransformer transformer(transformer_allocator, transformer_state, sql_transform_functions, parser.rules, enum_mappings, options); - vector> result; - optional_ptr current_stmt; - auto &first_stmt = prog.Child(0); - if (first_stmt.HasResult()) { - current_stmt = first_stmt.GetResult(); - } - - auto &statement_repeat = prog.Child(1); - if (statement_repeat.HasResult()) { - auto &repeat = statement_repeat.GetResult().Cast(); - for (auto &child : repeat.GetChildren()) { - auto &child_list = child.get().Cast(); - auto &separators = child_list.Child(0); - auto &next_stmt = child_list.GetChild(1); - if (current_stmt) { - auto separator_children = separators.GetChildren(); - auto &separator_terminator = separator_children[0].get(); - result.push_back( - ExtractAndTransformStatement(transformer, tokens, *current_stmt, separator_terminator.offset)); - } - current_stmt = next_stmt; - } - } - if (current_stmt) { - optional_idx trailing_terminator; - auto &trailing_repeat = prog.Child(2); - if (trailing_repeat.HasResult()) { - auto trailing_children = trailing_repeat.GetResult().Cast().GetChildren(); - if (!trailing_children.empty()) { - trailing_terminator = trailing_children[0].get().offset; - } - } - result.push_back(ExtractAndTransformStatement(transformer, tokens, *current_stmt, trailing_terminator)); - } - return result; + return ExtractAndTransformStatement(transformer, tokens, stmt_opt.GetResult(), terminator_offset); } #define REGISTER_TRANSFORM(FUNCTION) Register(string(#FUNCTION).substr(9), &FUNCTION) -void PEGTransformerFactory::RegisterAlter() { - // alter.gram - REGISTER_TRANSFORM(TransformAlterStatement); - REGISTER_TRANSFORM(TransformAlterOptions); - REGISTER_TRANSFORM(TransformAlterTableStmt); - REGISTER_TRANSFORM(TransformAlterViewStmt); - REGISTER_TRANSFORM(TransformAlterSchemaStmt); - REGISTER_TRANSFORM(TransformAlterDatabaseStmt); - REGISTER_TRANSFORM(TransformAlterSequenceStmt); - REGISTER_TRANSFORM(TransformAlterSequenceOptions); - REGISTER_TRANSFORM(TransformSetSequenceOption); - REGISTER_TRANSFORM(TransformAlterTableOptions); - REGISTER_TRANSFORM(TransformAddColumn); - REGISTER_TRANSFORM(TransformAddColumnEntry); - REGISTER_TRANSFORM(TransformDropColumn); - REGISTER_TRANSFORM(TransformSetPartitionedBy); - REGISTER_TRANSFORM(TransformResetPartitionedBy); - REGISTER_TRANSFORM(TransformAlterColumn); - REGISTER_TRANSFORM(TransformAlterColumnEntry); - REGISTER_TRANSFORM(TransformDropDefault); - REGISTER_TRANSFORM(TransformChangeNullability); - REGISTER_TRANSFORM(TransformAlterType); - REGISTER_TRANSFORM(TransformUsingExpression); - REGISTER_TRANSFORM(TransformDropOrSet); - REGISTER_TRANSFORM(TransformAddOrDropDefault); - REGISTER_TRANSFORM(TransformAddDefault); - REGISTER_TRANSFORM(TransformRenameColumn); - REGISTER_TRANSFORM(TransformRenameAlter); - REGISTER_TRANSFORM(TransformAddConstraint); - REGISTER_TRANSFORM(TransformQualifiedSequenceName); - REGISTER_TRANSFORM(TransformSequenceName); - REGISTER_TRANSFORM(TransformSetSortedBy); - REGISTER_TRANSFORM(TransformResetSortedBy); - REGISTER_TRANSFORM(TransformSetOptions); - REGISTER_TRANSFORM(TransformResetOptions); -} - -void PEGTransformerFactory::RegisterAnalyze() { - // analyze.gram - REGISTER_TRANSFORM(TransformAnalyzeStatement); - REGISTER_TRANSFORM(TransformAnalyzeTarget); -} - -void PEGTransformerFactory::RegisterAttach() { - // attach.gram - REGISTER_TRANSFORM(TransformAttachStatement); - REGISTER_TRANSFORM(TransformAttachAlias); - REGISTER_TRANSFORM(TransformAttachOptions); - REGISTER_TRANSFORM(TransformGenericCopyOptionList); - REGISTER_TRANSFORM(TransformGenericCopyOption); - REGISTER_TRANSFORM(TransformDatabasePath); -} - -void PEGTransformerFactory::RegisterCall() { - // call.gram - REGISTER_TRANSFORM(TransformCallStatement); -} - -void PEGTransformerFactory::RegisterCheckpoint() { - // checkpoint.gram - REGISTER_TRANSFORM(TransformCheckpointStatement); -} - void PEGTransformerFactory::RegisterComment() { // comment.gram - REGISTER_TRANSFORM(TransformCommentStatement); - REGISTER_TRANSFORM(TransformCommentOnType); REGISTER_TRANSFORM(TransformCommentValue); } @@ -220,259 +126,22 @@ void PEGTransformerFactory::RegisterCommon() { // common.gram REGISTER_TRANSFORM(TransformNumberLiteral); REGISTER_TRANSFORM(TransformStringLiteral); - REGISTER_TRANSFORM(TransformType); - REGISTER_TRANSFORM(TransformArrayBounds); - REGISTER_TRANSFORM(TransformSquareBracketsArray); - REGISTER_TRANSFORM(TransformTimeType); - REGISTER_TRANSFORM(TransformTimeZone); - REGISTER_TRANSFORM(TransformWithOrWithout); - REGISTER_TRANSFORM(TransformTimeOrTimestamp); - REGISTER_TRANSFORM(TransformNumericType); - REGISTER_TRANSFORM(TransformSimpleNumericType); - REGISTER_TRANSFORM(TransformDecimalNumericType); - REGISTER_TRANSFORM(TransformFloatType); - REGISTER_TRANSFORM(TransformDecimalType); - REGISTER_TRANSFORM(TransformTypeModifiers); - REGISTER_TRANSFORM(TransformSimpleType); - REGISTER_TRANSFORM(TransformQualifiedTypeName); - REGISTER_TRANSFORM(TransformCharacterType); - REGISTER_TRANSFORM(TransformMapType); - REGISTER_TRANSFORM(TransformRowType); - REGISTER_TRANSFORM(TransformGeometryType); - REGISTER_TRANSFORM(TransformVariantType); - REGISTER_TRANSFORM(TransformUnionType); - REGISTER_TRANSFORM(TransformColIdTypeList); - REGISTER_TRANSFORM(TransformColIdType); - REGISTER_TRANSFORM(TransformBitType); - REGISTER_TRANSFORM(TransformIntervalType); - REGISTER_TRANSFORM(TransformIntervalInterval); - REGISTER_TRANSFORM(TransformInterval); - REGISTER_TRANSFORM(TransformIntervalToInterval); - REGISTER_TRANSFORM(TransformSetofType); - Register("NumericModType", &TransformDecimalType); - Register("DecType", &TransformDecimalType); -} - -void PEGTransformerFactory::RegisterCopy() { - // copy.gram - REGISTER_TRANSFORM(TransformCopyStatement); - REGISTER_TRANSFORM(TransformCopySelect); - REGISTER_TRANSFORM(TransformCopyFromDatabase); - REGISTER_TRANSFORM(TransformCopyDatabaseFlag); - REGISTER_TRANSFORM(TransformCopyTable); - REGISTER_TRANSFORM(TransformFromOrTo); - REGISTER_TRANSFORM(TransformCopyFileName); - REGISTER_TRANSFORM(TransformCopyOptions); - REGISTER_TRANSFORM(TransformSpecializedOptionList); - REGISTER_TRANSFORM(TransformSpecializedOption); - REGISTER_TRANSFORM(TransformSingleOption); - REGISTER_TRANSFORM(TransformEncodingOption); - REGISTER_TRANSFORM(TransformForceQuoteOption); - REGISTER_TRANSFORM(TransformQuoteAsOption); - REGISTER_TRANSFORM(TransformForceNullOption); - REGISTER_TRANSFORM(TransformPartitionByOption); - REGISTER_TRANSFORM(TransformNullAsOption); - REGISTER_TRANSFORM(TransformDelimiterAsOption); - REGISTER_TRANSFORM(TransformEscapeAsOption); - REGISTER_TRANSFORM(TransformSchemaOrData); -} - -void PEGTransformerFactory::RegisterCreateIndex() { - // create_index.gram - REGISTER_TRANSFORM(TransformCreateIndexStmt); - REGISTER_TRANSFORM(TransformIndexType); - REGISTER_TRANSFORM(TransformIndexElement); - REGISTER_TRANSFORM(TransformWithList); - REGISTER_TRANSFORM(TransformRelOptionOrOids); - REGISTER_TRANSFORM(TransformRelOptionList); - REGISTER_TRANSFORM(TransformOids); - REGISTER_TRANSFORM(TransformRelOptionName); - REGISTER_TRANSFORM(TransformRelOptionArgumentOpt); - REGISTER_TRANSFORM(TransformRelOption); - REGISTER_TRANSFORM(TransformIndexName); + REGISTER_TRANSFORM(TransformIntervalToIntervalAsType); } void PEGTransformerFactory::RegisterCreateMacro() { // create_macro.gram - REGISTER_TRANSFORM(TransformCreateMacroStmt); - REGISTER_TRANSFORM(TransformMacroDefinition); - REGISTER_TRANSFORM(TransformTableMacroDefinition); - REGISTER_TRANSFORM(TransformScalarMacroDefinition); - REGISTER_TRANSFORM(TransformMacroParameters); - REGISTER_TRANSFORM(TransformMacroParameter); - REGISTER_TRANSFORM(TransformSimpleParameter); -} - -void PEGTransformerFactory::RegisterCreateSchema() { - REGISTER_TRANSFORM(TransformCreateSchemaStmt); -} - -void PEGTransformerFactory::RegisterCreateSecret() { - // create_secret.gram - REGISTER_TRANSFORM(TransformCreateSecretStmt); - REGISTER_TRANSFORM(TransformSecretStorageSpecifier); - REGISTER_TRANSFORM(TransformSecretName); -} - -void PEGTransformerFactory::RegisterCreateSequence() { - REGISTER_TRANSFORM(TransformCreateSequenceStmt); - REGISTER_TRANSFORM(TransformSequenceOption); - REGISTER_TRANSFORM(TransformSeqSetCycle); - REGISTER_TRANSFORM(TransformSeqSetIncrement); - REGISTER_TRANSFORM(TransformSeqSetMinMax); - REGISTER_TRANSFORM(TransformSeqMinOrMax); - REGISTER_TRANSFORM(TransformSeqNoMinMax); - REGISTER_TRANSFORM(TransformSeqStartWith); - REGISTER_TRANSFORM(TransformSeqOwnedBy); } void PEGTransformerFactory::RegisterCreateTable() { // create_table.gram - REGISTER_TRANSFORM(TransformCreateStatement); - REGISTER_TRANSFORM(TransformTemporary); - REGISTER_TRANSFORM(TransformCreateStatementVariation); - REGISTER_TRANSFORM(TransformCreateTableStmt); - REGISTER_TRANSFORM(TransformCreateTableAs); - REGISTER_TRANSFORM(TransformIdentifierList); - REGISTER_TRANSFORM(TransformCreateColumnList); - REGISTER_TRANSFORM(TransformCreateTableColumnList); - REGISTER_TRANSFORM(TransformPartitionSortedOptions); - REGISTER_TRANSFORM(TransformPartitionOptSortedOptions); - REGISTER_TRANSFORM(TransformSortedOptPartitionOptions); - REGISTER_TRANSFORM(TransformPartitionOptions); - REGISTER_TRANSFORM(TransformSortedOptions); - REGISTER_TRANSFORM(TransformIdentifierOrStringLiteral); - REGISTER_TRANSFORM(TransformColIdOrString); REGISTER_TRANSFORM(TransformColLabelOrString); - REGISTER_TRANSFORM(TransformColId); - REGISTER_TRANSFORM(TransformColumnIdList); - REGISTER_TRANSFORM(TransformTypeFuncName); REGISTER_TRANSFORM(TransformIdentifier); - REGISTER_TRANSFORM(TransformDottedIdentifier); - REGISTER_TRANSFORM(TransformColumnDefinition); - REGISTER_TRANSFORM(TransformTopLevelConstraint); - REGISTER_TRANSFORM(TransformTopLevelConstraintList); - REGISTER_TRANSFORM(TransformTopPrimaryKeyConstraint); - REGISTER_TRANSFORM(TransformTopUniqueConstraint); - REGISTER_TRANSFORM(TransformCheckConstraint); - REGISTER_TRANSFORM(TransformTopForeignKeyConstraint); - REGISTER_TRANSFORM(TransformForeignKeyConstraint); - REGISTER_TRANSFORM(TransformDefaultValue); - REGISTER_TRANSFORM(TransformGeneratedColumn); - REGISTER_TRANSFORM(TransformColumnCompression); - REGISTER_TRANSFORM(TransformPrimaryKeyConstraint); - REGISTER_TRANSFORM(TransformUniqueConstraint); - REGISTER_TRANSFORM(TransformNotNullConstraint); - REGISTER_TRANSFORM(TransformKeyActions); - REGISTER_TRANSFORM(TransformKeyAction); - REGISTER_TRANSFORM(TransformNoKeyAction); - REGISTER_TRANSFORM(TransformRestrictKeyAction); - REGISTER_TRANSFORM(TransformCascadeKeyAction); - REGISTER_TRANSFORM(TransformSetNullKeyAction); - REGISTER_TRANSFORM(TransformSetDefaultKeyAction); REGISTER_TRANSFORM(TransformCubeOrRollup); - - REGISTER_TRANSFORM(TransformUpdateAction); - REGISTER_TRANSFORM(TransformDeleteAction); - REGISTER_TRANSFORM(TransformColumnCollation); - REGISTER_TRANSFORM(TransformWithData); - REGISTER_TRANSFORM(TransformCommitAction); - REGISTER_TRANSFORM(TransformPreserveOrDelete); - REGISTER_TRANSFORM(TransformGeneratedColumnType); -} - -void PEGTransformerFactory::RegisterCreateType() { - // create_type.gram - REGISTER_TRANSFORM(TransformCreateTypeStmt); - REGISTER_TRANSFORM(TransformCreateType); - REGISTER_TRANSFORM(TransformEnumSelectType); - REGISTER_TRANSFORM(TransformEnumStringLiteralList); -} - -void PEGTransformerFactory::RegisterCreateView() { - REGISTER_TRANSFORM(TransformCreateViewStmt); -} - -void PEGTransformerFactory::RegisterCreateTrigger() { - REGISTER_TRANSFORM(TransformCreateTriggerStmt); - REGISTER_TRANSFORM(TransformForEachClause); - REGISTER_TRANSFORM(TransformTriggerName); - REGISTER_TRANSFORM(TransformTriggerTiming); - REGISTER_TRANSFORM(TransformTriggerEvent); - REGISTER_TRANSFORM(TransformTriggerEventInsert); - REGISTER_TRANSFORM(TransformTriggerEventDelete); - REGISTER_TRANSFORM(TransformTriggerEventUpdate); - REGISTER_TRANSFORM(TransformTriggerEventUpdateOf); - REGISTER_TRANSFORM(TransformTriggerColumnList); - REGISTER_TRANSFORM(TransformTriggerBody); -} - -void PEGTransformerFactory::RegisterDeallocate() { - // deallocate.gram - REGISTER_TRANSFORM(TransformDeallocateStatement); -} - -void PEGTransformerFactory::RegisterDelete() { - // delete.gram - REGISTER_TRANSFORM(TransformDeleteStatement); - REGISTER_TRANSFORM(TransformTargetOptAlias); - REGISTER_TRANSFORM(TransformDeleteUsingClause); - REGISTER_TRANSFORM(TransformTruncateStatement); -} - -void PEGTransformerFactory::RegisterDescribe() { - // describe.gram - REGISTER_TRANSFORM(TransformDescribeStatement); - REGISTER_TRANSFORM(TransformShowSelect); - REGISTER_TRANSFORM(TransformShowTables); - REGISTER_TRANSFORM(TransformShowAllTables); - REGISTER_TRANSFORM(TransformShowQualifiedName); - REGISTER_TRANSFORM(TransformShowOrDescribeOrSummarize); - REGISTER_TRANSFORM(TransformShowOrDescribe); - REGISTER_TRANSFORM(TransformSummarize); -} - -void PEGTransformerFactory::RegisterDrop() { - // drop.gram - REGISTER_TRANSFORM(TransformDropStatement); - REGISTER_TRANSFORM(TransformDropEntries); - REGISTER_TRANSFORM(TransformDropTable); - REGISTER_TRANSFORM(TransformTableOrView); - REGISTER_TRANSFORM(TransformDropTableFunction); - REGISTER_TRANSFORM(TransformDropFunction); - REGISTER_TRANSFORM(TransformDropSchema); - REGISTER_TRANSFORM(TransformQualifiedSchemaName); - REGISTER_TRANSFORM(TransformDropIndex); - REGISTER_TRANSFORM(TransformQualifiedIndexName); - REGISTER_TRANSFORM(TransformDropSequence); - REGISTER_TRANSFORM(TransformCollationName); - REGISTER_TRANSFORM(TransformDropCollation); - REGISTER_TRANSFORM(TransformDropType); - REGISTER_TRANSFORM(TransformDropBehavior); - REGISTER_TRANSFORM(TransformIfExists); - REGISTER_TRANSFORM(TransformDropSecret); - REGISTER_TRANSFORM(TransformDropSecretStorage); - REGISTER_TRANSFORM(TransformDropTrigger); -} - -void PEGTransformerFactory::RegisterExecute() { - // execute.gram - REGISTER_TRANSFORM(TransformExecuteStatement); -} - -void PEGTransformerFactory::RegisterExplain() { - // explain.gram - REGISTER_TRANSFORM(TransformExplainStatement); - REGISTER_TRANSFORM(TransformExplainableStatements); - REGISTER_TRANSFORM(TransformExplainOptionList); - REGISTER_TRANSFORM(TransformExplainOption); - Register("ExplainOptionName", &TransformIdentifierOrKeyword); } void PEGTransformerFactory::RegisterExpression() { // expression.gram - REGISTER_TRANSFORM(TransformExpressionStatement); - REGISTER_TRANSFORM(TransformExpressionAlias); REGISTER_TRANSFORM(TransformBaseExpression); REGISTER_TRANSFORM(TransformExpression); Register("ColumnDefaultExpr", &TransformExpression); @@ -524,11 +193,9 @@ void PEGTransformerFactory::RegisterExpression() { REGISTER_TRANSFORM(TransformAtTimeZoneExpression); REGISTER_TRANSFORM(TransformPrefixExpression); - REGISTER_TRANSFORM(TransformNestedColumnName); REGISTER_TRANSFORM(TransformColumnReference); REGISTER_TRANSFORM(TransformCatalogReservedSchemaTableColumnName); REGISTER_TRANSFORM(TransformSchemaReservedTableColumnName); - REGISTER_TRANSFORM(TransformReservedTableQualification); REGISTER_TRANSFORM(TransformParameter); REGISTER_TRANSFORM(TransformAnonymousParameter); @@ -541,6 +208,11 @@ void PEGTransformerFactory::RegisterExpression() { REGISTER_TRANSFORM(TransformParensExpression); REGISTER_TRANSFORM(TransformSingleExpression); REGISTER_TRANSFORM(TransformConstantLiteral); + REGISTER_TRANSFORM(TransformFalseLiteral); + REGISTER_TRANSFORM(TransformTrueLiteral); + REGISTER_TRANSFORM(TransformNullLiteral); + REGISTER_TRANSFORM(TransformUnknownLiteral); + REGISTER_TRANSFORM(TransformPrefixOperator); REGISTER_TRANSFORM(TransformListExpression); REGISTER_TRANSFORM(TransformStructExpression); @@ -564,7 +236,7 @@ void PEGTransformerFactory::RegisterExpression() { REGISTER_TRANSFORM(TransformStepSliceBound); REGISTER_TRANSFORM(TransformTableReservedColumnName); - REGISTER_TRANSFORM(TransformTableQualification); + REGISTER_TRANSFORM(TransformColIdDot); REGISTER_TRANSFORM(TransformStarExpression); REGISTER_TRANSFORM(TransformExcludeList); REGISTER_TRANSFORM(TransformExcludeNameList); @@ -614,9 +286,18 @@ void PEGTransformerFactory::RegisterExpression() { REGISTER_TRANSFORM(TransformSubstringArguments); REGISTER_TRANSFORM(TransformSubstringExpressionList); REGISTER_TRANSFORM(TransformSubstringParameters); + REGISTER_TRANSFORM(TransformSubstringFromFor); + REGISTER_TRANSFORM(TransformSubstringFromOptionalFor); + REGISTER_TRANSFORM(TransformSubstringFor); REGISTER_TRANSFORM(TransformTrimExpression); REGISTER_TRANSFORM(TransformTrimDirection); REGISTER_TRANSFORM(TransformTrimSource); + REGISTER_TRANSFORM(TransformOverlayExpression); + REGISTER_TRANSFORM(TransformOverlayArguments); + REGISTER_TRANSFORM(TransformOverlayParameters); + REGISTER_TRANSFORM(TransformFromExpression); + REGISTER_TRANSFORM(TransformForExpression); + REGISTER_TRANSFORM(TransformOverlayExpressionList); REGISTER_TRANSFORM(TransformPositionExpression); REGISTER_TRANSFORM(TransformCastExpression); REGISTER_TRANSFORM(TransformCastOrTryCast); @@ -644,87 +325,17 @@ void PEGTransformerFactory::RegisterExpression() { REGISTER_TRANSFORM(TransformIgnoreOrRespectNulls); } -void PEGTransformerFactory::RegisterInsert() { - // insert.gram - REGISTER_TRANSFORM(TransformInsertStatement); - REGISTER_TRANSFORM(TransformOrAction); - REGISTER_TRANSFORM(TransformInsertTarget); - REGISTER_TRANSFORM(TransformInsertAlias); - REGISTER_TRANSFORM(TransformOnConflictClause); - REGISTER_TRANSFORM(TransformOnConflictTarget); - REGISTER_TRANSFORM(TransformOnConflictExpressionTarget); - REGISTER_TRANSFORM(TransformOnConflictIndexTarget); - REGISTER_TRANSFORM(TransformOnConflictAction); - REGISTER_TRANSFORM(TransformOnConflictUpdate); - REGISTER_TRANSFORM(TransformOnConflictNothing); - REGISTER_TRANSFORM(TransformInsertValues); - REGISTER_TRANSFORM(TransformByNameOrPosition); - REGISTER_TRANSFORM(TransformInsertColumnList); - REGISTER_TRANSFORM(TransformColumnList); - REGISTER_TRANSFORM(TransformReturningClause); -} - -void PEGTransformerFactory::RegisterLoad() { - // load.gram - REGISTER_TRANSFORM(TransformLoadStatement); - REGISTER_TRANSFORM(TransformInstallStatement); - REGISTER_TRANSFORM(TransformUpdateExtensionsStatement); - REGISTER_TRANSFORM(TransformFromSource); - REGISTER_TRANSFORM(TransformVersionNumber); -} - -void PEGTransformerFactory::RegisterMergeInto() { - // merge_into.gram - REGISTER_TRANSFORM(TransformMergeIntoStatement); - REGISTER_TRANSFORM(TransformMergeIntoUsingClause); - REGISTER_TRANSFORM(TransformMergeMatch); - REGISTER_TRANSFORM(TransformMatchedClause); - REGISTER_TRANSFORM(TransformMatchedClauseAction); - REGISTER_TRANSFORM(TransformUpdateMatchClause); - REGISTER_TRANSFORM(TransformUpdateMatchInfo); - REGISTER_TRANSFORM(TransformDeleteMatchClause); - REGISTER_TRANSFORM(TransformInsertMatchClause); - REGISTER_TRANSFORM(TransformDoNothingMatchClause); - REGISTER_TRANSFORM(TransformErrorMatchClause); - REGISTER_TRANSFORM(TransformUpdateMatchSetClause); - REGISTER_TRANSFORM(TransformAndExpression); - REGISTER_TRANSFORM(TransformNotMatchedClause); - REGISTER_TRANSFORM(TransformBySourceOrTarget); - REGISTER_TRANSFORM(TransformInsertMatchInfo); - REGISTER_TRANSFORM(TransformInsertDefaultValues); - REGISTER_TRANSFORM(TransformInsertByNameOrPosition); - REGISTER_TRANSFORM(TransformInsertValuesList); +void PEGTransformerFactory::RegisterConnect() { + // connect.gram — both rules are hand-written; the generator skips them because of the + // optional SessionTarget sub-rule. + REGISTER_TRANSFORM(TransformConnectStatement); } void PEGTransformerFactory::RegisterPivot() { - // pivot.gram + // PivotStatement and UnpivotStatement measure parameter usage while transforming + // the source table, so their top-level wrappers remain manual. REGISTER_TRANSFORM(TransformPivotStatement); - REGISTER_TRANSFORM(TransformPivotUsing); - REGISTER_TRANSFORM(TransformPivotOn); - REGISTER_TRANSFORM(TransformPivotColumnList); - REGISTER_TRANSFORM(TransformPivotColumnEntry); - REGISTER_TRANSFORM(TransformPivotColumnSubquery); - REGISTER_TRANSFORM(TransformPivotColumnEntryInternal); REGISTER_TRANSFORM(TransformUnpivotStatement); - REGISTER_TRANSFORM(TransformIntoNameValues); - - REGISTER_TRANSFORM(TransformUnpivotHeader); - REGISTER_TRANSFORM(TransformUnpivotHeaderSingle); - REGISTER_TRANSFORM(TransformUnpivotHeaderList); - REGISTER_TRANSFORM(TransformIncludeOrExcludeNulls); -} - -void PEGTransformerFactory::RegisterPragma() { - // pragma.gram - REGISTER_TRANSFORM(TransformPragmaStatement); - REGISTER_TRANSFORM(TransformPragmaAssign); - REGISTER_TRANSFORM(TransformPragmaFunction); - REGISTER_TRANSFORM(TransformPragmaParameters); -} - -void PEGTransformerFactory::RegisterPrepare() { - // prepare.gram - REGISTER_TRANSFORM(TransformPrepareStatement); } void PEGTransformerFactory::RegisterSelect() { @@ -756,12 +367,6 @@ void PEGTransformerFactory::RegisterSelect() { REGISTER_TRANSFORM(TransformBaseTableName); REGISTER_TRANSFORM(TransformSchemaReservedTable); REGISTER_TRANSFORM(TransformCatalogReservedSchemaTable); - REGISTER_TRANSFORM(TransformSchemaQualification); - REGISTER_TRANSFORM(TransformCatalogQualification); - REGISTER_TRANSFORM(TransformQualifiedName); - REGISTER_TRANSFORM(TransformCatalogReservedSchemaIdentifier); - REGISTER_TRANSFORM(TransformSchemaReservedIdentifierOrStringLiteral); - REGISTER_TRANSFORM(TransformReservedIdentifierOrStringLiteral); REGISTER_TRANSFORM(TransformWhereClause); REGISTER_TRANSFORM(TransformTableFunctionArguments); @@ -868,48 +473,6 @@ void PEGTransformerFactory::RegisterSelect() { REGISTER_TRANSFORM(TransformRepeatableSample); } -void PEGTransformerFactory::RegisterSet() { - // set.gram - REGISTER_TRANSFORM(TransformResetStatement); - REGISTER_TRANSFORM(TransformSetAssignment); - REGISTER_TRANSFORM(TransformSetSetting); - REGISTER_TRANSFORM(TransformSetStatement); - REGISTER_TRANSFORM(TransformSetTimeZone); - REGISTER_TRANSFORM(TransformSetVariable); - REGISTER_TRANSFORM(TransformStandardAssignment); - REGISTER_TRANSFORM(TransformVariableList); - REGISTER_TRANSFORM(TransformZoneValue); - REGISTER_TRANSFORM(TransformZoneIntervalWithInterval); - REGISTER_TRANSFORM(TransformZoneIntervalWithPrecision); -} - -void PEGTransformerFactory::RegisterTransaction() { - REGISTER_TRANSFORM(TransformReadOnlyOrReadWrite); -} - -void PEGTransformerFactory::RegisterUpdate() { - // update.gram - REGISTER_TRANSFORM(TransformUpdateStatement); - REGISTER_TRANSFORM(TransformUpdateTarget); - REGISTER_TRANSFORM(TransformBaseTableSet); - REGISTER_TRANSFORM(TransformBaseTableAliasSet); - REGISTER_TRANSFORM(TransformUpdateAlias); - REGISTER_TRANSFORM(TransformUpdateSetClause); - REGISTER_TRANSFORM(TransformUpdateSetTuple); - REGISTER_TRANSFORM(TransformUpdateSetElementList); - REGISTER_TRANSFORM(TransformUpdateSetElement); - REGISTER_TRANSFORM(TransformUpdateSetColumnTarget); -} - -void PEGTransformerFactory::RegisterVacuum() { - REGISTER_TRANSFORM(TransformVacuumStatement); - REGISTER_TRANSFORM(TransformVacuumOptions); - REGISTER_TRANSFORM(TransformVacuumLegacyOptions); - REGISTER_TRANSFORM(TransformVacuumParensOptions); - REGISTER_TRANSFORM(TransformVacuumOption); - REGISTER_TRANSFORM(TransformNameList); -} - void PEGTransformerFactory::RegisterKeywordsAndIdentifiers() { Register("PragmaName", &TransformIdentifierOrKeyword); Register("TypeName", &TransformIdentifierOrKeyword); @@ -922,90 +485,16 @@ void PEGTransformerFactory::RegisterKeywordsAndIdentifiers() { Register("FuncNameKeyword", &TransformIdentifierOrKeyword); Register("TypeNameKeyword", &TransformIdentifierOrKeyword); Register("SettingName", &TransformIdentifierOrKeyword); - Register("ReservedSchemaQualification", &TransformSchemaQualification); + Register("ExplainOptionName", &TransformIdentifierOrKeyword); } void PEGTransformerFactory::RegisterEnums() { - RegisterEnum("LocalScope", SetScope::LOCAL); - RegisterEnum("GlobalScope", SetScope::GLOBAL); - RegisterEnum("SessionScope", SetScope::SESSION); - RegisterEnum("VariableScope", SetScope::VARIABLE); - - RegisterEnum("FalseLiteral", Value(false)); - RegisterEnum("TrueLiteral", Value(true)); - RegisterEnum("NullLiteral", Value()); - RegisterEnum("UnknownLiteral", Value()); - - RegisterEnum("ReadOnly", TransactionModifierType::TRANSACTION_READ_ONLY); - RegisterEnum("ReadWrite", TransactionModifierType::TRANSACTION_READ_WRITE); - - RegisterEnum("CopySchema", CopyDatabaseType::COPY_SCHEMA); - RegisterEnum("CopyData", CopyDatabaseType::COPY_DATA); - - RegisterEnum("IntType", LogicalTypeIdToString(LogicalTypeId::INTEGER)); - RegisterEnum("IntegerType", LogicalTypeIdToString(LogicalTypeId::INTEGER)); - RegisterEnum("SmallintType", LogicalTypeIdToString(LogicalTypeId::SMALLINT)); - RegisterEnum("BigintType", LogicalTypeIdToString(LogicalTypeId::BIGINT)); - RegisterEnum("RealType", LogicalTypeIdToString(LogicalTypeId::FLOAT)); - RegisterEnum("DoubleType", LogicalTypeIdToString(LogicalTypeId::DOUBLE)); - RegisterEnum("BooleanType", LogicalTypeIdToString(LogicalTypeId::BOOLEAN)); - - RegisterEnum("YearKeyword", DatePartSpecifier::YEAR); - RegisterEnum("MonthKeyword", DatePartSpecifier::MONTH); - RegisterEnum("DayKeyword", DatePartSpecifier::DAY); - RegisterEnum("HourKeyword", DatePartSpecifier::HOUR); - RegisterEnum("MinuteKeyword", DatePartSpecifier::MINUTE); - RegisterEnum("SecondKeyword", DatePartSpecifier::SECOND); - RegisterEnum("MillisecondKeyword", DatePartSpecifier::MILLISECONDS); - RegisterEnum("MicrosecondKeyword", DatePartSpecifier::MICROSECONDS); - RegisterEnum("WeekKeyword", DatePartSpecifier::WEEK); - RegisterEnum("QuarterKeyword", DatePartSpecifier::QUARTER); - RegisterEnum("DecadeKeyword", DatePartSpecifier::DECADE); - RegisterEnum("CenturyKeyword", DatePartSpecifier::CENTURY); - RegisterEnum("MillenniumKeyword", DatePartSpecifier::MILLENNIUM); - - RegisterEnum("TimeTypeId", LogicalTypeId::TIME); - RegisterEnum("TimestampTypeId", LogicalTypeId::TIMESTAMP); - RegisterEnum("WithRule", true); - RegisterEnum("WithoutRule", false); - - RegisterEnum("TempPersistent", SecretPersistType::TEMPORARY); - RegisterEnum("TemporaryPersistent", SecretPersistType::TEMPORARY); - RegisterEnum("Persistent", SecretPersistType::PERSISTENT); - - RegisterEnum("CommentTable", CatalogType::TABLE_ENTRY); - RegisterEnum("CommentSequence", CatalogType::SEQUENCE_ENTRY); - RegisterEnum("CommentFunction", CatalogType::MACRO_ENTRY); - RegisterEnum("CommentMacroTable", CatalogType::TABLE_MACRO_ENTRY); - RegisterEnum("CommentMacro", CatalogType::MACRO_ENTRY); - RegisterEnum("CommentView", CatalogType::VIEW_ENTRY); RegisterEnum("MaterializedViewEntry", CatalogType::VIEW_ENTRY); - RegisterEnum("CommentDatabase", CatalogType::DATABASE_ENTRY); - RegisterEnum("CommentIndex", CatalogType::INDEX_ENTRY); - RegisterEnum("CommentSchema", CatalogType::SCHEMA_ENTRY); - RegisterEnum("CommentType", CatalogType::TYPE_ENTRY); - RegisterEnum("CommentColumn", CatalogType::INVALID); - - RegisterEnum("TriggerBefore", TriggerTiming::BEFORE); - RegisterEnum("TriggerAfter", TriggerTiming::AFTER); - RegisterEnum("TriggerInsteadOf", TriggerTiming::INSTEAD_OF); - RegisterEnum("ForEachRow", TriggerForEach::ROW); - RegisterEnum("ForEachStatement", TriggerForEach::STATEMENT); - - RegisterEnum("MinValue", "minvalue"); - RegisterEnum("MaxValue", "maxvalue"); RegisterEnum("MinusPrefixOperator", "-"); RegisterEnum("PlusPrefixOperator", "+"); RegisterEnum("TildePrefixOperator", "~"); - RegisterEnum("SummarizeRule", ShowType::SUMMARY); - RegisterEnum("ShowRule", ShowType::DESCRIBE); - RegisterEnum("DescribeRule", ShowType::DESCRIBE); - - RegisterEnum("InsertByName", InsertColumnOrder::INSERT_BY_NAME); - RegisterEnum("InsertByPosition", InsertColumnOrder::INSERT_BY_POSITION); - RegisterEnum("DescendingOrder", OrderType::DESCENDING); RegisterEnum("AscendingOrder", OrderType::ASCENDING); RegisterEnum("NullsFirst", OrderByNullType::NULLS_FIRST); @@ -1040,25 +529,11 @@ void PEGTransformerFactory::RegisterEnums() { RegisterEnum("NotLikeOp", "!~~"); RegisterEnum("NotSimilarToOp", "!~"); - RegisterEnum("OptAnalyze", "analyze"); - RegisterEnum("OptFreeze", "freeze"); - RegisterEnum("OptFull", "full"); - RegisterEnum("OptVerbose", "verbose"); - - RegisterEnum("BySource", MergeActionCondition::WHEN_NOT_MATCHED_BY_SOURCE); - RegisterEnum("ByTarget", MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET); - RegisterEnum("ExcludeCurrentRow", WindowExcludeMode::CURRENT_ROW); RegisterEnum("ExcludeGroup", WindowExcludeMode::GROUP); RegisterEnum("ExcludeTies", WindowExcludeMode::TIES); RegisterEnum("ExcludeNoOthers", WindowExcludeMode::NO_OTHER); - RegisterEnum("BinaryOption", GenericCopyOption("format", Value("binary"))); - RegisterEnum("FreezeOption", GenericCopyOption("freeze", Value())); - RegisterEnum("OidsOption", GenericCopyOption("oids", Value())); - RegisterEnum("CsvOption", GenericCopyOption("format", Value("csv"))); - RegisterEnum("HeaderOption", GenericCopyOption("header", Value(true))); - RegisterEnum("SamplePercentage", true); RegisterEnum("SampleRows", false); @@ -1075,41 +550,14 @@ void PEGTransformerFactory::RegisterEnums() { PEGTransformerFactory::PEGTransformerFactory() { RegisterGenerated(); REGISTER_TRANSFORM(TransformStatement); - RegisterAlter(); - RegisterAttach(); - RegisterAnalyze(); - RegisterCall(); - RegisterCheckpoint(); RegisterComment(); RegisterCommon(); - RegisterCopy(); - RegisterCreateIndex(); RegisterCreateMacro(); - RegisterCreateSchema(); - RegisterCreateSequence(); - RegisterCreateSecret(); RegisterCreateTable(); - RegisterCreateType(); - RegisterCreateView(); - RegisterCreateTrigger(); - RegisterDeallocate(); - RegisterDelete(); - RegisterDescribe(); - RegisterDrop(); - RegisterExecute(); - RegisterExplain(); RegisterExpression(); - RegisterInsert(); - RegisterLoad(); - RegisterMergeInto(); + RegisterConnect(); RegisterPivot(); - RegisterPragma(); - RegisterPrepare(); RegisterSelect(); - RegisterSet(); - RegisterTransaction(); - RegisterUpdate(); - RegisterVacuum(); RegisterKeywordsAndIdentifiers(); RegisterEnums(); } @@ -1136,12 +584,12 @@ ParseResult &PEGTransformerFactory::ExtractResultFromParens(ParseResult &parse_r return list_pr.GetChild(1); } -bool PEGTransformerFactory::ExpressionIsEmptyStar(ParsedExpression &expr) { +bool PEGTransformerFactory::ExpressionIsEmptyStar(const ParsedExpression &expr) { if (expr.GetExpressionClass() != ExpressionClass::STAR) { return false; } auto &star = expr.Cast(); - if (!star.columns && star.exclude_list.empty() && star.replace_list.empty()) { + if (!star.IsColumns() && star.ExcludeList().empty() && star.ReplaceList().empty()) { return true; } return false; @@ -1153,17 +601,17 @@ QualifiedName PEGTransformerFactory::StringToQualifiedName(vector input) throw InternalException("QualifiedName cannot be made with an empty input."); } if (input.size() == 1) { - result.catalog = INVALID_CATALOG; - result.schema = INVALID_SCHEMA; - result.name = input[0]; + result.catalog = Identifier::InvalidCatalog(); + result.schema = Identifier::InvalidSchema(); + result.name = Identifier(input[0]); } else if (input.size() == 2) { - result.catalog = INVALID_CATALOG; - result.schema = input[0]; - result.name = input[1]; + result.catalog = Identifier::InvalidCatalog(); + result.schema = Identifier(input[0]); + result.name = Identifier(input[1]); } else if (input.size() == 3) { - result.catalog = input[0]; - result.schema = input[1]; - result.name = input[2]; + result.catalog = Identifier(input[0]); + result.schema = Identifier(input[1]); + result.name = Identifier(input[2]); } else { throw ParserException("Too many qualifications found - expected [catalog.schema.name] or [schema.name]"); } @@ -1198,28 +646,29 @@ bool PEGTransformerFactory::ConstructConstantFromExpression(const ParsedExpressi switch (expr.GetExpressionType()) { case ExpressionType::FUNCTION: { auto &function = expr.Cast(); - if (function.function_name == "struct_pack") { - unordered_set unique_names; + if (function.FunctionName() == "struct_pack") { + identifier_set_t unique_names; child_list_t values; - values.reserve(function.children.size()); - for (const auto &child : function.children) { - if (!unique_names.insert(child->GetAlias()).second) { - throw BinderException("Duplicate struct entry name \"%s\"", child->GetAlias()); + values.reserve(function.GetArguments().size()); + for (const auto &child : function.GetArguments()) { + if (!unique_names.insert(child.GetExpression().GetAlias()).second) { + throw BinderException("Duplicate struct entry name \"%s\"", + child.GetExpression().GetAlias().GetIdentifierName()); } Value child_value; - if (!ConstructConstantFromExpression(*child, child_value)) { + if (!ConstructConstantFromExpression(child.GetExpression(), child_value)) { return false; } - values.emplace_back(child->GetAlias(), std::move(child_value)); + values.emplace_back(child.GetExpression().GetAlias(), std::move(child_value)); } value = Value::STRUCT(std::move(values)); return true; - } else if (function.function_name == "list_value") { + } else if (function.FunctionName() == "list_value") { vector values; - values.reserve(function.children.size()); - for (const auto &child : function.children) { + values.reserve(function.GetArguments().size()); + for (const auto &child : function.GetArguments()) { Value child_value; - if (!ConstructConstantFromExpression(*child, child_value)) { + if (!ConstructConstantFromExpression(child.GetExpression(), child_value)) { return false; } values.emplace_back(std::move(child_value)); @@ -1228,20 +677,20 @@ bool PEGTransformerFactory::ConstructConstantFromExpression(const ParsedExpressi // figure out child type LogicalType child_type(LogicalTypeId::SQLNULL); for (auto &child_value : values) { - child_type = LogicalType::ForceMaxLogicalType(child_type, child_value.type()); + child_type = LogicalType::DefaultForceMaxLogicalType(child_type, child_value.type()); } // finally create the list value = Value::LIST(child_type, values); return true; - } else if (function.function_name == "map") { + } else if (function.FunctionName() == "map") { Value keys; - if (!ConstructConstantFromExpression(*function.children[0], keys)) { + if (!ConstructConstantFromExpression(function.GetArguments()[0].GetExpression(), keys)) { return false; } Value values; - if (!ConstructConstantFromExpression(*function.children[1], values)) { + if (!ConstructConstantFromExpression(function.GetArguments()[1].GetExpression(), values)) { return false; } @@ -1263,14 +712,14 @@ bool PEGTransformerFactory::ConstructConstantFromExpression(const ParsedExpressi case ExpressionType::OPERATOR_CAST: { auto &cast = expr.Cast(); Value dummy_value; - if (!ConstructConstantFromExpression(*cast.child, dummy_value)) { + if (!ConstructConstantFromExpression(cast.Child(), dummy_value)) { return false; } string error_message; - if (!dummy_value.DefaultTryCastAs(cast.cast_type, value, &error_message)) { + if (!dummy_value.DefaultTryCastAs(cast.TargetType(), value, &error_message)) { throw ConversionException("Unable to cast %s to %s", dummy_value.ToString(), - EnumUtil::ToString(cast.cast_type.id())); + EnumUtil::ToString(cast.TargetType().id())); } return true; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_alter.cpp b/src/duckdb/src/parser/peg/transformer/transform_alter.cpp index 83ac71ead..837308e34 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_alter.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_alter.cpp @@ -1,4 +1,5 @@ #include "duckdb/parser/peg/ast/add_column_entry.hpp" +#include "duckdb/parser/peg/ast/column_constraint_entry.hpp" #include "duckdb/parser/peg/transformer/peg_transformer.hpp" #include "duckdb/parser/statement/alter_statement.hpp" #include "duckdb/parser/parsed_data/alter_info.hpp" @@ -12,10 +13,9 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformAlterStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + unique_ptr alter_options) { auto result = make_uniq(); - result->info = transformer.Transform>(list_pr.Child(1)); + result->info = std::move(alter_options); if (result->info->type != AlterType::ALTER_TABLE) { return std::move(result); } @@ -36,118 +36,103 @@ unique_ptr PEGTransformerFactory::TransformAlterStatement(PEGTrans TransformAndMaterializeAlter(alter_entry_data, make_uniq(add_column.GetAlterEntryData(), std::move(null_column), add_column.if_column_not_exists), - column_entry.GetName(), column_entry.DefaultValue().Copy()))); + column_entry.GetName().GetIdentifierName(), column_entry.DefaultValue().Copy()))); } -unique_ptr PEGTransformerFactory::TransformAlterOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); -} - -unique_ptr PEGTransformerFactory::TransformAlterTableStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto if_exists = list_pr.Child(1).HasResult(); - auto table = transformer.Transform>(list_pr.Child(2)); - auto alter_option_list = ExtractParseResultsFromList(list_pr.GetChild(3)); - if (alter_option_list.size() > 1) { +unique_ptr +PEGTransformerFactory::TransformAlterTableStmt(PEGTransformer &transformer, const bool &if_exists, + unique_ptr base_table_name, + vector> alter_table_options) { + if (alter_table_options.size() > 1) { throw ParserException("Only one ALTER command per statement is supported"); } - auto result = transformer.Transform>(alter_option_list[0]); + auto result = std::move(alter_table_options[0]); result->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; - result->catalog = table->catalog_name; - result->schema = table->schema_name; - result->name = table->table_name; + result->catalog = base_table_name->catalog_name; + result->schema = base_table_name->schema_name; + result->name = base_table_name->table_name; return std::move(result); } unique_ptr PEGTransformerFactory::TransformAlterDatabaseStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto if_exists = list_pr.Child(1).HasResult(); + const bool &if_exists, + const Identifier &identifier, + const Identifier &identifier_1) { OnEntryNotFound not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; - - auto catalog_name = list_pr.Child(2).identifier; - auto new_name = list_pr.Child(6).identifier; + auto catalog_name = identifier; + auto new_name = identifier_1; auto result = make_uniq(catalog_name, new_name, not_found); return std::move(result); } -unique_ptr PEGTransformerFactory::TransformAlterViewStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto if_exists = list_pr.Child(1).HasResult(); - auto base_table = transformer.Transform>(list_pr.Child(2)); - auto alter_table_info = transformer.Transform>(list_pr.Child(3)); - auto rename_table = unique_ptr_cast(std::move(alter_table_info)); +unique_ptr PEGTransformerFactory::TransformAlterViewStmt(PEGTransformer &transformer, const bool &if_exists, + unique_ptr base_table_name, + unique_ptr rename_alter) { + auto rename_table = unique_ptr_cast(std::move(rename_alter)); auto result = make_uniq(AlterEntryData(), rename_table->new_table_name); - result->catalog = base_table->catalog_name; - result->schema = base_table->schema_name; - result->name = base_table->table_name; + result->catalog = base_table_name->catalog_name; + result->schema = base_table_name->schema_name; + result->name = base_table_name->table_name; result->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; return std::move(result); } unique_ptr PEGTransformerFactory::TransformAlterSchemaStmt(PEGTransformer &transformer, - ParseResult &parse_result) { + const bool &if_exists, + const QualifiedName &qualified_name, + unique_ptr rename_alter) { throw NotImplementedException("Altering schemas is not yet supported"); } unique_ptr PEGTransformerFactory::TransformAlterSequenceStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto if_exists = list_pr.Child(1).HasResult(); - auto sequence_name = transformer.Transform(list_pr.Child(2)); - auto alter_info = transformer.Transform>(list_pr.Child(3)); - if (sequence_name.schema.empty()) { - alter_info->schema = sequence_name.catalog; + const bool &if_exists, + const QualifiedName &qualified_sequence_name, + unique_ptr alter_sequence_options) { + if (qualified_sequence_name.schema.empty()) { + alter_sequence_options->schema = qualified_sequence_name.catalog; } else { - alter_info->catalog = sequence_name.catalog; - alter_info->schema = sequence_name.schema; + alter_sequence_options->catalog = qualified_sequence_name.catalog; + alter_sequence_options->schema = qualified_sequence_name.schema; } - alter_info->name = sequence_name.name; - alter_info->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; - return alter_info; + alter_sequence_options->name = qualified_sequence_name.name; + alter_sequence_options->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; + return alter_sequence_options; } QualifiedName PEGTransformerFactory::TransformQualifiedSequenceName(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const Identifier &catalog_qualification, + const Identifier &schema_qualification, + const Identifier &sequence_name) { QualifiedName result; - result.catalog = INVALID_CATALOG; - result.schema = INVALID_SCHEMA; - transformer.TransformOptional(list_pr, 0, result.catalog); - transformer.TransformOptional(list_pr, 1, result.schema); - result.name = list_pr.Child(2).identifier; + result.catalog = catalog_qualification.empty() ? INVALID_CATALOG : catalog_qualification; + result.schema = schema_qualification.empty() ? INVALID_SCHEMA : schema_qualification; + result.name = sequence_name; return result; } unique_ptr PEGTransformerFactory::TransformAlterSequenceOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - if (list_pr.Child(0).GetResult().name == "RenameAlter") { - return transformer.Transform>(list_pr.Child(0).GetResult()); + ParseResult &choice_result) { + if (choice_result.name == "RenameAlter") { + return transformer.Transform>(choice_result); } - return transformer.Transform>(list_pr.Child(0).GetResult()); + return transformer.Transform>(choice_result); } -unique_ptr PEGTransformerFactory::TransformSetSequenceOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &repeat_pr = list_pr.Child(0); +unique_ptr +PEGTransformerFactory::TransformSetSequenceOption(PEGTransformer &transformer, + vector>> sequence_option) { bool has_owned = false; unique_ptr owned_info; - for (auto &seq_option_pr : repeat_pr.GetChildren()) { - auto seq_option = transformer.Transform>>(seq_option_pr); + for (auto &seq_option : sequence_option) { if (seq_option.first == "owned") { if (has_owned) { throw ParserException("Owned by value should be passed at most once"); } has_owned = true; auto owned_by = unique_ptr_cast(std::move(seq_option.second)); - auto schema = owned_by->qualified_name.schema.empty() ? DEFAULT_SCHEMA : owned_by->qualified_name.schema; + auto schema = + owned_by->qualified_name.schema.empty() ? Identifier::DefaultSchema() : owned_by->qualified_name.schema; owned_info = make_uniq(CatalogType::SEQUENCE_ENTRY, "", "", "", schema, owned_by->qualified_name.name, OnEntryNotFound::THROW_EXCEPTION); @@ -159,12 +144,6 @@ unique_ptr PEGTransformerFactory::TransformSetSequenceOption(PEGTrans throw NotImplementedException("ALTER SEQUENCE option not yet supported"); } -unique_ptr PEGTransformerFactory::TransformAlterTableOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); -} - void PEGTransformerFactory::AddToMultiStatement(const unique_ptr &multi_statement, unique_ptr alter_info) { auto alter_statement = make_uniq(); @@ -187,7 +166,7 @@ void PEGTransformerFactory::AddUpdateToMultiStatement(const unique_ptr(); - set_info->columns.push_back(column_name); + set_info->columns.emplace_back(column_name); set_info->expressions.push_back(original_expression->Copy()); node.set_info = std::move(set_info); @@ -220,101 +199,94 @@ unique_ptr PEGTransformerFactory::TransformAndMaterializeAlter( // 3. `ALTER TABLE t ALTER u SET DEFAULT ;` // Reinstate the original default expression. - AddToMultiStatement(multi_statement, make_uniq(data, column_name, std::move(expression))); + AddToMultiStatement(multi_statement, + make_uniq(data, Identifier(column_name), std::move(expression))); return multi_statement; } unique_ptr PEGTransformerFactory::TransformAddColumn(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - AddColumnEntry new_column = transformer.Transform(list_pr.Child(3)); - bool if_not_exists = list_pr.Child(2).HasResult(); - auto column_definition = ColumnDefinition(new_column.column_path.back(), new_column.type); - if (new_column.default_value) { - column_definition.SetDefaultValue(std::move(new_column.default_value)); + const bool &if_not_exists, + AddColumnEntry add_column_entry) { + auto column_definition = ColumnDefinition(add_column_entry.column_path.back(), add_column_entry.type); + if (add_column_entry.default_value) { + column_definition.SetDefaultValue(std::move(add_column_entry.default_value)); } unique_ptr result; - if (new_column.column_path.size() == 1) { + if (add_column_entry.column_path.size() == 1) { result = make_uniq(AlterEntryData(), std::move(column_definition), if_not_exists); } else { - const auto parent_path = vector(new_column.column_path.begin(), new_column.column_path.end() - 1); + const auto parent_path = + vector(add_column_entry.column_path.begin(), add_column_entry.column_path.end() - 1); result = make_uniq(AlterEntryData(), parent_path, std::move(column_definition), if_not_exists); } return result; } -AddColumnEntry PEGTransformerFactory::TransformAddColumnEntry(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +AddColumnEntry PEGTransformerFactory::TransformAddColumnEntry(PEGTransformer &transformer, + const vector &dotted_identifier, + const LogicalType &type, + GeneratedColumnDefinition generated_column, + vector column_constraint) { AddColumnEntry new_column; - new_column.column_path = transformer.Transform>(list_pr.Child(0)); - bool has_generated = list_pr.Child(2).HasResult(); - bool has_type = list_pr.Child(1).HasResult(); + new_column.column_path = StringsToIdentifiers(dotted_identifier); + bool has_type = type != LogicalType::INVALID; + bool has_generated = generated_column.expr != nullptr; + // TODO(Dtenwolde) this checking logic should be moved to the binder if (!has_type && !has_generated) { throw ParserException("Column definition requires a type or generated expression"); } if (has_generated) { throw ParserException("Adding generated columns after table creation is not supported yet"); } - transformer.TransformOptional(list_pr, 1, new_column.type); - auto &constraints_opt = list_pr.Child(3); - if (!constraints_opt.HasResult()) { - return new_column; - } - auto &constraints_repeat = constraints_opt.GetResult().Cast(); - for (auto &constraint_entry : constraints_repeat.GetChildren()) { - auto &constraint_list = constraint_entry.get().Cast(); - auto &constraint = constraint_list.Child(0).GetResult(); - if (constraint.name == "DefaultValue") { + new_column.type = type; + for (auto &constraint : column_constraint) { + if (constraint.constraint_name == "DefaultValue") { if (new_column.default_value) { throw ParserException("Cannot define a default value twice"); } - new_column.default_value = transformer.Transform>(constraint); - } else { - throw ParserException("Adding columns with constraints not yet supported"); + new_column.default_value = std::move(constraint.expression); } } return new_column; } -unique_ptr PEGTransformerFactory::TransformDropColumn(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - bool cascade = false; - transformer.TransformOptional(list_pr, 4, cascade); - bool if_exists = list_pr.Child(2).HasResult(); - auto nested_column = transformer.Transform>(list_pr.Child(3)); - if (nested_column->column_names.size() == 1) { - auto result = make_uniq(AlterEntryData(), nested_column->column_names[0], if_exists, cascade); +unique_ptr +PEGTransformerFactory::TransformDropColumn(PEGTransformer &transformer, const bool &if_exists, + unique_ptr nested_column_name, + const bool &drop_behavior) { + if (nested_column_name->ColumnNames().size() == 1) { + auto result = make_uniq( + AlterEntryData(), nested_column_name->ColumnNames()[0].GetIdentifierName(), if_exists, drop_behavior); return std::move(result); } - auto result = make_uniq(AlterEntryData(), nested_column->column_names, if_exists, cascade); + auto result = + make_uniq(AlterEntryData(), nested_column_name->ColumnNames(), if_exists, drop_behavior); return std::move(result); } -unique_ptr PEGTransformerFactory::TransformAlterColumn(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto nested_column_name = transformer.Transform>(list_pr.Child(2)); - auto alter_column_entry = transformer.Transform>(list_pr.Child(3)); +unique_ptr +PEGTransformerFactory::TransformAlterColumn(PEGTransformer &transformer, + unique_ptr nested_column_name, + unique_ptr alter_column_entry) { if (alter_column_entry->alter_table_type == AlterTableType::SET_DEFAULT) { auto set_default_entry = unique_ptr_cast(std::move(alter_column_entry)); // TODO(Dtenwolde) Figure out with nested names; - set_default_entry->column_name = nested_column_name->column_names[0]; + set_default_entry->column_name = nested_column_name->ColumnNames()[0]; return std::move(set_default_entry); } else if (alter_column_entry->alter_table_type == AlterTableType::DROP_NOT_NULL) { auto drop_not_null = unique_ptr_cast(std::move(alter_column_entry)); - drop_not_null->column_name = nested_column_name->column_names[0]; + drop_not_null->column_name = nested_column_name->ColumnNames()[0]; return std::move(drop_not_null); } else if (alter_column_entry->alter_table_type == AlterTableType::SET_NOT_NULL) { auto set_not_null = unique_ptr_cast(std::move(alter_column_entry)); - set_not_null->column_name = nested_column_name->column_names[0]; + set_not_null->column_name = nested_column_name->ColumnNames()[0]; return std::move(set_not_null); } else if (alter_column_entry->alter_table_type == AlterTableType::ALTER_COLUMN_TYPE) { auto change_column_type = unique_ptr_cast(std::move(alter_column_entry)); - change_column_type->column_name = nested_column_name->column_names[0]; + change_column_type->column_name = nested_column_name->ColumnNames()[0]; if (!change_column_type->expression) { change_column_type->expression = make_uniq(change_column_type->target_type, std::move(nested_column_name)); @@ -325,21 +297,12 @@ unique_ptr PEGTransformerFactory::TransformAlterColumn(PEGTransf } } -unique_ptr PEGTransformerFactory::TransformAlterColumnEntry(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); -} - -unique_ptr PEGTransformerFactory::TransformDropDefault(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr PEGTransformerFactory::TransformDropDefault(PEGTransformer &transformer) { return make_uniq(AlterEntryData(), "", nullptr); } unique_ptr PEGTransformerFactory::TransformChangeNullability(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - string drop_or_set = transformer.Transform(list_pr.Child(0)); + const string &drop_or_set) { if (StringUtil::CIEquals(drop_or_set, "drop")) { return make_uniq(AlterEntryData(), ""); } else { @@ -348,132 +311,82 @@ unique_ptr PEGTransformerFactory::TransformChangeNullability(PEG } unique_ptr PEGTransformerFactory::TransformAlterType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - LogicalType target_type = LogicalType::UNKNOWN; - transformer.TransformOptional(list_pr, 2, target_type); - auto &using_expr_opt = list_pr.Child(3); - if (target_type == LogicalType::UNKNOWN && !using_expr_opt.HasResult()) { + const LogicalType &type, + unique_ptr using_expression) { + if (type == LogicalType::INVALID && !using_expression) { throw ParserException("Omitting the type is only possible in combination with USING"); } - unique_ptr expr; - transformer.TransformOptional>(list_pr, 3, expr); - return make_uniq(AlterEntryData(), "", target_type, std::move(expr)); + auto alter_type = type == LogicalType::INVALID ? LogicalType::UNKNOWN : type; + return make_uniq(AlterEntryData(), "", alter_type, std::move(using_expression)); } unique_ptr PEGTransformerFactory::TransformUsingExpression(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(1)); + unique_ptr expression) { + return expression; } -string PEGTransformerFactory::TransformDropOrSet(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice = list_pr.Child(0); - return choice.GetResult().Cast().keyword; -} - -unique_ptr PEGTransformerFactory::TransformAddOrDropDefault(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); +unique_ptr PEGTransformerFactory::TransformAddDefault(PEGTransformer &transformer, + unique_ptr expression) { + return make_uniq(AlterEntryData(), "", std::move(expression)); } -unique_ptr PEGTransformerFactory::TransformAddDefault(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto expr = transformer.Transform>(list_pr.Child(2)); - return make_uniq(AlterEntryData(), "", std::move(expr)); -} - -unique_ptr PEGTransformerFactory::TransformRenameColumn(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto nested_column = transformer.Transform>(list_pr.Child(2)); - auto new_column_name = list_pr.Child(4).identifier; - if (nested_column->column_names.size() == 1) { - auto result = make_uniq(AlterEntryData(), nested_column->column_names[0], new_column_name); +unique_ptr PEGTransformerFactory::TransformRenameColumn( + PEGTransformer &transformer, unique_ptr nested_column_name, const Identifier &identifier) { + if (nested_column_name->ColumnNames().size() == 1) { + auto result = make_uniq(AlterEntryData(), nested_column_name->ColumnNames()[0], identifier); return std::move(result); } - auto result = make_uniq(AlterEntryData(), nested_column->column_names, new_column_name); + auto result = make_uniq(AlterEntryData(), nested_column_name->ColumnNames(), identifier); return std::move(result); } unique_ptr PEGTransformerFactory::TransformRenameAlter(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto new_table_name = list_pr.Child(2).identifier; - auto result = make_uniq(AlterEntryData(), new_table_name); - return std::move(result); + const Identifier &identifier) { + return make_uniq(AlterEntryData(), identifier); } -unique_ptr PEGTransformerFactory::TransformSetPartitionedBy(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(3)); - auto expr_list = ExtractParseResultsFromList(extract_parens); - vector> partition_keys; - for (auto expr : expr_list) { - partition_keys.push_back(transformer.Transform>(expr)); - } - return make_uniq(AlterEntryData(), std::move(partition_keys)); +unique_ptr +PEGTransformerFactory::TransformSetPartitionedBy(PEGTransformer &transformer, + vector> expression) { + return make_uniq(AlterEntryData(), std::move(expression)); } -unique_ptr PEGTransformerFactory::TransformResetPartitionedBy(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr PEGTransformerFactory::TransformResetPartitionedBy(PEGTransformer &transformer) { vector> partition_keys; return make_uniq(AlterEntryData(), std::move(partition_keys)); } unique_ptr PEGTransformerFactory::TransformAddConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto constraint = transformer.Transform>(list_pr.Child(1)); - return make_uniq(AlterEntryData(), std::move(constraint)); -} - -string PEGTransformerFactory::TransformSequenceName(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; + unique_ptr top_level_constraint) { + return make_uniq(AlterEntryData(), std::move(top_level_constraint)); } unique_ptr PEGTransformerFactory::TransformSetSortedBy(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(3)); - auto order_by_exprs = transformer.Transform>(extract_parens); - auto result = make_uniq(AlterEntryData(), std::move(order_by_exprs)); + vector order_by_expressions) { + auto result = make_uniq(AlterEntryData(), std::move(order_by_expressions)); return std::move(result); } -unique_ptr PEGTransformerFactory::TransformResetSortedBy(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr PEGTransformerFactory::TransformResetSortedBy(PEGTransformer &transformer) { vector order_by_exprs; auto result = make_uniq(AlterEntryData(), std::move(order_by_exprs)); return std::move(result); } -unique_ptr PEGTransformerFactory::TransformSetOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - // SetOptions <- 'SET' RelOptionList - // child 0: 'SET' keyword, child 1: RelOptionList - auto options = - transformer.Transform>>(list_pr.Child(1)); - return make_uniq(AlterEntryData(), std::move(options)); -} - -unique_ptr PEGTransformerFactory::TransformResetOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - // ResetOptions <- 'RESET' RelOptionList - // child 0: 'RESET' keyword, child 1: RelOptionList - auto options_map = - transformer.Transform>>(list_pr.Child(1)); - case_insensitive_set_t option_names; - for (auto &opt : options_map) { +unique_ptr +PEGTransformerFactory::TransformSetOptions(PEGTransformer &transformer, + case_insensitive_map_t> rel_option_list) { + return make_uniq(AlterEntryData(), std::move(rel_option_list)); +} + +unique_ptr +PEGTransformerFactory::TransformResetOptions(PEGTransformer &transformer, + case_insensitive_map_t> rel_option_list) { + identifier_set_t option_names; + for (auto &opt : rel_option_list) { if (!opt.second) { - option_names.insert(opt.first); + option_names.insert(Identifier(opt.first)); + continue; } if (opt.second->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Reset option \"%s\" cannot set any value. Did you mean to use SET?", opt.first); @@ -482,9 +395,29 @@ unique_ptr PEGTransformerFactory::TransformResetOptions(PEGTrans if (!const_expr.GetValue().IsNull()) { throw ParserException("Reset option \"%s\" cannot set any value. Did you mean to use SET?", opt.first); } - option_names.insert(opt.first); + option_names.insert(Identifier(opt.first)); } return make_uniq(AlterEntryData(), std::move(option_names)); } +unique_ptr +PEGTransformerFactory::TransformNestedColumnName(PEGTransformer &transformer, const vector &identifier_dot, + const Identifier &column_name) { + vector column_names = identifier_dot; + column_names.push_back(column_name); + return make_uniq(column_names); +} + +Identifier PEGTransformerFactory::TransformIdentifierDot(PEGTransformer &transformer, const Identifier &identifier) { + return identifier; +} + +string PEGTransformerFactory::TransformDropNullability(PEGTransformer &transformer) { + return "drop"; +} + +string PEGTransformerFactory::TransformSetNullability(PEGTransformer &transformer) { + return "set"; +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_analyze.cpp b/src/duckdb/src/parser/peg/transformer/transform_analyze.cpp index 2232ee532..2b64f363a 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_analyze.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_analyze.cpp @@ -5,29 +5,32 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformAnalyzeStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const bool &analyze_verbose, + AnalyzeTarget analyze_target) { VacuumOptions vacuum_options; vacuum_options.analyze = true; auto result = make_uniq(vacuum_options); - if (list_pr.Child(1).HasResult()) { + if (analyze_verbose) { throw NotImplementedException("ANALYZE VERBOSE is not implemented yet"); } - auto &target_opt = list_pr.Child(2); - if (target_opt.HasResult()) { - auto target = transformer.Transform(target_opt.GetResult()); - result->info->columns = target.columns; - result->info->ref = std::move(target.ref); + if (analyze_target.ref) { + result->info->columns = analyze_target.columns; + result->info->ref = std::move(analyze_target.ref); result->info->has_table = true; } return std::move(result); } -AnalyzeTarget PEGTransformerFactory::TransformAnalyzeTarget(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +AnalyzeTarget PEGTransformerFactory::TransformAnalyzeTarget(PEGTransformer &transformer, + unique_ptr base_table_name, + const vector &name_list) { AnalyzeTarget result; - result.ref = transformer.Transform>(list_pr.Child(0)); - transformer.TransformOptional>(list_pr, 1, result.columns); + result.ref = std::move(base_table_name); + result.columns = StringsToIdentifiers(name_list); return result; } + +bool PEGTransformerFactory::TransformAnalyzeVerbose(PEGTransformer &transformer) { + return true; +} } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_attach.cpp b/src/duckdb/src/parser/peg/transformer/transform_attach.cpp index 5a4978809..174e78b9d 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_attach.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_attach.cpp @@ -1,21 +1,17 @@ #include "duckdb/parser/peg/ast/generic_copy_option.hpp" -#include "duckdb/parser/expression/cast_expression.hpp" -#include "duckdb/parser/expression/columnref_expression.hpp" -#include "duckdb/parser/expression/operator_expression.hpp" #include "duckdb/parser/statement/attach_statement.hpp" #include "duckdb/parser/peg/transformer/peg_transformer.hpp" namespace duckdb { -unique_ptr PEGTransformerFactory::TransformAttachStatement(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr +PEGTransformerFactory::TransformAttachStatement(PEGTransformer &transformer, const bool &or_replace, + const bool &if_not_exists, unique_ptr database_path, + const Identifier &attach_alias, + const vector &attach_options) { auto result = make_uniq(); auto info = make_uniq(); - auto &list_pr = parse_result.Cast(); - auto or_replace = list_pr.Child(1).HasResult(); - auto if_not_exists = list_pr.Child(2).HasResult(); - if (or_replace && if_not_exists) { throw ParserException("Cannot specify both OR REPLACE and IF NOT EXISTS at the same time"); } @@ -28,116 +24,43 @@ unique_ptr PEGTransformerFactory::TransformAttachStatement(PEGTran info->on_conflict = OnCreateConflict::ERROR_ON_CONFLICT; } - info->parsed_path = transformer.Transform>(list_pr.Child(4)); - transformer.TransformOptional(list_pr, 5, info->name); - vector copy_options; - transformer.TransformOptional>(list_pr, 6, copy_options); - for (auto ©_option : copy_options) { - if (copy_option.expression) { - info->parsed_options[copy_option.name] = std::move(copy_option.expression); + info->parsed_path = std::move(database_path); + info->name = Identifier(attach_alias); + for (const auto &attach_option : attach_options) { + if (attach_option.expression) { + info->parsed_options[attach_option.name.GetIdentifierName()] = attach_option.expression->Copy(); continue; } - if (copy_option.children.empty()) { - info->options[copy_option.name] = Value(true); - } else if (copy_option.children.size() == 1) { - auto val = copy_option.children[0]; + if (attach_option.children.empty()) { + info->options[attach_option.name.GetIdentifierName()] = Value(true); + } else if (attach_option.children.size() == 1) { + auto val = attach_option.children[0]; if (val.IsNull()) { throw BinderException("NULL is not supported as a valid option for ATTACH option \"%s\"", - copy_option.name); + attach_option.name); } - info->options[copy_option.name] = std::move(copy_option.children[0]); + info->options[attach_option.name.GetIdentifierName()] = attach_option.children[0]; } else { - throw ParserException("Option %s can only have one argument", copy_option.name); + throw ParserException("Option %s can only have one argument", attach_option.name); } } result->info = std::move(info); return std::move(result); } -string PEGTransformerFactory::TransformAttachAlias(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &col_id = list_pr.Child(1).Child(0); - if (col_id.GetResult().type == ParseResultType::IDENTIFIER) { - return col_id.GetResult().Cast().identifier; - } - return transformer.Transform(col_id.GetResult()); -} - -vector PEGTransformerFactory::TransformAttachOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0)); +Identifier PEGTransformerFactory::TransformAttachAlias(PEGTransformer &transformer, const Identifier &col_id) { + return Identifier(col_id); } -vector PEGTransformerFactory::TransformGenericCopyOptionList(PEGTransformer &transformer, - ParseResult &parse_result) { - vector result; - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - auto option_list = ExtractParseResultsFromList(extract_parens); - for (auto &option : option_list) { - result.push_back(transformer.Transform(option)); - } - return result; -} - -GenericCopyOption PEGTransformerFactory::TransformGenericCopyOption(PEGTransformer &transformer, - ParseResult &parse_result) { - GenericCopyOption copy_option; - - auto &list_pr = parse_result.Cast(); - copy_option.name = StringUtil::Lower(list_pr.Child(0).identifier); - auto &optional_expression = list_pr.Child(1); - if (!optional_expression.HasResult()) { - return copy_option; - } - // TODO(Dtenwolde) Rework to switch statement - auto expression = transformer.Transform>(optional_expression.GetResult()); - if (expression->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - copy_option.children.push_back(Value(expression->Cast().GetValue())); - } else if (expression->GetExpressionType() == ExpressionType::COLUMN_REF) { - copy_option.children.push_back(Value(expression->Cast().GetColumnName())); - } else if (expression->GetExpressionType() == ExpressionType::PLACEHOLDER) { - auto &op_expr = expression->Cast(); - for (auto &child : op_expr.children) { - if (child->GetExpressionClass() == ExpressionClass::CONSTANT) { - copy_option.children.push_back(Value(child->Cast().GetValue())); - } else if (child->GetExpressionClass() == ExpressionClass::COLUMN_REF) { - copy_option.children.push_back(Value(child->Cast().GetColumnName())); - } else { - throw InternalException("Unexpected expression type %s encountered for GenericCopyOption", - ExpressionClassToString(child->GetExpressionClass())); - } - } - } else if (expression->GetExpressionType() == ExpressionType::FUNCTION) { - copy_option.expression = std::move(expression); - } else if (expression->GetExpressionType() == ExpressionType::STAR) { - copy_option.children.push_back(Value("*")); - } else if (expression->GetExpressionType() == ExpressionType::OPERATOR_CAST) { - auto &cast_expr = expression->Cast(); - if (cast_expr.child->GetExpressionClass() == ExpressionClass::CONSTANT) { - auto &const_expr = cast_expr.child->Cast(); - if (const_expr.GetValue().GetValue() == "t") { - copy_option.children.push_back(Value(true)); - } else if (const_expr.GetValue().GetValue() == "f") { - copy_option.children.push_back(Value(false)); - } else { - copy_option.expression = std::move(expression); - } - } else { - copy_option.expression = std::move(expression); - } - } else { - throw NotImplementedException("Unrecognized expression type %s", - ExpressionTypeToString(expression->GetExpressionType())); - } - return copy_option; +vector +PEGTransformerFactory::TransformAttachOptions(PEGTransformer &transformer, + const vector &generic_copy_option_list) { + return generic_copy_option_list; } unique_ptr PEGTransformerFactory::TransformDatabasePath(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.GetChild(0)); + unique_ptr expression) { + return expression; } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_call.cpp b/src/duckdb/src/parser/peg/transformer/transform_call.cpp index c2be8b917..be8699c24 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_call.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_call.cpp @@ -5,15 +5,14 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformCallStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto table_function_name = transformer.Transform(list_pr.Child(1)); - auto function_children = - transformer.Transform>>(list_pr.Child(2)); +unique_ptr +PEGTransformerFactory::TransformCallStatement(PEGTransformer &transformer, + const QualifiedName &qualified_table_function, + vector table_function_arguments) { auto result = make_uniq(); - auto function_expression = make_uniq(table_function_name.catalog, table_function_name.schema, - table_function_name.name, std::move(function_children)); + auto function_expression = + make_uniq(qualified_table_function.catalog, qualified_table_function.schema, + qualified_table_function.name, std::move(table_function_arguments)); result->function = std::move(function_expression); return std::move(result); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_checkpoint.cpp b/src/duckdb/src/parser/peg/transformer/transform_checkpoint.cpp index 89e61757c..ef6942d1f 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_checkpoint.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_checkpoint.cpp @@ -5,24 +5,22 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformCheckpointStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - - auto &force_pr = list_pr.Child(0); - auto checkpoint_name = force_pr.HasResult() ? "force_checkpoint" : "checkpoint"; + const bool &checkpoint_force, + const Identifier &catalog_name) { + auto checkpoint_name = checkpoint_force ? "force_checkpoint" : "checkpoint"; auto result = make_uniq(); vector> children; auto function = make_uniq(checkpoint_name, std::move(children)); - function->catalog = SYSTEM_CATALOG; - function->schema = DEFAULT_SCHEMA; - - auto &catalog_name_pr = list_pr.Child(2); - if (catalog_name_pr.HasResult()) { - function->children.push_back( - make_uniq(catalog_name_pr.GetResult().Cast().identifier)); + function->CatalogMutable() = SYSTEM_CATALOG; + function->SchemaMutable() = DEFAULT_SCHEMA; + if (!catalog_name.empty()) { + function->GetArgumentsMutable().emplace_back(make_uniq(catalog_name)); } result->function = std::move(function); return std::move(result); } +bool PEGTransformerFactory::TransformCheckpointForce(PEGTransformer &transformer) { + return true; +} } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_comment.cpp b/src/duckdb/src/parser/peg/transformer/transform_comment.cpp index c7630990f..583272458 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_comment.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_comment.cpp @@ -5,24 +5,22 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformCommentStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto comment_on_type = transformer.Transform(list_pr.Child(2)); - auto dotted_identifier = transformer.Transform>(list_pr.Child(3)); - auto comment_value = transformer.Transform(list_pr.Child(5)); - + const CatalogType &comment_on_type, + const vector &dotted_identifier, + const Value &comment_value) { auto result = make_uniq(); unique_ptr info; - string column_name; + Identifier column_name; if (comment_on_type == CatalogType::INVALID) { // Column type returned - column_name = dotted_identifier.back(); - dotted_identifier.pop_back(); - if (dotted_identifier.empty()) { - throw ParserException("Invalid column reference: '%s'", column_name); + auto identifier = dotted_identifier; + column_name = Identifier(identifier.back()); + identifier.pop_back(); + if (identifier.empty()) { + throw ParserException("Invalid column reference: '%s'", column_name.GetIdentifierName()); } - auto qualified_name = StringToQualifiedName(dotted_identifier); + auto qualified_name = StringToQualifiedName(identifier); info = make_uniq(qualified_name.catalog, qualified_name.schema, qualified_name.name, column_name, comment_value, OnEntryNotFound::THROW_EXCEPTION); } else if (comment_on_type == CatalogType::DATABASE_ENTRY) { @@ -41,19 +39,58 @@ unique_ptr PEGTransformerFactory::TransformCommentStatement(PEGTra return std::move(result); } -CatalogType PEGTransformerFactory::TransformCommentOnType(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); +CatalogType PEGTransformerFactory::TransformCommentTable(PEGTransformer &transformer) { + return CatalogType::TABLE_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentSequence(PEGTransformer &transformer) { + return CatalogType::SEQUENCE_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentFunction(PEGTransformer &transformer) { + return CatalogType::MACRO_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentMacroTable(PEGTransformer &transformer) { + return CatalogType::TABLE_MACRO_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentMacro(PEGTransformer &transformer) { + return CatalogType::MACRO_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentView(PEGTransformer &transformer) { + return CatalogType::VIEW_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentDatabase(PEGTransformer &transformer) { + return CatalogType::DATABASE_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentIndex(PEGTransformer &transformer) { + return CatalogType::INDEX_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentSchema(PEGTransformer &transformer) { + return CatalogType::SCHEMA_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentType(PEGTransformer &transformer) { + return CatalogType::TYPE_ENTRY; +} + +CatalogType PEGTransformerFactory::TransformCommentColumn(PEGTransformer &transformer) { + return CatalogType::INVALID; } -Value PEGTransformerFactory::TransformCommentValue(PEGTransformer &transformer, ParseResult &parse_result) { - // CommentValue <- 'NULL'i / StringLiteral - auto &list_pr = parse_result.Cast(); +Value PEGTransformerFactory::TransformCommentValue(PEGTransformer &transformer, ParseResult &choice_result) { + // CommentValue <- NullLiteral / StringLiteral + auto &list_pr = choice_result.Cast(); auto &choice_pr = list_pr.Child(0); if (choice_pr.GetResult().type == ParseResultType::STRING) { return Value(choice_pr.GetResult().Cast().result); } - return Value(); + return transformer.Transform(choice_pr.GetResult()); } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_common.cpp b/src/duckdb/src/parser/peg/transformer/transform_common.cpp index 374ae128f..b3ebbd1fc 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_common.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_common.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/extra_type_info.hpp" +#include "duckdb/common/enums/date_part_specifier.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/types/decimal.hpp" #include "duckdb/parser/peg/transformer/peg_transformer.hpp" @@ -12,7 +13,7 @@ namespace duckdb { string PEGTransformerFactory::TransformIdentifierOrKeyword(PEGTransformer &transformer, ParseResult &parse_result) { if (parse_result.type == ParseResultType::IDENTIFIER) { - return parse_result.Cast().identifier; + return parse_result.Cast().identifier.GetIdentifierName(); } if (parse_result.type == ParseResultType::KEYWORD) { return parse_result.Cast().keyword; @@ -31,7 +32,7 @@ string PEGTransformerFactory::TransformIdentifierOrKeyword(PEGTransformer &trans if (child.get().type == ParseResultType::CHOICE) { auto &choice_result = child.get().Cast().GetResult(); if (choice_result.type == ParseResultType::IDENTIFIER) { - return choice_result.Cast().identifier; + return choice_result.Cast().identifier.GetIdentifierName(); } if (choice_result.type == ParseResultType::KEYWORD) { return choice_result.Cast().keyword; @@ -39,7 +40,7 @@ string PEGTransformerFactory::TransformIdentifierOrKeyword(PEGTransformer &trans return transformer.Transform(choice_result); } if (child.get().type == ParseResultType::IDENTIFIER) { - return child.get().Cast().identifier; + return child.get().Cast().identifier.GetIdentifierName(); } throw InternalException("Unexpected IdentifierOrKeyword type encountered %s.", ParseResultToString(child.get().type)); @@ -48,51 +49,38 @@ string PEGTransformerFactory::TransformIdentifierOrKeyword(PEGTransformer &trans throw ParserException("Unexpected ParseResult type in identifier transformation."); } -LogicalType PEGTransformerFactory::TransformType(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &type_pr = list_pr.Child(0); - auto type = transformer.Transform>(type_pr.Child(0).GetResult()); - auto &opt_array_bounds_pr = list_pr.Child(1); - if (!opt_array_bounds_pr.HasResult()) { - return LogicalType::UNBOUND(std::move(type)); - } - auto &array_bounds_repeat = opt_array_bounds_pr.GetResult().Cast(); - for (auto &array_bound : array_bounds_repeat.GetChildren()) { +LogicalType PEGTransformerFactory::TransformType(PEGTransformer &transformer, + unique_ptr type_variations, + const vector &array_bounds) { + auto type = std::move(type_variations); + for (auto array_size : array_bounds) { vector> children_types; children_types.push_back(std::move(type)); - auto array_size = transformer.Transform(array_bound); if (array_size < 0) { - type = make_uniq("list", std::move(children_types)); + type = make_uniq(Identifier("list"), std::move(children_types)); } else { children_types.push_back(make_uniq(Value::BIGINT(array_size))); - type = make_uniq("array", std::move(children_types)); + type = make_uniq(Identifier("array"), std::move(children_types)); } } return LogicalType::UNBOUND(std::move(type)); } -int64_t PEGTransformerFactory::TransformArrayBounds(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - if (choice_pr.name.empty()) { - return -1; - } - return transformer.Transform(list_pr.Child(0).GetResult()); +int64_t PEGTransformerFactory::TransformArrayKeyword(PEGTransformer &transformer) { + return -1; } -int64_t PEGTransformerFactory::TransformSquareBracketsArray(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &opt_array_size = list_pr.Child(1); - if (!opt_array_size.HasResult()) { +int64_t PEGTransformerFactory::TransformSquareBracketsArray(PEGTransformer &transformer, + unique_ptr expression) { + if (!expression) { // Empty array so we return -1 to signify it's a list return -1; } - auto number_expr = transformer.Transform>(opt_array_size.GetResult()); - if (number_expr->GetExpressionClass() != ExpressionClass::CONSTANT) { + if (expression->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected a constant number as array size"); } - auto &const_number = number_expr->Cast(); + auto &const_number = expression->Cast(); if (!const_number.GetValue().type().IsIntegral()) { throw BinderException("Expected an integer as array bound instead of %s", const_number.GetValue().ToString()); } @@ -103,35 +91,30 @@ int64_t PEGTransformerFactory::TransformSquareBracketsArray(PEGTransformer &tran return number_val; } -unique_ptr PEGTransformerFactory::TransformTimeType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - LogicalTypeId type = transformer.Transform(list_pr.Child(0)); - auto &opt_type_modifiers = list_pr.Child(1); - vector> modifiers; - if (opt_type_modifiers.HasResult()) { - modifiers = transformer.Transform>>(opt_type_modifiers.GetResult()); - } - auto &opt_timezone = list_pr.Child(2); - const bool with_timezone = opt_timezone.HasResult() ? transformer.Transform(opt_timezone.GetResult()) : false; +unique_ptr +PEGTransformerFactory::TransformTimeType(PEGTransformer &transformer, const LogicalTypeId &time_or_timestamp, + vector> type_modifiers, const bool &time_zone) { + auto type = time_or_timestamp; + auto modifiers = std::move(type_modifiers); + auto with_timezone = time_zone; if (type == LogicalTypeId::TIME) { if (!modifiers.empty()) { throw ParserException("Type TIME does not allow any modifiers"); } if (with_timezone) { - return make_uniq(EnumUtil::ToString(LogicalType::TIME_TZ), + return make_uniq(Identifier(EnumUtil::ToString(LogicalType::TIME_TZ)), vector> {}); } - return make_uniq(EnumUtil::ToString(LogicalType::TIME), + return make_uniq(Identifier(EnumUtil::ToString(LogicalType::TIME)), vector> {}); } if (type == LogicalTypeId::TIMESTAMP) { if (modifiers.empty()) { if (with_timezone) { - return make_uniq(EnumUtil::ToString(LogicalType::TIMESTAMP_TZ), + return make_uniq(Identifier(EnumUtil::ToString(LogicalType::TIMESTAMP_TZ)), vector> {}); } - return make_uniq(EnumUtil::ToString(LogicalType::TIMESTAMP), + return make_uniq(Identifier(EnumUtil::ToString(LogicalType::TIMESTAMP)), vector> {}); } if (modifiers.size() > 1) { @@ -148,276 +131,357 @@ unique_ptr PEGTransformerFactory::TransformTimeType(PEGTransfo throw ParserException("TIMESTAMP precision should be between 0 and 10 (inclusive)"); } if (timestamp_precision == 0) { - return make_uniq(EnumUtil::ToString(LogicalType::TIMESTAMP_S), + return make_uniq(Identifier(EnumUtil::ToString(LogicalType::TIMESTAMP_S)), vector> {}); } if (timestamp_precision <= 3) { - return make_uniq(EnumUtil::ToString(LogicalType::TIMESTAMP_MS), + return make_uniq(Identifier(EnumUtil::ToString(LogicalType::TIMESTAMP_MS)), vector> {}); } if (timestamp_precision <= 6) { // Corresponds to microseconds, which is the default TIMESTAMP - return make_uniq(EnumUtil::ToString(LogicalType::TIMESTAMP), + return make_uniq(Identifier(EnumUtil::ToString(LogicalType::TIMESTAMP)), vector> {}); } - return make_uniq(EnumUtil::ToString(LogicalType::TIMESTAMP_NS), + return make_uniq(Identifier(EnumUtil::ToString(LogicalType::TIMESTAMP_NS)), vector> {}); } throw ParserException("Unexpected time type encountered"); } -bool PEGTransformerFactory::TransformTimeZone(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0)); +bool PEGTransformerFactory::TransformTimeZone(PEGTransformer &transformer, const bool &with_or_without) { + return with_or_without; +} + +bool PEGTransformerFactory::TransformWithRule(PEGTransformer &transformer) { + return true; } -bool PEGTransformerFactory::TransformWithOrWithout(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0)); +bool PEGTransformerFactory::TransformWithoutRule(PEGTransformer &transformer) { + return false; } -LogicalTypeId PEGTransformerFactory::TransformTimeOrTimestamp(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &time_or_timestamp = list_pr.Child(0).GetResult(); - return transformer.TransformEnum(time_or_timestamp); +LogicalTypeId PEGTransformerFactory::TransformTimeTypeId(PEGTransformer &transformer) { + return LogicalTypeId::TIME; } -unique_ptr PEGTransformerFactory::TransformNumericType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); +LogicalTypeId PEGTransformerFactory::TransformTimestampTypeId(PEGTransformer &transformer) { + return LogicalTypeId::TIMESTAMP; } unique_ptr PEGTransformerFactory::TransformSimpleNumericType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto numeric = transformer.TransformEnum(list_pr.Child(0).GetResult()); - return make_uniq(numeric, vector> {}); + const string &child) { + return make_uniq(Identifier(child), vector> {}); } -unique_ptr PEGTransformerFactory::TransformDecimalNumericType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); +string PEGTransformerFactory::TransformIntType(PEGTransformer &transformer) { + return LogicalTypeIdToString(LogicalTypeId::INTEGER); } -unique_ptr PEGTransformerFactory::TransformFloatType(PEGTransformer &, ParseResult &) { - return make_uniq("FLOAT", vector> {}); +string PEGTransformerFactory::TransformIntegerType(PEGTransformer &transformer) { + return LogicalTypeIdToString(LogicalTypeId::INTEGER); } -unique_ptr PEGTransformerFactory::TransformDecimalType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &opt_type_modifier = list_pr.Child(1); - if (!opt_type_modifier.HasResult()) { - return make_uniq("DECIMAL", vector> {}); - } - auto type_modifiers = transformer.Transform>>(opt_type_modifier.GetResult()); - return make_uniq("DECIMAL", std::move(type_modifiers)); -} - -vector> PEGTransformerFactory::TransformTypeModifiers(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - vector> result; - auto &extract_list = ExtractResultFromParens(list_pr.Child(0)).Cast(); - if (extract_list.HasResult()) { - auto expressions = ExtractParseResultsFromList(extract_list.GetResult()); - for (auto expression : expressions) { - auto expr = transformer.Transform>(expression); - if (expr->GetExpressionClass() != ExpressionClass::CONSTANT) { - throw ParserException("Expected a constant as type modifier"); - } - result.push_back(std::move(expr)); - } - } +string PEGTransformerFactory::TransformSmallintType(PEGTransformer &transformer) { + return LogicalTypeIdToString(LogicalTypeId::SMALLINT); +} - return result; +string PEGTransformerFactory::TransformBigintType(PEGTransformer &transformer) { + return LogicalTypeIdToString(LogicalTypeId::BIGINT); } -unique_ptr PEGTransformerFactory::TransformSimpleType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &opt_modifiers = list_pr.Child(1); - vector> children; - if (opt_modifiers.HasResult()) { - children = transformer.Transform>>(opt_modifiers.GetResult()); - } - auto &qualified_type_or_character = list_pr.Child(0); - auto &type_or_character_pr = qualified_type_or_character.Child(0).GetResult(); - if (type_or_character_pr.name == "QualifiedTypeName") { - auto qualified_type_name = transformer.Transform(type_or_character_pr); - if (qualified_type_name.schema.empty()) { - qualified_type_name.schema = qualified_type_name.catalog; - qualified_type_name.catalog = INVALID_CATALOG; - } - return make_uniq(qualified_type_name.catalog, qualified_type_name.schema, - qualified_type_name.name, std::move(children)); - } - if (type_or_character_pr.name == "CharacterType") { - return transformer.Transform>(type_or_character_pr); - } - throw InternalException("Unexpected rule %s encountered in SimpleType", type_or_character_pr.name); -} - -QualifiedName PEGTransformerFactory::TransformQualifiedTypeName(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - vector qualified_typename; - auto &opt_identifiers = list_pr.Child(0); - if (opt_identifiers.HasResult()) { - auto &repeat_identifiers = opt_identifiers.GetResult().Cast(); - for (auto &child : repeat_identifiers.GetChildren()) { - auto &repeat_list = child.get().Cast(); - qualified_typename.push_back(repeat_list.Child(0).identifier); +string PEGTransformerFactory::TransformRealType(PEGTransformer &transformer) { + return LogicalTypeIdToString(LogicalTypeId::FLOAT); +} + +string PEGTransformerFactory::TransformBooleanType(PEGTransformer &transformer) { + return LogicalTypeIdToString(LogicalTypeId::BOOLEAN); +} + +string PEGTransformerFactory::TransformDoubleType(PEGTransformer &transformer) { + return LogicalTypeIdToString(LogicalTypeId::DOUBLE); +} + +unique_ptr PEGTransformerFactory::TransformFloatType(PEGTransformer &transformer, + unique_ptr number_literal) { + return make_uniq(Identifier("FLOAT"), vector> {}); +} + +unique_ptr +PEGTransformerFactory::TransformDecimalType(PEGTransformer &transformer, + vector> type_modifiers) { + return make_uniq(Identifier("DECIMAL"), std::move(type_modifiers)); +} + +unique_ptr +PEGTransformerFactory::TransformDecType(PEGTransformer &transformer, + vector> type_modifiers) { + return TransformDecimalType(transformer, std::move(type_modifiers)); +} + +unique_ptr +PEGTransformerFactory::TransformNumericModType(PEGTransformer &transformer, + vector> type_modifiers) { + return TransformDecimalType(transformer, std::move(type_modifiers)); +} + +vector> +PEGTransformerFactory::TransformTypeModifiers(PEGTransformer &transformer, + vector> expression) { + for (auto &expr : expression) { + if (expr->GetExpressionClass() != ExpressionClass::CONSTANT) { + throw ParserException("Expected a constant as type modifier"); } } + return expression; +} + +unique_ptr +PEGTransformerFactory::TransformCharacterSimpleType(PEGTransformer &transformer, const string &character_type, + vector> type_modifiers) { + return make_uniq(character_type, std::move(type_modifiers)); +} - if (list_pr.GetChild(1).type == ParseResultType::IDENTIFIER) { - qualified_typename.push_back(list_pr.Child(1).identifier); - } else { - qualified_typename.push_back(transformer.Transform(list_pr.Child(2))); +unique_ptr +PEGTransformerFactory::TransformQualifiedSimpleType(PEGTransformer &transformer, + const QualifiedName &qualified_type_name, + vector> type_modifiers) { + auto result = qualified_type_name; + if (result.schema.empty()) { + result.schema = result.catalog; + result.catalog = INVALID_CATALOG; } - return StringToQualifiedName(qualified_typename); + return make_uniq(result.catalog, result.schema, result.name, std::move(type_modifiers)); } -unique_ptr PEGTransformerFactory::TransformCharacterType(PEGTransformer &transformer, - ParseResult &parse_result) { - return make_uniq("VARCHAR", vector> {}); +QualifiedName PEGTransformerFactory::TransformTypeNameAsQualifiedName(PEGTransformer &transformer, + const Identifier &type_name) { + QualifiedName result; + result.catalog = INVALID_CATALOG; + result.schema = INVALID_SCHEMA; + result.name = type_name; + return result; +} + +QualifiedName PEGTransformerFactory::TransformSchemaReservedTypeName(PEGTransformer &transformer, + const Identifier &schema_qualification, + const Identifier &reserved_type_name) { + QualifiedName result; + result.catalog = INVALID_CATALOG; + result.schema = schema_qualification; + result.name = reserved_type_name; + return result; +} + +QualifiedName PEGTransformerFactory::TransformCatalogReservedSchemaTypeName( + PEGTransformer &transformer, const Identifier &catalog_qualification, + const Identifier &reserved_schema_qualification, const Identifier &reserved_type_name) { + QualifiedName result; + result.catalog = catalog_qualification; + result.schema = reserved_schema_qualification; + result.name = reserved_type_name; + return result; +} + +string PEGTransformerFactory::TransformCharacterType(PEGTransformer &transformer) { + return "VARCHAR"; } unique_ptr PEGTransformerFactory::TransformMapType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); - auto type_list = ExtractParseResultsFromList(extract_parens); - if (type_list.size() != 2) { + const vector &type) { + if (type.size() != 2) { throw ParserException("Map type needs exactly two entries, key and value type."); } - auto key_type = transformer.Transform(type_list[0]); - auto value_type = transformer.Transform(type_list[1]); vector> map_children; - map_children.push_back(UnboundType::GetTypeExpression(key_type)->Copy()); - map_children.push_back(UnboundType::GetTypeExpression(value_type)->Copy()); - return make_uniq("MAP", std::move(map_children)); + map_children.push_back(UnboundType::GetTypeExpression(type[0])->Copy()); + map_children.push_back(UnboundType::GetTypeExpression(type[1])->Copy()); + return make_uniq(Identifier("MAP"), std::move(map_children)); } -unique_ptr PEGTransformerFactory::TransformRowType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto colid_list = transformer.Transform>(list_pr.Child(1)); +unique_ptr +PEGTransformerFactory::TransformRowType(PEGTransformer &transformer, + const child_list_t &col_id_type_list) { vector> struct_children; - for (auto &child : colid_list) { + for (auto &child : col_id_type_list) { auto &type_expr = UnboundType::GetTypeExpression(child.second); auto new_type_expr = type_expr->Copy(); new_type_expr->SetAlias(child.first); struct_children.push_back(std::move(new_type_expr)); } - return make_uniq("STRUCT", std::move(struct_children)); + return make_uniq(Identifier("STRUCT"), std::move(struct_children)); } unique_ptr PEGTransformerFactory::TransformGeometryType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &crs_opt = list_pr.Child(1); - if (!crs_opt.HasResult()) { - return make_uniq("GEOMETRY", vector> {}); + unique_ptr expression) { + if (!expression) { + return make_uniq(Identifier("GEOMETRY"), vector> {}); } - auto &extract_parens = ExtractResultFromParens(crs_opt.GetResult()); - auto crs = transformer.Transform>(extract_parens); vector> geo_children; - if (crs->GetExpressionClass() != ExpressionClass::CONSTANT) { + if (expression->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected a constant as type modifier"); } - geo_children.push_back(std::move(crs)); - return make_uniq("GEOMETRY", std::move(geo_children)); + geo_children.push_back(std::move(expression)); + return make_uniq(Identifier("GEOMETRY"), std::move(geo_children)); } -unique_ptr PEGTransformerFactory::TransformVariantType(PEGTransformer &transformer, - ParseResult &parse_result) { - return make_uniq("VARIANT", vector> {}); +unique_ptr PEGTransformerFactory::TransformVariantType(PEGTransformer &transformer) { + return make_uniq(Identifier("VARIANT"), vector> {}); } -unique_ptr PEGTransformerFactory::TransformUnionType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto colid_list = transformer.Transform>(list_pr.Child(1)); - case_insensitive_string_set_t union_names; +unique_ptr +PEGTransformerFactory::TransformUnionType(PEGTransformer &transformer, + const child_list_t &col_id_type_list) { + identifier_set_t union_names; vector> union_children; - for (auto &colid : colid_list) { + for (auto &colid : col_id_type_list) { union_names.insert(colid.first); auto &type_expr = UnboundType::GetTypeExpression(colid.second); auto new_type_expr = type_expr->Copy(); new_type_expr->SetAlias(colid.first); union_children.push_back(std::move(new_type_expr)); } - return make_uniq("UNION", std::move(union_children)); + return make_uniq(Identifier("UNION"), std::move(union_children)); } -child_list_t PEGTransformerFactory::TransformColIdTypeList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_list = ExtractResultFromParens(list_pr.Child(0)); - auto colid_type_list = ExtractParseResultsFromList(extract_list); +child_list_t +PEGTransformerFactory::TransformColIdTypeList(PEGTransformer &transformer, + const vector> &col_id_type) { + return col_id_type; +} - child_list_t result; - for (auto colid_type : colid_type_list) { - result.push_back(transformer.Transform>(colid_type)); - } - return result; +pair PEGTransformerFactory::TransformColIdType(PEGTransformer &transformer, + const Identifier &col_id, + const LogicalType &type) { + return make_pair(Identifier(col_id), type); } -pair PEGTransformerFactory::TransformColIdType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto colid = transformer.Transform(list_pr.Child(0)); - auto type = transformer.Transform(list_pr.Child(1)); - return make_pair(colid, type); +unique_ptr PEGTransformerFactory::TransformBitType( + PEGTransformer &transformer, + // NOLINTNEXTLINE(performance-unnecessary-value-param): fixed generated signature; BIT takes no modifiers + vector> expression) { + return make_uniq(Identifier("BIT"), vector> {}); } -unique_ptr PEGTransformerFactory::TransformBitType(PEGTransformer &transformer, - ParseResult &parse_result) { - return make_uniq("BIT", vector> {}); +unique_ptr PEGTransformerFactory::TransformIntervalWithoutSpecifier(PEGTransformer &transformer) { + return make_uniq(Identifier("INTERVAL"), vector> {}); } -unique_ptr PEGTransformerFactory::TransformIntervalType(PEGTransformer &transformer, - ParseResult &parse_result) { - return make_uniq("INTERVAL", vector> {}); +DatePartSpecifier PEGTransformerFactory::TransformIntervalToIntervalAsType(PEGTransformer &transformer, + ParseResult &parse_result) { + return DatePartSpecifier::INVALID; } -unique_ptr PEGTransformerFactory::TransformIntervalInterval(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &opt_interval = list_pr.Child(1); - if (opt_interval.HasResult()) { - auto logical_type = transformer.Transform(opt_interval.GetResult()); - return make_uniq(LogicalTypeIdToString(logical_type.id()), - vector> {}); - } - return make_uniq("INTERVAL", vector> {}); +unique_ptr +PEGTransformerFactory::TransformIntervalWithRangeSpecifier(PEGTransformer &transformer, + const DatePartSpecifier &interval_to_interval_as_type) { + return make_uniq(Identifier("INTERVAL"), vector> {}); } -DatePartSpecifier PEGTransformerFactory::TransformInterval(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - if (StringUtil::CIEquals(choice_pr.name, "IntervalToInterval")) { - return transformer.Transform(choice_pr.GetResult()); - } else { - return transformer.TransformEnum(choice_pr); - } +unique_ptr +PEGTransformerFactory::TransformIntervalWithSimpleSpecifier(PEGTransformer &transformer, + const DatePartSpecifier &interval) { + return make_uniq(Identifier("INTERVAL"), vector> {}); +} + +DatePartSpecifier PEGTransformerFactory::TransformYearKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::YEAR; +} + +DatePartSpecifier PEGTransformerFactory::TransformMonthKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::MONTH; +} + +DatePartSpecifier PEGTransformerFactory::TransformDayKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::DAY; +} + +DatePartSpecifier PEGTransformerFactory::TransformHourKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::HOUR; +} + +DatePartSpecifier PEGTransformerFactory::TransformMinuteKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::MINUTE; +} + +DatePartSpecifier PEGTransformerFactory::TransformSecondKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::SECOND; +} + +DatePartSpecifier PEGTransformerFactory::TransformMillisecondKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::MILLISECONDS; +} + +DatePartSpecifier PEGTransformerFactory::TransformMicrosecondKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::MICROSECONDS; +} + +DatePartSpecifier PEGTransformerFactory::TransformWeekKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::WEEK; +} + +DatePartSpecifier PEGTransformerFactory::TransformQuarterKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::QUARTER; +} + +DatePartSpecifier PEGTransformerFactory::TransformDecadeKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::DECADE; +} + +DatePartSpecifier PEGTransformerFactory::TransformCenturyKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::CENTURY; +} + +DatePartSpecifier PEGTransformerFactory::TransformMillenniumKeyword(PEGTransformer &transformer) { + return DatePartSpecifier::MILLENNIUM; } -DatePartSpecifier PEGTransformerFactory::TransformIntervalToInterval(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - auto &inner_list = choice_pr.GetResult().Cast(); - auto first_interval = transformer.TransformEnum(inner_list.GetChild(0)); - auto second_interval = transformer.TransformEnum(inner_list.GetChild(2)); +static DatePartSpecifier UnsupportedIntervalRange(DatePartSpecifier first_interval, DatePartSpecifier second_interval) { throw ParserException("%s TO %s is not supported", EnumUtil::ToString(first_interval), EnumUtil::ToString(second_interval)); } +DatePartSpecifier PEGTransformerFactory::TransformYearToMonth(PEGTransformer &transformer, + const DatePartSpecifier &year_keyword, + const DatePartSpecifier &month_keyword) { + return UnsupportedIntervalRange(year_keyword, month_keyword); +} + +DatePartSpecifier PEGTransformerFactory::TransformDayToHour(PEGTransformer &transformer, + const DatePartSpecifier &day_keyword, + const DatePartSpecifier &hour_keyword) { + return UnsupportedIntervalRange(day_keyword, hour_keyword); +} + +DatePartSpecifier PEGTransformerFactory::TransformDayToMinute(PEGTransformer &transformer, + const DatePartSpecifier &day_keyword, + const DatePartSpecifier &minute_keyword) { + return UnsupportedIntervalRange(day_keyword, minute_keyword); +} + +DatePartSpecifier PEGTransformerFactory::TransformDayToSecond(PEGTransformer &transformer, + const DatePartSpecifier &day_keyword, + const DatePartSpecifier &second_keyword) { + return UnsupportedIntervalRange(day_keyword, second_keyword); +} + +DatePartSpecifier PEGTransformerFactory::TransformHourToMinute(PEGTransformer &transformer, + const DatePartSpecifier &hour_keyword, + const DatePartSpecifier &minute_keyword) { + return UnsupportedIntervalRange(hour_keyword, minute_keyword); +} + +DatePartSpecifier PEGTransformerFactory::TransformHourToSecond(PEGTransformer &transformer, + const DatePartSpecifier &hour_keyword, + const DatePartSpecifier &second_keyword) { + return UnsupportedIntervalRange(hour_keyword, second_keyword); +} + +DatePartSpecifier PEGTransformerFactory::TransformMinuteToSecond(PEGTransformer &transformer, + const DatePartSpecifier &minute_keyword, + const DatePartSpecifier &second_keyword) { + return UnsupportedIntervalRange(minute_keyword, second_keyword); +} + unique_ptr PEGTransformerFactory::TryNegateValue(const ConstantExpression &expr) { auto &val = expr.GetValue(); @@ -547,10 +611,8 @@ unique_ptr PEGTransformerFactory::TransformNumberLiteral(PEGTr } unique_ptr PEGTransformerFactory::TransformSetofType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto setof_type = transformer.Transform(list_pr.Child(1)); - return UnboundType::GetTypeExpression(setof_type)->Copy(); + const LogicalType &type) { + return UnboundType::GetTypeExpression(type)->Copy(); } // StringLiteral <- '\'' [^\']* '\'' @@ -558,4 +620,9 @@ string PEGTransformerFactory::TransformStringLiteral(PEGTransformer &transformer auto &string_literal_pr = parse_result.Cast(); return string_literal_pr.result; } + +Identifier PEGTransformerFactory::TransformConstraintName(PEGTransformer &transformer, + const Identifier &col_id_or_string) { + return Identifier(col_id_or_string); +} } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_connect.cpp b/src/duckdb/src/parser/peg/transformer/transform_connect.cpp new file mode 100644 index 000000000..a0ad3b38c --- /dev/null +++ b/src/duckdb/src/parser/peg/transformer/transform_connect.cpp @@ -0,0 +1,61 @@ +#include "duckdb/parser/peg/transformer/peg_transformer.hpp" +#include "duckdb/parser/statement/connect_statement.hpp" +#include "duckdb/parser/statement/disconnect_statement.hpp" + +namespace duckdb { + +//! Shape captured from `SessionTarget <- 'LOCAL' / StringLiteral / CatalogName`. The framework +//! wraps the named sub-rule in a List whose only child is the Choice over the alternatives. +struct SessionTargetCapture { + string name; + bool target_is_local = false; + bool name_is_string_literal = false; +}; + +static SessionTargetCapture TransformSessionTarget(PEGTransformer &transformer, ParseResult &target_result) { + auto &list = target_result.Cast(); + auto &inner = list.Child(0).GetResult(); + SessionTargetCapture result; + switch (inner.type) { + case ParseResultType::KEYWORD: + // 'LOCAL' alternative — name stays empty, just flip the flag. + result.target_is_local = true; + break; + case ParseResultType::STRING: + result.name = inner.Cast().GetRawString(); + result.name_is_string_literal = true; + break; + case ParseResultType::IDENTIFIER: + result.name = inner.Cast().identifier.GetIdentifierName(); + break; + default: + throw InternalException("Unexpected SessionTarget alternative type: %s", ParseResultToString(inner.type)); + } + return result; +} + +// ConnectStatement <- 'CONNECT' SessionTarget? +unique_ptr PEGTransformerFactory::TransformConnectStatement(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto info = make_uniq(); + auto &target_opt = list_pr.Child(1); + if (target_opt.HasResult()) { + auto captured = TransformSessionTarget(transformer, target_opt.GetResult()); + info->name = Identifier(std::move(captured.name)); + info->target_is_local = captured.target_is_local; + info->name_is_string_literal = captured.name_is_string_literal; + } + auto result = make_uniq(); + result->info = std::move(info); + return std::move(result); +} + +// DisconnectStatement <- 'DISCONNECT' +unique_ptr PEGTransformerFactory::TransformDisconnectStatement(PEGTransformer &transformer) { + auto result = make_uniq(); + result->info = make_uniq(); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_copy.cpp b/src/duckdb/src/parser/peg/transformer/transform_copy.cpp index 650c5721c..ffd44aed0 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_copy.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_copy.cpp @@ -8,52 +8,52 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformCopyStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto ©_mode = list_pr.Child(1); - return transformer.Transform>(copy_mode.Child(0).GetResult()); + unique_ptr copy_variations) { + return copy_variations; } void SetCopyOptions(unique_ptr &info, vector &options) { case_insensitive_string_set_t option_names; for (auto &option : options) { - if (option_names.find(option.name) != option_names.end()) { + if (option_names.find(option.name.GetIdentifierName()) != option_names.end()) { throw ParserException("Unexpected duplicate option \"%s\"", option.name); } - option_names.insert(option.name); - if (StringUtil::CIEquals(option.name, "PARTITION_BY") || StringUtil::CIEquals(option.name, "FORCE_QUOTE") || - StringUtil::CIEquals(option.name, "FORCE_NOT_NULL") || StringUtil::CIEquals(option.name, "FORCE_NULL")) { + option_names.insert(option.name.GetIdentifierName()); + if (option.name == "PARTITION_BY" || option.name == "FORCE_QUOTE" || option.name == "FORCE_NOT_NULL" || + option.name == "FORCE_NULL") { if (option.expression) { - info->parsed_options[option.name] = std::move(option.expression); + info->parsed_options[option.name.GetIdentifierName()] = std::move(option.expression); } else { if (option.children.empty()) { throw BinderException("\"%s\" expects a column list or * as parameter", option.name); } vector> func_children; for (const auto &partition : option.children) { - func_children.push_back(make_uniq(partition.GetValue())); + func_children.push_back(make_uniq(Identifier(partition.GetValue()))); } auto row_func = make_uniq(INVALID_CATALOG, DEFAULT_SCHEMA, "row", std::move(func_children)); - info->parsed_options[option.name] = std::move(row_func); + info->parsed_options[option.name.GetIdentifierName()] = std::move(row_func); } - } else if (StringUtil::CIEquals(option.name, "HEADER") || StringUtil::CIEquals(option.name, "ESCAPE")) { + } else if (option.name == "HEADER" || option.name == "ESCAPE") { if (option.children.empty()) { - info->parsed_options[option.name] = nullptr; + info->parsed_options[option.name.GetIdentifierName()] = nullptr; } else { - info->parsed_options[option.name] = make_uniq(option.children[0]); + info->parsed_options[option.name.GetIdentifierName()] = + make_uniq(option.children[0]); } - } else if (StringUtil::CIEquals(option.name, "NULL") || StringUtil::CIEquals(option.name, "NULLSTR")) { + } else if (option.name == "NULL" || option.name == "NULLSTR") { if (option.children.empty()) { - info->parsed_options[option.name] = std::move(option.expression); + info->parsed_options[option.name.GetIdentifierName()] = std::move(option.expression); } else { - info->parsed_options[option.name] = make_uniq(option.children[0]); + info->parsed_options[option.name.GetIdentifierName()] = + make_uniq(option.children[0]); } } else { if (option.expression) { - info->parsed_options[option.name] = std::move(option.expression); + info->parsed_options[option.name.GetIdentifierName()] = std::move(option.expression); } else { - info->options[option.name] = option.children; + info->options[option.name.GetIdentifierName()] = option.children; } } } @@ -68,60 +68,45 @@ void SetCopyOptions(unique_ptr &info, vector &optio } } -unique_ptr PEGTransformerFactory::TransformCopySelect(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &select_parens = ExtractResultFromParens(list_pr.Child(0)); - auto select_statement = transformer.Transform>(select_parens); +unique_ptr PEGTransformerFactory::TransformCopySelect( + PEGTransformer &transformer, unique_ptr select_statement_internal, + unique_ptr copy_file_name, const vector ©_options) { auto result = make_uniq(); auto info = make_uniq(); info->is_from = false; - auto file_name = transformer.Transform>(list_pr.Child(2)); - if (file_name->GetExpressionClass() == ExpressionClass::CONSTANT) { - auto &const_expr = file_name->Cast(); + if (copy_file_name->GetExpressionClass() == ExpressionClass::CONSTANT) { + auto &const_expr = copy_file_name->Cast(); info->file_path = const_expr.GetValue().GetValue(); } else { - info->file_path_expression = std::move(file_name); + info->file_path_expression = std::move(copy_file_name); } - auto &options_opt = list_pr.Child(3); - if (options_opt.HasResult()) { - auto options = transformer.Transform>(options_opt.GetResult()); - SetCopyOptions(info, options); - } - info->select_statement = std::move(select_statement->node); + auto options = copy_options; + SetCopyOptions(info, options); + info->select_statement = std::move(select_statement_internal->node); result->info = std::move(info); return std::move(result); } -unique_ptr PEGTransformerFactory::TransformCopyFromDatabase(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - - auto from_database = transformer.Transform(list_pr.Child(2)); - auto to_database = transformer.Transform(list_pr.Child(4)); +unique_ptr +PEGTransformerFactory::TransformCopyFromDatabaseWithFlag(PEGTransformer &transformer, const Identifier &col_id, + const Identifier &col_id_1, + const CopyDatabaseType ©_database_flag) { + return make_uniq(Identifier(col_id), Identifier(col_id_1), copy_database_flag); +} - auto ©_database_flag = list_pr.Child(5); - if (copy_database_flag.HasResult()) { - auto copy_type = transformer.Transform(copy_database_flag.GetResult()); - return make_uniq(from_database, to_database, copy_type); - } +unique_ptr PEGTransformerFactory::TransformCopyFromDatabaseWithoutFlag(PEGTransformer &transformer, + const Identifier &col_id, + const Identifier &col_id_1) { auto result = make_uniq(); result->info->name = "copy_database"; - result->info->parameters.emplace_back(make_uniq(Value(from_database))); - result->info->parameters.emplace_back(make_uniq(Value(to_database))); + result->info->parameters.emplace_back(make_uniq(Value(col_id))); + result->info->parameters.emplace_back(make_uniq(Value(col_id_1))); return std::move(result); } CopyDatabaseType PEGTransformerFactory::TransformCopyDatabaseFlag(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - return transformer.Transform(extract_parens); -} - -CopyDatabaseType PEGTransformerFactory::TransformSchemaOrData(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); + const CopyDatabaseType &schema_or_data) { + return schema_or_data; } string PEGTransformerFactory::ExtractFormat(const string &file_path) { @@ -140,150 +125,110 @@ string PEGTransformerFactory::ExtractFormat(const string &file_path) { } unique_ptr PEGTransformerFactory::TransformCopyTable(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - + unique_ptr base_table_name, + const vector &insert_column_list, + const bool &from_or_to, + unique_ptr copy_file_name, + const vector ©_options) { auto result = make_uniq(); auto info = make_uniq(); - auto base_table = transformer.Transform>(list_pr.Child(0)); - info->table = base_table->table_name; - info->schema = base_table->schema_name; - info->catalog = base_table->catalog_name; - auto &insert_column_list = list_pr.Child(1); - if (insert_column_list.HasResult()) { - info->select_list = transformer.Transform>(insert_column_list.GetResult()); - } - info->is_from = transformer.Transform(list_pr.Child(2)); - auto file_name = transformer.Transform>(list_pr.Child(3)); - if (file_name->GetExpressionClass() == ExpressionClass::CONSTANT) { - auto &const_expr = file_name->Cast(); + info->table = base_table_name->table_name; + info->schema = base_table_name->schema_name; + info->catalog = base_table_name->catalog_name; + info->select_list = StringsToIdentifiers(insert_column_list); + info->is_from = from_or_to; + if (copy_file_name->GetExpressionClass() == ExpressionClass::CONSTANT) { + auto &const_expr = copy_file_name->Cast(); info->file_path = const_expr.GetValue().GetValue(); } else { - info->file_path_expression = std::move(file_name); + info->file_path_expression = std::move(copy_file_name); } info->format = ExtractFormat(info->file_path); - auto ©_options_pr = list_pr.Child(4); - if (copy_options_pr.HasResult()) { - auto generic_options = transformer.Transform>(copy_options_pr.GetResult()); - SetCopyOptions(info, generic_options); - } + auto generic_options = copy_options; + SetCopyOptions(info, generic_options); result->info = std::move(info); return std::move(result); } -bool PEGTransformerFactory::TransformFromOrTo(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &from_or_to = list_pr.Child(0).GetResult(); - auto &keyword = from_or_to.Cast(); - return StringUtil::CIEquals(keyword.keyword, "from"); +bool PEGTransformerFactory::TransformCopyFrom(PEGTransformer &transformer) { + return true; } -unique_ptr PEGTransformerFactory::TransformCopyFileName(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - if (choice_pr.name == "ParensExpression" || choice_pr.name == "Parameter") { - return transformer.Transform>(choice_pr); - } - string file_name; - if (choice_pr.type == ParseResultType::IDENTIFIER) { - file_name = choice_pr.Cast().identifier; - if (StringUtil::CIEquals(file_name, "stdout")) { - file_name = "/dev/stdout"; - } - } else { - file_name = transformer.Transform(choice_pr); - } - return make_uniq(Value(file_name)); +bool PEGTransformerFactory::TransformCopyTo(PEGTransformer &transformer) { + return false; } -string PEGTransformerFactory::TransformIdentifierColId(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - string result; - result += list_pr.Child(0).name; - result += "."; - result += transformer.Transform(list_pr.Child(2)); - return result; +unique_ptr PEGTransformerFactory::TransformCopyFileNameStringLiteral(PEGTransformer &transformer, + const string &string_literal) { + return make_uniq(Value(string_literal)); } -vector PEGTransformerFactory::TransformCopyOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - // CopyOptions <- 'WITH'? GenericCopyOptionList / SpecializedOptions - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(1).GetResult()); +unique_ptr PEGTransformerFactory::TransformCopyFileNameIdentifier(PEGTransformer &transformer, + const Identifier &identifier) { + string file_name = identifier == "stdout" ? "/dev/stdout" : identifier.GetIdentifierName(); + return make_uniq(Value(file_name)); } -vector PEGTransformerFactory::TransformSpecializedOptionList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &options_opt = list_pr.Child(0); - if (!options_opt.HasResult()) { - return {}; - } - auto &options = options_opt.GetResult().Cast(); - vector result; - for (auto option : options.GetChildren()) { - result.push_back(transformer.Transform(option)); - } - return result; +unique_ptr +PEGTransformerFactory::TransformCopyFileNameIdentifierColId(PEGTransformer &transformer, + const Identifier &identifier_col_id) { + return make_uniq(Value(identifier_col_id)); +} + +Identifier PEGTransformerFactory::TransformIdentifierColId(PEGTransformer &transformer, const Identifier &identifier, + const Identifier &col_id) { + string result; + result += identifier.GetIdentifierName(); + result += "."; + result += col_id.GetIdentifierName(); + return Identifier(result); } -GenericCopyOption PEGTransformerFactory::TransformSpecializedOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0).GetResult()); +vector +PEGTransformerFactory::TransformCopyOptions(PEGTransformer &transformer, + const vector ©_option_list) { + return copy_option_list; } -GenericCopyOption PEGTransformerFactory::TransformSingleOption(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); +vector +PEGTransformerFactory::TransformSpecializedOptionList(PEGTransformer &transformer, + const vector &specialized_option) { + return specialized_option; } GenericCopyOption PEGTransformerFactory::TransformEncodingOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto string_literal = list_pr.Child(1).result; + const string &string_literal) { return GenericCopyOption("encoding", string_literal); } -GenericCopyOption PEGTransformerFactory::TransformForceQuoteOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - bool force_quote = list_pr.Child(0).HasResult(); +GenericCopyOption PEGTransformerFactory::TransformForceQuoteOption(PEGTransformer &transformer, const bool &force_quote, + const vector &star_symbol_column_list) { string func_name = force_quote ? "force_quote" : "quote"; - auto &star_or_column_list_pr = list_pr.Child(2); - auto &star_or_column_list = star_or_column_list_pr.Child(0).GetResult(); auto result = GenericCopyOption(); - result.name = func_name; - if (StringUtil::CIEquals(star_or_column_list.name, "StarSymbol")) { + result.name = Identifier(func_name); + if (star_symbol_column_list.empty()) { result.expression = make_uniq(); - } else if (StringUtil::CIEquals(star_or_column_list.name, "ColumnList")) { - auto column_list = transformer.Transform>(star_or_column_list); - for (auto &col : column_list) { - result.children.push_back(Value(col)); - } + return result; + } + for (auto &col : star_symbol_column_list) { + result.children.push_back(Value(col)); } - return result; } GenericCopyOption PEGTransformerFactory::TransformQuoteAsOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto string_literal = list_pr.Child(2).result; + const string &string_literal) { return GenericCopyOption("quote", string_literal); } GenericCopyOption PEGTransformerFactory::TransformForceNullOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - bool is_not = list_pr.Child(1).HasResult(); + const bool &force_not_null, + const vector &column_list) { auto result = GenericCopyOption(); - result.name = is_not ? "force_not_null" : "force_null"; - auto column_list = transformer.Transform>(list_pr.Child(3)); + result.name = force_not_null ? "force_not_null" : "force_null"; for (auto &col : column_list) { result.children.push_back(Value(col)); } @@ -291,41 +236,68 @@ GenericCopyOption PEGTransformerFactory::TransformForceNullOption(PEGTransformer } GenericCopyOption PEGTransformerFactory::TransformPartitionByOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const vector &star_symbol_column_list) { auto result = GenericCopyOption(); - auto &star_or_column_list_pr = list_pr.Child(2); - auto &star_or_column_list = star_or_column_list_pr.Child(0).GetResult(); result.name = "partition_by"; - if (StringUtil::CIEquals(star_or_column_list.name, "StarSymbol")) { + if (star_symbol_column_list.empty()) { result.expression = make_uniq(); - } else if (StringUtil::CIEquals(star_or_column_list.name, "ColumnList")) { - auto column_list = transformer.Transform>(star_or_column_list); - for (auto &col : column_list) { - result.children.push_back(Value(col)); - } + return result; + } + for (auto &col : star_symbol_column_list) { + result.children.push_back(Value(col)); } return result; } -GenericCopyOption PEGTransformerFactory::TransformNullAsOption(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto string_literal = list_pr.Child(2).result; +GenericCopyOption PEGTransformerFactory::TransformNullAsOption(PEGTransformer &transformer, + const string &string_literal) { return GenericCopyOption("null", string_literal); } GenericCopyOption PEGTransformerFactory::TransformDelimiterAsOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto string_literal = list_pr.Child(2).result; + const string &string_literal) { return GenericCopyOption("delimiter", string_literal); } GenericCopyOption PEGTransformerFactory::TransformEscapeAsOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto string_literal = list_pr.Child(2).result; + const string &string_literal) { return GenericCopyOption("escape", string_literal); } +GenericCopyOption PEGTransformerFactory::TransformBinaryOption(PEGTransformer &transformer) { + return GenericCopyOption("format", Value("binary")); +} + +GenericCopyOption PEGTransformerFactory::TransformFreezeOption(PEGTransformer &transformer) { + return GenericCopyOption("freeze", Value()); +} + +GenericCopyOption PEGTransformerFactory::TransformOidsOption(PEGTransformer &transformer) { + return GenericCopyOption("oids", Value()); +} + +GenericCopyOption PEGTransformerFactory::TransformCsvOption(PEGTransformer &transformer) { + return GenericCopyOption("format", Value("csv")); +} + +GenericCopyOption PEGTransformerFactory::TransformHeaderOption(PEGTransformer &transformer) { + return GenericCopyOption("header", Value(true)); +} + +CopyDatabaseType PEGTransformerFactory::TransformCopySchema(PEGTransformer &transformer) { + return CopyDatabaseType::COPY_SCHEMA; +} + +CopyDatabaseType PEGTransformerFactory::TransformCopyData(PEGTransformer &transformer) { + return CopyDatabaseType::COPY_DATA; +} + +bool PEGTransformerFactory::TransformForceQuote(PEGTransformer &transformer) { + return true; +} + +bool PEGTransformerFactory::TransformForceNotNull(PEGTransformer &transformer) { + return true; +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_index.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_index.cpp index b16b399de..f4a6d23bf 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_index.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_index.cpp @@ -3,165 +3,144 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformCreateIndexStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr PEGTransformerFactory::TransformCreateIndexStmt( + PEGTransformer &transformer, const bool &unique_index, const bool &if_not_exists, const Identifier &index_name, + unique_ptr base_table_name, const vector &insert_column_list, const Identifier &index_type, + vector> index_element, case_insensitive_map_t> with_list, + unique_ptr where_clause) { auto result = make_uniq(); auto index_info = make_uniq(); - bool unique = list_pr.Child(0).HasResult(); - index_info->constraint_type = unique ? IndexConstraintType::UNIQUE : IndexConstraintType::NONE; - bool if_not_exists = list_pr.Child(2).HasResult(); + index_info->constraint_type = unique_index ? IndexConstraintType::UNIQUE : IndexConstraintType::NONE; index_info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; - auto &index_name_opt = list_pr.Child(3); - if (index_name_opt.HasResult()) { - index_info->index_name = index_name_opt.GetResult().Cast().identifier; - } else { + if (index_name.empty()) { throw NotImplementedException("Please provide an index name, e.g., CREATE INDEX my_name ..."); } - auto table = transformer.Transform>(list_pr.Child(5)); - index_info->table = table->table_name; - index_info->catalog = table->catalog_name; - index_info->schema = table->schema_name; - index_info->index_type = "ART"; - auto &column_list_opt = list_pr.Child(6); - if (column_list_opt.HasResult()) { - auto column_list = transformer.Transform>(column_list_opt.GetResult()); - for (auto &column : column_list) { - index_info->expressions.push_back(make_uniq(column, table->table_name)); - index_info->parsed_expressions.push_back(make_uniq(column, table->table_name)); - } - } - transformer.TransformOptional(list_pr, 7, index_info->index_type); - auto &index_elements_opt = list_pr.Child(8); - if (index_elements_opt.HasResult()) { - auto &extract_parens = ExtractResultFromParens(index_elements_opt.GetResult()); - auto index_element_list = ExtractParseResultsFromList(extract_parens); - for (auto index_element : index_element_list) { - auto expr = transformer.Transform>(index_element); - if (expr->GetExpressionType() == ExpressionType::COLLATE) { - throw NotImplementedException("Index with collation not supported yet!"); - } - index_info->expressions.push_back(expr->Copy()); - index_info->parsed_expressions.push_back(expr->Copy()); - } + index_info->index_name = index_name; + index_info->table = base_table_name->table_name; + index_info->catalog = base_table_name->catalog_name; + index_info->schema = base_table_name->schema_name; + index_info->index_type = index_type.empty() ? "ART" : index_type.GetIdentifierName(); + for (auto &column : insert_column_list) { + index_info->expressions.push_back( + make_uniq(Identifier(column), base_table_name->table_name)); + index_info->parsed_expressions.push_back( + make_uniq(Identifier(column), base_table_name->table_name)); } - - auto &with_list_opt = list_pr.Child(9); - if (with_list_opt.HasResult()) { - auto options_expr = - transformer.Transform>>(with_list_opt.GetResult()); - for (auto &option_entry : options_expr) { - if (option_entry.second->GetExpressionClass() != ExpressionClass::CONSTANT) { - throw InvalidInputException("Create index option must be a constant value"); - } - index_info->options[option_entry.first] = option_entry.second->Cast().GetValue(); + for (auto &expr : index_element) { + if (expr->GetExpressionType() == ExpressionType::COLLATE) { + throw NotImplementedException("Index with collation not supported yet!"); } + index_info->expressions.push_back(expr->Copy()); + index_info->parsed_expressions.push_back(std::move(expr)); } - auto &where_opt = list_pr.Child(10); - if (where_opt.HasResult()) { + if (where_clause) { throw NotImplementedException("Creating partial indexes is not supported currently"); } + for (auto &option_entry : with_list) { + if (option_entry.second->GetExpressionClass() != ExpressionClass::CONSTANT) { + throw InvalidInputException("Create index option must be a constant value"); + } + index_info->options[option_entry.first] = option_entry.second->Cast().GetValue(); + } result->info = std::move(index_info); return result; } -string PEGTransformerFactory::TransformIndexType(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(1).identifier; +string PEGTransformerFactory::TransformDottedIdentifierString(PEGTransformer &transformer, + const vector &dotted_identifier) { + return StringUtil::Join(dotted_identifier, "."); +} + +Identifier PEGTransformerFactory::TransformIndexType(PEGTransformer &transformer, const Identifier &identifier) { + return identifier; } unique_ptr PEGTransformerFactory::TransformIndexElement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - // TODO(Dtenwolde): We currently ignore DescOrAsc? and NullsFirstOrLast? - return transformer.Transform>(list_pr.Child(0)); + unique_ptr expression, + const OrderType &desc_or_asc, + const OrderByNullType &nulls_first_or_last) { + // TODO(Dtenwolde): We currently ignore desc_or_asc and nulls_first_or_last + return expression; } -case_insensitive_map_t> -PEGTransformerFactory::TransformWithList(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>>( - list_pr.Child(1)); +bool PEGTransformerFactory::TransformUniqueIndex(PEGTransformer &transformer) { + return true; } case_insensitive_map_t> -PEGTransformerFactory::TransformRelOptionOrOids(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>>( - list_pr.Child(0).GetResult()); +PEGTransformerFactory::TransformWithList(PEGTransformer &transformer, + case_insensitive_map_t> rel_option_or_oids) { + return rel_option_or_oids; } case_insensitive_map_t> -PEGTransformerFactory::TransformRelOptionList(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +PEGTransformerFactory::TransformRelOptionList(PEGTransformer &transformer, + vector>> rel_option) { case_insensitive_map_t> result; - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - auto rel_option_list = ExtractParseResultsFromList(extract_parens); - for (auto rel_option : rel_option_list) { - auto option = transformer.Transform>>(rel_option); - result.insert({option.first, std::move(option.second)}); + for (auto &option : rel_option) { + result.insert({option.first.GetIdentifierName(), std::move(option.second)}); } return result; } -case_insensitive_map_t> PEGTransformerFactory::TransformOids(PEGTransformer &transformer, - ParseResult &parse_result) { - throw NotImplementedException("Oids for index are not yet implemented."); +// Oids <- WithOrWithoutOids 'OIDS' +case_insensitive_map_t> +PEGTransformerFactory::TransformOids(PEGTransformer &transformer, const bool &with_or_without_oids) { + throw NotImplementedException("OIDS for index are not yet implemented."); } -string PEGTransformerFactory::TransformRelOptionName(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - // RelOptionName <- DottedIdentifier / StringLiteral - auto &choice = list_pr.Child(0).GetResult(); - if (StringUtil::CIEquals(choice.name, "DottedIdentifier")) { - auto dotted_identifier = transformer.Transform>(choice); - return StringUtil::Join(dotted_identifier, "."); - } else { - return choice.Cast().GetRawString(); - } +// WithOids <- 'WITH' +bool PEGTransformerFactory::TransformWithOids(PEGTransformer &transformer) { + return true; } -unique_ptr PEGTransformerFactory::TransformRelOptionArgumentOpt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - // RelOptionArgumentOpt <- '=' DefArg - // child 0: '=' keyword, child 1: DefArg - // DefArg <- ReservedKeyword / StringLiteral / NumberLiteral / NoneLiteral / Expression - auto &defarg_list = list_pr.Child(1); - auto &def_arg_choice = defarg_list.Child(0).GetResult(); - if (def_arg_choice.name == "ReservedKeyword") { - auto &rw_list = def_arg_choice.Cast(); - auto keyword = rw_list.Child(0).GetResult().Cast().keyword; - return make_uniq(Value(keyword)); - } else if (def_arg_choice.name == "StringLiteral") { - return make_uniq(Value(transformer.Transform(def_arg_choice))); - } else if (def_arg_choice.name == "NoneLiteral" || def_arg_choice.name == "NullLiteral") { - return make_uniq(Value()); - } else if (def_arg_choice.name == "NumberLiteral" || def_arg_choice.name == "Expression") { - return transformer.Transform>(def_arg_choice); - } else { - throw ParserException("Unexpected rule encountered in TransformRelOptionArgumentOpt: %s", def_arg_choice.name); - } +// WithoutOids <- 'WITHOUT' +bool PEGTransformerFactory::TransformWithoutOids(PEGTransformer &transformer) { + return false; +} + +Identifier PEGTransformerFactory::TransformRelOptionName(PEGTransformer &transformer, const string &child) { + return Identifier(child); } -pair> PEGTransformerFactory::TransformRelOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - // RelOption <- RelOptionName RelOptionArgumentOpt? - // child 0: RelOptionName, child 1: RelOptionArgumentOpt? - auto option_name = transformer.Transform(list_pr.Child(0)); - auto &arg_opt = list_pr.Child(1); - if (!arg_opt.HasResult()) { - return {option_name, make_uniq(Value())}; +pair> +PEGTransformerFactory::TransformRelOption(PEGTransformer &transformer, const Identifier &rel_option_name, + unique_ptr rel_option_argument_opt) { + if (!rel_option_argument_opt) { + return {rel_option_name, make_uniq(Value())}; } - auto value = transformer.Transform>(arg_opt.GetResult()); - return {option_name, std::move(value)}; + return {rel_option_name, std::move(rel_option_argument_opt)}; +} + +// RelOptionArgumentOpt <- '=' DefArg +unique_ptr +PEGTransformerFactory::TransformRelOptionArgumentOpt(PEGTransformer &transformer, + unique_ptr def_arg) { + return def_arg; +} + +// DefArgNull <- NullLiteral +unique_ptr PEGTransformerFactory::TransformDefArgNull(PEGTransformer &transformer, + const Value &null_literal) { + return make_uniq(Value()); +} + +// DefArgKeyword <- ReservedKeyword +unique_ptr PEGTransformerFactory::TransformDefArgKeyword(PEGTransformer &transformer, + const string &reserved_keyword) { + return make_uniq(Value(reserved_keyword)); +} + +// DefArgStringLiteral <- StringLiteral +unique_ptr PEGTransformerFactory::TransformDefArgStringLiteral(PEGTransformer &transformer, + const string &string_literal) { + return make_uniq(Value(string_literal)); } -string PEGTransformerFactory::TransformIndexName(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; +// NoneLiteral <- 'NONE' +unique_ptr PEGTransformerFactory::TransformNoneLiteral(PEGTransformer &transformer) { + return make_uniq(Value()); } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_macro.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_macro.cpp index 7ac803b42..1f2e991ca 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_macro.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_macro.cpp @@ -5,14 +5,13 @@ #include "duckdb/function/scalar_macro_function.hpp" namespace duckdb { -unique_ptr PEGTransformerFactory::TransformCreateMacroStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformCreateMacroStmt(PEGTransformer &transformer, const bool ¯o_or_function, + const bool &if_not_exists, const QualifiedName &qualified_name, + vector> macro_definition) { auto result = make_uniq(); auto info = make_uniq(CatalogType::MACRO_ENTRY); - auto if_not_exists = list_pr.Child(1).HasResult(); - auto qualified_name = transformer.Transform(list_pr.Child(2)); if (qualified_name.schema.empty()) { info->schema = qualified_name.catalog; } else { @@ -20,17 +19,15 @@ unique_ptr PEGTransformerFactory::TransformCreateMacroStmt(PEGT info->schema = qualified_name.schema; } info->name = qualified_name.name; - auto macro_definition_list = ExtractParseResultsFromList(list_pr.Child(3)); info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; - vector> macro_functions; - for (auto macro_definition : macro_definition_list) { - info->macros.push_back(transformer.Transform>(macro_definition)); + for (auto ¯o_function : macro_definition) { + info->macros.push_back(std::move(macro_function)); } D_ASSERT(!info->macros.empty()); auto macro_type = info->macros[0]->type; if (info->macros.size() > 1) { - for (idx_t i = 1; i < macro_definition_list.size(); ++i) { + for (idx_t i = 1; i < info->macros.size(); ++i) { if (info->macros[i]->type != macro_type) { throw ParserException("Cannot mix table and scalar macro function definitions"); } @@ -42,84 +39,73 @@ unique_ptr PEGTransformerFactory::TransformCreateMacroStmt(PEGT return result; } -unique_ptr PEGTransformerFactory::TransformMacroDefinition(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &nested_list = list_pr.Child(2); +bool PEGTransformerFactory::TransformMacroKeyword(PEGTransformer &transformer) { + return true; +} - auto macro_function = - transformer.Transform>(nested_list.Child(0).GetResult()); - auto ¶meters_pr = ExtractResultFromParens(list_pr.Child(0)).Cast(); - if (parameters_pr.HasResult()) { - bool default_value_found = false; - auto parameters = transformer.Transform>(parameters_pr.GetResult()); - case_insensitive_string_set_t parameter_names; - for (auto ¶meter : parameters) { - D_ASSERT(!parameter.name.empty()); - if (parameter_names.find(parameter.name) != parameter_names.end()) { - throw ParserException("Duplicate parameter '%s' in macro definition", parameter.name); - } - parameter_names.insert(parameter.name); - if (parameter.is_default) { - auto default_expr = std::move(parameter.expression); - default_expr->SetAlias(parameter.name); - macro_function->default_parameters[parameter.name] = std::move(default_expr); - macro_function->parameters.push_back(make_uniq(parameter.name)); - default_value_found = true; - } else { - if (default_value_found) { - throw ParserException("Parameter without a default follows parameter with a default"); - } - macro_function->parameters.push_back(std::move(parameter.expression)); +bool PEGTransformerFactory::TransformFunctionKeyword(PEGTransformer &transformer) { + return false; +} + +unique_ptr +PEGTransformerFactory::TransformMacroDefinition(PEGTransformer &transformer, vector macro_parameters, + unique_ptr macro_definition_body) { + bool default_value_found = false; + identifier_set_t parameter_names; + for (auto ¶meter : macro_parameters) { + D_ASSERT(!parameter.name.empty()); + if (parameter_names.find(parameter.name) != parameter_names.end()) { + throw ParserException("Duplicate parameter '%s' in macro definition", parameter.name.GetIdentifierName()); + } + parameter_names.insert(parameter.name); + if (parameter.is_default) { + auto default_expr = std::move(parameter.expression); + default_expr->SetAlias(parameter.name); + macro_definition_body->default_parameters[parameter.name] = std::move(default_expr); + macro_definition_body->parameters.push_back(make_uniq(parameter.name)); + default_value_found = true; + } else { + if (default_value_found) { + throw ParserException("Parameter without a default follows parameter with a default"); } - macro_function->types.push_back(parameter.type); + macro_definition_body->parameters.push_back(std::move(parameter.expression)); } + macro_definition_body->types.push_back(parameter.type); } - return macro_function; + return macro_definition_body; } -unique_ptr PEGTransformerFactory::TransformTableMacroDefinition(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformTableMacroDefinition(PEGTransformer &transformer, + unique_ptr select_statement_internal) { auto result = make_uniq(); - auto select_statement = transformer.Transform>(list_pr.Child(1)); - result->query_node = std::move(select_statement->node); + result->query_node = std::move(select_statement_internal->node); return std::move(result); } -unique_ptr PEGTransformerFactory::TransformScalarMacroDefinition(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformScalarMacroDefinition(PEGTransformer &transformer, + unique_ptr expression) { auto result = make_uniq(); - result->expression = transformer.Transform>(list_pr.Child(0)); + result->expression = std::move(expression); return std::move(result); } vector PEGTransformerFactory::TransformMacroParameters(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto parameter_list = ExtractParseResultsFromList(list_pr.Child(0)); - vector parameters; - for (auto parameter : parameter_list) { - parameters.push_back(transformer.Transform(parameter)); - } - return parameters; -} - -MacroParameter PEGTransformerFactory::TransformMacroParameter(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - return transformer.Transform(choice_pr); + vector macro_parameter) { + return macro_parameter; } -MacroParameter PEGTransformerFactory::TransformSimpleParameter(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto parameter = transformer.Transform(list_pr.Child(0)); +MacroParameter PEGTransformerFactory::TransformSimpleParameter(PEGTransformer &transformer, + const Identifier &type_func_name, + const LogicalType &type) { MacroParameter result; - result.name = parameter; - result.expression = make_uniq(parameter); - transformer.TransformOptional(list_pr, 1, result.type); + result.name = Identifier(type_func_name); + result.expression = make_uniq(Identifier(type_func_name)); + if (type.id() != LogicalTypeId::INVALID) { + result.type = type; + } result.is_default = false; return result; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_schema.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_schema.cpp index d06961c8c..7063518e3 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_schema.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_schema.cpp @@ -2,12 +2,9 @@ #include "duckdb/parser/peg/transformer/peg_transformer.hpp" namespace duckdb { - unique_ptr PEGTransformerFactory::TransformCreateSchemaStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto if_not_exists = list_pr.Child(1).HasResult(); - auto qualified_name = transformer.Transform(list_pr.Child(2)); + const bool &if_not_exists, + const QualifiedName &qualified_name) { if (!qualified_name.catalog.empty()) { throw ParserException("CREATE SCHEMA too many dots: expected \"catalog.schema\" or \"schema\""); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_secret.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_secret.cpp index ca0673f57..2e4da4ddd 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_secret.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_secret.cpp @@ -13,24 +13,20 @@ Value PEGTransformerFactory::GetConstantExpressionValue(unique_ptr PEGTransformerFactory::TransformCreateSecretStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr PEGTransformerFactory::TransformCreateSecretStmt( + PEGTransformer &transformer, const bool &if_not_exists, const Identifier &secret_name, + const Identifier &secret_storage_specifier, const vector &generic_copy_option_list) { auto result = make_uniq(); - auto if_not_exists = list_pr.Child(1).HasResult(); auto on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; auto info = make_uniq(on_conflict, SecretPersistType::DEFAULT); - auto &secret_name_pr = list_pr.Child(2); - if (secret_name_pr.HasResult()) { - info->name = transformer.Transform(secret_name_pr.GetResult()); + if (!secret_name.empty()) { + info->name = secret_name; } - auto &secret_storage_specifier_pr = list_pr.Child(3); - if (secret_storage_specifier_pr.HasResult()) { - info->storage_type = StringUtil::Lower(transformer.Transform(secret_storage_specifier_pr.GetResult())); + if (!secret_storage_specifier.empty()) { + info->storage_type = Identifier(StringUtil::Lower(secret_storage_specifier.GetIdentifierName())); } - auto option_list = transformer.Transform>(list_pr.Child(4)); - for (auto option : option_list) { - auto lower_name = StringUtil::Lower(option.name); + for (const auto &option : generic_copy_option_list) { + auto lower_name = StringUtil::Lower(option.name.GetIdentifierName()); if (lower_name == "scope") { info->scope = option.GetFirstChildOrExpression(); continue; @@ -58,20 +54,19 @@ unique_ptr PEGTransformerFactory::TransformCreateSecretStmt(PEG "Can not combine a non-constant expression for the secret type with a default-named secret. Either " "provide an explicit secret name or use a constant expression for the secret type."); } - info->name = "__default_" + StringUtil::Lower(value.ToString()); + info->name = Identifier("__default_" + StringUtil::Lower(value.ToString())); } result->info = std::move(info); return result; } -string PEGTransformerFactory::TransformSecretName(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0)); +Identifier PEGTransformerFactory::TransformSecretStorageSpecifier(PEGTransformer &transformer, + const Identifier &identifier) { + return identifier; } -string PEGTransformerFactory::TransformSecretStorageSpecifier(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(1).identifier; +Identifier PEGTransformerFactory::TransformSecretName(PEGTransformer &transformer, const Identifier &col_id) { + return Identifier(col_id); } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_sequence.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_sequence.cpp index b9c23aefe..a779e4b9b 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_sequence.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_sequence.cpp @@ -4,30 +4,24 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformCreateSequenceStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto if_not_exists = list_pr.Child(1).HasResult(); - auto qualified_name = transformer.Transform(list_pr.Child(2)); +unique_ptr +PEGTransformerFactory::TransformCreateSequenceStmt(PEGTransformer &transformer, const bool &if_not_exists, + const QualifiedName &qualified_name, + vector>> sequence_option) { auto result = make_uniq(); auto info = make_uniq(); info->catalog = qualified_name.catalog; info->schema = qualified_name.schema; info->name = qualified_name.name; info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; - auto &opt_sequence_options = list_pr.Child(3); case_insensitive_map_t> sequence_options; - if (opt_sequence_options.HasResult()) { - auto &repeat_seq_options = opt_sequence_options.GetResult().Cast(); - for (auto seq_option : repeat_seq_options.GetChildren()) { - auto seq_result = transformer.Transform>>(seq_option); - if (sequence_options.find(seq_result.first) != sequence_options.end()) { - auto seq_option_capital = StringUtil::Lower(seq_result.first); - seq_option_capital[0] = StringUtil::CharacterToUpper(seq_option_capital[0]); - throw ParserException("%s should be passed at most once", seq_option_capital); - } - sequence_options.insert(std::move(seq_result)); + for (auto &seq_option : sequence_option) { + if (sequence_options.find(seq_option.first) != sequence_options.end()) { + auto seq_option_capital = StringUtil::Lower(seq_option.first); + seq_option_capital[0] = StringUtil::CharacterToUpper(seq_option_capital[0]); + throw ParserException("%s should be passed at most once", seq_option_capital); } + sequence_options.insert(std::move(seq_option)); } bool no_min = false; bool no_max = false; @@ -117,103 +111,88 @@ unique_ptr PEGTransformerFactory::TransformCreateSequenceStmt(P return result; } -pair> PEGTransformerFactory::TransformSequenceOption(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>>( - list_pr.Child(0).GetResult()); +pair> PEGTransformerFactory::TransformSeqCycle(PEGTransformer &transformer) { + return make_pair("cycle", make_uniq(SequenceInfo::SEQ_CYCLE, Value(true))); } -pair> PEGTransformerFactory::TransformSeqSetCycle(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - if (list_pr.Child(0).HasResult()) { - return make_pair("cycle", make_uniq(SequenceInfo::SEQ_CYCLE, Value(false))); - } - return make_pair("cycle", make_uniq(SequenceInfo::SEQ_CYCLE, Value(true))); +pair> PEGTransformerFactory::TransformSeqNoCycle(PEGTransformer &transformer) { + return make_pair("cycle", make_uniq(SequenceInfo::SEQ_CYCLE, Value(false))); } -pair> PEGTransformerFactory::TransformSeqSetIncrement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto expr = transformer.Transform>(list_pr.Child(2)); - if (expr->GetExpressionClass() == ExpressionClass::FUNCTION) { - auto func_expr = unique_ptr_cast(std::move(expr)); - if (func_expr->function_name != "-") { - throw InvalidInputException("Expected a minus function instead of %s", func_expr->function_name); +pair> +PEGTransformerFactory::TransformSeqSetIncrement(PEGTransformer &transformer, unique_ptr expression) { + if (expression->GetExpressionClass() == ExpressionClass::FUNCTION) { + auto func_expr = unique_ptr_cast(std::move(expression)); + if (func_expr->FunctionName() != "-") { + throw InvalidInputException("Expected a minus function instead of %s", func_expr->FunctionName()); } - D_ASSERT(!func_expr->children.empty()); - if (func_expr->children[0]->GetExpressionClass() != ExpressionClass::CONSTANT) { + D_ASSERT(!func_expr->GetArguments().empty()); + if (func_expr->GetArguments()[0].GetExpression().GetExpressionClass() != ExpressionClass::CONSTANT) { throw InvalidInputException("Expected constant expression as child of minus function"); } - const auto const_value = func_expr->children[0]->Cast().GetValue().GetValue(); - expr = make_uniq(Value::Numeric(LogicalType::BIGINT, -const_value)); + const auto const_value = + func_expr->GetArguments()[0].GetExpression().Cast().GetValue().GetValue(); + expression = make_uniq(Value::Numeric(LogicalType::BIGINT, -const_value)); } - if (expr->GetExpressionClass() != ExpressionClass::CONSTANT) { + if (expression->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected constant expression."); } - auto const_expr = expr->Cast(); + auto const_expr = expression->Cast(); return make_pair("increment", make_uniq(SequenceInfo::SEQ_INC, const_expr.GetValue())); } -pair> PEGTransformerFactory::TransformSeqSetMinMax(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto expr = transformer.Transform>(list_pr.Child(1)); - auto rule_name = transformer.Transform(list_pr.Child(0)); - - if (expr->GetExpressionClass() == ExpressionClass::FUNCTION) { - auto func_expr = unique_ptr_cast(std::move(expr)); - if (func_expr->function_name != "-") { - throw InvalidInputException("Expected a minus function instead of %s", func_expr->function_name); +pair> +PEGTransformerFactory::TransformSeqSetMinMax(PEGTransformer &transformer, const string &seq_min_or_max, + unique_ptr expression) { + if (expression->GetExpressionClass() == ExpressionClass::FUNCTION) { + auto func_expr = unique_ptr_cast(std::move(expression)); + if (func_expr->FunctionName() != "-") { + throw InvalidInputException("Expected a minus function instead of %s", func_expr->FunctionName()); } - D_ASSERT(!func_expr->children.empty()); - if (func_expr->children[0]->GetExpressionClass() != ExpressionClass::CONSTANT) { + D_ASSERT(!func_expr->GetArguments().empty()); + if (func_expr->GetArguments()[0].GetExpression().GetExpressionClass() != ExpressionClass::CONSTANT) { throw InvalidInputException("Expected constant expression as child of minus function"); } - const auto const_value = func_expr->children[0]->Cast().GetValue().GetValue(); - expr = make_uniq(Value::Numeric(LogicalType::BIGINT, -const_value)); + const auto const_value = + func_expr->GetArguments()[0].GetExpression().Cast().GetValue().GetValue(); + expression = make_uniq(Value::Numeric(LogicalType::BIGINT, -const_value)); } - if (expr->GetExpressionClass() != ExpressionClass::CONSTANT) { + if (expression->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected constant expression."); } - auto const_expr = expr->Cast(); - auto seq_info = rule_name == "minvalue" ? SequenceInfo::SEQ_MIN : SequenceInfo::SEQ_MAX; - return make_pair(rule_name, make_uniq(seq_info, const_expr.GetValue())); -} - -string PEGTransformerFactory::TransformSeqMinOrMax(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); + auto const_expr = expression->Cast(); + auto seq_info = seq_min_or_max == "minvalue" ? SequenceInfo::SEQ_MIN : SequenceInfo::SEQ_MAX; + return make_pair(seq_min_or_max, make_uniq(seq_info, const_expr.GetValue())); } pair> PEGTransformerFactory::TransformSeqNoMinMax(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto rule_name = transformer.Transform(list_pr.Child(1)); - auto seq_info = rule_name == "minvalue" ? SequenceInfo::SEQ_MIN : SequenceInfo::SEQ_MAX; - return make_pair("no" + rule_name, make_uniq(seq_info, true)); + const string &seq_min_or_max) { + auto seq_info = seq_min_or_max == "minvalue" ? SequenceInfo::SEQ_MIN : SequenceInfo::SEQ_MAX; + return make_pair("no" + seq_min_or_max, make_uniq(seq_info, true)); } -pair> PEGTransformerFactory::TransformSeqStartWith(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto expr = transformer.Transform>(list_pr.Child(2)); - - if (expr->GetExpressionClass() != ExpressionClass::CONSTANT) { +pair> +PEGTransformerFactory::TransformSeqStartWith(PEGTransformer &transformer, unique_ptr expression) { + if (expression->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected constant expression."); } - auto const_expr = expr->Cast(); + auto const_expr = expression->Cast(); return make_pair("start", make_uniq(SequenceInfo::SEQ_START, const_expr.GetValue())); } -pair> PEGTransformerFactory::TransformSeqOwnedBy(PEGTransformer &transformer, - ParseResult &parse_result) { +pair> +PEGTransformerFactory::TransformSeqOwnedBy(PEGTransformer &transformer, const QualifiedName &qualified_name) { // Unused by old transformer - auto &list_pr = parse_result.Cast(); - auto qualified_name = transformer.Transform(list_pr.Child(2)); return make_pair("owned", make_uniq(SequenceInfo::SEQ_OWN, qualified_name)); } +string PEGTransformerFactory::TransformMinValue(PEGTransformer &transformer) { + return "minvalue"; +} + +string PEGTransformerFactory::TransformMaxValue(PEGTransformer &transformer) { + return "maxvalue"; +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_table.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_table.cpp index 031f99324..0b3ac06fb 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_table.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_table.cpp @@ -1,6 +1,8 @@ +#include "duckdb/parser/peg/ast/column_constraint_entry.hpp" #include "duckdb/parser/peg/ast/column_constraints.hpp" #include "duckdb/parser/peg/ast/column_elements.hpp" -#include "duckdb/parser/peg/ast/create_table_as.hpp" +#include "duckdb/parser/peg/ast/create_table_column_element.hpp" +#include "duckdb/parser/peg/ast/create_table_definition.hpp" #include "duckdb/parser/peg/ast/generated_column_definition.hpp" #include "duckdb/parser/peg/ast/key_actions.hpp" #include "duckdb/parser/peg/ast/partition_sorted_options.hpp" @@ -15,103 +17,79 @@ #include "duckdb/parser/constraints/not_null_constraint.hpp" #include "duckdb/parser/expression/cast_expression.hpp" #include "duckdb/parser/expression/type_expression.hpp" +#include "duckdb/catalog/default/default_types.hpp" namespace duckdb { -unique_ptr PEGTransformerFactory::TransformCreateStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - bool replace = list_pr.Child(1).HasResult(); - auto result = transformer.Transform>(list_pr.Child(3)); +unique_ptr +PEGTransformerFactory::TransformCreateStatement(PEGTransformer &transformer, const bool &or_replace, + const SecretPersistType &temporary, + unique_ptr create_statement_variation) { + auto result = std::move(create_statement_variation); auto &conflict_policy = result->info->on_conflict; - if (replace) { + if (or_replace) { if (conflict_policy == OnCreateConflict::IGNORE_ON_CONFLICT) { throw ParserException("Cannot specify both OR REPLACE and IF NOT EXISTS within single create statement"); } conflict_policy = OnCreateConflict::REPLACE_ON_CONFLICT; } - auto persistent_type = SecretPersistType::DEFAULT; - transformer.TransformOptional(list_pr, 2, persistent_type); if (result->info->type == CatalogType::SECRET_ENTRY) { auto &secret_info = result->info->Cast(); - secret_info.persist_type = persistent_type; + secret_info.persist_type = temporary; } - result->info->temporary = persistent_type == SecretPersistType::TEMPORARY; + result->info->temporary = temporary == SecretPersistType::TEMPORARY; return std::move(result); } -SecretPersistType PEGTransformerFactory::TransformTemporary(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); +SecretPersistType PEGTransformerFactory::TransformPersistent(PEGTransformer &transformer) { + return SecretPersistType::PERSISTENT; } -unique_ptr PEGTransformerFactory::TransformCreateStatementVariation(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - return transformer.Transform>(choice_pr.GetResult()); +SecretPersistType PEGTransformerFactory::TransformTempPersistent(PEGTransformer &transformer) { + return SecretPersistType::TEMPORARY; } -unique_ptr PEGTransformerFactory::TransformCreateTableStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +SecretPersistType PEGTransformerFactory::TransformTemporaryPersistent(PEGTransformer &transformer) { + return SecretPersistType::TEMPORARY; +} + +unique_ptr PEGTransformerFactory::TransformCreateTableStmt( + PEGTransformer &transformer, const bool &if_not_exists, const QualifiedName &qualified_name, + CreateTableDefinition create_table_definition, const bool &commit_action) { auto result = make_uniq(); - QualifiedName table_name = transformer.Transform(list_pr.Child(2)); - if (table_name.name.empty()) { + if (qualified_name.name.empty()) { throw ParserException("Empty table name not supported"); } // Use appropriate constructor - auto info = make_uniq(table_name.catalog, table_name.schema, table_name.name); + auto info = make_uniq(qualified_name.catalog, qualified_name.schema, qualified_name.name); - bool if_not_exists = list_pr.Child(1).HasResult(); info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; - auto &table_as_or_column_list = list_pr.Child(3).Child(0); - if (table_as_or_column_list.name == "CreateTableAs") { - auto create_table_as = transformer.Transform(table_as_or_column_list.GetResult()); - info->query = std::move(create_table_as.select_statement); - info->columns = std::move(create_table_as.column_names); - info->partition_keys = std::move(create_table_as.partition_keys); - info->sort_keys = std::move(create_table_as.sort_keys); - info->options = std::move(create_table_as.options); - } else { - auto column_list = transformer.Transform(table_as_or_column_list.GetResult()); - info->columns = std::move(column_list.columns); - info->constraints = std::move(column_list.constraints); - info->partition_keys = std::move(column_list.partition_keys); - info->sort_keys = std::move(column_list.sort_keys); - info->options = std::move(column_list.options); - } - // On COMMIT is unused apart from checking whether it is ON COMMIT DELETE, which is unsupported - bool on_commit = false; - transformer.TransformOptional(list_pr, 4, on_commit); + info->query = std::move(create_table_definition.select_statement); + info->columns = std::move(create_table_definition.columns); + info->constraints = std::move(create_table_definition.constraints); + info->partition_keys = std::move(create_table_definition.partition_keys); + info->sort_keys = std::move(create_table_definition.sort_keys); + info->options = std::move(create_table_definition.options); result->info = std::move(info); return result; } -CreateTableAs PEGTransformerFactory::TransformCreateTableAs(PEGTransformer &transformer, ParseResult &parse_result) { - // CreateTableAs <- IdentifierList? PartitionSortedOptions? WithList? 'AS' SelectStatementInternal WithData? - // child 0: IdentifierList? - // child 1: PartitionSortedOptions? - // child 2: WithList? - // child 3: 'AS' - // child 4: Statement - // child 5: WithData? - auto &list_pr = parse_result.Cast(); - CreateTableAs result; - transformer.TransformOptional(list_pr, 0, result.column_names); - PartitionSortedOptions pso; - transformer.TransformOptional(list_pr, 1, pso); - result.partition_keys = std::move(pso.partition_keys); - result.sort_keys = std::move(pso.sort_keys); - transformer.TransformOptional>>(list_pr, 2, result.options); - auto stmt = transformer.Transform>(list_pr.Child(4)); - if (stmt->type != StatementType::SELECT_STATEMENT) { +CreateTableDefinition +PEGTransformerFactory::TransformCreateTableAs(PEGTransformer &transformer, ColumnList identifier_list, + PartitionSortedOptions partition_sorted_options, + case_insensitive_map_t> with_list, + unique_ptr statement, const bool &with_data) { + CreateTableDefinition result; + result.columns = std::move(identifier_list); + result.partition_keys = std::move(partition_sorted_options.partition_keys); + result.sort_keys = std::move(partition_sorted_options.sort_keys); + result.options = std::move(with_list); + if (statement->type != StatementType::SELECT_STATEMENT) { throw ParserException("CREATE TABLE AS requires a SELECT clause"); } - result.select_statement = unique_ptr_cast(std::move(stmt)); - transformer.TransformOptional(list_pr, 5, result.with_data); - if (result.with_data) { + result.select_statement = unique_ptr_cast(std::move(statement)); + if (with_data) { auto limit_modifier = make_uniq(); limit_modifier->limit = make_uniq(0); result.select_statement->node->modifiers.push_back(std::move(limit_modifier)); @@ -119,49 +97,47 @@ CreateTableAs PEGTransformerFactory::TransformCreateTableAs(PEGTransformer &tran return result; } -ColumnList PEGTransformerFactory::TransformIdentifierList(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - auto identifier_list = ExtractParseResultsFromList(extract_parens); +ColumnList PEGTransformerFactory::TransformIdentifierList(PEGTransformer &transformer, + const vector &identifier) { ColumnList result; - for (auto identifier : identifier_list) { - result.AddColumn( - ColumnDefinition(identifier.get().Cast().identifier, LogicalType::UNKNOWN)); + for (auto &name : identifier) { + result.AddColumn(ColumnDefinition(name, LogicalType::UNKNOWN)); } return result; } -ColumnElements PEGTransformerFactory::TransformCreateColumnList(PEGTransformer &transformer, - ParseResult &parse_result) { - // CreateColumnList <- Parens(CreateTableColumnList?) PartitionSortedOptions? WithList? - // child 0: Parens(CreateTableColumnList?) - // child 1: PartitionSortedOptions? - // child 2: WithList? - auto &list_pr = parse_result.Cast(); - auto &create_table_column_list = - ExtractResultFromParens(list_pr.Child(0)).Cast(); - if (!create_table_column_list.HasResult()) { +CreateTableDefinition +PEGTransformerFactory::TransformCreateColumnList(PEGTransformer &transformer, ColumnElements create_table_column_list, + PartitionSortedOptions partition_sorted_options, + case_insensitive_map_t> with_list) { + if (create_table_column_list.columns.empty()) { throw ParserException("Table must have at least one column!"); } - auto result = transformer.Transform(create_table_column_list.GetResult()); - PartitionSortedOptions pso; - transformer.TransformOptional(list_pr, 1, pso); - result.partition_keys = std::move(pso.partition_keys); - result.sort_keys = std::move(pso.sort_keys); - transformer.TransformOptional>>(list_pr, 2, result.options); + CreateTableDefinition result; + result.columns = std::move(create_table_column_list.columns); + result.constraints = std::move(create_table_column_list.constraints); + result.partition_keys = std::move(partition_sorted_options.partition_keys); + result.sort_keys = std::move(partition_sorted_options.sort_keys); + result.options = std::move(with_list); return result; } -ColumnElements PEGTransformerFactory::TransformCreateTableColumnList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto column_elements = ExtractParseResultsFromList(list_pr.Child(0)); +bool PEGTransformerFactory::TransformOrReplace(PEGTransformer &transformer) { + return true; +} + +bool PEGTransformerFactory::TransformIfNotExists(PEGTransformer &transformer) { + return true; +} + +ColumnElements +PEGTransformerFactory::TransformCreateTableColumnList(PEGTransformer &transformer, + vector create_table_column_element) { ColumnElements result; - for (idx_t col_idx = 0; col_idx < column_elements.size(); ++col_idx) { - auto &column_element_child = - column_elements[col_idx].get().Cast().Child(0).GetResult(); - if (column_element_child.name == "ColumnDefinition") { - auto column_result = transformer.Transform(column_element_child); + for (idx_t col_idx = 0; col_idx < create_table_column_element.size(); ++col_idx) { + auto &column_element = create_table_column_element[col_idx]; + if (column_element.column_definition) { + auto &column_result = *column_element.column_definition; for (auto &constraint : column_result.constraints) { result.constraints.push_back(std::move(constraint)); } @@ -174,38 +150,38 @@ ColumnElements PEGTransformerFactory::TransformCreateTableColumnList(PEGTransfor } } result.columns.AddColumn(std::move(column_result.column_definition)); - } else if (column_element_child.name == "TopLevelConstraint") { - result.constraints.push_back(transformer.Transform>(column_element_child)); } else { - throw NotImplementedException("Unknown column type encountered: %s", column_element_child.name); + result.constraints.push_back(std::move(column_element.constraint)); } } return result; } -// IdentifierOrStringLiteral <- Identifier / StringLiteral +CreateTableColumnElement +PEGTransformerFactory::TransformCreateTableColumnDefinition(PEGTransformer &transformer, + ConstraintColumnDefinition column_definition) { + CreateTableColumnElement result; + result.column_definition = make_uniq(std::move(column_definition)); + return result; +} + +CreateTableColumnElement +PEGTransformerFactory::TransformCreateTableConstraint(PEGTransformer &transformer, + unique_ptr top_level_constraint) { + CreateTableColumnElement result; + result.constraint = std::move(top_level_constraint); + return result; +} + QualifiedName PEGTransformerFactory::TransformIdentifierOrStringLiteral(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); + const string &child) { QualifiedName result; result.catalog = INVALID_CATALOG; result.schema = INVALID_SCHEMA; - if (choice_pr.GetResult().type == ParseResultType::IDENTIFIER) { - result.name = choice_pr.GetResult().Cast().identifier; - } - if (choice_pr.GetResult().type == ParseResultType::STRING) { - result.name = choice_pr.GetResult().Cast().result; - } + result.name = Identifier(child); return result; } -string PEGTransformerFactory::TransformColIdOrString(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - return transformer.Transform(choice_pr.GetResult()); -} - string PEGTransformerFactory::TransformColLabelOrString(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto &choice_pr = list_pr.Child(0); @@ -215,367 +191,320 @@ string PEGTransformerFactory::TransformColLabelOrString(PEGTransformer &transfor return transformer.Transform(choice_pr.GetResult()); } -string PEGTransformerFactory::TransformColId(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - if (choice_pr.GetResult().type == ParseResultType::IDENTIFIER) { - return choice_pr.GetResult().Cast().identifier; +Identifier PEGTransformerFactory::TransformColIdOrString(PEGTransformer &transformer, ParseResult &choice_result) { + if (choice_result.type == ParseResultType::STRING) { + return Identifier(choice_result.Cast().result); } - return transformer.Transform(choice_pr.GetResult()); + return transformer.Transform(choice_result); } string PEGTransformerFactory::TransformIdentifier(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; + return list_pr.Child(0).identifier.GetIdentifierName(); } vector PEGTransformerFactory::TransformDottedIdentifier(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - vector parts; - - parts.push_back(list_pr.Child(0).identifier); - - auto &optional_elements = list_pr.Child(1); - if (optional_elements.HasResult()) { - auto &repeat_elements = optional_elements.GetResult().Cast(); - for (auto &child_ref : repeat_elements.GetChildren()) { - auto &sub_list = child_ref.get().Cast(); - parts.push_back(transformer.Transform(sub_list.GetChild(1))); - } - } + const Identifier &identifier, + const vector &dot_col_label) { + vector parts {identifier.GetIdentifierName()}; + parts.insert(parts.end(), dot_col_label.begin(), dot_col_label.end()); return parts; } -ConstraintColumnDefinition PEGTransformerFactory::TransformColumnDefinition(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - - auto dotted_identifier = transformer.Transform>(list_pr.Child(0)); +ConstraintColumnDefinition +PEGTransformerFactory::TransformColumnDefinition(PEGTransformer &transformer, const vector &dotted_identifier, + const LogicalType &type, GeneratedColumnDefinition generated_column, + vector column_constraint) { auto qualified_name = StringToQualifiedName(dotted_identifier); - auto &type_opt = list_pr.Child(1); - auto &generated_opt = list_pr.Child(2); - LogicalType type = LogicalType::ANY; - if (!type_opt.HasResult() && !generated_opt.HasResult()) { + bool has_type = type != LogicalType::INVALID; + bool has_generated = generated_column.expr != nullptr; + if (!has_type && !has_generated) { throw ParserException("Column %s must have a type or be defined as a GENERATED column.", qualified_name.ToString()); } - transformer.TransformOptional(list_pr, 1, type); - auto &constraints_opt = list_pr.Child(4); + auto column_type = has_type ? type : LogicalType::ANY; CompressionType compression_type = CompressionType::COMPRESSION_AUTO; - ColumnConstraint column_constraint; - if (constraints_opt.HasResult()) { - auto &constraints_repeat = constraints_opt.GetResult().Cast(); - for (auto &constraint_entry : constraints_repeat.GetChildren()) { - auto &constraint_list = constraint_entry.get().Cast(); - auto &constraint = constraint_list.Child(0).GetResult(); - if (constraint.name == "DefaultValue") { - if (column_constraint.default_value) { - throw ParserException("Cannot define a default value twice"); - } - column_constraint.default_value = transformer.Transform>(constraint); - } else if (constraint.name == "NotNullConstraint" || constraint.name == "UniqueConstraint" || - constraint.name == "PrimaryKeyConstraint") { - column_constraint.constraint_types.push_back( - transformer.Transform>(constraint)); - } else if (constraint.name == "ColumnCompression") { - compression_type = transformer.Transform(constraint); - if (compression_type == CompressionType::COMPRESSION_AUTO) { - throw ParserException( - "Unrecognized option for column compression, expected none, uncompressed, rle, " - "dictionary, pfor, bitpacking, fsst, chimp, patas, zstd, alp, alprd or roaring"); - } - } else if (constraint.name == "ForeignKeyConstraint") { - auto fk_constraint = transformer.Transform>(constraint); - fk_constraint->fk_columns.push_back(qualified_name.name); - column_constraint.constraints.push_back(std::move(fk_constraint)); - } else if (constraint.name == "ColumnCollation") { - if (generated_opt.HasResult()) { - throw ParserException("Collations are not supported on generated columns"); + ColumnConstraint accumulated_constraints; + for (auto &cc_entry : column_constraint) { + if (cc_entry.constraint_name == "DefaultValue") { + if (accumulated_constraints.default_value) { + throw ParserException("Cannot define a default value twice"); + } + accumulated_constraints.default_value = std::move(cc_entry.expression); + } else if (cc_entry.constraint_name == "NotNullConstraint" || cc_entry.constraint_name == "UniqueConstraint" || + cc_entry.constraint_name == "PrimaryKeyConstraint") { + accumulated_constraints.constraint_types.push_back(cc_entry.constraint_type_info); + } else if (cc_entry.constraint_name == "ColumnCompression") { + compression_type = cc_entry.compression_type; + if (compression_type == CompressionType::COMPRESSION_AUTO) { + throw ParserException("Unrecognized option for column compression, expected none, uncompressed, rle, " + "dictionary, pfor, bitpacking, fsst, chimp, patas, zstd, alp, alprd or roaring"); + } + } else if (cc_entry.constraint_name == "ForeignKeyConstraint") { + auto &fk_constraint = cc_entry.constraint->Cast(); + fk_constraint.fk_columns.push_back(qualified_name.name); + accumulated_constraints.constraints.push_back(std::move(cc_entry.constraint)); + } else if (cc_entry.constraint_name == "ColumnCollation") { + if (has_generated) { + throw ParserException("Collations are not supported on generated columns"); + } + if (column_type.id() == LogicalTypeId::ANY) { + throw ParserException("Specify the VARCHAR type for column \"%s\" with collation.", + qualified_name.ToString()); + } else if (column_type.IsUnbound()) { + auto &expr = UnboundType::GetTypeExpression(column_type); + if (expr->GetExpressionClass() != ExpressionClass::TYPE) { + throw InternalException("Expected a type expression"); } - if (type.id() == LogicalTypeId::ANY) { - throw ParserException("Specify the VARCHAR type for column \"%s\" with collation.", - qualified_name.ToString()); - } else if (type.IsUnbound()) { - auto &expr = UnboundType::GetTypeExpression(type); - if (expr->GetExpressionClass() != ExpressionClass::TYPE) { - throw InternalException("Expected a type expression"); - } - auto &type_expr = expr->Cast(); - if (!StringUtil::CIEquals(type_expr.GetTypeName(), "VARCHAR")) { - throw ParserException("Only VARCHAR columns can have collations!"); - } - } else { - throw InternalException("Expected only unbound types here"); + auto &type_expr = expr->Cast(); + if (DefaultTypeGenerator::GetDefaultType(type_expr.GetTypeName()) != LogicalTypeId::VARCHAR) { + throw ParserException("Only VARCHAR columns can have collations!"); } - auto collation = transformer.Transform>(constraint); - vector> type_children; - type_children.push_back(std::move(collation)); - type = LogicalType::UNBOUND(make_uniq("VARCHAR", std::move(type_children))); } else { - column_constraint.constraints.push_back(transformer.Transform>(constraint)); + throw InternalException("Expected only unbound types here"); } + vector> type_children; + type_children.push_back(std::move(cc_entry.expression)); + column_type = + LogicalType::UNBOUND(make_uniq(Identifier("VARCHAR"), std::move(type_children))); + } else { + accumulated_constraints.constraints.push_back(std::move(cc_entry.constraint)); } } - if (generated_opt.HasResult()) { - auto generated = transformer.Transform(generated_opt.GetResult()); + if (has_generated) { + auto generated = std::move(generated_column); if (generated.expr->HasSubquery()) { throw ParserException("Expression of generated column \"%s\" contains a subquery, which isn't allowed", qualified_name.name); } - if (type != LogicalType::ANY) { - generated.expr = make_uniq(type, std::move(generated.expr)); + if (column_type != LogicalType::ANY) { + generated.expr = make_uniq(column_type, std::move(generated.expr)); } if (generated.expr->HasSubquery()) { throw ParserException("Expression of generated column \"%s\" contains a subquery, which isn't allowed", qualified_name.name); } - ColumnDefinition col(qualified_name.name, type, std::move(generated.expr), TableColumnType::GENERATED); + ColumnDefinition col(qualified_name.name, column_type, std::move(generated.expr), TableColumnType::GENERATED); col.SetCompressionType(compression_type); - if (column_constraint.default_value) { + if (accumulated_constraints.default_value) { throw ParserException("Not allowed to set default on a generated column"); } - ConstraintColumnDefinition result = {std::move(col), column_constraint.constraint_types, - std::move(column_constraint.constraints)}; + ConstraintColumnDefinition result = {std::move(col), accumulated_constraints.constraint_types, + std::move(accumulated_constraints.constraints)}; return result; } - ColumnDefinition col(qualified_name.name, type); + ColumnDefinition col(qualified_name.name, column_type); - if (column_constraint.default_value) { - col.SetDefaultValue(std::move(column_constraint.default_value)); + if (accumulated_constraints.default_value) { + col.SetDefaultValue(std::move(accumulated_constraints.default_value)); } col.SetCompressionType(compression_type); - ConstraintColumnDefinition result = {std::move(col), column_constraint.constraint_types, - std::move(column_constraint.constraints)}; + ConstraintColumnDefinition result = {std::move(col), accumulated_constraints.constraint_types, + std::move(accumulated_constraints.constraints)}; return result; } GeneratedColumnDefinition PEGTransformerFactory::TransformGeneratedColumn(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + unique_ptr expression, + const bool &generated_column_type) { GeneratedColumnDefinition generated; - auto &extract_parens = ExtractResultFromParens(list_pr.Child(2)); - generated.expr = transformer.Transform>(extract_parens); + generated.expr = std::move(expression); VerifyColumnRefs(*generated.expr); - auto &generated_column_type = list_pr.Child(3); - if (generated_column_type.HasResult()) { - transformer.Transform(generated_column_type.GetResult()); - } return generated; } -unique_ptr PEGTransformerFactory::TransformDefaultValue(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(1)); +ColumnConstraintEntry PEGTransformerFactory::TransformDefaultValue(PEGTransformer &transformer, + unique_ptr column_default_expr) { + ColumnConstraintEntry entry; + entry.constraint_name = "DefaultValue"; + entry.expression = std::move(column_default_expr); + return entry; } -unique_ptr PEGTransformerFactory::TransformTopLevelConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto result = transformer.Transform>(list_pr.Child(1)); - return result; +unique_ptr +PEGTransformerFactory::TransformTopLevelConstraint(PEGTransformer &transformer, + unique_ptr top_level_constraint_list) { + return top_level_constraint_list; } unique_ptr PEGTransformerFactory::TransformTopLevelConstraintList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); + ParseResult &choice_result) { + if (choice_result.name == "CheckConstraint") { + auto cc_entry = transformer.Transform(choice_result); + return std::move(cc_entry.constraint); + } + return transformer.Transform>(choice_result); } unique_ptr PEGTransformerFactory::TransformTopPrimaryKeyConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto column_list = transformer.Transform>(list_pr.Child(2)); - auto result = make_uniq(column_list, true); + const vector &column_id_list) { + auto result = make_uniq(StringsToIdentifiers(column_id_list), true); return std::move(result); } unique_ptr PEGTransformerFactory::TransformTopUniqueConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto column_list = transformer.Transform>(list_pr.Child(1)); - return make_uniq(column_list, false); + const vector &column_id_list) { + return make_uniq(StringsToIdentifiers(column_id_list), false); } -unique_ptr PEGTransformerFactory::TransformCheckConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); - auto check_expr = transformer.Transform>(extract_parens); - if (check_expr->HasSubquery()) { +ColumnConstraintEntry PEGTransformerFactory::TransformCheckConstraint(PEGTransformer &transformer, + unique_ptr expression) { + if (expression->HasSubquery()) { throw ParserException("subqueries prohibited in CHECK constraints"); } - auto result = make_uniq(std::move(check_expr)); - return std::move(result); + ColumnConstraintEntry entry; + entry.constraint_name = "CheckConstraint"; + entry.constraint = make_uniq(std::move(expression)); + return entry; } -unique_ptr PEGTransformerFactory::TransformTopForeignKeyConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto pk_list = transformer.Transform>(list_pr.Child(2)); - - auto fk_constraint = transformer.Transform>(list_pr.Child(3)); - fk_constraint->fk_columns = pk_list; - if (!fk_constraint->pk_columns.empty() && fk_constraint->fk_columns.size() != fk_constraint->pk_columns.size()) { +unique_ptr PEGTransformerFactory::TransformTopForeignKeyConstraint( + PEGTransformer &transformer, const vector &column_id_list, ColumnConstraintEntry foreign_key_constraint) { + auto &fk_constraint = foreign_key_constraint.constraint->Cast(); + fk_constraint.fk_columns = StringsToIdentifiers(column_id_list); + if (!fk_constraint.pk_columns.empty() && fk_constraint.fk_columns.size() != fk_constraint.pk_columns.size()) { throw ParserException("The number of referencing and referenced columns for foreign keys must be the same"); } - return std::move(fk_constraint); + return std::move(foreign_key_constraint.constraint); } -vector PEGTransformerFactory::TransformColumnIdList(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - vector result; - auto &colid_list = ExtractResultFromParens(list_pr.Child(0)); - auto colids = ExtractParseResultsFromList(colid_list); - for (auto colid : colids) { - result.push_back(transformer.Transform(colid)); - } - return result; -} - -string PEGTransformerFactory::TransformTypeFuncName(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - if (choice_pr.type == ParseResultType::IDENTIFIER) { - return choice_pr.Cast().identifier; - } - return transformer.Transform(choice_pr); +vector PEGTransformerFactory::TransformColumnIdList(PEGTransformer &transformer, + const vector &col_id) { + return IdentifiersToStrings(col_id); } -CompressionType PEGTransformerFactory::TransformColumnCompression(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto compression_string = transformer.Transform(list_pr.Child(2)); - return EnumUtil::FromString(StringUtil::Lower(compression_string)); +ColumnConstraintEntry PEGTransformerFactory::TransformColumnCompression(PEGTransformer &transformer, + const Identifier &col_id_or_string) { + ColumnConstraintEntry entry; + entry.constraint_name = "ColumnCompression"; + entry.compression_type = + EnumUtil::FromString(StringUtil::Lower(col_id_or_string.GetIdentifierName())); + return entry; } -unique_ptr PEGTransformerFactory::TransformForeignKeyConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +ColumnConstraintEntry PEGTransformerFactory::TransformForeignKeyConstraint(PEGTransformer &transformer, + unique_ptr base_table_name, + const vector &column_list, + const KeyActions &key_actions) { ForeignKeyInfo fk_info; - auto base_table = transformer.Transform>(list_pr.Child(1)); - fk_info.schema = base_table->schema_name; - fk_info.table = base_table->table_name; + fk_info.schema = base_table_name->schema_name; + fk_info.table = base_table_name->table_name; fk_info.type = ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE; - vector pk_list; - auto &col_list_opt = list_pr.Child(2); - if (col_list_opt.HasResult()) { - auto &extract_parens = ExtractResultFromParens(col_list_opt.GetResult()); - pk_list = transformer.Transform>(extract_parens); - } - auto key_actions = transformer.Transform(list_pr.Child(3)); - return make_uniq(pk_list, vector(), fk_info); + ColumnConstraintEntry entry; + entry.constraint_name = "ForeignKeyConstraint"; + entry.constraint = + make_uniq(StringsToIdentifiers(column_list), vector(), fk_info); + return entry; } -KeyActions PEGTransformerFactory::TransformKeyActions(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +KeyActions PEGTransformerFactory::TransformKeyActions(PEGTransformer &transformer, const string &update_action, + const string &delete_action) { KeyActions results; - transformer.TransformOptional(list_pr, 0, results.update_action); - transformer.TransformOptional(list_pr, 1, results.delete_action); + results.update_action = update_action; + results.delete_action = delete_action; return results; } -string PEGTransformerFactory::TransformUpdateAction(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(2)); -} - -string PEGTransformerFactory::TransformDeleteAction(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(2)); +string PEGTransformerFactory::TransformUpdateAction(PEGTransformer &transformer, const string &key_action) { + return key_action; } -string PEGTransformerFactory::TransformKeyAction(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0).GetResult()); +string PEGTransformerFactory::TransformDeleteAction(PEGTransformer &transformer, const string &key_action) { + return key_action; } -string PEGTransformerFactory::TransformNoKeyAction(PEGTransformer &transformer, ParseResult &parse_result) { +string PEGTransformerFactory::TransformNoKeyAction(PEGTransformer &transformer) { return "NoKeyAction"; } -string PEGTransformerFactory::TransformRestrictKeyAction(PEGTransformer &transformer, ParseResult &parse_result) { +string PEGTransformerFactory::TransformRestrictKeyAction(PEGTransformer &transformer) { return "Restrict"; } -string PEGTransformerFactory::TransformCascadeKeyAction(PEGTransformer &transformer, ParseResult &parse_result) { +string PEGTransformerFactory::TransformCascadeKeyAction(PEGTransformer &transformer) { throw ParserException("FOREIGN KEY constraints cannot use CASCADE, SET NULL or SET DEFAULT"); } -string PEGTransformerFactory::TransformSetNullKeyAction(PEGTransformer &transformer, ParseResult &parse_result) { +string PEGTransformerFactory::TransformSetNullKeyAction(PEGTransformer &transformer) { throw ParserException("FOREIGN KEY constraints cannot use CASCADE, SET NULL or SET DEFAULT"); } -string PEGTransformerFactory::TransformSetDefaultKeyAction(PEGTransformer &transformer, ParseResult &parse_result) { +string PEGTransformerFactory::TransformSetDefaultKeyAction(PEGTransformer &transformer) { throw ParserException("FOREIGN KEY constraints cannot use CASCADE, SET NULL or SET DEFAULT"); } -pair PEGTransformerFactory::TransformPrimaryKeyConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - return make_pair(true, ConstraintType::UNIQUE); +ColumnConstraintEntry PEGTransformerFactory::TransformPrimaryKeyConstraint(PEGTransformer &transformer) { + ColumnConstraintEntry entry; + entry.constraint_name = "PrimaryKeyConstraint"; + entry.constraint_type_info = make_pair(true, ConstraintType::UNIQUE); + return entry; } -pair PEGTransformerFactory::TransformUniqueConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - return make_pair(false, ConstraintType::UNIQUE); +ColumnConstraintEntry PEGTransformerFactory::TransformUniqueConstraint(PEGTransformer &transformer) { + ColumnConstraintEntry entry; + entry.constraint_name = "UniqueConstraint"; + entry.constraint_type_info = make_pair(false, ConstraintType::UNIQUE); + return entry; } -pair PEGTransformerFactory::TransformNotNullConstraint(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto not_null = list_pr.Child(0).HasResult(); - if (not_null) { - return make_pair(false, ConstraintType::NOT_NULL); - } - return make_pair(false, ConstraintType::INVALID); +bool PEGTransformerFactory::TransformNullConstraint(PEGTransformer &transformer) { + return false; } -unique_ptr PEGTransformerFactory::TransformColumnCollation(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto dotted_identifier = transformer.Transform>(list_pr.Child(1)); +bool PEGTransformerFactory::TransformNotNullColumnConstraint(PEGTransformer &transformer) { + return true; +} + +ColumnConstraintEntry PEGTransformerFactory::TransformNotNullConstraint(PEGTransformer &transformer, + const bool &child) { + ColumnConstraintEntry entry; + entry.constraint_name = "NotNullConstraint"; + entry.constraint_type_info = make_pair(false, child ? ConstraintType::NOT_NULL : ConstraintType::INVALID); + return entry; +} + +ColumnConstraintEntry PEGTransformerFactory::TransformColumnCollation(PEGTransformer &transformer, + const vector &dotted_identifier) { string collation = StringUtil::Join(dotted_identifier, "."); auto expr = make_uniq(Value(collation)); expr->SetAlias("collation"); - return std::move(expr); + ColumnConstraintEntry entry; + entry.constraint_name = "ColumnCollation"; + entry.expression = std::move(expr); + return entry; } -bool PEGTransformerFactory::TransformWithData(PEGTransformer &transformer, ParseResult &parse_result) { - // 'WITH' 'NO'? 'DATA' - auto &list_pr = parse_result.Cast(); - auto no_data = list_pr.Child(1).HasResult(); - return no_data; +bool PEGTransformerFactory::TransformWithDataOnly(PEGTransformer &transformer) { + return false; } -bool PEGTransformerFactory::TransformCommitAction(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(2)); +bool PEGTransformerFactory::TransformWithNoData(PEGTransformer &transformer) { + return true; } -bool PEGTransformerFactory::TransformPreserveOrDelete(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - auto &preserve_or_delete = choice_pr.Cast().keyword; - if (StringUtil::CIEquals(preserve_or_delete, "delete")) { - throw NotImplementedException("Only ON COMMIT PRESERVE ROWS is supported"); - } +bool PEGTransformerFactory::TransformCommitAction(PEGTransformer &transformer, const bool &preserve_or_delete) { + return preserve_or_delete; +} + +bool PEGTransformerFactory::TransformPreserveRows(PEGTransformer &transformer) { return true; } -bool PEGTransformerFactory::TransformGeneratedColumnType(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - auto &keyword = choice_pr.Cast().keyword; - if (StringUtil::CIEquals(keyword, "stored")) { - throw InvalidInputException("Can not create a STORED generated column!"); - } +bool PEGTransformerFactory::TransformDeleteRows(PEGTransformer &transformer) { + throw NotImplementedException("Only ON COMMIT PRESERVE ROWS is supported"); +} + +bool PEGTransformerFactory::TransformVirtualGeneratedColumn(PEGTransformer &transformer) { return true; } +bool PEGTransformerFactory::TransformStoredGeneratedColumn(PEGTransformer &transformer) { + throw InvalidInputException("Can not create a STORED generated column!"); +} + void PEGTransformerFactory::VerifyColumnRefs(const ParsedExpression &expr) { ParsedExpressionIterator::VisitExpression(expr, [&](const ColumnRefExpression &column_ref) { if (column_ref.IsQualified()) { @@ -585,62 +514,36 @@ void PEGTransformerFactory::VerifyColumnRefs(const ParsedExpression &expr) { }); } -vector> PEGTransformerFactory::TransformPartitionOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - // PartitionOptions <- 'PARTITIONED' 'BY' Parens(List(Expression)) - // child 0: 'PARTITIONED', child 1: 'BY', child 2: Parens(List(Expression)) - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(2)); - auto expr_list = ExtractParseResultsFromList(extract_parens); - vector> result; - for (auto expr_pr : expr_list) { - result.push_back(transformer.Transform>(expr_pr)); - } - return result; +vector> +PEGTransformerFactory::TransformPartitionOptions(PEGTransformer &transformer, + vector> expression) { + return expression; } -vector> PEGTransformerFactory::TransformSortedOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - // SortedOptions <- 'SORTED' 'BY' Parens(List(Expression)) - // child 0: 'SORTED', child 1: 'BY', child 2: Parens(List(Expression)) - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(2)); - auto expr_list = ExtractParseResultsFromList(extract_parens); - vector> result; - for (auto expr_pr : expr_list) { - result.push_back(transformer.Transform>(expr_pr)); - } - return result; +vector> +PEGTransformerFactory::TransformSortedOptions(PEGTransformer &transformer, + vector> expression) { + return expression; } -PartitionSortedOptions PEGTransformerFactory::TransformPartitionOptSortedOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - // PartitionOptSortedOptions <- PartitionOptions SortedOptions? - // child 0: PartitionOptions, child 1: SortedOptions? - auto &list_pr = parse_result.Cast(); +PartitionSortedOptions +PEGTransformerFactory::TransformPartitionOptSortedOptions(PEGTransformer &transformer, + vector> partition_options, + vector> sorted_options) { PartitionSortedOptions result; - result.partition_keys = - transformer.Transform>>(list_pr.Child(0)); - transformer.TransformOptional>>(list_pr, 1, result.sort_keys); + result.partition_keys = std::move(partition_options); + result.sort_keys = std::move(sorted_options); return result; } -PartitionSortedOptions PEGTransformerFactory::TransformSortedOptPartitionOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - // SortedOptPartitionOptions <- SortedOptions PartitionOptions? - // child 0: SortedOptions, child 1: PartitionOptions? - auto &list_pr = parse_result.Cast(); +PartitionSortedOptions +PEGTransformerFactory::TransformSortedOptPartitionOptions(PEGTransformer &transformer, + vector> sorted_options, + vector> partition_options) { PartitionSortedOptions result; - result.sort_keys = transformer.Transform>>(list_pr.Child(0)); - transformer.TransformOptional>>(list_pr, 1, result.partition_keys); + result.sort_keys = std::move(sorted_options); + result.partition_keys = std::move(partition_options); return result; } -PartitionSortedOptions PEGTransformerFactory::TransformPartitionSortedOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - // PartitionSortedOptions <- PartitionOptSortedOptions / SortedOptPartitionOptions - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0).GetResult()); -} - } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_trigger.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_trigger.cpp index c0f6e72ae..565ab846f 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_trigger.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_trigger.cpp @@ -22,110 +22,132 @@ static unique_ptr ExtractQueryNode(unique_ptr stmt) { } } -unique_ptr PEGTransformerFactory::TransformCreateTriggerStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - // CreateTriggerStmt <- 'TRIGGER' IfNotExists? TriggerName TriggerTiming TriggerEvent 'ON' BaseTableName - // ForEachClause? TriggerBody - auto &list_pr = parse_result.Cast(); - auto if_not_exists = list_pr.Child(1).HasResult(); - auto trigger_name = transformer.Transform(list_pr.Child(2)); // TriggerName - auto timing = transformer.Transform(list_pr.Child(3)); - auto trigger_event = transformer.Transform(list_pr.Child(4)); - // index 5 is 'ON' - auto base_table = transformer.Transform>(list_pr.Child(6)); - auto &for_each_opt = list_pr.Child(7); - TriggerForEach for_each = TriggerForEach::STATEMENT; - if (for_each_opt.HasResult()) { - for_each = transformer.Transform(for_each_opt.GetResult()); - } - auto trigger_action = transformer.Transform>(list_pr.Child(8)); - +unique_ptr PEGTransformerFactory::TransformCreateTriggerStmt( + PEGTransformer &transformer, const bool &if_not_exists, const Identifier &trigger_name, + const TriggerTiming &trigger_timing, const TriggerEventInfo &trigger_event, + unique_ptr base_table_name, const TriggerTableReferencingInfo &referencing_clause, + const TriggerForEach &for_each_clause, unique_ptr trigger_body) { auto result = make_uniq(); auto info = make_uniq(); info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; info->trigger_name = trigger_name; - info->timing = timing; + info->timing = trigger_timing; info->event_type = trigger_event.event_type; - info->columns = std::move(trigger_event.columns); - info->base_table = std::move(base_table); - info->for_each = for_each; - info->trigger_action = ExtractQueryNode(std::move(trigger_action)); + info->columns = trigger_event.columns; + info->base_table = std::move(base_table_name); + info->referencing_new_table = referencing_clause.new_table; + info->referencing_old_table = referencing_clause.old_table; + info->for_each = for_each_clause; + info->trigger_action = ExtractQueryNode(std::move(trigger_body)); result->info = std::move(info); return result; } -TriggerForEach PEGTransformerFactory::TransformForEachClause(PEGTransformer &transformer, ParseResult &parse_result) { - // ForEachClause <- ForEachRow / ForEachStatement - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); +Identifier PEGTransformerFactory::TransformTriggerName(PEGTransformer &transformer, const Identifier &identifier) { + return identifier; } -string PEGTransformerFactory::TransformTriggerName(PEGTransformer &transformer, ParseResult &parse_result) { +TriggerForEach PEGTransformerFactory::TransformForEachClause(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; + return transformer.TransformEnum(list_pr.Child(0).GetResult()); } TriggerTiming PEGTransformerFactory::TransformTriggerTiming(PEGTransformer &transformer, ParseResult &parse_result) { - // TriggerTiming <- TriggerBefore / TriggerAfter / TriggerInsteadOf auto &list_pr = parse_result.Cast(); return transformer.TransformEnum(list_pr.Child(0).GetResult()); } TriggerEventInfo PEGTransformerFactory::TransformTriggerEvent(PEGTransformer &transformer, ParseResult &parse_result) { - // TriggerEvent <- TriggerEventUpdateOf / TriggerEventInsert / TriggerEventDelete / TriggerEventUpdate auto &list_pr = parse_result.Cast(); return transformer.Transform(list_pr.Child(0).GetResult()); } -TriggerEventInfo PEGTransformerFactory::TransformTriggerEventInsert(PEGTransformer &transformer, - ParseResult &parse_result) { +TriggerTiming PEGTransformerFactory::TransformTriggerBefore(PEGTransformer &transformer) { + return TriggerTiming::BEFORE; +} + +TriggerTiming PEGTransformerFactory::TransformTriggerAfter(PEGTransformer &transformer) { + return TriggerTiming::AFTER; +} + +TriggerTiming PEGTransformerFactory::TransformTriggerInsteadOf(PEGTransformer &transformer) { + return TriggerTiming::INSTEAD_OF; +} + +TriggerEventInfo PEGTransformerFactory::TransformTriggerEventInsert(PEGTransformer &transformer) { TriggerEventInfo result; result.event_type = TriggerEventType::INSERT_EVENT; return result; } -TriggerEventInfo PEGTransformerFactory::TransformTriggerEventDelete(PEGTransformer &transformer, - ParseResult &parse_result) { +TriggerEventInfo PEGTransformerFactory::TransformTriggerEventDelete(PEGTransformer &transformer) { TriggerEventInfo result; result.event_type = TriggerEventType::DELETE_EVENT; return result; } -TriggerEventInfo PEGTransformerFactory::TransformTriggerEventUpdate(PEGTransformer &transformer, - ParseResult &parse_result) { +TriggerEventInfo PEGTransformerFactory::TransformTriggerEventUpdate(PEGTransformer &transformer) { TriggerEventInfo result; result.event_type = TriggerEventType::UPDATE_EVENT; return result; } TriggerEventInfo PEGTransformerFactory::TransformTriggerEventUpdateOf(PEGTransformer &transformer, - ParseResult &parse_result) { - // TriggerEventUpdateOf <- 'UPDATE' 'OF' TriggerColumnList - auto &list_pr = parse_result.Cast(); + const vector &trigger_column_list) { TriggerEventInfo result; result.event_type = TriggerEventType::UPDATE_EVENT; - result.columns = transformer.Transform>(list_pr.Child(2)); + result.columns = StringsToIdentifiers(trigger_column_list); return result; } vector PEGTransformerFactory::TransformTriggerColumnList(PEGTransformer &transformer, - ParseResult &parse_result) { - // TriggerColumnList <- List(ColId) - auto &list_pr = parse_result.Cast(); - auto column_list = ExtractParseResultsFromList(list_pr.Child(0)); - vector result; - for (auto &column : column_list) { - result.push_back(transformer.Transform(column)); + const vector &col_id) { + return IdentifiersToStrings(col_id); +} + +TriggerTableReferencingInfo PEGTransformerFactory::TransformReferencingNewTableAs(PEGTransformer &transformer, + const Identifier &col_id) { + TriggerTableReferencingInfo info; + info.new_table = Identifier(col_id); + return info; +} + +TriggerTableReferencingInfo PEGTransformerFactory::TransformReferencingOldTableAs(PEGTransformer &transformer, + const Identifier &col_id) { + TriggerTableReferencingInfo info; + info.old_table = Identifier(col_id); + return info; +} + +TriggerTableReferencingInfo +PEGTransformerFactory::TransformReferencingClause(PEGTransformer &transformer, + const TriggerTableReferencingInfo &referencing_item, + const TriggerTableReferencingInfo &referencing_item_1) { + auto result = referencing_item; + if (!referencing_item_1.new_table.empty()) { + if (!result.new_table.empty()) { + throw ParserException("NEW TABLE cannot be specified multiple times in REFERENCING clause"); + } + result.new_table = referencing_item_1.new_table; + } + if (!referencing_item_1.old_table.empty()) { + if (!result.old_table.empty()) { + throw ParserException("OLD TABLE cannot be specified multiple times in REFERENCING clause"); + } + result.old_table = referencing_item_1.old_table; + } + if (!result.new_table.empty() && !result.old_table.empty() && result.new_table == result.old_table) { + throw ParserException("REFERENCING aliases must be distinct"); } return result; } -unique_ptr PEGTransformerFactory::TransformTriggerBody(PEGTransformer &transformer, - ParseResult &parse_result) { - // TriggerBody <- InsertStatement / UpdateStatement / DeleteStatement - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - return transformer.Transform>(choice_pr.GetResult()); +TriggerForEach PEGTransformerFactory::TransformForEachRow(PEGTransformer &transformer) { + return TriggerForEach::ROW; +} + +TriggerForEach PEGTransformerFactory::TransformForEachStatement(PEGTransformer &transformer) { + return TriggerForEach::STATEMENT; } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_type.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_type.cpp index a58432d7e..179c4b8f0 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_type.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_type.cpp @@ -6,59 +6,45 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformCreateTypeStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const bool &if_not_exists, + const QualifiedName &qualified_name, + unique_ptr create_type) { auto result = make_uniq(); - auto if_not_exists = list_pr.Child(1).HasResult(); - auto qualified_name = transformer.Transform(list_pr.Child(2)); - auto create_type_info = transformer.Transform>(list_pr.Child(4)); - create_type_info->catalog = qualified_name.catalog; - create_type_info->schema = qualified_name.schema; - create_type_info->name = qualified_name.name; - create_type_info->on_conflict = + create_type->catalog = qualified_name.catalog; + create_type->schema = qualified_name.schema; + create_type->name = qualified_name.name; + create_type->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; - result->info = std::move(create_type_info); + result->info = std::move(create_type); return result; } -unique_ptr PEGTransformerFactory::TransformCreateType(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr PEGTransformerFactory::TransformCreateTypeFromType(PEGTransformer &transformer, + const LogicalType &type) { auto result = make_uniq(); - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - if (choice_pr.GetResult().name == "EnumSelectType") { - result->query = transformer.Transform>(choice_pr.GetResult()); - result->type = LogicalType::INVALID; - } else { - result->type = transformer.Transform(choice_pr.GetResult()); - } + result->type = type; return result; } -unique_ptr PEGTransformerFactory::TransformEnumSelectType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); - return transformer.Transform>(extract_parens); +unique_ptr +PEGTransformerFactory::TransformEnumSelectType(PEGTransformer &transformer, + unique_ptr select_statement_internal) { + auto result = make_uniq(); + result->query = std::move(select_statement_internal); + result->type = LogicalType::INVALID; + return result; } -LogicalType PEGTransformerFactory::TransformEnumStringLiteralList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); - auto &string_list_opt = extract_parens.Cast(); - if (!string_list_opt.HasResult()) { - Vector enum_vector(LogicalType::VARCHAR, (idx_t)0); - return LogicalType::ENUM(enum_vector, 0); - } - auto string_literal_list = ExtractParseResultsFromList(string_list_opt.GetResult()); - - Vector enum_vector(LogicalType::VARCHAR, string_literal_list.size()); - auto string_data = FlatVector::Writer(enum_vector, string_literal_list.size()); - for (auto string_literal : string_literal_list) { - string_data.WriteValue(string_t(string_literal.get().Cast().result)); +unique_ptr PEGTransformerFactory::TransformEnumStringLiteralList(PEGTransformer &transformer, + const vector &string_literal) { + auto result = make_uniq(); + Vector enum_vector(LogicalType::VARCHAR, string_literal.size()); + auto string_data = FlatVector::Writer(enum_vector, string_literal.size()); + for (auto &literal : string_literal) { + string_data.WriteValue(string_t(literal)); } - return LogicalType::ENUM(enum_vector, string_literal_list.size()); + result->type = LogicalType::ENUM(enum_vector, string_literal.size()); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_view.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_view.cpp index 895266866..a281e7604 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_view.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_view.cpp @@ -4,8 +4,8 @@ #include "duckdb/parser/query_node/set_operation_node.hpp" namespace duckdb { -unique_ptr PEGTransformerFactory::ToRecursiveCTE(unique_ptr node, const string &name, - vector &aliases, +unique_ptr PEGTransformerFactory::ToRecursiveCTE(unique_ptr node, const Identifier &name, + vector &aliases, vector> &key_targets) { if (node->type != QueryNodeType::SET_OPERATION_NODE) { return node; @@ -29,15 +29,12 @@ unique_ptr PEGTransformerFactory::ToRecursiveCTE(unique_ptr(std::move(node)); recursive_node->union_all = owned_set_node->setop_all; - if (!owned_set_node->modifiers.empty()) { - for (auto &modifier : owned_set_node->modifiers) { - if (modifier->type == ResultModifierType::LIMIT_MODIFIER || - modifier->type == ResultModifierType::LIMIT_PERCENT_MODIFIER) { - throw ParserException("LIMIT or OFFSET in a recursive query is not allowed"); - } - if (modifier->type == ResultModifierType::ORDER_MODIFIER) { - throw ParserException("ORDER BY in a recursive query is not allowed"); - } + for (auto &modifier : owned_set_node->modifiers) { + if (modifier->type == ResultModifierType::LIMIT_MODIFIER) { + throw ParserException("LIMIT or OFFSET in a recursive query is not allowed"); + } + if (modifier->type == ResultModifierType::ORDER_MODIFIER) { + throw ParserException("ORDER BY in a recursive query is not allowed"); } } if (owned_set_node->children.size() == 2) { @@ -89,28 +86,21 @@ void PEGTransformerFactory::ConvertToRecursiveView(unique_ptr &i WrapRecursiveView(info, std::move(result_node)); } -unique_ptr PEGTransformerFactory::TransformCreateViewStmt(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto if_not_exists = list_pr.Child(2).HasResult(); - auto qualified_name = transformer.Transform(list_pr.Child(3)); - auto &insert_column_list_pr = list_pr.Child(4); - vector column_list; - if (insert_column_list_pr.HasResult()) { - column_list = transformer.Transform>(insert_column_list_pr.GetResult()); - } +unique_ptr +PEGTransformerFactory::TransformCreateViewStmt(PEGTransformer &transformer, const bool &create_recursive, + const bool &if_not_exists, const QualifiedName &qualified_name, + const vector &insert_column_list, + case_insensitive_map_t> with_list, + unique_ptr select_statement_internal) { auto result = make_uniq(); auto info = make_uniq(); info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; info->catalog = qualified_name.catalog; info->schema = qualified_name.schema; info->view_name = qualified_name.name; - info->aliases = column_list; - auto &with_list_opt = list_pr.Child(5); - if (with_list_opt.HasResult()) { - auto options_expr = - transformer.Transform>>(with_list_opt.GetResult()); - for (auto &option_entry : options_expr) { + info->aliases = StringsToIdentifiers(insert_column_list); + if (!with_list.empty()) { + for (auto &option_entry : with_list) { if (!StringUtil::CIEquals(option_entry.first, "defer_binding")) { throw ParserException("Only DEFER_BINDING is currently supported as option for CREATE VIEW"); } @@ -127,15 +117,18 @@ unique_ptr PEGTransformerFactory::TransformCreateViewStmt(PEGTr } } } - auto select_statement = transformer.Transform>(list_pr.Child(7)); - if (list_pr.Child(0).HasResult()) { - ConvertToRecursiveView(info, select_statement->node); + if (create_recursive) { + ConvertToRecursiveView(info, select_statement_internal->node); } else { - info->query = std::move(select_statement); + info->query = std::move(select_statement_internal); } transformer.PivotEntryCheck("view"); result->info = std::move(info); return result; } +bool PEGTransformerFactory::TransformCreateRecursive(PEGTransformer &transformer) { + return true; +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_deallocate.cpp b/src/duckdb/src/parser/peg/transformer/transform_deallocate.cpp index 505c8d27b..c1403a0a8 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_deallocate.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_deallocate.cpp @@ -4,12 +4,16 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformDeallocateStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const bool &deallocate_prepare, + const Identifier &identifier) { auto result = make_uniq(); result->info->type = CatalogType::PREPARED_STATEMENT; - result->info->name = list_pr.Child(2).identifier; + result->info->name = identifier; return std::move(result); } +bool PEGTransformerFactory::TransformDeallocatePrepare(PEGTransformer &transformer) { + return true; +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_delete.cpp b/src/duckdb/src/parser/peg/transformer/transform_delete.cpp index c67b7cbb9..5e0ff952f 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_delete.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_delete.cpp @@ -4,47 +4,40 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformDeleteStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - +unique_ptr PEGTransformerFactory::TransformDeleteStatement( + PEGTransformer &transformer, CommonTableExpressionMap with_clause, unique_ptr target_opt_alias, + vector> delete_using_clause, unique_ptr where_clause, + vector> returning_clause) { auto result = make_uniq(); auto &node = *result->node; - auto &with_opt = list_pr.Child(0); - if (with_opt.HasResult()) { - node.cte_map = transformer.Transform(with_opt.GetResult()); + if (!with_clause.map.empty()) { + node.cte_map = std::move(with_clause); } - node.table = transformer.Transform>(list_pr.Child(3)); - transformer.TransformOptional>>(list_pr, 4, node.using_clauses); - transformer.TransformOptional>(list_pr, 5, node.condition); - transformer.TransformOptional>>(list_pr, 6, node.returning_list); + node.table = std::move(target_opt_alias); + node.using_clauses = std::move(delete_using_clause); + node.condition = std::move(where_clause); + node.returning_list = std::move(returning_clause); return std::move(result); } unique_ptr PEGTransformerFactory::TransformTargetOptAlias(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto table_ref = transformer.Transform>(list_pr.Child(0)); - transformer.TransformOptional(list_pr, 2, table_ref->alias); - return table_ref; + unique_ptr base_table_name, + const Identifier &col_id) { + if (!col_id.empty()) { + base_table_name->alias = Identifier(col_id); + } + return base_table_name; } vector> PEGTransformerFactory::TransformDeleteUsingClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto table_ref_list = ExtractParseResultsFromList(list_pr.Child(1)); - vector> result; - for (auto table_ref : table_ref_list) { - result.push_back(transformer.Transform>(table_ref)); - } - return result; + vector> table_ref) { + return table_ref; } unique_ptr PEGTransformerFactory::TransformTruncateStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + unique_ptr base_table_name) { auto result = make_uniq(); - result->node->table = transformer.Transform>(list_pr.Child(2)); + result->node->table = std::move(base_table_name); return std::move(result); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_describe.cpp b/src/duckdb/src/parser/peg/transformer/transform_describe.cpp index 58a6a5911..6302ea496 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_describe.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_describe.cpp @@ -4,21 +4,18 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformDescribeStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + unique_ptr child) { auto select_statement = make_uniq(); - select_statement->node = - transformer.Transform>(list_pr.Child(0).GetResult()); + select_statement->node = std::move(child); return select_statement; } -unique_ptr PEGTransformerFactory::TransformShowSelect(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformShowSelect(PEGTransformer &transformer, const ShowType &show_or_describe_or_summarize, + unique_ptr select_statement_internal) { auto result = make_uniq(); - result->show_type = transformer.Transform(list_pr.Child(0)); - auto select_statement = transformer.Transform>(list_pr.Child(1)); - result->query = std::move(select_statement->node); + result->show_type = show_or_describe_or_summarize; + result->query = std::move(select_statement_internal->node); auto select_node = make_uniq(); select_node->select_list.push_back(make_uniq()); select_node->from_table = std::move(result); @@ -26,11 +23,10 @@ unique_ptr PEGTransformerFactory::TransformShowSelect(PEGTransformer } unique_ptr PEGTransformerFactory::TransformShowTables(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const ShowType &show_or_describe, + const QualifiedName &qualified_name) { auto showref = make_uniq(); showref->show_type = ShowType::SHOW_FROM; - auto qualified_name = transformer.Transform(list_pr.Child(3)); if (!IsInvalidCatalog(qualified_name.catalog)) { throw ParserException("Expected \"SHOW TABLES FROM database\", \"SHOW TABLES FROM schema\", or " "\"SHOW TABLES FROM database.schema\""); @@ -48,7 +44,7 @@ unique_ptr PEGTransformerFactory::TransformShowTables(PEGTransformer } unique_ptr PEGTransformerFactory::TransformShowAllTables(PEGTransformer &transformer, - ParseResult &parse_result) { + const ShowType &show_or_describe) { auto result = make_uniq(); // Legacy reasons, see bind_showref.cpp result->table_name = "__show_tables_expanded"; @@ -60,38 +56,33 @@ unique_ptr PEGTransformerFactory::TransformShowAllTables(PEGTransform } unique_ptr PEGTransformerFactory::TransformShowQualifiedName(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const ShowType &show_or_describe_or_summarize, + DescribeTarget describe_target) { auto showref = make_uniq(); + showref->show_type = show_or_describe_or_summarize; - showref->show_type = transformer.Transform(list_pr.Child(0)); - - auto &opt_table_name_parens = list_pr.Child(1); - if (opt_table_name_parens.HasResult()) { - auto &base_table_or_string = opt_table_name_parens.GetResult().Cast(); - auto &choice_pr = base_table_or_string.Child(0); - - if (choice_pr.GetResult().type == ParseResultType::STRING) { + if (describe_target.is_table_name || describe_target.table_ref) { + if (describe_target.is_table_name) { // Case: SHOW 'something' or DESCRIBE 'something' - showref->table_name = choice_pr.GetResult().Cast().result; + showref->table_name = describe_target.table_name; } else { // Case: A relation/table reference - auto base_table = transformer.Transform>(choice_pr.GetResult()); + auto &base_table = *describe_target.table_ref; if (showref->show_type == ShowType::SHOW_FROM) { // Logic for SHOW TABLES FROM [database].[schema] - if (IsInvalidSchema(base_table->schema_name)) { - showref->schema_name = base_table->table_name; + if (IsInvalidSchema(base_table.schema_name)) { + showref->schema_name = base_table.table_name; } else { - showref->catalog_name = base_table->schema_name; - showref->schema_name = base_table->table_name; + showref->catalog_name = base_table.schema_name; + showref->schema_name = base_table.table_name; } - } else if (IsInvalidSchema(base_table->schema_name)) { + } else if (IsInvalidSchema(base_table.schema_name)) { // Logic for unqualified relations (databases, tables, variables) - auto table_name = StringUtil::Lower(base_table->table_name); + auto table_name = StringUtil::Lower(base_table.table_name.GetIdentifierName()); if (table_name == "databases" || table_name == "tables" || table_name == "schemas" || table_name == "variables") { - showref->table_name = "\"" + table_name + "\""; + showref->table_name = Identifier("\"" + table_name + "\""); showref->show_type = ShowType::SHOW_UNQUALIFIED; } } @@ -99,14 +90,14 @@ unique_ptr PEGTransformerFactory::TransformShowQualifiedName(PEGTrans if (showref->table_name.empty() && showref->show_type != ShowType::SHOW_FROM) { auto show_select_node = make_uniq(); show_select_node->select_list.push_back(make_uniq()); - if (choice_pr.GetResult().type == ParseResultType::STRING) { + if (describe_target.is_table_name) { // Case: SHOW 'something' or DESCRIBE 'something' auto table_ref = make_uniq(); - table_ref->table_name = choice_pr.GetResult().Cast().result; + table_ref->table_name = describe_target.table_name; show_select_node->from_table = std::move(table_ref); } else { // Case: A relation/table reference - show_select_node->from_table = transformer.Transform>(choice_pr.GetResult()); + show_select_node->from_table = std::move(describe_target.table_ref); } showref->query = std::move(show_select_node); } @@ -125,20 +116,36 @@ unique_ptr PEGTransformerFactory::TransformShowQualifiedName(PEGTrans return std::move(select_node); } -ShowType PEGTransformerFactory::TransformShowOrDescribeOrSummarize(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0).GetResult()); + +DescribeTarget PEGTransformerFactory::TransformDescribeBaseTableName(PEGTransformer &transformer, + unique_ptr base_table_name) { + DescribeTarget result; + result.table_ref = std::move(base_table_name); + return result; +} + +DescribeTarget PEGTransformerFactory::TransformDescribeStringLiteral(PEGTransformer &transformer, + const string &string_literal) { + DescribeTarget result; + result.is_table_name = true; + result.table_name = Identifier(string_literal); + return result; +} + +ShowType PEGTransformerFactory::TransformSummarizeRule(PEGTransformer &transformer) { + return ShowType::SUMMARY; +} + +ShowType PEGTransformerFactory::TransformShowRule(PEGTransformer &transformer) { + return ShowType::DESCRIBE; } -ShowType PEGTransformerFactory::TransformShowOrDescribe(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); +ShowType PEGTransformerFactory::TransformDescribeLongRule(PEGTransformer &transformer) { + return ShowType::DESCRIBE; } -ShowType PEGTransformerFactory::TransformSummarize(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0)); +ShowType PEGTransformerFactory::TransformDescRule(PEGTransformer &transformer) { + return ShowType::DESCRIBE; } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_detach.cpp b/src/duckdb/src/parser/peg/transformer/transform_detach.cpp index ecd7c7b33..6ada2bd52 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_detach.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_detach.cpp @@ -3,8 +3,9 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformDetachStatement(const bool &if_exists, - const string &catalog_name) { +unique_ptr PEGTransformerFactory::TransformDetachStatement(PEGTransformer &transformer, + const bool &if_exists, + const Identifier &catalog_name) { auto result = make_uniq(); auto info = make_uniq(); info->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; diff --git a/src/duckdb/src/parser/peg/transformer/transform_drop.cpp b/src/duckdb/src/parser/peg/transformer/transform_drop.cpp index 93aa4e286..d835c7cc5 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_drop.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_drop.cpp @@ -5,81 +5,71 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformDropStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto drop_entry = transformer.Transform>(list_pr.Child(1)); - transformer.TransformOptional(list_pr, 2, drop_entry->info->cascade); - return std::move(drop_entry); -} - -unique_ptr PEGTransformerFactory::TransformDropEntries(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); + unique_ptr drop_entries, + const bool &drop_behavior) { + drop_entries->info->cascade = drop_behavior; + return std::move(drop_entries); } unique_ptr PEGTransformerFactory::TransformDropTable(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const CatalogType &table_or_view, + const bool &if_exists, + vector> base_table_name) { auto result = make_uniq(); auto info = make_uniq(); - auto catalog_type = transformer.Transform(list_pr.Child(0)); - bool if_exists = list_pr.Child(1).HasResult(); - auto base_table_list = ExtractParseResultsFromList(list_pr.Child(2)); - if (base_table_list.size() > 1) { + if (base_table_name.size() > 1) { throw NotImplementedException("Can only drop one object at a time"); } - auto base_table = transformer.Transform>(base_table_list[0]); + auto base_table = std::move(base_table_name[0]); info->catalog = base_table->catalog_name; info->schema = base_table->schema_name; info->name = base_table->table_name; - info->type = catalog_type; + info->type = table_or_view; info->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; result->info = std::move(info); return result; } -CatalogType PEGTransformerFactory::TransformTableOrView(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - if (choice_pr.name == "MaterializedViewEntry") { - throw NotImplementedException("Cannot drop MATERIALIZED VIEW yet"); - } - return transformer.TransformEnum(choice_pr); +CatalogType PEGTransformerFactory::TransformMaterializedViewEntry(PEGTransformer &transformer) { + throw NotImplementedException("Cannot drop MATERIALIZED VIEW yet"); +} + +bool PEGTransformerFactory::TransformFunctionTypeMacroKeyword(PEGTransformer &transformer) { + return true; } -unique_ptr PEGTransformerFactory::TransformDropTableFunction(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +bool PEGTransformerFactory::TransformFunctionTypeFunction(PEGTransformer &transformer) { + return false; +} + +unique_ptr +PEGTransformerFactory::TransformDropTableFunction(PEGTransformer &transformer, const CatalogType &comment_macro_table, + const bool &if_exists, + const vector &table_function_name) { auto result = make_uniq(); auto info = make_uniq(); - auto catalog_type = transformer.TransformEnum(list_pr.Child(0)); - bool if_exists = list_pr.Child(1).HasResult(); - auto table_function_list = ExtractParseResultsFromList(list_pr.Child(2)); - if (table_function_list.size() > 1) { + if (table_function_name.size() > 1) { throw NotImplementedException("Can only drop one object at a time"); } - info->name = table_function_list[0].get().Cast().identifier; + info->name = table_function_name[0]; info->catalog = INVALID_CATALOG; info->schema = INVALID_SCHEMA; - info->type = catalog_type; + info->type = comment_macro_table; info->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; result->info = std::move(info); return result; } -unique_ptr PEGTransformerFactory::TransformDropFunction(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformDropFunction(PEGTransformer &transformer, const bool &function_type_macro, + const bool &if_exists, const vector &function_identifier) { auto result = make_uniq(); auto info = make_uniq(); auto catalog_type = CatalogType::MACRO_ENTRY; - bool if_exists = list_pr.Child(1).HasResult(); - auto function_list = ExtractParseResultsFromList(list_pr.Child(2)); - if (function_list.size() > 1) { + if (function_identifier.size() > 1) { throw NotImplementedException("Can only drop one object at a time"); } - auto function = transformer.Transform(function_list[0]); + const auto &function = function_identifier[0]; info->catalog = function.catalog.empty() ? INVALID_CATALOG : function.catalog; info->schema = function.schema; info->name = function.name; @@ -89,17 +79,15 @@ unique_ptr PEGTransformerFactory::TransformDropFunction(PEGTransf return result; } -unique_ptr PEGTransformerFactory::TransformDropSchema(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformDropSchema(PEGTransformer &transformer, const bool &if_exists, + const vector &qualified_schema_name) { auto result = make_uniq(); auto info = make_uniq(); - bool if_exists = list_pr.Child(1).HasResult(); - auto schema_list = ExtractParseResultsFromList(list_pr.Child(2)); - if (schema_list.size() > 1) { + if (qualified_schema_name.size() > 1) { throw NotImplementedException("Can only drop one object at a time"); } - auto schema = transformer.Transform(schema_list[0]); + const auto &schema = qualified_schema_name[0]; info->catalog = schema.catalog; info->name = schema.schema; info->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; @@ -108,28 +96,31 @@ unique_ptr PEGTransformerFactory::TransformDropSchema(PEGTransfor return result; } -QualifiedName PEGTransformerFactory::TransformQualifiedSchemaName(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +QualifiedName PEGTransformerFactory::TransformQualifiedSchemaNameString(PEGTransformer &transformer, + const Identifier &schema_name) { QualifiedName result; - string catalog = INVALID_CATALOG; - transformer.TransformOptional(list_pr, 0, catalog); - result.catalog = catalog; - result.schema = list_pr.Child(1).identifier; + result.catalog = INVALID_CATALOG; + result.schema = schema_name; return result; } -unique_ptr PEGTransformerFactory::TransformDropIndex(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +QualifiedName PEGTransformerFactory::TransformCatalogReservedSchema(PEGTransformer &transformer, + const Identifier &catalog_qualification, + const Identifier &reserved_schema_name) { + QualifiedName result; + result.catalog = catalog_qualification; + result.schema = reserved_schema_name; + return result; +} + +unique_ptr PEGTransformerFactory::TransformDropIndex(PEGTransformer &transformer, const bool &if_exists, + const vector &qualified_index_name) { auto result = make_uniq(); auto info = make_uniq(); - bool if_exists = list_pr.Child(1).HasResult(); - auto index_list = ExtractParseResultsFromList(list_pr.Child(2)); - if (index_list.size() > 1) { + if (qualified_index_name.size() > 1) { throw NotImplementedException("Can only drop one object at a time"); } - auto index = transformer.Transform(index_list[0]); + const auto &index = qualified_index_name[0]; info->catalog = index.catalog; info->schema = index.schema; info->name = index.name; @@ -139,29 +130,44 @@ unique_ptr PEGTransformerFactory::TransformDropIndex(PEGTransform return result; } -QualifiedName PEGTransformerFactory::TransformQualifiedIndexName(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +QualifiedName PEGTransformerFactory::TransformQualifiedIndexNameString(PEGTransformer &transformer, + const Identifier &index_name) { QualifiedName result; result.catalog = INVALID_CATALOG; result.schema = INVALID_SCHEMA; - transformer.TransformOptional(list_pr, 0, result.catalog); - transformer.TransformOptional(list_pr, 1, result.schema); - result.name = list_pr.Child(2).identifier; + result.name = index_name; return result; } -unique_ptr PEGTransformerFactory::TransformDropSequence(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +QualifiedName PEGTransformerFactory::TransformSchemaReservedIndex(PEGTransformer &transformer, + const Identifier &schema_qualification, + const Identifier &reserved_index_name) { + QualifiedName result; + result.catalog = INVALID_CATALOG; + result.schema = schema_qualification; + result.name = reserved_index_name; + return result; +} + +QualifiedName PEGTransformerFactory::TransformCatalogReservedSchemaIndex( + PEGTransformer &transformer, const Identifier &catalog_qualification, + const Identifier &reserved_schema_qualification, const Identifier &reserved_index_name) { + QualifiedName result; + result.catalog = catalog_qualification; + result.schema = reserved_schema_qualification; + result.name = reserved_index_name; + return result; +} + +unique_ptr +PEGTransformerFactory::TransformDropSequence(PEGTransformer &transformer, const bool &if_exists, + const vector &qualified_sequence_name) { auto result = make_uniq(); auto info = make_uniq(); - bool if_exists = list_pr.Child(1).HasResult(); - auto sequence_list = ExtractParseResultsFromList(list_pr.Child(2)); - if (sequence_list.size() > 1) { + if (qualified_sequence_name.size() > 1) { throw NotImplementedException("Can only drop one object at a time"); } - auto sequence = transformer.Transform(sequence_list[0]); + const auto &sequence = qualified_sequence_name[0]; if (sequence.schema.empty()) { info->schema = sequence.catalog; } else { @@ -175,24 +181,21 @@ unique_ptr PEGTransformerFactory::TransformDropSequence(PEGTransf return result; } -string PEGTransformerFactory::TransformCollationName(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; +Identifier PEGTransformerFactory::TransformCollationName(PEGTransformer &transformer, const Identifier &identifier) { + return identifier; } unique_ptr PEGTransformerFactory::TransformDropCollation(PEGTransformer &transformer, - ParseResult &parse_result) { + const bool &if_exists, + const vector &collation_name) { throw NotImplementedException("Cannot drop collation yet"); /* - *auto &list_pr = parse_result.Cast(); auto result = make_uniq(); auto info = make_uniq(); - bool if_exists = list_pr.Child(1).HasResult(); - auto collation_list = ExtractParseResultsFromList(list_pr.Child(2)); - if (collation_list.size() > 1) { + if (collation_name.size() > 1) { throw NotImplementedException("Can only drop one object at a time"); } - auto collation = transformer.Transform(collation_list[0]); + auto collation = collation_name[0]; info->catalog = INVALID_CATALOG; info->schema = INVALID_SCHEMA; info->name = collation; @@ -203,17 +206,14 @@ unique_ptr PEGTransformerFactory::TransformDropCollation(PEGTrans */ } -unique_ptr PEGTransformerFactory::TransformDropType(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr PEGTransformerFactory::TransformDropType(PEGTransformer &transformer, const bool &if_exists, + const vector &qualified_type_name) { auto result = make_uniq(); auto info = make_uniq(); - bool if_exists = list_pr.Child(1).HasResult(); - auto type_list = ExtractParseResultsFromList(list_pr.Child(2)); - if (type_list.size() > 1) { + if (qualified_type_name.size() > 1) { throw NotImplementedException("Can only drop one object at a time"); } - auto type = transformer.Transform(type_list[0]); + const auto &type = qualified_type_name[0]; if (type.schema.empty()) { info->schema = type.catalog; } else { @@ -227,55 +227,55 @@ unique_ptr PEGTransformerFactory::TransformDropType(PEGTransforme return result; } -bool PEGTransformerFactory::TransformDropBehavior(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - return StringUtil::CIEquals(choice_pr.Cast().keyword, "cascade"); +bool PEGTransformerFactory::TransformCascadeDropBehavior(PEGTransformer &transformer) { + return true; +} + +bool PEGTransformerFactory::TransformRestrictDropBehavior(PEGTransformer &transformer) { + return false; } -bool PEGTransformerFactory::TransformIfExists(PEGTransformer &transformer, ParseResult &parse_result) { +bool PEGTransformerFactory::TransformIfExists(PEGTransformer &transformer) { return true; } unique_ptr PEGTransformerFactory::TransformDropSecret(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const SecretPersistType &temporary, + const bool &if_exists, + const Identifier &secret_name, + const Identifier &drop_secret_storage) { auto result = make_uniq(); auto info = make_uniq(); info->type = CatalogType::SECRET_ENTRY; auto extra_drop_info = make_uniq(); - extra_drop_info->persist_mode = SecretPersistType::DEFAULT; - transformer.TransformOptional(list_pr, 0, extra_drop_info->persist_mode); + extra_drop_info->persist_mode = temporary; - bool if_exists = list_pr.Child(2).HasResult(); info->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; - info->name = transformer.Transform(list_pr.Child(3)); - transformer.TransformOptional(list_pr, 4, extra_drop_info->secret_storage); + info->name = secret_name; + extra_drop_info->secret_storage = drop_secret_storage.GetIdentifierName(); info->extra_drop_info = std::move(extra_drop_info); result->info = std::move(info); return result; } -string PEGTransformerFactory::TransformDropSecretStorage(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(1).identifier; +Identifier PEGTransformerFactory::TransformDropSecretStorage(PEGTransformer &transformer, + const Identifier &identifier) { + return identifier; } unique_ptr PEGTransformerFactory::TransformDropTrigger(PEGTransformer &transformer, - ParseResult &parse_result) { - // DropTrigger <- 'TRIGGER' IfExists? TriggerName 'ON' BaseTableName - auto &list_pr = parse_result.Cast(); + const bool &if_exists, + const Identifier &trigger_name, + unique_ptr base_table_name) { auto result = make_uniq(); auto info = make_uniq(); info->type = CatalogType::TRIGGER_ENTRY; - bool if_exists = list_pr.Child(1).HasResult(); info->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; - info->name = transformer.Transform(list_pr.Child(2)); // TriggerName + info->name = trigger_name; - auto base_table = transformer.Transform>(list_pr.Child(4)); auto extra_info = make_uniq(); - extra_info->base_table = std::move(base_table); + extra_info->base_table = std::move(base_table_name); info->extra_drop_info = std::move(extra_info); result->info = std::move(info); diff --git a/src/duckdb/src/parser/peg/transformer/transform_execute.cpp b/src/duckdb/src/parser/peg/transformer/transform_execute.cpp index b08adde45..1275e60fc 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_execute.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_execute.cpp @@ -3,36 +3,33 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformExecuteStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformExecuteStatement(PEGTransformer &transformer, const Identifier &identifier, + vector table_function_arguments) { auto result = make_uniq(); - result->name = list_pr.Child(1).identifier; - auto &table_function_opt = list_pr.Child(2); - if (table_function_opt.HasResult()) { - idx_t param_idx = 0; - auto table_function_arguments = - transformer.Transform>>(table_function_opt.GetResult()); - for (idx_t i = 0; i < table_function_arguments.size(); i++) { - auto &expr = table_function_arguments[i]; - if (!table_function_arguments[i]->IsScalar()) { - throw InvalidInputException("Only scalar parameters, named parameters or NULL supported for EXECUTE"); - } - if (!table_function_arguments[i]->GetAlias().empty() && param_idx != 0) { + result->name = identifier; + if (table_function_arguments.empty()) { + return std::move(result); + } + idx_t param_idx = 0; + for (idx_t i = 0; i < table_function_arguments.size(); i++) { + auto &arg = table_function_arguments[i]; + if (!table_function_arguments[i].GetExpression().IsScalar()) { + throw InvalidInputException("Only scalar parameters, named parameters or NULL supported for EXECUTE"); + } + if (!table_function_arguments[i].GetName().empty() && param_idx != 0) { + throw NotImplementedException("Mixing named parameters and positional parameters is not supported yet"); + } + auto param_name = arg.GetName(); + if (table_function_arguments[i].GetName().empty()) { + param_name = Identifier(std::to_string(param_idx + 1)); + if (param_idx != i) { throw NotImplementedException("Mixing named parameters and positional parameters is not supported yet"); } - auto param_name = expr->GetAlias(); - if (table_function_arguments[i]->GetAlias().empty()) { - param_name = std::to_string(param_idx + 1); - if (param_idx != i) { - throw NotImplementedException( - "Mixing named parameters and positional parameters is not supported yet"); - } - param_idx++; - } - expr->ClearAlias(); - result->named_values[param_name] = std::move(expr); + param_idx++; } + arg.GetExpressionMutable()->ClearAlias(); + result->named_values[param_name] = std::move(arg.GetExpressionMutable()); } return std::move(result); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_explain.cpp b/src/duckdb/src/parser/peg/transformer/transform_explain.cpp index d8b292204..6a726808e 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_explain.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_explain.cpp @@ -19,26 +19,24 @@ ExplainFormat ParseExplainFormat(const Value &val) { return it->second; } vector options_list; - for (auto &it : format_mapping) { - options_list.push_back(it.first); + for (auto &format : format_mapping) { + options_list.push_back(format.first); } auto allowed_options = StringUtil::Join(options_list, ", "); throw InvalidInputException("\"%s\" is not a valid FORMAT argument, valid options are: %s", format_val, allowed_options); } -unique_ptr PEGTransformerFactory::TransformExplainStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &explain_analyze = list_pr.Child(1); - auto explain_type = explain_analyze.HasResult() ? ExplainType::EXPLAIN_ANALYZE : ExplainType::EXPLAIN_STANDARD; +unique_ptr +PEGTransformerFactory::TransformExplainStatement(PEGTransformer &transformer, const bool &explain_analyze, + const vector &explain_option_list, + unique_ptr explainable_statements) { + auto explain_type = explain_analyze ? ExplainType::EXPLAIN_ANALYZE : ExplainType::EXPLAIN_STANDARD; bool format_is_set = false; auto explain_format = ExplainFormat::DEFAULT; - auto &explain_options_opt = list_pr.Child(2); - if (explain_options_opt.HasResult()) { - auto options = transformer.Transform>(explain_options_opt.GetResult()); - for (auto option : options) { - auto option_name = StringUtil::Lower(option.name); + if (!explain_option_list.empty()) { + for (auto option : explain_option_list) { + auto option_name = StringUtil::Lower(option.name.GetIdentifierName()); if (option_name == "format") { if (format_is_set) { throw InvalidInputException("FORMAT can not be provided more than once"); @@ -52,42 +50,34 @@ unique_ptr PEGTransformerFactory::TransformExplainStatement(PEGTra } } } - auto statement = transformer.Transform>(list_pr.Child(3)); + auto statement = std::move(explainable_statements); return make_uniq(std::move(statement), explain_type, explain_format); } -unique_ptr PEGTransformerFactory::TransformExplainableStatements(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0).GetResult(); - if (StringUtil::CIEquals(choice_pr.name, "SelectStatementInternal")) { - return transformer.Transform>(list_pr.Child(0).GetResult()); - } - return transformer.Transform>(list_pr.Child(0).GetResult()); +bool PEGTransformerFactory::TransformExplainAnalyze(PEGTransformer &transformer) { + return true; } -vector PEGTransformerFactory::TransformExplainOptionList(PEGTransformer &transformer, - ParseResult &parse_result) { - vector result; - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - auto option_list = ExtractParseResultsFromList(extract_parens); - for (auto &option : option_list) { - result.push_back(transformer.Transform(option)); - } - return result; +unique_ptr +PEGTransformerFactory::TransformExplainSelectStatement(PEGTransformer &transformer, + unique_ptr select_statement_internal) { + return std::move(select_statement_internal); +} + +vector +PEGTransformerFactory::TransformExplainOptionList(PEGTransformer &transformer, + const vector &explain_option) { + return explain_option; } GenericCopyOption PEGTransformerFactory::TransformExplainOption(PEGTransformer &transformer, - ParseResult &parse_result) { + const Identifier &explain_option_name, + unique_ptr expression) { GenericCopyOption copy_option; - auto &list_pr = parse_result.Cast(); - copy_option.name = StringUtil::Lower(transformer.Transform(list_pr.GetChild(0))); - auto &optional_expression = list_pr.Child(1); - if (!optional_expression.HasResult()) { + copy_option.name = Identifier(StringUtil::Lower(explain_option_name.GetIdentifierName())); + if (!expression) { return copy_option; } - auto expression = transformer.Transform>(optional_expression.GetResult()); if (expression->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { copy_option.children.push_back(Value(expression->Cast().GetValue())); } else if (expression->GetExpressionType() == ExpressionType::COLUMN_REF) { diff --git a/src/duckdb/src/parser/peg/transformer/transform_export.cpp b/src/duckdb/src/parser/peg/transformer/transform_export.cpp index 18b2a71e1..0f7e53ea2 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_export.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_export.cpp @@ -1,25 +1,27 @@ #include "duckdb/parser/parsed_data/copy_info.hpp" #include "duckdb/parser/statement/export_statement.hpp" +#include "duckdb/parser/statement/pragma_statement.hpp" #include "duckdb/parser/peg/transformer/peg_transformer.hpp" namespace duckdb { unique_ptr -PEGTransformerFactory::TransformExportStatement(const string &export_source, const string &string_literal, - vector generic_copy_option_list) { +PEGTransformerFactory::TransformExportStatement(PEGTransformer &transformer, const string &export_source, + const string &string_literal, + const vector &generic_copy_option_list) { auto info = make_uniq(); info->file_path = string_literal; info->format = "csv"; info->is_from = false; - for (auto &option : generic_copy_option_list) { + for (const auto &option : generic_copy_option_list) { if (option.name == "format") { info->format = option.children[0].GetValue(); info->is_format_auto_detected = false; } else if (option.expression) { - info->parsed_options[StringUtil::Upper(option.name)] = std::move(option.expression); + info->parsed_options[StringUtil::Upper(option.name.GetIdentifierName())] = option.expression->Copy(); } else { - info->options[StringUtil::Upper(option.name)] = option.children; + info->options[StringUtil::Upper(option.name.GetIdentifierName())] = option.children; } } @@ -28,8 +30,16 @@ PEGTransformerFactory::TransformExportStatement(const string &export_source, con return std::move(result); } -string PEGTransformerFactory::TransformExportSource(const string &catalog_name) { - return catalog_name; +string PEGTransformerFactory::TransformExportSource(PEGTransformer &transformer, const Identifier &catalog_name) { + return catalog_name.GetIdentifierName(); +} + +unique_ptr PEGTransformerFactory::TransformImportStatement(PEGTransformer &transformer, + const string &string_literal) { + auto result = make_uniq(); + result->info->name = "import_database"; + result->info->parameters.emplace_back(make_uniq(Value(string_literal))); + return std::move(result); } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_expression.cpp b/src/duckdb/src/parser/peg/transformer/transform_expression.cpp index 352a02dd4..fbfaecfae 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_expression.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_expression.cpp @@ -19,16 +19,10 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformExpressionStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto expr_list = ExtractParseResultsFromList(list_pr.GetChild(0)); - - vector> expressions; - for (auto &expr : expr_list) { - expressions.push_back(transformer.Transform>(expr)); - } - +unique_ptr +PEGTransformerFactory::TransformExpressionStatement(PEGTransformer &transformer, + vector> expression_alias) { + auto expressions = std::move(expression_alias); auto select_statement = make_uniq(); auto select_node = make_uniq(); @@ -49,17 +43,17 @@ unique_ptr PEGTransformerFactory::TransformExpressionStatement(PEG // Single COLUMN_REF: treat as table scan auto &col_expr = expressions[0]->Cast(); if (col_expr.IsQualified()) { - if (col_expr.column_names.size() >= 4) { + if (col_expr.ColumnNames().size() >= 4) { throw ParserException( "Too many qualifiers encountered. Expected \"catalog.schema.table\" or \"schema.table\""); } - if (col_expr.column_names.size() == 3) { + if (col_expr.ColumnNames().size() == 3) { auto table_description = - TableDescription(col_expr.column_names[0], col_expr.column_names[1], col_expr.column_names[2]); + TableDescription(col_expr.ColumnNames()[0], col_expr.ColumnNames()[1], col_expr.ColumnNames()[2]); select_node->from_table = make_uniq(table_description); - } else if (col_expr.column_names.size() == 2) { + } else if (col_expr.ColumnNames().size() == 2) { auto table_description = - TableDescription(INVALID_CATALOG, col_expr.column_names[0], col_expr.column_names[1]); + TableDescription(INVALID_CATALOG, col_expr.ColumnNames()[0], col_expr.ColumnNames()[1]); select_node->from_table = make_uniq(table_description); } } else { @@ -101,7 +95,7 @@ unique_ptr PEGTransformerFactory::TransformBaseExpression(PEGT auto indirection_expr = transformer.Transform>(child); if (indirection_expr->GetExpressionClass() == ExpressionClass::CAST) { auto cast_expr = unique_ptr_cast(std::move(indirection_expr)); - cast_expr->child = std::move(expr); + cast_expr->ChildMutable() = std::move(expr); expr = std::move(cast_expr); prev_indirection_was_cast = true; } else if (indirection_expr->GetExpressionClass() == ExpressionClass::OPERATOR) { @@ -111,12 +105,12 @@ unique_ptr PEGTransformerFactory::TransformBaseExpression(PEGT "allowed). Wrap the cast in parentheses: (x::TYPE)[1:3]"); } auto operator_expr = unique_ptr_cast(std::move(indirection_expr)); - operator_expr->children.insert(operator_expr->children.begin(), std::move(expr)); + operator_expr->GetChildrenMutable().insert(operator_expr->GetChildrenMutable().begin(), std::move(expr)); expr = std::move(operator_expr); prev_indirection_was_cast = false; } else if (indirection_expr->GetExpressionClass() == ExpressionClass::FUNCTION) { auto function_expr = unique_ptr_cast(std::move(indirection_expr)); - function_expr->children.insert(function_expr->children.begin(), std::move(expr)); + function_expr->GetArgumentsMutable().insert(function_expr->GetArgumentsMutable().begin(), std::move(expr)); expr = std::move(function_expr); prev_indirection_was_cast = false; } else if (indirection_expr->GetExpressionClass() == ExpressionClass::CONSTANT) { @@ -134,22 +128,6 @@ unique_ptr PEGTransformerFactory::TransformBaseExpression(PEGT return expr; } -unique_ptr PEGTransformerFactory::TransformNestedColumnName(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - vector column_names; - auto &opt_identifiers = list_pr.Child(0); - if (opt_identifiers.HasResult()) { - auto &repeat_identifiers = opt_identifiers.GetResult().Cast(); - for (auto &child : repeat_identifiers.GetChildren()) { - auto &repeat_list = child.get().Cast(); - column_names.push_back(repeat_list.Child(0).identifier); - } - } - column_names.push_back(list_pr.Child(1).identifier); - return make_uniq(std::move(column_names)); -} - // ColumnReference <- CatalogReservedSchemaTableColumnName / SchemaReservedTableColumnName / TableReservedColumnName / // NestedColumnName unique_ptr PEGTransformerFactory::TransformColumnReference(PEGTransformer &transformer, @@ -162,10 +140,10 @@ unique_ptr PEGTransformerFactory::TransformCatalogReservedSchemaTableColumnName(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector column_names; - column_names.push_back(transformer.Transform(list_pr.Child(0))); - column_names.push_back(transformer.Transform(list_pr.Child(1))); - column_names.push_back(transformer.Transform(list_pr.Child(2))); + vector column_names; + column_names.push_back(transformer.Transform(list_pr.Child(0))); + column_names.push_back(transformer.Transform(list_pr.Child(1))); + column_names.push_back(transformer.Transform(list_pr.Child(2))); column_names.push_back(list_pr.Child(3).identifier); return make_uniq(std::move(column_names)); } @@ -173,17 +151,16 @@ PEGTransformerFactory::TransformCatalogReservedSchemaTableColumnName(PEGTransfor unique_ptr PEGTransformerFactory::TransformSchemaReservedTableColumnName(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector column_names; - column_names.push_back(transformer.Transform(list_pr.Child(0))); - column_names.push_back(transformer.Transform(list_pr.Child(1))); + vector column_names; + column_names.push_back(transformer.Transform(list_pr.Child(0))); + column_names.push_back(transformer.Transform(list_pr.Child(1))); column_names.push_back(list_pr.Child(2).identifier); return make_uniq(std::move(column_names)); } -string PEGTransformerFactory::TransformReservedTableQualification(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; +Identifier PEGTransformerFactory::TransformReservedTableQualification(PEGTransformer &transformer, + const Identifier &reserved_table_name) { + return reserved_table_name; } unique_ptr PEGTransformerFactory::TransformFunctionExpression(PEGTransformer &transformer, @@ -195,11 +172,11 @@ unique_ptr PEGTransformerFactory::TransformFunctionExpression( transformer.TransformOptional(extract_parens, 0, distinct); auto &function_arg_opt = extract_parens.Child(1); - vector> function_children; + vector function_children; if (function_arg_opt.HasResult()) { auto function_argument_list = ExtractParseResultsFromList(function_arg_opt.GetResult()); for (auto function_argument : function_argument_list) { - function_children.push_back(transformer.Transform>(function_argument)); + function_children.push_back(transformer.Transform(function_argument)); } } auto order_modifier = make_uniq(); @@ -212,12 +189,12 @@ unique_ptr PEGTransformerFactory::TransformFunctionExpression( transformer.TransformOptional(extract_parens, 3, ignore_nulls); auto &export_opt = list_pr.Child(4); - if (function_children.size() == 1 && ExpressionIsEmptyStar(*function_children[0]) && !distinct && - order_modifier->orders.empty()) { + if (function_children.size() == 1 && ExpressionIsEmptyStar(*function_children[0].GetExpressionMutable()) && + !distinct && order_modifier->orders.empty()) { // COUNT(*) gets converted into COUNT() function_children.clear(); } - auto lowercase_name = StringUtil::Lower(qualified_function.name); + auto lowercase_name = StringUtil::Lower(qualified_function.name.GetIdentifierName()); auto &over_opt = list_pr.Child(5); if (over_opt.HasResult()) { @@ -236,16 +213,19 @@ unique_ptr PEGTransformerFactory::TransformFunctionExpression( transformer.in_window_definition = true; auto expr = transformer.Transform>(over_opt.GetResult()); - expr->catalog = qualified_function.catalog; - expr->schema = qualified_function.schema; + expr->CatalogMutable() = qualified_function.catalog; + expr->SchemaMutable() = qualified_function.schema; expr->SetFunctionName(lowercase_name); - expr->children = std::move(function_children); - expr->has_ignore_nulls = has_ignore_nulls_result; - expr->ignore_nulls = ignore_nulls; - expr->filter_expr = std::move(filter_expr); - expr->arg_orders = std::move(order_modifier->orders); - expr->distinct = distinct; + for (auto &arg : function_children) { + expr->GetArgumentsMutable().push_back(std::move(arg)); + } + + expr->HasIgnoreNullsMutable() = has_ignore_nulls_result; + expr->IgnoreNullsMutable() = ignore_nulls; + expr->FilterMutable() = std::move(filter_expr); + expr->ArgOrdersMutable() = std::move(order_modifier->orders); + expr->DistinctMutable() = distinct; transformer.in_window_definition = false; return std::move(expr); } @@ -257,46 +237,78 @@ unique_ptr PEGTransformerFactory::TransformFunctionExpression( if (function_children.size() != 3) { throw ParserException("Wrong number of arguments to IF."); } + for (auto &arg : function_children) { + if (arg.HasName()) { + throw ParserException("Named arguments are not supported in IF expressions"); + } + } + auto expr = make_uniq(); CaseCheck check; - check.when_expr = std::move(function_children[0]); - check.then_expr = std::move(function_children[1]); - expr->case_checks.push_back(std::move(check)); - expr->else_expr = std::move(function_children[2]); + check.when_expr = std::move(function_children[0].GetExpressionMutable()); + check.then_expr = std::move(function_children[1].GetExpressionMutable()); + expr->CaseChecksMutable().push_back(std::move(check)); + expr->ElseMutable() = std::move(function_children[2].GetExpressionMutable()); return std::move(expr); - } else if (lowercase_name == "unpack") { + } + if (lowercase_name == "unpack") { if (function_children.size() != 1) { throw ParserException("Wrong number of arguments to the UNPACK operator"); } auto expr = make_uniq(ExpressionType::OPERATOR_UNPACK); - expr->children = std::move(function_children); + for (auto &arg : function_children) { + if (arg.HasName()) { + throw ParserException("Named arguments are not supported in UNPACK operator"); + } + expr->GetChildrenMutable().push_back(std::move(arg.GetExpressionMutable())); + } return std::move(expr); - } else if (lowercase_name == "try") { + } + if (lowercase_name == "try") { if (function_children.size() != 1) { throw ParserException("Wrong number of arguments provided to TRY expression"); } auto try_expression = make_uniq(ExpressionType::OPERATOR_TRY); - try_expression->children = std::move(function_children); + for (auto &arg : function_children) { + if (arg.HasName()) { + throw ParserException("Named arguments are not supported in TRY expression"); + } + try_expression->GetChildrenMutable().push_back(std::move(arg.GetExpressionMutable())); + } return std::move(try_expression); - } else if (lowercase_name == "construct_array") { + } + if (lowercase_name == "construct_array") { auto construct_array = make_uniq(ExpressionType::ARRAY_CONSTRUCTOR); - construct_array->children = std::move(function_children); + for (auto &arg : function_children) { + if (arg.HasName()) { + throw ParserException("Named arguments are not supported in array constructors"); + } + construct_array->GetChildrenMutable().push_back(std::move(arg.GetExpressionMutable())); + } return std::move(construct_array); - } else if (lowercase_name == "ifnull") { + } + if (lowercase_name == "ifnull") { if (function_children.size() != 2) { throw ParserException("Wrong number of arguments to IFNULL."); } + for (auto &arg : function_children) { + if (arg.HasName()) { + throw ParserException("Named arguments are not supported in IFNULL expressions"); + } + } // Two-argument COALESCE auto coalesce_op = make_uniq(ExpressionType::OPERATOR_COALESCE); - coalesce_op->children.push_back(std::move(function_children[0])); - coalesce_op->children.push_back(std::move(function_children[1])); + coalesce_op->GetChildrenMutable().push_back(std::move(function_children[0].GetExpressionMutable())); + coalesce_op->GetChildrenMutable().push_back(std::move(function_children[1].GetExpressionMutable())); return std::move(coalesce_op); - } else if (lowercase_name == "date") { + } + if (lowercase_name == "date") { if (function_children.size() != 1) { throw ParserException("Wrong number of arguments provided to DATE function"); } - return std::move(make_uniq(LogicalType::DATE, std::move(function_children[0]))); + return std::move( + make_uniq(LogicalType::DATE, std::move(function_children[0].GetExpressionMutable()))); } if (has_ignore_nulls_result) { throw ParserException("RESPECT/IGNORE NULLS is not supported for non-window functions"); @@ -336,9 +348,9 @@ unique_ptr PEGTransformerFactory::TransformFunctionExpression( throw ParserException("Unknown ordered aggregate \"%s\".", qualified_function.name); } } - auto result = make_uniq(qualified_function.catalog, qualified_function.schema, lowercase_name, - std::move(function_children), std::move(filter_expr), - std::move(order_modifier), distinct, false, export_opt.HasResult()); + auto result = make_uniq( + qualified_function.catalog, qualified_function.schema, Identifier(lowercase_name), std::move(function_children), + std::move(filter_expr), std::move(order_modifier), distinct, false, export_opt.HasResult()); return std::move(result); } @@ -369,7 +381,7 @@ QualifiedName PEGTransformerFactory::TransformSchemaReservedFunctionName(PEGTran auto &list_pr = parse_result.Cast(); QualifiedName result; result.catalog = INVALID_CATALOG; - result.schema = transformer.Transform(list_pr.Child(0)); + result.schema = Identifier(transformer.Transform(list_pr.Child(0))); result.name = list_pr.Child(1).identifier; return result; } @@ -380,10 +392,10 @@ QualifiedName PEGTransformerFactory::TransformCatalogReservedSchemaFunctionName( QualifiedName result; auto &opt_schema = list_pr.Child(1); if (opt_schema.HasResult()) { - result.catalog = transformer.Transform(list_pr.Child(0)); - result.schema = transformer.Transform(opt_schema.GetResult()); + result.catalog = Identifier(transformer.Transform(list_pr.Child(0))); + result.schema = Identifier(transformer.Transform(opt_schema.GetResult())); } else { - result.schema = transformer.Transform(list_pr.Child(0)); + result.schema = Identifier(transformer.Transform(list_pr.Child(0))); } result.name = list_pr.Child(2).identifier; return result; @@ -434,9 +446,9 @@ unique_ptr PEGTransformerFactory::TransformParenthesisExpressi void PEGTransformerFactory::RemoveOrderQualificationRecursive(unique_ptr &root_expr) { ParsedExpressionIterator::VisitExpressionMutable( *root_expr, [&](ColumnRefExpression &col_ref) { - auto &col_names = col_ref.column_names; + auto &col_names = col_ref.ColumnNamesMutable(); if (col_names.size() > 1) { - col_names = vector {col_names.back()}; + col_names = vector {col_names.back()}; } }); } @@ -446,16 +458,16 @@ unique_ptr PEGTransformerFactory::TransformArrayParensSelect(P auto &list_pr = parse_result.Cast(); auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); auto subquery_expr = make_uniq(); - subquery_expr->subquery = transformer.Transform>(extract_parens); + subquery_expr->SubqueryMutable() = transformer.Transform>(extract_parens); // ARRAY expression // wrap subquery into // "SELECT CASE WHEN ARRAY_AGG(col) IS NULL THEN [] ELSE ARRAY_AGG(col) END FROM (...) tbl" auto select_node = make_uniq(); unique_ptr array_agg_child; optional_ptr sub_select; - if (subquery_expr->subquery->node->type == QueryNodeType::SELECT_NODE) { + if (subquery_expr->Subquery()->node->type == QueryNodeType::SELECT_NODE) { // easy case - subquery is a SELECT - sub_select = subquery_expr->subquery->node->Cast(); + sub_select = subquery_expr->Subquery()->node->Cast(); if (sub_select->select_list.size() != 1) { throw BinderException(*subquery_expr, "Subquery returns %zu columns - expected 1", sub_select->select_list.size()); @@ -465,7 +477,7 @@ unique_ptr PEGTransformerFactory::TransformArrayParensSelect(P // subquery is not a SELECT but a UNION or CTE // we can still support this but it is more challenging since we can't push columns for the ORDER BY auto columns_star = make_uniq(); - columns_star->columns = true; + columns_star->IsColumnsMutable() = true; array_agg_child = std::move(columns_star); } @@ -474,16 +486,16 @@ unique_ptr PEGTransformerFactory::TransformArrayParensSelect(P children.push_back(std::move(array_agg_child)); auto aggr = make_uniq("array_agg", std::move(children)); // push ORDER BY modifiers into the array_agg - for (auto &modifier : subquery_expr->subquery->node->modifiers) { + for (auto &modifier : subquery_expr->SubqueryMutable()->node->modifiers) { if (modifier->type == ResultModifierType::ORDER_MODIFIER) { - aggr->order_bys = unique_ptr_cast(modifier->Copy()); + aggr->OrderByMutable() = unique_ptr_cast(modifier->Copy()); break; } } // transform constants (e.g. ORDER BY 1) into positional references (ORDER BY #1) idx_t array_idx = 0; - if (aggr->order_bys) { - for (auto &order : aggr->order_bys->orders) { + if (aggr->OrderBy()) { + for (auto &order : aggr->OrderByMutable()->orders) { if (order.expression->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { auto &constant_expr = order.expression->Cast(); Value bigint_value; @@ -496,9 +508,9 @@ unique_ptr PEGTransformerFactory::TransformArrayParensSelect(P } else if (sub_select) { // if we have a SELECT we can push the ORDER BY clause into the SELECT list and reference it auto alias = "__array_internal_idx_" + to_string(++array_idx); - order.expression->SetAlias(alias); + order.expression->SetAlias(Identifier(alias)); sub_select->select_list.push_back(std::move(order.expression)); - order.expression = make_uniq(alias); + order.expression = make_uniq(Identifier(alias)); } else { // otherwise we remove order qualifications RemoveOrderQualificationRecursive(order.expression); @@ -515,20 +527,20 @@ unique_ptr PEGTransformerFactory::TransformArrayParensSelect(P CaseCheck check; check.when_expr = std::move(agg_is_null); check.then_expr = std::move(empty_list); - case_expr->case_checks.push_back(std::move(check)); - case_expr->else_expr = std::move(aggr); + case_expr->CaseChecksMutable().push_back(std::move(check)); + case_expr->ElseMutable() = std::move(aggr); select_node->select_list.push_back(std::move(case_expr)); // FROM (...) tbl - auto child_subquery = make_uniq(std::move(subquery_expr->subquery)); + auto child_subquery = make_uniq(std::move(subquery_expr->SubqueryMutable())); select_node->from_table = std::move(child_subquery); auto new_subquery = make_uniq(); new_subquery->node = std::move(select_node); - subquery_expr->subquery = std::move(new_subquery); + subquery_expr->SubqueryMutable() = std::move(new_subquery); - subquery_expr->subquery_type = SubqueryType::SCALAR; + subquery_expr->GetSubqueryTypeMutable() = SubqueryType::SCALAR; return std::move(subquery_expr); } @@ -536,22 +548,22 @@ unique_ptr PEGTransformerFactory::TransformStructExpression(PE ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto func_name = "struct_pack"; - vector> struct_children; + vector struct_children; auto struct_children_list = ExtractParseResultsFromList(list_pr.GetChild(1)); for (auto struct_child : struct_children_list) { - struct_children.push_back(transformer.Transform>(struct_child)); + struct_children.push_back(transformer.Transform(struct_child)); } return make_uniq(INVALID_CATALOG, "main", func_name, std::move(struct_children)); } -unique_ptr PEGTransformerFactory::TransformStructField(PEGTransformer &transformer, - ParseResult &parse_result) { +FunctionArgument PEGTransformerFactory::TransformStructField(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto alias = transformer.Transform(list_pr.Child(0)); auto expr = transformer.Transform>(list_pr.Child(2)); - expr->SetAlias(alias); - return expr; + expr->SetAlias(Identifier(alias)); + + return FunctionArgument(Identifier(std::move(alias)), std::move(expr)); } vector> PEGTransformerFactory::TransformBoundedListExpression(PEGTransformer &transformer, @@ -666,11 +678,11 @@ unique_ptr PEGTransformerFactory::TransformIsExpression(PEGTra auto is_expr = transformer.Transform>(is_test_expr); if (is_expr->GetExpressionClass() == ExpressionClass::COMPARISON) { auto compare_expr = unique_ptr_cast(std::move(is_expr)); - compare_expr->left = make_uniq(LogicalType::BOOLEAN, std::move(expr)); + compare_expr->LeftMutable() = make_uniq(LogicalType::BOOLEAN, std::move(expr)); expr = std::move(compare_expr); } else if (is_expr->GetExpressionClass() == ExpressionClass::OPERATOR) { auto operator_expr = unique_ptr_cast(std::move(is_expr)); - operator_expr->children.insert(operator_expr->children.begin(), std::move(expr)); + operator_expr->GetChildrenMutable().insert(operator_expr->GetChildrenMutable().begin(), std::move(expr)); expr = std::move(operator_expr); } else { throw InternalException("Unexpected expression encountered in IsExpression: %s", @@ -691,7 +703,7 @@ unique_ptr PEGTransformerFactory::TransformIsLiteral(PEGTransf auto &list_pr = parse_result.Cast(); auto ¬_expr = list_pr.Child(1); auto &inner_list_pr = list_pr.Child(2); - auto literal_value = transformer.TransformEnum(inner_list_pr.Child(0).GetResult()); + auto literal_value = transformer.Transform(inner_list_pr.Child(0).GetResult()); if (literal_value.IsNull()) { auto expr_type = not_expr.HasResult() ? ExpressionType::OPERATOR_IS_NOT_NULL : ExpressionType::OPERATOR_IS_NULL; return make_uniq(expr_type, nullptr); @@ -767,7 +779,7 @@ ExpressionType PEGTransformerFactory::TransformComparisonOperator(PEGTransformer return transformer.TransformEnum(list_pr.Child(0).GetResult()); } -bool TryNegateLikeFunction(string &function_name) { +bool TryNegateLikeFunction(Identifier &function_name) { if (function_name == "~~") { function_name = "!~~"; return true; @@ -796,7 +808,7 @@ unique_ptr PEGTransformerFactory::TransformBetweenInLikeExpres bool has_not = op_list.Child(0).HasResult(); if (between_in_like_expr->GetExpressionClass() == ExpressionClass::BETWEEN) { auto between_expr = unique_ptr_cast(std::move(between_in_like_expr)); - between_expr->input = std::move(expr); + between_expr->InputMutable() = std::move(expr); if (has_not) { expr = make_uniq(ExpressionType::OPERATOR_NOT, std::move(between_expr)); } else { @@ -804,28 +816,28 @@ unique_ptr PEGTransformerFactory::TransformBetweenInLikeExpres } } else if (between_in_like_expr->GetExpressionClass() == ExpressionClass::FUNCTION) { auto func_expr = unique_ptr_cast(std::move(between_in_like_expr)); - if (func_expr->function_name == "contains") { - func_expr->children.push_back(std::move(expr)); + if (func_expr->FunctionName() == "contains") { + func_expr->GetArgumentsMutable().push_back(std::move(expr)); } else { - func_expr->children.insert(func_expr->children.begin(), std::move(expr)); + func_expr->GetArgumentsMutable().insert(func_expr->GetArgumentsMutable().begin(), std::move(expr)); } if (has_not) { - if (!TryNegateLikeFunction(func_expr->function_name)) { + if (!TryNegateLikeFunction(func_expr->FunctionNameMutable())) { // If it wasn't a special "Like" function, wrap it in a standard NOT operator expr = make_uniq(ExpressionType::OPERATOR_NOT, std::move(func_expr)); } else { expr = std::move(func_expr); } - } else if (func_expr->function_name == "!~") { - func_expr->function_name = "regexp_full_match"; - func_expr->is_operator = false; + } else if (func_expr->FunctionName() == "!~") { + func_expr->FunctionNameMutable() = "regexp_full_match"; + func_expr->IsOperatorMutable() = false; expr = make_uniq(ExpressionType::OPERATOR_NOT, std::move(func_expr)); } else { expr = std::move(func_expr); } } else if (between_in_like_expr->GetExpressionClass() == ExpressionClass::OPERATOR) { auto &operator_expr = between_in_like_expr->Cast(); - operator_expr.children.insert(operator_expr.children.begin(), std::move(expr)); + operator_expr.GetChildrenMutable().insert(operator_expr.GetChildrenMutable().begin(), std::move(expr)); if (has_not) { expr = make_uniq(ExpressionType::OPERATOR_NOT, std::move(between_in_like_expr)); } else { @@ -833,7 +845,7 @@ unique_ptr PEGTransformerFactory::TransformBetweenInLikeExpres } } else if (between_in_like_expr->GetExpressionClass() == ExpressionClass::SUBQUERY) { auto &subquery_expr = between_in_like_expr->Cast(); - subquery_expr.child = std::move(expr); + subquery_expr.GetChildMutable() = std::move(expr); if (has_not) { expr = make_uniq(ExpressionType::OPERATOR_NOT, std::move(between_in_like_expr)); } else { @@ -883,9 +895,9 @@ unique_ptr PEGTransformerFactory::TransformInExpressionList(PE if (in_children.size() == 1 && in_children[0]->GetExpressionClass() == ExpressionClass::SUBQUERY) { auto &subquery_expr = in_children[0]->Cast(); auto result = make_uniq(); - result->subquery_type = SubqueryType::ANY; - result->comparison_type = ExpressionType::COMPARE_EQUAL; - result->subquery = std::move(subquery_expr.subquery); + result->GetSubqueryTypeMutable() = SubqueryType::ANY; + result->GetComparisonTypeMutable() = ExpressionType::COMPARE_EQUAL; + result->SubqueryMutable() = std::move(subquery_expr.SubqueryMutable()); return std::move(result); } auto result = make_uniq(ExpressionType::COMPARE_IN, std::move(in_children)); @@ -897,9 +909,9 @@ unique_ptr PEGTransformerFactory::TransformInSelectStatement(P auto &list_pr = parse_result.Cast(); auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); auto result = make_uniq(); - result->subquery_type = SubqueryType::ANY; - result->comparison_type = ExpressionType::COMPARE_EQUAL; - result->subquery = transformer.Transform>(extract_parens); + result->GetSubqueryTypeMutable() = SubqueryType::ANY; + result->GetComparisonTypeMutable() = ExpressionType::COMPARE_EQUAL; + result->SubqueryMutable() = transformer.Transform>(extract_parens); return std::move(result); } @@ -927,9 +939,9 @@ unique_ptr PEGTransformerFactory::TransformLikeClause(PEGTrans } like_children.push_back(transformer.Transform>(escape_opt.GetResult())); } - auto result = make_uniq(like_variation, std::move(like_children)); + auto result = make_uniq(Identifier(like_variation), std::move(like_children)); if (like_variation != "regexp_full_match") { - result->is_operator = true; + result->IsOperatorMutable() = true; } return std::move(result); } @@ -973,16 +985,17 @@ unique_ptr PEGTransformerFactory::TransformOtherOperatorExpres if (expression_type == ExpressionType::INVALID) { throw ParserException("ANY and ALL operators require one of =,<>,>,<,>=,<= comparisons!"); } - subquery_expr->subquery_type = SubqueryType::ANY; - subquery_expr->comparison_type = expression_type; + subquery_expr->GetSubqueryTypeMutable() = SubqueryType::ANY; + subquery_expr->GetComparisonTypeMutable() = expression_type; auto &right_expr_subquery = right_expr->Cast(); - subquery_expr->subquery = std::move(right_expr_subquery.subquery); - subquery_expr->child = std::move(expr); + subquery_expr->SubqueryMutable() = std::move(right_expr_subquery.SubqueryMutable()); + subquery_expr->GetChildMutable() = std::move(expr); if (!is_any) { // ALL sublink is equivalent to NOT(ANY) with inverted comparison // e.g. [= ALL()] is equivalent to [NOT(<> ANY())] // first invert the comparison type - subquery_expr->comparison_type = NegateComparisonExpression(subquery_expr->comparison_type); + subquery_expr->GetComparisonTypeMutable() = + NegateComparisonExpression(subquery_expr->GetComparisonType()); return make_uniq(ExpressionType::OPERATOR_NOT, std::move(subquery_expr)); } expr = std::move(subquery_expr); @@ -1000,15 +1013,16 @@ unique_ptr PEGTransformerFactory::TransformOtherOperatorExpres select_node->select_list.push_back(make_uniq("UNNEST", std::move(children))); select_node->from_table = make_uniq(); select_statement->node = std::move(select_node); - subquery_expr->subquery = std::move(select_statement); - subquery_expr->subquery_type = SubqueryType::ANY; - subquery_expr->child = std::move(expr); - subquery_expr->comparison_type = expression_type; + subquery_expr->SubqueryMutable() = std::move(select_statement); + subquery_expr->GetSubqueryTypeMutable() = SubqueryType::ANY; + subquery_expr->GetChildMutable() = std::move(expr); + subquery_expr->GetComparisonTypeMutable() = expression_type; if (!is_any) { // ALL sublink is equivalent to NOT(ANY) with inverted comparison // e.g. [= ALL()] is equivalent to [NOT(<> ANY())] // first invert the comparison type - subquery_expr->comparison_type = NegateComparisonExpression(subquery_expr->comparison_type); + subquery_expr->GetComparisonTypeMutable() = + NegateComparisonExpression(subquery_expr->GetComparisonType()); return make_uniq(ExpressionType::OPERATOR_NOT, std::move(subquery_expr)); } return std::move(subquery_expr); @@ -1018,8 +1032,22 @@ unique_ptr PEGTransformerFactory::TransformOtherOperatorExpres vector> children_function; children_function.push_back(std::move(expr)); children_function.push_back(std::move(right_expr)); - auto func_expr = make_uniq(std::move(other_operator), std::move(children_function)); - func_expr->is_operator = true; + vector split_operator = StringUtil::Split(other_operator, "."); + string schema_name = INVALID_SCHEMA; + string func_name = ""; + if (split_operator.size() == 1) { + func_name = split_operator[0]; + } else if (split_operator.size() == 2) { + schema_name = split_operator[0]; + func_name = split_operator[1]; + } else { + throw ParserException("Too many identifiers found, expected schema.operator or operator"); + } + + auto func_expr = + make_uniq(INVALID_CATALOG, Identifier(std::move(schema_name)), + Identifier(std::move(func_name)), std::move(children_function)); + func_expr->IsOperatorMutable() = true; expr = std::move(func_expr); } } @@ -1036,11 +1064,20 @@ string PEGTransformerFactory::TransformOtherOperator(PEGTransformer &transformer return transformer.Transform(child); } -// QualifiedOperator <- 'OPERATOR' Parens(AnyOp) +// QualifiedOperator <- 'OPERATOR' Parens(ColId* AnyOp) string PEGTransformerFactory::TransformQualifiedOperator(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto &any_op_pr = ExtractResultFromParens(list_pr.GetChild(1)); - return transformer.Transform(any_op_pr); + auto &any_op_pr = ExtractResultFromParens(list_pr.GetChild(1)).Cast(); + auto &repeat_colid_opt = any_op_pr.Child(0); + vector result; + if (repeat_colid_opt.HasResult()) { + auto &repeat_colid = repeat_colid_opt.GetResult().Cast(); + for (auto &colid : repeat_colid.GetChildren()) { + result.push_back(transformer.Transform(colid)); + } + } + result.push_back(transformer.Transform(any_op_pr.GetChild(1))); + return StringUtil::Join(result, "."); } // AnyOp <- '!~~*' / '>>=' / ... / '!' @@ -1104,8 +1141,8 @@ unique_ptr PEGTransformerFactory::TransformBitwiseExpression(P bit_children.push_back(std::move(expr)); bit_children.push_back( transformer.Transform>(inner_list_pr.Child(1))); - auto func_expr = make_uniq(std::move(bit), std::move(bit_children)); - func_expr->is_operator = true; + auto func_expr = make_uniq(Identifier(std::move(bit)), std::move(bit_children)); + func_expr->IsOperatorMutable() = true; expr = std::move(func_expr); } return expr; @@ -1135,8 +1172,8 @@ unique_ptr PEGTransformerFactory::TransformAdditiveExpression( term_children.push_back(std::move(expr)); term_children.push_back( transformer.Transform>(inner_list_pr.Child(1))); - auto func_expr = make_uniq(std::move(term), std::move(term_children)); - func_expr->is_operator = true; + auto func_expr = make_uniq(Identifier(std::move(term)), std::move(term_children)); + func_expr->IsOperatorMutable() = true; if (inner_list_pr.offset.IsValid()) { transformer.SetQueryLocation(*func_expr, inner_list_pr.offset); } @@ -1172,8 +1209,8 @@ unique_ptr PEGTransformerFactory::TransformMultiplicativeExpre factor_children.push_back(std::move(expr)); factor_children.push_back( transformer.Transform>(inner_list_pr.Child(1))); - auto func_expr = make_uniq(std::move(factor), std::move(factor_children)); - func_expr->is_operator = true; + auto func_expr = make_uniq(Identifier(std::move(factor)), std::move(factor_children)); + func_expr->IsOperatorMutable() = true; expr = std::move(func_expr); } return expr; @@ -1202,8 +1239,8 @@ unique_ptr PEGTransformerFactory::TransformExponentiationExpre exponent_children.push_back(std::move(expr)); exponent_children.push_back( transformer.Transform>(inner_list_pr.Child(1))); - auto func_expr = make_uniq(std::move(exponent), std::move(exponent_children)); - func_expr->is_operator = true; + auto func_expr = make_uniq(Identifier(std::move(exponent)), std::move(exponent_children)); + func_expr->IsOperatorMutable() = true; expr = std::move(func_expr); } return expr; @@ -1236,7 +1273,7 @@ unique_ptr PEGTransformerFactory::TransformCollateExpression(P collate_string = const_expr.GetValue().GetValue(); } else if (collate_string_expr->GetExpressionClass() == ExpressionClass::COLUMN_REF) { auto &col_ref = collate_string_expr->Cast(); - collate_string = StringUtil::Join(col_ref.column_names, "."); + collate_string = StringUtil::Join(col_ref.ColumnNames(), "."); } else { throw NotImplementedException("Unexpected expression encountered for collate, %s", EnumUtil::ToString(collate_string_expr->GetExpressionClass())); @@ -1347,8 +1384,8 @@ unique_ptr PEGTransformerFactory::TransformPrefixExpression(PE vector> children; children.push_back(std::move(expr)); - auto func_expr = make_uniq(prefix, std::move(children)); - func_expr->is_operator = true; + auto func_expr = make_uniq(Identifier(prefix), std::move(children)); + func_expr->IsOperatorMutable() = true; expr = std::move(func_expr); } return expr; @@ -1379,10 +1416,10 @@ unique_ptr PEGTransformerFactory::TransformAnonymousParameter( string identifier = StringUtil::Format("%d", known_param_index); // Register it - transformer.SetParam(identifier, known_param_index, PreparedParamType::AUTO_INCREMENT); + transformer.SetParam(Identifier(identifier), known_param_index, PreparedParamType::AUTO_INCREMENT); transformer.SetParamCount(MaxValue(transformer.ParamCount(), known_param_index)); - expr->identifier = identifier; + expr->IdentifierMutable() = Identifier(identifier); return std::move(expr); } @@ -1403,14 +1440,14 @@ unique_ptr PEGTransformerFactory::TransformQuestionMarkNumbere string identifier = const_expr.GetValue().ToString(); idx_t known_param_index = DConstants::INVALID_INDEX; - transformer.GetParam(identifier, known_param_index, PreparedParamType::POSITIONAL); + transformer.GetParam(Identifier(identifier), known_param_index, PreparedParamType::POSITIONAL); if (known_param_index == DConstants::INVALID_INDEX) { known_param_index = NumericCast(param_number); - transformer.SetParam(identifier, known_param_index, PreparedParamType::POSITIONAL); + transformer.SetParam(Identifier(identifier), known_param_index, PreparedParamType::POSITIONAL); } - expr->identifier = identifier; + expr->IdentifierMutable() = Identifier(identifier); transformer.SetParamCount(MaxValue(transformer.ParamCount(), known_param_index)); return std::move(expr); } @@ -1432,14 +1469,14 @@ unique_ptr PEGTransformerFactory::TransformNumberedParameter(P string identifier = const_expr.GetValue().ToString(); idx_t known_param_index = DConstants::INVALID_INDEX; - transformer.GetParam(identifier, known_param_index, PreparedParamType::POSITIONAL); + transformer.GetParam(Identifier(identifier), known_param_index, PreparedParamType::POSITIONAL); if (known_param_index == DConstants::INVALID_INDEX) { known_param_index = NumericCast(param_number); - transformer.SetParam(identifier, known_param_index, PreparedParamType::POSITIONAL); + transformer.SetParam(Identifier(identifier), known_param_index, PreparedParamType::POSITIONAL); } - expr->identifier = identifier; + expr->IdentifierMutable() = Identifier(identifier); transformer.SetParamCount(MaxValue(transformer.ParamCount(), known_param_index)); return std::move(expr); } @@ -1453,15 +1490,15 @@ unique_ptr PEGTransformerFactory::TransformColLabelParameter(P auto expr = make_uniq(); idx_t known_param_index = DConstants::INVALID_INDEX; - transformer.GetParam(identifier, known_param_index, PreparedParamType::NAMED); + transformer.GetParam(Identifier(identifier), known_param_index, PreparedParamType::NAMED); if (known_param_index == DConstants::INVALID_INDEX) { // New named parameter gets the next available index known_param_index = transformer.ParamCount() + 1; - transformer.SetParam(identifier, known_param_index, PreparedParamType::NAMED); + transformer.SetParam(Identifier(identifier), known_param_index, PreparedParamType::NAMED); } - expr->identifier = identifier; + expr->IdentifierMutable() = Identifier(identifier); transformer.SetParamCount(MaxValue(transformer.ParamCount(), known_param_index)); return std::move(expr); } @@ -1501,13 +1538,24 @@ unique_ptr PEGTransformerFactory::TransformParensExpression(PE unique_ptr PEGTransformerFactory::TransformConstantLiteral(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto val = transformer.TransformEnum(list_pr.Child(0).GetResult()); - if (val.IsNull()) { - return make_uniq(val); - } else { - auto varchar_val = make_uniq(val.GetValue().substr(0, 1)); - return make_uniq(LogicalType::BOOLEAN, std::move(varchar_val)); - } + auto val = transformer.Transform(list_pr.Child(0).GetResult()); + return make_uniq(val); +} + +Value PEGTransformerFactory::TransformFalseLiteral(PEGTransformer &transformer, ParseResult &parse_result) { + return Value(false); +} + +Value PEGTransformerFactory::TransformTrueLiteral(PEGTransformer &transformer, ParseResult &parse_result) { + return Value(true); +} + +Value PEGTransformerFactory::TransformNullLiteral(PEGTransformer &transformer, ParseResult &parse_result) { + return Value(); +} + +Value PEGTransformerFactory::TransformUnknownLiteral(PEGTransformer &transformer, ParseResult &parse_result) { + return Value(); } // SingleExpression <- LiteralExpression / @@ -1583,14 +1631,14 @@ unique_ptr PEGTransformerFactory::TransformMethodExpression(PE bool distinct = false; transformer.TransformOptional(extract_parens, 0, distinct); auto &function_arg_opt = extract_parens.Child(1); - vector> function_children; + vector function_children; if (function_arg_opt.HasResult()) { auto function_argument_list = ExtractParseResultsFromList(function_arg_opt.GetResult()); for (auto function_argument : function_argument_list) { - function_children.push_back(transformer.Transform>(function_argument)); + function_children.push_back(transformer.Transform(function_argument)); } } - if (function_children.size() == 1 && ExpressionIsEmptyStar(*function_children[0])) { + if (function_children.size() == 1 && ExpressionIsEmptyStar(function_children[0].GetExpression())) { // COUNT(*) gets converted into COUNT() function_children.clear(); } @@ -1602,13 +1650,13 @@ unique_ptr PEGTransformerFactory::TransformMethodExpression(PE if (has_ignore_nulls_result) { throw ParserException("RESPECT/IGNORE NULLS is not supported for non-window functions"); } - auto result = - make_uniq(INVALID_CATALOG, DEFAULT_SCHEMA, collabel, std::move(function_children)); - result->distinct = distinct; + auto result = make_uniq(INVALID_CATALOG, DEFAULT_SCHEMA, Identifier(collabel), + std::move(function_children)); + result->DistinctMutable() = distinct; if (!order_by.empty()) { auto order_by_modifier = make_uniq(); order_by_modifier->orders = std::move(order_by); - result->order_bys = std::move(order_by_modifier); + result->OrderByMutable() = std::move(order_by_modifier); } return std::move(result); } @@ -1694,14 +1742,18 @@ unique_ptr PEGTransformerFactory::TransformTableReservedCol auto &list_pr = parse_result.Cast(); auto table = transformer.Transform(list_pr.Child(0)); auto column = list_pr.Child(1).identifier; - return make_uniq(column, table); + return make_uniq(column, Identifier(table)); } -string PEGTransformerFactory::TransformTableQualification(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; +Identifier PEGTransformerFactory::TransformTableQualification(PEGTransformer &transformer, + const Identifier &table_name) { + return table_name; } +string PEGTransformerFactory::TransformColIdDot(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + return transformer.Transform(list_pr.GetChild(0)); +} unique_ptr PEGTransformerFactory::TransformStarExpression(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); @@ -1713,16 +1765,19 @@ unique_ptr PEGTransformerFactory::TransformStarExpression(PEGT if (repeat_colid.GetChildren().size() > 1) { throw ParserException("Did not expect more than one column in front of a star expression"); } - auto &colid_list = repeat_colid.GetChildren()[0].get().Cast(); - result->relation_name = transformer.Transform(colid_list.Child(0)); + result->RelationNameMutable() = + Identifier(transformer.Transform(repeat_colid.Child(0))); } - transformer.TransformOptional(list_pr, 2, result->exclude_list); + transformer.TransformOptional(list_pr, 2, result->ExcludeListMutable()); auto &replace_list_opt = list_pr.Child(3); if (replace_list_opt.HasResult()) { - result->replace_list = + auto replace_string_map = transformer.Transform>>(replace_list_opt.GetResult()); - for (auto &replace_entry : result->replace_list) { - if (result->exclude_list.find(QualifiedColumnName(replace_entry.first)) != result->exclude_list.end()) { + for (auto &replace_entry : replace_string_map) { + result->ReplaceListMutable()[Identifier(replace_entry.first)] = std::move(replace_entry.second); + } + for (auto &replace_entry : result->ReplaceList()) { + if (result->ExcludeList().find(QualifiedColumnName(replace_entry.first)) != result->ExcludeList().end()) { throw ParserException("Column \"%s\" cannot occur in both EXCLUDE and REPLACE list", replace_entry.first); } @@ -1730,13 +1785,16 @@ unique_ptr PEGTransformerFactory::TransformStarExpression(PEGT } auto &rename_list_opt = list_pr.Child(4); if (rename_list_opt.HasResult()) { - result->rename_list = transformer.Transform>(rename_list_opt.GetResult()); - for (auto &rename_column : result->rename_list) { - if (result->exclude_list.find(rename_column.first) != result->exclude_list.end()) { + auto rename_string_map = transformer.Transform>(rename_list_opt.GetResult()); + for (auto &rename_entry : rename_string_map) { + result->RenameListMutable()[rename_entry.first] = Identifier(rename_entry.second); + } + for (auto &rename_column : result->RenameList()) { + if (result->ExcludeList().find(rename_column.first) != result->ExcludeList().end()) { throw ParserException("Column \"%s\" cannot occur in both EXCLUDE and RENAME list", rename_column.first.ToString()); } - if (result->replace_list.find(rename_column.first.column) != result->replace_list.end()) { + if (result->ReplaceList().find(rename_column.first.column) != result->ReplaceList().end()) { throw ParserException("Column \"%s\" cannot occur in both REPLACE and RENAME list", rename_column.first.ToString()); } @@ -1785,7 +1843,7 @@ QualifiedColumnName PEGTransformerFactory::TransformExcludeName(PEGTransformer & return QualifiedColumnName::Parse(result_string); } else if (StringUtil::CIEquals(choice_pr.name, "colidorstring")) { auto result = transformer.Transform(choice_pr); - return QualifiedColumnName(result); + return QualifiedColumnName(Identifier(result)); } else { throw InternalException("Unexpected option encountered for ExcludeName"); } @@ -1814,8 +1872,8 @@ unique_ptr PEGTransformerFactory::TransformParensIdentifier(PE auto &extract_parens = ExtractResultFromParens(list_pr.GetChild(0)); auto window_name = extract_parens.Cast().identifier; auto window_clause = transformer.GetWindowClause(window_name); - if (window_clause->start_expr || window_clause->end_expr || - !transformer.IsWindowFrameDefault(window_clause->start, window_clause->end)) { + if (window_clause->StartExpr() || window_clause->EndExpr() || + !transformer.IsWindowFrameDefault(window_clause->WindowStart(), window_clause->WindowEnd())) { throw ParserException("cannot copy window \"%s\" because it has a frame clause", window_name); } return window_clause; @@ -1849,35 +1907,35 @@ unique_ptr PEGTransformerFactory::TransformWindowFrameNameCont if (window_name.empty()) { return window_frame_contents; } - auto copied_window = transformer.GetWindowClause(window_name); - if (copied_window->start_expr || copied_window->end_expr || - !transformer.IsWindowFrameDefault(copied_window->start, copied_window->end)) { + auto copied_window = transformer.GetWindowClause(Identifier(window_name)); + if (copied_window->StartExpr() || copied_window->EndExpr() || + !transformer.IsWindowFrameDefault(copied_window->WindowStart(), copied_window->WindowEnd())) { throw ParserException("cannot copy window \"%s\" because it has a frame clause", window_name); } - copied_window->start = window_frame_contents->start; - copied_window->end = window_frame_contents->end; - copied_window->exclude_clause = window_frame_contents->exclude_clause; - copied_window->start_expr = std::move(window_frame_contents->start_expr); - copied_window->end_expr = std::move(window_frame_contents->end_expr); + copied_window->WindowStartMutable() = window_frame_contents->WindowStart(); + copied_window->WindowEndMutable() = window_frame_contents->WindowEnd(); + copied_window->WindowExcludeMutable() = window_frame_contents->WindowExclude(); + copied_window->StartExprMutable() = std::move(window_frame_contents->StartExprMutable()); + copied_window->EndExprMutable() = std::move(window_frame_contents->EndExprMutable()); - if (!copied_window->orders.empty() && !window_frame_contents->orders.empty()) { + if (!copied_window->OrderBy().empty() && !window_frame_contents->OrderBy().empty()) { throw ParserException("Cannot override ORDER BY clause of window \"%s\"", window_name); } - if (copied_window->orders.empty()) { - copied_window->orders = std::move(window_frame_contents->orders); + if (copied_window->OrderBy().empty()) { + copied_window->OrderByMutable() = std::move(window_frame_contents->OrderByMutable()); } - if (!copied_window->partitions.empty() && !window_frame_contents->partitions.empty()) { + if (!copied_window->Partitions().empty() && !window_frame_contents->Partitions().empty()) { throw ParserException("Cannot override PARTITION BY clause of window \"%s\"", window_name); } - if (copied_window->partitions.empty()) { - copied_window->partitions = std::move(window_frame_contents->partitions); + if (copied_window->Partitions().empty()) { + copied_window->PartitionsMutable() = std::move(window_frame_contents->PartitionsMutable()); } return copied_window; } string PEGTransformerFactory::TransformBaseWindowName(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; + return list_pr.Child(0).identifier.GetIdentifierName(); } unique_ptr PEGTransformerFactory::TransformWindowFrameContents(PEGTransformer &transformer, @@ -1887,15 +1945,16 @@ unique_ptr PEGTransformerFactory::TransformWindowFrameContents auto result = make_uniq(INVALID_CATALOG, INVALID_SCHEMA, string()); auto &partition_opt = list_pr.Child(0); if (partition_opt.HasResult()) { - result->partitions = transformer.Transform>>(partition_opt.GetResult()); + result->PartitionsMutable() = + transformer.Transform>>(partition_opt.GetResult()); } auto &order_by_opt = list_pr.Child(1); if (order_by_opt.HasResult()) { - result->orders = transformer.Transform>(order_by_opt.GetResult()); - for (auto &order : result->orders) { + result->OrderByMutable() = transformer.Transform>(order_by_opt.GetResult()); + for (auto &order : result->OrderByMutable()) { if (order.expression->GetExpressionType() == ExpressionType::STAR) { auto &star = order.expression->Cast(); - if (!star.expr) { + if (!star.Expression()) { throw ParserException("Cannot ORDER BY ALL in a window expression"); } } @@ -1904,14 +1963,14 @@ unique_ptr PEGTransformerFactory::TransformWindowFrameContents auto &frame_opt = list_pr.Child(2); if (frame_opt.HasResult()) { auto window_frame = transformer.Transform(frame_opt.GetResult()); - result->start = window_frame.start; - result->end = window_frame.end; - result->start_expr = std::move(window_frame.start_expr); - result->end_expr = std::move(window_frame.end_expr); - result->exclude_clause = window_frame.exclude_clause; + result->WindowStartMutable() = window_frame.start; + result->WindowEndMutable() = window_frame.end; + result->StartExprMutable() = std::move(window_frame.start_expr); + result->EndExprMutable() = std::move(window_frame.end_expr); + result->WindowExcludeMutable() = window_frame.exclude_clause; } else { - result->start = WindowBoundary::UNBOUNDED_PRECEDING; - result->end = WindowBoundary::CURRENT_ROW_RANGE; + result->WindowStartMutable() = WindowBoundary::UNBOUNDED_PRECEDING; + result->WindowEndMutable() = WindowBoundary::CURRENT_ROW_RANGE; } return result; } @@ -2086,7 +2145,7 @@ unique_ptr PEGTransformerFactory::TransformCoalesceExpression( auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); auto expr_list = ExtractParseResultsFromList(extract_parens); for (auto expr : expr_list) { - result->children.push_back(transformer.Transform>(expr)); + result->GetChildrenMutable().push_back(transformer.Transform>(expr)); } return std::move(result); } @@ -2096,7 +2155,7 @@ unique_ptr PEGTransformerFactory::TransformUnpackExpression(PE auto &list_pr = parse_result.Cast(); auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); auto result = make_uniq(ExpressionType::OPERATOR_UNPACK); - result->children.push_back(transformer.Transform>(extract_parens)); + result->GetChildrenMutable().push_back(transformer.Transform>(extract_parens)); return std::move(result); } @@ -2105,7 +2164,7 @@ unique_ptr PEGTransformerFactory::TransformTryExpression(PEGTr auto &list_pr = parse_result.Cast(); auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); auto result = make_uniq(ExpressionType::OPERATOR_TRY); - result->children.push_back(transformer.Transform>(extract_parens)); + result->GetChildrenMutable().push_back(transformer.Transform>(extract_parens)); return std::move(result); } @@ -2119,8 +2178,8 @@ unique_ptr PEGTransformerFactory::TransformColumnsExpression(P transformer.Transform>(ExtractResultFromParens(list_pr.Child(2))); if (expr->GetExpressionType() == ExpressionType::STAR) { auto star_expr = unique_ptr_cast(std::move(expr)); - if (star_expr->columns) { - result->expr = std::move(star_expr); + if (star_expr->IsColumns()) { + result->ExpressionMutable() = std::move(star_expr); } else { result = std::move(star_expr); } @@ -2129,11 +2188,11 @@ unique_ptr PEGTransformerFactory::TransformColumnsExpression(P children.push_back(make_uniq()); children.push_back(std::move(expr)); auto list_filter = make_uniq("list_filter", std::move(children)); - result->expr = std::move(list_filter); + result->ExpressionMutable() = std::move(list_filter); } else { - result->expr = std::move(expr); + result->ExpressionMutable() = std::move(expr); } - result->columns = true; + result->IsColumnsMutable() = true; if (unpack) { return make_uniq(ExpressionType::OPERATOR_UNPACK, std::move(result)); } @@ -2162,7 +2221,7 @@ unique_ptr PEGTransformerFactory::TransformExtractArgument(PEG if (choice_pr.type == ParseResultType::STRING) { return make_uniq(Value(choice_pr.Cast().result)); } - auto date_part = transformer.TransformEnum(choice_pr); + auto date_part = transformer.Transform(choice_pr); return make_uniq(EnumUtil::ToString(date_part)); } @@ -2236,11 +2295,45 @@ PEGTransformerFactory::TransformSubstringExpressionList(PEGTransformer &transfor vector> PEGTransformerFactory::TransformSubstringParameters(PEGTransformer &transformer, ParseResult &parse_result) { + // SubstringParameters <- Expression SubstringFromFor auto &list_pr = parse_result.Cast(); vector> results; results.push_back(transformer.Transform>(list_pr.GetChild(0))); - results.push_back(transformer.Transform>(list_pr.GetChild(2))); - results.push_back(transformer.Transform>(list_pr.GetChild(4))); + auto from_for = transformer.Transform>>(list_pr.GetChild(1)); + for (auto &arg : from_for) { + results.push_back(std::move(arg)); + } + return results; +} + +vector> PEGTransformerFactory::TransformSubstringFromFor(PEGTransformer &transformer, + ParseResult &parse_result) { + // SubstringFromFor <- SubstringFromOptionalFor / SubstringFor + auto &list_pr = parse_result.Cast(); + return transformer.Transform>>(list_pr.Child(0).GetResult()); +} + +vector> +PEGTransformerFactory::TransformSubstringFromOptionalFor(PEGTransformer &transformer, ParseResult &parse_result) { + // SubstringFromOptionalFor <- FromExpression ForExpression? + auto &list_pr = parse_result.Cast(); + vector> results; + results.push_back(transformer.Transform>(list_pr.GetChild(0))); + auto &for_opt = list_pr.Child(1); + if (for_opt.HasResult()) { + results.push_back(transformer.Transform>(for_opt.GetResult())); + } + return results; +} + +vector> PEGTransformerFactory::TransformSubstringFor(PEGTransformer &transformer, + ParseResult &parse_result) { + // SubstringFor <- ForExpression + // SUBSTRING(str FOR len) => substring(str, 1, len) + auto &list_pr = parse_result.Cast(); + vector> results; + results.push_back(make_uniq(Value::INTEGER(1))); + results.push_back(transformer.Transform>(list_pr.GetChild(0))); return results; } @@ -2263,7 +2356,8 @@ unique_ptr PEGTransformerFactory::TransformTrimExpression(PEGT trim_expressions.push_back(std::move(trim_source_expr)); } } - return make_uniq(INVALID_CATALOG, DEFAULT_SCHEMA, function_name, std::move(trim_expressions)); + return make_uniq(INVALID_CATALOG, DEFAULT_SCHEMA, Identifier(function_name), + std::move(trim_expressions)); } string PEGTransformerFactory::TransformTrimDirection(PEGTransformer &transformer, ParseResult &parse_result) { @@ -2281,6 +2375,63 @@ unique_ptr PEGTransformerFactory::TransformTrimSource(PEGTrans return nullptr; } +unique_ptr PEGTransformerFactory::TransformOverlayExpression(PEGTransformer &transformer, + ParseResult &parse_result) { + // OverlayExpression <- 'OVERLAY' Parens(OverlayArguments) + auto &list_pr = parse_result.Cast(); + auto &extract_parens = ExtractResultFromParens(list_pr.Child(1)); + auto args = transformer.Transform>>(extract_parens); + return make_uniq(INVALID_CATALOG, DEFAULT_SCHEMA, "overlay", std::move(args)); +} + +vector> PEGTransformerFactory::TransformOverlayArguments(PEGTransformer &transformer, + ParseResult &parse_result) { + // OverlayArguments <- OverlayParameters / OverlayExpressionList + auto &list_pr = parse_result.Cast(); + return transformer.Transform>>(list_pr.Child(0).GetResult()); +} + +vector> PEGTransformerFactory::TransformOverlayParameters(PEGTransformer &transformer, + ParseResult &parse_result) { + // OverlayParameters <- Expression 'PLACING' Expression FromExpression ForExpression? + // Children: 0=expr(A), 1='PLACING', 2=expr(B), 3=FromExpression, 4=optional(ForExpression) + auto &list_pr = parse_result.Cast(); + vector> results; + results.push_back(transformer.Transform>(list_pr.GetChild(0))); // A + results.push_back(transformer.Transform>(list_pr.GetChild(2))); // B (after PLACING) + results.push_back(transformer.Transform>(list_pr.GetChild(3))); // C FromExpression + auto &for_opt = list_pr.Child(4); // D ForExpression? + if (for_opt.HasResult()) { + auto &for_pr = for_opt.GetResult(); + results.push_back(transformer.Transform>(for_pr)); + } + return results; +} + +unique_ptr PEGTransformerFactory::TransformFromExpression(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + return transformer.Transform>(list_pr.GetChild(1)); // after FROM +} + +unique_ptr PEGTransformerFactory::TransformForExpression(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + return transformer.Transform>(list_pr.GetChild(1)); // after FOR +} + +vector> PEGTransformerFactory::TransformOverlayExpressionList(PEGTransformer &transformer, + ParseResult &parse_result) { + // OverlayExpressionList <- List(Expression) + auto &list_pr = parse_result.Cast(); + vector> results; + auto expr_list = ExtractParseResultsFromList(list_pr.Child(0)); + for (const auto expr : expr_list) { + results.push_back(transformer.Transform>(expr)); + } + return results; +} + unique_ptr PEGTransformerFactory::TransformPositionExpression(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); @@ -2329,13 +2480,13 @@ unique_ptr PEGTransformerFactory::TransformCaseExpression(PEGT new_case.when_expr = std::move(case_expr.when_expr); } new_case.then_expr = std::move(case_expr.then_expr); - result->case_checks.push_back(std::move(new_case)); + result->CaseChecksMutable().push_back(std::move(new_case)); } auto &else_expr_opt = list_pr.Child(3); if (else_expr_opt.HasResult()) { - result->else_expr = transformer.Transform>(else_expr_opt.GetResult()); + result->ElseMutable() = transformer.Transform>(else_expr_opt.GetResult()); } else { - result->else_expr = make_uniq(Value()); + result->ElseMutable() = make_uniq(Value()); } return std::move(result); } @@ -2363,7 +2514,7 @@ unique_ptr PEGTransformerFactory::TransformTypeLiteral(PEGTran throw ParserException("Cannot convert to type %s, requires exactly one type modifier", EnumUtil::ToString(type.id())); } - if (type == LogicalTypeId::UNBOUND) { + if (type == LogicalTypeId::UNBOUND || type.InternalType() == PhysicalType::INVALID) { type = LogicalType::UNBOUND(make_uniq(colid, vector>())); } auto string_literal = list_pr.Child(1).result; @@ -2399,7 +2550,7 @@ unique_ptr PEGTransformerFactory::TransformIntervalLiteral(PEG } vector> children; children.push_back(std::move(expr)); - auto result = make_uniq(func_name, std::move(children)); + auto result = make_uniq(Identifier(func_name), std::move(children)); return std::move(result); } @@ -2422,20 +2573,20 @@ unique_ptr PEGTransformerFactory::TransformSubqueryExpression( auto result = make_uniq(); if (is_exists) { - result->subquery_type = SubqueryType::EXISTS; + result->GetSubqueryTypeMutable() = SubqueryType::EXISTS; } else { - result->subquery_type = SubqueryType::SCALAR; + result->GetSubqueryTypeMutable() = SubqueryType::SCALAR; } if (subquery_reference->type == TableReferenceType::SUBQUERY) { auto &subquery_ref = subquery_reference->Cast(); - result->subquery = std::move(subquery_ref.subquery); + result->SubqueryMutable() = std::move(subquery_ref.subquery); } else { auto select_statement = make_uniq(); auto select_node = make_uniq(); select_node->select_list.push_back(make_uniq()); select_node->from_table = std::move(subquery_reference); select_statement->node = std::move(select_node); - result->subquery = std::move(select_statement); + result->SubqueryMutable() = std::move(select_statement); } if (is_not) { vector> children; @@ -2513,12 +2664,9 @@ unique_ptr PEGTransformerFactory::TransformListComprehensionEx auto filter_expr = transformer.Transform>(list_comprehension_filter.GetResult()); // STAGE 1: list_apply(in_expr, x -> struct_pack(filter := ..., result := ...)) - filter_expr->SetAlias("filter"); - result_expr->SetAlias("result"); - - vector> struct_children; - struct_children.push_back(std::move(filter_expr)); - struct_children.push_back(std::move(result_expr)); + vector struct_children; + struct_children.emplace_back("filter", std::move(filter_expr)); + struct_children.emplace_back("result", std::move(result_expr)); auto struct_pack = make_uniq(INVALID_CATALOG, DEFAULT_SCHEMA, "struct_pack", std::move(struct_children)); @@ -2618,7 +2766,7 @@ pair> PEGTransformerFactory::TransformRepla } auto &col_ref = column_reference->Cast(); auto column_name = col_ref.GetColumnName(); - return make_pair(column_name, std::move(expr)); + return make_pair(column_name.GetIdentifierName(), std::move(expr)); } ExpressionType PEGTransformerFactory::TransformIsDistinctFromOp(PEGTransformer &transformer, @@ -2682,7 +2830,7 @@ pair PEGTransformerFactory::TransformRenameEntry(PE auto &list_pr = parse_result.Cast(); auto column_name = transformer.Transform(list_pr.GetChild(0)); auto alias = list_pr.Child(2).identifier; - return make_pair(column_name, alias); + return make_pair(column_name, alias.GetIdentifierName()); } bool PEGTransformerFactory::TransformIgnoreOrRespectNulls(PEGTransformer &transformer, ParseResult &parse_result) { diff --git a/src/duckdb/src/parser/peg/transformer/transform_generated.cpp b/src/duckdb/src/parser/peg/transformer/transform_generated.cpp index d5d1fa115..110fe9037 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_generated.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_generated.cpp @@ -3,45 +3,4901 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformUseStatementInternal(PEGTransformer &transformer, +unique_ptr PEGTransformerFactory::TransformAlterStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto alter_options = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformAlterStatement(transformer, std::move(alter_options)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterTableStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto base_table_name = transformer.Transform>(list_pr.GetChild(2)); + vector> alter_table_options; + auto alter_table_options_items = ExtractParseResultsFromList(list_pr.GetChild(3)); + for (auto &alter_table_options_item : alter_table_options_items) { + auto alter_table_options_value = + transformer.Transform>(alter_table_options_item.get()); + alter_table_options.push_back(std::move(alter_table_options_value)); + } + auto result = + TransformAlterTableStmt(transformer, if_exists, std::move(base_table_name), std::move(alter_table_options)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterSchemaStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto qualified_name = transformer.Transform(list_pr.GetChild(2)); + auto rename_alter = transformer.Transform>(list_pr.GetChild(3)); + auto result = TransformAlterSchemaStmt(transformer, if_exists, qualified_name, std::move(rename_alter)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterTableOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAddConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto top_level_constraint = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformAddConstraint(transformer, std::move(top_level_constraint)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAddColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(2).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto add_column_entry = transformer.Transform(list_pr.GetChild(3)); + auto result = TransformAddColumn(transformer, if_not_exists, std::move(add_column_entry)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAddColumnEntryInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto dotted_identifier = transformer.Transform>(list_pr.GetChild(0)); + LogicalType type {}; + auto &type_opt = list_pr.GetChild(1).Cast(); + if (type_opt.HasResult()) { + type = transformer.Transform(type_opt.GetResult()); + } + GeneratedColumnDefinition generated_column {}; + auto &generated_column_opt = list_pr.GetChild(2).Cast(); + if (generated_column_opt.HasResult()) { + generated_column = transformer.Transform(generated_column_opt.GetResult()); + } + vector column_constraint {}; + auto &column_constraint_opt = list_pr.GetChild(3).Cast(); + if (column_constraint_opt.HasResult()) { + auto &column_constraint_repeat_1 = column_constraint_opt.GetResult().Cast(); + for (auto &column_constraint_item_1 : column_constraint_repeat_1.GetChildren()) { + auto column_constraint_value_1 = + transformer.Transform(column_constraint_item_1.get()); + column_constraint.push_back(std::move(column_constraint_value_1)); + } + } + auto result = TransformAddColumnEntry(transformer, dotted_identifier, type, std::move(generated_column), + std::move(column_constraint)); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(2).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto nested_column_name = transformer.Transform>(list_pr.GetChild(3)); + bool drop_behavior {}; + auto &drop_behavior_opt = list_pr.GetChild(4).Cast(); + if (drop_behavior_opt.HasResult()) { + drop_behavior = transformer.Transform(drop_behavior_opt.GetResult()); + } + auto result = TransformDropColumn(transformer, if_exists, std::move(nested_column_name), drop_behavior); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto nested_column_name = transformer.Transform>(list_pr.GetChild(2)); + auto alter_column_entry = transformer.Transform>(list_pr.GetChild(3)); + auto result = TransformAlterColumn(transformer, std::move(nested_column_name), std::move(alter_column_entry)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformRenameColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto nested_column_name = transformer.Transform>(list_pr.GetChild(2)); + auto identifier = list_pr.GetChild(4).Cast().identifier; + auto result = TransformRenameColumn(transformer, std::move(nested_column_name), identifier); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformNestedColumnNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector identifier_dot {}; + auto &identifier_dot_opt = list_pr.GetChild(0).Cast(); + if (identifier_dot_opt.HasResult()) { + auto &identifier_dot_repeat_1 = identifier_dot_opt.GetResult().Cast(); + for (auto &identifier_dot_item_1 : identifier_dot_repeat_1.GetChildren()) { + auto identifier_dot_value_1 = transformer.Transform(identifier_dot_item_1.get()); + identifier_dot.push_back(identifier_dot_value_1); + } + } + auto column_name = list_pr.GetChild(1).Cast().identifier; + auto result = TransformNestedColumnName(transformer, identifier_dot, column_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformIdentifierDotInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(0).Cast().identifier; + auto result = TransformIdentifierDot(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformRenameAlterInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(2).Cast().identifier; + auto result = TransformRenameAlter(transformer, identifier); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSetPartitionedByInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> expression; + auto expression_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(3))); + for (auto &expression_item : expression_items) { + auto expression_value = transformer.Transform>(expression_item.get()); + expression.push_back(std::move(expression_value)); + } + auto result = TransformSetPartitionedBy(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformResetPartitionedByInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformResetPartitionedBy(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSetSortedByInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto order_by_expressions = + transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(3))); + auto result = TransformSetSortedBy(transformer, std::move(order_by_expressions)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformResetSortedByInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformResetSortedBy(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSetOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto rel_option_list = + transformer.Transform>>(list_pr.GetChild(1)); + auto result = TransformSetOptions(transformer, std::move(rel_option_list)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformResetOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto rel_option_list = + transformer.Transform>>(list_pr.GetChild(1)); + auto result = TransformResetOptions(transformer, std::move(rel_option_list)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterColumnEntryInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAddOrDropDefaultInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAddDefaultInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformAddDefault(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropDefaultInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDropDefault(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformChangeNullabilityInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto drop_or_set = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformChangeNullability(transformer, drop_or_set); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropOrSetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDropNullabilityInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDropNullability(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSetNullabilityInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformSetNullability(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformAlterTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + LogicalType type {}; + auto &type_opt = list_pr.GetChild(2).Cast(); + if (type_opt.HasResult()) { + type = transformer.Transform(type_opt.GetResult()); + } + unique_ptr using_expression {}; + auto &using_expression_opt = list_pr.GetChild(3).Cast(); + if (using_expression_opt.HasResult()) { + using_expression = transformer.Transform>(using_expression_opt.GetResult()); + } + auto result = TransformAlterType(transformer, type, std::move(using_expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUsingExpressionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformUsingExpression(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterViewStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto base_table_name = transformer.Transform>(list_pr.GetChild(2)); + auto rename_alter = transformer.Transform>(list_pr.GetChild(3)); + auto result = TransformAlterViewStmt(transformer, if_exists, std::move(base_table_name), std::move(rename_alter)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterSequenceStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto qualified_sequence_name = transformer.Transform(list_pr.GetChild(2)); + auto alter_sequence_options = transformer.Transform>(list_pr.GetChild(3)); + auto result = + TransformAlterSequenceStmt(transformer, if_exists, qualified_sequence_name, std::move(alter_sequence_options)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformQualifiedSequenceNameInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + Identifier catalog_qualification {}; + auto &catalog_qualification_opt = list_pr.GetChild(0).Cast(); + if (catalog_qualification_opt.HasResult()) { + catalog_qualification = transformer.Transform(catalog_qualification_opt.GetResult()); + } + Identifier schema_qualification {}; + auto &schema_qualification_opt = list_pr.GetChild(1).Cast(); + if (schema_qualification_opt.HasResult()) { + schema_qualification = transformer.Transform(schema_qualification_opt.GetResult()); + } + auto sequence_name = list_pr.GetChild(2).Cast().identifier; + auto result = + TransformQualifiedSequenceName(transformer, catalog_qualification, schema_qualification, sequence_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformAlterSequenceOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = TransformAlterSequenceOptions(transformer, choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSetSequenceOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector>> sequence_option; + auto &sequence_option_repeat = list_pr.GetChild(0).Cast(); + for (auto &sequence_option_item : sequence_option_repeat.GetChildren()) { + auto sequence_option_value = + transformer.Transform>>(sequence_option_item.get()); + sequence_option.push_back(std::move(sequence_option_value)); + } + auto result = TransformSetSequenceOption(transformer, std::move(sequence_option)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAlterDatabaseStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto identifier = list_pr.GetChild(2).Cast().identifier; + auto identifier_1 = list_pr.GetChild(6).Cast().identifier; + auto result = TransformAlterDatabaseStmt(transformer, if_exists, identifier, identifier_1); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAnalyzeStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool analyze_verbose {}; + auto &analyze_verbose_opt = list_pr.GetChild(1).Cast(); + if (analyze_verbose_opt.HasResult()) { + analyze_verbose = transformer.Transform(analyze_verbose_opt.GetResult()); + } + AnalyzeTarget analyze_target {}; + auto &analyze_target_opt = list_pr.GetChild(2).Cast(); + if (analyze_target_opt.HasResult()) { + analyze_target = transformer.Transform(analyze_target_opt.GetResult()); + } + auto result = TransformAnalyzeStatement(transformer, analyze_verbose, std::move(analyze_target)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAnalyzeTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); + vector name_list {}; + auto &name_list_opt = list_pr.GetChild(1).Cast(); + if (name_list_opt.HasResult()) { + name_list = transformer.Transform>(name_list_opt.GetResult()); + } + auto result = TransformAnalyzeTarget(transformer, std::move(base_table_name), name_list); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAnalyzeVerboseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformAnalyzeVerbose(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformAttachStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool or_replace {}; + auto &or_replace_opt = list_pr.GetChild(1).Cast(); + if (or_replace_opt.HasResult()) { + or_replace = transformer.Transform(or_replace_opt.GetResult()); + } + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(2).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto database_path = transformer.Transform>(list_pr.GetChild(4)); + Identifier attach_alias {}; + auto &attach_alias_opt = list_pr.GetChild(5).Cast(); + if (attach_alias_opt.HasResult()) { + attach_alias = transformer.Transform(attach_alias_opt.GetResult()); + } + vector attach_options {}; + auto &attach_options_opt = list_pr.GetChild(6).Cast(); + if (attach_options_opt.HasResult()) { + attach_options = transformer.Transform>(attach_options_opt.GetResult()); + } + auto result = TransformAttachStatement(transformer, or_replace, if_not_exists, std::move(database_path), + attach_alias, attach_options); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDatabasePathInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformDatabasePath(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAttachAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformAttachAlias(transformer, col_id); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformAttachOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto generic_copy_option_list = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformAttachOptions(transformer, generic_copy_option_list); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformCallStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto qualified_table_function = transformer.Transform(list_pr.GetChild(1)); + auto table_function_arguments = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformCallStatement(transformer, qualified_table_function, std::move(table_function_arguments)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCheckpointStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool checkpoint_force {}; + auto &checkpoint_force_opt = list_pr.GetChild(0).Cast(); + if (checkpoint_force_opt.HasResult()) { + checkpoint_force = transformer.Transform(checkpoint_force_opt.GetResult()); + } + Identifier catalog_name {}; + auto &catalog_name_opt = list_pr.GetChild(2).Cast(); + if (catalog_name_opt.HasResult()) { + catalog_name = catalog_name_opt.GetResult().Cast().identifier; + } + auto result = TransformCheckpointStatement(transformer, checkpoint_force, catalog_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCheckpointForceInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCheckpointForce(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto comment_on_type = transformer.Transform(list_pr.GetChild(2)); + auto dotted_identifier = transformer.Transform>(list_pr.GetChild(3)); + auto comment_value = transformer.Transform(list_pr.GetChild(5)); + auto result = TransformCommentStatement(transformer, comment_on_type, dotted_identifier, comment_value); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCommentOnTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentTableInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentTable(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentSequenceInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentSequence(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentFunction(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentMacroTableInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentMacroTable(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentMacroInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentMacro(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentViewInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentView(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentDatabaseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentDatabase(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentIndex(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentSchemaInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentSchema(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommentColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCommentColumn(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformExpressionStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> expression_alias; + auto expression_alias_items = ExtractParseResultsFromList(list_pr.GetChild(0)); + for (auto &expression_alias_item : expression_alias_items) { + auto expression_alias_value = transformer.Transform>(expression_alias_item.get()); + expression_alias.push_back(std::move(expression_alias_value)); + } + auto result = TransformExpressionStatement(transformer, std::move(expression_alias)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformExpressionAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformConstraintNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id_or_string = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformConstraintName(transformer, col_id_or_string); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCollationNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(0).Cast().identifier; + auto result = TransformCollationName(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto type_variations = transformer.Transform>(list_pr.GetChild(0)); + vector array_bounds {}; + auto &array_bounds_opt = list_pr.GetChild(1).Cast(); + if (array_bounds_opt.HasResult()) { + auto &array_bounds_repeat_1 = array_bounds_opt.GetResult().Cast(); + for (auto &array_bounds_item_1 : array_bounds_repeat_1.GetChildren()) { + auto array_bounds_value_1 = transformer.Transform(array_bounds_item_1.get()); + array_bounds.push_back(array_bounds_value_1); + } + } + auto result = TransformType(transformer, std::move(type_variations), array_bounds); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTypeVariationsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSimpleTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCharacterSimpleTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto character_type = transformer.Transform(list_pr.GetChild(0)); + vector> type_modifiers {}; + auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); + if (type_modifiers_opt.HasResult()) { + type_modifiers = transformer.Transform>>(type_modifiers_opt.GetResult()); + } + auto result = TransformCharacterSimpleType(transformer, character_type, std::move(type_modifiers)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformQualifiedSimpleTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto qualified_type_name = transformer.Transform(list_pr.GetChild(0)); + vector> type_modifiers {}; + auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); + if (type_modifiers_opt.HasResult()) { + type_modifiers = transformer.Transform>>(type_modifiers_opt.GetResult()); + } + auto result = TransformQualifiedSimpleType(transformer, qualified_type_name, std::move(type_modifiers)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCharacterTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCharacterType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIntervalTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformIntervalIntervalInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformIntervalWithSpecifierInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformIntervalWithRangeSpecifierInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto interval_to_interval_as_type = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformIntervalWithRangeSpecifier(transformer, interval_to_interval_as_type); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformIntervalWithSimpleSpecifierInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto interval = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformIntervalWithSimpleSpecifier(transformer, interval); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformIntervalWithoutSpecifierInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformIntervalWithoutSpecifier(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformYearKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformYearKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMonthKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformMonthKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDayKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDayKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformHourKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformHourKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMinuteKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformMinuteKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSecondKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformSecondKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMillisecondKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformMillisecondKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMicrosecondKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformMicrosecondKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformWeekKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformWeekKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformQuarterKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformQuarterKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDecadeKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDecadeKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCenturyKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCenturyKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMillenniumKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformMillenniumKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIntervalInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIntervalToIntervalInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformYearToMonthInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto year_keyword = transformer.Transform(list_pr.GetChild(0)); + auto month_keyword = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformYearToMonth(transformer, year_keyword, month_keyword); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDayToHourInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto day_keyword = transformer.Transform(list_pr.GetChild(0)); + auto hour_keyword = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformDayToHour(transformer, day_keyword, hour_keyword); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDayToMinuteInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto day_keyword = transformer.Transform(list_pr.GetChild(0)); + auto minute_keyword = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformDayToMinute(transformer, day_keyword, minute_keyword); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDayToSecondInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto day_keyword = transformer.Transform(list_pr.GetChild(0)); + auto second_keyword = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformDayToSecond(transformer, day_keyword, second_keyword); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformHourToMinuteInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto hour_keyword = transformer.Transform(list_pr.GetChild(0)); + auto minute_keyword = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformHourToMinute(transformer, hour_keyword, minute_keyword); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformHourToSecondInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto hour_keyword = transformer.Transform(list_pr.GetChild(0)); + auto second_keyword = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformHourToSecond(transformer, hour_keyword, second_keyword); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMinuteToSecondInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto minute_keyword = transformer.Transform(list_pr.GetChild(0)); + auto second_keyword = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformMinuteToSecond(transformer, minute_keyword, second_keyword); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformBitTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> expression {}; + auto &expression_opt = list_pr.GetChild(2).Cast(); + if (expression_opt.HasResult()) { + auto expression_items_1 = ExtractParseResultsFromList(ExtractResultFromParens(expression_opt.GetResult())); + for (auto &expression_item_1 : expression_items_1) { + auto expression_value_1 = transformer.Transform>(expression_item_1.get()); + expression.push_back(std::move(expression_value_1)); + } + } + auto result = TransformBitType(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformGeometryTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + unique_ptr expression {}; + auto &expression_opt = list_pr.GetChild(1).Cast(); + if (expression_opt.HasResult()) { + expression = + transformer.Transform>(ExtractResultFromParens(expression_opt.GetResult())); + } + auto result = TransformGeometryType(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformVariantTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformVariantType(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformNumericTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSimpleNumericTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto &choice_result = choice_pr.GetResult(); + string child; + if (choice_result.type == ParseResultType::IDENTIFIER) { + child = choice_result.Cast().identifier.GetIdentifierName(); + } else if (choice_result.type == ParseResultType::KEYWORD) { + child = choice_result.Cast().keyword; + } else if (choice_result.type == ParseResultType::STRING) { + child = choice_result.Cast().result; + } else { + child = transformer.Transform(choice_result); + } + auto result = TransformSimpleNumericType(transformer, child); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDecimalNumericTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformIntTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformIntType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIntegerTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformIntegerType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSmallintTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformSmallintType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformBigintTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformBigintType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformRealTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformRealType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformBooleanTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformBooleanType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDoubleTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDoubleType(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformFloatTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + unique_ptr number_literal {}; + auto &number_literal_opt = list_pr.GetChild(1).Cast(); + if (number_literal_opt.HasResult()) { + number_literal = transformer.Transform>( + ExtractResultFromParens(number_literal_opt.GetResult())); + } + auto result = TransformFloatType(transformer, std::move(number_literal)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDecimalTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> type_modifiers {}; + auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); + if (type_modifiers_opt.HasResult()) { + type_modifiers = transformer.Transform>>(type_modifiers_opt.GetResult()); + } + auto result = TransformDecimalType(transformer, std::move(type_modifiers)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDecTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> type_modifiers {}; + auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); + if (type_modifiers_opt.HasResult()) { + type_modifiers = transformer.Transform>>(type_modifiers_opt.GetResult()); + } + auto result = TransformDecType(transformer, std::move(type_modifiers)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformNumericModTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> type_modifiers {}; + auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); + if (type_modifiers_opt.HasResult()) { + type_modifiers = transformer.Transform>>(type_modifiers_opt.GetResult()); + } + auto result = TransformNumericModType(transformer, std::move(type_modifiers)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformQualifiedTypeNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformTypeNameAsQualifiedNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto type_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformTypeNameAsQualifiedName(transformer, type_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformCatalogReservedSchemaTypeNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_qualification = transformer.Transform(list_pr.GetChild(0)); + auto reserved_schema_qualification = transformer.Transform(list_pr.GetChild(1)); + auto reserved_type_name = list_pr.GetChild(2).Cast().identifier; + auto result = TransformCatalogReservedSchemaTypeName(transformer, catalog_qualification, + reserved_schema_qualification, reserved_type_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformSchemaReservedTypeNameInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto schema_qualification = transformer.Transform(list_pr.GetChild(0)); + auto reserved_type_name = list_pr.GetChild(1).Cast().identifier; + auto result = TransformSchemaReservedTypeName(transformer, schema_qualification, reserved_type_name); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTypeModifiersInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> expression {}; + auto &expression_opt = ExtractResultFromParens(list_pr.GetChild(0)).Cast(); + if (expression_opt.HasResult()) { + auto expression_items_1 = ExtractParseResultsFromList(expression_opt.GetResult()); + for (auto &expression_item_1 : expression_items_1) { + auto expression_value_1 = transformer.Transform>(expression_item_1.get()); + expression.push_back(std::move(expression_value_1)); + } + } + auto result = TransformTypeModifiers(transformer, std::move(expression)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformRowTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id_type_list = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformRowType(transformer, col_id_type_list); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSetofTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto type = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformSetofType(transformer, type); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUnionTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id_type_list = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformUnionType(transformer, col_id_type_list); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformColIdTypeListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> col_id_type; + auto col_id_type_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &col_id_type_item : col_id_type_items) { + auto col_id_type_value = transformer.Transform>(col_id_type_item.get()); + col_id_type.push_back(col_id_type_value); + } + auto result = TransformColIdTypeList(transformer, col_id_type); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformMapTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector type; + auto type_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(1))); + for (auto &type_item : type_items) { + auto type_value = transformer.Transform(type_item.get()); + type.push_back(type_value); + } + auto result = TransformMapType(transformer, type); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformColIdTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id = transformer.Transform(list_pr.GetChild(0)); + auto type = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformColIdType(transformer, col_id, type); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformArrayBoundsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformArrayKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformArrayKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformSquareBracketsArrayInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + unique_ptr expression {}; + auto &expression_opt = list_pr.GetChild(1).Cast(); + if (expression_opt.HasResult()) { + expression = transformer.Transform>(expression_opt.GetResult()); + } + auto result = TransformSquareBracketsArray(transformer, std::move(expression)); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTimeTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto time_or_timestamp = transformer.Transform(list_pr.GetChild(0)); + vector> type_modifiers {}; + auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); + if (type_modifiers_opt.HasResult()) { + type_modifiers = transformer.Transform>>(type_modifiers_opt.GetResult()); + } + bool time_zone {}; + auto &time_zone_opt = list_pr.GetChild(2).Cast(); + if (time_zone_opt.HasResult()) { + time_zone = transformer.Transform(time_zone_opt.GetResult()); + } + auto result = TransformTimeType(transformer, time_or_timestamp, std::move(type_modifiers), time_zone); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformTimeOrTimestampInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTimeTypeIdInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTimeTypeId(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTimestampTypeIdInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTimestampTypeId(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTimeZoneInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto with_or_without = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformTimeZone(transformer, with_or_without); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformWithOrWithoutInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformWithRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformWithRule(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformWithoutRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformWithoutRule(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformDisconnectStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformDisconnectStatement(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCopyStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto copy_variations = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformCopyStatement(transformer, std::move(copy_variations)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCopyVariationsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCopyTableInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); + vector insert_column_list {}; + auto &insert_column_list_opt = list_pr.GetChild(1).Cast(); + if (insert_column_list_opt.HasResult()) { + insert_column_list = transformer.Transform>(insert_column_list_opt.GetResult()); + } + auto from_or_to = transformer.Transform(list_pr.GetChild(2)); + auto copy_file_name = transformer.Transform>(list_pr.GetChild(3)); + vector copy_options {}; + auto ©_options_opt = list_pr.GetChild(4).Cast(); + if (copy_options_opt.HasResult()) { + copy_options = transformer.Transform>(copy_options_opt.GetResult()); + } + auto result = TransformCopyTable(transformer, std::move(base_table_name), insert_column_list, from_or_to, + std::move(copy_file_name), copy_options); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformFromOrToInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCopyFromInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCopyFrom(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCopyToInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCopyTo(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCopySelectInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto select_statement_internal = + transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(0))); + auto copy_file_name = transformer.Transform>(list_pr.GetChild(2)); + vector copy_options {}; + auto ©_options_opt = list_pr.GetChild(3).Cast(); + if (copy_options_opt.HasResult()) { + copy_options = transformer.Transform>(copy_options_opt.GetResult()); + } + auto result = + TransformCopySelect(transformer, std::move(select_statement_internal), std::move(copy_file_name), copy_options); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCopyFileNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCopyFileNameExpressionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCopyFileNameStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformCopyFileNameStringLiteral(transformer, string_literal); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCopyFileNameIdentifierInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(0).Cast().identifier; + auto result = TransformCopyFileNameIdentifier(transformer, identifier); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCopyFileNameIdentifierColIdInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier_col_id = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformCopyFileNameIdentifierColId(transformer, identifier_col_id); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformIdentifierColIdInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(0).Cast().identifier; + auto col_id = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformIdentifierColId(transformer, identifier, col_id); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCopyOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto copy_option_list = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformCopyOptions(transformer, copy_option_list); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformCopyOptionListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(result); +} + +unique_ptr +PEGTransformerFactory::TransformSpecializedOptionListInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector specialized_option {}; + auto &specialized_option_opt = list_pr.GetChild(0).Cast(); + if (specialized_option_opt.HasResult()) { + auto &specialized_option_repeat_1 = specialized_option_opt.GetResult().Cast(); + for (auto &specialized_option_item_1 : specialized_option_repeat_1.GetChildren()) { + auto specialized_option_value_1 = transformer.Transform(specialized_option_item_1.get()); + specialized_option.push_back(specialized_option_value_1); + } + } + auto result = TransformSpecializedOptionList(transformer, specialized_option); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformSpecializedOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSingleOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformBinaryOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformBinaryOption(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformFreezeOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformFreezeOption(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformOidsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformOidsOption(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCsvOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCsvOption(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformHeaderOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformHeaderOption(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformNullAsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformNullAsOption(transformer, string_literal); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDelimiterAsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformDelimiterAsOption(transformer, string_literal); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformQuoteAsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformQuoteAsOption(transformer, string_literal); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformEscapeAsOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformEscapeAsOption(transformer, string_literal); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformEncodingOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformEncodingOption(transformer, string_literal); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformForceQuoteOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool force_quote {}; + auto &force_quote_opt = list_pr.GetChild(0).Cast(); + if (force_quote_opt.HasResult()) { + force_quote = transformer.Transform(force_quote_opt.GetResult()); + } + auto star_symbol_column_list = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformForceQuoteOption(transformer, force_quote, star_symbol_column_list); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformStarSymbolColumnListInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto &choice_result = choice_pr.GetResult(); + vector result {}; + if (!(choice_result.name == "StarSymbol")) { + result = transformer.Transform>(choice_result); + } + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformForceQuoteInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformForceQuote(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformPartitionByOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto star_symbol_column_list = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformPartitionByOption(transformer, star_symbol_column_list); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformForceNullOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool force_not_null {}; + auto &force_not_null_opt = list_pr.GetChild(1).Cast(); + if (force_not_null_opt.HasResult()) { + force_not_null = transformer.Transform(force_not_null_opt.GetResult()); + } + auto column_list = transformer.Transform>(list_pr.GetChild(3)); + auto result = TransformForceNullOption(transformer, force_not_null, column_list); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformForceNotNullInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformForceNotNull(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformGenericCopyOptionListInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector generic_copy_option; + auto generic_copy_option_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &generic_copy_option_item : generic_copy_option_items) { + auto generic_copy_option_value = transformer.Transform(generic_copy_option_item.get()); + generic_copy_option.push_back(generic_copy_option_value); + } + auto result = TransformGenericCopyOptionList(transformer, generic_copy_option); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformGenericCopyOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto copy_option_name = list_pr.GetChild(0).Cast().identifier; + GenericCopyOptionValue generic_copy_option_value {}; + auto &generic_copy_option_value_opt = list_pr.GetChild(1).Cast(); + if (generic_copy_option_value_opt.HasResult()) { + generic_copy_option_value = + transformer.Transform(generic_copy_option_value_opt.GetResult()); + } + auto result = TransformGenericCopyOption(transformer, copy_option_name, std::move(generic_copy_option_value)); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformGenericCopyOptionValueInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformGenericCopyOptionOrderListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto generic_copy_option_parenthesized_expression_list = + transformer.Transform>(list_pr.GetChild(0)); + auto result = + TransformGenericCopyOptionOrderList(transformer, std::move(generic_copy_option_parenthesized_expression_list)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformGenericCopyOptionExpressionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformGenericCopyOptionExpression(transformer, std::move(expression)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformGenericCopyOptionParenthesizedExpressionListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto order_by_expression_list = + transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(0))); + auto result = + TransformGenericCopyOptionParenthesizedExpressionList(transformer, std::move(order_by_expression_list)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCopyFromDatabaseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCopyFromDatabaseWithFlagInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id = transformer.Transform(list_pr.GetChild(2)); + auto col_id_1 = transformer.Transform(list_pr.GetChild(4)); + auto copy_database_flag = transformer.Transform(list_pr.GetChild(5)); + auto result = TransformCopyFromDatabaseWithFlag(transformer, col_id, col_id_1, copy_database_flag); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCopyFromDatabaseWithoutFlagInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id = transformer.Transform(list_pr.GetChild(2)); + auto col_id_1 = transformer.Transform(list_pr.GetChild(4)); + auto result = TransformCopyFromDatabaseWithoutFlag(transformer, col_id, col_id_1); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCopyDatabaseFlagInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto schema_or_data = transformer.Transform(ExtractResultFromParens(list_pr.GetChild(0))); + auto result = TransformCopyDatabaseFlag(transformer, schema_or_data); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSchemaOrDataInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCopySchemaInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCopySchema(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCopyDataInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCopyData(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCreateIndexStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool unique_index {}; + auto &unique_index_opt = list_pr.GetChild(0).Cast(); + if (unique_index_opt.HasResult()) { + unique_index = transformer.Transform(unique_index_opt.GetResult()); + } + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(2).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + Identifier index_name {}; + auto &index_name_opt = list_pr.GetChild(3).Cast(); + if (index_name_opt.HasResult()) { + index_name = index_name_opt.GetResult().Cast().identifier; + } + auto base_table_name = transformer.Transform>(list_pr.GetChild(5)); + vector insert_column_list {}; + auto &insert_column_list_opt = list_pr.GetChild(6).Cast(); + if (insert_column_list_opt.HasResult()) { + insert_column_list = transformer.Transform>(insert_column_list_opt.GetResult()); + } + Identifier index_type {}; + auto &index_type_opt = list_pr.GetChild(7).Cast(); + if (index_type_opt.HasResult()) { + index_type = transformer.Transform(index_type_opt.GetResult()); + } + vector> index_element {}; + auto &index_element_opt = list_pr.GetChild(8).Cast(); + if (index_element_opt.HasResult()) { + auto index_element_items_1 = + ExtractParseResultsFromList(ExtractResultFromParens(index_element_opt.GetResult())); + for (auto &index_element_item_1 : index_element_items_1) { + auto index_element_value_1 = + transformer.Transform>(index_element_item_1.get()); + index_element.push_back(std::move(index_element_value_1)); + } + } + case_insensitive_map_t> with_list {}; + auto &with_list_opt = list_pr.GetChild(9).Cast(); + if (with_list_opt.HasResult()) { + with_list = + transformer.Transform>>(with_list_opt.GetResult()); + } + unique_ptr where_clause {}; + auto &where_clause_opt = list_pr.GetChild(10).Cast(); + if (where_clause_opt.HasResult()) { + where_clause = transformer.Transform>(where_clause_opt.GetResult()); + } + auto result = TransformCreateIndexStmt(transformer, unique_index, if_not_exists, index_name, + std::move(base_table_name), insert_column_list, index_type, + std::move(index_element), std::move(with_list), std::move(where_clause)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformWithListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto rel_option_or_oids = + transformer.Transform>>(list_pr.GetChild(1)); + auto result = TransformWithList(transformer, std::move(rel_option_or_oids)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformRelOptionOrOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>>(choice_pr.GetResult()); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformRelOptionListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector>> rel_option; + auto rel_option_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &rel_option_item : rel_option_items) { + auto rel_option_value = + transformer.Transform>>(rel_option_item.get()); + rel_option.push_back(std::move(rel_option_value)); + } + auto result = TransformRelOptionList(transformer, std::move(rel_option)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto with_or_without_oids = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformOids(transformer, with_or_without_oids); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformWithOrWithoutOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformWithOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformWithOids(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformWithoutOidsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformWithoutOids(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIndexElementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(0)); + OrderType desc_or_asc {}; + auto &desc_or_asc_opt = list_pr.GetChild(1).Cast(); + if (desc_or_asc_opt.HasResult()) { + desc_or_asc = transformer.Transform(desc_or_asc_opt.GetResult()); + } + OrderByNullType nulls_first_or_last {}; + auto &nulls_first_or_last_opt = list_pr.GetChild(2).Cast(); + if (nulls_first_or_last_opt.HasResult()) { + nulls_first_or_last = transformer.Transform(nulls_first_or_last_opt.GetResult()); + } + auto result = TransformIndexElement(transformer, std::move(expression), desc_or_asc, nulls_first_or_last); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUniqueIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformUniqueIndex(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIndexTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + auto result = TransformIndexType(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformRelOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto rel_option_name = transformer.Transform(list_pr.GetChild(0)); + unique_ptr rel_option_argument_opt {}; + auto &rel_option_argument_opt_opt = list_pr.GetChild(1).Cast(); + if (rel_option_argument_opt_opt.HasResult()) { + rel_option_argument_opt = + transformer.Transform>(rel_option_argument_opt_opt.GetResult()); + } + auto result = TransformRelOption(transformer, rel_option_name, std::move(rel_option_argument_opt)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformRelOptionNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto &choice_result = choice_pr.GetResult(); + string child; + if (choice_result.type == ParseResultType::IDENTIFIER) { + child = choice_result.Cast().identifier.GetIdentifierName(); + } else if (choice_result.type == ParseResultType::KEYWORD) { + child = choice_result.Cast().keyword; + } else if (choice_result.type == ParseResultType::STRING) { + child = choice_result.Cast().result; + } else { + child = transformer.Transform(choice_result); + } + auto result = TransformRelOptionName(transformer, child); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformDottedIdentifierStringInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto dotted_identifier = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformDottedIdentifierString(transformer, dotted_identifier); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformRelOptionArgumentOptInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto def_arg = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformRelOptionArgumentOpt(transformer, std::move(def_arg)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDefArgInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDefArgNullInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto null_literal = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformDefArgNull(transformer, null_literal); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDefArgKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto reserved_keyword = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformDefArgKeyword(transformer, reserved_keyword); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformDefArgStringLiteralInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformDefArgStringLiteral(transformer, string_literal); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformNoneLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformNoneLiteral(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateMacroStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto macro_or_function = transformer.Transform(list_pr.GetChild(0)); + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto qualified_name = transformer.Transform(list_pr.GetChild(2)); + vector> macro_definition; + auto macro_definition_items = ExtractParseResultsFromList(list_pr.GetChild(3)); + for (auto ¯o_definition_item : macro_definition_items) { + auto macro_definition_value = transformer.Transform>(macro_definition_item.get()); + macro_definition.push_back(std::move(macro_definition_value)); + } + auto result = TransformCreateMacroStmt(transformer, macro_or_function, if_not_exists, qualified_name, + std::move(macro_definition)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformMacroOrFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMacroKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformMacroKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformFunctionKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformFunctionKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMacroDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector macro_parameters {}; + auto ¯o_parameters_opt = ExtractResultFromParens(list_pr.GetChild(0)).Cast(); + if (macro_parameters_opt.HasResult()) { + macro_parameters = transformer.Transform>(macro_parameters_opt.GetResult()); + } + auto macro_definition_body = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformMacroDefinition(transformer, std::move(macro_parameters), std::move(macro_definition_body)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformMacroDefinitionBodyInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformMacroParametersInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector macro_parameter; + auto macro_parameter_items = ExtractParseResultsFromList(list_pr.GetChild(0)); + for (auto ¯o_parameter_item : macro_parameter_items) { + auto macro_parameter_value = transformer.Transform(macro_parameter_item.get()); + macro_parameter.push_back(std::move(macro_parameter_value)); + } + auto result = TransformMacroParameters(transformer, std::move(macro_parameter)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformMacroParameterInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSimpleParameterInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto type_func_name = transformer.Transform(list_pr.GetChild(0)); + LogicalType type {}; + auto &type_opt = list_pr.GetChild(1).Cast(); + if (type_opt.HasResult()) { + type = transformer.Transform(type_opt.GetResult()); + } + auto result = TransformSimpleParameter(transformer, type_func_name, type); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformScalarMacroDefinitionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformScalarMacroDefinition(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformTableMacroDefinitionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto select_statement_internal = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformTableMacroDefinition(transformer, std::move(select_statement_internal)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateSchemaStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto qualified_name = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformCreateSchemaStmt(transformer, if_not_exists, qualified_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateSecretStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + Identifier secret_name {}; + auto &secret_name_opt = list_pr.GetChild(2).Cast(); + if (secret_name_opt.HasResult()) { + secret_name = transformer.Transform(secret_name_opt.GetResult()); + } + Identifier secret_storage_specifier {}; + auto &secret_storage_specifier_opt = list_pr.GetChild(3).Cast(); + if (secret_storage_specifier_opt.HasResult()) { + secret_storage_specifier = transformer.Transform(secret_storage_specifier_opt.GetResult()); + } + auto generic_copy_option_list = transformer.Transform>(list_pr.GetChild(4)); + auto result = TransformCreateSecretStmt(transformer, if_not_exists, secret_name, secret_storage_specifier, + generic_copy_option_list); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformSecretStorageSpecifierInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + auto result = TransformSecretStorageSpecifier(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSecretNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformSecretName(transformer, col_id); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCreateSequenceStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto qualified_name = transformer.Transform(list_pr.GetChild(2)); + vector>> sequence_option {}; + auto &sequence_option_opt = list_pr.GetChild(3).Cast(); + if (sequence_option_opt.HasResult()) { + auto &sequence_option_repeat_1 = sequence_option_opt.GetResult().Cast(); + for (auto &sequence_option_item_1 : sequence_option_repeat_1.GetChildren()) { + auto sequence_option_value_1 = + transformer.Transform>>(sequence_option_item_1.get()); + sequence_option.push_back(std::move(sequence_option_value_1)); + } + } + auto result = TransformCreateSequenceStmt(transformer, if_not_exists, qualified_name, std::move(sequence_option)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSequenceOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>>(choice_pr.GetResult()); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqSetCycleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>>(choice_pr.GetResult()); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqCycleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformSeqCycle(transformer); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqNoCycleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformSeqNoCycle(transformer); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqSetIncrementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformSeqSetIncrement(transformer, std::move(expression)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqSetMinMaxInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto seq_min_or_max = transformer.Transform(list_pr.GetChild(0)); + auto expression = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformSeqSetMinMax(transformer, seq_min_or_max, std::move(expression)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqNoMinMaxInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto seq_min_or_max = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformSeqNoMinMax(transformer, seq_min_or_max); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqStartWithInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformSeqStartWith(transformer, std::move(expression)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqOwnedByInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto qualified_name = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformSeqOwnedBy(transformer, qualified_name); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSeqMinOrMaxInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMinValueInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformMinValue(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMaxValueInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformMaxValue(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCreateStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool or_replace {}; + auto &or_replace_opt = list_pr.GetChild(1).Cast(); + if (or_replace_opt.HasResult()) { + or_replace = transformer.Transform(or_replace_opt.GetResult()); + } + SecretPersistType temporary {}; + auto &temporary_opt = list_pr.GetChild(2).Cast(); + if (temporary_opt.HasResult()) { + temporary = transformer.Transform(temporary_opt.GetResult()); + } + auto create_statement_variation = transformer.Transform>(list_pr.GetChild(3)); + auto result = TransformCreateStatement(transformer, or_replace, temporary, std::move(create_statement_variation)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCreateStatementVariationInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformOrReplaceInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformOrReplace(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTemporaryInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformPersistentInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformPersistent(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTempPersistentInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTempPersistent(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformTemporaryPersistentInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformTemporaryPersistent(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCreateTableStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto qualified_name = transformer.Transform(list_pr.GetChild(2)); + auto create_table_definition = transformer.Transform(list_pr.GetChild(3)); + bool commit_action {}; + auto &commit_action_opt = list_pr.GetChild(4).Cast(); + if (commit_action_opt.HasResult()) { + commit_action = transformer.Transform(commit_action_opt.GetResult()); + } + auto result = TransformCreateTableStmt(transformer, if_not_exists, qualified_name, + std::move(create_table_definition), commit_action); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCreateTableDefinitionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateTableAsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + ColumnList identifier_list {}; + auto &identifier_list_opt = list_pr.GetChild(0).Cast(); + if (identifier_list_opt.HasResult()) { + identifier_list = transformer.Transform(identifier_list_opt.GetResult()); + } + PartitionSortedOptions partition_sorted_options {}; + auto &partition_sorted_options_opt = list_pr.GetChild(1).Cast(); + if (partition_sorted_options_opt.HasResult()) { + partition_sorted_options = + transformer.Transform(partition_sorted_options_opt.GetResult()); + } + case_insensitive_map_t> with_list {}; + auto &with_list_opt = list_pr.GetChild(2).Cast(); + if (with_list_opt.HasResult()) { + with_list = + transformer.Transform>>(with_list_opt.GetResult()); + } + auto statement = transformer.Transform>(list_pr.GetChild(4)); + bool with_data {}; + auto &with_data_opt = list_pr.GetChild(5).Cast(); + if (with_data_opt.HasResult()) { + with_data = transformer.Transform(with_data_opt.GetResult()); + } + auto result = TransformCreateTableAs(transformer, std::move(identifier_list), std::move(partition_sorted_options), + std::move(with_list), std::move(statement), with_data); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformPartitionSortedOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformPartitionOptSortedOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto partition_options = transformer.Transform>>(list_pr.GetChild(0)); + vector> sorted_options {}; + auto &sorted_options_opt = list_pr.GetChild(1).Cast(); + if (sorted_options_opt.HasResult()) { + sorted_options = transformer.Transform>>(sorted_options_opt.GetResult()); + } + auto result = + TransformPartitionOptSortedOptions(transformer, std::move(partition_options), std::move(sorted_options)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformSortedOptPartitionOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto sorted_options = transformer.Transform>>(list_pr.GetChild(0)); + vector> partition_options {}; + auto &partition_options_opt = list_pr.GetChild(1).Cast(); + if (partition_options_opt.HasResult()) { + partition_options = + transformer.Transform>>(partition_options_opt.GetResult()); + } + auto result = + TransformSortedOptPartitionOptions(transformer, std::move(sorted_options), std::move(partition_options)); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformPartitionOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> expression; + auto expression_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(2))); + for (auto &expression_item : expression_items) { + auto expression_value = transformer.Transform>(expression_item.get()); + expression.push_back(std::move(expression_value)); + } + auto result = TransformPartitionOptions(transformer, std::move(expression)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSortedOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> expression; + auto expression_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(2))); + for (auto &expression_item : expression_items) { + auto expression_value = transformer.Transform>(expression_item.get()); + expression.push_back(std::move(expression_value)); + } + auto result = TransformSortedOptions(transformer, std::move(expression)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformWithDataInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformWithDataOnlyInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformWithDataOnly(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformWithNoDataInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformWithNoData(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIdentifierListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector identifier; + auto identifier_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &identifier_item : identifier_items) { + auto identifier_value = identifier_item.get().Cast().identifier; + identifier.push_back(identifier_value); + } + auto result = TransformIdentifierList(transformer, identifier); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + ColumnElements create_table_column_list {}; + auto &create_table_column_list_opt = ExtractResultFromParens(list_pr.GetChild(0)).Cast(); + if (create_table_column_list_opt.HasResult()) { + create_table_column_list = transformer.Transform(create_table_column_list_opt.GetResult()); + } + PartitionSortedOptions partition_sorted_options {}; + auto &partition_sorted_options_opt = list_pr.GetChild(1).Cast(); + if (partition_sorted_options_opt.HasResult()) { + partition_sorted_options = + transformer.Transform(partition_sorted_options_opt.GetResult()); + } + case_insensitive_map_t> with_list {}; + auto &with_list_opt = list_pr.GetChild(2).Cast(); + if (with_list_opt.HasResult()) { + with_list = + transformer.Transform>>(with_list_opt.GetResult()); + } + auto result = TransformCreateColumnList(transformer, std::move(create_table_column_list), + std::move(partition_sorted_options), std::move(with_list)); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformIfNotExistsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformIfNotExists(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformQualifiedNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformSchemaReservedIdentifierOrStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto schema_qualification = transformer.Transform(list_pr.GetChild(0)); + auto reserved_identifier_or_string_literal = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformSchemaReservedIdentifierOrStringLiteral(transformer, schema_qualification, + reserved_identifier_or_string_literal); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformCatalogReservedSchemaIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_qualification = transformer.Transform(list_pr.GetChild(0)); + auto reserved_schema_qualification = transformer.Transform(list_pr.GetChild(1)); + auto reserved_identifier_or_string_literal = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformCatalogReservedSchemaIdentifier( + transformer, catalog_qualification, reserved_schema_qualification, reserved_identifier_or_string_literal); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformIdentifierOrStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto &choice_result = choice_pr.GetResult(); + string child; + if (choice_result.type == ParseResultType::IDENTIFIER) { + child = choice_result.Cast().identifier.GetIdentifierName(); + } else if (choice_result.type == ParseResultType::KEYWORD) { + child = choice_result.Cast().keyword; + } else if (choice_result.type == ParseResultType::STRING) { + child = choice_result.Cast().result; + } else { + child = transformer.Transform(choice_result); + } + auto result = TransformIdentifierOrStringLiteral(transformer, child); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformReservedIdentifierOrStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + Identifier result; + if (choice_pr.GetResult().type == ParseResultType::IDENTIFIER) { + result = choice_pr.GetResult().Cast().identifier; + } else if (choice_pr.GetResult().type == ParseResultType::KEYWORD) { + result = Identifier(choice_pr.GetResult().Cast().keyword); + } else if (choice_pr.GetResult().type == ParseResultType::STRING) { + result = Identifier(choice_pr.GetResult().Cast().result); + } else { + result = transformer.Transform(choice_pr.GetResult()); + } + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformCatalogQualificationInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformCatalogQualification(transformer, catalog_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformSchemaQualificationInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto schema_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformSchemaQualification(transformer, schema_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformReservedSchemaQualificationInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto reserved_schema_name = list_pr.GetChild(0).Cast().identifier; + auto result = reserved_schema_name; + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTableQualificationInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto table_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformTableQualification(transformer, table_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformReservedTableQualificationInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto reserved_table_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformReservedTableQualification(transformer, reserved_table_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformCreateTableColumnListInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector create_table_column_element; + auto create_table_column_element_items = ExtractParseResultsFromList(list_pr.GetChild(0)); + for (auto &create_table_column_element_item : create_table_column_element_items) { + auto create_table_column_element_value = + transformer.Transform(create_table_column_element_item.get()); + create_table_column_element.push_back(std::move(create_table_column_element_value)); + } + auto result = TransformCreateTableColumnList(transformer, std::move(create_table_column_element)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCreateTableColumnElementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCreateTableColumnDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto column_definition = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformCreateTableColumnDefinition(transformer, std::move(column_definition)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCreateTableConstraintInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto top_level_constraint = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformCreateTableConstraint(transformer, std::move(top_level_constraint)); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformColumnDefinitionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto dotted_identifier = transformer.Transform>(list_pr.GetChild(0)); + LogicalType type {}; + auto &type_opt = list_pr.GetChild(1).Cast(); + if (type_opt.HasResult()) { + type = transformer.Transform(type_opt.GetResult()); + } + GeneratedColumnDefinition generated_column {}; + auto &generated_column_opt = list_pr.GetChild(2).Cast(); + if (generated_column_opt.HasResult()) { + generated_column = transformer.Transform(generated_column_opt.GetResult()); + } + vector column_constraint {}; + auto &column_constraint_opt = list_pr.GetChild(4).Cast(); + if (column_constraint_opt.HasResult()) { + auto &column_constraint_repeat_1 = column_constraint_opt.GetResult().Cast(); + for (auto &column_constraint_item_1 : column_constraint_repeat_1.GetChildren()) { + auto column_constraint_value_1 = + transformer.Transform(column_constraint_item_1.get()); + column_constraint.push_back(std::move(column_constraint_value_1)); + } + } + auto result = TransformColumnDefinition(transformer, dotted_identifier, type, std::move(generated_column), + std::move(column_constraint)); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformColumnConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformNotNullConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto child = transformer.Transform(choice_pr.GetResult()); + auto result = TransformNotNullConstraint(transformer, child); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformNullConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformNullConstraint(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformNotNullColumnConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformNotNullColumnConstraint(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformUniqueConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformUniqueConstraint(transformer); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformPrimaryKeyConstraintInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformPrimaryKeyConstraint(transformer); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDefaultValueInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto column_default_expr = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformDefaultValue(transformer, std::move(column_default_expr)); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCheckConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(1))); + auto result = TransformCheckConstraint(transformer, std::move(expression)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformForeignKeyConstraintInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_table_name = transformer.Transform>(list_pr.GetChild(1)); + vector column_list {}; + auto &column_list_opt = list_pr.GetChild(2).Cast(); + if (column_list_opt.HasResult()) { + column_list = transformer.Transform>(ExtractResultFromParens(column_list_opt.GetResult())); + } + auto key_actions = transformer.Transform(list_pr.GetChild(3)); + auto result = TransformForeignKeyConstraint(transformer, std::move(base_table_name), column_list, key_actions); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformColumnCollationInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto dotted_identifier = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformColumnCollation(transformer, dotted_identifier); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformColumnCompressionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id_or_string = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformColumnCompression(transformer, col_id_or_string); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformKeyActionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + string update_action {}; + auto &update_action_opt = list_pr.GetChild(0).Cast(); + if (update_action_opt.HasResult()) { + update_action = transformer.Transform(update_action_opt.GetResult()); + } + string delete_action {}; + auto &delete_action_opt = list_pr.GetChild(1).Cast(); + if (delete_action_opt.HasResult()) { + delete_action = transformer.Transform(delete_action_opt.GetResult()); + } + auto result = TransformKeyActions(transformer, update_action, delete_action); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformUpdateActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto key_action = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformUpdateAction(transformer, key_action); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDeleteActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto key_action = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformDeleteAction(transformer, key_action); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformNoKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformNoKeyAction(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformRestrictKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformRestrictKeyAction(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCascadeKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCascadeKeyAction(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSetNullKeyActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformSetNullKeyAction(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformSetDefaultKeyActionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformSetDefaultKeyAction(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTopLevelConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto top_level_constraint_list = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformTopLevelConstraint(transformer, std::move(top_level_constraint_list)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformTopLevelConstraintListInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = TransformTopLevelConstraintList(transformer, choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformTopPrimaryKeyConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto column_id_list = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformTopPrimaryKeyConstraint(transformer, column_id_list); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformTopUniqueConstraintInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto column_id_list = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformTopUniqueConstraint(transformer, column_id_list); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformTopForeignKeyConstraintInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto column_id_list = transformer.Transform>(list_pr.GetChild(2)); + auto foreign_key_constraint = transformer.Transform(list_pr.GetChild(3)); + auto result = TransformTopForeignKeyConstraint(transformer, column_id_list, std::move(foreign_key_constraint)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformColumnIdListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector col_id; + auto col_id_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &col_id_item : col_id_items) { + auto col_id_value = transformer.Transform(col_id_item.get()); + col_id.push_back(col_id_value); + } + auto result = TransformColumnIdList(transformer, col_id); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformDottedIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(0).Cast().identifier; + vector dot_col_label {}; + auto &dot_col_label_opt = list_pr.GetChild(1).Cast(); + if (dot_col_label_opt.HasResult()) { + auto &dot_col_label_repeat_1 = dot_col_label_opt.GetResult().Cast(); + for (auto &dot_col_label_item_1 : dot_col_label_repeat_1.GetChildren()) { + auto dot_col_label_value_1 = transformer.Transform(dot_col_label_item_1.get()); + dot_col_label.push_back(dot_col_label_value_1); + } + } + auto result = TransformDottedIdentifier(transformer, identifier, dot_col_label); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformDotColLabelInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_label = transformer.Transform(list_pr.GetChild(1)); + auto result = col_label; + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformColIdInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + Identifier result; + if (choice_pr.GetResult().type == ParseResultType::IDENTIFIER) { + result = choice_pr.GetResult().Cast().identifier; + } else if (choice_pr.GetResult().type == ParseResultType::KEYWORD) { + result = Identifier(choice_pr.GetResult().Cast().keyword); + } else if (choice_pr.GetResult().type == ParseResultType::STRING) { + result = Identifier(choice_pr.GetResult().Cast().result); + } else { + result = transformer.Transform(choice_pr.GetResult()); + } + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformColIdOrStringInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = TransformColIdOrString(transformer, choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTypeFuncNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + Identifier result; + if (choice_pr.GetResult().type == ParseResultType::IDENTIFIER) { + result = choice_pr.GetResult().Cast().identifier; + } else if (choice_pr.GetResult().type == ParseResultType::KEYWORD) { + result = Identifier(choice_pr.GetResult().Cast().keyword); + } else if (choice_pr.GetResult().type == ParseResultType::STRING) { + result = Identifier(choice_pr.GetResult().Cast().result); + } else { + result = transformer.Transform(choice_pr.GetResult()); + } + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformGeneratedColumnInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(2))); + bool generated_column_type {}; + auto &generated_column_type_opt = list_pr.GetChild(3).Cast(); + if (generated_column_type_opt.HasResult()) { + generated_column_type = transformer.Transform(generated_column_type_opt.GetResult()); + } + auto result = TransformGeneratedColumn(transformer, std::move(expression), generated_column_type); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformGeneratedColumnTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCommitActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto preserve_or_delete = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformCommitAction(transformer, preserve_or_delete); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformPreserveOrDeleteInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformPreserveRowsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformPreserveRows(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDeleteRowsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDeleteRows(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformVirtualGeneratedColumnInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformVirtualGeneratedColumn(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformStoredGeneratedColumnInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformStoredGeneratedColumn(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCreateTriggerStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto trigger_name = transformer.Transform(list_pr.GetChild(2)); + auto trigger_timing = transformer.Transform(list_pr.GetChild(3)); + auto trigger_event = transformer.Transform(list_pr.GetChild(4)); + auto base_table_name = transformer.Transform>(list_pr.GetChild(6)); + TriggerTableReferencingInfo referencing_clause {}; + auto &referencing_clause_opt = list_pr.GetChild(7).Cast(); + if (referencing_clause_opt.HasResult()) { + referencing_clause = transformer.Transform(referencing_clause_opt.GetResult()); + } + TriggerForEach for_each_clause {}; + auto &for_each_clause_opt = list_pr.GetChild(8).Cast(); + if (for_each_clause_opt.HasResult()) { + for_each_clause = transformer.Transform(for_each_clause_opt.GetResult()); + } + auto trigger_body = transformer.Transform>(list_pr.GetChild(9)); + auto result = TransformCreateTriggerStmt(transformer, if_not_exists, trigger_name, trigger_timing, trigger_event, + std::move(base_table_name), referencing_clause, for_each_clause, + std::move(trigger_body)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformTriggerBodyInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformTriggerNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(0).Cast().identifier; + auto result = TransformTriggerName(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformReferencingClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto referencing_item = transformer.Transform(list_pr.GetChild(1)); + TriggerTableReferencingInfo referencing_item_1 {}; + auto &referencing_item_1_opt = list_pr.GetChild(2).Cast(); + if (referencing_item_1_opt.HasResult()) { + referencing_item_1 = transformer.Transform(referencing_item_1_opt.GetResult()); + } + auto result = TransformReferencingClause(transformer, referencing_item, referencing_item_1); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformReferencingItemInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformReferencingNewTableAsInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id = transformer.Transform(list_pr.GetChild(3)); + auto result = TransformReferencingNewTableAs(transformer, col_id); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformReferencingOldTableAsInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id = transformer.Transform(list_pr.GetChild(3)); + auto result = TransformReferencingOldTableAs(transformer, col_id); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerTimingInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerBeforeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTriggerBefore(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerAfterInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTriggerAfter(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerInsteadOfInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTriggerInsteadOf(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerEventInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerEventInsertInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTriggerEventInsert(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerEventDeleteInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTriggerEventDelete(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerEventUpdateInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformTriggerEventUpdate(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformTriggerEventUpdateOfInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto trigger_column_list = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformTriggerEventUpdateOf(transformer, trigger_column_list); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformTriggerColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector col_id; + auto col_id_items = ExtractParseResultsFromList(list_pr.GetChild(0)); + for (auto &col_id_item : col_id_items) { + auto col_id_value = transformer.Transform(col_id_item.get()); + col_id.push_back(col_id_value); + } + auto result = TransformTriggerColumnList(transformer, col_id); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformForEachClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformForEachRowInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformForEachRow(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformForEachStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformForEachStatement(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformCreateTypeStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto qualified_name = transformer.Transform(list_pr.GetChild(2)); + auto create_type = transformer.Transform>(list_pr.GetChild(4)); + auto result = TransformCreateTypeStmt(transformer, if_not_exists, qualified_name, std::move(create_type)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateTypeFromTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto type = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformCreateTypeFromType(transformer, type); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformEnumSelectTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto select_statement_internal = + transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(1))); + auto result = TransformEnumSelectType(transformer, std::move(select_statement_internal)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformEnumStringLiteralListInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector string_literal {}; + auto &string_literal_opt = ExtractResultFromParens(list_pr.GetChild(1)).Cast(); + if (string_literal_opt.HasResult()) { + auto string_literal_items_1 = ExtractParseResultsFromList(string_literal_opt.GetResult()); + for (auto &string_literal_item_1 : string_literal_items_1) { + auto string_literal_value_1 = transformer.Transform(string_literal_item_1.get()); + string_literal.push_back(string_literal_value_1); + } + } + auto result = TransformEnumStringLiteralList(transformer, string_literal); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateViewStmtInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool create_recursive {}; + auto &create_recursive_opt = list_pr.GetChild(0).Cast(); + if (create_recursive_opt.HasResult()) { + create_recursive = transformer.Transform(create_recursive_opt.GetResult()); + } + bool if_not_exists {}; + auto &if_not_exists_opt = list_pr.GetChild(2).Cast(); + if (if_not_exists_opt.HasResult()) { + if_not_exists = transformer.Transform(if_not_exists_opt.GetResult()); + } + auto qualified_name = transformer.Transform(list_pr.GetChild(3)); + vector insert_column_list {}; + auto &insert_column_list_opt = list_pr.GetChild(4).Cast(); + if (insert_column_list_opt.HasResult()) { + insert_column_list = transformer.Transform>(insert_column_list_opt.GetResult()); + } + case_insensitive_map_t> with_list {}; + auto &with_list_opt = list_pr.GetChild(5).Cast(); + if (with_list_opt.HasResult()) { + with_list = + transformer.Transform>>(with_list_opt.GetResult()); + } + auto select_statement_internal = transformer.Transform>(list_pr.GetChild(7)); + auto result = + TransformCreateViewStmt(transformer, create_recursive, if_not_exists, qualified_name, insert_column_list, + std::move(with_list), std::move(select_statement_internal)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformCreateRecursiveInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformCreateRecursive(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformDeallocateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool deallocate_prepare {}; + auto &deallocate_prepare_opt = list_pr.GetChild(1).Cast(); + if (deallocate_prepare_opt.HasResult()) { + deallocate_prepare = transformer.Transform(deallocate_prepare_opt.GetResult()); + } + auto identifier = list_pr.GetChild(2).Cast().identifier; + auto result = TransformDeallocateStatement(transformer, deallocate_prepare, identifier); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDeallocatePrepareInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDeallocatePrepare(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDeleteStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + CommonTableExpressionMap with_clause {}; + auto &with_clause_opt = list_pr.GetChild(0).Cast(); + if (with_clause_opt.HasResult()) { + with_clause = transformer.Transform(with_clause_opt.GetResult()); + } + auto target_opt_alias = transformer.Transform>(list_pr.GetChild(3)); + vector> delete_using_clause {}; + auto &delete_using_clause_opt = list_pr.GetChild(4).Cast(); + if (delete_using_clause_opt.HasResult()) { + delete_using_clause = transformer.Transform>>(delete_using_clause_opt.GetResult()); + } + unique_ptr where_clause {}; + auto &where_clause_opt = list_pr.GetChild(5).Cast(); + if (where_clause_opt.HasResult()) { + where_clause = transformer.Transform>(where_clause_opt.GetResult()); + } + vector> returning_clause {}; + auto &returning_clause_opt = list_pr.GetChild(6).Cast(); + if (returning_clause_opt.HasResult()) { + returning_clause = + transformer.Transform>>(returning_clause_opt.GetResult()); + } + auto result = + TransformDeleteStatement(transformer, std::move(with_clause), std::move(target_opt_alias), + std::move(delete_using_clause), std::move(where_clause), std::move(returning_clause)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformTruncateStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_table_name = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformTruncateStatement(transformer, std::move(base_table_name)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformTargetOptAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); + Identifier col_id {}; + auto &col_id_opt = list_pr.GetChild(2).Cast(); + if (col_id_opt.HasResult()) { + col_id = transformer.Transform(col_id_opt.GetResult()); + } + auto result = TransformTargetOptAlias(transformer, std::move(base_table_name), col_id); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDeleteUsingClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> table_ref; + auto table_ref_items = ExtractParseResultsFromList(list_pr.GetChild(1)); + for (auto &table_ref_item : table_ref_items) { + auto table_ref_value = transformer.Transform>(table_ref_item.get()); + table_ref.push_back(std::move(table_ref_value)); + } + auto result = TransformDeleteUsingClause(transformer, std::move(table_ref)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDescribeStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto child = transformer.Transform>(choice_pr.GetResult()); + auto result = TransformDescribeStatement(transformer, std::move(child)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformShowSelectInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto show_or_describe_or_summarize = transformer.Transform(list_pr.GetChild(0)); + auto select_statement_internal = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformShowSelect(transformer, show_or_describe_or_summarize, std::move(select_statement_internal)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformShowAllTablesInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto show_or_describe = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformShowAllTables(transformer, show_or_describe); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformShowQualifiedNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto show_or_describe_or_summarize = transformer.Transform(list_pr.GetChild(0)); + DescribeTarget describe_target {}; + auto &describe_target_opt = list_pr.GetChild(1).Cast(); + if (describe_target_opt.HasResult()) { + describe_target = transformer.Transform(describe_target_opt.GetResult()); + } + auto result = TransformShowQualifiedName(transformer, show_or_describe_or_summarize, std::move(describe_target)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformShowTablesInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto show_or_describe = transformer.Transform(list_pr.GetChild(0)); + auto qualified_name = transformer.Transform(list_pr.GetChild(3)); + auto result = TransformShowTables(transformer, show_or_describe, qualified_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDescribeTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformDescribeBaseTableNameInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformDescribeBaseTableName(transformer, std::move(base_table_name)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformDescribeStringLiteralInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformDescribeStringLiteral(transformer, string_literal); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformShowOrDescribeOrSummarizeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSummarizeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto summarize_rule = transformer.Transform(list_pr.GetChild(0)); + auto result = summarize_rule; + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSummarizeRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformSummarizeRule(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformShowOrDescribeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformShowRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformShowRule(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDescribeRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDescribeLongRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDescribeLongRule(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDescRuleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDescRule(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDetachStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(2).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto catalog_name = list_pr.GetChild(3).Cast().identifier; + auto result = TransformDetachStatement(transformer, if_exists, catalog_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto drop_entries = transformer.Transform>(list_pr.GetChild(1)); + bool drop_behavior {}; + auto &drop_behavior_opt = list_pr.GetChild(2).Cast(); + if (drop_behavior_opt.HasResult()) { + drop_behavior = transformer.Transform(drop_behavior_opt.GetResult()); + } + auto result = TransformDropStatement(transformer, std::move(drop_entries), drop_behavior); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropEntriesInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropTriggerInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto trigger_name = transformer.Transform(list_pr.GetChild(2)); + auto base_table_name = transformer.Transform>(list_pr.GetChild(4)); + auto result = TransformDropTrigger(transformer, if_exists, trigger_name, std::move(base_table_name)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropTableInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto table_or_view = transformer.Transform(list_pr.GetChild(0)); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + vector> base_table_name; + auto base_table_name_items = ExtractParseResultsFromList(list_pr.GetChild(2)); + for (auto &base_table_name_item : base_table_name_items) { + auto base_table_name_value = transformer.Transform>(base_table_name_item.get()); + base_table_name.push_back(std::move(base_table_name_value)); + } + auto result = TransformDropTable(transformer, table_or_view, if_exists, std::move(base_table_name)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropTableFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto comment_macro_table = transformer.Transform(list_pr.GetChild(0)); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + vector table_function_name; + auto table_function_name_items = ExtractParseResultsFromList(list_pr.GetChild(2)); + for (auto &table_function_name_item : table_function_name_items) { + auto table_function_name_value = table_function_name_item.get().Cast().identifier; + table_function_name.push_back(table_function_name_value); + } + auto result = TransformDropTableFunction(transformer, comment_macro_table, if_exists, table_function_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto function_type_macro = transformer.Transform(list_pr.GetChild(0)); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + vector function_identifier; + auto function_identifier_items = ExtractParseResultsFromList(list_pr.GetChild(2)); + for (auto &function_identifier_item : function_identifier_items) { + auto function_identifier_value = transformer.Transform(function_identifier_item.get()); + function_identifier.push_back(function_identifier_value); + } + auto result = TransformDropFunction(transformer, function_type_macro, if_exists, function_identifier); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropSchemaInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + vector qualified_schema_name; + auto qualified_schema_name_items = ExtractParseResultsFromList(list_pr.GetChild(2)); + for (auto &qualified_schema_name_item : qualified_schema_name_items) { + auto qualified_schema_name_value = transformer.Transform(qualified_schema_name_item.get()); + qualified_schema_name.push_back(qualified_schema_name_value); + } + auto result = TransformDropSchema(transformer, if_exists, qualified_schema_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + vector qualified_index_name; + auto qualified_index_name_items = ExtractParseResultsFromList(list_pr.GetChild(2)); + for (auto &qualified_index_name_item : qualified_index_name_items) { + auto qualified_index_name_value = transformer.Transform(qualified_index_name_item.get()); + qualified_index_name.push_back(qualified_index_name_value); + } + auto result = TransformDropIndex(transformer, if_exists, qualified_index_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformQualifiedIndexNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformQualifiedIndexNameStringInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto index_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformQualifiedIndexNameString(transformer, index_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformSchemaReservedIndexInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto schema_qualification = transformer.Transform(list_pr.GetChild(0)); + auto reserved_index_name = list_pr.GetChild(1).Cast().identifier; + auto result = TransformSchemaReservedIndex(transformer, schema_qualification, reserved_index_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformCatalogReservedSchemaIndexInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_qualification = transformer.Transform(list_pr.GetChild(0)); + auto reserved_schema_qualification = transformer.Transform(list_pr.GetChild(1)); + auto reserved_index_name = list_pr.GetChild(2).Cast().identifier; + auto result = TransformCatalogReservedSchemaIndex(transformer, catalog_qualification, reserved_schema_qualification, + reserved_index_name); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDropSequenceInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + vector qualified_sequence_name; + auto qualified_sequence_name_items = ExtractParseResultsFromList(list_pr.GetChild(2)); + for (auto &qualified_sequence_name_item : qualified_sequence_name_items) { + auto qualified_sequence_name_value = transformer.Transform(qualified_sequence_name_item.get()); + qualified_sequence_name.push_back(qualified_sequence_name_value); + } + auto result = TransformDropSequence(transformer, if_exists, qualified_sequence_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropCollationInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + vector collation_name; + auto collation_name_items = ExtractParseResultsFromList(list_pr.GetChild(2)); + for (auto &collation_name_item : collation_name_items) { + auto collation_name_value = transformer.Transform(collation_name_item.get()); + collation_name.push_back(collation_name_value); + } + auto result = TransformDropCollation(transformer, if_exists, collation_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropTypeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(1).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + vector qualified_type_name; + auto qualified_type_name_items = ExtractParseResultsFromList(list_pr.GetChild(2)); + for (auto &qualified_type_name_item : qualified_type_name_items) { + auto qualified_type_name_value = transformer.Transform(qualified_type_name_item.get()); + qualified_type_name.push_back(qualified_type_name_value); + } + auto result = TransformDropType(transformer, if_exists, qualified_type_name); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDropSecretInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + SecretPersistType temporary {}; + auto &temporary_opt = list_pr.GetChild(0).Cast(); + if (temporary_opt.HasResult()) { + temporary = transformer.Transform(temporary_opt.GetResult()); + } + bool if_exists {}; + auto &if_exists_opt = list_pr.GetChild(2).Cast(); + if (if_exists_opt.HasResult()) { + if_exists = transformer.Transform(if_exists_opt.GetResult()); + } + auto secret_name = transformer.Transform(list_pr.GetChild(3)); + Identifier drop_secret_storage {}; + auto &drop_secret_storage_opt = list_pr.GetChild(4).Cast(); + if (drop_secret_storage_opt.HasResult()) { + drop_secret_storage = transformer.Transform(drop_secret_storage_opt.GetResult()); + } + auto result = TransformDropSecret(transformer, temporary, if_exists, secret_name, drop_secret_storage); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformTableOrViewInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformMaterializedViewEntryInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformMaterializedViewEntry(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformFunctionTypeMacroInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformFunctionTypeMacroKeywordInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformFunctionTypeMacroKeyword(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformFunctionTypeFunctionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformFunctionTypeFunction(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDropBehaviorInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformCascadeDropBehaviorInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformCascadeDropBehavior(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformRestrictDropBehaviorInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformRestrictDropBehavior(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIfExistsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformIfExists(transformer); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformQualifiedSchemaNameInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformQualifiedSchemaNameStringInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto schema_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformQualifiedSchemaNameString(transformer, schema_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformCatalogReservedSchemaInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_qualification = transformer.Transform(list_pr.GetChild(0)); + auto reserved_schema_name = list_pr.GetChild(1).Cast().identifier; + auto result = TransformCatalogReservedSchema(transformer, catalog_qualification, reserved_schema_name); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDropSecretStorageInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + auto result = TransformDropSecretStorage(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformExecuteStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + vector table_function_arguments {}; + auto &table_function_arguments_opt = list_pr.GetChild(2).Cast(); + if (table_function_arguments_opt.HasResult()) { + table_function_arguments = + transformer.Transform>(table_function_arguments_opt.GetResult()); + } + auto result = TransformExecuteStatement(transformer, identifier, std::move(table_function_arguments)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformExplainStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + bool explain_analyze {}; + auto &explain_analyze_opt = list_pr.GetChild(1).Cast(); + if (explain_analyze_opt.HasResult()) { + explain_analyze = transformer.Transform(explain_analyze_opt.GetResult()); + } + vector explain_option_list {}; + auto &explain_option_list_opt = list_pr.GetChild(2).Cast(); + if (explain_option_list_opt.HasResult()) { + explain_option_list = transformer.Transform>(explain_option_list_opt.GetResult()); + } + auto explainable_statements = transformer.Transform>(list_pr.GetChild(3)); + auto result = + TransformExplainStatement(transformer, explain_analyze, explain_option_list, std::move(explainable_statements)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformExplainAnalyzeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformExplainAnalyze(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformExplainOptionListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector explain_option; + auto explain_option_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &explain_option_item : explain_option_items) { + auto explain_option_value = transformer.Transform(explain_option_item.get()); + explain_option.push_back(explain_option_value); + } + auto result = TransformExplainOptionList(transformer, explain_option); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformExplainOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto explain_option_name = transformer.Transform(list_pr.GetChild(0)); + unique_ptr expression {}; + auto &expression_opt = list_pr.GetChild(1).Cast(); + if (expression_opt.HasResult()) { + expression = transformer.Transform>(expression_opt.GetResult()); + } + auto result = TransformExplainOption(transformer, explain_option_name, std::move(expression)); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformExplainSelectStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto select_statement_internal = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformExplainSelectStatement(transformer, std::move(select_statement_internal)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformExplainableStatementsInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformExportStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + string export_source {}; + auto &export_source_opt = list_pr.GetChild(2).Cast(); + if (export_source_opt.HasResult()) { + export_source = transformer.Transform(export_source_opt.GetResult()); + } + auto string_literal = transformer.Transform(list_pr.GetChild(3)); + vector generic_copy_option_list {}; + auto &generic_copy_option_list_opt = list_pr.GetChild(4).Cast(); + if (generic_copy_option_list_opt.HasResult()) { + generic_copy_option_list = + transformer.Transform>(generic_copy_option_list_opt.GetResult()); + } + auto result = TransformExportStatement(transformer, export_source, string_literal, generic_copy_option_list); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformExportSourceInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformExportSource(transformer, catalog_name); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformImportStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformImportStatement(transformer, string_literal); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformInsertStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + CommonTableExpressionMap with_clause {}; + auto &with_clause_opt = list_pr.GetChild(0).Cast(); + if (with_clause_opt.HasResult()) { + with_clause = transformer.Transform(with_clause_opt.GetResult()); + } + OnConflictAction or_action {}; + auto &or_action_opt = list_pr.GetChild(2).Cast(); + if (or_action_opt.HasResult()) { + or_action = transformer.Transform(or_action_opt.GetResult()); + } + auto insert_target = transformer.Transform>(list_pr.GetChild(4)); + InsertColumnOrder by_name_or_position {}; + auto &by_name_or_position_opt = list_pr.GetChild(5).Cast(); + if (by_name_or_position_opt.HasResult()) { + by_name_or_position = transformer.Transform(by_name_or_position_opt.GetResult()); + } + vector insert_column_list {}; + auto &insert_column_list_opt = list_pr.GetChild(6).Cast(); + if (insert_column_list_opt.HasResult()) { + insert_column_list = transformer.Transform>(insert_column_list_opt.GetResult()); + } + auto insert_values = transformer.Transform(list_pr.GetChild(7)); + unique_ptr on_conflict_clause {}; + auto &on_conflict_clause_opt = list_pr.GetChild(8).Cast(); + if (on_conflict_clause_opt.HasResult()) { + on_conflict_clause = transformer.Transform>(on_conflict_clause_opt.GetResult()); + } + vector> returning_clause {}; + auto &returning_clause_opt = list_pr.GetChild(9).Cast(); + if (returning_clause_opt.HasResult()) { + returning_clause = + transformer.Transform>>(returning_clause_opt.GetResult()); + } + auto result = TransformInsertStatement(transformer, std::move(with_clause), or_action, std::move(insert_target), + by_name_or_position, insert_column_list, std::move(insert_values), + std::move(on_conflict_clause), std::move(returning_clause)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformOrActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformInsertOrReplaceInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformInsertOrReplace(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformInsertOrIgnoreInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformInsertOrIgnore(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformByNameOrPositionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformInsertByNameOrderInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto insert_by_name = transformer.Transform(list_pr.GetChild(1)); + auto result = insert_by_name; + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformInsertByPositionOrderInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto insert_by_position = transformer.Transform(list_pr.GetChild(1)); + auto result = insert_by_position; + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformInsertByNameInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformInsertByName(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformInsertByPositionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformInsertByPosition(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformInsertTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); + Identifier insert_alias {}; + auto &insert_alias_opt = list_pr.GetChild(1).Cast(); + if (insert_alias_opt.HasResult()) { + insert_alias = transformer.Transform(insert_alias_opt.GetResult()); + } + auto result = TransformInsertTarget(transformer, std::move(base_table_name), insert_alias); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformInsertAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + auto result = TransformInsertAlias(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector col_id; + auto col_id_items = ExtractParseResultsFromList(list_pr.GetChild(0)); + for (auto &col_id_item : col_id_items) { + auto col_id_value = transformer.Transform(col_id_item.get()); + col_id.push_back(col_id_value); + } + auto result = TransformColumnList(transformer, col_id); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformInsertColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto column_list = transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(0))); + auto result = TransformInsertColumnList(transformer, column_list); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformInsertValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSelectInsertValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto select_statement_internal = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformSelectInsertValues(transformer, std::move(select_statement_internal)); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDefaultValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDefaultValues(transformer); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformOnConflictClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + OnConflictExpressionTarget on_conflict_target {}; + auto &on_conflict_target_opt = list_pr.GetChild(2).Cast(); + if (on_conflict_target_opt.HasResult()) { + on_conflict_target = transformer.Transform(on_conflict_target_opt.GetResult()); + } + auto on_conflict_action = transformer.Transform>(list_pr.GetChild(3)); + auto result = TransformOnConflictClause(transformer, std::move(on_conflict_target), std::move(on_conflict_action)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformOnConflictTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformOnConflictExpressionTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto column_id_list = transformer.Transform>(list_pr.GetChild(0)); + unique_ptr where_clause {}; + auto &where_clause_opt = list_pr.GetChild(1).Cast(); + if (where_clause_opt.HasResult()) { + where_clause = transformer.Transform>(where_clause_opt.GetResult()); + } + auto result = TransformOnConflictExpressionTarget(transformer, column_id_list, std::move(where_clause)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformOnConflictIndexTargetInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto constraint_name = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformOnConflictIndexTarget(transformer, constraint_name); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformOnConflictActionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformOnConflictUpdateInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto update_set_clause = transformer.Transform>(list_pr.GetChild(3)); + unique_ptr where_clause {}; + auto &where_clause_opt = list_pr.GetChild(4).Cast(); + if (where_clause_opt.HasResult()) { + where_clause = transformer.Transform>(where_clause_opt.GetResult()); + } + auto result = TransformOnConflictUpdate(transformer, std::move(update_set_clause), std::move(where_clause)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformOnConflictNothingInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformOnConflictNothing(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformReturningClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto target_list = transformer.Transform>>(list_pr.GetChild(1)); + auto result = TransformReturningClause(transformer, std::move(target_list)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformLoadStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id_or_string = transformer.Transform(list_pr.GetChild(1)); + Identifier extension_alias {}; + auto &extension_alias_opt = list_pr.GetChild(2).Cast(); + if (extension_alias_opt.HasResult()) { + extension_alias = transformer.Transform(extension_alias_opt.GetResult()); + } + auto result = TransformLoadStatement(transformer, col_id_or_string, extension_alias); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformExtensionAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + auto result = TransformExtensionAlias(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformInstallStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier_or_string_literal = transformer.Transform(list_pr.GetChild(2)); + ExtensionRepositoryInfo from_source {}; + auto &from_source_opt = list_pr.GetChild(3).Cast(); + if (from_source_opt.HasResult()) { + from_source = transformer.Transform(from_source_opt.GetResult()); + } + string version_number {}; + auto &version_number_opt = list_pr.GetChild(4).Cast(); + if (version_number_opt.HasResult()) { + version_number = transformer.Transform(version_number_opt.GetResult()); + } + auto result = TransformInstallStatement(transformer, identifier_or_string_literal, from_source, version_number); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformUpdateExtensionsStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector identifier {}; + auto &identifier_opt = list_pr.GetChild(2).Cast(); + if (identifier_opt.HasResult()) { + auto identifier_items_1 = ExtractParseResultsFromList(ExtractResultFromParens(identifier_opt.GetResult())); + for (auto &identifier_item_1 : identifier_items_1) { + auto identifier_value_1 = identifier_item_1.get().Cast().identifier; + identifier.push_back(identifier_value_1); + } + } + auto result = TransformUpdateExtensionsStatement(transformer, identifier); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformFromSourceInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformFromSourceIdentifierInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + auto result = TransformFromSourceIdentifier(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformFromSourceStringInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformFromSourceString(transformer, string_literal); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformVersionNumberInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier_or_string_literal = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformVersionNumber(transformer, identifier_or_string_literal); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformMergeIntoStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + CommonTableExpressionMap with_clause {}; + auto &with_clause_opt = list_pr.GetChild(0).Cast(); + if (with_clause_opt.HasResult()) { + with_clause = transformer.Transform(with_clause_opt.GetResult()); + } + auto target_opt_alias = transformer.Transform>(list_pr.GetChild(3)); + auto merge_into_using_clause = transformer.Transform>(list_pr.GetChild(4)); + auto join_qualifier = transformer.Transform(list_pr.GetChild(5)); + vector>> merge_match; + auto &merge_match_repeat = list_pr.GetChild(6).Cast(); + for (auto &merge_match_item : merge_match_repeat.GetChildren()) { + auto merge_match_value = + transformer.Transform>>(merge_match_item.get()); + merge_match.push_back(std::move(merge_match_value)); + } + vector> returning_clause {}; + auto &returning_clause_opt = list_pr.GetChild(7).Cast(); + if (returning_clause_opt.HasResult()) { + returning_clause = + transformer.Transform>>(returning_clause_opt.GetResult()); + } + auto result = TransformMergeIntoStatement(transformer, std::move(with_clause), std::move(target_opt_alias), + std::move(merge_into_using_clause), std::move(join_qualifier), + std::move(merge_match), std::move(returning_clause)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformMergeIntoUsingClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto table_ref = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformMergeIntoUsingClause(transformer, std::move(table_ref)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformMergeMatchInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>>(choice_pr.GetResult()); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformMatchedClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + unique_ptr and_expression {}; + auto &and_expression_opt = list_pr.GetChild(2).Cast(); + if (and_expression_opt.HasResult()) { + and_expression = transformer.Transform>(and_expression_opt.GetResult()); + } + auto matched_clause_action = transformer.Transform>(list_pr.GetChild(4)); + auto result = TransformMatchedClause(transformer, std::move(and_expression), std::move(matched_clause_action)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformMatchedClauseActionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUpdateMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + unique_ptr update_match_info {}; + auto &update_match_info_opt = list_pr.GetChild(1).Cast(); + if (update_match_info_opt.HasResult()) { + update_match_info = transformer.Transform>(update_match_info_opt.GetResult()); + } + auto result = TransformUpdateMatchClause(transformer, std::move(update_match_info)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUpdateMatchInfoInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformUpdateMatchSetActionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto update_match_set_clause = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformUpdateMatchSetAction(transformer, std::move(update_match_set_clause)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformUpdateByNameOrPositionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto by_name_or_position = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformUpdateByNameOrPosition(transformer, by_name_or_position); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformDeleteMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformDeleteMatchClause(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformInsertMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + unique_ptr insert_match_info {}; + auto &insert_match_info_opt = list_pr.GetChild(1).Cast(); + if (insert_match_info_opt.HasResult()) { + insert_match_info = transformer.Transform>(insert_match_info_opt.GetResult()); + } + auto result = TransformInsertMatchClause(transformer, std::move(insert_match_info)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformInsertMatchInfoInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformInsertDefaultValuesInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformInsertDefaultValues(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformInsertByNameOrPositionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + InsertColumnOrder by_name_or_position {}; + auto &by_name_or_position_opt = list_pr.GetChild(0).Cast(); + if (by_name_or_position_opt.HasResult()) { + by_name_or_position = transformer.Transform(by_name_or_position_opt.GetResult()); + } + auto result = TransformInsertByNameOrPosition(transformer, by_name_or_position); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformInsertValuesListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector insert_column_list {}; + auto &insert_column_list_opt = list_pr.GetChild(0).Cast(); + if (insert_column_list_opt.HasResult()) { + insert_column_list = transformer.Transform>(insert_column_list_opt.GetResult()); + } + vector> expression; + auto expression_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(2))); + for (auto &expression_item : expression_items) { + auto expression_value = transformer.Transform>(expression_item.get()); + expression.push_back(std::move(expression_value)); + } + auto result = TransformInsertValuesList(transformer, insert_column_list, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformDoNothingMatchClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformDoNothingMatchClause(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformErrorMatchClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + unique_ptr expression {}; + auto &expression_opt = list_pr.GetChild(1).Cast(); + if (expression_opt.HasResult()) { + expression = transformer.Transform>(expression_opt.GetResult()); + } + auto result = TransformErrorMatchClause(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformUpdateMatchSetClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto update_match_set_info = transformer.Transform>(list_pr.GetChild(1)); + auto result = std::move(update_match_set_info); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUpdateMatchSetInfoInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto &choice_result = choice_pr.GetResult(); + unique_ptr result {}; + if (!(choice_result.name == "StarSymbol")) { + result = transformer.Transform>(choice_result); + } + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformAndExpressionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformAndExpression(transformer, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformNotMatchedClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + MergeActionCondition by_source_or_target {}; + auto &by_source_or_target_opt = list_pr.GetChild(3).Cast(); + if (by_source_or_target_opt.HasResult()) { + by_source_or_target = transformer.Transform(by_source_or_target_opt.GetResult()); + } + unique_ptr and_expression {}; + auto &and_expression_opt = list_pr.GetChild(4).Cast(); + if (and_expression_opt.HasResult()) { + and_expression = transformer.Transform>(and_expression_opt.GetResult()); + } + auto matched_clause_action = transformer.Transform>(list_pr.GetChild(6)); + auto result = TransformNotMatchedClause(transformer, by_source_or_target, std::move(and_expression), + std::move(matched_clause_action)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformBySourceOrTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformBySourceInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformBySource(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformByTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformByTarget(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformPivotOnInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto pivot_column_list = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformPivotOn(transformer, std::move(pivot_column_list)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformPivotUsingInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto target_list = transformer.Transform>>(list_pr.GetChild(1)); + auto result = TransformPivotUsing(transformer, std::move(target_list)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformPivotColumnListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector pivot_column_entry; + auto pivot_column_entry_items = ExtractParseResultsFromList(list_pr.GetChild(0)); + for (auto &pivot_column_entry_item : pivot_column_entry_items) { + auto pivot_column_entry_value = transformer.Transform(pivot_column_entry_item.get()); + pivot_column_entry.push_back(std::move(pivot_column_entry_value)); + } + auto result = TransformPivotColumnList(transformer, std::move(pivot_column_entry)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformPivotColumnEntryInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformPivotColumnExpressionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto expression = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformPivotColumnExpression(transformer, std::move(expression)); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformPivotColumnSubqueryInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_expression = transformer.Transform>(list_pr.GetChild(0)); + auto select_statement_internal = + transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(2))); + auto result = + TransformPivotColumnSubquery(transformer, std::move(base_expression), std::move(select_statement_internal)); + return make_uniq>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformIntoNameValuesInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id_or_string = transformer.Transform(list_pr.GetChild(2)); + vector identifier; + auto identifier_items = ExtractParseResultsFromList(list_pr.GetChild(4)); + for (auto &identifier_item : identifier_items) { + auto identifier_value = identifier_item.get().Cast().identifier; + identifier.push_back(identifier_value); + } + auto result = TransformIntoNameValues(transformer, col_id_or_string, identifier); + return make_uniq>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformIncludeOrExcludeNullsInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformIncludeNullsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformIncludeNulls(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformExcludeNullsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformExcludeNulls(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformUnpivotHeaderInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(result); +} + +unique_ptr +PEGTransformerFactory::TransformUnpivotHeaderSingleInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id_or_string = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformUnpivotHeaderSingle(transformer, col_id_or_string); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformUnpivotHeaderListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector col_id_or_string; + auto col_id_or_string_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &col_id_or_string_item : col_id_or_string_items) { + auto col_id_or_string_value = transformer.Transform(col_id_or_string_item.get()); + col_id_or_string.push_back(col_id_or_string_value); + } + auto result = TransformUnpivotHeaderList(transformer, col_id_or_string); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformPragmaStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto pragma_assign_or_function = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformPragmaStatement(transformer, std::move(pragma_assign_or_function)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformPragmaAssignOrFunctionInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformPragmaAssignInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto setting_name = list_pr.GetChild(0).Cast().identifier; + auto variable_list = transformer.Transform>>(list_pr.GetChild(2)); + auto result = TransformPragmaAssign(transformer, setting_name, std::move(variable_list)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformPragmaFunctionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto pragma_name = list_pr.GetChild(0).Cast().identifier; + vector> pragma_parameters {}; + auto &pragma_parameters_opt = list_pr.GetChild(1).Cast(); + if (pragma_parameters_opt.HasResult()) { + pragma_parameters = + transformer.Transform>>(pragma_parameters_opt.GetResult()); + } + auto result = TransformPragmaFunction(transformer, pragma_name, std::move(pragma_parameters)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformPragmaParametersInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> expression; + auto expression_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &expression_item : expression_items) { + auto expression_value = transformer.Transform>(expression_item.get()); + expression.push_back(std::move(expression_value)); + } + auto result = TransformPragmaParameters(transformer, std::move(expression)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformPrepareStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + vector type_list {}; + auto &type_list_opt = list_pr.GetChild(2).Cast(); + if (type_list_opt.HasResult()) { + type_list = transformer.Transform>(type_list_opt.GetResult()); + } + auto statement = transformer.Transform>(list_pr.GetChild(4)); + auto result = TransformPrepareStatement(transformer, identifier, type_list, std::move(statement)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformTypeListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector type; + auto type_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &type_item : type_items) { + auto type_value = transformer.Transform(type_item.get()); + type.push_back(type_value); + } + auto result = TransformTypeList(transformer, type); + return make_uniq>>(result); +} + +unique_ptr PEGTransformerFactory::TransformSetStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto use_target = transformer.Transform(list_pr, 1); - auto result = TransformUseStatement(use_target); + auto set_assignment_or_time_zone = transformer.Transform>(list_pr.GetChild(1)); + auto result = TransformSetStatement(transformer, std::move(set_assignment_or_time_zone)); return make_uniq>>(std::move(result)); } -unique_ptr PEGTransformerFactory::TransformUseTargetInternal(PEGTransformer &transformer, +unique_ptr +PEGTransformerFactory::TransformSetAssignmentOrTimeZoneInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformResetStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto set_variable_or_setting = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformResetStatement(transformer, set_variable_or_setting); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformStandardAssignmentInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto set_variable_or_setting = transformer.Transform(list_pr.GetChild(0)); + auto set_assignment = transformer.Transform>>(list_pr.GetChild(1)); + auto result = TransformStandardAssignment(transformer, set_variable_or_setting, std::move(set_assignment)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformSetVariableOrSettingInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSetTimeZoneInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto zone_value = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformSetTimeZone(transformer, std::move(zone_value)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformZoneValueInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto &choice_pr = list_pr.Child(0); - auto result = TransformUseTarget(transformer, choice_pr.GetResult()); - return make_uniq>(result); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformZoneLocalInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformZoneLocal(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformZoneDefaultInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformZoneDefault(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformZoneStringLiteralInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformZoneStringLiteral(transformer, string_literal); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformZoneIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(0).Cast().identifier; + auto result = TransformZoneIdentifier(transformer, identifier); + return make_uniq>>(std::move(result)); } unique_ptr -PEGTransformerFactory::TransformUseTargetCatalogSchemaInternal(PEGTransformer &transformer, ParseResult &parse_result) { +PEGTransformerFactory::TransformZoneIntervalWithIntervalInternal(PEGTransformer &transformer, + ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto catalog_name = list_pr.Child(0).identifier; - auto reserved_schema_name = list_pr.Child(2).identifier; - auto &dot_identifier_opt = list_pr.Child(3); - vector dot_identifier; - if (dot_identifier_opt.HasResult()) { - auto &dot_identifier_repeat = dot_identifier_opt.GetResult().Cast(); - for (auto &dot_identifier_item : dot_identifier_repeat.GetChildren()) { - dot_identifier.push_back(transformer.Transform(dot_identifier_item)); - } + auto string_literal = transformer.Transform(list_pr.GetChild(1)); + DatePartSpecifier interval {}; + auto &interval_opt = list_pr.GetChild(2).Cast(); + if (interval_opt.HasResult()) { + interval = transformer.Transform(interval_opt.GetResult()); } - auto result = TransformUseTargetCatalogSchema(catalog_name, reserved_schema_name, dot_identifier); - return make_uniq>(result); + auto result = TransformZoneIntervalWithInterval(transformer, string_literal, interval); + return make_uniq>>(std::move(result)); } -unique_ptr PEGTransformerFactory::TransformDotIdentifierInternal(PEGTransformer &transformer, +unique_ptr +PEGTransformerFactory::TransformZoneIntervalWithPrecisionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto number_literal = + transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(1))); + auto string_literal = transformer.Transform(list_pr.GetChild(2)); + auto result = TransformZoneIntervalWithPrecision(transformer, std::move(number_literal), string_literal); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformSetSettingInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + SetScope setting_scope {}; + auto &setting_scope_opt = list_pr.GetChild(0).Cast(); + if (setting_scope_opt.HasResult()) { + setting_scope = transformer.Transform(setting_scope_opt.GetResult()); + } + auto setting_name = list_pr.GetChild(1).Cast().identifier; + auto result = TransformSetSetting(transformer, setting_scope, setting_name); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSetVariableInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto variable_scope = transformer.Transform(list_pr.GetChild(0)); + auto identifier = list_pr.GetChild(1).Cast().identifier; + auto result = TransformSetVariable(transformer, variable_scope, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformVariableScopeInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto result = TransformVariableScope(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSettingScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto identifier = list_pr.Child(1).identifier; - auto result = TransformDotIdentifier(identifier); - return make_uniq>(result); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformLocalScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformLocalScope(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSessionScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformSessionScope(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformGlobalScopeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformGlobalScope(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformSetAssignmentInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto variable_list = transformer.Transform>>(list_pr.GetChild(1)); + auto result = TransformSetAssignment(transformer, std::move(variable_list)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformVariableListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector> expression; + auto expression_items = ExtractParseResultsFromList(list_pr.GetChild(0)); + for (auto &expression_item : expression_items) { + auto expression_value = transformer.Transform>(expression_item.get()); + expression.push_back(std::move(expression_value)); + } + auto result = TransformVariableList(transformer, std::move(expression)); + return make_uniq>>>(std::move(result)); } unique_ptr @@ -56,84 +4912,895 @@ unique_ptr PEGTransformerFactory::TransformBeginTransactio ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); TransactionModifierType read_or_write {}; - transformer.TransformOptional(list_pr, 2, read_or_write); - auto result = TransformBeginTransaction(read_or_write); + auto &read_or_write_opt = list_pr.GetChild(2).Cast(); + if (read_or_write_opt.HasResult()) { + read_or_write = transformer.Transform(read_or_write_opt.GetResult()); + } + auto result = TransformBeginTransaction(transformer, read_or_write); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformRollbackTransactionInternal(PEGTransformer &transformer, ParseResult &parse_result) { - auto result = TransformRollbackTransaction(); + auto result = TransformRollbackTransaction(transformer); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformCommitTransactionInternal(PEGTransformer &transformer, ParseResult &parse_result) { - auto result = TransformCommitTransaction(); + auto result = TransformCommitTransaction(transformer); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformReadOrWriteInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto read_only_or_read_write = transformer.Transform(list_pr, 1); - auto result = TransformReadOrWrite(read_only_or_read_write); + auto read_only_or_read_write = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformReadOrWrite(transformer, read_only_or_read_write); return make_uniq>(result); } -unique_ptr PEGTransformerFactory::TransformDetachStatementInternal(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr +PEGTransformerFactory::TransformReadOnlyOrReadWriteInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; - transformer.TransformOptional(list_pr, 2, if_exists); - auto catalog_name = list_pr.Child(3).identifier; - auto result = TransformDetachStatement(if_exists, catalog_name); - return make_uniq>>(std::move(result)); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); } -unique_ptr PEGTransformerFactory::TransformExportStatementInternal(PEGTransformer &transformer, +unique_ptr PEGTransformerFactory::TransformReadOnlyInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformReadOnly(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformReadWriteInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformReadWrite(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformUpdateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - string export_source {}; - transformer.TransformOptional(list_pr, 2, export_source); - auto string_literal = transformer.Transform(list_pr, 3); - vector generic_copy_option_list {}; - transformer.TransformOptional(list_pr, 4, generic_copy_option_list); - auto result = TransformExportStatement(export_source, string_literal, std::move(generic_copy_option_list)); + CommonTableExpressionMap with_clause {}; + auto &with_clause_opt = list_pr.GetChild(0).Cast(); + if (with_clause_opt.HasResult()) { + with_clause = transformer.Transform(with_clause_opt.GetResult()); + } + auto update_target = transformer.Transform>(list_pr.GetChild(2)); + auto update_set_clause = transformer.Transform>(list_pr.GetChild(3)); + unique_ptr from_clause {}; + auto &from_clause_opt = list_pr.GetChild(4).Cast(); + if (from_clause_opt.HasResult()) { + from_clause = transformer.Transform>(from_clause_opt.GetResult()); + } + unique_ptr where_clause {}; + auto &where_clause_opt = list_pr.GetChild(5).Cast(); + if (where_clause_opt.HasResult()) { + where_clause = transformer.Transform>(where_clause_opt.GetResult()); + } + vector> returning_clause {}; + auto &returning_clause_opt = list_pr.GetChild(6).Cast(); + if (returning_clause_opt.HasResult()) { + returning_clause = + transformer.Transform>>(returning_clause_opt.GetResult()); + } + auto result = TransformUpdateStatement(transformer, std::move(with_clause), std::move(update_target), + std::move(update_set_clause), std::move(from_clause), + std::move(where_clause), std::move(returning_clause)); return make_uniq>>(std::move(result)); } -unique_ptr PEGTransformerFactory::TransformExportSourceInternal(PEGTransformer &transformer, +unique_ptr PEGTransformerFactory::TransformUpdateTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformBaseTableSetInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto catalog_name = list_pr.Child(0).identifier; - auto result = TransformExportSource(catalog_name); + auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); + auto result = TransformBaseTableSet(transformer, std::move(base_table_name)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformBaseTableAliasSetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); + Identifier update_alias {}; + auto &update_alias_opt = list_pr.GetChild(1).Cast(); + if (update_alias_opt.HasResult()) { + update_alias = transformer.Transform(update_alias_opt.GetResult()); + } + auto result = TransformBaseTableAliasSet(transformer, std::move(base_table_name), update_alias); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUpdateAliasInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_id = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformUpdateAlias(transformer, col_id); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformUpdateSetClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUpdateSetTupleInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector column_name; + auto column_name_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &column_name_item : column_name_items) { + auto column_name_value = column_name_item.get().Cast().identifier; + column_name.push_back(column_name_value); + } + auto expression = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformUpdateSetTuple(transformer, column_name, std::move(expression)); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformUpdateSetElementListInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector>> update_set_element; + auto update_set_element_items = ExtractParseResultsFromList(list_pr.GetChild(0)); + for (auto &update_set_element_item : update_set_element_items) { + auto update_set_element_value = + transformer.Transform>>(update_set_element_item.get()); + update_set_element.push_back(std::move(update_set_element_value)); + } + auto result = TransformUpdateSetElementList(transformer, std::move(update_set_element)); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUpdateSetElementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto update_set_column_target = transformer.Transform(list_pr.GetChild(0)); + auto expression = transformer.Transform>(list_pr.GetChild(2)); + auto result = TransformUpdateSetElement(transformer, update_set_column_target, std::move(expression)); + return make_uniq>>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformUpdateSetColumnTargetInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto column_name = list_pr.GetChild(0).Cast().identifier; + vector dot_identifier {}; + auto &dot_identifier_opt = list_pr.GetChild(1).Cast(); + if (dot_identifier_opt.HasResult()) { + auto &dot_identifier_repeat_1 = dot_identifier_opt.GetResult().Cast(); + for (auto &dot_identifier_item_1 : dot_identifier_repeat_1.GetChildren()) { + auto dot_identifier_value_1 = transformer.Transform(dot_identifier_item_1.get()); + dot_identifier.push_back(dot_identifier_value_1); + } + } + auto result = TransformUpdateSetColumnTarget(transformer, column_name, dot_identifier); return make_uniq>(result); } -unique_ptr PEGTransformerFactory::TransformImportStatementInternal(PEGTransformer &transformer, +unique_ptr PEGTransformerFactory::TransformUseStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto use_target = transformer.Transform(list_pr.GetChild(1)); + auto result = TransformUseStatement(transformer, use_target); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformUseTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformSchemaNameAsUseTargetInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto schema_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformSchemaNameAsUseTarget(transformer, schema_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformCatalogNameAsUseTargetInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformCatalogNameAsUseTarget(transformer, catalog_name); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformUseTargetCatalogSchemaInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_name = list_pr.GetChild(0).Cast().identifier; + auto reserved_schema_name = list_pr.GetChild(2).Cast().identifier; + vector dot_identifier {}; + auto &dot_identifier_opt = list_pr.GetChild(3).Cast(); + if (dot_identifier_opt.HasResult()) { + auto &dot_identifier_repeat_1 = dot_identifier_opt.GetResult().Cast(); + for (auto &dot_identifier_item_1 : dot_identifier_repeat_1.GetChildren()) { + auto dot_identifier_value_1 = transformer.Transform(dot_identifier_item_1.get()); + dot_identifier.push_back(dot_identifier_value_1); + } + } + auto result = TransformUseTargetCatalogSchema(transformer, catalog_name, reserved_schema_name, dot_identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformDotIdentifierInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto identifier = list_pr.GetChild(1).Cast().identifier; + auto result = TransformDotIdentifier(transformer, identifier); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformVacuumStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto string_literal = transformer.Transform(list_pr, 2); - auto result = TransformImportStatement(string_literal); + VacuumOptions vacuum_options {}; + auto &vacuum_options_opt = list_pr.GetChild(1).Cast(); + if (vacuum_options_opt.HasResult()) { + vacuum_options = transformer.Transform(vacuum_options_opt.GetResult()); + } + AnalyzeTarget analyze_target {}; + auto &analyze_target_opt = list_pr.GetChild(2).Cast(); + if (analyze_target_opt.HasResult()) { + analyze_target = transformer.Transform(analyze_target_opt.GetResult()); + } + auto result = TransformVacuumStatement(transformer, vacuum_options, std::move(analyze_target)); return make_uniq>>(std::move(result)); } +unique_ptr PEGTransformerFactory::TransformVacuumOptionsInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform(choice_pr.GetResult()); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformVacuumParensOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector vacuum_option; + auto vacuum_option_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &vacuum_option_item : vacuum_option_items) { + auto vacuum_option_value = transformer.Transform(vacuum_option_item.get()); + vacuum_option.push_back(vacuum_option_value); + } + auto result = TransformVacuumParensOptions(transformer, vacuum_option); + return make_uniq>(result); +} + +unique_ptr +PEGTransformerFactory::TransformVacuumLegacyOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + string opt_full {}; + auto &opt_full_opt = list_pr.GetChild(0).Cast(); + if (opt_full_opt.HasResult()) { + opt_full = transformer.Transform(opt_full_opt.GetResult()); + } + string opt_freeze {}; + auto &opt_freeze_opt = list_pr.GetChild(1).Cast(); + if (opt_freeze_opt.HasResult()) { + opt_freeze = transformer.Transform(opt_freeze_opt.GetResult()); + } + string opt_verbose {}; + auto &opt_verbose_opt = list_pr.GetChild(2).Cast(); + if (opt_verbose_opt.HasResult()) { + opt_verbose = transformer.Transform(opt_verbose_opt.GetResult()); + } + string opt_analyze {}; + auto &opt_analyze_opt = list_pr.GetChild(3).Cast(); + if (opt_analyze_opt.HasResult()) { + opt_analyze = transformer.Transform(opt_analyze_opt.GetResult()); + } + auto result = TransformVacuumLegacyOptions(transformer, opt_full, opt_freeze, opt_verbose, opt_analyze); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformVacuumOptionInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + string result; + if (choice_pr.GetResult().type == ParseResultType::IDENTIFIER) { + result = choice_pr.GetResult().Cast().identifier.GetIdentifierName(); + } else if (choice_pr.GetResult().type == ParseResultType::KEYWORD) { + result = choice_pr.GetResult().Cast().keyword; + } else if (choice_pr.GetResult().type == ParseResultType::STRING) { + result = choice_pr.GetResult().Cast().result; + } else { + result = transformer.Transform(choice_pr.GetResult()); + } + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformOptAnalyzeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformOptAnalyze(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformOptFullInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformOptFull(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformOptFreezeInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformOptFreeze(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformOptVerboseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformOptVerbose(transformer); + return make_uniq>(result); +} + +unique_ptr PEGTransformerFactory::TransformNameListInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + vector col_id; + auto col_id_items = ExtractParseResultsFromList(ExtractResultFromParens(list_pr.GetChild(0))); + for (auto &col_id_item : col_id_items) { + auto col_id_value = transformer.Transform(col_id_item.get()); + col_id.push_back(col_id_value); + } + auto result = TransformNameList(transformer, col_id); + return make_uniq>>(result); +} + void PEGTransformerFactory::RegisterGenerated() { static const TransformRule builtin_transform_rules[] = { - {"UseStatement", &PEGTransformerFactory::TransformUseStatementInternal}, - {"UseTarget", &PEGTransformerFactory::TransformUseTargetInternal}, - {"UseTargetCatalogSchema", &PEGTransformerFactory::TransformUseTargetCatalogSchemaInternal}, - {"DotIdentifier", &PEGTransformerFactory::TransformDotIdentifierInternal}, + {"AlterStatement", &PEGTransformerFactory::TransformAlterStatementInternal}, + {"AlterOptions", &PEGTransformerFactory::TransformAlterOptionsInternal}, + {"AlterTableStmt", &PEGTransformerFactory::TransformAlterTableStmtInternal}, + {"AlterSchemaStmt", &PEGTransformerFactory::TransformAlterSchemaStmtInternal}, + {"AlterTableOptions", &PEGTransformerFactory::TransformAlterTableOptionsInternal}, + {"AddConstraint", &PEGTransformerFactory::TransformAddConstraintInternal}, + {"AddColumn", &PEGTransformerFactory::TransformAddColumnInternal}, + {"AddColumnEntry", &PEGTransformerFactory::TransformAddColumnEntryInternal}, + {"DropColumn", &PEGTransformerFactory::TransformDropColumnInternal}, + {"AlterColumn", &PEGTransformerFactory::TransformAlterColumnInternal}, + {"RenameColumn", &PEGTransformerFactory::TransformRenameColumnInternal}, + {"NestedColumnName", &PEGTransformerFactory::TransformNestedColumnNameInternal}, + {"IdentifierDot", &PEGTransformerFactory::TransformIdentifierDotInternal}, + {"RenameAlter", &PEGTransformerFactory::TransformRenameAlterInternal}, + {"SetPartitionedBy", &PEGTransformerFactory::TransformSetPartitionedByInternal}, + {"ResetPartitionedBy", &PEGTransformerFactory::TransformResetPartitionedByInternal}, + {"SetSortedBy", &PEGTransformerFactory::TransformSetSortedByInternal}, + {"ResetSortedBy", &PEGTransformerFactory::TransformResetSortedByInternal}, + {"SetOptions", &PEGTransformerFactory::TransformSetOptionsInternal}, + {"ResetOptions", &PEGTransformerFactory::TransformResetOptionsInternal}, + {"AlterColumnEntry", &PEGTransformerFactory::TransformAlterColumnEntryInternal}, + {"AddOrDropDefault", &PEGTransformerFactory::TransformAddOrDropDefaultInternal}, + {"AddDefault", &PEGTransformerFactory::TransformAddDefaultInternal}, + {"DropDefault", &PEGTransformerFactory::TransformDropDefaultInternal}, + {"ChangeNullability", &PEGTransformerFactory::TransformChangeNullabilityInternal}, + {"DropOrSet", &PEGTransformerFactory::TransformDropOrSetInternal}, + {"DropNullability", &PEGTransformerFactory::TransformDropNullabilityInternal}, + {"SetNullability", &PEGTransformerFactory::TransformSetNullabilityInternal}, + {"AlterType", &PEGTransformerFactory::TransformAlterTypeInternal}, + {"UsingExpression", &PEGTransformerFactory::TransformUsingExpressionInternal}, + {"AlterViewStmt", &PEGTransformerFactory::TransformAlterViewStmtInternal}, + {"AlterSequenceStmt", &PEGTransformerFactory::TransformAlterSequenceStmtInternal}, + {"QualifiedSequenceName", &PEGTransformerFactory::TransformQualifiedSequenceNameInternal}, + {"AlterSequenceOptions", &PEGTransformerFactory::TransformAlterSequenceOptionsInternal}, + {"SetSequenceOption", &PEGTransformerFactory::TransformSetSequenceOptionInternal}, + {"AlterDatabaseStmt", &PEGTransformerFactory::TransformAlterDatabaseStmtInternal}, + {"AnalyzeStatement", &PEGTransformerFactory::TransformAnalyzeStatementInternal}, + {"AnalyzeTarget", &PEGTransformerFactory::TransformAnalyzeTargetInternal}, + {"AnalyzeVerbose", &PEGTransformerFactory::TransformAnalyzeVerboseInternal}, + {"AttachStatement", &PEGTransformerFactory::TransformAttachStatementInternal}, + {"DatabasePath", &PEGTransformerFactory::TransformDatabasePathInternal}, + {"AttachAlias", &PEGTransformerFactory::TransformAttachAliasInternal}, + {"AttachOptions", &PEGTransformerFactory::TransformAttachOptionsInternal}, + {"CallStatement", &PEGTransformerFactory::TransformCallStatementInternal}, + {"CheckpointStatement", &PEGTransformerFactory::TransformCheckpointStatementInternal}, + {"CheckpointForce", &PEGTransformerFactory::TransformCheckpointForceInternal}, + {"CommentStatement", &PEGTransformerFactory::TransformCommentStatementInternal}, + {"CommentOnType", &PEGTransformerFactory::TransformCommentOnTypeInternal}, + {"CommentTable", &PEGTransformerFactory::TransformCommentTableInternal}, + {"CommentSequence", &PEGTransformerFactory::TransformCommentSequenceInternal}, + {"CommentFunction", &PEGTransformerFactory::TransformCommentFunctionInternal}, + {"CommentMacroTable", &PEGTransformerFactory::TransformCommentMacroTableInternal}, + {"CommentMacro", &PEGTransformerFactory::TransformCommentMacroInternal}, + {"CommentView", &PEGTransformerFactory::TransformCommentViewInternal}, + {"CommentDatabase", &PEGTransformerFactory::TransformCommentDatabaseInternal}, + {"CommentIndex", &PEGTransformerFactory::TransformCommentIndexInternal}, + {"CommentSchema", &PEGTransformerFactory::TransformCommentSchemaInternal}, + {"CommentType", &PEGTransformerFactory::TransformCommentTypeInternal}, + {"CommentColumn", &PEGTransformerFactory::TransformCommentColumnInternal}, + {"ExpressionStatement", &PEGTransformerFactory::TransformExpressionStatementInternal}, + {"ExpressionAlias", &PEGTransformerFactory::TransformExpressionAliasInternal}, + {"ConstraintName", &PEGTransformerFactory::TransformConstraintNameInternal}, + {"CollationName", &PEGTransformerFactory::TransformCollationNameInternal}, + {"Type", &PEGTransformerFactory::TransformTypeInternal}, + {"TypeVariations", &PEGTransformerFactory::TransformTypeVariationsInternal}, + {"SimpleType", &PEGTransformerFactory::TransformSimpleTypeInternal}, + {"CharacterSimpleType", &PEGTransformerFactory::TransformCharacterSimpleTypeInternal}, + {"QualifiedSimpleType", &PEGTransformerFactory::TransformQualifiedSimpleTypeInternal}, + {"CharacterType", &PEGTransformerFactory::TransformCharacterTypeInternal}, + {"IntervalType", &PEGTransformerFactory::TransformIntervalTypeInternal}, + {"IntervalInterval", &PEGTransformerFactory::TransformIntervalIntervalInternal}, + {"IntervalWithSpecifier", &PEGTransformerFactory::TransformIntervalWithSpecifierInternal}, + {"IntervalWithRangeSpecifier", &PEGTransformerFactory::TransformIntervalWithRangeSpecifierInternal}, + {"IntervalWithSimpleSpecifier", &PEGTransformerFactory::TransformIntervalWithSimpleSpecifierInternal}, + {"IntervalWithoutSpecifier", &PEGTransformerFactory::TransformIntervalWithoutSpecifierInternal}, + {"YearKeyword", &PEGTransformerFactory::TransformYearKeywordInternal}, + {"MonthKeyword", &PEGTransformerFactory::TransformMonthKeywordInternal}, + {"DayKeyword", &PEGTransformerFactory::TransformDayKeywordInternal}, + {"HourKeyword", &PEGTransformerFactory::TransformHourKeywordInternal}, + {"MinuteKeyword", &PEGTransformerFactory::TransformMinuteKeywordInternal}, + {"SecondKeyword", &PEGTransformerFactory::TransformSecondKeywordInternal}, + {"MillisecondKeyword", &PEGTransformerFactory::TransformMillisecondKeywordInternal}, + {"MicrosecondKeyword", &PEGTransformerFactory::TransformMicrosecondKeywordInternal}, + {"WeekKeyword", &PEGTransformerFactory::TransformWeekKeywordInternal}, + {"QuarterKeyword", &PEGTransformerFactory::TransformQuarterKeywordInternal}, + {"DecadeKeyword", &PEGTransformerFactory::TransformDecadeKeywordInternal}, + {"CenturyKeyword", &PEGTransformerFactory::TransformCenturyKeywordInternal}, + {"MillenniumKeyword", &PEGTransformerFactory::TransformMillenniumKeywordInternal}, + {"Interval", &PEGTransformerFactory::TransformIntervalInternal}, + {"IntervalToInterval", &PEGTransformerFactory::TransformIntervalToIntervalInternal}, + {"YearToMonth", &PEGTransformerFactory::TransformYearToMonthInternal}, + {"DayToHour", &PEGTransformerFactory::TransformDayToHourInternal}, + {"DayToMinute", &PEGTransformerFactory::TransformDayToMinuteInternal}, + {"DayToSecond", &PEGTransformerFactory::TransformDayToSecondInternal}, + {"HourToMinute", &PEGTransformerFactory::TransformHourToMinuteInternal}, + {"HourToSecond", &PEGTransformerFactory::TransformHourToSecondInternal}, + {"MinuteToSecond", &PEGTransformerFactory::TransformMinuteToSecondInternal}, + {"BitType", &PEGTransformerFactory::TransformBitTypeInternal}, + {"GeometryType", &PEGTransformerFactory::TransformGeometryTypeInternal}, + {"VariantType", &PEGTransformerFactory::TransformVariantTypeInternal}, + {"NumericType", &PEGTransformerFactory::TransformNumericTypeInternal}, + {"SimpleNumericType", &PEGTransformerFactory::TransformSimpleNumericTypeInternal}, + {"DecimalNumericType", &PEGTransformerFactory::TransformDecimalNumericTypeInternal}, + {"IntType", &PEGTransformerFactory::TransformIntTypeInternal}, + {"IntegerType", &PEGTransformerFactory::TransformIntegerTypeInternal}, + {"SmallintType", &PEGTransformerFactory::TransformSmallintTypeInternal}, + {"BigintType", &PEGTransformerFactory::TransformBigintTypeInternal}, + {"RealType", &PEGTransformerFactory::TransformRealTypeInternal}, + {"BooleanType", &PEGTransformerFactory::TransformBooleanTypeInternal}, + {"DoubleType", &PEGTransformerFactory::TransformDoubleTypeInternal}, + {"FloatType", &PEGTransformerFactory::TransformFloatTypeInternal}, + {"DecimalType", &PEGTransformerFactory::TransformDecimalTypeInternal}, + {"DecType", &PEGTransformerFactory::TransformDecTypeInternal}, + {"NumericModType", &PEGTransformerFactory::TransformNumericModTypeInternal}, + {"QualifiedTypeName", &PEGTransformerFactory::TransformQualifiedTypeNameInternal}, + {"TypeNameAsQualifiedName", &PEGTransformerFactory::TransformTypeNameAsQualifiedNameInternal}, + {"CatalogReservedSchemaTypeName", &PEGTransformerFactory::TransformCatalogReservedSchemaTypeNameInternal}, + {"SchemaReservedTypeName", &PEGTransformerFactory::TransformSchemaReservedTypeNameInternal}, + {"TypeModifiers", &PEGTransformerFactory::TransformTypeModifiersInternal}, + {"RowType", &PEGTransformerFactory::TransformRowTypeInternal}, + {"SetofType", &PEGTransformerFactory::TransformSetofTypeInternal}, + {"UnionType", &PEGTransformerFactory::TransformUnionTypeInternal}, + {"ColIdTypeList", &PEGTransformerFactory::TransformColIdTypeListInternal}, + {"MapType", &PEGTransformerFactory::TransformMapTypeInternal}, + {"ColIdType", &PEGTransformerFactory::TransformColIdTypeInternal}, + {"ArrayBounds", &PEGTransformerFactory::TransformArrayBoundsInternal}, + {"ArrayKeyword", &PEGTransformerFactory::TransformArrayKeywordInternal}, + {"SquareBracketsArray", &PEGTransformerFactory::TransformSquareBracketsArrayInternal}, + {"TimeType", &PEGTransformerFactory::TransformTimeTypeInternal}, + {"TimeOrTimestamp", &PEGTransformerFactory::TransformTimeOrTimestampInternal}, + {"TimeTypeId", &PEGTransformerFactory::TransformTimeTypeIdInternal}, + {"TimestampTypeId", &PEGTransformerFactory::TransformTimestampTypeIdInternal}, + {"TimeZone", &PEGTransformerFactory::TransformTimeZoneInternal}, + {"WithOrWithout", &PEGTransformerFactory::TransformWithOrWithoutInternal}, + {"WithRule", &PEGTransformerFactory::TransformWithRuleInternal}, + {"WithoutRule", &PEGTransformerFactory::TransformWithoutRuleInternal}, + {"DisconnectStatement", &PEGTransformerFactory::TransformDisconnectStatementInternal}, + {"CopyStatement", &PEGTransformerFactory::TransformCopyStatementInternal}, + {"CopyVariations", &PEGTransformerFactory::TransformCopyVariationsInternal}, + {"CopyTable", &PEGTransformerFactory::TransformCopyTableInternal}, + {"FromOrTo", &PEGTransformerFactory::TransformFromOrToInternal}, + {"CopyFrom", &PEGTransformerFactory::TransformCopyFromInternal}, + {"CopyTo", &PEGTransformerFactory::TransformCopyToInternal}, + {"CopySelect", &PEGTransformerFactory::TransformCopySelectInternal}, + {"CopyFileName", &PEGTransformerFactory::TransformCopyFileNameInternal}, + {"CopyFileNameExpression", &PEGTransformerFactory::TransformCopyFileNameExpressionInternal}, + {"CopyFileNameStringLiteral", &PEGTransformerFactory::TransformCopyFileNameStringLiteralInternal}, + {"CopyFileNameIdentifier", &PEGTransformerFactory::TransformCopyFileNameIdentifierInternal}, + {"CopyFileNameIdentifierColId", &PEGTransformerFactory::TransformCopyFileNameIdentifierColIdInternal}, + {"IdentifierColId", &PEGTransformerFactory::TransformIdentifierColIdInternal}, + {"CopyOptions", &PEGTransformerFactory::TransformCopyOptionsInternal}, + {"CopyOptionList", &PEGTransformerFactory::TransformCopyOptionListInternal}, + {"SpecializedOptionList", &PEGTransformerFactory::TransformSpecializedOptionListInternal}, + {"SpecializedOption", &PEGTransformerFactory::TransformSpecializedOptionInternal}, + {"SingleOption", &PEGTransformerFactory::TransformSingleOptionInternal}, + {"BinaryOption", &PEGTransformerFactory::TransformBinaryOptionInternal}, + {"FreezeOption", &PEGTransformerFactory::TransformFreezeOptionInternal}, + {"OidsOption", &PEGTransformerFactory::TransformOidsOptionInternal}, + {"CsvOption", &PEGTransformerFactory::TransformCsvOptionInternal}, + {"HeaderOption", &PEGTransformerFactory::TransformHeaderOptionInternal}, + {"NullAsOption", &PEGTransformerFactory::TransformNullAsOptionInternal}, + {"DelimiterAsOption", &PEGTransformerFactory::TransformDelimiterAsOptionInternal}, + {"QuoteAsOption", &PEGTransformerFactory::TransformQuoteAsOptionInternal}, + {"EscapeAsOption", &PEGTransformerFactory::TransformEscapeAsOptionInternal}, + {"EncodingOption", &PEGTransformerFactory::TransformEncodingOptionInternal}, + {"ForceQuoteOption", &PEGTransformerFactory::TransformForceQuoteOptionInternal}, + {"StarSymbolColumnList", &PEGTransformerFactory::TransformStarSymbolColumnListInternal}, + {"ForceQuote", &PEGTransformerFactory::TransformForceQuoteInternal}, + {"PartitionByOption", &PEGTransformerFactory::TransformPartitionByOptionInternal}, + {"ForceNullOption", &PEGTransformerFactory::TransformForceNullOptionInternal}, + {"ForceNotNull", &PEGTransformerFactory::TransformForceNotNullInternal}, + {"GenericCopyOptionList", &PEGTransformerFactory::TransformGenericCopyOptionListInternal}, + {"GenericCopyOption", &PEGTransformerFactory::TransformGenericCopyOptionInternal}, + {"GenericCopyOptionValue", &PEGTransformerFactory::TransformGenericCopyOptionValueInternal}, + {"GenericCopyOptionOrderList", &PEGTransformerFactory::TransformGenericCopyOptionOrderListInternal}, + {"GenericCopyOptionExpression", &PEGTransformerFactory::TransformGenericCopyOptionExpressionInternal}, + {"GenericCopyOptionParenthesizedExpressionList", + &PEGTransformerFactory::TransformGenericCopyOptionParenthesizedExpressionListInternal}, + {"CopyFromDatabase", &PEGTransformerFactory::TransformCopyFromDatabaseInternal}, + {"CopyFromDatabaseWithFlag", &PEGTransformerFactory::TransformCopyFromDatabaseWithFlagInternal}, + {"CopyFromDatabaseWithoutFlag", &PEGTransformerFactory::TransformCopyFromDatabaseWithoutFlagInternal}, + {"CopyDatabaseFlag", &PEGTransformerFactory::TransformCopyDatabaseFlagInternal}, + {"SchemaOrData", &PEGTransformerFactory::TransformSchemaOrDataInternal}, + {"CopySchema", &PEGTransformerFactory::TransformCopySchemaInternal}, + {"CopyData", &PEGTransformerFactory::TransformCopyDataInternal}, + {"CreateIndexStmt", &PEGTransformerFactory::TransformCreateIndexStmtInternal}, + {"WithList", &PEGTransformerFactory::TransformWithListInternal}, + {"RelOptionOrOids", &PEGTransformerFactory::TransformRelOptionOrOidsInternal}, + {"RelOptionList", &PEGTransformerFactory::TransformRelOptionListInternal}, + {"Oids", &PEGTransformerFactory::TransformOidsInternal}, + {"WithOrWithoutOids", &PEGTransformerFactory::TransformWithOrWithoutOidsInternal}, + {"WithOids", &PEGTransformerFactory::TransformWithOidsInternal}, + {"WithoutOids", &PEGTransformerFactory::TransformWithoutOidsInternal}, + {"IndexElement", &PEGTransformerFactory::TransformIndexElementInternal}, + {"UniqueIndex", &PEGTransformerFactory::TransformUniqueIndexInternal}, + {"IndexType", &PEGTransformerFactory::TransformIndexTypeInternal}, + {"RelOption", &PEGTransformerFactory::TransformRelOptionInternal}, + {"RelOptionName", &PEGTransformerFactory::TransformRelOptionNameInternal}, + {"DottedIdentifierString", &PEGTransformerFactory::TransformDottedIdentifierStringInternal}, + {"RelOptionArgumentOpt", &PEGTransformerFactory::TransformRelOptionArgumentOptInternal}, + {"DefArg", &PEGTransformerFactory::TransformDefArgInternal}, + {"DefArgNull", &PEGTransformerFactory::TransformDefArgNullInternal}, + {"DefArgKeyword", &PEGTransformerFactory::TransformDefArgKeywordInternal}, + {"DefArgStringLiteral", &PEGTransformerFactory::TransformDefArgStringLiteralInternal}, + {"NoneLiteral", &PEGTransformerFactory::TransformNoneLiteralInternal}, + {"CreateMacroStmt", &PEGTransformerFactory::TransformCreateMacroStmtInternal}, + {"MacroOrFunction", &PEGTransformerFactory::TransformMacroOrFunctionInternal}, + {"MacroKeyword", &PEGTransformerFactory::TransformMacroKeywordInternal}, + {"FunctionKeyword", &PEGTransformerFactory::TransformFunctionKeywordInternal}, + {"MacroDefinition", &PEGTransformerFactory::TransformMacroDefinitionInternal}, + {"MacroDefinitionBody", &PEGTransformerFactory::TransformMacroDefinitionBodyInternal}, + {"MacroParameters", &PEGTransformerFactory::TransformMacroParametersInternal}, + {"MacroParameter", &PEGTransformerFactory::TransformMacroParameterInternal}, + {"SimpleParameter", &PEGTransformerFactory::TransformSimpleParameterInternal}, + {"ScalarMacroDefinition", &PEGTransformerFactory::TransformScalarMacroDefinitionInternal}, + {"TableMacroDefinition", &PEGTransformerFactory::TransformTableMacroDefinitionInternal}, + {"CreateSchemaStmt", &PEGTransformerFactory::TransformCreateSchemaStmtInternal}, + {"CreateSecretStmt", &PEGTransformerFactory::TransformCreateSecretStmtInternal}, + {"SecretStorageSpecifier", &PEGTransformerFactory::TransformSecretStorageSpecifierInternal}, + {"SecretName", &PEGTransformerFactory::TransformSecretNameInternal}, + {"CreateSequenceStmt", &PEGTransformerFactory::TransformCreateSequenceStmtInternal}, + {"SequenceOption", &PEGTransformerFactory::TransformSequenceOptionInternal}, + {"SeqSetCycle", &PEGTransformerFactory::TransformSeqSetCycleInternal}, + {"SeqCycle", &PEGTransformerFactory::TransformSeqCycleInternal}, + {"SeqNoCycle", &PEGTransformerFactory::TransformSeqNoCycleInternal}, + {"SeqSetIncrement", &PEGTransformerFactory::TransformSeqSetIncrementInternal}, + {"SeqSetMinMax", &PEGTransformerFactory::TransformSeqSetMinMaxInternal}, + {"SeqNoMinMax", &PEGTransformerFactory::TransformSeqNoMinMaxInternal}, + {"SeqStartWith", &PEGTransformerFactory::TransformSeqStartWithInternal}, + {"SeqOwnedBy", &PEGTransformerFactory::TransformSeqOwnedByInternal}, + {"SeqMinOrMax", &PEGTransformerFactory::TransformSeqMinOrMaxInternal}, + {"MinValue", &PEGTransformerFactory::TransformMinValueInternal}, + {"MaxValue", &PEGTransformerFactory::TransformMaxValueInternal}, + {"CreateStatement", &PEGTransformerFactory::TransformCreateStatementInternal}, + {"CreateStatementVariation", &PEGTransformerFactory::TransformCreateStatementVariationInternal}, + {"OrReplace", &PEGTransformerFactory::TransformOrReplaceInternal}, + {"Temporary", &PEGTransformerFactory::TransformTemporaryInternal}, + {"Persistent", &PEGTransformerFactory::TransformPersistentInternal}, + {"TempPersistent", &PEGTransformerFactory::TransformTempPersistentInternal}, + {"TemporaryPersistent", &PEGTransformerFactory::TransformTemporaryPersistentInternal}, + {"CreateTableStmt", &PEGTransformerFactory::TransformCreateTableStmtInternal}, + {"CreateTableDefinition", &PEGTransformerFactory::TransformCreateTableDefinitionInternal}, + {"CreateTableAs", &PEGTransformerFactory::TransformCreateTableAsInternal}, + {"PartitionSortedOptions", &PEGTransformerFactory::TransformPartitionSortedOptionsInternal}, + {"PartitionOptSortedOptions", &PEGTransformerFactory::TransformPartitionOptSortedOptionsInternal}, + {"SortedOptPartitionOptions", &PEGTransformerFactory::TransformSortedOptPartitionOptionsInternal}, + {"PartitionOptions", &PEGTransformerFactory::TransformPartitionOptionsInternal}, + {"SortedOptions", &PEGTransformerFactory::TransformSortedOptionsInternal}, + {"WithData", &PEGTransformerFactory::TransformWithDataInternal}, + {"WithDataOnly", &PEGTransformerFactory::TransformWithDataOnlyInternal}, + {"WithNoData", &PEGTransformerFactory::TransformWithNoDataInternal}, + {"IdentifierList", &PEGTransformerFactory::TransformIdentifierListInternal}, + {"CreateColumnList", &PEGTransformerFactory::TransformCreateColumnListInternal}, + {"IfNotExists", &PEGTransformerFactory::TransformIfNotExistsInternal}, + {"QualifiedName", &PEGTransformerFactory::TransformQualifiedNameInternal}, + {"SchemaReservedIdentifierOrStringLiteral", + &PEGTransformerFactory::TransformSchemaReservedIdentifierOrStringLiteralInternal}, + {"CatalogReservedSchemaIdentifier", &PEGTransformerFactory::TransformCatalogReservedSchemaIdentifierInternal}, + {"IdentifierOrStringLiteral", &PEGTransformerFactory::TransformIdentifierOrStringLiteralInternal}, + {"ReservedIdentifierOrStringLiteral", + &PEGTransformerFactory::TransformReservedIdentifierOrStringLiteralInternal}, + {"CatalogQualification", &PEGTransformerFactory::TransformCatalogQualificationInternal}, + {"SchemaQualification", &PEGTransformerFactory::TransformSchemaQualificationInternal}, + {"ReservedSchemaQualification", &PEGTransformerFactory::TransformReservedSchemaQualificationInternal}, + {"TableQualification", &PEGTransformerFactory::TransformTableQualificationInternal}, + {"ReservedTableQualification", &PEGTransformerFactory::TransformReservedTableQualificationInternal}, + {"CreateTableColumnList", &PEGTransformerFactory::TransformCreateTableColumnListInternal}, + {"CreateTableColumnElement", &PEGTransformerFactory::TransformCreateTableColumnElementInternal}, + {"CreateTableColumnDefinition", &PEGTransformerFactory::TransformCreateTableColumnDefinitionInternal}, + {"CreateTableConstraint", &PEGTransformerFactory::TransformCreateTableConstraintInternal}, + {"ColumnDefinition", &PEGTransformerFactory::TransformColumnDefinitionInternal}, + {"ColumnConstraint", &PEGTransformerFactory::TransformColumnConstraintInternal}, + {"NotNullConstraint", &PEGTransformerFactory::TransformNotNullConstraintInternal}, + {"NullConstraint", &PEGTransformerFactory::TransformNullConstraintInternal}, + {"NotNullColumnConstraint", &PEGTransformerFactory::TransformNotNullColumnConstraintInternal}, + {"UniqueConstraint", &PEGTransformerFactory::TransformUniqueConstraintInternal}, + {"PrimaryKeyConstraint", &PEGTransformerFactory::TransformPrimaryKeyConstraintInternal}, + {"DefaultValue", &PEGTransformerFactory::TransformDefaultValueInternal}, + {"CheckConstraint", &PEGTransformerFactory::TransformCheckConstraintInternal}, + {"ForeignKeyConstraint", &PEGTransformerFactory::TransformForeignKeyConstraintInternal}, + {"ColumnCollation", &PEGTransformerFactory::TransformColumnCollationInternal}, + {"ColumnCompression", &PEGTransformerFactory::TransformColumnCompressionInternal}, + {"KeyActions", &PEGTransformerFactory::TransformKeyActionsInternal}, + {"UpdateAction", &PEGTransformerFactory::TransformUpdateActionInternal}, + {"DeleteAction", &PEGTransformerFactory::TransformDeleteActionInternal}, + {"KeyAction", &PEGTransformerFactory::TransformKeyActionInternal}, + {"NoKeyAction", &PEGTransformerFactory::TransformNoKeyActionInternal}, + {"RestrictKeyAction", &PEGTransformerFactory::TransformRestrictKeyActionInternal}, + {"CascadeKeyAction", &PEGTransformerFactory::TransformCascadeKeyActionInternal}, + {"SetNullKeyAction", &PEGTransformerFactory::TransformSetNullKeyActionInternal}, + {"SetDefaultKeyAction", &PEGTransformerFactory::TransformSetDefaultKeyActionInternal}, + {"TopLevelConstraint", &PEGTransformerFactory::TransformTopLevelConstraintInternal}, + {"TopLevelConstraintList", &PEGTransformerFactory::TransformTopLevelConstraintListInternal}, + {"TopPrimaryKeyConstraint", &PEGTransformerFactory::TransformTopPrimaryKeyConstraintInternal}, + {"TopUniqueConstraint", &PEGTransformerFactory::TransformTopUniqueConstraintInternal}, + {"TopForeignKeyConstraint", &PEGTransformerFactory::TransformTopForeignKeyConstraintInternal}, + {"ColumnIdList", &PEGTransformerFactory::TransformColumnIdListInternal}, + {"DottedIdentifier", &PEGTransformerFactory::TransformDottedIdentifierInternal}, + {"DotColLabel", &PEGTransformerFactory::TransformDotColLabelInternal}, + {"ColId", &PEGTransformerFactory::TransformColIdInternal}, + {"ColIdOrString", &PEGTransformerFactory::TransformColIdOrStringInternal}, + {"TypeFuncName", &PEGTransformerFactory::TransformTypeFuncNameInternal}, + {"GeneratedColumn", &PEGTransformerFactory::TransformGeneratedColumnInternal}, + {"GeneratedColumnType", &PEGTransformerFactory::TransformGeneratedColumnTypeInternal}, + {"CommitAction", &PEGTransformerFactory::TransformCommitActionInternal}, + {"PreserveOrDelete", &PEGTransformerFactory::TransformPreserveOrDeleteInternal}, + {"PreserveRows", &PEGTransformerFactory::TransformPreserveRowsInternal}, + {"DeleteRows", &PEGTransformerFactory::TransformDeleteRowsInternal}, + {"VirtualGeneratedColumn", &PEGTransformerFactory::TransformVirtualGeneratedColumnInternal}, + {"StoredGeneratedColumn", &PEGTransformerFactory::TransformStoredGeneratedColumnInternal}, + {"CreateTriggerStmt", &PEGTransformerFactory::TransformCreateTriggerStmtInternal}, + {"TriggerBody", &PEGTransformerFactory::TransformTriggerBodyInternal}, + {"TriggerName", &PEGTransformerFactory::TransformTriggerNameInternal}, + {"ReferencingClause", &PEGTransformerFactory::TransformReferencingClauseInternal}, + {"ReferencingItem", &PEGTransformerFactory::TransformReferencingItemInternal}, + {"ReferencingNewTableAs", &PEGTransformerFactory::TransformReferencingNewTableAsInternal}, + {"ReferencingOldTableAs", &PEGTransformerFactory::TransformReferencingOldTableAsInternal}, + {"TriggerTiming", &PEGTransformerFactory::TransformTriggerTimingInternal}, + {"TriggerBefore", &PEGTransformerFactory::TransformTriggerBeforeInternal}, + {"TriggerAfter", &PEGTransformerFactory::TransformTriggerAfterInternal}, + {"TriggerInsteadOf", &PEGTransformerFactory::TransformTriggerInsteadOfInternal}, + {"TriggerEvent", &PEGTransformerFactory::TransformTriggerEventInternal}, + {"TriggerEventInsert", &PEGTransformerFactory::TransformTriggerEventInsertInternal}, + {"TriggerEventDelete", &PEGTransformerFactory::TransformTriggerEventDeleteInternal}, + {"TriggerEventUpdate", &PEGTransformerFactory::TransformTriggerEventUpdateInternal}, + {"TriggerEventUpdateOf", &PEGTransformerFactory::TransformTriggerEventUpdateOfInternal}, + {"TriggerColumnList", &PEGTransformerFactory::TransformTriggerColumnListInternal}, + {"ForEachClause", &PEGTransformerFactory::TransformForEachClauseInternal}, + {"ForEachRow", &PEGTransformerFactory::TransformForEachRowInternal}, + {"ForEachStatement", &PEGTransformerFactory::TransformForEachStatementInternal}, + {"CreateTypeStmt", &PEGTransformerFactory::TransformCreateTypeStmtInternal}, + {"CreateType", &PEGTransformerFactory::TransformCreateTypeInternal}, + {"CreateTypeFromType", &PEGTransformerFactory::TransformCreateTypeFromTypeInternal}, + {"EnumSelectType", &PEGTransformerFactory::TransformEnumSelectTypeInternal}, + {"EnumStringLiteralList", &PEGTransformerFactory::TransformEnumStringLiteralListInternal}, + {"CreateViewStmt", &PEGTransformerFactory::TransformCreateViewStmtInternal}, + {"CreateRecursive", &PEGTransformerFactory::TransformCreateRecursiveInternal}, + {"DeallocateStatement", &PEGTransformerFactory::TransformDeallocateStatementInternal}, + {"DeallocatePrepare", &PEGTransformerFactory::TransformDeallocatePrepareInternal}, + {"DeleteStatement", &PEGTransformerFactory::TransformDeleteStatementInternal}, + {"TruncateStatement", &PEGTransformerFactory::TransformTruncateStatementInternal}, + {"TargetOptAlias", &PEGTransformerFactory::TransformTargetOptAliasInternal}, + {"DeleteUsingClause", &PEGTransformerFactory::TransformDeleteUsingClauseInternal}, + {"DescribeStatement", &PEGTransformerFactory::TransformDescribeStatementInternal}, + {"ShowSelect", &PEGTransformerFactory::TransformShowSelectInternal}, + {"ShowAllTables", &PEGTransformerFactory::TransformShowAllTablesInternal}, + {"ShowQualifiedName", &PEGTransformerFactory::TransformShowQualifiedNameInternal}, + {"ShowTables", &PEGTransformerFactory::TransformShowTablesInternal}, + {"DescribeTarget", &PEGTransformerFactory::TransformDescribeTargetInternal}, + {"DescribeBaseTableName", &PEGTransformerFactory::TransformDescribeBaseTableNameInternal}, + {"DescribeStringLiteral", &PEGTransformerFactory::TransformDescribeStringLiteralInternal}, + {"ShowOrDescribeOrSummarize", &PEGTransformerFactory::TransformShowOrDescribeOrSummarizeInternal}, + {"Summarize", &PEGTransformerFactory::TransformSummarizeInternal}, + {"SummarizeRule", &PEGTransformerFactory::TransformSummarizeRuleInternal}, + {"ShowOrDescribe", &PEGTransformerFactory::TransformShowOrDescribeInternal}, + {"ShowRule", &PEGTransformerFactory::TransformShowRuleInternal}, + {"DescribeRule", &PEGTransformerFactory::TransformDescribeRuleInternal}, + {"DescribeLongRule", &PEGTransformerFactory::TransformDescribeLongRuleInternal}, + {"DescRule", &PEGTransformerFactory::TransformDescRuleInternal}, + {"DetachStatement", &PEGTransformerFactory::TransformDetachStatementInternal}, + {"DropStatement", &PEGTransformerFactory::TransformDropStatementInternal}, + {"DropEntries", &PEGTransformerFactory::TransformDropEntriesInternal}, + {"DropTrigger", &PEGTransformerFactory::TransformDropTriggerInternal}, + {"DropTable", &PEGTransformerFactory::TransformDropTableInternal}, + {"DropTableFunction", &PEGTransformerFactory::TransformDropTableFunctionInternal}, + {"DropFunction", &PEGTransformerFactory::TransformDropFunctionInternal}, + {"DropSchema", &PEGTransformerFactory::TransformDropSchemaInternal}, + {"DropIndex", &PEGTransformerFactory::TransformDropIndexInternal}, + {"QualifiedIndexName", &PEGTransformerFactory::TransformQualifiedIndexNameInternal}, + {"QualifiedIndexNameString", &PEGTransformerFactory::TransformQualifiedIndexNameStringInternal}, + {"SchemaReservedIndex", &PEGTransformerFactory::TransformSchemaReservedIndexInternal}, + {"CatalogReservedSchemaIndex", &PEGTransformerFactory::TransformCatalogReservedSchemaIndexInternal}, + {"DropSequence", &PEGTransformerFactory::TransformDropSequenceInternal}, + {"DropCollation", &PEGTransformerFactory::TransformDropCollationInternal}, + {"DropType", &PEGTransformerFactory::TransformDropTypeInternal}, + {"DropSecret", &PEGTransformerFactory::TransformDropSecretInternal}, + {"TableOrView", &PEGTransformerFactory::TransformTableOrViewInternal}, + {"MaterializedViewEntry", &PEGTransformerFactory::TransformMaterializedViewEntryInternal}, + {"FunctionTypeMacro", &PEGTransformerFactory::TransformFunctionTypeMacroInternal}, + {"FunctionTypeMacroKeyword", &PEGTransformerFactory::TransformFunctionTypeMacroKeywordInternal}, + {"FunctionTypeFunction", &PEGTransformerFactory::TransformFunctionTypeFunctionInternal}, + {"DropBehavior", &PEGTransformerFactory::TransformDropBehaviorInternal}, + {"CascadeDropBehavior", &PEGTransformerFactory::TransformCascadeDropBehaviorInternal}, + {"RestrictDropBehavior", &PEGTransformerFactory::TransformRestrictDropBehaviorInternal}, + {"IfExists", &PEGTransformerFactory::TransformIfExistsInternal}, + {"QualifiedSchemaName", &PEGTransformerFactory::TransformQualifiedSchemaNameInternal}, + {"QualifiedSchemaNameString", &PEGTransformerFactory::TransformQualifiedSchemaNameStringInternal}, + {"CatalogReservedSchema", &PEGTransformerFactory::TransformCatalogReservedSchemaInternal}, + {"DropSecretStorage", &PEGTransformerFactory::TransformDropSecretStorageInternal}, + {"ExecuteStatement", &PEGTransformerFactory::TransformExecuteStatementInternal}, + {"ExplainStatement", &PEGTransformerFactory::TransformExplainStatementInternal}, + {"ExplainAnalyze", &PEGTransformerFactory::TransformExplainAnalyzeInternal}, + {"ExplainOptionList", &PEGTransformerFactory::TransformExplainOptionListInternal}, + {"ExplainOption", &PEGTransformerFactory::TransformExplainOptionInternal}, + {"ExplainSelectStatement", &PEGTransformerFactory::TransformExplainSelectStatementInternal}, + {"ExplainableStatements", &PEGTransformerFactory::TransformExplainableStatementsInternal}, + {"ExportStatement", &PEGTransformerFactory::TransformExportStatementInternal}, + {"ExportSource", &PEGTransformerFactory::TransformExportSourceInternal}, + {"ImportStatement", &PEGTransformerFactory::TransformImportStatementInternal}, + {"InsertStatement", &PEGTransformerFactory::TransformInsertStatementInternal}, + {"OrAction", &PEGTransformerFactory::TransformOrActionInternal}, + {"InsertOrReplace", &PEGTransformerFactory::TransformInsertOrReplaceInternal}, + {"InsertOrIgnore", &PEGTransformerFactory::TransformInsertOrIgnoreInternal}, + {"ByNameOrPosition", &PEGTransformerFactory::TransformByNameOrPositionInternal}, + {"InsertByNameOrder", &PEGTransformerFactory::TransformInsertByNameOrderInternal}, + {"InsertByPositionOrder", &PEGTransformerFactory::TransformInsertByPositionOrderInternal}, + {"InsertByName", &PEGTransformerFactory::TransformInsertByNameInternal}, + {"InsertByPosition", &PEGTransformerFactory::TransformInsertByPositionInternal}, + {"InsertTarget", &PEGTransformerFactory::TransformInsertTargetInternal}, + {"InsertAlias", &PEGTransformerFactory::TransformInsertAliasInternal}, + {"ColumnList", &PEGTransformerFactory::TransformColumnListInternal}, + {"InsertColumnList", &PEGTransformerFactory::TransformInsertColumnListInternal}, + {"InsertValues", &PEGTransformerFactory::TransformInsertValuesInternal}, + {"SelectInsertValues", &PEGTransformerFactory::TransformSelectInsertValuesInternal}, + {"DefaultValues", &PEGTransformerFactory::TransformDefaultValuesInternal}, + {"OnConflictClause", &PEGTransformerFactory::TransformOnConflictClauseInternal}, + {"OnConflictTarget", &PEGTransformerFactory::TransformOnConflictTargetInternal}, + {"OnConflictExpressionTarget", &PEGTransformerFactory::TransformOnConflictExpressionTargetInternal}, + {"OnConflictIndexTarget", &PEGTransformerFactory::TransformOnConflictIndexTargetInternal}, + {"OnConflictAction", &PEGTransformerFactory::TransformOnConflictActionInternal}, + {"OnConflictUpdate", &PEGTransformerFactory::TransformOnConflictUpdateInternal}, + {"OnConflictNothing", &PEGTransformerFactory::TransformOnConflictNothingInternal}, + {"ReturningClause", &PEGTransformerFactory::TransformReturningClauseInternal}, + {"LoadStatement", &PEGTransformerFactory::TransformLoadStatementInternal}, + {"ExtensionAlias", &PEGTransformerFactory::TransformExtensionAliasInternal}, + {"InstallStatement", &PEGTransformerFactory::TransformInstallStatementInternal}, + {"UpdateExtensionsStatement", &PEGTransformerFactory::TransformUpdateExtensionsStatementInternal}, + {"FromSource", &PEGTransformerFactory::TransformFromSourceInternal}, + {"FromSourceIdentifier", &PEGTransformerFactory::TransformFromSourceIdentifierInternal}, + {"FromSourceString", &PEGTransformerFactory::TransformFromSourceStringInternal}, + {"VersionNumber", &PEGTransformerFactory::TransformVersionNumberInternal}, + {"MergeIntoStatement", &PEGTransformerFactory::TransformMergeIntoStatementInternal}, + {"MergeIntoUsingClause", &PEGTransformerFactory::TransformMergeIntoUsingClauseInternal}, + {"MergeMatch", &PEGTransformerFactory::TransformMergeMatchInternal}, + {"MatchedClause", &PEGTransformerFactory::TransformMatchedClauseInternal}, + {"MatchedClauseAction", &PEGTransformerFactory::TransformMatchedClauseActionInternal}, + {"UpdateMatchClause", &PEGTransformerFactory::TransformUpdateMatchClauseInternal}, + {"UpdateMatchInfo", &PEGTransformerFactory::TransformUpdateMatchInfoInternal}, + {"UpdateMatchSetAction", &PEGTransformerFactory::TransformUpdateMatchSetActionInternal}, + {"UpdateByNameOrPosition", &PEGTransformerFactory::TransformUpdateByNameOrPositionInternal}, + {"DeleteMatchClause", &PEGTransformerFactory::TransformDeleteMatchClauseInternal}, + {"InsertMatchClause", &PEGTransformerFactory::TransformInsertMatchClauseInternal}, + {"InsertMatchInfo", &PEGTransformerFactory::TransformInsertMatchInfoInternal}, + {"InsertDefaultValues", &PEGTransformerFactory::TransformInsertDefaultValuesInternal}, + {"InsertByNameOrPosition", &PEGTransformerFactory::TransformInsertByNameOrPositionInternal}, + {"InsertValuesList", &PEGTransformerFactory::TransformInsertValuesListInternal}, + {"DoNothingMatchClause", &PEGTransformerFactory::TransformDoNothingMatchClauseInternal}, + {"ErrorMatchClause", &PEGTransformerFactory::TransformErrorMatchClauseInternal}, + {"UpdateMatchSetClause", &PEGTransformerFactory::TransformUpdateMatchSetClauseInternal}, + {"UpdateMatchSetInfo", &PEGTransformerFactory::TransformUpdateMatchSetInfoInternal}, + {"AndExpression", &PEGTransformerFactory::TransformAndExpressionInternal}, + {"NotMatchedClause", &PEGTransformerFactory::TransformNotMatchedClauseInternal}, + {"BySourceOrTarget", &PEGTransformerFactory::TransformBySourceOrTargetInternal}, + {"BySource", &PEGTransformerFactory::TransformBySourceInternal}, + {"ByTarget", &PEGTransformerFactory::TransformByTargetInternal}, + {"PivotOn", &PEGTransformerFactory::TransformPivotOnInternal}, + {"PivotUsing", &PEGTransformerFactory::TransformPivotUsingInternal}, + {"PivotColumnList", &PEGTransformerFactory::TransformPivotColumnListInternal}, + {"PivotColumnEntry", &PEGTransformerFactory::TransformPivotColumnEntryInternal}, + {"PivotColumnExpression", &PEGTransformerFactory::TransformPivotColumnExpressionInternal}, + {"PivotColumnSubquery", &PEGTransformerFactory::TransformPivotColumnSubqueryInternal}, + {"IntoNameValues", &PEGTransformerFactory::TransformIntoNameValuesInternal}, + {"IncludeOrExcludeNulls", &PEGTransformerFactory::TransformIncludeOrExcludeNullsInternal}, + {"IncludeNulls", &PEGTransformerFactory::TransformIncludeNullsInternal}, + {"ExcludeNulls", &PEGTransformerFactory::TransformExcludeNullsInternal}, + {"UnpivotHeader", &PEGTransformerFactory::TransformUnpivotHeaderInternal}, + {"UnpivotHeaderSingle", &PEGTransformerFactory::TransformUnpivotHeaderSingleInternal}, + {"UnpivotHeaderList", &PEGTransformerFactory::TransformUnpivotHeaderListInternal}, + {"PragmaStatement", &PEGTransformerFactory::TransformPragmaStatementInternal}, + {"PragmaAssignOrFunction", &PEGTransformerFactory::TransformPragmaAssignOrFunctionInternal}, + {"PragmaAssign", &PEGTransformerFactory::TransformPragmaAssignInternal}, + {"PragmaFunction", &PEGTransformerFactory::TransformPragmaFunctionInternal}, + {"PragmaParameters", &PEGTransformerFactory::TransformPragmaParametersInternal}, + {"PrepareStatement", &PEGTransformerFactory::TransformPrepareStatementInternal}, + {"TypeList", &PEGTransformerFactory::TransformTypeListInternal}, + {"SetStatement", &PEGTransformerFactory::TransformSetStatementInternal}, + {"SetAssignmentOrTimeZone", &PEGTransformerFactory::TransformSetAssignmentOrTimeZoneInternal}, + {"ResetStatement", &PEGTransformerFactory::TransformResetStatementInternal}, + {"StandardAssignment", &PEGTransformerFactory::TransformStandardAssignmentInternal}, + {"SetVariableOrSetting", &PEGTransformerFactory::TransformSetVariableOrSettingInternal}, + {"SetTimeZone", &PEGTransformerFactory::TransformSetTimeZoneInternal}, + {"ZoneValue", &PEGTransformerFactory::TransformZoneValueInternal}, + {"ZoneLocal", &PEGTransformerFactory::TransformZoneLocalInternal}, + {"ZoneDefault", &PEGTransformerFactory::TransformZoneDefaultInternal}, + {"ZoneStringLiteral", &PEGTransformerFactory::TransformZoneStringLiteralInternal}, + {"ZoneIdentifier", &PEGTransformerFactory::TransformZoneIdentifierInternal}, + {"ZoneIntervalWithInterval", &PEGTransformerFactory::TransformZoneIntervalWithIntervalInternal}, + {"ZoneIntervalWithPrecision", &PEGTransformerFactory::TransformZoneIntervalWithPrecisionInternal}, + {"SetSetting", &PEGTransformerFactory::TransformSetSettingInternal}, + {"SetVariable", &PEGTransformerFactory::TransformSetVariableInternal}, + {"VariableScope", &PEGTransformerFactory::TransformVariableScopeInternal}, + {"SettingScope", &PEGTransformerFactory::TransformSettingScopeInternal}, + {"LocalScope", &PEGTransformerFactory::TransformLocalScopeInternal}, + {"SessionScope", &PEGTransformerFactory::TransformSessionScopeInternal}, + {"GlobalScope", &PEGTransformerFactory::TransformGlobalScopeInternal}, + {"SetAssignment", &PEGTransformerFactory::TransformSetAssignmentInternal}, + {"VariableList", &PEGTransformerFactory::TransformVariableListInternal}, {"TransactionStatement", &PEGTransformerFactory::TransformTransactionStatementInternal}, {"BeginTransaction", &PEGTransformerFactory::TransformBeginTransactionInternal}, {"RollbackTransaction", &PEGTransformerFactory::TransformRollbackTransactionInternal}, {"CommitTransaction", &PEGTransformerFactory::TransformCommitTransactionInternal}, {"ReadOrWrite", &PEGTransformerFactory::TransformReadOrWriteInternal}, - {"DetachStatement", &PEGTransformerFactory::TransformDetachStatementInternal}, - {"ExportStatement", &PEGTransformerFactory::TransformExportStatementInternal}, - {"ExportSource", &PEGTransformerFactory::TransformExportSourceInternal}, - {"ImportStatement", &PEGTransformerFactory::TransformImportStatementInternal}, + {"ReadOnlyOrReadWrite", &PEGTransformerFactory::TransformReadOnlyOrReadWriteInternal}, + {"ReadOnly", &PEGTransformerFactory::TransformReadOnlyInternal}, + {"ReadWrite", &PEGTransformerFactory::TransformReadWriteInternal}, + {"UpdateStatement", &PEGTransformerFactory::TransformUpdateStatementInternal}, + {"UpdateTarget", &PEGTransformerFactory::TransformUpdateTargetInternal}, + {"BaseTableSet", &PEGTransformerFactory::TransformBaseTableSetInternal}, + {"BaseTableAliasSet", &PEGTransformerFactory::TransformBaseTableAliasSetInternal}, + {"UpdateAlias", &PEGTransformerFactory::TransformUpdateAliasInternal}, + {"UpdateSetClause", &PEGTransformerFactory::TransformUpdateSetClauseInternal}, + {"UpdateSetTuple", &PEGTransformerFactory::TransformUpdateSetTupleInternal}, + {"UpdateSetElementList", &PEGTransformerFactory::TransformUpdateSetElementListInternal}, + {"UpdateSetElement", &PEGTransformerFactory::TransformUpdateSetElementInternal}, + {"UpdateSetColumnTarget", &PEGTransformerFactory::TransformUpdateSetColumnTargetInternal}, + {"UseStatement", &PEGTransformerFactory::TransformUseStatementInternal}, + {"UseTarget", &PEGTransformerFactory::TransformUseTargetInternal}, + {"SchemaNameAsUseTarget", &PEGTransformerFactory::TransformSchemaNameAsUseTargetInternal}, + {"CatalogNameAsUseTarget", &PEGTransformerFactory::TransformCatalogNameAsUseTargetInternal}, + {"UseTargetCatalogSchema", &PEGTransformerFactory::TransformUseTargetCatalogSchemaInternal}, + {"DotIdentifier", &PEGTransformerFactory::TransformDotIdentifierInternal}, + {"VacuumStatement", &PEGTransformerFactory::TransformVacuumStatementInternal}, + {"VacuumOptions", &PEGTransformerFactory::TransformVacuumOptionsInternal}, + {"VacuumParensOptions", &PEGTransformerFactory::TransformVacuumParensOptionsInternal}, + {"VacuumLegacyOptions", &PEGTransformerFactory::TransformVacuumLegacyOptionsInternal}, + {"VacuumOption", &PEGTransformerFactory::TransformVacuumOptionInternal}, + {"OptAnalyze", &PEGTransformerFactory::TransformOptAnalyzeInternal}, + {"OptFull", &PEGTransformerFactory::TransformOptFullInternal}, + {"OptFreeze", &PEGTransformerFactory::TransformOptFreezeInternal}, + {"OptVerbose", &PEGTransformerFactory::TransformOptVerboseInternal}, + {"NameList", &PEGTransformerFactory::TransformNameListInternal}, }; for (const auto &rule : builtin_transform_rules) { sql_transform_functions[rule.name] = rule.transform; diff --git a/src/duckdb/src/parser/peg/transformer/transform_generic_copy_option.cpp b/src/duckdb/src/parser/peg/transformer/transform_generic_copy_option.cpp new file mode 100644 index 000000000..dde50816a --- /dev/null +++ b/src/duckdb/src/parser/peg/transformer/transform_generic_copy_option.cpp @@ -0,0 +1,137 @@ +#include "duckdb/parser/peg/ast/generic_copy_option.hpp" +#include "duckdb/parser/expression/cast_expression.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/peg/transformer/peg_transformer.hpp" + +namespace duckdb { + +vector +PEGTransformerFactory::TransformGenericCopyOptionList(PEGTransformer &transformer, + const vector &generic_copy_option) { + return generic_copy_option; +} + +static void SetGenericCopyOptionExpression(GenericCopyOption ©_option, unique_ptr expression) { + if (expression->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { + copy_option.children.push_back(Value(expression->Cast().GetValue())); + } else if (expression->GetExpressionType() == ExpressionType::COLUMN_REF) { + copy_option.children.push_back(Value(expression->Cast().GetColumnName())); + } else if (expression->GetExpressionType() == ExpressionType::PLACEHOLDER) { + auto &op_expr = expression->Cast(); + for (auto &child : op_expr.GetChildren()) { + if (child->GetExpressionClass() == ExpressionClass::CONSTANT) { + copy_option.children.push_back(Value(child->Cast().GetValue())); + } else if (child->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + copy_option.children.push_back(Value(child->Cast().GetColumnName())); + } else { + throw InternalException("Unexpected expression type %s encountered for GenericCopyOption", + ExpressionClassToString(child->GetExpressionClass())); + } + } + } else if (expression->GetExpressionType() == ExpressionType::FUNCTION) { + copy_option.expression = std::move(expression); + } else if (expression->GetExpressionType() == ExpressionType::STAR) { + copy_option.children.push_back(Value("*")); + } else if (expression->GetExpressionType() == ExpressionType::OPERATOR_CAST) { + auto &cast_expr = expression->Cast(); + if (cast_expr.Child().GetExpressionClass() == ExpressionClass::CONSTANT) { + auto &const_expr = cast_expr.Child().Cast(); + if (const_expr.GetValue().GetValue() == "t") { + copy_option.children.push_back(Value(true)); + } else if (const_expr.GetValue().GetValue() == "f") { + copy_option.children.push_back(Value(false)); + } else { + copy_option.expression = std::move(expression); + } + } else { + copy_option.expression = std::move(expression); + } + } else { + throw NotImplementedException("Unrecognized expression type %s", + ExpressionTypeToString(expression->GetExpressionType())); + } +} + +static unique_ptr CreateRowFunction(vector> &&children) { + return make_uniq(INVALID_CATALOG, DEFAULT_SCHEMA, "row", std::move(children)); +} + +static unique_ptr CreateOrderByRowFunction(const vector &orders) { + vector> children; + children.reserve(orders.size()); + for (auto &order : orders) { + children.push_back(make_uniq(Value(order.ToString()))); + } + return CreateRowFunction(std::move(children)); +} + +static unique_ptr CreateExpressionRowFunction(vector &orders) { + vector> children; + children.reserve(orders.size()); + for (auto &order : orders) { + children.push_back(std::move(order.expression)); + } + return CreateRowFunction(std::move(children)); +} + +GenericCopyOption PEGTransformerFactory::TransformGenericCopyOption(PEGTransformer &transformer, + const Identifier ©_option_name, + GenericCopyOptionValue generic_copy_option_value) { + GenericCopyOption copy_option; + copy_option.name = Identifier(StringUtil::Lower(copy_option_name.GetIdentifierName())); + if (!generic_copy_option_value.has_value) { + return copy_option; + } + + if (generic_copy_option_value.is_order_list) { + auto &orders = generic_copy_option_value.order_list; + bool has_order_modifier = false; + for (auto &order : orders) { + if (order.type != OrderType::ORDER_DEFAULT || order.null_order != OrderByNullType::ORDER_DEFAULT) { + has_order_modifier = true; + break; + } + } + + if (copy_option.name == "ORDER_BY") { + copy_option.expression = CreateOrderByRowFunction(orders); + } else if (has_order_modifier) { + throw ParserException("ORDER BY modifiers are only supported in the ORDER_BY option"); + } else if (orders.size() == 1) { + SetGenericCopyOptionExpression(copy_option, std::move(orders[0].expression)); + } else { + copy_option.expression = CreateExpressionRowFunction(generic_copy_option_value.order_list); + } + } else { + SetGenericCopyOptionExpression(copy_option, std::move(generic_copy_option_value.expression)); + } + return copy_option; +} + +GenericCopyOptionValue PEGTransformerFactory::TransformGenericCopyOptionOrderList( + PEGTransformer &transformer, vector generic_copy_option_parenthesized_expression_list) { + GenericCopyOptionValue result; + result.has_value = true; + result.is_order_list = true; + result.order_list = std::move(generic_copy_option_parenthesized_expression_list); + return result; +} + +GenericCopyOptionValue +PEGTransformerFactory::TransformGenericCopyOptionExpression(PEGTransformer &transformer, + unique_ptr expression) { + GenericCopyOptionValue result; + result.has_value = true; + result.expression = std::move(expression); + return result; +} + +vector PEGTransformerFactory::TransformGenericCopyOptionParenthesizedExpressionList( + PEGTransformer &transformer, vector order_by_expression_list) { + return order_by_expression_list; +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_import.cpp b/src/duckdb/src/parser/peg/transformer/transform_import.cpp deleted file mode 100644 index 1d54d470d..000000000 --- a/src/duckdb/src/parser/peg/transformer/transform_import.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "duckdb/parser/statement/pragma_statement.hpp" -#include "duckdb/parser/peg/transformer/peg_transformer.hpp" - -namespace duckdb { - -unique_ptr PEGTransformerFactory::TransformImportStatement(const string &string_literal) { - auto result = make_uniq(); - result->info->name = "import_database"; - result->info->parameters.emplace_back(make_uniq(Value(string_literal))); - return std::move(result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_insert.cpp b/src/duckdb/src/parser/peg/transformer/transform_insert.cpp index 3473ef001..f2bb7484c 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_insert.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_insert.cpp @@ -6,24 +6,19 @@ namespace duckdb { -unique_ptr PEGTransformerFactory::TransformInsertStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr PEGTransformerFactory::TransformInsertStatement( + PEGTransformer &transformer, CommonTableExpressionMap with_clause, const OnConflictAction &or_action, + unique_ptr insert_target, const InsertColumnOrder &by_name_or_position, + const vector &insert_column_list, InsertValues insert_values, unique_ptr on_conflict_clause, + vector> returning_clause) { auto result = make_uniq(); auto &node = *result->node; - auto &with_opt = list_pr.Child(0); - if (with_opt.HasResult()) { - node.cte_map = transformer.Transform(with_opt.GetResult()); - } - auto &or_action_opt = list_pr.Child(2); - auto insert_target = transformer.Transform>(list_pr.Child(4)); - // TODO(Dtenwolde) What about the insert alias? + node.cte_map = std::move(with_clause); node.catalog = insert_target->catalog_name; node.schema = insert_target->schema_name; node.table = insert_target->table_name; - transformer.TransformOptional(list_pr, 5, node.column_order); - transformer.TransformOptional>(list_pr, 6, node.columns); - auto insert_values = transformer.Transform(list_pr.Child(7)); + node.column_order = by_name_or_position; + node.columns = StringsToIdentifiers(insert_column_list); if (!node.columns.empty() && insert_values.default_values) { throw ParserException( "You can not provide both a column list and DEFAULT VALUES, please remove one of the two"); @@ -34,156 +29,118 @@ unique_ptr PEGTransformerFactory::TransformInsertStatement(PEGTran if (insert_values.select_statement) { node.select_statement = std::move(insert_values.select_statement); } - auto on_conflict_info = make_uniq(); - auto &on_conflict_clause = list_pr.Child(8); - if (on_conflict_clause.HasResult()) { - if (or_action_opt.HasResult()) { + if (on_conflict_clause) { + if (or_action != OnConflictAction::THROW) { // OR REPLACE | OR IGNORE are shorthands for the ON CONFLICT clause throw ParserException("You can not provide both OR REPLACE|IGNORE and an ON CONFLICT clause, please remove " - "the first if you want to have more granual control"); + "the first if you want to have more granular control"); } - on_conflict_info = transformer.Transform>(on_conflict_clause.GetResult()); - node.on_conflict_info = std::move(on_conflict_info); + node.on_conflict_info = std::move(on_conflict_clause); node.table_ref = std::move(insert_target); - } else if (or_action_opt.HasResult()) { - on_conflict_info->action_type = transformer.Transform(or_action_opt.GetResult()); + } else if (or_action != OnConflictAction::THROW) { + auto on_conflict_info = make_uniq(); + on_conflict_info->action_type = or_action; node.on_conflict_info = std::move(on_conflict_info); node.table_ref = std::move(insert_target); } - transformer.TransformOptional>>(list_pr, 9, node.returning_list); + node.returning_list = std::move(returning_clause); return std::move(result); } -OnConflictAction PEGTransformerFactory::TransformOrAction(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &replace_or_ignore = list_pr.Child(1).GetResult(); - auto &replace_or_ignore_keyword = replace_or_ignore.Cast().keyword; - if (StringUtil::CIEquals(replace_or_ignore_keyword, "replace")) { - return OnConflictAction::REPLACE; - } else if (StringUtil::CIEquals(replace_or_ignore_keyword, "ignore")) { - return OnConflictAction::NOTHING; - } else { - throw InternalException("Unexpected keyword %s encountered for OrAction", replace_or_ignore_keyword); - } +OnConflictAction PEGTransformerFactory::TransformInsertOrReplace(PEGTransformer &transformer) { + return OnConflictAction::REPLACE; +} + +OnConflictAction PEGTransformerFactory::TransformInsertOrIgnore(PEGTransformer &transformer) { + return OnConflictAction::NOTHING; } unique_ptr PEGTransformerFactory::TransformInsertTarget(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto table_ref = transformer.Transform>(list_pr.Child(0)); - transformer.TransformOptional(list_pr, 1, table_ref->alias); - return table_ref; -} - -string PEGTransformerFactory::TransformInsertAlias(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(1).identifier; -} - -unique_ptr PEGTransformerFactory::TransformOnConflictClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto result = transformer.Transform>(list_pr.Child(3)); - auto &on_conflict_target_opt = list_pr.Child(2); - if (on_conflict_target_opt.HasResult()) { - auto expression_target = transformer.Transform(on_conflict_target_opt.GetResult()); - result->indexed_columns = expression_target.indexed_columns; - if (expression_target.where_clause) { - result->condition = std::move(expression_target.where_clause); - } - } - return result; + unique_ptr base_table_name, + const Identifier &insert_alias) { + base_table_name->alias = insert_alias; + return base_table_name; +} + +Identifier PEGTransformerFactory::TransformInsertAlias(PEGTransformer &transformer, const Identifier &identifier) { + return identifier; } -OnConflictExpressionTarget PEGTransformerFactory::TransformOnConflictTarget(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0).GetResult()); +unique_ptr +PEGTransformerFactory::TransformOnConflictClause(PEGTransformer &transformer, + OnConflictExpressionTarget on_conflict_target, + unique_ptr on_conflict_action) { + on_conflict_action->indexed_columns = on_conflict_target.indexed_columns; + if (on_conflict_target.where_clause) { + on_conflict_action->condition = std::move(on_conflict_target.where_clause); + } + return on_conflict_action; } -OnConflictExpressionTarget PEGTransformerFactory::TransformOnConflictExpressionTarget(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +OnConflictExpressionTarget PEGTransformerFactory::TransformOnConflictExpressionTarget( + PEGTransformer &transformer, const vector &column_id_list, unique_ptr where_clause) { OnConflictExpressionTarget result; - result.indexed_columns = transformer.Transform>(list_pr.Child(0)); - transformer.TransformOptional>(list_pr, 1, result.where_clause); + result.indexed_columns = StringsToIdentifiers(column_id_list); + result.where_clause = std::move(where_clause); return result; } OnConflictExpressionTarget PEGTransformerFactory::TransformOnConflictIndexTarget(PEGTransformer &transformer, - ParseResult &parse_result) { + const Identifier &constraint_name) { throw NotImplementedException("ON CONSTRAINT conflict target is not supported yet"); } -unique_ptr PEGTransformerFactory::TransformOnConflictAction(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); -} - unique_ptr PEGTransformerFactory::TransformOnConflictUpdate(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + unique_ptr update_set_clause, + unique_ptr where_clause) { auto result = make_uniq(); result->action_type = OnConflictAction::UPDATE; - result->set_info = transformer.Transform>(list_pr.Child(3)); - transformer.TransformOptional>(list_pr, 4, result->set_info->condition); + result->set_info = std::move(update_set_clause); + result->set_info->condition = std::move(where_clause); return result; } -unique_ptr PEGTransformerFactory::TransformOnConflictNothing(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr PEGTransformerFactory::TransformOnConflictNothing(PEGTransformer &transformer) { auto result = make_uniq(); result->action_type = OnConflictAction::NOTHING; return result; } -InsertValues PEGTransformerFactory::TransformInsertValues(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +InsertValues PEGTransformerFactory::TransformSelectInsertValues(PEGTransformer &transformer, + unique_ptr select_statement_internal) { InsertValues result; - auto &choice_pr = list_pr.Child(0); - if (choice_pr.GetResult().name == "DefaultValues") { - result.default_values = true; - result.select_statement = nullptr; - return result; - } - if (choice_pr.GetResult().name == "SelectStatementInternal") { - result.default_values = false; - result.select_statement = transformer.Transform>(choice_pr.GetResult()); - return result; - } - throw InternalException("Unexpected choice in InsertValues statement."); + result.select_statement = std::move(select_statement_internal); + return result; +} + +InsertValues PEGTransformerFactory::TransformDefaultValues(PEGTransformer &transformer) { + InsertValues result; + result.default_values = true; + return result; } -InsertColumnOrder PEGTransformerFactory::TransformByNameOrPosition(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(1).GetResult()); +InsertColumnOrder PEGTransformerFactory::TransformInsertByName(PEGTransformer &transformer) { + return InsertColumnOrder::INSERT_BY_NAME; +} + +InsertColumnOrder PEGTransformerFactory::TransformInsertByPosition(PEGTransformer &transformer) { + return InsertColumnOrder::INSERT_BY_POSITION; } vector PEGTransformerFactory::TransformInsertColumnList(PEGTransformer &transformer, - ParseResult &parse_result) { - // InsertColumnList <- Parens(ColumnList) - auto &list_pr = parse_result.Cast(); - auto &column_list = ExtractResultFromParens(list_pr.Child(0)); - return transformer.Transform>(column_list); -} - -vector PEGTransformerFactory::TransformColumnList(PEGTransformer &transformer, ParseResult &parse_result) { - // ColumnList <- List(ColId) - auto &list_pr = parse_result.Cast(); - auto column_list = ExtractParseResultsFromList(list_pr.Child(0)); - vector result; - for (auto &column : column_list) { - result.push_back(transformer.Transform(column)); - } - return result; + const vector &column_list) { + return column_list; +} + +vector PEGTransformerFactory::TransformColumnList(PEGTransformer &transformer, + const vector &col_id) { + return IdentifiersToStrings(col_id); } -vector> PEGTransformerFactory::TransformReturningClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>>(list_pr.Child(1)); +vector> +PEGTransformerFactory::TransformReturningClause(PEGTransformer &transformer, + vector> target_list) { + return target_list; } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_load.cpp b/src/duckdb/src/parser/peg/transformer/transform_load.cpp index c06026fc5..fc3c6f068 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_load.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_load.cpp @@ -7,79 +7,74 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformLoadStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const Identifier &col_id_or_string, + const Identifier &extension_alias) { auto result = make_uniq(); auto info = make_uniq(); - info->load_type = LoadType::LOAD; info->repo_is_alias = false; - info->filename = transformer.Transform(list_pr.Child(1)); + info->filename = col_id_or_string.GetIdentifierName(); + if (!extension_alias.empty()) { + info->alias = extension_alias; + info->load_type = LoadType::LOAD_AS; + } else { + info->load_type = LoadType::LOAD; + } result->info = std::move(info); return std::move(result); } -unique_ptr PEGTransformerFactory::TransformInstallStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +Identifier PEGTransformerFactory::TransformExtensionAlias(PEGTransformer &transformer, const Identifier &identifier) { + return identifier; +} + +unique_ptr PEGTransformerFactory::TransformInstallStatement( + PEGTransformer &transformer, const QualifiedName &identifier_or_string_literal, + const ExtensionRepositoryInfo &from_source, const string &version_number) { auto result = make_uniq(); auto info = make_uniq(); - auto opt_force = list_pr.Child(0).HasResult(); - info->load_type = opt_force ? LoadType::FORCE_INSTALL : LoadType::INSTALL; - auto extension_name = transformer.Transform(list_pr.Child(2)); - info->filename = extension_name.name; + info->load_type = LoadType::INSTALL; + info->filename = identifier_or_string_literal.name.GetIdentifierName(); info->repo_is_alias = false; - auto &from_source_opt = list_pr.Child(3); - if (from_source_opt.HasResult()) { - auto repository_info = transformer.Transform(from_source_opt.GetResult()); - info->repository = repository_info.name; - info->repo_is_alias = repository_info.repository_is_alias; + if (!from_source.name.empty()) { + info->repository = from_source.name.GetIdentifierName(); + info->repo_is_alias = from_source.repository_is_alias; + } + if (!version_number.empty()) { + info->version = version_number; } - transformer.TransformOptional(list_pr, 4, info->version); result->info = std::move(info); return std::move(result); } -ExtensionRepositoryInfo PEGTransformerFactory::TransformFromSource(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &sub_list_pr = list_pr.Child(1); - auto &string_or_identifier = sub_list_pr.Child(0); +ExtensionRepositoryInfo PEGTransformerFactory::TransformFromSourceIdentifier(PEGTransformer &transformer, + const Identifier &identifier) { ExtensionRepositoryInfo result; - if (string_or_identifier.GetResult().type == ParseResultType::STRING) { - result.name = transformer.Transform(string_or_identifier.GetResult()); - result.repository_is_alias = false; - } else { - result.name = string_or_identifier.GetResult().Cast().identifier; - result.repository_is_alias = true; - } + result.name = identifier; + result.repository_is_alias = true; return result; } -unique_ptr PEGTransformerFactory::TransformUpdateExtensionsStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - // UpdateExtensionsStatement <- 'UPDATE' 'EXTENSIONS' Parens(List(Identifier))? - // child 0: 'UPDATE', child 1: 'EXTENSIONS', child 2: optional Parens(List(Identifier)) - auto &list_pr = parse_result.Cast(); +ExtensionRepositoryInfo PEGTransformerFactory::TransformFromSourceString(PEGTransformer &transformer, + const string &string_literal) { + ExtensionRepositoryInfo result; + result.name = Identifier(string_literal); + result.repository_is_alias = false; + return result; +} + +unique_ptr +PEGTransformerFactory::TransformUpdateExtensionsStatement(PEGTransformer &transformer, + const vector &identifier) { auto result = make_uniq(); auto info = make_uniq(); - - auto &opt = list_pr.Child(2); - if (opt.HasResult()) { - auto &inner = ExtractResultFromParens(opt.GetResult()); - auto ext_list = ExtractParseResultsFromList(inner); - for (auto ext : ext_list) { - info->extensions_to_update.emplace_back(ext.get().Cast().identifier); - } - } - + info->extensions_to_update = identifier; result->info = std::move(info); return std::move(result); } -string PEGTransformerFactory::TransformVersionNumber(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto version = transformer.Transform(list_pr.Child(1)); - return version.name; +string PEGTransformerFactory::TransformVersionNumber(PEGTransformer &transformer, + const QualifiedName &identifier_or_string_literal) { + return identifier_or_string_literal.name.GetIdentifierName(); } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_merge_into.cpp b/src/duckdb/src/parser/peg/transformer/transform_merge_into.cpp index 4579c44fc..6d416dbaa 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_merge_into.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_merge_into.cpp @@ -1,210 +1,165 @@ +#include "duckdb/common/set.hpp" #include "duckdb/parser/peg/transformer/peg_transformer.hpp" #include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" namespace duckdb { -unique_ptr PEGTransformerFactory::TransformMergeIntoStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr PEGTransformerFactory::TransformMergeIntoStatement( + PEGTransformer &transformer, CommonTableExpressionMap with_clause, unique_ptr target_opt_alias, + unique_ptr merge_into_using_clause, JoinQualifier join_qualifier, + vector>> merge_match, + vector> returning_clause) { auto result = make_uniq(); - transformer.TransformOptional(list_pr, 0, result->cte_map); - result->target = transformer.Transform>(list_pr.Child(3)); - result->source = transformer.Transform>(list_pr.Child(4)); - auto join_condition = transformer.Transform(list_pr.Child(5)); - if (join_condition.on_clause) { - result->join_condition = std::move(join_condition.on_clause); + auto &node = *result->node; + node.cte_map = std::move(with_clause); + node.target = std::move(target_opt_alias); + node.source = std::move(merge_into_using_clause); + if (join_qualifier.on_clause) { + node.join_condition = std::move(join_qualifier.on_clause); } else { - result->using_columns = std::move(join_condition.using_columns); + node.using_columns = join_qualifier.using_columns; } - auto &merge_match_repeat = list_pr.Child(6); - map> unconditional_actions; - for (auto merge_match : merge_match_repeat.GetChildren()) { - auto merge_match_result = - transformer.Transform>>(merge_match); + set unconditional_actions; + for (auto &merge_match_result : merge_match) { auto action_condition = merge_match_result.first; auto &action = merge_match_result.second; + // Once an unconditional clause has been seen for a condition type, no further clauses + // of the same type are allowed: they would be unreachable. + if (unconditional_actions.count(action_condition)) { + string action_condition_str = MergeQueryNode::ActionConditionToString(action_condition); + throw ParserException( + "Unconditional %s clause was already defined - any following %s clause would be unreachable", + action_condition_str, action_condition_str); + } if (!action->condition) { - auto entry = unconditional_actions.find(action_condition); - if (entry != unconditional_actions.end()) { - string action_condition_str = MergeIntoStatement::ActionConditionToString(action_condition); - throw ParserException( - "Unconditional %s clause was already defined - only one unconditional %s clause is supported", - action_condition_str, action_condition_str); - } - unconditional_actions.emplace(action_condition, std::move(action)); - continue; + unconditional_actions.insert(action_condition); } - result->actions[action_condition].push_back(std::move(action)); - } - // finally add the unconditional actions - for (auto &entry : unconditional_actions) { - result->actions[entry.first].push_back(std::move(entry.second)); + node.actions[action_condition].push_back(std::move(action)); } - transformer.TransformOptional>>(list_pr, 7, result->returning_list); + node.returning_list = std::move(returning_clause); return std::move(result); } unique_ptr PEGTransformerFactory::TransformMergeIntoUsingClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(1)); -} - -pair> -PEGTransformerFactory::TransformMergeMatch(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>>( - list_pr.Child(0).GetResult()); + unique_ptr table_ref) { + return table_ref; } pair> -PEGTransformerFactory::TransformMatchedClause(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto condition = MergeActionCondition::WHEN_MATCHED; - auto merge_into_action = transformer.Transform>(list_pr.Child(4)); - transformer.TransformOptional>(list_pr, 2, merge_into_action->condition); - return pair>(condition, std::move(merge_into_action)); -} - -unique_ptr PEGTransformerFactory::TransformMatchedClauseAction(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); -} - -unique_ptr PEGTransformerFactory::TransformUpdateMatchClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - unique_ptr result; - auto &update_info = list_pr.Child(1); - if (update_info.HasResult()) { - result = transformer.Transform>(update_info.GetResult()); - } else { +PEGTransformerFactory::TransformMatchedClause(PEGTransformer &transformer, unique_ptr and_expression, + unique_ptr matched_clause_action) { + matched_clause_action->condition = std::move(and_expression); + return pair>(MergeActionCondition::WHEN_MATCHED, + std::move(matched_clause_action)); +} + +unique_ptr +PEGTransformerFactory::TransformUpdateMatchClause(PEGTransformer &transformer, + unique_ptr update_match_info) { + auto result = std::move(update_match_info); + if (!result) { result = make_uniq(); } result->action_type = MergeActionType::MERGE_UPDATE; return result; } -unique_ptr PEGTransformerFactory::TransformUpdateMatchInfo(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformUpdateMatchSetAction(PEGTransformer &transformer, + unique_ptr update_match_set_clause) { auto result = make_uniq(); - auto &choice_pr = list_pr.Child(0).GetResult(); - if (choice_pr.name == "ByNameOrPosition") { - result->column_order = transformer.Transform(choice_pr); - } else { - result->update_info = transformer.Transform>(choice_pr); - } + result->update_info = std::move(update_match_set_clause); return result; } -unique_ptr PEGTransformerFactory::TransformUpdateMatchSetClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &inner_list = list_pr.Child(1); - auto &choice_pr = inner_list.Child(0).GetResult(); - if (choice_pr.type == ParseResultType::KEYWORD) { - // We found the '*' - return nullptr; - } - return transformer.Transform>(choice_pr); +unique_ptr +PEGTransformerFactory::TransformUpdateByNameOrPosition(PEGTransformer &transformer, + const InsertColumnOrder &by_name_or_position) { + auto result = make_uniq(); + result->column_order = by_name_or_position; + return result; } -unique_ptr PEGTransformerFactory::TransformDeleteMatchClause(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr PEGTransformerFactory::TransformDeleteMatchClause(PEGTransformer &transformer) { auto result = make_uniq(); result->action_type = MergeActionType::MERGE_DELETE; return result; } -unique_ptr PEGTransformerFactory::TransformInsertMatchClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - unique_ptr result; - auto &insert_info = list_pr.Child(1); - if (insert_info.HasResult()) { - result = transformer.Transform>(insert_info.GetResult()); - } else { +unique_ptr +PEGTransformerFactory::TransformInsertMatchClause(PEGTransformer &transformer, + unique_ptr insert_match_info) { + auto result = std::move(insert_match_info); + if (!result) { result = make_uniq(); } result->action_type = MergeActionType::MERGE_INSERT; return result; } -unique_ptr PEGTransformerFactory::TransformInsertMatchInfo(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); -} - -unique_ptr PEGTransformerFactory::TransformInsertDefaultValues(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr PEGTransformerFactory::TransformInsertDefaultValues(PEGTransformer &transformer) { auto result = make_uniq(); result->default_values = true; return result; } -unique_ptr PEGTransformerFactory::TransformInsertByNameOrPosition(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformInsertByNameOrPosition(PEGTransformer &transformer, + const InsertColumnOrder &by_name_or_position) { auto result = make_uniq(); - transformer.TransformOptional(list_pr, 0, result->column_order); + result->column_order = by_name_or_position; return result; } -unique_ptr PEGTransformerFactory::TransformInsertValuesList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformInsertValuesList(PEGTransformer &transformer, const vector &insert_column_list, + vector> expression) { auto result = make_uniq(); - transformer.TransformOptional>(list_pr, 0, result->insert_columns); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(2)); - auto expr_list = ExtractParseResultsFromList(extract_parens); - for (auto &expr : expr_list) { - result->expressions.push_back(transformer.Transform>(expr)); - } + result->insert_columns = StringsToIdentifiers(insert_column_list); + result->expressions = std::move(expression); return result; } -unique_ptr PEGTransformerFactory::TransformDoNothingMatchClause(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr PEGTransformerFactory::TransformDoNothingMatchClause(PEGTransformer &transformer) { auto result = make_uniq(); result->action_type = MergeActionType::MERGE_DO_NOTHING; return result; } unique_ptr PEGTransformerFactory::TransformErrorMatchClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + unique_ptr expression) { auto result = make_uniq(); result->action_type = MergeActionType::MERGE_ERROR; - auto &expression_opt = list_pr.Child(1); - if (expression_opt.HasResult()) { - result->expressions.push_back(transformer.Transform>(expression_opt.GetResult())); + if (expression) { + result->expressions.push_back(std::move(expression)); } return result; } unique_ptr PEGTransformerFactory::TransformAndExpression(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(1)); + unique_ptr expression) { + return expression; } -pair> -PEGTransformerFactory::TransformNotMatchedClause(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto result = transformer.Transform>(list_pr.Child(6)); - transformer.TransformOptional>(list_pr, 4, result->condition); - MergeActionCondition action_condition = MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET; - transformer.TransformOptional(list_pr, 3, action_condition); - return pair>(action_condition, std::move(result)); +pair> PEGTransformerFactory::TransformNotMatchedClause( + PEGTransformer &transformer, const MergeActionCondition &by_source_or_target, + unique_ptr and_expression, unique_ptr matched_clause_action) { + matched_clause_action->condition = std::move(and_expression); + auto action_condition = by_source_or_target; + if (action_condition == MergeActionCondition::WHEN_MATCHED) { + action_condition = MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET; + } + return pair>(action_condition, std::move(matched_clause_action)); } -MergeActionCondition PEGTransformerFactory::TransformBySourceOrTarget(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); +MergeActionCondition PEGTransformerFactory::TransformBySource(PEGTransformer &transformer) { + return MergeActionCondition::WHEN_NOT_MATCHED_BY_SOURCE; } + +MergeActionCondition PEGTransformerFactory::TransformByTarget(PEGTransformer &transformer) { + return MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET; +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_pivot.cpp b/src/duckdb/src/parser/peg/transformer/transform_pivot.cpp index 7574fb4f7..cf6ea5911 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_pivot.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_pivot.cpp @@ -17,18 +17,14 @@ void PEGTransformerFactory::AddPivotEntry(PEGTransformer &transformer, string en transformer.pivot_entries.push_back(std::move(result)); } -vector PEGTransformerFactory::TransformPivotOn(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.GetChild(1)); +vector PEGTransformerFactory::TransformPivotOn(PEGTransformer &transformer, + vector pivot_column_list) { + return pivot_column_list; } vector PEGTransformerFactory::TransformPivotColumnList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto pivot_column_list = ExtractParseResultsFromList(list_pr.GetChild(0)); - vector result; - for (auto &pivot_column : pivot_column_list) { - auto col = transformer.Transform(pivot_column); + vector pivot_column_entry) { + for (auto &col : pivot_column_entry) { // Re-wrap multiple pivot expressions into a single row() only when IN entries are // scalar (size 1). For tuple IN entries the unpacked expressions must stay separate // so that each value maps to its corresponding pivot expression. @@ -55,32 +51,23 @@ vector PEGTransformerFactory::TransformPivotColumnList(PEGTransform throw ParserException(expr->GetQueryLocation(), "Cannot pivot on subquery \"%s\"", expr->ToString()); } } - result.push_back(std::move(col)); } - return result; -} - -PivotColumn PEGTransformerFactory::TransformPivotColumnEntry(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0).GetResult()); + return pivot_column_entry; } PivotColumn PEGTransformerFactory::TransformPivotColumnSubquery(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + unique_ptr base_expression, + unique_ptr select_statement_internal) { PivotColumn result; - result.pivot_expressions.push_back(transformer.Transform>(list_pr.GetChild(0))); - auto &select_subquery = ExtractResultFromParens(list_pr.GetChild(2)); - auto select = transformer.Transform>(select_subquery); - result.subquery = std::move(select->node); + result.pivot_expressions.push_back(std::move(base_expression)); + result.subquery = std::move(select_statement_internal->node); return result; } -PivotColumn PEGTransformerFactory::TransformPivotColumnEntryInternal(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +PivotColumn PEGTransformerFactory::TransformPivotColumnExpression(PEGTransformer &transformer, + unique_ptr expression) { PivotColumn result; - result.pivot_expressions.push_back(transformer.Transform>(list_pr.GetChild(0))); + result.pivot_expressions.push_back(std::move(expression)); return result; } @@ -102,7 +89,7 @@ unique_ptr PEGTransformerFactory::TransformPivotStatement(PEGTr GroupingSet set; for (idx_t gr = 0; gr < pivot_group_list.size(); gr++) { auto &group = pivot_group_list[gr]; - auto colref = make_uniq(group); + auto colref = make_uniq(Identifier(group)); select_node->select_list.push_back(colref->Copy()); select_node->groups.group_expressions.push_back(std::move(colref)); set.insert(ProjectionIndex(gr)); @@ -135,7 +122,7 @@ unique_ptr PEGTransformerFactory::TransformPivotStatement(PEGTr new_select->from_table = source->Copy(); AddPivotEntry(transformer, enum_name, std::move(new_select), col.pivot_expressions[0]->Copy(), std::move(col.subquery), has_parameters); - col.pivot_enum = enum_name; + col.pivot_enum = Identifier(enum_name); } // Generate the actual query, including the pivot @@ -154,7 +141,7 @@ unique_ptr PEGTransformerFactory::TransformPivotStatement(PEGTr pivot_ref->aggregates.push_back(std::move(function)); } if (pivot_group.HasResult()) { - pivot_ref->groups = transformer.Transform>(pivot_group.GetResult()); + pivot_ref->groups = transformer.Transform>(pivot_group.GetResult()); } pivot_ref->pivots = std::move(columns); select_node->from_table = std::move(pivot_ref); @@ -163,40 +150,30 @@ unique_ptr PEGTransformerFactory::TransformPivotStatement(PEGTr return select_statement; } -vector> PEGTransformerFactory::TransformPivotUsing(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>>(list_pr.GetChild(1)); -} - -vector PEGTransformerFactory::TransformUnpivotHeader(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); +vector> +PEGTransformerFactory::TransformPivotUsing(PEGTransformer &transformer, + vector> target_list) { + return target_list; } vector PEGTransformerFactory::TransformUnpivotHeaderSingle(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const Identifier &col_id_or_string) { vector result; - result.push_back(transformer.Transform(list_pr.GetChild(0))); + result.push_back(col_id_or_string.GetIdentifierName()); return result; } vector PEGTransformerFactory::TransformUnpivotHeaderList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.GetChild(0)); - auto col_list = ExtractParseResultsFromList(extract_parens); - vector result; - for (auto col : col_list) { - result.push_back(transformer.Transform(col)); - } - return result; + const vector &col_id_or_string) { + return IdentifiersToStrings(col_id_or_string); } -bool PEGTransformerFactory::TransformIncludeOrExcludeNulls(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); +bool PEGTransformerFactory::TransformIncludeNulls(PEGTransformer &transformer) { + return true; +} + +bool PEGTransformerFactory::TransformExcludeNulls(PEGTransformer &transformer) { + return false; } unique_ptr PEGTransformerFactory::TransformUnpivotStatement(PEGTransformer &transformer, @@ -252,7 +229,7 @@ unique_ptr PEGTransformerFactory::TransformUnpivotStatement(PEG new_select->from_table = source->Copy(); AddPivotEntry(transformer, enum_name, std::move(new_select), col.pivot_expressions[0]->Copy(), std::move(col.subquery), has_parameters); - col.pivot_enum = enum_name; + col.pivot_enum = Identifier(enum_name); } auto pivot_ref = make_uniq(); @@ -273,16 +250,13 @@ unique_ptr PEGTransformerFactory::TransformUnpivotStatement(PEG } UnpivotNameValues PEGTransformerFactory::TransformIntoNameValues(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const Identifier &col_id_or_string, + const vector &identifier) { UnpivotNameValues result; PivotColumn column; - column.unpivot_names.push_back(transformer.Transform(list_pr.GetChild(2))); + column.unpivot_names.push_back(Identifier(col_id_or_string)); result.column = std::move(column); - auto unpivot_name_list = ExtractParseResultsFromList(list_pr.GetChild(4)); - for (auto identifier : unpivot_name_list) { - result.unpivot_names.push_back(identifier.get().Cast().identifier); - } + result.unpivot_names = identifier; return result; } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_pragma.cpp b/src/duckdb/src/parser/peg/transformer/transform_pragma.cpp index 357ab0854..8381a41d5 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_pragma.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_pragma.cpp @@ -3,28 +3,25 @@ #include "duckdb/parser/expression/comparison_expression.hpp" namespace duckdb { -unique_ptr PEGTransformerFactory::TransformPragmaStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - // Rule: PragmaStatement <- 'PRAGMA'i (PragmaAssign / PragmaFunction) - auto &list_pr = parse_result.Cast(); - auto &matched_child = list_pr.Child(1).Child(0); - return transformer.Transform>(matched_child.GetResult()); +unique_ptr +PEGTransformerFactory::TransformPragmaStatement(PEGTransformer &transformer, + unique_ptr pragma_assign_or_function) { + return pragma_assign_or_function; } -unique_ptr PEGTransformerFactory::TransformPragmaAssign(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr +PEGTransformerFactory::TransformPragmaAssign(PEGTransformer &transformer, const Identifier &setting_name, + vector> variable_list) { // Rule: PragmaAssign <- SettingName '=' Expression - auto &list_pr = parse_result.Cast(); auto result = make_uniq(); auto &info = *result->info; - info.name = list_pr.Child(0).identifier; - auto value_list = transformer.Transform>>(list_pr.Child(2)); - if (value_list.size() != 1) { + info.name = setting_name; + if (variable_list.size() != 1) { throw ParserException("PRAGMA statement with assignment should contain exactly one parameter"); } - auto &expr = value_list[0]; + auto &expr = variable_list[0]; if (expr->GetExpressionType() == ExpressionType::COLUMN_REF) { - auto &colref = value_list[0]->Cast(); + auto &colref = expr->Cast(); if (!colref.IsQualified()) { info.parameters.emplace_back(make_uniq(Value(colref.GetColumnName()))); } else { @@ -37,60 +34,49 @@ unique_ptr PEGTransformerFactory::TransformPragmaAssign(PEGTransfo // "PRAGMA table_info='integers'" // "PRAGMA table_info('integers')" // for compatibility, any pragmas that match the SQLite ones are parsed as calls - case_insensitive_set_t sqlite_compat_pragmas {"table_info"}; + identifier_set_t sqlite_compat_pragmas {"table_info"}; if (sqlite_compat_pragmas.find(info.name) != sqlite_compat_pragmas.end()) { return std::move(result); } auto set_statement = make_uniq(info.name, std::move(info.parameters[0]), SetScope::AUTOMATIC); - return std::move(set_statement); } -unique_ptr PEGTransformerFactory::TransformPragmaFunction(PEGTransformer &transformer, - ParseResult &parse_result) { +unique_ptr +PEGTransformerFactory::TransformPragmaFunction(PEGTransformer &transformer, const Identifier &pragma_name, + vector> pragma_parameters) { // Rule: PragmaFunction <- PragmaName PragmaParameters? auto result = make_uniq(); - auto &list_pr = parse_result.Cast(); - result->info->name = list_pr.Child(0).identifier; - auto &optional_parameters_pr = list_pr.Child(1); - if (optional_parameters_pr.HasResult()) { - auto parameters = - transformer.Transform>>(optional_parameters_pr.GetResult()); - for (auto ¶meter : parameters) { - if (parameter->GetExpressionType() == ExpressionType::COMPARE_EQUAL) { - auto &comp = parameter->Cast(); - if (comp.left->GetExpressionType() != ExpressionType::COLUMN_REF) { - throw ParserException("Named parameter requires a column reference on the LHS"); - } - auto &columnref = comp.left->Cast(); - result->info->named_parameters.insert(make_pair(columnref.GetName(), std::move(comp.right))); - } else if (parameter->GetExpressionType() == ExpressionType::COLUMN_REF) { - auto &colref = parameter->Cast(); - if (!colref.IsQualified()) { - result->info->parameters.emplace_back(make_uniq(Value(colref.GetColumnName()))); - } else { - result->info->parameters.emplace_back(make_uniq(Value(parameter->ToString()))); - } + result->info->name = pragma_name; + if (pragma_parameters.empty()) { + return std::move(result); + } + for (auto ¶meter : pragma_parameters) { + if (parameter->GetExpressionType() == ExpressionType::COMPARE_EQUAL) { + auto &comp = parameter->Cast(); + if (comp.Left().GetExpressionType() != ExpressionType::COLUMN_REF) { + throw ParserException("Named parameter requires a column reference on the LHS"); + } + auto &columnref = comp.Left().Cast(); + result->info->named_parameters.insert(make_pair(columnref.GetName(), std::move(comp.RightMutable()))); + } else if (parameter->GetExpressionType() == ExpressionType::COLUMN_REF) { + auto &colref = parameter->Cast(); + if (!colref.IsQualified()) { + result->info->parameters.emplace_back(make_uniq(Value(colref.GetColumnName()))); } else { - result->info->parameters.emplace_back(std::move(parameter)); + result->info->parameters.emplace_back(make_uniq(Value(parameter->ToString()))); } + } else { + result->info->parameters.emplace_back(std::move(parameter)); } } return std::move(result); } -vector> PEGTransformerFactory::TransformPragmaParameters(PEGTransformer &transformer, - ParseResult &parse_result) { - // TODO(Dtenwolde) Check about named parameters - // PragmaParameters <- List(Expression) - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - auto expr_list = ExtractParseResultsFromList(extract_parens); - vector> parameters; - for (auto expr : expr_list) { - parameters.push_back(transformer.Transform>(expr)); - } - return parameters; +vector> +PEGTransformerFactory::TransformPragmaParameters(PEGTransformer &transformer, + vector> expression) { + return expression; } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_prepare.cpp b/src/duckdb/src/parser/peg/transformer/transform_prepare.cpp index d70124e73..132e591e0 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_prepare.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_prepare.cpp @@ -17,21 +17,22 @@ bool IsPrepareableStatement(StatementType type) { } unique_ptr PEGTransformerFactory::TransformPrepareStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); + const Identifier &identifier, + const vector &type_list, + unique_ptr statement) { auto result = make_uniq(); - result->name = list_pr.Child(1).identifier; - auto &type_list_opt = list_pr.Child(2); - if (type_list_opt.HasResult()) { - throw NotImplementedException("TypeList for prepared statement has not been implemented."); + result->name = identifier; + if (!IsPrepareableStatement(statement->type)) { + throw ParserException("%s is not a preparable statement", EnumUtil::ToString(statement->type)); } - auto stmt = transformer.Transform>(list_pr.Child(4)); - if (!IsPrepareableStatement(stmt->type)) { - throw ParserException("%s is not a preparable statement", EnumUtil::ToString(stmt->type)); - } - result->statement = std::move(stmt); + result->statement = std::move(statement); transformer.ClearParameters(); return std::move(result); } +vector PEGTransformerFactory::TransformTypeList(PEGTransformer &transformer, + const vector &type) { + throw NotImplementedException("TypeList for prepared statement has not been implemented."); +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_select.cpp b/src/duckdb/src/parser/peg/transformer/transform_select.cpp index d1fd4091c..2fe6a9848 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_select.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_select.cpp @@ -202,12 +202,12 @@ unique_ptr PEGTransformerFactory::TransformSimpleSelect(PEGTran D_ASSERT(!window_func->GetAlias().empty()); string window_name(window_func->GetAlias()); window_func->ClearAlias(); - auto it = transformer.window_clauses.find(window_name); + auto it = transformer.window_clauses.find(Identifier(window_name)); if (it != transformer.window_clauses.end()) { throw ParserException("window \"%s\" is already defined", window_name); } auto window_function = unique_ptr_cast(std::move(window_func)); - transformer.window_clauses[window_name] = std::move(window_function); + transformer.window_clauses[Identifier(window_name)] = std::move(window_function); } } auto select_node = transformer.Transform>(list_pr.Child(0)); @@ -342,40 +342,42 @@ vector> PEGTransformerFactory::TransformDistinctOnT return result; } -unique_ptr PEGTransformerFactory::TransformFunctionArgument(PEGTransformer &transformer, - ParseResult &parse_result) { +FunctionArgument PEGTransformerFactory::TransformFunctionArgument(PEGTransformer &transformer, + ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto &choice_pr = list_pr.Child(0).GetResult(); if (choice_pr.name == "NamedParameter") { auto parameter = transformer.Transform(choice_pr); parameter.expression->SetAlias(parameter.name); - return std::move(parameter.expression); + return FunctionArgument(parameter.name, std::move(parameter.expression)); } - return transformer.Transform>(choice_pr); + + return FunctionArgument(transformer.Transform>(choice_pr)); } MacroParameter PEGTransformerFactory::TransformNamedParameter(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); MacroParameter parameter; parameter.expression = transformer.Transform>(list_pr.Child(3)); - parameter.name = transformer.Transform(list_pr.Child(0)); + parameter.name = Identifier(transformer.Transform(list_pr.Child(0))); parameter.is_default = true; transformer.TransformOptional(list_pr, 1, parameter.type); return parameter; } -vector> PEGTransformerFactory::TransformTableFunctionArguments(PEGTransformer &transformer, - ParseResult &parse_result) { +vector PEGTransformerFactory::TransformTableFunctionArguments(PEGTransformer &transformer, + ParseResult &parse_result) { // TableFunctionArguments <- Parens(List(FunctionArgument)?) - vector> result; + vector result; auto &list_pr = parse_result.Cast(); auto &stripped_parens = ExtractResultFromParens(list_pr.Child(0)).Cast(); if (stripped_parens.HasResult()) { auto argument_list = ExtractParseResultsFromList(stripped_parens.GetResult()); for (auto &argument : argument_list) { - result.push_back(transformer.Transform>(argument)); + result.push_back(transformer.Transform(argument)); } } + return result; } @@ -395,7 +397,7 @@ unique_ptr PEGTransformerFactory::TransformSchemaReservedTable(PEG ParseResult &parse_result) { // SchemaReservedTable <- SchemaQualification ReservedTableName auto &list_pr = parse_result.Cast(); - auto schema = transformer.Transform(list_pr.Child(0)); + auto schema = transformer.Transform(list_pr.Child(0)); auto table_name = list_pr.Child(1).identifier; const auto description = TableDescription(INVALID_CATALOG, schema, table_name); @@ -406,26 +408,21 @@ unique_ptr PEGTransformerFactory::TransformCatalogReservedSchemaTa ParseResult &parse_result) { // CatalogReservedSchemaTable <- CatalogQualification ReservedSchemaQualification ReservedTableName auto &list_pr = parse_result.Cast(); - auto catalog = transformer.Transform(list_pr.Child(0)); - auto schema = transformer.Transform(list_pr.Child(1)); + auto catalog = transformer.Transform(list_pr.Child(0)); + auto schema = transformer.Transform(list_pr.Child(1)); auto table_name = list_pr.Child(2).identifier; const auto description = TableDescription(catalog, schema, table_name); return make_uniq(description); } -string PEGTransformerFactory::TransformSchemaQualification(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; -} - -string PEGTransformerFactory::TransformCatalogQualification(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return list_pr.Child(0).identifier; +Identifier PEGTransformerFactory::TransformSchemaQualification(PEGTransformer &transformer, + const Identifier &schema_name) { + return schema_name; } -QualifiedName PEGTransformerFactory::TransformQualifiedName(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0).GetResult()); +Identifier PEGTransformerFactory::TransformCatalogQualification(PEGTransformer &transformer, + const Identifier &catalog_name) { + return catalog_name; } QualifiedName @@ -433,49 +430,39 @@ PEGTransformerFactory::TransformCatalogReservedSchemaIdentifierOrStringLiteral(P optional_ptr parse_result) { QualifiedName result; auto &list_pr = parse_result->Cast(); - result.catalog = transformer.Transform(list_pr.Child(0)); - result.schema = transformer.Transform(list_pr.Child(1)); - result.name = transformer.Transform(list_pr.Child(2)); + result.catalog = transformer.Transform(list_pr.Child(0)); + result.schema = transformer.Transform(list_pr.Child(1)); + result.name = transformer.Transform(list_pr.Child(2)); return result; } -QualifiedName PEGTransformerFactory::TransformCatalogReservedSchemaIdentifier(PEGTransformer &transformer, - ParseResult &parse_result) { +QualifiedName PEGTransformerFactory::TransformCatalogReservedSchemaIdentifier( + PEGTransformer &transformer, const Identifier &catalog_qualification, + const Identifier &reserved_schema_qualification, const Identifier &reserved_identifier_or_string_literal) { QualifiedName result; - auto &list_pr = parse_result.Cast(); - result.catalog = transformer.Transform(list_pr.Child(0)); - result.schema = transformer.Transform(list_pr.Child(1)); - result.name = transformer.Transform(list_pr.Child(2)); + result.catalog = catalog_qualification; + result.schema = reserved_schema_qualification; + result.name = Identifier(reserved_identifier_or_string_literal); return result; } -QualifiedName PEGTransformerFactory::TransformSchemaReservedIdentifierOrStringLiteral(PEGTransformer &transformer, - ParseResult &parse_result) { +QualifiedName PEGTransformerFactory::TransformSchemaReservedIdentifierOrStringLiteral( + PEGTransformer &transformer, const Identifier &schema_qualification, + const Identifier &reserved_identifier_or_string_literal) { QualifiedName result; - auto &list_pr = parse_result.Cast(); result.catalog = INVALID_CATALOG; - result.schema = transformer.Transform(list_pr.Child(0)); - result.name = transformer.Transform(list_pr.Child(1)); + result.schema = schema_qualification; + result.name = Identifier(reserved_identifier_or_string_literal); return result; } -string PEGTransformerFactory::TransformReservedIdentifierOrStringLiteral(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - if (choice_pr.GetResult().type == ParseResultType::IDENTIFIER) { - return choice_pr.GetResult().Cast().identifier; - } - return transformer.Transform(choice_pr.GetResult()); -} - QualifiedName PEGTransformerFactory::TransformTableNameIdentifierOrStringLiteral(PEGTransformer &transformer, ParseResult &parse_result) { QualifiedName result; auto &list_pr = parse_result.Cast(); result.catalog = INVALID_CATALOG; result.schema = INVALID_SCHEMA; - result.name = transformer.Transform(list_pr.Child(0)); + result.name = transformer.Transform(list_pr.Child(0)); return result; } @@ -507,7 +494,7 @@ unique_ptr PEGTransformerFactory::TransformExpressionAsCollabe auto &list_pr = parse_result.Cast(); auto expr = transformer.Transform>(list_pr.Child(0)); auto &collabel_or_string = list_pr.Child(2); - expr->SetAlias(transformer.Transform(collabel_or_string)); + expr->SetAlias(transformer.Transform(collabel_or_string)); return expr; } @@ -516,7 +503,7 @@ unique_ptr PEGTransformerFactory::TransformColIdExpression(PEG auto &list_pr = parse_result.Cast(); auto &colid = list_pr.Child(0); auto expr = transformer.Transform>(list_pr.Child(2)); - expr->SetAlias(transformer.Transform(colid)); + expr->SetAlias(transformer.Transform(colid)); return expr; } @@ -541,7 +528,7 @@ TableAlias PEGTransformerFactory::TransformTableAliasAs(PEGTransformer &transfor TableAlias result; auto qualified_name = transformer.Transform(list_pr.Child(1)); result.name = qualified_name.name; - transformer.TransformOptional>(list_pr, 2, result.column_name_alias); + transformer.TransformOptional>(list_pr, 2, result.column_name_alias); return result; } @@ -549,7 +536,7 @@ TableAlias PEGTransformerFactory::TransformTableAliasWithoutAs(PEGTransformer &t auto &list_pr = parse_result.Cast(); TableAlias result; result.name = list_pr.Child(0).identifier; - transformer.TransformOptional>(list_pr, 1, result.column_name_alias); + transformer.TransformOptional>(list_pr, 1, result.column_name_alias); return result; } @@ -652,7 +639,7 @@ unique_ptr PEGTransformerFactory::TransformTableUnpivotClause(PEGTrans transformer.TransformOptional(list_pr, 1, result->include_nulls); auto &extract_parens = ExtractResultFromParens(list_pr.Child(2)); auto &inner_list = extract_parens.Cast(); - result->unpivot_names = transformer.Transform>(inner_list.GetChild(0)); + result->unpivot_names = transformer.Transform>(inner_list.GetChild(0)); auto &pivot_values_list = inner_list.Child(2); for (auto pivot_value : pivot_values_list.GetChildren()) { result->pivots.push_back(transformer.Transform(pivot_value)); @@ -670,7 +657,7 @@ unique_ptr PEGTransformerFactory::TransformTableUnpivotClause(PEGTrans PivotColumn PEGTransformerFactory::TransformUnpivotValueList(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); PivotColumn result; - result.unpivot_names = transformer.Transform>(list_pr.GetChild(0)); + result.unpivot_names = transformer.Transform>(list_pr.GetChild(0)); if (result.unpivot_names.size() != 1) { throw ParserException("UNPIVOT requires a single column name for the PIVOT IN clause"); } @@ -684,20 +671,21 @@ void PEGTransformerFactory::GetValueFromExpression(unique_ptr result.push_back(const_expr.GetValue()); } else if (expr->GetExpressionClass() == ExpressionClass::COLUMN_REF) { auto &col_ref_expr = expr->Cast(); - for (auto &col : col_ref_expr.column_names) { + for (auto &col : col_ref_expr.ColumnNames()) { result.push_back(Value(col)); } } else if (expr->GetExpressionClass() == ExpressionClass::FUNCTION) { auto &func_expr = expr->Cast(); - if (func_expr.function_name == "row") { - for (auto &col : func_expr.children) { - GetValueFromExpression(col, result); + if (func_expr.FunctionName() == "row") { + for (auto &col : func_expr.GetArgumentsMutable()) { + GetValueFromExpression(col.GetExpressionMutable(), result); } } } } bool PEGTransformerFactory::TransformPivotInList(unique_ptr &expr, PivotColumnEntry &entry) { + auto initial_size = entry.values.size(); switch (expr->GetExpressionType()) { case ExpressionType::COLUMN_REF: { auto &colref = expr->Cast(); @@ -709,11 +697,12 @@ bool PEGTransformerFactory::TransformPivotInList(unique_ptr &e } case ExpressionType::FUNCTION: { auto &function = expr->Cast(); - if (function.function_name != "row") { + if (function.FunctionName() != "row") { return false; } - for (auto &child : function.children) { - if (!TransformPivotInList(child, entry)) { + for (auto &child : function.GetArgumentsMutable()) { + if (!TransformPivotInList(child.GetExpressionMutable(), entry)) { + entry.values.resize(initial_size); return false; } } @@ -730,6 +719,17 @@ bool PEGTransformerFactory::TransformPivotInList(unique_ptr &e } } +static bool PivotEntryIsTuple(const PivotColumnEntry &entry) { + if (entry.values.size() > 1) { + return true; + } + if (!entry.expr || entry.expr->GetExpressionType() != ExpressionType::FUNCTION) { + return false; + } + auto &function = entry.expr->Cast(); + return function.FunctionName() == "row"; +} + vector PEGTransformerFactory::TransformUnpivotTargetList(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); @@ -757,7 +757,7 @@ unique_ptr PEGTransformerFactory::TransformTablePivotClause(PEGTransfo for (auto pivot_value : pivot_values_list.GetChildren()) { result->pivots.push_back(transformer.Transform(pivot_value)); } - transformer.TransformOptional>(inner_list, 3, result->groups); + transformer.TransformOptional>(inner_list, 3, result->groups); TableAlias table_alias; transformer.TransformOptional(list_pr, 2, table_alias); result->alias = table_alias.name; @@ -792,7 +792,7 @@ PivotColumn PEGTransformerFactory::TransformPivotValueList(PEGTransformer &trans return result; } auto &func_expr = pivot_expression->Cast(); - if (func_expr.function_name != "row") { + if (func_expr.FunctionName() != "row") { result.pivot_expressions.push_back(std::move(pivot_expression)); return result; } @@ -801,13 +801,15 @@ PivotColumn PEGTransformerFactory::TransformPivotValueList(PEGTransformer &trans // so pivot_expressions.size() matches entry.values.size() (both 1). bool has_tuple_entries = false; for (auto &entry : result.entries) { - if (entry.values.size() > 1) { + if (PivotEntryIsTuple(entry)) { has_tuple_entries = true; break; } } if (has_tuple_entries) { - result.pivot_expressions = std::move(func_expr.children); + for (auto &child : func_expr.GetArgumentsMutable()) { + result.pivot_expressions.emplace_back(std::move(child.GetExpressionMutable())); + } } else { result.pivot_expressions.push_back(std::move(pivot_expression)); } @@ -957,8 +959,7 @@ unique_ptr PEGTransformerFactory::TransformTableFunctionLateralOpt(PEG auto result = make_uniq(); auto qualified_table_function = transformer.Transform(list_pr.Child(1)); - auto table_function_arguments = - transformer.Transform>>(list_pr.Child(2)); + auto table_function_arguments = transformer.Transform>(list_pr.Child(2)); result->with_ordinality = list_pr.Child(3).HasResult() ? OrdinalityType::WITH_ORDINALITY : OrdinalityType::WITHOUT_ORDINALITY; result->function = @@ -980,8 +981,7 @@ unique_ptr PEGTransformerFactory::TransformTableFunctionAliasColon(PEG auto table_alias = transformer.Transform(list_pr.Child(0)); auto qualified_table_function = transformer.Transform(list_pr.Child(1)); - auto table_function_arguments = - transformer.Transform>>(list_pr.Child(2)); + auto table_function_arguments = transformer.Transform>(list_pr.Child(2)); auto result = make_uniq(); result->with_ordinality = list_pr.Child(3).HasResult() ? OrdinalityType::WITH_ORDINALITY @@ -989,7 +989,7 @@ unique_ptr PEGTransformerFactory::TransformTableFunctionAliasColon(PEG result->function = make_uniq(qualified_table_function.catalog, qualified_table_function.schema, qualified_table_function.name, std::move(table_function_arguments)); - result->alias = table_alias; + result->alias = Identifier(table_alias); transformer.TransformOptional>(list_pr, 4, result->sample); return std::move(result); } @@ -1005,13 +1005,13 @@ QualifiedName PEGTransformerFactory::TransformQualifiedTableFunction(PEGTransfor QualifiedName result; auto &opt_catalog = list_pr.Child(0); if (opt_catalog.HasResult()) { - result.catalog = transformer.Transform(opt_catalog.GetResult()); + result.catalog = transformer.Transform(opt_catalog.GetResult()); } else { result.catalog = INVALID_CATALOG; } auto &opt_schema = list_pr.Child(1); if (opt_schema.HasResult()) { - result.schema = transformer.Transform(opt_schema.GetResult()); + result.schema = transformer.Transform(opt_schema.GetResult()); } else { result.schema = INVALID_SCHEMA; } @@ -1051,7 +1051,7 @@ unique_ptr PEGTransformerFactory::TransformBaseTableRef(PEGTransformer auto result = transformer.Transform>(list_pr.Child(1)); auto &table_alias_colon_opt = list_pr.Child(0); if (table_alias_colon_opt.HasResult()) { - result->alias = transformer.Transform(table_alias_colon_opt.GetResult()); + result->alias = transformer.Transform(table_alias_colon_opt.GetResult()); } auto &table_alias_opt = list_pr.Child(2); if (table_alias_opt.HasResult() && table_alias_colon_opt.HasResult()) { @@ -1087,7 +1087,7 @@ unique_ptr PEGTransformerFactory::TransformParensTableRef(PEGTransform select_statement->node = std::move(select_node); auto subquery = make_uniq(std::move(select_statement)); if (table_alias_colon_opt.HasResult()) { - subquery->alias = transformer.Transform(table_alias_colon_opt.GetResult()); + subquery->alias = transformer.Transform(table_alias_colon_opt.GetResult()); } if (table_alias_opt.HasResult()) { auto table_alias = transformer.Transform(table_alias_opt.GetResult()); @@ -1225,7 +1225,7 @@ vector PEGTransformerFactory::TransformOrderByAll(PEGTransformer &t order_by_null_type = transformer.Transform(order_by_null_pr.GetResult()); } auto star_expr = make_uniq(); - star_expr->columns = true; + star_expr->IsColumnsMutable() = true; result.push_back(OrderByNode(order_type, order_by_null_type, std::move(star_expr))); return result; } @@ -1281,13 +1281,8 @@ unique_ptr PEGTransformerFactory::VerifyLimitOffset(LimitPercent if (offset.is_percent) { throw ParserException("Percentage for offsets are not supported."); } - if (limit.is_percent) { - auto result = make_uniq(); - result->limit = std::move(limit.expression); - result->offset = std::move(offset.expression); - return std::move(result); - } auto result = make_uniq(); + result->limit_type = limit.is_percent ? LimitValueType::PERCENTAGE : LimitValueType::ROW_COUNT; if (limit.expression) { result->limit = std::move(limit.expression); } @@ -1409,9 +1404,9 @@ void PEGTransformerFactory::AddGroupByExpression(unique_ptr ex GroupByNode &result, vector &result_set) { if (expression->GetExpressionType() == ExpressionType::FUNCTION) { auto &func = expression->Cast(); - if (func.function_name == "row") { - for (auto &child : func.children) { - AddGroupByExpression(std::move(child), map, result, result_set); + if (func.FunctionName() == "row") { + for (auto &child : func.GetArgumentsMutable()) { + AddGroupByExpression(std::move(child.GetExpressionMutable()), map, result, result_set); } return; } @@ -1584,17 +1579,17 @@ CommonTableExpressionMap PEGTransformerFactory::TransformWithClause(PEGTransform throw ParserException("Recursive CTEs with DML statements are not supported"); } // Now safe to call on SELECT, VALUES, etc. - query_node = ToRecursiveCTE(std::move(query_node), with_entry.first, with_entry.second->aliases, + query_node = ToRecursiveCTE(std::move(query_node), Identifier(with_entry.first), with_entry.second->aliases, with_entry.second->key_targets); } auto cte_name = string(with_entry.first); - auto it = result.map.find(cte_name); + auto it = result.map.find(Identifier(cte_name)); if (it != result.map.end()) { // can't have two CTEs with same name throw ParserException("Duplicate CTE name \"%s\"", cte_name); } - result.map.insert(with_entry.first, std::move(with_entry.second)); + result.map.insert(Identifier(with_entry.first), std::move(with_entry.second)); } return result; } @@ -1604,7 +1599,7 @@ PEGTransformerFactory::TransformWithStatement(PEGTransformer &transformer, Parse auto &list_pr = parse_result.Cast(); auto result = make_uniq(); auto cte_name = transformer.Transform(list_pr.Child(0)); - transformer.TransformOptional>(list_pr, 1, result->aliases); + transformer.TransformOptional>(list_pr, 1, result->aliases); transformer.TransformOptional>>(list_pr, 2, result->key_targets); auto &materialized_opt = list_pr.Child(4); if (materialized_opt.HasResult()) { diff --git a/src/duckdb/src/parser/peg/transformer/transform_set.cpp b/src/duckdb/src/parser/peg/transformer/transform_set.cpp index bcfa213a0..6890ef587 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_set.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_set.cpp @@ -4,146 +4,141 @@ namespace duckdb { -// ResetStatement <- 'RESET' (SetVariable / SetSetting) +// ResetStatement <- 'RESET' SetVariableOrSetting unique_ptr PEGTransformerFactory::TransformResetStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &child_pr = list_pr.Child(1); - auto &choice_pr = child_pr.Child(0); - - SettingInfo setting_info = transformer.Transform(choice_pr.GetResult()); - if (setting_info.scope == SetScope::LOCAL) { + const SettingInfo &set_variable_or_setting) { + if (set_variable_or_setting.scope == SetScope::LOCAL) { throw NotImplementedException("RESET LOCAL is not implemented."); } - return make_uniq(setting_info.name, setting_info.scope); + return make_uniq(set_variable_or_setting.name, set_variable_or_setting.scope); } // SetAssignment <- VariableAssign VariableList -vector> PEGTransformerFactory::TransformSetAssignment(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>>(list_pr, 1); +vector> +PEGTransformerFactory::TransformSetAssignment(PEGTransformer &transformer, + vector> variable_list) { + return variable_list; } // SetSetting <- SettingScope? SettingName -SettingInfo PEGTransformerFactory::TransformSetSetting(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &optional_scope_pr = list_pr.Child(0); - +SettingInfo PEGTransformerFactory::TransformSetSetting(PEGTransformer &transformer, const SetScope &setting_scope, + const Identifier &setting_name) { SettingInfo result; - result.name = list_pr.Child(1).identifier; - if (optional_scope_pr.HasResult()) { - auto &setting_scope = optional_scope_pr.GetResult().Cast(); - auto &scope_value = setting_scope.Child(0); - result.scope = transformer.TransformEnum(scope_value); - } + result.name = setting_name; + result.scope = setting_scope; return result; } -// SetStatement <- 'SET' (StandardAssignment / SetTimeZone) -unique_ptr PEGTransformerFactory::TransformSetStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &child_pr = list_pr.Child(1); - return transformer.Transform>(child_pr.Child(0).GetResult()); +// SetStatement <- 'SET' SetAssignmentOrTimeZone +unique_ptr +PEGTransformerFactory::TransformSetStatement(PEGTransformer &transformer, + unique_ptr set_assignment_or_time_zone) { + return std::move(set_assignment_or_time_zone); } -// ZoneIntervalWithInterval <- 'INTERVAL' StringLiteral Interval? -unique_ptr PEGTransformerFactory::TransformZoneIntervalWithInterval(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - // child 0 = 'INTERVAL' keyword, child 1 = StringLiteral, child 2 = Interval? - auto &str_pr = list_pr.Child(1); - auto expr = make_uniq(Value(str_pr.result)); - return make_uniq(LogicalType::INTERVAL, std::move(expr)); +// ZoneLocal <- 'LOCAL' +unique_ptr PEGTransformerFactory::TransformZoneLocal(PEGTransformer &transformer) { + return make_uniq(); } -// ZoneIntervalWithPrecision <- 'INTERVAL' Parens(NumberLiteral) StringLiteral -unique_ptr PEGTransformerFactory::TransformZoneIntervalWithPrecision(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - // child 0 = 'INTERVAL' keyword, child 1 = Parens(NumberLiteral), child 2 = StringLiteral - auto &str_pr = list_pr.Child(2); - auto expr = make_uniq(Value(str_pr.result)); - return make_uniq(LogicalType::INTERVAL, std::move(expr)); +// ZoneDefault <- 'DEFAULT' +unique_ptr PEGTransformerFactory::TransformZoneDefault(PEGTransformer &transformer) { + return make_uniq(); } -// ZoneValue <- ZoneIntervalWithPrecision / ZoneIntervalWithInterval / StringLiteral / Identifier / NumberLiteral / -// 'DEFAULT' / 'LOCAL' -unique_ptr PEGTransformerFactory::TransformZoneValue(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &choice_pr = list_pr.Child(0); - if (choice_pr.GetResult().type == ParseResultType::STRING) { - return make_uniq(Value(choice_pr.GetResult().Cast().result)); - } - if (choice_pr.GetResult().type == ParseResultType::IDENTIFIER) { - return make_uniq(Value(choice_pr.GetResult().Cast().identifier)); - } - const auto &name = choice_pr.GetResult().name; - if (name == "ZoneIntervalWithPrecision" || name == "ZoneIntervalWithInterval" || name == "NumberLiteral") { - return transformer.Transform>(choice_pr.GetResult()); - } - return make_uniq(); +// ZoneStringLiteral <- StringLiteral +unique_ptr PEGTransformerFactory::TransformZoneStringLiteral(PEGTransformer &transformer, + const string &string_literal) { + return make_uniq(Value(string_literal)); +} + +// ZoneIdentifier <- Identifier +unique_ptr PEGTransformerFactory::TransformZoneIdentifier(PEGTransformer &transformer, + const Identifier &identifier) { + return make_uniq(Value(identifier)); } // SetTimeZone <- 'TIME' 'ZONE' ZoneValue unique_ptr PEGTransformerFactory::TransformSetTimeZone(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto expr = transformer.Transform>(list_pr.Child(2)); - if (expr->GetExpressionClass() == ExpressionClass::DEFAULT) { + unique_ptr zone_value) { + if (zone_value->GetExpressionClass() == ExpressionClass::DEFAULT) { return make_uniq("timezone", SetScope::AUTOMATIC); } - return make_uniq("timezone", std::move(expr), SetScope::AUTOMATIC); + return make_uniq("timezone", std::move(zone_value), SetScope::AUTOMATIC); } // SetVariable <- VariableScope Identifier -SettingInfo PEGTransformerFactory::TransformSetVariable(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - +SettingInfo PEGTransformerFactory::TransformSetVariable(PEGTransformer &transformer, const SetScope &variable_scope, + const Identifier &identifier) { SettingInfo result; - result.scope = transformer.TransformEnum(list_pr.Child(0)); - result.name = list_pr.Child(1).identifier; + result.name = identifier; + result.scope = variable_scope; return result; } -// StandardAssignment <- (SetVariable / SetSetting) SetAssignment -unique_ptr PEGTransformerFactory::TransformStandardAssignment(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &first_sub_rule = list_pr.Child(0); - - auto &setting_or_var_pr = first_sub_rule.Child(0); - SettingInfo setting_info = transformer.Transform(setting_or_var_pr.GetResult()); - if (setting_info.scope == SetScope::LOCAL) { +// StandardAssignment <- SetVariableOrSetting SetAssignment +unique_ptr +PEGTransformerFactory::TransformStandardAssignment(PEGTransformer &transformer, + const SettingInfo &set_variable_or_setting, + vector> set_assignment) { + if (set_variable_or_setting.scope == SetScope::LOCAL) { throw NotImplementedException("SET LOCAL is not implemented."); } - auto &set_assignment_pr = list_pr.Child(1); - auto values = transformer.Transform>>(set_assignment_pr); - if (values.size() > 1) { + if (set_assignment.size() > 1) { throw ParserException("SET can only contain a single value"); } - auto value = std::move(values[0]); + auto value = std::move(set_assignment[0]); if (value->GetExpressionClass() == ExpressionClass::COLUMN_REF) { // SET value cannot be a column reference auto &col_ref = value->Cast(); value = make_uniq(col_ref.GetColumnName()); } else if (value->GetExpressionClass() == ExpressionClass::DEFAULT) { - return make_uniq(setting_info.name, setting_info.scope); + return make_uniq(set_variable_or_setting.name, set_variable_or_setting.scope); } - return make_uniq(setting_info.name, std::move(value), setting_info.scope); + return make_uniq(set_variable_or_setting.name, std::move(value), + set_variable_or_setting.scope); } // VariableList <- List(Expression) -vector> PEGTransformerFactory::TransformVariableList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto expr_list = ExtractParseResultsFromList(list_pr.Child(0)); - vector> expressions; - for (auto &expr : expr_list) { - expressions.push_back(transformer.Transform>(expr)); - } - return expressions; +vector> +PEGTransformerFactory::TransformVariableList(PEGTransformer &transformer, + vector> expression) { + return expression; } + +// VariableScope <- 'VARIABLE' +SetScope PEGTransformerFactory::TransformVariableScope(PEGTransformer &transformer) { + return SetScope::VARIABLE; +} + +// LocalScope <- 'LOCAL' +SetScope PEGTransformerFactory::TransformLocalScope(PEGTransformer &transformer) { + return SetScope::LOCAL; +} + +// SessionScope <- 'SESSION' +SetScope PEGTransformerFactory::TransformSessionScope(PEGTransformer &transformer) { + return SetScope::SESSION; +} + +// GlobalScope <- 'GLOBAL' +SetScope PEGTransformerFactory::TransformGlobalScope(PEGTransformer &transformer) { + return SetScope::GLOBAL; +} + +// ZoneIntervalWithInterval <- 'INTERVAL' StringLiteral Interval? +unique_ptr +PEGTransformerFactory::TransformZoneIntervalWithInterval(PEGTransformer &transformer, const string &string_literal, + const DatePartSpecifier &interval) { + auto expr = make_uniq(Value(string_literal)); + return make_uniq(LogicalType::INTERVAL, std::move(expr)); +} + +// ZoneIntervalWithPrecision <- 'INTERVAL' Parens(NumberLiteral) StringLiteral +unique_ptr PEGTransformerFactory::TransformZoneIntervalWithPrecision( + PEGTransformer &transformer, unique_ptr number_literal, const string &string_literal) { + auto expr = make_uniq(Value(string_literal)); + return make_uniq(LogicalType::INTERVAL, std::move(expr)); +} + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_transaction.cpp b/src/duckdb/src/parser/peg/transformer/transform_transaction.cpp index fedb2e773..f852077a9 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_transaction.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_transaction.cpp @@ -4,28 +4,32 @@ namespace duckdb { unique_ptr -PEGTransformerFactory::TransformBeginTransaction(const TransactionModifierType &read_or_write) { +PEGTransformerFactory::TransformBeginTransaction(PEGTransformer &transformer, + const TransactionModifierType &read_or_write) { auto info = make_uniq(TransactionType::BEGIN_TRANSACTION); info->modifier = read_or_write; return make_uniq(std::move(info)); } TransactionModifierType -PEGTransformerFactory::TransformReadOrWrite(const TransactionModifierType &read_only_or_read_write) { +PEGTransformerFactory::TransformReadOrWrite(PEGTransformer &transformer, + const TransactionModifierType &read_only_or_read_write) { return read_only_or_read_write; } -TransactionModifierType PEGTransformerFactory::TransformReadOnlyOrReadWrite(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.TransformEnum(list_pr.Child(0).GetResult()); +TransactionModifierType PEGTransformerFactory::TransformReadOnly(PEGTransformer &transformer) { + return TransactionModifierType::TRANSACTION_READ_ONLY; } -unique_ptr PEGTransformerFactory::TransformCommitTransaction() { +TransactionModifierType PEGTransformerFactory::TransformReadWrite(PEGTransformer &transformer) { + return TransactionModifierType::TRANSACTION_READ_WRITE; +} + +unique_ptr PEGTransformerFactory::TransformCommitTransaction(PEGTransformer &transformer) { return make_uniq(make_uniq(TransactionType::COMMIT)); } -unique_ptr PEGTransformerFactory::TransformRollbackTransaction() { +unique_ptr PEGTransformerFactory::TransformRollbackTransaction(PEGTransformer &transformer) { return make_uniq(make_uniq(TransactionType::ROLLBACK)); } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_update.cpp b/src/duckdb/src/parser/peg/transformer/transform_update.cpp index 12510f000..e3e5dd239 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_update.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_update.cpp @@ -1,123 +1,96 @@ -#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/query_node/update_query_node.hpp" +#include "duckdb/parser/statement/update_statement.hpp" #include "duckdb/parser/peg/transformer/peg_transformer.hpp" namespace duckdb { -unique_ptr PEGTransformerFactory::TransformUpdateStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr PEGTransformerFactory::TransformUpdateStatement( + PEGTransformer &transformer, CommonTableExpressionMap with_clause, unique_ptr update_target, + unique_ptr update_set_clause, unique_ptr from_clause, + unique_ptr where_clause, vector> returning_clause) { auto result = make_uniq(); auto &node = *result->node; - auto &with_opt = list_pr.Child(0); - if (with_opt.HasResult()) { - node.cte_map = transformer.Transform(with_opt.GetResult()); - } - node.table = transformer.Transform>(list_pr.Child(2)); - node.set_info = transformer.Transform>(list_pr.Child(3)); - transformer.TransformOptional>(list_pr, 4, node.from_table); - transformer.TransformOptional>(list_pr, 5, node.set_info->condition); - transformer.TransformOptional>>(list_pr, 6, node.returning_list); + node.cte_map = std::move(with_clause); + node.table = std::move(update_target); + node.set_info = std::move(update_set_clause); + node.from_table = std::move(from_clause); + node.set_info->condition = std::move(where_clause); + node.returning_list = std::move(returning_clause); return std::move(result); } -unique_ptr PEGTransformerFactory::TransformUpdateTarget(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); -} - unique_ptr PEGTransformerFactory::TransformBaseTableSet(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0)); + unique_ptr base_table_name) { + return std::move(base_table_name); } unique_ptr PEGTransformerFactory::TransformBaseTableAliasSet(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto table_ref = transformer.Transform>(list_pr.Child(0)); - transformer.TransformOptional(list_pr, 1, table_ref->alias); - return std::move(table_ref); + unique_ptr base_table_name, + const Identifier &update_alias) { + base_table_name->alias = update_alias; + return std::move(base_table_name); } -string PEGTransformerFactory::TransformUpdateAlias(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(1)); -} - -unique_ptr PEGTransformerFactory::TransformUpdateSetClause(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform>(list_pr.Child(0).GetResult()); +Identifier PEGTransformerFactory::TransformUpdateAlias(PEGTransformer &transformer, const Identifier &col_id) { + return Identifier(col_id); } unique_ptr PEGTransformerFactory::TransformUpdateSetTuple(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - auto column_list = ExtractParseResultsFromList(extract_parens); + const vector &column_name, + unique_ptr expression) { auto result = make_uniq(); - result->columns.reserve(column_list.size()); - for (auto column : column_list) { - result->columns.push_back(column.get().Cast().identifier); - } + result->columns = column_name; - auto expr = transformer.Transform>(list_pr.Child(2)); bool is_row_assignment = false; - if (expr->GetExpressionClass() == ExpressionClass::FUNCTION) { - auto &func_ref = expr->Cast(); - if (StringUtil::CIEquals(func_ref.function_name, "row")) { + if (expression->GetExpressionClass() == ExpressionClass::FUNCTION) { + auto &func_ref = expression->Cast(); + if (func_ref.FunctionName() == "row") { is_row_assignment = true; } } if (is_row_assignment) { - auto &func_expr = expr->Cast(); - if (func_expr.children.size() != result->columns.size()) { + auto &func_expr = expression->Cast(); + if (func_expr.GetArguments().size() != result->columns.size()) { throw ParserException("Could not perform assignment, expected %d values, got %d", result->columns.size(), - func_expr.children.size()); + func_expr.GetArguments().size()); + } + for (auto &arg : func_expr.GetArgumentsMutable()) { + result->expressions.push_back(std::move(arg.GetExpressionMutable())); } - result->expressions = std::move(func_expr.children); } else { result->expressions.reserve(result->columns.size()); for (idx_t i = 0; i < result->columns.size(); i++) { - result->expressions.push_back(expr->Copy()); + result->expressions.push_back(expression->Copy()); } } return result; } -unique_ptr PEGTransformerFactory::TransformUpdateSetElementList(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr PEGTransformerFactory::TransformUpdateSetElementList( + PEGTransformer &transformer, vector>> update_set_element) { auto result = make_uniq(); - auto update_element_list = ExtractParseResultsFromList(list_pr.Child(0)); - for (auto &update_element : update_element_list) { - auto column_expr = transformer.Transform>>(update_element); - result->columns.push_back(column_expr.first); - result->expressions.push_back(std::move(column_expr.second)); + for (auto &element : update_set_element) { + result->columns.emplace_back(std::move(element.first)); + result->expressions.push_back(std::move(element.second)); } return result; } -pair> PEGTransformerFactory::TransformUpdateSetElement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto column_name = transformer.Transform(list_pr.Child(0)); - auto expr = transformer.Transform>(list_pr.Child(2)); - return make_pair(column_name, std::move(expr)); +pair> +PEGTransformerFactory::TransformUpdateSetElement(PEGTransformer &transformer, const string &update_set_column_target, + unique_ptr expression) { + return {update_set_column_target, std::move(expression)}; } -// UpdateSetColumnTarget <- ColumnName ('.' Identifier)* -string PEGTransformerFactory::TransformUpdateSetColumnTarget(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto column_name = list_pr.Child(0).identifier; - auto &extra_opt = list_pr.Child(1); - if (extra_opt.HasResult()) { +string PEGTransformerFactory::TransformUpdateSetColumnTarget(PEGTransformer &transformer, const Identifier &column_name, + const vector &dot_identifier) { + if (!dot_identifier.empty()) { throw ParserException("Qualified column names in UPDATE .. SET not supported"); } - return column_name; + return column_name.GetIdentifierName(); } + } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_use.cpp b/src/duckdb/src/parser/peg/transformer/transform_use.cpp index ac2bc4dfc..061f815ab 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_use.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_use.cpp @@ -4,10 +4,11 @@ namespace duckdb { // UseStatement <- 'USE' UseTarget -unique_ptr PEGTransformerFactory::TransformUseStatement(const QualifiedName &use_target) { +unique_ptr PEGTransformerFactory::TransformUseStatement(PEGTransformer &transformer, + const QualifiedName &use_target) { string value_str; if (IsInvalidSchema(use_target.schema)) { - value_str = SQLIdentifier::ToString(use_target.name); + value_str = SQLIdentifier::ToString(use_target.name.GetIdentifierName()); } else { value_str = SQLIdentifier(use_target.schema) + "." + SQLIdentifier(use_target.name); } @@ -17,30 +18,36 @@ unique_ptr PEGTransformerFactory::TransformUseStatement(const Qual } // UseTarget <- UseTargetCatalogSchema / SchemaName / CatalogName -QualifiedName PEGTransformerFactory::TransformUseTarget(PEGTransformer &transformer, ParseResult &pr) { - if (pr.type == ParseResultType::IDENTIFIER) { - QualifiedName result; - result.name = pr.Cast().identifier; - return result; - } - return transformer.Transform(pr); +QualifiedName PEGTransformerFactory::TransformSchemaNameAsUseTarget(PEGTransformer &transformer, + const Identifier &schema_name) { + QualifiedName result; + result.name = schema_name; + return result; +} + +QualifiedName PEGTransformerFactory::TransformCatalogNameAsUseTarget(PEGTransformer &transformer, + const Identifier &catalog_name) { + QualifiedName result; + result.name = catalog_name; + return result; } // UseTargetCatalogSchema <- CatalogName '.' ReservedSchemaName DotIdentifier* -QualifiedName PEGTransformerFactory::TransformUseTargetCatalogSchema(const string &catalog_name, - const string &reserved_schema_name, - const vector &dot_identifier) { +QualifiedName PEGTransformerFactory::TransformUseTargetCatalogSchema(PEGTransformer &transformer, + const Identifier &catalog_name, + const Identifier &reserved_schema_name, + const vector &dot_identifier) { if (!dot_identifier.empty()) { throw ParserException("Expected \"USE database\" or \"USE database.schema\""); } QualifiedName result; - result.catalog = INVALID_CATALOG; + result.catalog = Identifier::InvalidCatalog(); result.schema = catalog_name; result.name = reserved_schema_name; return result; } -string PEGTransformerFactory::TransformDotIdentifier(const string &identifier) { +Identifier PEGTransformerFactory::TransformDotIdentifier(PEGTransformer &transformer, const Identifier &identifier) { return identifier; } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_vacuum.cpp b/src/duckdb/src/parser/peg/transformer/transform_vacuum.cpp index c62ca358d..7d0139980 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_vacuum.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_vacuum.cpp @@ -4,90 +4,77 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformVacuumStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - VacuumOptions options; - transformer.TransformOptional(list_pr, 1, options); - auto result = make_uniq(options); - auto &target_opt = list_pr.Child(2); - if (target_opt.HasResult()) { - auto target = transformer.Transform(target_opt.GetResult()); - result->info->columns = target.columns; - result->info->ref = std::move(target.ref); + const VacuumOptions &vacuum_options, + AnalyzeTarget analyze_target) { + auto result = make_uniq(vacuum_options); + if (analyze_target.ref) { + result->info->columns = analyze_target.columns; + result->info->ref = std::move(analyze_target.ref); result->info->has_table = true; } return std::move(result); } -VacuumOptions PEGTransformerFactory::TransformVacuumOptions(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - return transformer.Transform(list_pr.Child(0).GetResult()); -} - -VacuumOptions PEGTransformerFactory::TransformVacuumLegacyOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +VacuumOptions PEGTransformerFactory::TransformVacuumLegacyOptions(PEGTransformer &transformer, const string &opt_full, + const string &opt_freeze, const string &opt_verbose, + const string &opt_analyze) { VacuumOptions options; options.vacuum = true; - options.analyze = list_pr.Child(3).HasResult(); - if (list_pr.Child(0).HasResult()) { + options.analyze = !opt_analyze.empty(); + if (!opt_full.empty()) { throw NotImplementedException("FULL is not yet implemented"); } - if (list_pr.Child(1).HasResult()) { + if (!opt_freeze.empty()) { throw NotImplementedException("FREEZE is not yet implemented"); } - if (list_pr.Child(2).HasResult()) { + if (!opt_verbose.empty()) { throw NotImplementedException("VERBOSE is not yet implemented"); } return options; } VacuumOptions PEGTransformerFactory::TransformVacuumParensOptions(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - auto option_list = ExtractParseResultsFromList(extract_parens); + const vector &vacuum_option) { VacuumOptions options; options.vacuum = true; - for (auto &option : option_list) { - string option_name = transformer.Transform(option); - if (StringUtil::CIEquals(option_name, "disable_page_skipping")) { + for (auto &option : vacuum_option) { + if (StringUtil::CIEquals(option, "disable_page_skipping")) { throw NotImplementedException("Disable Page Skipping vacuum option"); } - if (StringUtil::CIEquals(option_name, "freeze")) { + if (StringUtil::CIEquals(option, "freeze")) { throw NotImplementedException("FREEZE is not yet implemented"); } - if (StringUtil::CIEquals(option_name, "full")) { + if (StringUtil::CIEquals(option, "full")) { throw NotImplementedException("FULL is not yet implemented"); } - if (StringUtil::CIEquals(option_name, "verbose")) { + if (StringUtil::CIEquals(option, "verbose")) { throw NotImplementedException("VERBOSE is not yet implemented"); } - if (StringUtil::CIEquals(option_name, "analyze")) { + if (StringUtil::CIEquals(option, "analyze")) { options.analyze = true; } } return options; } -string PEGTransformerFactory::TransformVacuumOption(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - auto &option = list_pr.Child(0).GetResult(); - if (option.type == ParseResultType::IDENTIFIER) { - return option.Cast().identifier; - } - return transformer.TransformEnum(option); +vector PEGTransformerFactory::TransformNameList(PEGTransformer &transformer, const vector &col_id) { + return IdentifiersToStrings(col_id); } -vector PEGTransformerFactory::TransformNameList(PEGTransformer &transformer, ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); - vector result; - auto &extract_parens = ExtractResultFromParens(list_pr.Child(0)); - auto colid_list = ExtractParseResultsFromList(extract_parens); - for (auto &colid : colid_list) { - result.push_back(transformer.Transform(colid)); - } - return result; +string PEGTransformerFactory::TransformOptAnalyze(PEGTransformer &transformer) { + return "analyze"; +} + +string PEGTransformerFactory::TransformOptFull(PEGTransformer &transformer) { + return "full"; +} + +string PEGTransformerFactory::TransformOptFreeze(PEGTransformer &transformer) { + return "freeze"; +} + +string PEGTransformerFactory::TransformOptVerbose(PEGTransformer &transformer) { + return "verbose"; } } // namespace duckdb diff --git a/src/duckdb/src/parser/qualified_name.cpp b/src/duckdb/src/parser/qualified_name.cpp index fcb239ccf..d8b94dacb 100644 --- a/src/duckdb/src/parser/qualified_name.cpp +++ b/src/duckdb/src/parser/qualified_name.cpp @@ -8,8 +8,8 @@ string QualifiedName::ToString() const { return ParseInfo::QualifierToString(catalog, schema, name); } -vector QualifiedName::ParseComponents(const string &input) { - vector result; +vector QualifiedName::ParseComponents(const string &input) { + vector result; idx_t idx = 0; string entry; @@ -26,7 +26,7 @@ vector QualifiedName::ParseComponents(const string &input) { } goto end; separator: - result.push_back(entry); + result.push_back(Identifier(entry)); entry = ""; idx++; goto normal; @@ -43,15 +43,15 @@ vector QualifiedName::ParseComponents(const string &input) { throw ParserException("Unterminated quote in qualified name! (input: %s)", input); end: if (!entry.empty()) { - result.push_back(entry); + result.push_back(Identifier(entry)); } return result; } QualifiedName QualifiedName::Parse(const string &input) { - string catalog; - string schema; - string name; + Identifier catalog; + Identifier schema; + Identifier name; auto entries = ParseComponents(input); if (entries.empty()) { @@ -78,12 +78,12 @@ QualifiedName QualifiedName::Parse(const string &input) { QualifiedColumnName::QualifiedColumnName() { } -QualifiedColumnName::QualifiedColumnName(string column_p) : column(std::move(column_p)) { +QualifiedColumnName::QualifiedColumnName(Identifier column_p) : column(std::move(column_p)) { } -QualifiedColumnName::QualifiedColumnName(string table_p, string column_p) +QualifiedColumnName::QualifiedColumnName(Identifier table_p, Identifier column_p) : table(std::move(table_p)), column(std::move(column_p)) { } -QualifiedColumnName::QualifiedColumnName(const BindingAlias &alias, string column_p) +QualifiedColumnName::QualifiedColumnName(const BindingAlias &alias, Identifier column_p) : catalog(alias.GetCatalog()), schema(alias.GetSchema()), table(alias.GetAlias()), column(std::move(column_p)) { } @@ -133,8 +133,7 @@ bool QualifiedColumnName::IsQualified() const { } bool QualifiedColumnName::operator==(const QualifiedColumnName &rhs) const { - return StringUtil::CIEquals(catalog, rhs.catalog) && StringUtil::CIEquals(schema, rhs.schema) && - StringUtil::CIEquals(table, rhs.table) && StringUtil::CIEquals(column, rhs.column); + return catalog == rhs.catalog && schema == rhs.schema && table == rhs.table && column == rhs.column; } } // namespace duckdb diff --git a/src/duckdb/src/parser/query_node.cpp b/src/duckdb/src/parser/query_node.cpp index a3e3c8ba7..719f9c2e6 100644 --- a/src/duckdb/src/parser/query_node.cpp +++ b/src/duckdb/src/parser/query_node.cpp @@ -111,19 +111,15 @@ string QueryNode::ResultModifiersToString() const { } else if (modifier.type == ResultModifierType::LIMIT_MODIFIER) { auto &limit_modifier = modifier.Cast(); if (limit_modifier.limit) { - result += " LIMIT " + limit_modifier.limit->ToString(); + if (limit_modifier.limit_type == LimitValueType::PERCENTAGE) { + result += " LIMIT (" + limit_modifier.limit->ToString() + ") %"; + } else { + result += " LIMIT " + limit_modifier.limit->ToString(); + } } if (limit_modifier.offset) { result += " OFFSET " + limit_modifier.offset->ToString(); } - } else if (modifier.type == ResultModifierType::LIMIT_PERCENT_MODIFIER) { - auto &limit_p_modifier = modifier.Cast(); - if (limit_p_modifier.limit) { - result += " LIMIT (" + limit_p_modifier.limit->ToString() + ") %"; - } - if (limit_p_modifier.offset) { - result += " OFFSET " + limit_p_modifier.offset->ToString(); - } } } return result; @@ -173,7 +169,7 @@ bool QueryNode::Equals(const QueryNode *other) const { return false; } } - return other->type == type; + return true; } void QueryNode::CopyProperties(QueryNode &other) const { @@ -209,8 +205,7 @@ void QueryNode::AddDistinct() { // we have a DISTINCT without an ON clause - this distinct does not need to be added return; } - } else if (modifier.type == ResultModifierType::LIMIT_MODIFIER || - modifier.type == ResultModifierType::LIMIT_PERCENT_MODIFIER) { + } else if (modifier.type == ResultModifierType::LIMIT_MODIFIER) { // we encountered a LIMIT or LIMIT PERCENT - these change the result of DISTINCT, so we do need to push a // DISTINCT relation break; diff --git a/src/duckdb/src/parser/query_node/insert_query_node.cpp b/src/duckdb/src/parser/query_node/insert_query_node.cpp index e8153181c..6b837b701 100644 --- a/src/duckdb/src/parser/query_node/insert_query_node.cpp +++ b/src/duckdb/src/parser/query_node/insert_query_node.cpp @@ -52,7 +52,7 @@ string InsertQueryNode::ToString() const { if (values_list) { D_ASSERT(!default_values); auto saved_alias = values_list->alias; - values_list->alias = string(); + values_list->alias = Identifier(string()); result += values_list->ToString(); values_list->alias = saved_alias; } else if (select_statement) { @@ -70,7 +70,7 @@ string InsertQueryNode::ToString() const { result += "("; auto &cols = conflict_info.indexed_columns; for (auto it = cols.begin(); it != cols.end();) { - result += StringUtil::Lower(*it); + result += StringUtil::Lower(it->GetIdentifierName()); if (++it != cols.end()) { result += ", "; } @@ -95,7 +95,7 @@ string InsertQueryNode::ToString() const { if (i) { result += ", "; } - result += StringUtil::Lower(column) + " = " + expr->ToString(); + result += StringUtil::Lower(column.GetIdentifierName()) + " = " + expr->ToString(); } // (optional) where clause if (set_info.condition) { diff --git a/src/duckdb/src/parser/query_node/merge_query_node.cpp b/src/duckdb/src/parser/query_node/merge_query_node.cpp new file mode 100644 index 000000000..5dd8f466f --- /dev/null +++ b/src/duckdb/src/parser/query_node/merge_query_node.cpp @@ -0,0 +1,135 @@ +#include "duckdb/parser/query_node/merge_query_node.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/parser/keyword_helper.hpp" +#include "duckdb/common/string_util.hpp" + +namespace duckdb { + +MergeQueryNode::MergeQueryNode() : QueryNode(QueryNodeType::MERGE_QUERY_NODE) { +} + +string MergeQueryNode::ActionConditionToString(MergeActionCondition condition) { + switch (condition) { + case MergeActionCondition::WHEN_MATCHED: + return "WHEN MATCHED"; + case MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET: + return "WHEN NOT MATCHED"; + case MergeActionCondition::WHEN_NOT_MATCHED_BY_SOURCE: + return "WHEN NOT MATCHED BY SOURCE"; + default: + throw InternalException("Unknown match condition"); + } +} + +string MergeQueryNode::ToString() const { + string result; + result = cte_map.ToString(); + result += "MERGE INTO "; + result += target->ToString(); + result += " USING "; + result += source->ToString(); + if (join_condition) { + result += " ON "; + result += join_condition->ToString(); + } else { + result += " USING ("; + for (idx_t c = 0; c < using_columns.size(); c++) { + if (c > 0) { + result += ", "; + } + result += using_columns[c].GetIdentifierName(); + } + result += ")"; + } + for (auto &entry : actions) { + for (auto &action : entry.second) { + result += " "; + result += MergeQueryNode::ActionConditionToString(entry.first); + result += " "; + result += action->ToString(); + } + } + if (!returning_list.empty()) { + result += " RETURNING "; + for (idx_t i = 0; i < returning_list.size(); i++) { + if (i > 0) { + result += ", "; + } + auto column = returning_list[i]->ToString(); + if (!returning_list[i]->GetAlias().empty()) { + column += StringUtil::Format(" AS %s", SQLIdentifier(returning_list[i]->GetAlias())); + } + result += column; + } + } + return result; +} + +bool MergeQueryNode::Equals(const QueryNode *other_p) const { + if (this == other_p) { + return true; + } + if (!QueryNode::Equals(other_p)) { + return false; + } + auto &other = other_p->Cast(); + if (!TableRef::Equals(target, other.target)) { + return false; + } + if (!TableRef::Equals(source, other.source)) { + return false; + } + if (!ParsedExpression::Equals(join_condition, other.join_condition)) { + return false; + } + if (using_columns != other.using_columns) { + return false; + } + if (actions.size() != other.actions.size()) { + return false; + } + for (auto &entry : actions) { + auto other_entry = other.actions.find(entry.first); + if (other_entry == other.actions.end()) { + return false; + } + if (entry.second.size() != other_entry->second.size()) { + return false; + } + for (idx_t i = 0; i < entry.second.size(); i++) { + if (!MergeIntoAction::Equals(*entry.second[i], *other_entry->second[i])) { + return false; + } + } + } + if (returning_list.size() != other.returning_list.size()) { + return false; + } + for (idx_t i = 0; i < returning_list.size(); i++) { + if (!ParsedExpression::Equals(returning_list[i], other.returning_list[i])) { + return false; + } + } + return true; +} + +unique_ptr MergeQueryNode::Copy() const { + auto result = make_uniq(); + result->target = target ? target->Copy() : nullptr; + result->source = source ? source->Copy() : nullptr; + result->join_condition = join_condition ? join_condition->Copy() : nullptr; + result->using_columns = using_columns; + for (auto &entry : actions) { + auto &action_list = result->actions[entry.first]; + for (auto &action : entry.second) { + action_list.push_back(action->Copy()); + } + } + for (auto &expr : returning_list) { + result->returning_list.push_back(expr->Copy()); + } + CopyProperties(*result); + return std::move(result); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/query_node/set_operation_node.cpp b/src/duckdb/src/parser/query_node/set_operation_node.cpp index 9ab492ed5..a6c3808ee 100644 --- a/src/duckdb/src/parser/query_node/set_operation_node.cpp +++ b/src/duckdb/src/parser/query_node/set_operation_node.cpp @@ -87,6 +87,11 @@ SetOperationNode::SetOperationNode(SetOperationType setop_type, unique_ptr SetOperationNode::SerializeChildNode(Serializer &serializer, idx_t index) const { @@ -127,7 +132,7 @@ unique_ptr SetOperationNode::SerializeChildNode(Serializer &serialize } bool SetOperationNode::SerializeChildList(Serializer &serializer) const { - return serializer.ShouldSerialize(7); + return serializer.ShouldSerialize(StorageVersion::V1_5_0); } } // namespace duckdb diff --git a/src/duckdb/src/parser/result_modifier.cpp b/src/duckdb/src/parser/result_modifier.cpp index e22108aea..308592726 100644 --- a/src/duckdb/src/parser/result_modifier.cpp +++ b/src/duckdb/src/parser/result_modifier.cpp @@ -12,6 +12,9 @@ bool LimitModifier::Equals(const ResultModifier &other_p) const { return false; } auto &other = other_p.Cast(); + if (limit_type != other.limit_type) { + return false; + } if (!ParsedExpression::Equals(limit, other.limit)) { return false; } @@ -23,6 +26,7 @@ bool LimitModifier::Equals(const ResultModifier &other_p) const { unique_ptr LimitModifier::Copy() const { auto copy = make_uniq(); + copy->limit_type = limit_type; if (limit) { copy->limit = limit->Copy(); } @@ -32,6 +36,21 @@ unique_ptr LimitModifier::Copy() const { return std::move(copy); } +bool LimitModifier::UseLegacySerialization() const { + return limit_type == LimitValueType::PERCENTAGE; +} + +void LimitModifier::LegacySerialize(Serializer &serializer) const { + LegacyLimitPercentModifier legacy_serialization; + if (limit) { + legacy_serialization.limit = limit->Copy(); + } + if (offset) { + legacy_serialization.offset = offset->Copy(); + } + legacy_serialization.Serialize(serializer); +} + bool DistinctModifier::Equals(const ResultModifier &other_p) const { if (!ResultModifier::Equals(other_p)) { return false; @@ -114,29 +133,22 @@ string OrderByNode::ToString() const { return str; } -bool LimitPercentModifier::Equals(const ResultModifier &other_p) const { - if (!ResultModifier::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (!ParsedExpression::Equals(limit, other.limit)) { - return false; - } - if (!ParsedExpression::Equals(offset, other.offset)) { - return false; - } - return true; +bool LegacyLimitPercentModifier::Equals(const ResultModifier &other) const { + throw InternalException("Legacy Limit Percent modifier should not be used anymore"); } -unique_ptr LimitPercentModifier::Copy() const { - auto copy = make_uniq(); - if (limit) { - copy->limit = limit->Copy(); - } - if (offset) { - copy->offset = offset->Copy(); - } - return std::move(copy); +unique_ptr LegacyLimitPercentModifier::Copy() const { + throw InternalException("Legacy Limit Percent modifier should not be used anymore"); +} + +unique_ptr +LegacyLimitPercentModifier::DeserializeLegacyLimitPercentModifier(unique_ptr limit_p, + unique_ptr offset_p) { + auto result = make_uniq(); + result->limit = std::move(limit_p); + result->offset = std::move(offset_p); + result->limit_type = LimitValueType::PERCENTAGE; + return std::move(result); } } // namespace duckdb diff --git a/src/duckdb/src/parser/statement/connect_statement.cpp b/src/duckdb/src/parser/statement/connect_statement.cpp new file mode 100644 index 000000000..224054c35 --- /dev/null +++ b/src/duckdb/src/parser/statement/connect_statement.cpp @@ -0,0 +1,19 @@ +#include "duckdb/parser/statement/connect_statement.hpp" + +namespace duckdb { + +ConnectStatement::ConnectStatement() : SQLStatement(StatementType::CONNECT_STATEMENT) { +} + +ConnectStatement::ConnectStatement(const ConnectStatement &other) : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr ConnectStatement::Copy() const { + return unique_ptr(new ConnectStatement(*this)); +} + +string ConnectStatement::ToString() const { + return info->ToString(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/copy_database_statement.cpp b/src/duckdb/src/parser/statement/copy_database_statement.cpp index c3b8304ee..afd1c83ef 100644 --- a/src/duckdb/src/parser/statement/copy_database_statement.cpp +++ b/src/duckdb/src/parser/statement/copy_database_statement.cpp @@ -3,7 +3,8 @@ namespace duckdb { -CopyDatabaseStatement::CopyDatabaseStatement(string from_database_p, string to_database_p, CopyDatabaseType copy_type) +CopyDatabaseStatement::CopyDatabaseStatement(Identifier from_database_p, Identifier to_database_p, + CopyDatabaseType copy_type) : SQLStatement(StatementType::COPY_DATABASE_STATEMENT), from_database(std::move(from_database_p)), to_database(std::move(to_database_p)), copy_type(copy_type) { } diff --git a/src/duckdb/src/parser/statement/disconnect_statement.cpp b/src/duckdb/src/parser/statement/disconnect_statement.cpp new file mode 100644 index 000000000..db5dfe1f4 --- /dev/null +++ b/src/duckdb/src/parser/statement/disconnect_statement.cpp @@ -0,0 +1,20 @@ +#include "duckdb/parser/statement/disconnect_statement.hpp" + +namespace duckdb { + +DisconnectStatement::DisconnectStatement() : SQLStatement(StatementType::DISCONNECT_STATEMENT) { +} + +DisconnectStatement::DisconnectStatement(const DisconnectStatement &other) + : SQLStatement(other), info(other.info->Copy()) { +} + +unique_ptr DisconnectStatement::Copy() const { + return unique_ptr(new DisconnectStatement(*this)); +} + +string DisconnectStatement::ToString() const { + return info->ToString(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/statement/execute_statement.cpp b/src/duckdb/src/parser/statement/execute_statement.cpp index b052b71b9..6a4439708 100644 --- a/src/duckdb/src/parser/statement/execute_statement.cpp +++ b/src/duckdb/src/parser/statement/execute_statement.cpp @@ -22,7 +22,7 @@ string ExecuteStatement::ToString() const { if (!named_values.empty()) { vector stringified; for (auto &val : named_values) { - stringified.push_back(StringUtil::Format("\"%s\" := %s", val.first, val.second->ToString())); + stringified.push_back(StringUtil::Format("%s := %s", val.first, val.second->ToString())); } result += "(" + StringUtil::Join(stringified, ", ") + ")"; } diff --git a/src/duckdb/src/parser/statement/merge_into_statement.cpp b/src/duckdb/src/parser/statement/merge_into_statement.cpp index 393fe5bb7..3bd980eaf 100644 --- a/src/duckdb/src/parser/statement/merge_into_statement.cpp +++ b/src/duckdb/src/parser/statement/merge_into_statement.cpp @@ -1,86 +1,55 @@ #include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" namespace duckdb { -MergeIntoStatement::MergeIntoStatement() : SQLStatement(StatementType::MERGE_INTO_STATEMENT) { +MergeIntoStatement::MergeIntoStatement() + : SQLStatement(StatementType::MERGE_INTO_STATEMENT), node(make_uniq()) { } -MergeIntoStatement::MergeIntoStatement(const MergeIntoStatement &other) : SQLStatement(other) { - target = other.target->Copy(); - source = other.source->Copy(); - join_condition = other.join_condition ? other.join_condition->Copy() : nullptr; - using_columns = other.using_columns; - for (auto &entry : other.actions) { - auto &action_list = actions[entry.first]; - for (auto &action : entry.second) { - action_list.push_back(action->Copy()); - } - } - for (auto &entry : other.returning_list) { - returning_list.push_back(entry->Copy()); - } - cte_map = other.cte_map.Copy(); +MergeIntoStatement::MergeIntoStatement(const MergeIntoStatement &other) + : SQLStatement(other), node(unique_ptr_cast(other.node->Copy())) { } -string MergeIntoStatement::ActionConditionToString(MergeActionCondition condition) { - switch (condition) { - case MergeActionCondition::WHEN_MATCHED: - return "WHEN MATCHED"; - case MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET: - return "WHEN NOT MATCHED"; - case MergeActionCondition::WHEN_NOT_MATCHED_BY_SOURCE: - return "WHEN NOT MATCHED BY SOURCE"; - default: - throw InternalException("Unknown match condition"); - } +string MergeIntoStatement::ToString() const { + return node->ToString(); } -string MergeIntoStatement::ToString() const { - string result; - result = cte_map.ToString(); - result += "MERGE INTO "; - result += target->ToString(); - result += " USING "; - result += source->ToString(); - if (join_condition) { - result += " ON "; - result += join_condition->ToString(); - } else { - result += " USING ("; - for (idx_t c = 0; c < using_columns.size(); c++) { - if (c > 0) { - result += ", "; - } - result += using_columns[c]; - } - result += ")"; +unique_ptr MergeIntoStatement::Copy() const { + return unique_ptr(new MergeIntoStatement(*this)); +} + +bool MergeIntoAction::Equals(const MergeIntoAction &left, const MergeIntoAction &right) { + if (left.action_type != right.action_type) { + return false; } - for (auto &entry : actions) { - for (auto &action : entry.second) { - result += " "; - result += MergeIntoStatement::ActionConditionToString(entry.first); - result += " "; - result += action->ToString(); - } + if (!ParsedExpression::Equals(left.condition, right.condition)) { + return false; } - if (!returning_list.empty()) { - result += " RETURNING "; - for (idx_t i = 0; i < returning_list.size(); i++) { - if (i > 0) { - result += ", "; - } - auto column = returning_list[i]->ToString(); - if (!returning_list[i]->GetAlias().empty()) { - column += StringUtil::Format(" AS %s", SQLIdentifier(returning_list[i]->GetAlias())); - } - result += column; + if (!UpdateSetInfo::Equals(left.update_info, right.update_info)) { + return false; + } + if (left.insert_columns != right.insert_columns) { + return false; + } + if (left.column_order != right.column_order) { + return false; + } + if (left.default_values != right.default_values) { + return false; + } + if (left.exclude_columns != right.exclude_columns) { + return false; + } + if (left.expressions.size() != right.expressions.size()) { + return false; + } + for (idx_t i = 0; i < left.expressions.size(); i++) { + if (!ParsedExpression::Equals(left.expressions[i], right.expressions[i])) { + return false; } } - return result; -} - -unique_ptr MergeIntoStatement::Copy() const { - return unique_ptr(new MergeIntoStatement(*this)); + return true; } string MergeIntoAction::ToString() const { diff --git a/src/duckdb/src/parser/statement/set_statement.cpp b/src/duckdb/src/parser/statement/set_statement.cpp index 6c4138643..2d1b465a5 100644 --- a/src/duckdb/src/parser/statement/set_statement.cpp +++ b/src/duckdb/src/parser/statement/set_statement.cpp @@ -3,13 +3,13 @@ namespace duckdb { -SetStatement::SetStatement(string name_p, SetScope scope_p, SetType type_p) +SetStatement::SetStatement(Identifier name_p, SetScope scope_p, SetType type_p) : SQLStatement(StatementType::SET_STATEMENT), name(std::move(name_p)), scope(scope_p), set_type(type_p) { } // Set Variable -SetVariableStatement::SetVariableStatement(string name_p, unique_ptr value_p, SetScope scope_p) +SetVariableStatement::SetVariableStatement(Identifier name_p, unique_ptr value_p, SetScope scope_p) : SetStatement(std::move(name_p), scope_p, SetType::SET), value(std::move(value_p)) { } @@ -44,7 +44,7 @@ string SetVariableStatement::ToString() const { // Reset Variable -ResetVariableStatement::ResetVariableStatement(std::string name_p, SetScope scope_p) +ResetVariableStatement::ResetVariableStatement(Identifier name_p, SetScope scope_p) : SetStatement(std::move(name_p), scope_p, SetType::RESET) { } diff --git a/src/duckdb/src/parser/tableref.cpp b/src/duckdb/src/parser/tableref.cpp index 1b742cfac..807732c6c 100644 --- a/src/duckdb/src/parser/tableref.cpp +++ b/src/duckdb/src/parser/tableref.cpp @@ -8,14 +8,14 @@ namespace duckdb { string TableRef::BaseToString(string result) const { - vector column_name_alias; + vector column_name_alias; return BaseToString(std::move(result), column_name_alias); } -string TableRef::AliasToString(const vector &column_name_alias) const { +string TableRef::AliasToString(const vector &column_name_alias) const { string result; if (!alias.empty()) { - result += StringUtil::Format(" AS %s", SQLIdentifier(alias)); + result += StringUtil::Format(" AS %s", alias); } if (!column_name_alias.empty()) { D_ASSERT(!alias.empty()); @@ -43,7 +43,7 @@ string TableRef::SampleToString() const { return result; } -string TableRef::BaseToString(string result, const vector &column_name_alias) const { +string TableRef::BaseToString(string result, const vector &column_name_alias) const { result += AliasToString(column_name_alias); result += SampleToString(); return result; diff --git a/src/duckdb/src/parser/tableref/column_data_ref.cpp b/src/duckdb/src/parser/tableref/column_data_ref.cpp index 6336e9204..edaca66a4 100644 --- a/src/duckdb/src/parser/tableref/column_data_ref.cpp +++ b/src/duckdb/src/parser/tableref/column_data_ref.cpp @@ -3,7 +3,7 @@ namespace duckdb { -ColumnDataRef::ColumnDataRef(optionally_owned_ptr collection_p, vector expected_names) +ColumnDataRef::ColumnDataRef(optionally_owned_ptr collection_p, vector expected_names) : TableRef(TableReferenceType::COLUMN_DATA), expected_names(std::move(expected_names)), collection(std::move(collection_p)) { } @@ -37,7 +37,7 @@ bool ColumnDataRef::Equals(const TableRef &other_p) const { if (this_type != other_type) { return false; } - if (!StringUtil::CIEquals(this_name, other_name)) { + if (this_name != other_name) { return false; } } diff --git a/src/duckdb/src/parser/tableref/pivotref.cpp b/src/duckdb/src/parser/tableref/pivotref.cpp index 5bc47d3fd..2174ed719 100644 --- a/src/duckdb/src/parser/tableref/pivotref.cpp +++ b/src/duckdb/src/parser/tableref/pivotref.cpp @@ -142,28 +142,28 @@ static bool TryFoldConstantForBackwardsCompatibility(const ParsedExpression &exp switch (expr.GetExpressionType()) { case ExpressionType::FUNCTION: { auto &function = expr.Cast(); - if (function.function_name == "struct_pack") { - unordered_set unique_names; + if (function.FunctionName() == "struct_pack") { + identifier_set_t unique_names; child_list_t values; - values.reserve(function.children.size()); - for (const auto &child : function.children) { - if (!unique_names.insert(child->GetAlias()).second) { + values.reserve(function.GetArguments().size()); + for (const auto &child : function.GetArguments()) { + if (!unique_names.insert(child.GetExpression().GetAlias()).second) { return false; } Value child_value; - if (!TryFoldConstantForBackwardsCompatibility(*child, child_value)) { + if (!TryFoldConstantForBackwardsCompatibility(child.GetExpression(), child_value)) { return false; } - values.emplace_back(child->GetAlias(), std::move(child_value)); + values.emplace_back(child.GetExpression().GetAlias(), std::move(child_value)); } value = Value::STRUCT(std::move(values)); return true; - } else if (function.function_name == "list_value") { + } else if (function.FunctionName() == "list_value") { vector values; - values.reserve(function.children.size()); - for (const auto &child : function.children) { + values.reserve(function.GetArguments().size()); + for (const auto &child : function.GetArguments()) { Value child_value; - if (!TryFoldConstantForBackwardsCompatibility(*child, child_value)) { + if (!TryFoldConstantForBackwardsCompatibility(child.GetExpression(), child_value)) { return false; } values.emplace_back(std::move(child_value)); @@ -172,20 +172,20 @@ static bool TryFoldConstantForBackwardsCompatibility(const ParsedExpression &exp // figure out child type LogicalType child_type(LogicalTypeId::SQLNULL); for (auto &child_value : values) { - child_type = LogicalType::ForceMaxLogicalType(child_type, child_value.type()); + child_type = LogicalType::DefaultForceMaxLogicalType(child_type, child_value.type()); } // finally create the list value = Value::LIST(child_type, values); return true; - } else if (function.function_name == "map") { + } else if (function.FunctionName() == "map") { Value keys; - if (!TryFoldConstantForBackwardsCompatibility(*function.children[0], keys)) { + if (!TryFoldConstantForBackwardsCompatibility(function.GetArguments()[0].GetExpression(), keys)) { return false; } Value values; - if (!TryFoldConstantForBackwardsCompatibility(*function.children[1], values)) { + if (!TryFoldConstantForBackwardsCompatibility(function.GetArguments()[1].GetExpression(), values)) { return false; } @@ -207,14 +207,14 @@ static bool TryFoldConstantForBackwardsCompatibility(const ParsedExpression &exp case ExpressionType::OPERATOR_CAST: { auto &cast = expr.Cast(); Value dummy_value; - if (!TryFoldConstantForBackwardsCompatibility(*cast.child, dummy_value)) { + if (!TryFoldConstantForBackwardsCompatibility(cast.Child(), dummy_value)) { return false; } // Try to default bind cast LogicalType cast_type; try { - cast_type = UnboundType::TryDefaultBind(cast.cast_type); + cast_type = UnboundType::TryDefaultBind(cast.TargetType()); } catch (...) { return false; } @@ -250,11 +250,11 @@ static bool TryFoldForBackwardsCompatibility(const unique_ptr } case ExpressionType::FUNCTION: { auto &function = expr->Cast(); - if (function.function_name != "row") { + if (function.FunctionName() != "row") { return false; } - for (auto &child : function.children) { - if (!TryFoldForBackwardsCompatibility(child, values)) { + for (auto &child : function.GetArgumentsMutable()) { + if (!TryFoldForBackwardsCompatibility(child.GetExpressionMutable(), values)) { return false; } } @@ -274,7 +274,7 @@ static bool TryFoldForBackwardsCompatibility(const unique_ptr vector PivotColumn::GetEntriesForSerialization(Serializer &serializer) const { vector result; - if (serializer.ShouldSerialize(7)) { + if (serializer.ShouldSerialize(StorageVersion::V1_5_0)) { // Latest version, serialize as is. // Unfortunately, we have to make a deep copy to return vector by value. for (auto &entry : entries) { @@ -312,7 +312,7 @@ vector PivotColumn::GetEntriesForSerialization(Serializer &ser // Otherwise this is a PIVOT with an expression we could not fold. // Older versions of DuckDB do not support this, so throw an exception. - const auto target_version = serializer.GetOptions().serialization_compatibility.duckdb_version; + const auto target_version = serializer.GetOptions().storage_compatibility.duckdb_version; throw SerializationException( "Cannot serialize non-constant expression '%s' in pivot list when targeting database storage version '%s'", diff --git a/src/duckdb/src/parser/tableref/showref.cpp b/src/duckdb/src/parser/tableref/showref.cpp index 97b9d1ad0..6d6636c0b 100644 --- a/src/duckdb/src/parser/tableref/showref.cpp +++ b/src/duckdb/src/parser/tableref/showref.cpp @@ -29,7 +29,7 @@ string ShowRef::ToString() const { result += query->ToString(); result += ")"; } else if (table_name != "__show_tables_expanded") { - result += table_name; + result += table_name.GetIdentifierName(); } return result; } diff --git a/src/duckdb/src/parser/tableref/subqueryref.cpp b/src/duckdb/src/parser/tableref/subqueryref.cpp index 4d9d09880..20f03283c 100644 --- a/src/duckdb/src/parser/tableref/subqueryref.cpp +++ b/src/duckdb/src/parser/tableref/subqueryref.cpp @@ -12,7 +12,7 @@ string SubqueryRef::ToString() const { SubqueryRef::SubqueryRef() : TableRef(TableReferenceType::SUBQUERY) { } -SubqueryRef::SubqueryRef(unique_ptr subquery_p, string alias_p) +SubqueryRef::SubqueryRef(unique_ptr subquery_p, Identifier alias_p) : TableRef(TableReferenceType::SUBQUERY), subquery(std::move(subquery_p)) { this->alias = std::move(alias_p); } diff --git a/src/duckdb/src/planner/bind_context.cpp b/src/duckdb/src/planner/bind_context.cpp index adac33218..dd20e0f92 100644 --- a/src/duckdb/src/planner/bind_context.cpp +++ b/src/duckdb/src/planner/bind_context.cpp @@ -25,16 +25,16 @@ BindContext::BindContext(Binder &binder) : binder(binder) { } string MinimumUniqueAlias(const BindingAlias &alias, const BindingAlias &other) { - if (!StringUtil::CIEquals(alias.GetAlias(), other.GetAlias())) { - return alias.GetAlias(); + if (alias.GetAlias() != other.GetAlias()) { + return alias.GetAlias().GetIdentifierName(); } - if (!StringUtil::CIEquals(alias.GetSchema(), other.GetSchema())) { + if (alias.GetSchema() != other.GetSchema()) { return alias.GetSchema() + "." + alias.GetAlias(); } return alias.ToString(); } -optional_ptr BindContext::GetMatchingBinding(const string &column_name, QueryErrorContext context) { +optional_ptr BindContext::GetMatchingBinding(const Identifier &column_name, QueryErrorContext context) { optional_ptr result; for (auto &binding_ptr : bindings_list) { auto &binding = *binding_ptr; @@ -57,29 +57,29 @@ optional_ptr BindContext::GetMatchingBinding(const string &column_name, return result; } -vector BindContext::GetSimilarBindings(const string &column_name) { +vector BindContext::GetSimilarBindings(const Identifier &column_name) { vector> scores; for (auto &binding_ptr : bindings_list) { auto &binding = *binding_ptr; for (auto &name : binding.GetColumnNames()) { - double distance = StringUtil::SimilarityRating(name, column_name); + double distance = StringUtil::SimilarityRating(name.GetIdentifierName(), column_name.GetIdentifierName()); // check if we need to qualify the column auto matching_bindings = GetMatchingBindings(name); if (matching_bindings.size() > 1) { - scores.emplace_back(binding.GetAlias() + "." + name, distance); + scores.emplace_back(binding.GetAlias() + "." + name.GetIdentifierName(), distance); } else { scores.emplace_back(name, distance); } } } - return StringUtil::TopNStrings(scores); + return StringsToIdentifiers(StringUtil::TopNStrings(scores)); } -void BindContext::AddUsingBinding(const string &column_name, UsingColumnSet &set) { +void BindContext::AddUsingBinding(const Identifier &column_name, UsingColumnSet &set) { using_columns[column_name].insert(set); } -optional_ptr BindContext::GetUsingBinding(const string &column_name) { +optional_ptr BindContext::GetUsingBinding(const Identifier &column_name) { auto entry = using_columns.find(column_name); if (entry == using_columns.end()) { return nullptr; @@ -111,7 +111,7 @@ optional_ptr BindContext::GetUsingBinding(const string &column_n return &using_bindings.begin()->get(); } -optional_ptr BindContext::GetUsingBinding(const string &column_name, const BindingAlias &binding) { +optional_ptr BindContext::GetUsingBinding(const Identifier &column_name, const BindingAlias &binding) { if (!binding.IsSet()) { throw InternalException("GetUsingBinding: expected non-empty binding_name"); } @@ -132,7 +132,7 @@ optional_ptr BindContext::GetUsingBinding(const string &column_n return nullptr; } -void BindContext::RemoveUsingBinding(const string &column_name, UsingColumnSet &set) { +void BindContext::RemoveUsingBinding(const Identifier &column_name, UsingColumnSet &set) { auto entry = using_columns.find(column_name); if (entry == using_columns.end()) { throw InternalException("Attempting to remove using binding that is not there"); @@ -147,23 +147,23 @@ void BindContext::RemoveUsingBinding(const string &column_name, UsingColumnSet & } void BindContext::TransferUsingBinding(BindContext ¤t_context, optional_ptr current_set, - UsingColumnSet &new_set, const string &using_column) { + UsingColumnSet &new_set, const Identifier &using_column) { AddUsingBinding(using_column, new_set); if (current_set) { current_context.RemoveUsingBinding(using_column, *current_set); } } -string BindContext::GetActualColumnName(Binding &binding, const string &column_name) { +string BindContext::GetActualColumnName(Binding &binding, const Identifier &column_name) { column_t binding_index; if (!binding.TryGetBindingIndex(column_name, binding_index)) { // LCOV_EXCL_START throw InternalException("Binding with name \"%s\" does not have a column named \"%s\"", binding.GetAlias(), column_name); } // LCOV_EXCL_STOP - return binding.GetColumnNames()[binding_index]; + return binding.GetColumnNames()[binding_index].GetIdentifierName(); } -string BindContext::GetActualColumnName(const BindingAlias &binding_alias, const string &column_name) { +string BindContext::GetActualColumnName(const BindingAlias &binding_alias, const Identifier &column_name) { ErrorData error; auto binding = GetBinding(binding_alias, error); if (!binding) { @@ -172,7 +172,7 @@ string BindContext::GetActualColumnName(const BindingAlias &binding_alias, const return GetActualColumnName(*binding, column_name); } -vector> BindContext::GetMatchingBindings(const string &column_name) { +vector> BindContext::GetMatchingBindings(const Identifier &column_name) { vector> result; for (auto &binding_ptr : bindings_list) { auto &binding = *binding_ptr; @@ -184,22 +184,24 @@ vector> BindContext::GetMatchingBindings(const string &column } unique_ptr BindContext::ExpandGeneratedColumn(TableBinding &table_binding, - const string &column_name) { + const Identifier &column_name) { auto result = table_binding.ExpandGeneratedColumn(column_name); result->SetAlias(column_name); return result; } unique_ptr BindContext::CreateColumnReference(const BindingAlias &table_alias, - const string &column_name, ColumnBindType bind_type) { + const Identifier &column_name, + ColumnBindType bind_type) { return CreateColumnReference(table_alias.GetCatalog(), table_alias.GetSchema(), table_alias.GetAlias(), column_name, bind_type); } -unique_ptr BindContext::CreateColumnReference(const string &table_name, const string &column_name, +unique_ptr BindContext::CreateColumnReference(const Identifier &table_name, + const Identifier &column_name, ColumnBindType bind_type) { string schema_name; - return CreateColumnReference(schema_name, table_name, column_name, bind_type); + return CreateColumnReference(Identifier(schema_name), table_name, column_name, bind_type); } static bool ColumnIsGenerated(Binding &binding, column_t index) { @@ -219,11 +221,13 @@ static bool ColumnIsGenerated(Binding &binding, column_t index) { return table_entry.GetColumn(LogicalIndex(index)).Generated(); } -unique_ptr BindContext::CreateColumnReference(const string &catalog_name, const string &schema_name, - const string &table_name, const string &column_name, +unique_ptr BindContext::CreateColumnReference(const Identifier &catalog_name, + const Identifier &schema_name, + const Identifier &table_name, + const Identifier &column_name, ColumnBindType bind_type) { ErrorData error; - vector names; + vector names; if (!catalog_name.empty()) { names.push_back(catalog_name); } @@ -233,8 +237,8 @@ unique_ptr BindContext::CreateColumnReference(const string &ca names.push_back(table_name); names.push_back(column_name); - BindingAlias alias(catalog_name, schema_name, table_name); - auto result = make_uniq(std::move(names)); + BindingAlias alias {catalog_name, schema_name, table_name}; + auto result = make_uniq(names); auto binding = GetBinding(alias, column_name, error); if (!binding) { return std::move(result); @@ -252,9 +256,11 @@ unique_ptr BindContext::CreateColumnReference(const string &ca return std::move(result); } -unique_ptr BindContext::CreateColumnReference(const string &schema_name, const string &table_name, - const string &column_name, ColumnBindType bind_type) { - string catalog_name; +unique_ptr BindContext::CreateColumnReference(const Identifier &schema_name, + const Identifier &table_name, + const Identifier &column_name, + ColumnBindType bind_type) { + Identifier catalog_name; return CreateColumnReference(catalog_name, schema_name, table_name, column_name, bind_type); } @@ -335,7 +341,7 @@ string BindContext::AmbiguityException(const BindingAlias &alias, const vector BindContext::GetBinding(const BindingAlias &alias, const string &column_name, +optional_ptr BindContext::GetBinding(const BindingAlias &alias, const Identifier &column_name, ErrorData &out_error) { auto matching_bindings = GetBindings(alias, out_error); if (matching_bindings.empty()) { @@ -378,21 +384,21 @@ optional_ptr BindContext::GetBinding(const BindingAlias &alias, ErrorDa return &matching_bindings[0].get(); } -optional_ptr BindContext::GetBinding(const string &name, ErrorData &out_error) { +optional_ptr BindContext::GetBinding(const Identifier &name, ErrorData &out_error) { return GetBinding(BindingAlias(name), out_error); } BindingAlias GetBindingAlias(ColumnRefExpression &colref) { - if (colref.column_names.size() <= 1 || colref.column_names.size() > 4) { + if (colref.ColumnNames().size() <= 1 || colref.ColumnNames().size() > 4) { throw InternalException("Cannot get binding alias from column ref unless it has 2..4 entries"); } - if (colref.column_names.size() >= 4) { - return BindingAlias(colref.column_names[0], colref.column_names[1], colref.column_names[2]); + if (colref.ColumnNames().size() >= 4) { + return BindingAlias(colref.ColumnNames()[0], colref.ColumnNames()[1], colref.ColumnNames()[2]); } - if (colref.column_names.size() == 3) { - return BindingAlias(colref.column_names[0], colref.column_names[1]); + if (colref.ColumnNames().size() == 3) { + return BindingAlias(colref.ColumnNames()[0], colref.ColumnNames()[1]); } - return BindingAlias(colref.column_names[0]); + return BindingAlias(colref.ColumnNames()[0]); } BindResult BindContext::BindColumn(ColumnRefExpression &colref, idx_t depth) { @@ -409,14 +415,14 @@ BindResult BindContext::BindColumn(ColumnRefExpression &colref, idx_t depth) { return binding->Bind(colref, depth); } -string BindContext::BindColumn(PositionalReferenceExpression &ref, string &table_name, string &column_name) { +string BindContext::BindColumn(PositionalReferenceExpression &ref, Identifier &table_name, Identifier &column_name) { idx_t total_columns = 0; - idx_t current_position = ref.index - 1; + idx_t current_position = ref.Index() - 1; for (auto &entry : bindings_list) { auto &binding = *entry; auto &column_names = binding.GetColumnNames(); idx_t entry_column_count = column_names.size(); - if (ref.index == 0) { + if (ref.Index() == 0) { // this is a row id table_name = binding.GetAlias(); column_name = "rowid"; @@ -431,11 +437,11 @@ string BindContext::BindColumn(PositionalReferenceExpression &ref, string &table current_position -= entry_column_count; } } - return StringUtil::Format("Positional reference %d out of range (total %d columns)", ref.index, total_columns); + return StringUtil::Format("Positional reference %d out of range (total %d columns)", ref.Index(), total_columns); } unique_ptr BindContext::PositionToColumn(PositionalReferenceExpression &ref) { - string table_name, column_name; + Identifier table_name, column_name; string error = BindColumn(ref, table_name, column_name); if (!error.empty()) { @@ -450,13 +456,13 @@ struct ExclusionListInfo { } vector> &new_select_list; - case_insensitive_set_t excluded_columns; + identifier_set_t excluded_columns; qualified_column_set_t excluded_qualified_columns; - case_insensitive_set_t replaced_columns; + identifier_set_t replaced_columns; }; bool CheckExclusionList(StarExpression &expr, const QualifiedColumnName &qualified_name, ExclusionListInfo &info) { - if (expr.exclude_list.find(qualified_name) != expr.exclude_list.end()) { + if (expr.ExcludeList().find(qualified_name) != expr.ExcludeList().end()) { info.excluded_qualified_columns.insert(qualified_name); return true; } @@ -465,8 +471,8 @@ bool CheckExclusionList(StarExpression &expr, const QualifiedColumnName &qualifi bool HandleRename(StarExpression &expr, const QualifiedColumnName &qualified_name, unique_ptr &new_expr, ExclusionListInfo &info) { - auto replace_entry = expr.replace_list.find(qualified_name.column); - if (replace_entry != expr.replace_list.end()) { + auto replace_entry = expr.ReplaceList().find(qualified_name.column); + if (replace_entry != expr.ReplaceList().end()) { if (info.replaced_columns.find(replace_entry->first) == info.replaced_columns.end()) { new_expr = replace_entry->second->Copy(); new_expr->SetAlias(replace_entry->first); @@ -476,8 +482,8 @@ bool HandleRename(StarExpression &expr, const QualifiedColumnName &qualified_nam return false; } } - auto rename_entry = expr.rename_list.find(qualified_name); - if (rename_entry != expr.rename_list.end()) { + auto rename_entry = expr.RenameList().find(qualified_name); + if (rename_entry != expr.RenameList().end()) { new_expr->SetAlias(rename_entry->second); } return true; @@ -489,7 +495,7 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, throw BinderException("* expression without FROM clause!"); } ExclusionListInfo exclusion_info(new_select_list); - if (expr.relation_name.empty()) { + if (expr.RelationName().empty()) { // SELECT * case // bind all expressions of each table in-order reference_set_t handled_using_columns; @@ -518,7 +524,7 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, auto coalesce = make_uniq_base(ExpressionType::OPERATOR_COALESCE); for (auto &child_binding : using_binding.bindings) { - coalesce->Cast().children.push_back( + coalesce->Cast().GetChildrenMutable().push_back( make_uniq(column_name, child_binding)); } coalesce->SetAlias(column_name); @@ -547,10 +553,10 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, // SELECT tbl.* case // SELECT struct.* case ErrorData error; - auto binding = GetBinding(expr.relation_name, error); + auto binding = GetBinding(expr.RelationName(), error); bool is_struct_ref = false; if (!binding) { - binding = GetMatchingBinding(expr.relation_name, expr); + binding = GetMatchingBinding(expr.RelationName(), expr); if (!binding) { error.Throw(); } @@ -561,16 +567,16 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, auto &column_types = binding->GetColumnTypes(); if (is_struct_ref) { - auto col_idx = binding->GetBindingIndex(expr.relation_name); + auto col_idx = binding->GetBindingIndex(expr.RelationName()); auto col_type = column_types[col_idx]; if (col_type.id() != LogicalTypeId::STRUCT) { throw BinderException(StringUtil::Format( "Cannot extract field from expression \"%s\" because it is not a struct", expr.ToString())); } auto &struct_children = StructType::GetChildTypes(col_type); - vector column_names(3); + vector column_names(3); column_names[0] = binding->GetAlias(); - column_names[1] = expr.relation_name; + column_names[1] = expr.RelationName(); for (auto &child : struct_children) { QualifiedColumnName qualified_name(child.first); if (CheckExclusionList(expr, qualified_name, exclusion_info)) { @@ -601,29 +607,30 @@ void BindContext::GenerateAllColumnExpressions(StarExpression &expr, binder.GetBindingMode() == BindingMode::EXTRACT_QUALIFIED_NAMES) { //! We only care about extracting the names of the referenced columns //! remove the exclude + replace lists - expr.exclude_list.clear(); - expr.replace_list.clear(); + expr.ExcludeListMutable().clear(); + expr.ReplaceListMutable().clear(); } //! Verify correctness of the exclude list - for (auto &excluded : expr.exclude_list) { + for (auto &excluded : expr.ExcludeList()) { if (exclusion_info.excluded_qualified_columns.find(excluded) == exclusion_info.excluded_qualified_columns.end()) { throw BinderException("Column \"%s\" in EXCLUDE list not found in %s", excluded.ToString(), - expr.relation_name.empty() ? "FROM clause" : expr.relation_name.c_str()); + expr.RelationName().empty() ? "FROM clause" : expr.RelationName().c_str()); } } //! Verify correctness of the replace list - for (auto &entry : expr.replace_list) { + for (auto &entry : expr.ReplaceList()) { if (exclusion_info.excluded_columns.find(entry.first) == exclusion_info.excluded_columns.end()) { - throw BinderException("Column \"%s\" in REPLACE list not found in %s", entry.first, - expr.relation_name.empty() ? "FROM clause" : expr.relation_name.c_str()); + throw BinderException("Column \"%s\" in REPLACE list not found in %s", entry.first.GetIdentifierName(), + expr.RelationName().empty() ? string("FROM clause") + : expr.RelationName().GetIdentifierName()); } } } -void BindContext::GetTypesAndNames(vector &result_names, vector &result_types) { +void BindContext::GetTypesAndNames(vector &result_names, vector &result_types) { for (auto &binding_entry : bindings_list) { auto &binding = *binding_entry; auto &column_names = binding.GetColumnNames(); @@ -640,14 +647,14 @@ void BindContext::AddBinding(unique_ptr binding) { bindings_list.push_back(std::move(binding)); } -void BindContext::AddBaseTable(TableIndex index, const string &alias, const vector &names, +void BindContext::AddBaseTable(TableIndex index, const Identifier &alias, const vector &names, const vector &types, vector &bound_column_ids, TableCatalogEntry &entry, virtual_column_map_t virtual_columns) { AddBinding( make_uniq(alias, types, names, bound_column_ids, &entry, index, std::move(virtual_columns))); } -void BindContext::AddBaseTable(TableIndex index, const string &alias, const vector &names, +void BindContext::AddBaseTable(TableIndex index, const Identifier &alias, const vector &names, const vector &types, vector &bound_column_ids, TableCatalogEntry &entry, bool add_virtual_columns) { virtual_column_map_t virtual_columns; @@ -657,39 +664,39 @@ void BindContext::AddBaseTable(TableIndex index, const string &alias, const vect AddBaseTable(index, alias, names, types, bound_column_ids, entry, std::move(virtual_columns)); } -void BindContext::AddBaseTable(TableIndex index, const string &alias, const vector &names, +void BindContext::AddBaseTable(TableIndex index, const Identifier &alias, const vector &names, const vector &types, vector &bound_column_ids, - const string &table_name) { + const Identifier &table_name) { virtual_column_map_t virtual_columns; AddBinding(make_uniq(alias.empty() ? table_name : alias, types, names, bound_column_ids, nullptr, index, std::move(virtual_columns))); } -void BindContext::AddTableFunction(TableIndex index, const string &alias, const vector &names, +void BindContext::AddTableFunction(TableIndex index, const Identifier &alias, const vector &names, const vector &types, vector &bound_column_ids, optional_ptr entry, virtual_column_map_t virtual_columns) { AddBinding( make_uniq(alias, types, names, bound_column_ids, entry, index, std::move(virtual_columns))); } -static string AddColumnNameToBinding(const string &base_name, case_insensitive_set_t ¤t_names) { +static Identifier AddColumnNameToBinding(const Identifier &base_name, identifier_set_t ¤t_names) { idx_t index = 1; - string name = base_name; + Identifier name = base_name; while (current_names.find(name) != current_names.end()) { - name = base_name + "_" + std::to_string(index++); + name = Identifier(base_name + "_" + std::to_string(index++)); } current_names.insert(name); return name; } -vector BindContext::AliasColumnNames(const string &table_name, const vector &names, - const vector &column_aliases) { - vector result; +vector BindContext::AliasColumnNames(const Identifier &table_name, const vector &names, + const vector &column_aliases) { + vector result; if (column_aliases.size() > names.size()) { throw BinderException("table \"%s\" has %lld columns available but %lld columns specified", table_name, names.size(), column_aliases.size()); } - case_insensitive_set_t current_names; + identifier_set_t current_names; // use any provided column aliases first for (idx_t i = 0; i < column_aliases.size(); i++) { result.push_back(AddColumnNameToBinding(column_aliases[i], current_names)); @@ -701,28 +708,29 @@ vector BindContext::AliasColumnNames(const string &table_name, const vec return result; } -void BindContext::AddSubquery(TableIndex index, const string &alias, SubqueryRef &ref, BoundStatement &subquery) { +void BindContext::AddSubquery(TableIndex index, const Identifier &alias, SubqueryRef &ref, BoundStatement &subquery) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddGenericBinding(index, alias, names, subquery.types); } -void BindContext::AddEntryBinding(TableIndex index, const string &alias, const vector &names, +void BindContext::AddEntryBinding(TableIndex index, const Identifier &alias, const vector &names, const vector &types, StandardEntry &entry) { AddBinding(make_uniq(alias, types, names, index, entry)); } -void BindContext::AddView(TableIndex index, const string &alias, SubqueryRef &ref, BoundStatement &subquery, +void BindContext::AddView(TableIndex index, const Identifier &alias, SubqueryRef &ref, BoundStatement &subquery, ViewCatalogEntry &view) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddEntryBinding(index, alias, names, subquery.types, view.Cast()); } -void BindContext::AddSubquery(TableIndex index, const string &alias, TableFunctionRef &ref, BoundStatement &subquery) { +void BindContext::AddSubquery(TableIndex index, const Identifier &alias, TableFunctionRef &ref, + BoundStatement &subquery) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddGenericBinding(index, alias, names, subquery.types); } -void BindContext::AddGenericBinding(TableIndex index, const string &alias, const vector &names, +void BindContext::AddGenericBinding(TableIndex index, const Identifier &alias, const vector &names, const vector &types) { AddBinding(make_uniq(BindingType::BASE, BindingAlias(alias), types, names, index)); } @@ -736,7 +744,7 @@ void BindContext::AddCTEBinding(unique_ptr binding) { cte_bindings.push_back(std::move(binding)); } -void BindContext::AddCTEBinding(TableIndex index, BindingAlias alias_p, const vector &names, +void BindContext::AddCTEBinding(TableIndex index, BindingAlias alias_p, const vector &names, const vector &types, CTEType cte_type) { auto binding = make_uniq(std::move(alias_p), types, names, index, cte_type); AddCTEBinding(std::move(binding)); @@ -773,7 +781,7 @@ vector BindContext::GetBindingAliases() { void BindContext::RemoveContext(const vector &aliases) { for (auto &alias : aliases) { // remove the binding from any USING columns - vector removed_using_columns; + vector removed_using_columns; for (auto &using_sets : using_columns) { for (auto &using_set_ref : using_sets.second) { auto &using_set = using_set_ref.get(); diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index 6c38bf44e..24fdc8d75 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -32,6 +32,7 @@ #include "duckdb/planner/query_node/list.hpp" #include "duckdb/planner/tableref/list.hpp" #include "duckdb/storage/data_table.hpp" +#include "duckdb/common/types/uuid.hpp" #include @@ -58,31 +59,22 @@ shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr parent_p, BinderType binder_type) : context(context), bind_context(*this), parent(std::move(parent_p)), binder_type(binder_type), global_binder_state(parent ? parent->global_binder_state : make_shared_ptr()), - query_binder_state(parent && binder_type == BinderType::REGULAR_BINDER ? parent->query_binder_state - : make_shared_ptr()), entry_retriever(context), depth(parent ? parent->GetBinderDepth() : 1) { IncreaseDepth(); if (parent) { entry_retriever.Inherit(parent->entry_retriever); + inside_subquery = parent->inside_subquery; // We have to inherit macro and lambda parameter bindings and from the parent binder, if there is a parent. macro_binding = parent->macro_binding; lambda_bindings = parent->lambda_bindings; + if (binder_type != BinderType::VIEW_BINDER) { + // inherit expression binders from parent + active_binders = parent->active_binders; + } } } -template -BoundStatement Binder::BindWithCTE(T &statement) { - auto &cte_map = statement.cte_map; - if (cte_map.map.empty()) { - return Bind(statement); - } - - auto stmt_node = make_uniq(statement); - stmt_node->cte_map = cte_map.Copy(); - return Bind(*stmt_node); -} - BoundStatement Binder::Bind(SQLStatement &statement) { switch (statement.type) { case StatementType::SELECT_STATEMENT: @@ -136,7 +128,11 @@ BoundStatement Binder::Bind(SQLStatement &statement) { case StatementType::UPDATE_EXTENSIONS_STATEMENT: return Bind(statement.Cast()); case StatementType::MERGE_INTO_STATEMENT: - return BindWithCTE(statement.Cast()); + return Bind(statement.Cast()); + case StatementType::CONNECT_STATEMENT: + return Bind(statement.Cast()); + case StatementType::DISCONNECT_STATEMENT: + return Bind(statement.Cast()); default: // LCOV_EXCL_START throw NotImplementedException("Unimplemented statement type \"%s\" for Bind", StatementTypeToString(statement.type)); @@ -245,18 +241,26 @@ void Binder::SetParameters(BoundParameterMap ¶meters) { global_binder_state->parameters = parameters; } -void Binder::PushExpressionBinder(ExpressionBinder &binder) { - GetActiveBinders().push_back(binder); +void Binder::SetInsideSubquery() { + inside_subquery = true; } -void Binder::PopExpressionBinder() { - D_ASSERT(HasActiveBinder()); - GetActiveBinders().pop_back(); +bool Binder::IsInsideSubquery() const { + return inside_subquery; } -void Binder::SetActiveBinder(ExpressionBinder &binder) { - D_ASSERT(HasActiveBinder()); - GetActiveBinders().back() = binder; +void Binder::BeginSubqueryBind(Binder &parent, ExpressionBinder &binder) { + // push all active expression binders + auto &active_binders = GetActiveBinders(); + for (auto &active_binder : parent.GetActiveBinders()) { + active_binders.push_back(active_binder); + } + // finally push this binder + active_binders.push_back(binder); +} + +void Binder::FinishSubqueryBind() { + GetActiveBinders().clear(); } ExpressionBinder &Binder::GetActiveBinder() { @@ -268,7 +272,7 @@ bool Binder::HasActiveBinder() { } vector> &Binder::GetActiveBinders() { - return query_binder_state->active_binders; + return active_binders; } void Binder::AddUsingBindingSet(unique_ptr set) { @@ -293,20 +297,20 @@ void Binder::AddCorrelatedColumn(const CorrelatedColumnInfo &info) { } } -optional_ptr Binder::GetMatchingBinding(const string &table_name, const string &column_name, +optional_ptr Binder::GetMatchingBinding(const Identifier &table_name, const Identifier &column_name, ErrorData &error) { - string empty_schema; + Identifier empty_schema; return GetMatchingBinding(empty_schema, table_name, column_name, error); } -optional_ptr Binder::GetMatchingBinding(const string &schema_name, const string &table_name, - const string &column_name, ErrorData &error) { - string empty_catalog; +optional_ptr Binder::GetMatchingBinding(const Identifier &schema_name, const Identifier &table_name, + const Identifier &column_name, ErrorData &error) { + Identifier empty_catalog; return GetMatchingBinding(empty_catalog, schema_name, table_name, column_name, error); } -optional_ptr Binder::GetMatchingBinding(const string &catalog_name, const string &schema_name, - const string &table_name, const string &column_name, +optional_ptr Binder::GetMatchingBinding(const Identifier &catalog_name, const Identifier &schema_name, + const Identifier &table_name, const Identifier &column_name, ErrorData &error) { optional_ptr binding; if (macro_binding && table_name == macro_binding->GetAlias()) { @@ -327,7 +331,15 @@ BindingMode Binder::GetBindingMode() { } void Binder::SetCanContainNulls(bool can_contain_nulls_p) { - can_contain_nulls = can_contain_nulls_p; + legacy_can_contain_nulls = can_contain_nulls_p; +} + +bool Binder::CanContainNulls() const { + if (!Settings::Get(context)) { + // if this legacy setting is not enabled nulls are not special - they can just occur anywhere + return true; + } + return legacy_can_contain_nulls; } void Binder::SetAlwaysRequireRebind() { @@ -339,7 +351,7 @@ void Binder::AddTableName(string table_name) { global_binder_state->table_names.insert(std::move(table_name)); } -void Binder::AddReplacementScan(const string &table_name, unique_ptr replacement) { +void Binder::AddReplacementScan(const Identifier &table_name, unique_ptr replacement) { auto it = global_binder_state->replacement_scans.find(table_name); replacement->column_name_alias.clear(); replacement->alias.clear(); @@ -354,7 +366,7 @@ const unordered_set &Binder::GetTableNames() { return global_binder_state->table_names; } -case_insensitive_map_t> &Binder::GetReplacementScans() { +identifier_map_t> &Binder::GetReplacementScans() { return global_binder_state->replacement_scans; } @@ -511,17 +523,17 @@ void Binder::BindDeleteIndexColumns(TableCatalogEntry &table, LogicalGet &get, v } BoundStatement Binder::BindReturning(vector> returning_list, TableCatalogEntry &table, - const string &alias, TableIndex update_table_index, + const Identifier &alias, TableIndex update_table_index, unique_ptr child_operator, virtual_column_map_t virtual_columns) { vector types; - vector names; + vector names; auto binder = Binder::CreateBinder(context); vector bound_columns; idx_t column_count = 0; for (auto &col : table.GetColumns().Logical()) { - names.push_back(col.Name()); + names.emplace_back(col.Name()); types.push_back(col.Type()); if (!col.Generated()) { bound_columns.emplace_back(column_count); @@ -562,14 +574,14 @@ BoundStatement Binder::BindReturning(vector> return return result; } -optional_ptr Binder::GetCatalogEntry(const string &catalog, const string &schema, +optional_ptr Binder::GetCatalogEntry(const Identifier &catalog, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found) { return entry_retriever.GetEntry(catalog, schema, lookup_info, on_entry_not_found); } //! Create a binder whose catalog search path is anchored to the table's catalog+schema -shared_ptr Binder::CreateBinderWithSearchPath(const string &catalog_name, const string &schema_name) { +shared_ptr Binder::CreateBinderWithSearchPath(const Identifier &catalog_name, const Identifier &schema_name) { shared_ptr new_binder = Binder::CreateBinder(context, this); vector search_path; @@ -582,11 +594,9 @@ shared_ptr Binder::CreateBinderWithSearchPath(const string &catalog_name return new_binder; } -static constexpr const char *TRIGGER_BASE_CTE_NAME = "__duckdb_trigger_base"; - -unique_ptr Binder::TryExpandAfterTriggers(QueryNode &node, - vector> &returning_list, - TableCatalogEntry &table, TriggerEventType event_type) { +unique_ptr Binder::TryExpandTriggers(QueryNode &node, + vector> &returning_list, + TableCatalogEntry &table, TriggerEventType event_type) { auto &expanded_tables = global_binder_state->trigger_expanded_tables; if (expanded_tables.find(table) != expanded_tables.end()) { if (global_binder_state->trigger_creation_table == &table) { @@ -596,46 +606,136 @@ unique_ptr Binder::TryExpandAfterTriggers(QueryNode &node, } return nullptr; } - auto triggers = table.GetTriggersForEvent(table.ParentCatalog().GetCatalogTransaction(context), - TriggerTiming::AFTER, event_type); - if (triggers.empty()) { + auto txn = table.ParentCatalog().GetCatalogTransaction(context); + auto before_triggers = table.GetTriggersForEvent(txn, TriggerTiming::BEFORE, event_type); + auto after_triggers = table.GetTriggersForEvent(txn, TriggerTiming::AFTER, event_type); + + // UPDATE OF : drop triggers whose OF list is disjoint from the SET list. + // Triggers without an OF list are unrestricted and always fire. + if (event_type == TriggerEventType::UPDATE_EVENT && node.type == QueryNodeType::UPDATE_QUERY_NODE) { + auto &update_node = node.Cast(); + identifier_set_t updated_columns; + if (update_node.set_info) { + updated_columns.insert(update_node.set_info->columns.begin(), update_node.set_info->columns.end()); + } + auto trigger_does_not_fire = [&](const_reference trig) { + const auto &of_cols = trig.get().columns; + return !of_cols.empty() && std::none_of(of_cols.begin(), of_cols.end(), + [&](const Identifier &c) { return updated_columns.count(c) > 0; }); + }; + auto drop_non_firing = [&](vector> &triggers) { + triggers.erase(std::remove_if(triggers.begin(), triggers.end(), trigger_does_not_fire), triggers.end()); + }; + drop_non_firing(before_triggers); + drop_non_firing(after_triggers); + } + + if (before_triggers.empty() && after_triggers.empty()) { return nullptr; } if (!returning_list.empty()) { - throw NotImplementedException("RETURNING is not yet supported on tables with AFTER triggers"); + throw NotImplementedException("RETURNING is not yet supported on tables with triggers"); + } + if (node.type == QueryNodeType::INSERT_QUERY_NODE) { + auto &insert_node = node.Cast(); + if (insert_node.on_conflict_info && insert_node.on_conflict_info->action_type != OnConflictAction::NOTHING) { + // Only AFTER triggers can carry REFERENCING NEW TABLE — BEFORE rejects it at CREATE time. + for (auto &trigger : after_triggers) { + if (!trigger.get().referencing_new_table.empty()) { + throw NotImplementedException( + "ON CONFLICT DO UPDATE is not yet supported with REFERENCING NEW TABLE AS triggers"); + } + } + } } expanded_tables.insert(table); - return make_uniq(ExpandAfterTriggers(node, returning_list, triggers)); + auto bound = ExpandTriggers(node, returning_list, before_triggers, after_triggers); + + // Erasing from the set, so we will track expanded tables only while we're on the same node in the recursive stack, + // meaning we're on the same "trigger" in the trigger chain. + expanded_tables.erase(table); + return make_uniq(std::move(bound)); +} + +static constexpr const char *TRIGGER_BASE_CTE_PREFIX = "__duckdb_trigger_base_"; +static constexpr const char *TRIGGER_BODY_CTE_PREFIX = "__duckdb_trigger_body_"; +static constexpr const char *TRIGGER_BEFORE_BODY_CTE_PREFIX = "__duckdb_trigger_before_body_"; + +static unique_ptr MakeTransitionTableAliasCTE(const Identifier &base_cte_name) { + auto alias_cte = make_uniq(); + auto alias_select = make_uniq(); + alias_select->select_list.push_back(make_uniq()); + auto alias_ref = make_uniq(); + alias_ref->table_name = base_cte_name; + alias_select->from_table = std::move(alias_ref); + alias_cte->query_node = std::move(alias_select); + alias_cte->materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; + return alias_cte; } -BoundStatement Binder::ExpandAfterTriggers(QueryNode &node, vector> &returning_list, - const vector> &triggers) { - // multiple triggers per table are not yet supported - D_ASSERT(triggers.size() == 1); +BoundStatement Binder::ExpandTriggers(QueryNode &node, vector> &returning_list, + const vector> &before_triggers, + const vector> &after_triggers) { + D_ASSERT(!before_triggers.empty() || !after_triggers.empty()); D_ASSERT(returning_list.empty()); returning_list.push_back(make_uniq()); - auto base_cte = make_uniq(); - base_cte->query_node = node.Copy(); - base_cte->materialized = CTEMaterialize::CTE_MATERIALIZE_ALWAYS; - base_cte->is_trigger_generated = true; - - // Unreferenced DML CTE - auto &trigger = triggers[0].get(); - auto trig_cte = make_uniq(); - trig_cte->query_node = trigger.trigger_action->Copy(); - trig_cte->materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - trig_cte->is_trigger_generated = true; + auto uuid_suffix = UUID::ToString(UUID::GenerateRandomUUID()); + Identifier base_cte_name(TRIGGER_BASE_CTE_PREFIX + uuid_suffix); // count(*) over the base CTE gives CHANGED_ROWS ("N rows affected") to the client auto outer = make_uniq(); outer->select_list.push_back(make_uniq("count_star", vector>())); auto from_ref = make_uniq(); - from_ref->table_name = TRIGGER_BASE_CTE_NAME; + from_ref->table_name = base_cte_name; outer->from_table = std::move(from_ref); - outer->cte_map.map[TRIGGER_BASE_CTE_NAME] = std::move(base_cte); - outer->cte_map.map["__duckdb_trigger_1"] = std::move(trig_cte); + + // CTE definition order == execution order: BEFORE bodies, then base (DML), then AFTER bodies. + + // BEFORE bodies (no transition tables — REFERENCING is rejected at CREATE TRIGGER time) + for (idx_t i = 0; i < before_triggers.size(); i++) { + auto &trigger = before_triggers[i].get(); + auto body_cte_name = string(TRIGGER_BEFORE_BODY_CTE_PREFIX) + to_string(i + 1) + "_" + uuid_suffix; + + auto trig_cte = make_uniq(); + trig_cte->query_node = trigger.trigger_action->Copy(); + trig_cte->materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; + trig_cte->is_trigger_generated = true; + + outer->cte_map.map[Identifier(body_cte_name)] = std::move(trig_cte); + } + + // Base CTE: the original DML, materialized so AFTER bodies can scan it (transition tables) + auto base_cte = make_uniq(); + base_cte->query_node = node.Copy(); + base_cte->materialized = CTEMaterialize::CTE_MATERIALIZE_ALWAYS; + base_cte->is_trigger_generated = true; + outer->cte_map.map[base_cte_name] = std::move(base_cte); + + // AFTER bodies — may reference base via REFERENCING NEW/OLD TABLE aliases + for (idx_t i = 0; i < after_triggers.size(); i++) { + auto &trigger = after_triggers[i].get(); + auto body_cte_name = string(TRIGGER_BODY_CTE_PREFIX) + to_string(i + 1) + "_" + uuid_suffix; + + auto trig_cte = make_uniq(); + trig_cte->query_node = trigger.trigger_action->Copy(); + trig_cte->materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; + trig_cte->is_trigger_generated = true; + + // Inject the transition-table aliases (REFERENCING NEW/OLD TABLE) into this trigger body's own + // cte_map. That scopes the aliases to this trigger only (sibling triggers can't see them), and + // if the body has a local WITH that defines the same name, the local WITH shadows the alias. + auto &body_map = trig_cte->query_node->cte_map.map; + if (!trigger.referencing_new_table.empty() && body_map.find(trigger.referencing_new_table) == body_map.end()) { + body_map[trigger.referencing_new_table] = MakeTransitionTableAliasCTE(base_cte_name); + } + if (!trigger.referencing_old_table.empty() && body_map.find(trigger.referencing_old_table) == body_map.end()) { + body_map[trigger.referencing_old_table] = MakeTransitionTableAliasCTE(base_cte_name); + } + + outer->cte_map.map[Identifier(body_cte_name)] = std::move(trig_cte); + } auto bound = Bind(*outer); auto &properties = GetStatementProperties(); diff --git a/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp index 463e13e3d..47128cd9a 100644 --- a/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp @@ -14,7 +14,6 @@ #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression_binder/aggregate_binder.hpp" #include "duckdb/planner/expression_binder/base_select_binder.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" @@ -22,7 +21,7 @@ namespace duckdb { -static bool IsFunctionallyDependent(const unique_ptr &expr, const vector> &deps) { +static bool IsFunctionallyDependent(const unique_ptr &expr, const vector> &deps) { // Volatile expressions can't depend on anything else if (expr->IsVolatile()) { return false; @@ -34,7 +33,7 @@ static bool IsFunctionallyDependent(const unique_ptr &expr, const ve // If the expression matches ANY of the dependencies, then it is FD on them for (const auto &dep : deps) { // We don't need to check volatility of the dependencies because we checked it for the expression. - if (expr->Equals(*dep)) { + if (expr->Equals(dep.get())) { return true; } } @@ -114,16 +113,28 @@ static void NegatePercentileFractions(ClientContext &context, unique_ptrbound_aggregate = true; unique_ptr bound_filter; - AggregateBinder aggregate_binder(binder, context); ErrorData error; + inside_aggregate_filter = true; + this->inside_aggregate = true; + auto initial_bound_column_count = GetBoundColumns().size(); // Now we bind the filter (if any) - if (aggr.filter) { - aggregate_binder.BindChild(aggr.filter, 0, error); + if (aggr.Filter()) { + BindChild(aggr.FilterMutable(), 0, error); } + inside_aggregate_filter = false; // Handle ordered-set aggregates by moving the single ORDER BY expression to the front of the children. // https://www.postgresql.org/docs/current/functions-aggregate.html#FUNCTIONS-ORDEREDSET-TABLE @@ -131,32 +142,32 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu // and only inject the ordering expression if there are too few. idx_t ordered_set_agg = 0; bool negate_fractions = false; - if (aggr.order_bys && aggr.order_bys->orders.size() == 1) { - const auto &func_name = aggr.function_name; + if (aggr.OrderBy() && aggr.OrderBy()->orders.size() == 1) { + const auto &func_name = aggr.FunctionName(); if (func_name == "mode") { ordered_set_agg = 1; } else if (func_name == "quantile_cont" || func_name == "quantile_disc") { ordered_set_agg = 2; auto &config = DBConfig::GetConfig(context); - const auto &order = aggr.order_bys->orders[0]; + const auto &order = aggr.OrderBy()->orders[0]; const auto sense = config.ResolveOrder(context, order.type); negate_fractions = (sense == OrderType::DESCENDING); } } - for (idx_t i = 0; i < aggr.children.size(); ++i) { - auto &child = aggr.children[i]; - aggregate_binder.BindChild(child, 0, error); + for (idx_t i = 0; i < aggr.GetArguments().size(); ++i) { + auto &child = aggr.GetArgumentsMutable()[i]; + BindChild(child.GetExpressionMutable(), 0, error); // We have to negate the fractions for PERCENTILE_XXXX DESC - if (!error.HasError() && ordered_set_agg && i == aggr.children.size() - 1) { - NegatePercentileFractions(context, child, negate_fractions); + if (!error.HasError() && ordered_set_agg && i == aggr.GetArguments().size() - 1) { + NegatePercentileFractions(context, child.GetExpressionMutable(), negate_fractions); } } // Bind the ORDER BYs, if any - if (aggr.order_bys && !aggr.order_bys->orders.empty()) { - for (auto &order : aggr.order_bys->orders) { + if (aggr.OrderBy() && !aggr.OrderBy()->orders.empty()) { + for (auto &order : aggr.OrderByMutable()->orders) { if (order.expression->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { auto &const_expr = order.expression->Cast(); if (!const_expr.GetValue().type().IsIntegral()) { @@ -170,37 +181,41 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu } } } - aggregate_binder.BindChild(order.expression, 0, error); + BindChild(order.expression, 0, error); } } + inside_aggregate = false; + bool bound_columns = GetBoundColumns().size() > initial_bound_column_count; + TruncateBoundColumns(initial_bound_column_count); if (error.HasError()) { // failed to bind child - if (aggregate_binder.HasBoundColumns()) { - for (idx_t i = 0; i < aggr.children.size(); i++) { + if (bound_columns) { + for (auto &child : aggr.GetArgumentsMutable()) { // however, we bound columns! // that means this aggregation belongs to this node // check if we have to resolve any errors by binding with parent binders - auto result = aggregate_binder.BindCorrelatedColumns(aggr.children[i], error); + auto result = BindCorrelatedColumns(child.GetExpressionMutable(), error); // if there is still an error after this, we could not successfully bind the aggregate if (result.HasError()) { result.error.Throw(); } - auto &bound_expr = BoundExpression::GetExpression(*aggr.children[i]); + auto &bound_expr = BoundExpression::GetExpression(*child.GetExpressionMutable()); ExtractCorrelatedExpressions(binder, *bound_expr); } - if (aggr.filter) { - auto result = aggregate_binder.BindCorrelatedColumns(aggr.filter, error); + + if (aggr.Filter()) { + auto result = BindCorrelatedColumns(aggr.FilterMutable(), error); // if there is still an error after this, we could not successfully bind the aggregate if (result.HasError()) { result.error.Throw(); } - auto &bound_expr = BoundExpression::GetExpression(*aggr.filter); + auto &bound_expr = BoundExpression::GetExpression(*aggr.Filter()); ExtractCorrelatedExpressions(binder, *bound_expr); } - if (aggr.order_bys && !aggr.order_bys->orders.empty()) { - for (auto &order : aggr.order_bys->orders) { - auto result = aggregate_binder.BindCorrelatedColumns(order.expression, error); + if (aggr.OrderBy() && !aggr.OrderBy()->orders.empty()) { + for (auto &order : aggr.OrderByMutable()->orders) { + auto result = BindCorrelatedColumns(order.expression, error); if (result.HasError()) { result.error.Throw(); } @@ -212,71 +227,57 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu // we didn't bind columns, try again in children return BindResult(std::move(error)); } - } else if (depth > 0 && !aggregate_binder.HasBoundColumns()) { + } else if (depth > 0 && !bound_columns) { return BindResult("Aggregate with only constant parameters has to be bound in the root subquery"); } - if (aggr.filter) { - auto &child = BoundExpression::GetExpression(*aggr.filter); + if (aggr.Filter()) { + auto &child = BoundExpression::GetExpression(*aggr.Filter()); bound_filter = BoundCastExpression::AddCastToType(context, std::move(child), LogicalType::BOOLEAN); } - // all children bound successfully - // extract the children and types - vector types; - vector arguments; - vector> children; + // all children bound successfully - collect them (with their explicit names, if any) into the full argument list. + // The positional/named split and (for legacy calls) the alias capture are resolved later, per candidate overload. + vector>> arguments; if (ordered_set_agg) { - const bool order_sensitive = (aggr.function_name == "mode"); - // Inject missing ordering arguments - if (aggr.children.size() < ordered_set_agg) { - for (auto &order : aggr.order_bys->orders) { + const bool order_sensitive = (aggr.FunctionName() == "mode"); + // Inject missing ordering arguments as positional arguments + if (aggr.GetArguments().size() < ordered_set_agg) { + for (auto &order : aggr.OrderByMutable()->orders) { auto &child = BoundExpression::GetExpression(*order.expression); - types.push_back(child->GetReturnType()); - arguments.push_back(child->GetReturnType()); if (order_sensitive) { - children.push_back(child->Copy()); + arguments.emplace_back(string(), child->Copy()); } else { - children.push_back(std::move(child)); + arguments.emplace_back(string(), std::move(child)); } } } if (!order_sensitive) { - aggr.order_bys->orders.clear(); + aggr.OrderByMutable()->orders.clear(); } } - for (idx_t i = 0; i < aggr.children.size(); i++) { - auto &child = BoundExpression::GetExpression(*aggr.children[i]); - types.push_back(child->GetReturnType()); - arguments.push_back(child->GetReturnType()); - children.push_back(std::move(child)); - } - - // bind the aggregate - FunctionBinder function_binder(binder); - auto best_function = function_binder.BindFunction(func.name, func.functions, types, error); - if (!best_function.IsValid()) { - error.AddQueryLocation(aggr); - error.Throw(); - } - // found a matching function! - const auto &bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); - - if (!bound_function.CanAggregate() && bound_function.CanWindow()) { - auto msg = StringUtil::Format("Function '%s' can only be used as a window function", bound_function.GetName()); - error = BinderException(msg); - error.AddQueryLocation(aggr); - error.Throw(); + for (auto &arg : aggr.GetArgumentsMutable()) { + auto &bound_arg = BoundExpression::GetExpression(*arg.GetExpressionMutable()); + // legacy function calls cannot have named arguments, so we ignore the names of the arguments during binding + // and pass them all positionally, aliasing them by their name (see BindFunction for the rationale) + if (!arg.GetName().empty()) { + bound_arg->SetAlias(arg.GetName()); + } + if (aggr.IsLegacyFunctionCall()) { + arguments.emplace_back(string(), std::move(bound_arg)); + } else { + arguments.emplace_back(arg.GetName(), std::move(bound_arg)); + } } // Bind any sort columns, unless the aggregate is order-insensitive unique_ptr order_bys; - if (!aggr.order_bys->orders.empty()) { + if (!aggr.OrderBy()->orders.empty()) { order_bys = make_uniq(); auto &config = DBConfig::GetConfig(context); - for (auto &order : aggr.order_bys->orders) { + for (auto &order : aggr.OrderByMutable()->orders) { auto &order_expr = BoundExpression::GetExpression(*order.expression); PushCollation(context, order_expr, order_expr->GetReturnType()); const auto sense = config.ResolveOrder(context, order.type); @@ -286,10 +287,15 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu } // If the aggregate is DISTINCT then the ORDER BYs need to be functional dependencies of the arguments. - if (aggr.distinct && order_bys) { + if (aggr.Distinct() && order_bys) { + vector> arg_refs; + arg_refs.reserve(arguments.size()); + for (auto &arg : arguments) { + arg_refs.emplace_back(*arg.second); + } bool in_args = true; for (const auto &order_by : order_bys->orders) { - in_args &= IsFunctionallyDependent(order_by.expression, children); + in_args &= IsFunctionallyDependent(order_by.expression, arg_refs); } if (!in_args) { @@ -297,13 +303,30 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu } } + // Bind the function + FunctionBinder function_binder(binder); auto aggregate = - function_binder.BindAggregateFunction(bound_function, std::move(children), std::move(bound_filter), - aggr.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT); - if (aggr.export_state) { + function_binder.BindAggregateFunction(func, std::move(arguments), error, std::move(bound_filter), + aggr.Distinct() ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT); + // No function found, throw an error + if (!aggregate) { + error.AddQueryLocation(aggr); + error.Throw(); + } + + // If the function cannot be used as an aggregate, but can be used as a window function, throw a specific error + // message + if (!aggregate->Function().CanAggregate() && aggregate->Function().CanWindow()) { + auto msg = StringUtil::Format("Function '%s' can only be used as a window function", func.name); + error = BinderException(msg); + error.AddQueryLocation(aggr); + error.Throw(); + } + + if (aggr.ExportState()) { aggregate = ExportAggregateFunction::Bind(std::move(aggregate)); } - aggregate->order_bys = std::move(order_bys); + aggregate->GetOrderBysMutable() = std::move(order_bys); // check for all the aggregates if this aggregate already exists ProjectionIndex aggr_index; @@ -320,9 +343,9 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu auto &bound_aggr = *node.aggregates[aggr_index]; // now create a column reference referring to the aggregate - auto colref = make_uniq(aggr.GetAlias().empty() ? bound_aggr.ToString() : aggr.GetAlias(), - bound_aggr.GetReturnType(), - ColumnBinding(node.aggregate_index, aggr_index), depth); + auto colref = make_uniq( + aggr.GetAlias().empty() ? Identifier(bound_aggr.ToString()) : aggr.GetAlias(), bound_aggr.GetReturnType(), + ColumnBinding(node.aggregate_index, aggr_index), depth); // move the aggregate expression into the set of bound aggregates return BindResult(std::move(colref)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp index abf55a2bc..3c23dfd2d 100644 --- a/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_between_expression.cpp @@ -12,16 +12,16 @@ namespace duckdb { BindResult ExpressionBinder::BindExpression(BetweenExpression &expr, idx_t depth) { // first try to bind the children of the case expression ErrorData error; - BindChild(expr.input, depth, error); - BindChild(expr.lower, depth, error); - BindChild(expr.upper, depth, error); + BindChild(expr.InputMutable(), depth, error); + BindChild(expr.LowerBoundMutable(), depth, error); + BindChild(expr.UpperBoundMutable(), depth, error); if (error.HasError()) { return BindResult(std::move(error)); } // the children have been successfully resolved - auto &input = BoundExpression::GetExpression(*expr.input); - auto &lower = BoundExpression::GetExpression(*expr.lower); - auto &upper = BoundExpression::GetExpression(*expr.upper); + auto &input = BoundExpression::GetExpression(*expr.InputMutable()); + auto &lower = BoundExpression::GetExpression(*expr.LowerBoundMutable()); + auto &upper = BoundExpression::GetExpression(*expr.UpperBoundMutable()); auto input_sql_type = ExpressionBinder::GetExpressionReturnType(*input); auto lower_sql_type = ExpressionBinder::GetExpressionReturnType(*lower); diff --git a/src/duckdb/src/planner/binder/expression/bind_case_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_case_expression.cpp index a96a2e2bb..7e1742ea0 100644 --- a/src/duckdb/src/planner/binder/expression/bind_case_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_case_expression.cpp @@ -9,19 +9,19 @@ namespace duckdb { BindResult ExpressionBinder::BindExpression(CaseExpression &expr, idx_t depth) { // first try to bind the children of the case expression ErrorData error; - for (auto &check : expr.case_checks) { + for (auto &check : expr.CaseChecksMutable()) { BindChild(check.when_expr, depth, error); BindChild(check.then_expr, depth, error); } - BindChild(expr.else_expr, depth, error); + BindChild(expr.ElseMutable(), depth, error); if (error.HasError()) { return BindResult(std::move(error)); } // the children have been successfully resolved // figure out the result type of the CASE expression - auto &else_expr = BoundExpression::GetExpression(*expr.else_expr); + auto &else_expr = BoundExpression::GetExpression(*expr.ElseMutable()); auto return_type = ExpressionBinder::GetExpressionReturnType(*else_expr); - for (auto &check : expr.case_checks) { + for (auto &check : expr.CaseChecksMutable()) { auto &then_expr = BoundExpression::GetExpression(*check.then_expr); auto then_type = ExpressionBinder::GetExpressionReturnType(*then_expr); if (!LogicalType::TryGetMaxLogicalType(context, return_type, then_type, return_type)) { @@ -33,16 +33,16 @@ BindResult ExpressionBinder::BindExpression(CaseExpression &expr, idx_t depth) { // bind all the individual components of the CASE statement auto result = make_uniq(return_type); - for (auto &check : expr.case_checks) { + for (auto &check : expr.CaseChecksMutable()) { auto &when_expr = BoundExpression::GetExpression(*check.when_expr); auto &then_expr = BoundExpression::GetExpression(*check.then_expr); BoundCaseCheck result_check; result_check.when_expr = BoundCastExpression::AddCastToType(context, std::move(when_expr), LogicalType::BOOLEAN); result_check.then_expr = BoundCastExpression::AddCastToType(context, std::move(then_expr), return_type); - result->case_checks.push_back(std::move(result_check)); + result->CaseChecksMutable().push_back(std::move(result_check)); } - result->else_expr = BoundCastExpression::AddCastToType(context, std::move(else_expr), return_type); + result->ElseMutable() = BoundCastExpression::AddCastToType(context, std::move(else_expr), return_type); return BindResult(std::move(result)); } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/expression/bind_cast_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_cast_expression.cpp index 79d9047e1..2387d72f6 100644 --- a/src/duckdb/src/planner/binder/expression/bind_cast_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_cast_expression.cpp @@ -8,24 +8,24 @@ namespace duckdb { BindResult ExpressionBinder::BindExpression(CastExpression &expr, idx_t depth) { // first try to bind the child of the cast expression - auto error = Bind(expr.child, depth); + auto error = Bind(expr.ChildMutable(), depth); if (error.HasError()) { return BindResult(std::move(error)); } // FIXME: We can also implement 'hello'::schema.custom_type; and pass by the schema down here. // Right now just considering its DEFAULT_SCHEMA always - binder.BindLogicalType(expr.cast_type); + binder.BindLogicalType(expr.TargetTypeMutable()); // the children have been successfully resolved - auto &child = BoundExpression::GetExpression(*expr.child); - if (expr.try_cast) { - if (ExpressionBinder::GetExpressionReturnType(*child) == expr.cast_type) { + auto &child = BoundExpression::GetExpression(*expr.ChildMutable()); + if (expr.IsTryCast()) { + if (ExpressionBinder::GetExpressionReturnType(*child) == expr.TargetType()) { // no cast required: type matches return BindResult(std::move(child)); } - child = BoundCastExpression::AddCastToType(context, std::move(child), expr.cast_type, true); + child = BoundCastExpression::AddCastToType(context, std::move(child), expr.TargetType(), true); } else { // otherwise add a cast to the target type - child = BoundCastExpression::AddCastToType(context, std::move(child), expr.cast_type); + child = BoundCastExpression::AddCastToType(context, std::move(child), expr.TargetType()); } return BindResult(std::move(child)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_collate_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_collate_expression.cpp index 5f897834c..9a39e73ec 100644 --- a/src/duckdb/src/planner/binder/expression/bind_collate_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_collate_expression.cpp @@ -7,11 +7,11 @@ namespace duckdb { BindResult ExpressionBinder::BindExpression(CollateExpression &expr, idx_t depth) { // first try to bind the child of the cast expression - auto error = Bind(expr.child, depth); + auto error = Bind(expr.ChildMutable(), depth); if (error.HasError()) { return BindResult(std::move(error)); } - auto &child = BoundExpression::GetExpression(*expr.child); + auto &child = BoundExpression::GetExpression(*expr.ChildMutable()); if (child->HasParameter()) { throw ParameterNotResolvedException(); } @@ -20,7 +20,7 @@ BindResult ExpressionBinder::BindExpression(CollateExpression &expr, idx_t depth } // Validate the collation, but don't use it auto collation_test = make_uniq_base(Value(child->GetReturnType())); - auto collation_type = LogicalType::VARCHAR_COLLATION(expr.collation); + auto collation_type = LogicalType::VARCHAR_COLLATION(expr.Collation()); PushCollation(context, collation_test, collation_type); child->SetReturnType(collation_type); return BindResult(std::move(child)); diff --git a/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp index d4520f8c8..41cba781c 100644 --- a/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_columnref_expression.cpp @@ -9,500 +9,51 @@ #include "duckdb/parser/expression/subquery_expression.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/planner/binder.hpp" +#include "duckdb/planner/column_qualifier.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression_binder.hpp" #include "duckdb/planner/expression_binder/where_binder.hpp" +#include "duckdb/planner/column_qualifier.hpp" namespace duckdb { -string GetSQLValueFunctionName(const string &column_name) { - auto lcase = StringUtil::Lower(column_name); - if (lcase == "current_catalog") { - return "current_catalog"; - } else if (lcase == "current_date") { - return "current_date"; - } else if (lcase == "current_schema") { - return "current_schema"; - } else if (lcase == "current_role") { - return "current_role"; - } else if (lcase == "current_time") { - return "get_current_time"; - } else if (lcase == "current_timestamp") { - return "get_current_timestamp"; - } else if (lcase == "current_user") { - return "current_user"; - } else if (lcase == "localtime") { - return "current_localtime"; - } else if (lcase == "localtimestamp") { - return "current_localtimestamp"; - } else if (lcase == "session_user") { - return "session_user"; - } else if (lcase == "user") { - return "user"; - } - return string(); -} - -unique_ptr ExpressionBinder::GetSQLValueFunction(const string &column_name) { - auto value_function = GetSQLValueFunctionName(column_name); - if (value_function.empty()) { - return nullptr; - } - - vector> children; - return make_uniq(value_function, std::move(children)); -} - -unique_ptr ExpressionBinder::QualifyColumnName(const ParsedExpression &expr, - const string &column_name, ErrorData &error) { - auto using_binding = binder.bind_context.GetUsingBinding(column_name); - if (using_binding) { - // we are referencing a USING column - // check if we can refer to one of the base columns directly - unique_ptr expression; - if (using_binding->primary_binding.IsSet()) { - // we can! just assign the table name and re-bind - return binder.bind_context.CreateColumnReference(using_binding->primary_binding, column_name); - } else { - // we cannot! we need to bind this as COALESCE between all the relevant columns - auto coalesce = make_uniq(ExpressionType::OPERATOR_COALESCE); - coalesce->children.reserve(using_binding->bindings.size()); - for (auto &entry : using_binding->bindings) { - coalesce->children.push_back(make_uniq(column_name, entry)); - } - return std::move(coalesce); - } - } - - // try binding as a lambda parameter - auto lambda_ref = LambdaRefExpression::FindMatchingBinding(lambda_bindings, column_name); - if (lambda_ref) { - return lambda_ref; - } - - // find a table binding that contains this column name - auto table_binding = binder.bind_context.GetMatchingBinding(column_name, expr); - - // throw an error if a macro parameter name conflicts with a column name - auto is_macro_column = false; - if (binder.macro_binding && binder.macro_binding->HasMatchingBinding(column_name)) { - is_macro_column = true; - if (table_binding) { - throw BinderException(expr, "Conflicting column names for column " + column_name + "!"); - } - } - - // bind as a macro column - if (is_macro_column) { - return binder.bind_context.CreateColumnReference(binder.macro_binding->GetBindingAlias(), column_name); - } - - // bind as a regular column - if (table_binding) { - return binder.bind_context.CreateColumnReference(table_binding->GetBindingAlias(), column_name); - } - - // it's not, find candidates and error - auto similar_bindings = binder.bind_context.GetSimilarBindings(column_name); - error = ErrorData(BinderException::ColumnNotFound(column_name, similar_bindings)); - return nullptr; -} - -void ExpressionBinder::QualifyColumnNames(unique_ptr &expr, - vector> &lambda_params, - const bool within_function_expression) { - bool next_within_function_expression = false; - switch (expr->GetExpressionType()) { - case ExpressionType::COLUMN_REF: { - auto &col_ref = expr->Cast(); - - // don't qualify lambda parameters - if (LambdaExpression::IsLambdaParameter(lambda_params, col_ref.GetName())) { - return; - } - - ErrorData error; - auto new_expr = QualifyColumnName(col_ref, error); - - if (new_expr) { - if (!expr->GetAlias().empty()) { - // Pre-existing aliases are added to the qualified column reference - new_expr->SetAlias(expr->GetAlias()); - } else if (within_function_expression) { - // Qualifying the column reference may add an alias, but this needs to be removed within function - // expressions, because the alias here means a named parameter instead of a positional parameter - new_expr->ClearAlias(); - } - - // replace the expression with the qualified column reference - new_expr->SetQueryLocation(col_ref.GetQueryLocation()); - expr = std::move(new_expr); - } - return; - } - case ExpressionType::POSITIONAL_REFERENCE: { - auto &ref = expr->Cast(); - if (ref.GetAlias().empty()) { - string table_name, column_name; - auto error = binder.bind_context.BindColumn(ref, table_name, column_name); - if (error.empty()) { - ref.SetAlias(column_name); - } - } - break; - } - case ExpressionType::FUNCTION: { - // Special-handling for lambdas, which are inside function expressions. - auto &function = expr->Cast(); - if (!IsUnnestFunction(function.function_name)) { - BindAndQualifyFunction(function, false); - } - if (function.IsLambdaFunction()) { - return QualifyColumnNamesInLambda(function, lambda_params); - } - - next_within_function_expression = true; - break; - } - default: // fall through - break; - } - - // recurse on the child expressions - ParsedExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { - QualifyColumnNames(child, lambda_params, next_within_function_expression); - }); -} - -void ExpressionBinder::QualifyColumnNamesInLambda(FunctionExpression &function, - vector> &lambda_params) { - for (auto &child : function.children) { - if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { - // not a lambda expression - QualifyColumnNames(child, lambda_params, true); - continue; - } - - // special-handling for LHS lambda parameters - // we do not qualify them, and we add them to the lambda_params vector - auto &lambda_expr = child->Cast(); - string error_message; - auto column_ref_expressions = lambda_expr.ExtractColumnRefExpressions(error_message); - - if (!error_message.empty()) { - // possibly a JSON function, qualify both LHS and RHS - QualifyColumnNames(lambda_expr.lhs, lambda_params, true); - QualifyColumnNames(lambda_expr.expr, lambda_params, true); - continue; - } - - // push this level - lambda_params.emplace_back(); - - // push the lambda parameter names - for (const auto &column_ref_expr : column_ref_expressions) { - const auto &column_ref = column_ref_expr.get().Cast(); - lambda_params.back().emplace(column_ref.GetName()); - } - - // only qualify in RHS - QualifyColumnNames(lambda_expr.expr, lambda_params, true); - - // pop this level - lambda_params.pop_back(); - } -} - -void ExpressionBinder::QualifyColumnNames(Binder &binder, unique_ptr &expr) { - WhereBinder where_binder(binder, binder.context); - vector> lambda_params; - where_binder.QualifyColumnNames(expr, lambda_params); -} - -void ExpressionBinder::QualifyColumnNames(ExpressionBinder &expression_binder, unique_ptr &expr) { - vector> lambda_params; - expression_binder.QualifyColumnNames(expr, lambda_params); +unique_ptr ExpressionBinder::GetSQLValueFunction(const Identifier &column_name) { + return binder.GetSQLValueFunction(column_name); } unique_ptr ExpressionBinder::CreateStructExtract(unique_ptr base, - const string &field_name) { - vector> children; - children.push_back(std::move(base)); - children.push_back(make_uniq_base(Value(field_name))); - auto extract_fun = make_uniq(ExpressionType::STRUCT_EXTRACT, std::move(children)); - return std::move(extract_fun); + const Identifier &field_name) { + ColumnQualifier qualifier(binder); + return qualifier.CreateStructExtract(std::move(base), field_name); } unique_ptr ExpressionBinder::CreateStructPack(ColumnRefExpression &col_ref) { - if (col_ref.column_names.size() > 3) { - return nullptr; - } - D_ASSERT(!col_ref.column_names.empty()); - - // get a matching binding - ErrorData error; - optional_ptr binding; - switch (col_ref.column_names.size()) { - case 1: { - // single entry - this must be the table name - BindingAlias alias(col_ref.column_names[0]); - binding = binder.bind_context.GetBinding(alias, error); - break; - } - case 2: { - // two entries - this can either be "catalog.table" or "schema.table" - try both - BindingAlias alias(col_ref.column_names[0], col_ref.column_names[1]); - binding = binder.bind_context.GetBinding(alias, error); - if (!binding) { - alias = BindingAlias(col_ref.column_names[0], INVALID_SCHEMA, col_ref.column_names[1]); - binding = binder.bind_context.GetBinding(alias, error); - } - break; - } - case 3: { - // three entries - this must be "catalog.schema.table" - BindingAlias alias(col_ref.column_names[0], col_ref.column_names[1], col_ref.column_names[2]); - binding = binder.bind_context.GetBinding(alias, error); - break; - } - default: - throw InternalException("Expected 1, 2 or 3 column names for CreateStructPack"); - } - if (!binding) { - return nullptr; - } - - // We found the table, now create the struct_pack expression - auto &column_names = binding->GetColumnNames(); - vector> child_expressions; - child_expressions.reserve(column_names.size()); - for (const auto &column_name : column_names) { - child_expressions.push_back(binder.bind_context.CreateColumnReference( - binding->GetBindingAlias(), column_name, ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS)); - } - return make_uniq("struct_pack", std::move(child_expressions)); + ColumnQualifier qualifier(binder); + return qualifier.CreateStructPack(col_ref); } -unique_ptr ExpressionBinder::QualifyColumnNameWithManyDotsInternal(ColumnRefExpression &col_ref, - ErrorData &error, - idx_t &struct_extract_start) { - // two or more dots (i.e. "part1.part2.part3.part4...") - // -> part1 is a catalog, part2 is a schema, part3 is a table, part4 is a column name, part 5 and beyond are - // struct fields - // -> part1 is a catalog, part2 is a table, part3 is a column name, part4 and beyond are struct fields - // -> part1 is a schema, part2 is a table, part3 is a column name, part4 and beyond are struct fields - // -> part1 is a table, part2 is a column name, part3 and beyond are struct fields - // -> part1 is a column, part2 and beyond are struct fields - - // we always prefer the most top-level view - // i.e. in case of multiple resolution options, we resolve in order: - // -> 1. resolve "part1" as a catalog - // -> 2. resolve "part1" as a schema - // -> 3. resolve "part1" as a table - // -> 4. resolve "part1" as a column - - // first check if part1 is a catalog - ErrorData fully_qualified_error; - optional_ptr binding; - if (col_ref.column_names.size() > 3) { - binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], col_ref.column_names[2], - col_ref.column_names[3], fully_qualified_error); - if (binding) { - // part1 is a catalog - the column reference is "catalog.schema.table.column" - struct_extract_start = 4; - return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[3]); - } - } - ErrorData catalog_table_error; - binding = binder.GetMatchingBinding(col_ref.column_names[0], INVALID_SCHEMA, col_ref.column_names[1], - col_ref.column_names[2], catalog_table_error); - if (binding) { - // part1 is a catalog - the column reference is "catalog.table.column" - struct_extract_start = 3; - return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[2]); - } - ErrorData schema_table_error; - binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], col_ref.column_names[2], - schema_table_error); - if (binding) { - // part1 is a schema - the column reference is "schema.table.column" - // any additional fields are turned into struct_extract calls - struct_extract_start = 3; - return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[2]); - } - ErrorData table_column_error; - binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], table_column_error); - if (binding) { - // part1 is a table - // the column reference is "table.column" - // any additional fields are turned into struct_extract calls - struct_extract_start = 2; - return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.column_names[1]); - } - // part1 could be a column - ErrorData unused_error; - auto result_expr = QualifyColumnName(col_ref, col_ref.column_names[0], unused_error); - if (result_expr) { - // it is! add the struct extract calls - struct_extract_start = 1; - return result_expr; - } - auto struct_pack = CreateStructPack(col_ref); - if (struct_pack) { - return struct_pack; - } - - // we could not find the column - return an error - // try to find out which error to return - optional_idx catalog_pos; - optional_idx schema_pos; - optional_idx table_pos; - for (const auto &binding_entry : binder.bind_context.GetBindingsList()) { - auto &alias = binding_entry->GetBindingAlias(); - string catalog = alias.GetCatalog(); - string schema = alias.GetSchema(); - string table = alias.GetAlias(); - auto entry = binding_entry->GetStandardEntry(); - if (entry) { - catalog = entry->ParentCatalog().GetName(); - schema = entry->ParentSchema().name; - } - - for (idx_t i = 0; i < 3; i++) { - if (catalog == col_ref.column_names[i]) { - catalog_pos = i; - } - if (schema == col_ref.column_names[i]) { - schema_pos = i; - } - if (table == col_ref.column_names[i]) { - table_pos = i; - } - } - } - - error = std::move(table_column_error); - if (table_pos.IsValid()) { - // we have a valid table - if (table_pos.GetIndex() == 0) { - // the first column is a valid table - return the error as if we were binding the column as the second - - } else if (table_pos.GetIndex() == 1) { - // the first column is a valid table - // this means either the first column is a catalog OR the first column is a schema - if (catalog_pos.IsValid()) { - // catalog.table - error = std::move(catalog_table_error); - } else { - // assume schema.table otherwise - error = std::move(schema_table_error); - } - } else if (col_ref.column_names.size() > 3) { - // the second column is a valid table - // return the fully qualified error if we have enough components - error = std::move(fully_qualified_error); - } - } else if (catalog_pos.IsValid()) { - // the first column is a catalog - if (schema_pos.IsValid()) { - // AND a valid schema - this is likely "catalog.schema.table" - if (col_ref.column_names.size() > 3) { - error = std::move(fully_qualified_error); - } else { - error = std::move(catalog_table_error); - } - } else { - // but no valid schema - this could be either "catalog.schema.table" or "catalog.table" - // return "catalog.table" for now - error = std::move(catalog_table_error); - } - } else if (schema_pos.IsValid()) { - // we did not find a catalog or a table, but we did find a schema - if (schema_pos.GetIndex() == 0) { - // "schema.table" - error = std::move(schema_table_error); - } else if (schema_pos.GetIndex() == 1 && col_ref.column_names.size() > 3) { - // catalog.schema.table - error = std::move(fully_qualified_error); - } - } - return nullptr; +void ExpressionBinder::QualifyColumnNames(Binder &binder, unique_ptr &expr, + optional_ptr alias_binder) { + ColumnQualifier qualifier(binder, nullptr, alias_binder); + vector lambda_params; + qualifier.QualifyColumnNames(expr, lambda_params); } -unique_ptr ExpressionBinder::QualifyColumnNameWithManyDots(ColumnRefExpression &col_ref, - ErrorData &error) { - idx_t struct_extract_start = col_ref.column_names.size(); - auto result_expr = QualifyColumnNameWithManyDotsInternal(col_ref, error, struct_extract_start); - if (!result_expr) { - return nullptr; - } - - // create a struct extract with all remaining column names - for (idx_t i = struct_extract_start; i < col_ref.column_names.size(); i++) { - result_expr = CreateStructExtract(std::move(result_expr), col_ref.column_names[i]); - } - - return result_expr; +void ExpressionBinder::QualifyColumnNames(ExpressionBinder &expression_binder, unique_ptr &expr) { + ColumnQualifier qualifier(expression_binder.binder, expression_binder.lambda_bindings); + vector lambda_params; + qualifier.QualifyColumnNames(expr, lambda_params); } -unique_ptr ExpressionBinder::QualifyColumnName(ColumnRefExpression &col_ref, ErrorData &error) { - if (!col_ref.IsQualified()) { - // Try binding as a lambda parameter. - auto lambda_ref = LambdaRefExpression::FindMatchingBinding(lambda_bindings, col_ref.GetColumnName()); - if (lambda_ref) { - return lambda_ref; - } - } - - idx_t column_parts = col_ref.column_names.size(); - - // column names can have an arbitrary amount of dots - // here is how the resolution works: - if (column_parts == 1) { - // no dots (i.e. "part1") - // -> part1 refers to a column - // check if we can qualify the column name with the table name - auto qualified_col_ref = QualifyColumnName(col_ref, col_ref.GetColumnName(), error); - if (qualified_col_ref) { - // we could: return it - return qualified_col_ref; - } - // we could not! Try creating an implicit struct_pack - return CreateStructPack(col_ref); - } - - if (column_parts == 2) { - // one dot (i.e. "part1.part2") - // EITHER: - // -> part1 is a table, part2 is a column - // -> part1 is a column, part2 is a property of that column (i.e. struct_extract) - - // first check if part1 is a table, and part2 is a standard column name - auto binding = binder.GetMatchingBinding(col_ref.column_names[0], col_ref.column_names[1], error); - if (binding) { - // it is! return the column reference directly - return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.GetColumnName()); - } - - // otherwise check if we can turn this into a struct extract - ErrorData other_error; - auto qualified_col_ref = QualifyColumnName(col_ref, col_ref.column_names[0], other_error); - if (qualified_col_ref) { - // we could: create a struct extract - return CreateStructExtract(std::move(qualified_col_ref), col_ref.column_names[1]); - } - // we could not! Try creating an implicit struct_pack - return CreateStructPack(col_ref); - } - - // three or more dots - return QualifyColumnNameWithManyDots(col_ref, error); +BindResult ExpressionBinder::BindExpression(LambdaRefExpression &lambda_ref, idx_t depth) { + D_ASSERT(lambda_bindings && lambda_ref.LambdaIndex() < lambda_bindings->size()); + return (*lambda_bindings)[lambda_ref.LambdaIndex()].Bind(lambda_ref, depth); } -BindResult ExpressionBinder::BindExpression(LambdaRefExpression &lambda_ref, idx_t depth) { - D_ASSERT(lambda_bindings && lambda_ref.lambda_idx < lambda_bindings->size()); - return (*lambda_bindings)[lambda_ref.lambda_idx].Bind(lambda_ref, depth); +unique_ptr ExpressionBinder::QualifyColumnName(ColumnRefExpression &col_ref, ErrorData &error) { + ColumnQualifier qualifier(binder, lambda_bindings); + return qualifier.QualifyColumnName(col_ref, error); } BindResult ExpressionBinder::BindExpression(ColumnRefExpression &col_ref_p, idx_t depth, bool root_expression, @@ -567,7 +118,7 @@ BindResult ExpressionBinder::BindExpression(ColumnRefExpression &col_ref_p, idx_ // we bound the column reference BoundColumnReferenceInfo ref; - ref.name = col_ref.column_names.back(); + ref.name = col_ref.ColumnNames().back(); ref.query_location = col_ref.GetQueryLocation(); bound_columns.push_back(std::move(ref)); return result; diff --git a/src/duckdb/src/planner/binder/expression/bind_comparison_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_comparison_expression.cpp index da49d53b5..9cb511b76 100644 --- a/src/duckdb/src/planner/binder/expression/bind_comparison_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_comparison_expression.cpp @@ -78,7 +78,7 @@ bool BoundComparisonExpression::TryBindComparison(ClientContext &context, const break; } if (is_equality) { - res = LogicalType::ForceMaxLogicalType(left_type, right_type); + res = LogicalType::ForceMaxLogicalType(context, left_type, right_type); } else { if (!LogicalType::TryGetMaxLogicalType(context, left_type, right_type, res)) { return false; @@ -147,8 +147,8 @@ LogicalType ExpressionBinder::GetExpressionReturnType(const Expression &expr) { } if (expr.GetReturnType().IsIntegral()) { auto &constant = expr.Cast(); - if (!constant.value.IsNull()) { - return LogicalType::INTEGER_LITERAL(constant.value); + if (!constant.GetValue().IsNull()) { + return LogicalType::INTEGER_LITERAL(constant.GetValue()); } } } @@ -158,15 +158,15 @@ LogicalType ExpressionBinder::GetExpressionReturnType(const Expression &expr) { BindResult ExpressionBinder::BindExpression(ComparisonExpression &expr, idx_t depth) { // first try to bind the children of the case expression ErrorData error; - BindChild(expr.left, depth, error); - BindChild(expr.right, depth, error); + BindChild(expr.LeftMutable(), depth, error); + BindChild(expr.RightMutable(), depth, error); if (error.HasError()) { return BindResult(std::move(error)); } // the children have been successfully resolved - auto &left = BoundExpression::GetExpression(*expr.left); - auto &right = BoundExpression::GetExpression(*expr.right); + auto &left = BoundExpression::GetExpression(*expr.LeftMutable()); + auto &right = BoundExpression::GetExpression(*expr.RightMutable()); auto left_sql_type = ExpressionBinder::GetExpressionReturnType(*left); auto right_sql_type = ExpressionBinder::GetExpressionReturnType(*right); // cast the input types to the same type diff --git a/src/duckdb/src/planner/binder/expression/bind_conjunction_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_conjunction_expression.cpp index b93cb5a2d..cb797382b 100644 --- a/src/duckdb/src/planner/binder/expression/bind_conjunction_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_conjunction_expression.cpp @@ -8,8 +8,8 @@ namespace duckdb { BindResult ExpressionBinder::BindExpression(ConjunctionExpression &expr, idx_t depth) { // first try to bind the children of the case expression ErrorData error; - for (idx_t i = 0; i < expr.children.size(); i++) { - BindChild(expr.children[i], depth, error); + for (idx_t i = 0; i < expr.GetChildrenMutable().size(); i++) { + BindChild(expr.GetChildrenMutable()[i], depth, error); } if (error.HasError()) { return BindResult(std::move(error)); @@ -18,9 +18,10 @@ BindResult ExpressionBinder::BindExpression(ConjunctionExpression &expr, idx_t d // cast the input types to boolean (if necessary) // and construct the bound conjunction expression auto result = make_uniq(expr.GetExpressionType()); - for (auto &child_expr : expr.children) { + for (auto &child_expr : expr.GetChildrenMutable()) { auto &child = BoundExpression::GetExpression(*child_expr); - result->children.push_back(BoundCastExpression::AddCastToType(context, std::move(child), LogicalType::BOOLEAN)); + result->GetChildrenMutable().push_back( + BoundCastExpression::AddCastToType(context, std::move(child), LogicalType::BOOLEAN)); } // now create the bound conjunction expression return BindResult(std::move(result)); diff --git a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp index 8c8103c34..c9b597967 100644 --- a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp @@ -15,6 +15,7 @@ #include "duckdb/planner/expression/bound_lambda_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression_binder.hpp" +#include "duckdb/planner/column_qualifier.hpp" #include "duckdb/main/settings.hpp" #include @@ -115,7 +116,7 @@ ListReduceRebindResult MaybeRebindListReduceLambda(ClientContext &context, idx_t const bool has_initial = function_child_types.size() == 3; auto &bound_lambda_expr = bind_lambda_result.expression->Cast(); - const auto &lambda_return_type = bound_lambda_expr.lambda_expr->GetReturnType(); + const auto &lambda_return_type = bound_lambda_expr.LambdaExpr()->GetReturnType(); auto list_child_type = function_child_types[0]; if (list_child_type.id() != LogicalTypeId::SQLNULL && list_child_type.id() != LogicalTypeId::UNKNOWN) { @@ -159,16 +160,16 @@ ListReduceRebindResult MaybeRebindListReduceLambda(ClientContext &context, idx_t // Avoid repeated rebinds for DECIMAL type widening by forcing the lambda return type to the chosen // accumulator type when decimals are involved. if (TypeContainsDecimal(accumulator_type) || - TypeContainsDecimal(rebound_lambda_expr.lambda_expr->GetReturnType())) { - if (rebound_lambda_expr.lambda_expr->GetReturnType() != accumulator_type) { - const auto old_return_type = rebound_lambda_expr.lambda_expr->GetReturnType(); - auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(rebound_lambda_expr.lambda_expr), - accumulator_type); + TypeContainsDecimal(rebound_lambda_expr.LambdaExpr()->GetReturnType())) { + if (rebound_lambda_expr.LambdaExpr()->GetReturnType() != accumulator_type) { + const auto old_return_type = rebound_lambda_expr.LambdaExpr()->GetReturnType(); + auto cast_expr = BoundCastExpression::AddCastToType( + context, std::move(rebound_lambda_expr.LambdaExprMutable()), accumulator_type); if (!cast_expr) { throw BinderException("Could not cast lambda return type %s to accumulator type %s", old_return_type.ToString(), accumulator_type.ToString()); } - rebound_lambda_expr.lambda_expr = std::move(cast_expr); + rebound_lambda_expr.LambdaExprMutable() = std::move(cast_expr); } } result.did_rebind = true; @@ -203,8 +204,7 @@ BindResult ExpressionBinder::TryBindLambdaOrJson(FunctionExpression &function, i } auto setting = Settings::Get(context); - bool invalid_syntax = - setting == LambdaSyntax::DISABLE_SINGLE_ARROW && syntax_type == LambdaSyntaxType::SINGLE_ARROW; + bool invalid_syntax = setting != LambdaSyntax::ENABLE_SINGLE_ARROW && syntax_type == LambdaSyntaxType::SINGLE_ARROW; bool warn_deprecated_syntax = setting == LambdaSyntax::DEFAULT && syntax_type == LambdaSyntaxType::SINGLE_ARROW; const string msg = "Deprecated lambda arrow (->) detected. Please transition to the new lambda syntax, " "i.e.., lambda x, i: x + i, before DuckDB's next release.\n" @@ -233,14 +233,7 @@ BindResult ExpressionBinder::TryBindLambdaOrJson(FunctionExpression &function, i } return lambda_bind_result; } - return BindResult(msg); - } - if (StringUtil::Contains(lambda_bind_result.error.RawMessage(), "Deprecated lambda arrow (->) detected.")) { - if (warn_deprecated_syntax) { - DUCKDB_LOG_WARNING(context, msg); - } - - return lambda_bind_result; + throw BinderException(msg); } auto json_bind_result = BindFunction(function, func.Cast(), depth); @@ -254,106 +247,66 @@ BindResult ExpressionBinder::TryBindLambdaOrJson(FunctionExpression &function, i json_bind_result.error.RawMessage()); } -optional_ptr ExpressionBinder::BindAndQualifyFunction(FunctionExpression &function, bool allow_throw) { - D_ASSERT(!IsUnnestFunction(function.function_name)); - // lookup the function in the catalog +CatalogEntry &ExpressionBinder::BindFunction(FunctionExpression &function) { QueryErrorContext error_context(function.GetQueryLocation()); - binder.BindSchemaOrCatalog(function.catalog, function.schema); - EntryLookupInfo function_lookup(CatalogType::SCALAR_FUNCTION_ENTRY, function.function_name, error_context); - auto func = GetCatalogEntry(function.catalog, function.schema, function_lookup, OnEntryNotFound::RETURN_NULL); + ColumnQualifier qualifier(binder); + auto func = qualifier.QualifyFunction(function); if (!func) { - // function was not found - check if we this is a table function - EntryLookupInfo table_function_lookup(CatalogType::TABLE_FUNCTION_ENTRY, function.function_name, error_context); + // function was not found - check if we this is a table function (to throw a more helpful error message) + EntryLookupInfo table_function_lookup(CatalogType::TABLE_FUNCTION_ENTRY, function.FunctionName(), + error_context); auto table_func = - GetCatalogEntry(function.catalog, function.schema, table_function_lookup, OnEntryNotFound::RETURN_NULL); + GetCatalogEntry(function.Catalog(), function.Schema(), table_function_lookup, OnEntryNotFound::RETURN_NULL); if (table_func) { - if (!allow_throw) { - return func; - } throw BinderException(function, "Function \"%s\" is a table function but it was used as a scalar function. This " "function has to be called in a FROM clause (similar to a table).", - function.function_name); - } - // not a table function - check if the schema is set - if (!function.schema.empty()) { - // the schema is set - check if we can turn this the schema into a column ref - // does this function exist in the system catalog? - func = GetCatalogEntry(INVALID_CATALOG, INVALID_SCHEMA, function_lookup, OnEntryNotFound::RETURN_NULL); - if (func) { - // the function exists in the system catalog - turn this into a dot call - ErrorData error; - unique_ptr colref; - if (function.catalog.empty()) { - colref = make_uniq(function.schema); - } else { - colref = make_uniq(function.schema, function.catalog); - } - auto new_colref = QualifyColumnName(*colref, error); - if (error.HasError()) { - // could not find the column - try to qualify the alias - if (!DoesColumnAliasExist(*colref)) { - if (!allow_throw) { - return func; - } - // no alias found either - throw - error.Throw(); - } - } - // we can! transform this into a function call on the column - // i.e. "x.lower()" becomes "lower(x)" - function.children.insert(function.children.begin(), std::move(colref)); - function.catalog = INVALID_CATALOG; - function.schema = INVALID_SCHEMA; - } - } - // rebind the function - if (!func) { - const auto on_entry_not_found = - allow_throw ? OnEntryNotFound::THROW_EXCEPTION : OnEntryNotFound::RETURN_NULL; - func = GetCatalogEntry(function.catalog, function.schema, function_lookup, on_entry_not_found); + function.FunctionName()); } + // not a table function - rebind to throw an error + EntryLookupInfo function_lookup(CatalogType::SCALAR_FUNCTION_ENTRY, function.FunctionName(), error_context); + func = + GetCatalogEntry(function.Catalog(), function.Schema(), function_lookup, OnEntryNotFound::THROW_EXCEPTION); } - - return func; + return *func; } BindResult ExpressionBinder::BindExpression(FunctionExpression &function, idx_t depth, unique_ptr &expr_ptr) { - auto func = BindAndQualifyFunction(function, true); + auto &func = BindFunction(function); - switch (func->type) { + switch (func.type) { case CatalogType::AGGREGATE_FUNCTION_ENTRY: case CatalogType::WINDOW_FUNCTION_ENTRY: break; default: - if (function.distinct || function.filter || !function.order_bys->orders.empty()) { + if (function.Distinct() || function.Filter() || !function.OrderBy()->orders.empty()) { throw InvalidInputException("Function \"%s\" is a %s. \"DISTINCT\", \"FILTER\", and \"ORDER BY\" are only " "applicable to window and aggregate functions.", - function.function_name, CatalogTypeToString(func->type)); + function.FunctionName(), CatalogTypeToString(func.type)); } break; } - switch (func->type) { + switch (func.type) { case CatalogType::SCALAR_FUNCTION_ENTRY: { auto child = function.IsLambdaFunction(); if (child) { - auto syntax_type = child->Cast().syntax_type; - return TryBindLambdaOrJson(function, depth, *func, syntax_type); + auto syntax_type = child->Cast().GetLambdaSyntaxType(); + return TryBindLambdaOrJson(function, depth, func, syntax_type); } - return BindFunction(function, func->Cast(), depth); + return BindFunction(function, func.Cast(), depth); } case CatalogType::MACRO_ENTRY: // macro function - return BindMacro(function, func->Cast(), depth, expr_ptr); + return BindMacro(function, func.Cast(), depth, expr_ptr); case CatalogType::AGGREGATE_FUNCTION_ENTRY: // aggregate function - return BindAggregate(function, func->Cast(), depth); + return BindAggregate(function, func.Cast(), depth); case CatalogType::WINDOW_FUNCTION_ENTRY: // window function - return BindWindow(function, func->Cast(), depth); + return BindWindow(function, func.Cast(), depth); default: throw InvalidInputException("Unsupported catalog type when binding function"); } @@ -363,9 +316,9 @@ BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFu // bind the children of the function expression ErrorData error; - // bind of each child - for (idx_t i = 0; i < function.children.size(); i++) { - BindChild(function.children[i], depth, error); + // bind each child + for (idx_t i = 0; i < function.GetArguments().size(); i++) { + BindChild(function.GetArgumentsMutable()[i].GetExpressionMutable(), depth, error); } if (error.HasError()) { @@ -376,23 +329,38 @@ BindResult ExpressionBinder::BindFunction(FunctionExpression &function, ScalarFu return BindResult(make_uniq(Value(LogicalType::SQLNULL))); } - // all children bound successfully - // extract the children and types - vector> children; - for (idx_t i = 0; i < function.children.size(); i++) { - auto &child = BoundExpression::GetExpression(*function.children[i]); - children.push_back(std::move(child)); + // all children bound successfully - collect them (with their explicit names, if any) into the full argument list. + // The positional/named split (and, for capturing functions, the alias capture) is resolved later per candidate + // overload. + vector>> arguments; + arguments.reserve(function.GetArguments().size()); + for (auto &arg : function.GetArgumentsMutable()) { + auto &bound_arg = BoundExpression::GetExpression(*arg.GetExpressionMutable()); + + // legacy function calls cannot have named arguments, so we ignore the names of the arguments during binding + // and pass them all positionally. We do alias them by their name though, so that alias-capturing functions + // (e.g. struct_pack) still work and so that re-serializing to the old format can match arguments by name. + // Only override the alias when the argument actually carries a name, otherwise we would clobber the + // display alias the binding assigned (e.g. clearing a column reference's name to its raw binding). + if (!arg.GetName().empty()) { + bound_arg->SetAlias(arg.GetName()); + } + if (function.IsLegacyFunctionCall()) { + arguments.emplace_back(string(), std::move(bound_arg)); + } else { + arguments.emplace_back(arg.GetName(), std::move(bound_arg)); + } } FunctionBinder function_binder(binder); - auto result = function_binder.BindScalarFunction(func, std::move(children), error, function.is_operator, &binder); + auto result = function_binder.BindScalarFunction(func, std::move(arguments), error, function.IsOperator(), &binder); if (!result) { error.AddQueryLocation(function); error.Throw(); } if (result->GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { auto &bound_function = result->Cast(); - if (bound_function.function.GetStability() == FunctionStability::CONSISTENT_WITHIN_QUERY) { + if (bound_function.Function().GetStability() == FunctionStability::CONSISTENT_WITHIN_QUERY) { binder.SetAlwaysRequireRebind(); } } @@ -408,37 +376,40 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc return BindResult("This scalar function does not support lambdas!"); } + auto &args = function.GetArgumentsMutable(); + // the first child is the list, the second child is the lambda expression // constexpr idx_t list_ix = 0; constexpr idx_t list_idx = 0; constexpr idx_t lambda_expr_idx = 1; - D_ASSERT(function.children[lambda_expr_idx]->GetExpressionClass() == ExpressionClass::LAMBDA); + D_ASSERT(args[lambda_expr_idx].GetExpression().GetExpressionClass() == ExpressionClass::LAMBDA); vector function_child_types; // bind the list ErrorData error; - for (idx_t i = 0; i < function.children.size(); i++) { + + for (idx_t i = 0; i < function.GetArguments().size(); i++) { if (i == lambda_expr_idx) { function_child_types.push_back(LogicalType::LAMBDA); continue; } - if (function.children[i]->GetExpressionClass() == ExpressionClass::LAMBDA) { + if (args[i].GetExpression().GetExpressionClass() == ExpressionClass::LAMBDA) { return BindResult("No function matches the given name and argument types: '" + function.ToString() + "'. You might need to add explicit type casts."); } - BindChild(function.children[i], depth, error); + BindChild(function.GetArgumentsMutable()[i].GetExpressionMutable(), depth, error); if (error.HasError()) { return BindResult(std::move(error)); } - const auto &child = BoundExpression::GetExpression(*function.children[i]); + const auto &child = BoundExpression::GetExpression(*args[i].GetExpressionMutable()); function_child_types.push_back(child->GetReturnType()); } // get the logical type of the children of the list - auto &list_child = BoundExpression::GetExpression(*function.children[list_idx]); + auto &list_child = BoundExpression::GetExpression(*args[list_idx].GetExpressionMutable()); if (list_child->GetReturnType().id() != LogicalTypeId::LIST && list_child->GetReturnType().id() != LogicalTypeId::ARRAY && list_child->GetReturnType().id() != LogicalTypeId::SQLNULL && @@ -447,7 +418,7 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc } // bind the lambda parameter - auto &lambda_expr = function.children[lambda_expr_idx]->Cast(); + auto &lambda_expr = args[lambda_expr_idx].GetExpressionMutable()->Cast(); unique_ptr lambda_expr_copy; const bool is_list_reduce = func.name == "list_reduce"; @@ -492,12 +463,10 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc } // successfully bound: replace the node with a BoundExpression - auto alias = function.children[lambda_expr_idx]->GetAlias(); + auto alias = args[lambda_expr_idx].GetExpression().GetAlias(); bind_lambda_result.expression->SetAlias(alias); - if (!alias.empty()) { - bind_lambda_result.expression->SetAlias(alias); - } - function.children[lambda_expr_idx] = make_uniq(std::move(bind_lambda_result.expression)); + + args[lambda_expr_idx].GetExpressionMutable() = make_uniq(std::move(bind_lambda_result.expression)); if (binder.GetBindingMode() == BindingMode::EXTRACT_NAMES) { return BindResult(make_uniq(Value(LogicalType::SQLNULL))); @@ -506,19 +475,19 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc // all children bound successfully // extract the children and types vector> children; - for (idx_t i = 0; i < function.children.size(); i++) { - auto &child = BoundExpression::GetExpression(*function.children[i]); + for (idx_t i = 0; i < args.size(); i++) { + auto &child = BoundExpression::GetExpression(*args[i].GetExpressionMutable()); children.push_back(std::move(child)); } // capture the (lambda) columns auto &bound_lambda_expr = children[lambda_expr_idx]->Cast(); - CaptureLambdaColumns(bound_lambda_expr, bound_lambda_expr.lambda_expr, capture_bind_lambda, + CaptureLambdaColumns(bound_lambda_expr, bound_lambda_expr.LambdaExprMutable(), capture_bind_lambda, override_bind_lambda_context, capture_child_types); FunctionBinder function_binder(binder); unique_ptr result = - function_binder.BindScalarFunction(func, std::move(children), error, function.is_operator, &binder); + function_binder.BindScalarFunction(func, std::move(children), error, function.IsOperator(), &binder); if (!result) { error.AddQueryLocation(function); error.Throw(); @@ -527,8 +496,8 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc auto &bound_function_expr = result->Cast(); // remove the lambda expression from the children - auto lambda = std::move(bound_function_expr.children[lambda_expr_idx]); - bound_function_expr.children.erase_at(lambda_expr_idx); + auto lambda = std::move(bound_function_expr.GetChildrenMutable()[lambda_expr_idx]); + bound_function_expr.GetChildrenMutable().erase_at(lambda_expr_idx); auto &bound_lambda = lambda->Cast(); // push back (in reverse order) any nested lambda parameters so that we can later use them in the lambda @@ -547,14 +516,14 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc auto bound_lambda_param = make_uniq(column_names[column_idx - 1], column_types[column_idx - 1], offset); offset++; - bound_function_expr.children.push_back(std::move(bound_lambda_param)); + bound_function_expr.GetChildrenMutable().push_back(std::move(bound_lambda_param)); } } } // push back the captures into the children vector - for (auto &capture : bound_lambda.captures) { - bound_function_expr.children.push_back(std::move(capture)); + for (auto &capture : bound_lambda.CapturesMutable()) { + bound_function_expr.GetChildrenMutable().push_back(std::move(capture)); } return BindResult(std::move(result)); @@ -588,7 +557,7 @@ string ExpressionBinder::UnsupportedUnnestMessage() { return "UNNEST not supported here"; } -optional_ptr ExpressionBinder::GetCatalogEntry(const string &catalog, const string &schema, +optional_ptr ExpressionBinder::GetCatalogEntry(const Identifier &catalog, const Identifier &schema, const EntryLookupInfo &lookup_info, OnEntryNotFound on_entry_not_found) { return binder.GetCatalogEntry(catalog, schema, lookup_info, on_entry_not_found); diff --git a/src/duckdb/src/planner/binder/expression/bind_lambda.cpp b/src/duckdb/src/planner/binder/expression/bind_lambda.cpp index 6c3e0fd64..6196ca309 100644 --- a/src/duckdb/src/planner/binder/expression/bind_lambda.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_lambda.cpp @@ -22,31 +22,32 @@ static idx_t GetLambdaParamCount(vector &lambda_bindings) { static idx_t GetLambdaParamIndex(vector &lambda_bindings, const BoundLambdaExpression &bound_lambda_expr, const BoundLambdaRefExpression &bound_lambda_ref_expr) { - D_ASSERT(bound_lambda_ref_expr.lambda_idx < lambda_bindings.size()); + D_ASSERT(bound_lambda_ref_expr.LambdaIndex() < lambda_bindings.size()); idx_t offset = 0; // count the remaining lambda parameters BEFORE the current lambda parameter, // as these will be in front of the current lambda parameter in the input chunk - for (idx_t i = bound_lambda_ref_expr.lambda_idx + 1; i < lambda_bindings.size(); i++) { + for (idx_t i = bound_lambda_ref_expr.LambdaIndex() + 1; i < lambda_bindings.size(); i++) { offset += lambda_bindings[i].GetColumnCount(); } - offset += lambda_bindings[bound_lambda_ref_expr.lambda_idx].GetColumnCount() - - bound_lambda_ref_expr.binding.column_index - 1; - offset += bound_lambda_expr.parameter_count; + offset += lambda_bindings[bound_lambda_ref_expr.LambdaIndex()].GetColumnCount() - + bound_lambda_ref_expr.Binding().column_index - 1; + offset += bound_lambda_expr.ParameterCount(); return offset; } -static void ExtractParameter(const ParsedExpression &expr, vector &column_names, +static void ExtractParameter(const ParsedExpression &expr, vector &column_names, vector &column_aliases) { auto &column_ref = expr.Cast(); if (column_ref.IsQualified()) { throw BinderException(LambdaExpression::InvalidParametersErrorMessage()); } - column_names.push_back(column_ref.GetName()); + column_names.emplace_back(column_ref.GetName()); column_aliases.push_back(column_ref.ToString()); } -static void ExtractParameters(LambdaExpression &expr, vector &column_names, vector &column_aliases) { +static void ExtractParameters(LambdaExpression &expr, vector &column_names, + vector &column_aliases) { // extract the lambda parameters, which are a single column // reference, or a list of column references (ROW function) string error_message; @@ -66,18 +67,19 @@ static bool IsDoubleArrowRHS(const ParsedExpression &expr) { return false; } auto &func = expr.Cast(); - return func.is_operator && func.function_name == "->>" && func.children.size() == 2; + return func.IsOperator() && func.FunctionName() == "->>" && func.GetArguments().size() == 2; } static unique_ptr RestructureArrowChain(LambdaExpression &expr) { - auto &rhs_func = expr.expr->Cast(); - auto inner_lambda = make_uniq(std::move(expr.lhs), std::move(rhs_func.children[0])); - inner_lambda->syntax_type = expr.syntax_type; + auto &rhs_func = expr.RightMutable()->Cast(); + auto inner_lambda = make_uniq( + std::move(expr.LeftMutable()), std::move(rhs_func.GetArgumentsMutable()[0].GetExpressionMutable())); + inner_lambda->GetLambdaSyntaxTypeMutable() = expr.GetLambdaSyntaxType(); vector> children; children.push_back(std::move(inner_lambda)); - children.push_back(std::move(rhs_func.children[1])); + children.push_back(std::move(rhs_func.GetArgumentsMutable()[1].GetExpressionMutable())); auto restructured = make_uniq("->>", std::move(children)); - restructured->is_operator = true; + restructured->IsOperatorMutable() = true; return std::move(restructured); } @@ -85,39 +87,40 @@ BindResult ExpressionBinder::BindExpression(LambdaExpression &expr, idx_t depth, const vector &function_child_types, optional_ptr bind_lambda_function, optional_ptr bind_lambda_context) { - if (expr.syntax_type == LambdaSyntaxType::LAMBDA_KEYWORD && !bind_lambda_function) { + if (expr.GetLambdaSyntaxType() == LambdaSyntaxType::LAMBDA_KEYWORD && !bind_lambda_function) { return BindResult("invalid lambda expression"); } if (!bind_lambda_function) { // The PEG parser produces A -> (B ->> C) where the standard parser left-associates to (A -> B) ->> C. // Restructure to match standard behavior before binding. - if (IsDoubleArrowRHS(*expr.expr)) { + if (IsDoubleArrowRHS(expr.Right())) { unique_ptr restructured = RestructureArrowChain(expr); return BindExpression(restructured, depth); } // This is not a lambda expression, but the JSON arrow operator. // Remember the original expression in case of a binding error. - if (!expr.copied_expr) { - expr.copied_expr = expr.expr->Copy(); + if (!expr.CopiedExprMutable()) { + expr.CopiedExprMutable() = expr.Right().Copy(); } - OperatorExpression arrow_expr(ExpressionType::ARROW, std::move(expr.lhs), std::move(expr.expr)); + OperatorExpression arrow_expr(ExpressionType::ARROW, std::move(expr.LeftMutable()), + std::move(expr.RightMutable())); auto bind_result = BindExpression(arrow_expr, depth); // The arrow_expr now might contain bound nodes. // Restore the original expression. if (bind_result.HasError()) { - D_ASSERT(arrow_expr.children.size() == 2); - expr.lhs = std::move(arrow_expr.children[0]); - expr.expr = std::move(arrow_expr.children[1]); + D_ASSERT(arrow_expr.GetChildrenMutable().size() == 2); + expr.LeftMutable() = std::move(arrow_expr.GetChildrenMutable()[0]); + expr.RightMutable() = std::move(arrow_expr.GetChildrenMutable()[1]); } return bind_result; } // extract and verify lambda parameters to create dummy columns vector column_types; - vector column_names; + vector column_names; vector column_aliases; ExtractParameters(expr, column_names, column_aliases); for (idx_t i = 0; i < column_names.size(); i++) { @@ -138,10 +141,10 @@ BindResult ExpressionBinder::BindExpression(LambdaExpression &expr, idx_t depth, DummyBinding new_lambda_binding(column_types, column_names, table_alias); lambda_bindings->push_back(new_lambda_binding); - if (expr.copied_expr) { - expr.expr = std::move(expr.copied_expr); + if (expr.CopiedExprMutable()) { + expr.RightMutable() = std::move(expr.CopiedExprMutable()); } - auto result = BindExpression(expr.expr, depth, false); + auto result = BindExpression(expr.RightMutable(), depth, false); lambda_bindings->pop_back(); // successfully bound a subtree of nested lambdas, set this to nullptr in case other parts of the @@ -171,15 +174,15 @@ void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &ori // refers to a lambda parameter outside the current lambda function // so the lambda parameter will be inside the lambda_bindings - if (lambda_bindings && bound_lambda_ref.lambda_idx != lambda_bindings->size()) { - auto &binding = (*lambda_bindings)[bound_lambda_ref.lambda_idx]; + if (lambda_bindings && bound_lambda_ref.LambdaIndex() != lambda_bindings->size()) { + auto &binding = (*lambda_bindings)[bound_lambda_ref.LambdaIndex()]; auto &column_names = binding.GetColumnNames(); auto &column_types = binding.GetColumnTypes(); D_ASSERT(column_names.size() == column_types.size()); // find the matching dummy column in the lambda binding for (idx_t column_idx = 0; column_idx < binding.GetColumnCount(); column_idx++) { - if (ProjectionIndex(column_idx) == bound_lambda_ref.binding.column_index) { + if (ProjectionIndex(column_idx) == bound_lambda_ref.Binding().column_index) { // now create the replacement auto index = GetLambdaParamIndex(*lambda_bindings, bound_lambda_expr, bound_lambda_ref); replacement = @@ -193,8 +196,8 @@ void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &ori } // refers to a lambda parameter inside the current lambda function auto logical_type = (*bind_lambda_function)(context, function_child_types, - bound_lambda_ref.binding.column_index, bind_lambda_context); - auto index = bound_lambda_expr.parameter_count - bound_lambda_ref.binding.column_index - 1; + bound_lambda_ref.Binding().column_index, bind_lambda_context); + auto index = bound_lambda_expr.ParameterCount() - bound_lambda_ref.Binding().column_index - 1; replacement = make_uniq(alias, logical_type, index); return; } @@ -204,11 +207,11 @@ void ExpressionBinder::TransformCapturedLambdaColumn(unique_ptr &ori if (lambda_bindings) { offset += GetLambdaParamCount(*lambda_bindings); } - offset += bound_lambda_expr.parameter_count; - offset += bound_lambda_expr.captures.size(); + offset += bound_lambda_expr.ParameterCount(); + offset += bound_lambda_expr.CapturesMutable().size(); replacement = make_uniq(original->GetAlias(), original->GetReturnType(), offset); - bound_lambda_expr.captures.push_back(std::move(original)); + bound_lambda_expr.CapturesMutable().push_back(std::move(original)); } void ExpressionBinder::CaptureLambdaColumns(BoundLambdaExpression &bound_lambda_expr, unique_ptr &expr, @@ -234,7 +237,7 @@ void ExpressionBinder::CaptureLambdaColumns(BoundLambdaExpression &bound_lambda_ expr->GetExpressionClass() == ExpressionClass::BOUND_LAMBDA_REF) { if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { // Search for UNNEST. - auto &column_binding = expr->Cast().binding; + auto &column_binding = expr->Cast().Binding(); ThrowIfUnnestInLambda(column_binding); } diff --git a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp index 79efe9c86..20d83fdcb 100644 --- a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp @@ -1,33 +1,38 @@ #include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/entry_lookup_info.hpp" #include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/enums/on_entry_not_found.hpp" +#include "duckdb/common/exception/binder_exception.hpp" +#include "duckdb/common/unique_ptr.hpp" #include "duckdb/function/scalar_macro_function.hpp" #include "duckdb/parser/expression/conjunction_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/expression/subquery_expression.hpp" #include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/parser/parsed_expression.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/planner/expression_binder.hpp" namespace duckdb { void ExpressionBinder::ReplaceMacroParametersInLambda(FunctionExpression &function, - vector> &lambda_params) { - for (auto &child : function.children) { - if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { - ReplaceMacroParameters(child, lambda_params); + vector &lambda_params) { + for (auto &child : function.GetArgumentsMutable()) { + if (child.GetExpression().GetExpressionClass() != ExpressionClass::LAMBDA) { + ReplaceMacroParameters(child.GetExpressionMutable(), lambda_params); continue; } // Special-handling for LHS lambda parameters. // We do not replace them, and we add them to the lambda_params vector. - auto &lambda_expr = child->Cast(); + auto &lambda_expr = child.GetExpressionMutable()->Cast(); string error_message; auto column_ref_expressions = lambda_expr.ExtractColumnRefExpressions(error_message); if (!error_message.empty()) { // Possibly a JSON function, replace both LHS and RHS. - ReplaceMacroParameters(lambda_expr.lhs, lambda_params); - ReplaceMacroParameters(lambda_expr.expr, lambda_params); + ReplaceMacroParameters(lambda_expr.LeftMutable(), lambda_params); + ReplaceMacroParameters(lambda_expr.RightMutable(), lambda_params); continue; } @@ -39,14 +44,14 @@ void ExpressionBinder::ReplaceMacroParametersInLambda(FunctionExpression &functi } // Only replace in the RHS of the expression. - ReplaceMacroParameters(lambda_expr.expr, lambda_params); + ReplaceMacroParameters(lambda_expr.RightMutable(), lambda_params); lambda_params.pop_back(); } } void ExpressionBinder::ReplaceMacroParameters(unique_ptr &expr, - vector> &lambda_params) { + vector &lambda_params) { switch (expr->GetExpressionClass()) { case ExpressionClass::COLUMN_REF: { // If the expression is a column reference, we replace it with its argument. @@ -57,7 +62,7 @@ void ExpressionBinder::ReplaceMacroParameters(unique_ptr &expr bool bind_macro_parameter = false; if (col_ref.IsQualified()) { - if (col_ref.GetTableName().find(DummyBinding::DUMMY_NAME) != string::npos) { + if (col_ref.GetTableName().GetIdentifierName().find(DummyBinding::DUMMY_NAME) != string::npos) { bind_macro_parameter = true; } } else { @@ -79,7 +84,7 @@ void ExpressionBinder::ReplaceMacroParameters(unique_ptr &expr break; } case ExpressionClass::SUBQUERY: { - auto &sq = (expr->Cast()).subquery; + auto &sq = (expr->Cast()).Subquery(); ParsedExpressionIterator::EnumerateQueryNodeChildren( *sq->node, [&](unique_ptr &child) { ReplaceMacroParameters(child, lambda_params); }); break; @@ -92,11 +97,77 @@ void ExpressionBinder::ReplaceMacroParameters(unique_ptr &expr *expr, [&](unique_ptr &child) { ReplaceMacroParameters(child, lambda_params); }); } +// Find aggregate expression children +void ExpressionBinder::FindAggregateExprs(unique_ptr &expr, + vector>> &exprs) { + if (expr->GetExpressionType() == ExpressionType::FUNCTION) { + auto &fn_expr = expr->Cast(); + + // Look up the function in the catalog, check to see if it is actually an aggregate function + EntryLookupInfo fn_entry(CatalogType::AGGREGATE_FUNCTION_ENTRY, fn_expr.FunctionName()); + auto entry = GetCatalogEntry(fn_expr.Catalog(), fn_expr.Schema(), fn_entry, OnEntryNotFound::RETURN_NULL); + + if (entry && entry->type == CatalogType::AGGREGATE_FUNCTION_ENTRY) { + exprs.push_back(expr); + return; + } + } + + ParsedExpressionIterator::EnumerateChildren( + *expr, [&](unique_ptr &child_expr) { FindAggregateExprs(child_expr, exprs); }); +} + +void ExpressionBinder::UnfoldWindowMacroExpression(unique_ptr &expr, ScalarMacroFunction ¯o_def) { + auto macro_copy = macro_def.expression->Copy(); + vector>> aggregate_exprs; + FindAggregateExprs(macro_copy, aggregate_exprs); + + // Only allowed if the macro body has a single aggregate expression + if (aggregate_exprs.size() != 1) { + throw BinderException("Window function macro bodies must contain exactly one aggregate function"); + } + + // The window spec is pushed down to the aggregate function target within the macro body + unique_ptr &agg_expr_ref = aggregate_exprs[0]; + auto &agg_fn_expr = agg_expr_ref->Cast(); + + // Transfer the macro function attributes + auto &window_expr = expr->Cast(); + window_expr.CatalogMutable() = agg_fn_expr.Catalog(); + window_expr.SchemaMutable() = agg_fn_expr.Schema(); + window_expr.FunctionNameMutable() = agg_fn_expr.FunctionName(); + window_expr.GetArgumentsMutable().clear(); + for (auto &arg : agg_fn_expr.GetArgumentsMutable()) { + window_expr.GetArgumentsMutable().push_back(std::move(arg)); + } + if (!window_expr.Distinct()) { + window_expr.DistinctMutable() = agg_fn_expr.Distinct(); + } + if (window_expr.Filter() && agg_fn_expr.Filter()) { + // Two FILTER clauses: combine + window_expr.FilterMutable() = + make_uniq(ExpressionType::CONJUNCTION_AND, std::move(window_expr.FilterMutable()), + std::move(agg_fn_expr.FilterMutable())); + } else if (agg_fn_expr.Filter()) { + // One FILTER from the MACRO + window_expr.FilterMutable() = std::move(agg_fn_expr.FilterMutable()); + } + // Transfer argument ORDER BYs + if (agg_fn_expr.OrderBy()) { + auto clone = agg_fn_expr.OrderBy()->Copy(); + window_expr.ArgOrdersMutable() = std::move(clone->Cast().orders); + } + + // Replace the aggregate expression with the new window expression + agg_expr_ref = std::move(expr); + expr = std::move(macro_copy); +} + void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, ScalarMacroCatalogEntry ¯o_func, unique_ptr &expr, idx_t depth) { // validate the arguments and separate positional and default arguments vector> positional_arguments; - InsertionOrderPreservingMap> named_arguments; + InsertionOrderPreservingMap, Identifier, identifier_map_t> named_arguments; binder.lambda_bindings = lambda_bindings; auto bind_result = MacroFunction::BindMacroFunction(binder, macro_func.macros, macro_func.name, function, positional_arguments, named_arguments, depth); @@ -112,34 +183,7 @@ void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, Scala // replace current expression with stored macro expression // special case: If this is a window function, then we need to return a window expression if (expr->GetExpressionClass() == ExpressionClass::WINDOW) { - // Only allowed if the expression is a function - if (macro_def.expression->GetExpressionType() != ExpressionType::FUNCTION) { - throw BinderException("Window function macros must be functions"); - } - auto macro_copy = macro_def.expression->Copy(); - auto ¯o_expr = macro_copy->Cast(); - // Transfer the macro function attributes - auto &window_expr = expr->Cast(); - window_expr.catalog = macro_expr.catalog; - window_expr.schema = macro_expr.schema; - window_expr.function_name = macro_expr.function_name; - window_expr.children = std::move(macro_expr.children); - if (!window_expr.distinct) { - window_expr.distinct = macro_expr.distinct; - } - if (window_expr.filter_expr && macro_expr.filter) { - // Two FILTER clauses: combine - window_expr.filter_expr = make_uniq( - ExpressionType::CONJUNCTION_AND, std::move(window_expr.filter_expr), std::move(macro_expr.filter)); - } else if (macro_expr.filter) { - // One FILTER from the MACRO - window_expr.filter_expr = std::move(macro_expr.filter); - } - // Transfer argument ORDER BYs - if (macro_expr.order_bys) { - auto clone = macro_expr.order_bys->Copy(); - window_expr.arg_orders = std::move(clone->Cast().orders); - } + UnfoldWindowMacroExpression(expr, macro_def); } else { expr = macro_def.expression->Copy(); } @@ -150,7 +194,7 @@ void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, Scala ExpressionBinder::QualifyColumnNames(*dummy_binder, expr); // now replace the parameters - vector> lambda_params; + vector lambda_params; ReplaceMacroParameters(expr, lambda_params); } diff --git a/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp index 15338203e..5a146bb70 100644 --- a/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp @@ -8,7 +8,6 @@ #include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/expression/bound_parameter_expression.hpp" #include "duckdb/planner/expression_binder.hpp" -#include "duckdb/planner/expression_binder/try_operator_binder.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" @@ -102,18 +101,13 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) // Only those children that trigger an error are not yet bound. ErrorData error; if (operator_type == ExpressionType::OPERATOR_TRY) { - D_ASSERT(op.children.size() == 1); - //! This binder is used to throw when the child expression is of a type that is not allowed. - TryOperatorBinder try_operator_binder(binder, context); - try_operator_binder.BindChild(op.children[0], depth, error); - // Propagate bound columns from TryOperatorBinder back to parent binder - // This ensures that column references inside TRY() are properly tracked for GROUP BY validation - for (const auto &bound_col : try_operator_binder.GetBoundColumns()) { - bound_columns.push_back(bound_col); - } + D_ASSERT(op.GetChildrenMutable().size() == 1); + inside_try = true; + BindChild(op.GetChildrenMutable()[0], depth, error); + inside_try = false; } else { - for (idx_t i = 0; i < op.children.size(); i++) { - BindChild(op.children[i], depth, error); + for (idx_t i = 0; i < op.GetChildrenMutable().size(); i++) { + BindChild(op.GetChildrenMutable()[i], depth, error); } } @@ -127,36 +121,38 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) case ExpressionType::OPERATOR_UNPACK: return BindResult("UNPACK not allowed here, should have been resolved earlier"); case ExpressionType::ARRAY_EXTRACT: { - D_ASSERT(op.children[0]->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); - auto &b_exp = BoundExpression::GetExpression(*op.children[0]); + D_ASSERT(op.GetChildrenMutable()[0]->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); + auto &b_exp = BoundExpression::GetExpression(*op.GetChildrenMutable()[0]); const auto &b_exp_type = b_exp->GetReturnType(); if (b_exp_type.id() == LogicalTypeId::MAP) { function_name = "map_extract_value"; - } else if (b_exp_type.IsJSONType() && op.children.size() == 2) { + } else if (b_exp_type.IsJSONType() && op.GetChildrenMutable().size() == 2) { function_name = "json_extract"; // Make sure we only extract array elements, not fields, by adding the $[] syntax - auto &i_exp = BoundExpression::GetExpression(*op.children[1]); + auto &i_exp = BoundExpression::GetExpression(*op.GetChildrenMutable()[1]); if (i_exp->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT && - !i_exp->Cast().value.IsNull()) { + !i_exp->Cast().GetValue().IsNull()) { auto &const_exp = i_exp->Cast(); - if (const_exp.value.TryCastAs(context, LogicalType::UINTEGER)) { + if (const_exp.GetValueMutable().TryCastAs(context, LogicalType::UINTEGER)) { // Array extraction: if the cast fails it's definitely out-of-bounds for a JSON array - auto index = UIntegerValue::Get(const_exp.value); - const_exp.value = StringUtil::Format("$[%lld]", index); + auto index = UIntegerValue::Get(const_exp.GetValueMutable()); + const_exp.GetValueMutable() = StringUtil::Format("$[%lld]", index); const_exp.SetReturnType(LogicalType::VARCHAR); } else if (const_exp.GetReturnType().id() == LogicalType::VARCHAR) { // Field extraction - const_exp.value = StringUtil::Format("$.\"%s\"", const_exp.value.ToString()); + const_exp.GetValueMutable() = + StringUtil::Format("$.\"%s\"", const_exp.GetValueMutable().ToString()); const_exp.SetReturnType(LogicalType::VARCHAR); } } - } else if (b_exp_type.id() == LogicalTypeId::VARIANT && op.children.size() == 2) { + } else if (b_exp_type.id() == LogicalTypeId::VARIANT && op.GetChildrenMutable().size() == 2) { function_name = "variant_extract"; - auto &i_exp = BoundExpression::GetExpression(*op.children[1]); + auto &i_exp = BoundExpression::GetExpression(*op.GetChildrenMutable()[1]); if (i_exp->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { auto &const_exp = i_exp->Cast(); - if (!const_exp.value.IsNull() && const_exp.GetReturnType().IsNumeric()) { - const_exp.value = const_exp.value.DefaultCastAs(LogicalType::UINTEGER, true); + if (!const_exp.GetValueMutable().IsNull() && const_exp.GetReturnType().IsNumeric()) { + const_exp.GetValueMutable() = + const_exp.GetValueMutable().DefaultCastAs(LogicalType::UINTEGER, true); const_exp.SetReturnType(LogicalType::UINTEGER); } } @@ -169,19 +165,18 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) function_name = "array_slice"; break; case ExpressionType::STRUCT_EXTRACT: { - D_ASSERT(op.children.size() == 2); - D_ASSERT(op.children[0]->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); - D_ASSERT(op.children[1]->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); - auto &extract_exp = BoundExpression::GetExpression(*op.children[0]); + D_ASSERT(op.GetChildrenMutable().size() == 2); + D_ASSERT(op.GetChildrenMutable()[0]->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); + D_ASSERT(op.GetChildrenMutable()[1]->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); + auto &extract_exp = BoundExpression::GetExpression(*op.GetChildrenMutable()[0]); if (extract_exp->HasParameter() || extract_exp->GetReturnType().id() == LogicalTypeId::UNKNOWN) { throw ParameterNotResolvedException(); } - auto &name_exp = BoundExpression::GetExpression(*op.children[1]); + auto &name_exp = BoundExpression::GetExpression(*op.GetChildrenMutable()[1]); const auto &extract_expr_type = extract_exp->GetReturnType(); if (extract_expr_type.id() != LogicalTypeId::STRUCT && extract_expr_type.id() != LogicalTypeId::UNION && extract_expr_type.id() != LogicalTypeId::MAP && extract_expr_type.id() != LogicalTypeId::SQLNULL && - !extract_expr_type.IsJSONType() && extract_expr_type.id() != LogicalTypeId::VARIANT && - !extract_expr_type.IsAggregateStateStructType()) { + !extract_expr_type.IsJSONType() && extract_expr_type.id() != LogicalTypeId::VARIANT) { return BindResult(StringUtil::Format( "Cannot extract field %s from expression \"%s\" because it is not a struct, union, map, or json", name_exp->ToString(), extract_exp->ToString())); @@ -192,11 +187,11 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) function_name = "map_extract_value"; } else if (extract_expr_type.id() == LogicalTypeId::VARIANT) { function_name = "variant_extract"; - auto &i_exp = BoundExpression::GetExpression(*op.children[1]); + auto &i_exp = BoundExpression::GetExpression(*op.GetChildrenMutable()[1]); if (i_exp->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { auto &const_exp = i_exp->Cast(); - if (!const_exp.value.IsNull()) { - const_exp.value = StringUtil::Format("%s", const_exp.value.ToString()); + if (!const_exp.GetValueMutable().IsNull()) { + const_exp.GetValueMutable() = StringUtil::Format("%s", const_exp.GetValueMutable().ToString()); const_exp.SetReturnType(LogicalType::VARCHAR); } } @@ -205,8 +200,9 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) // Make sure we only extract fields, not array elements, by adding $. syntax if (name_exp->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { auto &const_exp = name_exp->Cast(); - if (!const_exp.value.IsNull()) { - const_exp.value = StringUtil::Format("$.\"%s\"", const_exp.value.ToString()); + if (!const_exp.GetValueMutable().IsNull()) { + const_exp.GetValueMutable() = + StringUtil::Format("$.\"%s\"", const_exp.GetValueMutable().ToString()); const_exp.SetReturnType(LogicalType::VARCHAR); } } @@ -222,7 +218,7 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) function_name = "json_extract"; break; case ExpressionType::OPERATOR_TRY: { - auto &expr = BoundExpression::GetExpression(*op.children[0]); + auto &expr = BoundExpression::GetExpression(*op.GetChildrenMutable()[0]); if (expr->HasSubquery()) { throw BinderException("TRY can not be used in combination with a scalar subquery"); } @@ -235,14 +231,15 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) break; } if (!function_name.empty()) { - auto function = make_uniq_base(function_name, std::move(op.children)); + auto function = make_uniq_base(Identifier(function_name), + std::move(op.GetChildrenMutable())); return BindExpression(function, depth, false); } vector> children; - for (idx_t i = 0; i < op.children.size(); i++) { - D_ASSERT(op.children[i]->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); - children.push_back(std::move(BoundExpression::GetExpression(*op.children[i]))); + for (idx_t i = 0; i < op.GetChildrenMutable().size(); i++) { + D_ASSERT(op.GetChildrenMutable()[i]->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); + children.push_back(std::move(BoundExpression::GetExpression(*op.GetChildrenMutable()[i]))); } // now resolve the types LogicalType result_type = ResolveOperatorType(op, children); @@ -257,7 +254,7 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) auto result = make_uniq(op.GetExpressionType(), result_type); for (auto &child : children) { - result->children.push_back(std::move(child)); + result->GetChildrenMutable().push_back(std::move(child)); } return BindResult(std::move(result)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp index 3fe02467e..9ee9ba320 100644 --- a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp @@ -12,7 +12,7 @@ BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t dep if (!parameters) { throw BinderException("Unexpected prepared parameter. This type of statement can't be prepared!"); } - auto parameter_id = expr.identifier; + auto parameter_id = expr.Identifier(); // Check if a parameter value has already been supplied auto ¶meter_data = parameters->GetParameterData(); diff --git a/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp index 45f8ae07a..137e9b37b 100644 --- a/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_star_expression.cpp @@ -14,7 +14,7 @@ namespace duckdb { string GetColumnsStringValue(ParsedExpression &expr) { if (expr.GetExpressionType() == ExpressionType::COLUMN_REF) { auto &colref = expr.Cast(); - return colref.GetColumnName(); + return colref.GetColumnName().GetIdentifierName(); } else { return expr.ToString(); } @@ -25,7 +25,7 @@ StarExpressionType Binder::FindStarExpression(unique_ptr &expr StarExpressionType has_star = StarExpressionType::NONE; if (expr->GetExpressionType() == ExpressionType::OPERATOR_UNPACK) { auto &operator_expr = expr->Cast(); - auto res = FindStarExpression(operator_expr.children[0], star, is_root, in_columns); + auto res = FindStarExpression(operator_expr.GetChildrenMutable()[0], star, is_root, in_columns); if (res != StarExpressionType::STAR && res != StarExpressionType::COLUMNS) { throw BinderException( "UNPACK can only be used in combination with a STAR (*) expression or COLUMNS expression"); @@ -48,12 +48,12 @@ StarExpressionType Binder::FindStarExpression(unique_ptr &expr "STAR expression is only allowed as the root element of an expression. Use COLUMNS(*) instead."); } - if (!current_star.replace_list.empty()) { + if (!current_star.ReplaceList().empty()) { // '*' inside COLUMNS can not have a REPLACE list throw BinderException( "STAR expression with REPLACE list is only allowed as the root element of COLUMNS"); } - if (!current_star.rename_list.empty()) { + if (!current_star.RenameList().empty()) { // '*' inside COLUMNS can not have a REPLACE list throw BinderException( "STAR expression with RENAME list is only allowed as the root element of COLUMNS"); @@ -153,7 +153,7 @@ string Binder::ReplaceColumnsAlias(const string &alias, const string &column_nam void TryTransformStarLike(unique_ptr &root) { // detect "* LIKE [literal]" and similar expressions bool inverse = root->GetExpressionType() == ExpressionType::OPERATOR_NOT; - auto &expr = inverse ? root->Cast().children[0] : root; + auto &expr = inverse ? root->Cast().GetChildrenMutable()[0] : root; if (!expr) { return; } @@ -161,16 +161,16 @@ void TryTransformStarLike(unique_ptr &root) { return; } auto &function = expr->Cast(); - if (function.children.size() < 2 || function.children.size() > 3) { + if (function.GetArguments().size() < 2 || function.GetArguments().size() > 3) { return; } - auto &left = function.children[0]; + auto &left = function.GetArgumentsMutable()[0]; // expression must have a star on the LHS, and a literal on the RHS - if (left->GetExpressionClass() != ExpressionClass::STAR) { + if (left.GetExpression().GetExpressionClass() != ExpressionClass::STAR) { return; } - auto &star = left->Cast(); - if (star.columns) { + auto &star = left.GetExpressionMutable()->Cast(); + if (star.IsColumns()) { // COLUMNS(*) has different semantics return; } @@ -185,34 +185,34 @@ void TryTransformStarLike(unique_ptr &root) { "ilike_escape", "not_ilike_escape", "like_escape"}; - if (supported_ops.count(function.function_name) == 0) { + if (supported_ops.count(function.FunctionName().GetIdentifierName()) == 0) { // unsupported op for * expression - throw BinderException(*root, "Function \"%s\" cannot be applied to a star expression", function.function_name); + throw BinderException(*root, "Function \"%s\" cannot be applied to a star expression", function.FunctionName()); } - auto &right = function.children[1]; - if (right->GetExpressionClass() != ExpressionClass::CONSTANT) { + auto &right = function.GetArgumentsMutable()[1]; + if (right.GetExpression().GetExpressionClass() != ExpressionClass::CONSTANT) { throw BinderException(*root, "Pattern applied to a star expression must be a constant"); } - if (!star.rename_list.empty()) { + if (!star.RenameList().empty()) { throw BinderException(*root, "Rename list cannot be combined with a filtering operation"); } - if (!star.replace_list.empty()) { + if (!star.ReplaceList().empty()) { throw BinderException(*root, "Replace list cannot be combined with a filtering operation"); } auto original_alias = root->GetAlias(); auto star_expr = std::move(left); unique_ptr child_expr; - if (!inverse && function.function_name == "regexp_full_match" && star.exclude_list.empty()) { + if (!inverse && function.FunctionName() == "regexp_full_match" && star.ExcludeList().empty()) { // * SIMILAR TO '[regex]' is equivalent to COLUMNS('[regex]') so we can just move the expression directly - child_expr = std::move(right); + child_expr = std::move(right.GetExpressionMutable()); } else { // for other expressions -> generate a columns expression // "* LIKE '%literal%' // -> COLUMNS(list_filter(*, x -> x LIKE '%literal%')) vector named_parameters; named_parameters.push_back("__lambda_col"); - function.children[0] = make_uniq("__lambda_col"); - function.children[1] = std::move(right); + function.GetArgumentsMutable()[0] = FunctionArgument(make_uniq("__lambda_col")); + function.GetArgumentsMutable()[1] = std::move(right); unique_ptr lambda_body = std::move(expr); if (inverse) { @@ -223,14 +223,14 @@ void TryTransformStarLike(unique_ptr &root) { auto lambda = make_uniq(std::move(named_parameters), std::move(lambda_body)); vector> filter_children; - filter_children.push_back(std::move(star_expr)); + filter_children.push_back(std::move(star_expr.GetExpressionMutable())); filter_children.push_back(std::move(lambda)); child_expr = make_uniq("list_filter", std::move(filter_children)); } - auto columns_expr = make_uniq(star.relation_name); - columns_expr->columns = true; - columns_expr->expr = std::move(child_expr); + auto columns_expr = make_uniq(star.RelationName()); + columns_expr->IsColumnsMutable() = true; + columns_expr->ExpressionMutable() = std::move(child_expr); columns_expr->SetAlias(std::move(original_alias)); root = std::move(columns_expr); } @@ -242,7 +242,7 @@ optional_ptr Binder::GetResolvedColumnExpression(ParsedExpress break; } if (expr->GetExpressionType() == ExpressionType::OPERATOR_COALESCE) { - expr = expr->Cast().children[0].get(); + expr = expr->Cast().GetChildrenMutable()[0].get(); } else { // unknown expression return nullptr; @@ -272,17 +272,17 @@ void Binder::ExpandStarExpression(unique_ptr expr, bind_context.GenerateAllColumnExpressions(*star, star_list); unique_ptr regex; - if (star->expr) { + if (star->Expression()) { // COLUMNS with an expression // two options: // VARCHAR parameter <- this is a regular expression // LIST of VARCHAR parameters <- this is a set of columns TableFunctionBinder binder(*this, context); - auto child = star->expr->Copy(); + auto child = star->Expression()->Copy(); auto result = binder.Bind(child); if (!result->IsFoldable()) { // cannot resolve parameters here - if (star->expr->HasParameter()) { + if (star->Expression()->HasParameter()) { throw ParameterNotResolvedException(); } else { throw BinderException("Unsupported expression in COLUMNS"); @@ -307,7 +307,7 @@ void Binder::ExpandStarExpression(unique_ptr expr, continue; } auto &colref = child_expr->Cast(); - if (!RE2::PartialMatch(colref.GetColumnName(), *regex)) { + if (!RE2::PartialMatch(colref.GetColumnName().GetIdentifierName(), *regex)) { continue; } new_list.push_back(std::move(expanded_expr)); @@ -320,7 +320,7 @@ void Binder::ExpandStarExpression(unique_ptr expr, continue; } auto &colref = child_expr->Cast(); - candidates.push_back(colref.GetColumnName()); + candidates.emplace_back(colref.GetColumnName()); } string candidate_str; if (!candidates.empty()) { @@ -396,7 +396,9 @@ void Binder::ExpandStarExpression(unique_ptr expr, if (new_expr->GetAlias().empty()) { new_expr->SetAlias(colref.GetColumnName()); } else { - new_expr->SetAlias(ReplaceColumnsAlias(new_expr->GetAlias(), colref.GetColumnName(), regex.get())); + new_expr->SetAlias( + Identifier(ReplaceColumnsAlias(new_expr->GetAlias().GetIdentifierName(), + colref.GetColumnName().GetIdentifierName(), regex.get()))); } } } diff --git a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp index fc4664909..6158acb40 100644 --- a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp @@ -60,13 +60,13 @@ static void ExtractSubqueryChildren(unique_ptr &child, vectorCast(); - if (function.function.GetName() != "row") { + if (function.Function().GetName() != "row") { // not "ROW" return; } // we found (a, b, ...) - we can extract all children of this function // note that we don't always want to do this - if (types.size() == 1 && TypeIsUnnamedStruct(types[0]) && function.children.size() != types.size()) { + if (types.size() == 1 && TypeIsUnnamedStruct(types[0]) && function.GetChildrenMutable().size() != types.size()) { // old case: we have an unnamed struct INSIDE the subquery as well // i.e. (a, b) IN (SELECT (a, b) ...) // unnesting the struct is guaranteed to throw an error - match the structs against each-other instead @@ -83,17 +83,25 @@ static void ExtractSubqueryChildren(unique_ptr &child, vectornode->type != QueryNodeType::BOUND_SUBQUERY_NODE) { + if (inside_try) { + throw BinderException("TRY can not be used in combination with a scalar subquery"); + } + if (expr.Subquery()->node->type != QueryNodeType::BOUND_SUBQUERY_NODE) { // first bind the actual subquery in a new binder - auto subquery_binder = Binder::CreateBinder(context, &binder); - subquery_binder->can_contain_nulls = true; - auto bound_node = subquery_binder->BindNode(*expr.subquery->node); + auto subquery_binder = Binder::CreateBinder(context, binder); + subquery_binder->SetCanContainNulls(true); + subquery_binder->SetInsideSubquery(); + + subquery_binder->BeginSubqueryBind(binder, *this); + auto bound_node = subquery_binder->BindNode(*expr.Subquery()->node); + subquery_binder->FinishSubqueryBind(); + // check the correlated columns of the subquery for correlated columns with depth > 1 for (idx_t i = 0; i < subquery_binder->correlated_columns.size(); i++) { CorrelatedColumnInfo corr = subquery_binder->correlated_columns[i]; @@ -104,29 +112,30 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept binder.AddCorrelatedColumn(corr); } } - auto prior_subquery = std::move(expr.subquery); - expr.subquery = make_uniq(); - expr.subquery->node = + auto prior_subquery = std::move(expr.SubqueryMutable()); + expr.SubqueryMutable() = make_uniq(); + expr.SubqueryMutable()->node = make_uniq(std::move(subquery_binder), std::move(bound_node), std::move(prior_subquery)); } // now bind the child node of the subquery - if (expr.child) { + if (expr.GetChild()) { // first bind the children of the subquery, if any - auto error = Bind(expr.child, depth); + auto error = Bind(expr.GetChildMutable(), depth); if (error.HasError()) { return BindResult(std::move(error)); } } - auto &bound_subquery = expr.subquery->node->Cast(); + auto &bound_subquery = expr.Subquery()->node->Cast(); vector> child_expressions; - if (expr.subquery_type != SubqueryType::EXISTS) { + if (expr.GetSubqueryType() != SubqueryType::EXISTS) { idx_t expected_columns = 1; bool has_unexpanded_struct = false; - if (expr.child) { - auto &child = BoundExpression::GetExpression(*expr.child); + if (expr.GetChild()) { + auto &child = BoundExpression::GetExpression(*expr.GetChildMutable()); // Check if child is an unexpanded struct before extraction has_unexpanded_struct = TypeIsUnnamedStruct(child->GetReturnType()); - ExtractSubqueryChildren(child, child_expressions, bound_subquery.bound_node.types, expr.comparison_type); + ExtractSubqueryChildren(child, child_expressions, bound_subquery.bound_node.types, + expr.GetComparisonType()); if (child_expressions.empty()) { child_expressions.push_back(std::move(child)); } @@ -138,7 +147,7 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept TypeIsUnnamedStruct(child_expressions[0]->GetReturnType())) { // The child is a struct with N elements, and the subquery returns N columns // This is allowed - the subquery columns will be matched against the struct during execution - expected_columns = bound_subquery.bound_node.types.size(); + expected_columns = StructType::GetChildCount(child_expressions[0]->GetReturnType()); } if (bound_subquery.bound_node.types.size() != expected_columns) { throw BinderException(expr, "Subquery returns %zu columns - expected %d", @@ -146,17 +155,17 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept } } // both binding the child and binding the subquery was successful - D_ASSERT(expr.subquery->node->type == QueryNodeType::BOUND_SUBQUERY_NODE); + D_ASSERT(expr.Subquery()->node->type == QueryNodeType::BOUND_SUBQUERY_NODE); auto subquery_binder = std::move(bound_subquery.subquery_binder); auto bound_node = std::move(bound_subquery.bound_node); LogicalType return_type = - expr.subquery_type == SubqueryType::SCALAR ? bound_node.types[0] : LogicalType(LogicalTypeId::BOOLEAN); + expr.GetSubqueryType() == SubqueryType::SCALAR ? bound_node.types[0] : LogicalType(LogicalTypeId::BOOLEAN); if (return_type.id() == LogicalTypeId::UNKNOWN) { return_type = LogicalType::SQLNULL; } auto result = make_uniq(return_type); - if (expr.subquery_type == SubqueryType::ANY) { + if (expr.GetSubqueryType() == SubqueryType::ANY) { // ANY comparison // cast child and subquery child to equivalent types // Special case: if we have a single struct child and multiple subquery types, @@ -164,11 +173,11 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept if (child_expressions.size() == 1 && bound_node.types.size() > 1 && TypeIsUnnamedStruct(child_expressions[0]->GetReturnType())) { // Keep the struct as-is for proper lexicographic row comparison - result->children.push_back(std::move(child_expressions[0])); + result->GetChildrenMutable().push_back(std::move(child_expressions[0])); // Store all the subquery types - they will be used to construct the RHS struct during planning for (auto &subquery_type : bound_node.types) { - result->child_types.push_back(subquery_type); - result->child_targets.push_back(subquery_type); + result->ChildTypesMutable().push_back(subquery_type); + result->ChildTargetsMutable().push_back(subquery_type); } } else { // Standard case: either no struct or struct was extracted into separate expressions @@ -184,16 +193,16 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept child_type.ToString(), subquery_type); } child = BoundCastExpression::AddCastToType(context, std::move(child), compare_type); - result->child_types.push_back(subquery_type); - result->child_targets.push_back(compare_type); - result->children.push_back(std::move(child)); + result->ChildTypesMutable().push_back(subquery_type); + result->ChildTargetsMutable().push_back(compare_type); + result->GetChildrenMutable().push_back(std::move(child)); } } } - result->binder = std::move(subquery_binder); - result->subquery = std::move(bound_node); - result->subquery_type = expr.subquery_type; - result->comparison_type = expr.comparison_type; + result->GetBinderMutable() = std::move(subquery_binder); + result->SubqueryMutable() = std::move(bound_node); + result->SubqueryTypeMutable() = expr.GetSubqueryType(); + result->ComparisonTypeMutable() = expr.GetComparisonType(); return BindResult(std::move(result)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_type_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_type_expression.cpp index aa218b65c..3b143a7e1 100644 --- a/src/duckdb/src/planner/binder/expression/bind_type_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_type_expression.cpp @@ -30,7 +30,7 @@ BindResult ExpressionBinder::BindExpression(TypeExpression &type_expr, idx_t dep // Required for WAL lookup to work if (type_catalog.empty() && !DatabaseManager::Get(context).HasDefaultDatabase()) { // Look in the system catalog if no catalog was specified - entry = binder.entry_retriever.GetEntry(SYSTEM_CATALOG, type_schema, type_lookup); + entry = binder.entry_retriever.GetEntry(Identifier::SystemCatalog(), type_schema, type_lookup); } else { // Try to search from most specific to least specific // The search path should already have been set to the correct catalog/schema, @@ -43,16 +43,16 @@ BindResult ExpressionBinder::BindExpression(TypeExpression &type_expr, idx_t dep entry = binder.entry_retriever.GetEntry(type_catalog, type_schema, type_lookup, OnEntryNotFound::THROW_EXCEPTION); } - entry = binder.entry_retriever.GetEntry(type_catalog, INVALID_SCHEMA, type_lookup, + entry = binder.entry_retriever.GetEntry(type_catalog, Identifier::InvalidSchema(), type_lookup, OnEntryNotFound::RETURN_NULL); } if (!IsValidTypeLookup(entry)) { - entry = binder.entry_retriever.GetEntry(INVALID_CATALOG, INVALID_SCHEMA, type_lookup, - OnEntryNotFound::RETURN_NULL); + entry = binder.entry_retriever.GetEntry(Identifier::InvalidCatalog(), Identifier::InvalidSchema(), + type_lookup, OnEntryNotFound::RETURN_NULL); } if (!IsValidTypeLookup(entry)) { - entry = binder.entry_retriever.GetEntry(SYSTEM_CATALOG, DEFAULT_SCHEMA, type_lookup, - OnEntryNotFound::THROW_EXCEPTION); + entry = binder.entry_retriever.GetEntry(Identifier::SystemCatalog(), Identifier::DefaultSchema(), + type_lookup, OnEntryNotFound::THROW_EXCEPTION); } } @@ -92,13 +92,13 @@ BindResult ExpressionBinder::BindExpression(TypeExpression &type_expr, idx_t dep // Shortcut for constant expressions if (bound_expr->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { auto &const_expr = bound_expr->Cast(); - bound_parameters.emplace_back(param->GetAlias(), const_expr.value); + bound_parameters.emplace_back(param->GetAlias().GetIdentifierName(), const_expr.GetValue()); continue; } // Otherwise we need to evaluate the expression auto bound_param = ExpressionExecutor::EvaluateScalar(context, *bound_expr); - bound_parameters.emplace_back(param->GetAlias(), bound_param); + bound_parameters.emplace_back(param->GetAlias().GetIdentifierName(), bound_param); }; // Call the bind function diff --git a/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp index 81c5df1a2..258a825d5 100644 --- a/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_unnest_expression.cpp @@ -7,7 +7,6 @@ #include "duckdb/planner/expression/bound_expanded_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_parameter_expression.hpp" -#include "duckdb/planner/expression_binder/aggregate_binder.hpp" #include "duckdb/planner/expression_binder/select_binder.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/expression/bound_unnest_expression.hpp" @@ -31,9 +30,9 @@ static unique_ptr CreateBoundStructExtract(ClientContext &context, u if (!alias.empty() && alias[0] == '.') { alias = alias.substr(1); } - result->SetAlias(alias); + result->SetAlias(Identifier(alias)); } else { - result->SetAlias(key_path[0]); + result->SetAlias(Identifier(key_path[0])); } return std::move(result); } @@ -45,7 +44,7 @@ static unique_ptr CreateBoundStructExtractIndex(ClientContext &conte arguments.push_back(make_uniq(Value::BIGINT(int64_t(key)))); auto result = GetIndexExtractFunction().Bind(context, std::move(arguments)); - result->SetAlias("element" + to_string(key)); + result->SetAlias(Identifier("element" + to_string(key))); return std::move(result); } @@ -69,35 +68,37 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b } ErrorData error; - if (function.children.empty()) { + if (function.GetArguments().empty()) { return BindResult(BinderException(function, "UNNEST() requires at lease one argument")); } - if (inside_window) { + if (inside_window || inside_aggregate || inside_try) { return BindResult(BinderException(function, UnsupportedUnnestMessage())); } - if (function.distinct || function.filter || !function.order_bys->orders.empty()) { + if (function.Distinct() || function.Filter() || !function.OrderBy()->orders.empty()) { throw InvalidInputException("\"DISTINCT\", \"FILTER\", and \"ORDER BY\" are not " "applicable to \"UNNEST\""); } idx_t max_depth = 1; bool keep_parent_names = false; - if (function.children.size() != 1) { + auto &args = function.GetArgumentsMutable(); + + if (args.size() != 1) { bool supported_argument = false; - for (idx_t i = 1; i < function.children.size(); i++) { - if (function.children[i]->HasParameter()) { + for (idx_t i = 1; i < args.size(); i++) { + if (args[i].GetExpression().HasParameter()) { throw ParameterNotAllowedException("Parameter not allowed in unnest parameter"); } - if (!function.children[i]->IsScalar()) { + if (!args[i].GetExpression().IsScalar()) { break; } - auto alias = StringUtil::Lower(function.children[i]->GetAlias()); - BindChild(function.children[i], depth, error); + auto alias = args[i].GetExpression().GetAlias(); + BindChild(args[i].GetExpressionMutable(), depth, error); if (error.HasError()) { return BindResult(std::move(error)); } - auto &const_child = BoundExpression::GetExpression(*function.children[i]); + auto &const_child = BoundExpression::GetExpression(*args[i].GetExpressionMutable()); auto value = ExpressionExecutor::EvaluateScalar(context, *const_child, true); if (alias == "recursive") { auto recursive = value.GetValue(); @@ -112,7 +113,7 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b } else if (alias == "keep_parent_names") { keep_parent_names = value.GetValue(); } else if (!alias.empty()) { - throw BinderException("Unsupported parameter \"%s\" for unnest", alias); + throw BinderException("Unsupported parameter \"%s\" for unnest", alias.GetIdentifierName()); } else { break; } @@ -125,18 +126,18 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b } } unnest_level++; - BindChild(function.children[0], depth, error); + BindChild(args[0].GetExpressionMutable(), depth, error); if (error.HasError()) { // failed to bind // try to bind correlated columns manually - auto result = BindCorrelatedColumns(function.children[0], error); + auto result = BindCorrelatedColumns(args[0].GetExpressionMutable(), error); if (result.HasError()) { return BindResult(result.error); } - auto &bound_expr = BoundExpression::GetExpression(*function.children[0]); + auto &bound_expr = BoundExpression::GetExpression(*args[0].GetExpressionMutable()); ExtractCorrelatedExpressions(binder, *bound_expr); } - auto &child = BoundExpression::GetExpression(*function.children[0]); + auto &child = BoundExpression::GetExpression(*args[0].GetExpressionMutable()); child = BoundCastExpression::AddArrayCastToList(context, std::move(child)); auto &child_type = child->GetReturnType(); unnest_level--; @@ -204,8 +205,8 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b } auto result = make_uniq(return_type); - result->child = std::move(unnest_expr); - auto alias = function.GetAlias().empty() ? result->ToString() : function.GetAlias(); + result->ChildMutable() = std::move(unnest_expr); + Identifier alias = function.GetAlias().empty() ? Identifier(result->ToString()) : function.GetAlias(); auto current_level = unnest_level + list_unnests - current_depth - 1; auto entry = node.unnests.find(current_level); @@ -247,9 +248,9 @@ BindResult SelectBinder::BindUnnest(FunctionExpression &function, idx_t depth, b vector current_key_path; // During recursive expansion, not all expressions are BoundFunctionExpression if (keep_parent_names && expr->GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { - current_key_path.push_back(expr->GetAlias()); + current_key_path.emplace_back(expr->GetAlias()); } - current_key_path.push_back(entry.first); + current_key_path.emplace_back(entry.first); new_expressions.push_back( CreateBoundStructExtract(context, expr->Copy(), current_key_path, keep_parent_names)); } diff --git a/src/duckdb/src/planner/binder/expression/bind_unpacked_star_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_unpacked_star_expression.cpp index eeb80d6ad..6096da25d 100644 --- a/src/duckdb/src/planner/binder/expression/bind_unpacked_star_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_unpacked_star_expression.cpp @@ -17,8 +17,8 @@ static void AddChild(unique_ptr &child, expression_list_t &new } auto &unpack = child->Cast(); D_ASSERT(unpack.GetExpressionType() == ExpressionType::OPERATOR_UNPACK); - D_ASSERT(unpack.children.size() == 1); - auto &unpack_child = unpack.children[0]; + D_ASSERT(unpack.GetChildren().size() == 1); + auto &unpack_child = unpack.GetChildrenMutable()[0]; // Replace the child with the replacement expression(s) for (auto &replacement : replacements) { @@ -31,8 +31,8 @@ static void AddChild(unique_ptr &child, expression_list_t &new if (new_expr->GetAlias().empty()) { new_expr->SetAlias(colref.GetColumnName()); } else { - new_expr->SetAlias( - Binder::ReplaceColumnsAlias(new_expr->GetAlias(), colref.GetColumnName(), regex)); + new_expr->SetAlias(Identifier(Binder::ReplaceColumnsAlias( + new_expr->GetAlias().GetIdentifierName(), colref.GetColumnName().GetIdentifierName(), regex))); } } } @@ -46,23 +46,27 @@ static void ReplaceInFunction(unique_ptr &expr, expression_lis // Replace children expression_list_t new_children; - for (auto &child : function_expr.children) { - AddChild(child, new_children, star_list, star, regex); + for (auto &child : function_expr.GetArgumentsMutable()) { + AddChild(child.GetExpressionMutable(), new_children, star_list, star, regex); + } + + function_expr.GetArgumentsMutable().clear(); + for (auto &child : new_children) { + function_expr.GetArgumentsMutable().emplace_back(std::move(child)); } - function_expr.children = std::move(new_children); // Replace ORDER_BY - if (function_expr.order_bys) { + if (function_expr.OrderBy()) { expression_list_t new_orders; - for (auto &order : function_expr.order_bys->orders) { + for (auto &order : function_expr.OrderByMutable()->orders) { AddChild(order.expression, new_orders, star_list, star, regex); } - if (new_orders.size() != function_expr.order_bys->orders.size()) { + if (new_orders.size() != function_expr.OrderBy()->orders.size()) { throw NotImplementedException("*COLUMNS(...) is not supported in the order expression"); } for (idx_t i = 0; i < new_orders.size(); i++) { auto &new_order = new_orders[i]; - function_expr.order_bys->orders[i].expression = std::move(new_order); + function_expr.OrderByMutable()->orders[i].expression = std::move(new_order); } } } @@ -90,10 +94,10 @@ static void ReplaceInOperator(unique_ptr &expr, expression_lis // Replace children expression_list_t new_children; - for (auto &child : operator_expr.children) { + for (auto &child : operator_expr.GetChildrenMutable()) { AddChild(child, new_children, star_list, star, regex); } - operator_expr.children = std::move(new_children); + operator_expr.GetChildrenMutable() = std::move(new_children); } void Binder::ReplaceUnpackedStarExpression(unique_ptr &expr, expression_list_t &star_list, diff --git a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp index cfa368477..a4ebc6723 100644 --- a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp @@ -80,7 +80,8 @@ static LogicalType BindRangeExpression(ClientContext &context, const string &nam ErrorData error; FunctionBinder function_binder(context); - auto function = function_binder.BindScalarFunction(DEFAULT_SCHEMA, name, std::move(children), error, true); + auto function = function_binder.BindScalarFunction(Identifier::DefaultSchema(), Identifier(name), + std::move(children), error, true); if (!function) { error.Throw(); } @@ -95,53 +96,60 @@ static LogicalType BindRangeExpression(ClientContext &context, const string &nam BindResult BaseSelectBinder::BindWindowExpression(WindowExpression &window, idx_t depth) { QueryErrorContext error_context(window.GetQueryLocation()); - // Check for macros pretending to be aggregates - EntryLookupInfo function_lookup(CatalogType::SCALAR_FUNCTION_ENTRY, window.function_name, error_context); - auto entry = GetCatalogEntry(window.catalog, window.schema, function_lookup, OnEntryNotFound::RETURN_NULL); + // Check for macros pretending to be aggregates + EntryLookupInfo function_lookup(CatalogType::SCALAR_FUNCTION_ENTRY, window.FunctionName(), error_context); + auto entry = GetCatalogEntry(window.Catalog(), window.Schema(), function_lookup, OnEntryNotFound::RETURN_NULL); if (entry && entry->type == CatalogType::MACRO_ENTRY) { auto macro_expr = window.Copy(); - auto macro = make_uniq(window.catalog, window.schema, window.function_name, - std::move(window.children), std::move(window.filter_expr), nullptr, - window.distinct); + auto macro = make_uniq(window.Catalog(), window.Schema(), window.FunctionName(), + std::move(window.GetArgumentsMutable()), + std::move(window.FilterMutable()), nullptr, window.Distinct()); return BindMacro(*macro, entry->Cast(), depth, macro_expr); } auto name = window.GetAlias(); + if (inside_try) { + throw BinderException("window functions are not allowed in try"); + } + if (inside_aggregate) { + throw BinderException(window, "aggregate function calls cannot contain window function calls"); + } if (inside_window) { - throw BinderException(error_context, "window function calls cannot be nested"); + throw BinderException(window, "window function calls cannot be nested"); } if (depth > 0) { - throw BinderException(error_context, "correlated columns in window functions not supported"); + throw BinderException(window, "correlated columns in window functions not supported"); } // Look up the aggregate function in the catalog if (!entry || (entry->type != CatalogType::AGGREGATE_FUNCTION_ENTRY && entry->type != CatalogType::WINDOW_FUNCTION_ENTRY)) { // Not an aggregate or window function: Look it up to generate error - Catalog::GetEntry(context, window.catalog, window.schema, window.function_name, - error_context); + Catalog::GetEntry(context, window.Catalog(), window.Schema(), + window.FunctionName(), error_context); } // If we have range expressions, then only one order by clause is allowed. - const auto is_range = - (window.start == WindowBoundary::EXPR_PRECEDING_RANGE || window.start == WindowBoundary::EXPR_FOLLOWING_RANGE || - window.end == WindowBoundary::EXPR_PRECEDING_RANGE || window.end == WindowBoundary::EXPR_FOLLOWING_RANGE); - if (is_range && window.orders.size() != 1) { + const auto is_range = (window.WindowStart() == WindowBoundary::EXPR_PRECEDING_RANGE || + window.WindowStart() == WindowBoundary::EXPR_FOLLOWING_RANGE || + window.WindowEnd() == WindowBoundary::EXPR_PRECEDING_RANGE || + window.WindowEnd() == WindowBoundary::EXPR_FOLLOWING_RANGE); + if (is_range && window.OrderBy().size() != 1) { throw BinderException(error_context, "RANGE frames must have only one ORDER BY expression"); } // bind inside the children of the window function // we set the inside_window flag to true to prevent binding nested window functions inside_window = true; ErrorData error; - for (auto &child : window.children) { - BindChild(child, depth, error); + for (auto &child : window.GetArgumentsMutable()) { + BindChild(child.GetExpressionMutable(), depth, error); } - for (auto &child : window.partitions) { + for (auto &child : window.PartitionsMutable()) { BindChild(child, depth, error); } - for (auto &order : window.orders) { + for (auto &order : window.OrderByMutable()) { BindChild(order.expression, depth, error); // If the frame is a RANGE frame and the type is a time, @@ -159,11 +167,11 @@ BindResult BaseSelectBinder::BindWindowExpression(WindowExpression &window, idx_ BindRangeExpression(context, "+", order.expression, epoch); } } - BindChild(window.filter_expr, depth, error); - BindChild(window.start_expr, depth, error); - BindChild(window.end_expr, depth, error); + BindChild(window.FilterMutable(), depth, error); + BindChild(window.StartExprMutable(), depth, error); + BindChild(window.EndExprMutable(), depth, error); - for (auto &order : window.arg_orders) { + for (auto &order : window.ArgOrdersMutable()) { BindChild(order.expression, depth, error); } @@ -174,120 +182,121 @@ BindResult BaseSelectBinder::BindWindowExpression(WindowExpression &window, idx_ } // Restore any collation expressions - for (auto &order : window.arg_orders) { + for (auto &order : window.ArgOrdersMutable()) { auto &order_expr = order.expression; auto &bound_order = BoundExpression::GetExpression(*order_expr); ExpressionBinder::PushCollation(context, bound_order, bound_order->GetReturnType()); } - for (auto &part_expr : window.partitions) { + for (auto &part_expr : window.PartitionsMutable()) { auto &bound_partition = BoundExpression::GetExpression(*part_expr); ExpressionBinder::PushCollation(context, bound_partition, bound_partition->GetReturnType()); } - for (auto &order : window.orders) { + for (auto &order : window.OrderByMutable()) { auto &order_expr = order.expression; auto &bound_order = BoundExpression::GetExpression(*order_expr); ExpressionBinder::PushCollation(context, bound_order, bound_order->GetReturnType()); } - // successfully bound all children: create bound window function - vector types; - vector> children; - for (auto &child : window.children) { - D_ASSERT(child.get()); - D_ASSERT(child->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); - auto &bound = BoundExpression::GetExpression(*child); - types.push_back(bound->GetReturnType()); - children.push_back(std::move(bound)); + + vector>> arguments; + arguments.reserve(window.GetArguments().size()); + for (auto &arg : window.GetArgumentsMutable()) { + auto &bound_arg = BoundExpression::GetExpression(*arg.GetExpressionMutable()); + + // legacy function calls cannot have named arguments, so we ignore the names of the arguments during binding + // and pass them all positionally. We do alias them by their name though, so that alias-capturing functions + // (e.g. struct_pack) still work and so that re-serializing to the old format can match arguments by name. + // Only override the alias when the argument actually carries a name, otherwise we would clobber the + // display alias the binding assigned (e.g. clearing a column reference's name to its raw binding). + if (!arg.GetName().empty()) { + bound_arg->SetAlias(arg.GetName()); + } + + if (window.IsLegacyFunctionCall()) { + arguments.emplace_back(Identifier(), std::move(bound_arg)); + } else { + arguments.emplace_back(arg.GetName(), std::move(bound_arg)); + } } + // Determine the function type. LogicalType sql_type; unique_ptr aggregate; unique_ptr window_func; unique_ptr bind_info; + vector> children; + if (entry->type == CatalogType::AGGREGATE_FUNCTION_ENTRY) { auto &func = entry->Cast(); - if (window.has_ignore_nulls) { + if (window.HasIgnoreNulls()) { throw BinderException(error_context, "RESPECT/IGNORE NULLS is not supported for windowed aggregates"); } // bind the aggregate - ErrorData error_aggr; - FunctionBinder function_binder(context); - auto best_function = function_binder.BindFunction(func.name, func.functions, types, error_aggr); - if (!best_function.IsValid()) { - error_aggr.AddQueryLocation(window); - error_aggr.Throw(); + FunctionBinder function_binder(binder); + auto window_bound_aggregate = function_binder.BindAggregateFunction(func, std::move(arguments), error, nullptr); + // No function found, throw an error + if (!window_bound_aggregate) { + error.AddQueryLocation(window); + error.Throw(); } - // found a matching function! bind it as an aggregate - const auto &bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); - - auto window_bound_aggregate = function_binder.BindAggregateFunction(bound_function, std::move(children)); // create the aggregate - aggregate = make_uniq(window_bound_aggregate->function); - bind_info = std::move(window_bound_aggregate->bind_info); - children = std::move(window_bound_aggregate->children); + aggregate = make_uniq(window_bound_aggregate->Function()); + bind_info = std::move(window_bound_aggregate->BindInfoMutable()); + children = std::move(window_bound_aggregate->GetChildrenMutable()); sql_type = window_bound_aggregate->GetReturnType(); + } else { auto &func = entry->Cast(); - // bind the aggregate - ErrorData error_aggr; - FunctionBinder function_binder(context); - auto best_function = function_binder.BindFunction(func.name, func.functions, types, error_aggr); - if (!best_function.IsValid()) { - error_aggr.AddQueryLocation(window); - error_aggr.Throw(); - } + FunctionBinder function_binder(binder); + auto window_bound_function = function_binder.BindWindowFunction( + func, std::move(arguments), error, window.OrderByMutable(), window.ArgOrdersMutable()); - // found a matching function! bind it as a window - const auto &bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); + if (!window_bound_function) { + error.AddQueryLocation(window); + error.Throw(); + } - if (window.distinct && !bound_function.CanDistinct()) { + // TODO: Move this into the binder + auto &bound_function = *window_bound_function->WindowFunction(); + if (window.Distinct() && !bound_function.CanDistinct()) { throw BinderException(error_context, "DISTINCT is not implemented for the window function \"%s\"", - window.function_name); + window.FunctionName()); } - if (window.filter_expr && !bound_function.CanFilter()) { + if (window.Filter() && !bound_function.CanFilter()) { throw BinderException(error_context, "FILTER is not implemented for the window function \"%s\"", - window.function_name); + window.FunctionName()); } - if (!window.arg_orders.empty() && !bound_function.CanOrderBy()) { + if (!window.ArgOrders().empty() && !bound_function.CanOrderBy()) { throw BinderException(error_context, "ORDER BY is not supported for the window function \"%s\"", - window.function_name); + window.FunctionName()); } - if (window.exclude_clause != WindowExcludeMode::NO_OTHER && !window.arg_orders.empty() && + if (window.WindowExclude() != WindowExcludeMode::NO_OTHER && !window.ArgOrders().empty() && !bound_function.CanExclude()) { throw BinderException(error_context, "EXCLUDE is not supported for the window function \"%s\"", - bound_function.name); + window.FunctionName()); } - if (window.has_ignore_nulls && !bound_function.CanIgnoreNulls()) { + if (window.HasIgnoreNulls() && !bound_function.CanIgnoreNulls()) { throw BinderException(error_context, "RESPECT/IGNORE NULLS is not supported for the window function \"%s\"", - bound_function.name); - } - - // Check for unresolved arguments - for (const auto &type : types) { - if (type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } + window.FunctionName()); } - auto window_bound_function = - function_binder.BindWindowFunction(bound_function, std::move(children), window.orders, window.arg_orders); - - window_func = std::move(window_bound_function->window); - bind_info = std::move(window_bound_function->bind_info); - children = std::move(window_bound_function->children); + window_func = std::move(window_bound_function->WindowFunctionMutable()); + bind_info = std::move(window_bound_function->BindInfoMutable()); + children = std::move(window_bound_function->GetChildrenMutable()); sql_type = window_bound_function->GetReturnType(); } + auto result = make_uniq(sql_type, std::move(aggregate), std::move(window_func), std::move(bind_info)); - result->children = std::move(children); - for (auto &child : window.partitions) { - result->partitions.push_back(GetExpression(child)); + result->GetChildrenMutable() = std::move(children); + for (auto &child : window.PartitionsMutable()) { + result->PartitionsMutable().push_back(GetExpression(child)); } - result->ignore_nulls = window.ignore_nulls; - result->distinct = window.distinct; + result->IgnoreNullsMutable() = window.IgnoreNulls(); + result->DistinctMutable() = window.Distinct(); // Convert RANGE boundary expressions to ORDER +/- expressions. // Note that PRECEDING and FOLLOWING refer to the sequential order in the frame, @@ -296,46 +305,50 @@ BindResult BaseSelectBinder::BindWindowExpression(WindowExpression &window, idx_ auto &config = DBConfig::GetConfig(context); auto range_sense = OrderType::INVALID; LogicalType start_type = LogicalType::BIGINT; - if (window.start == WindowBoundary::EXPR_PRECEDING_RANGE) { - D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(context, window.orders[0].type); + if (window.WindowStart() == WindowBoundary::EXPR_PRECEDING_RANGE) { + D_ASSERT(window.OrderBy().size() == 1); + range_sense = config.ResolveOrder(context, window.OrderByMutable()[0].type); const auto range_name = (range_sense == OrderType::ASCENDING) ? "-" : "+"; - start_type = BindRangeExpression(context, range_name, window.start_expr, window.orders[0].expression); + start_type = + BindRangeExpression(context, range_name, window.StartExprMutable(), window.OrderByMutable()[0].expression); - } else if (window.start == WindowBoundary::EXPR_FOLLOWING_RANGE) { - D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(context, window.orders[0].type); + } else if (window.WindowStart() == WindowBoundary::EXPR_FOLLOWING_RANGE) { + D_ASSERT(window.OrderBy().size() == 1); + range_sense = config.ResolveOrder(context, window.OrderByMutable()[0].type); const auto range_name = (range_sense == OrderType::ASCENDING) ? "+" : "-"; - start_type = BindRangeExpression(context, range_name, window.start_expr, window.orders[0].expression); + start_type = + BindRangeExpression(context, range_name, window.StartExprMutable(), window.OrderByMutable()[0].expression); } LogicalType end_type = LogicalType::BIGINT; - if (window.end == WindowBoundary::EXPR_PRECEDING_RANGE) { - D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(context, window.orders[0].type); + if (window.WindowEnd() == WindowBoundary::EXPR_PRECEDING_RANGE) { + D_ASSERT(window.OrderBy().size() == 1); + range_sense = config.ResolveOrder(context, window.OrderByMutable()[0].type); const auto range_name = (range_sense == OrderType::ASCENDING) ? "-" : "+"; - end_type = BindRangeExpression(context, range_name, window.end_expr, window.orders[0].expression); + end_type = + BindRangeExpression(context, range_name, window.EndExprMutable(), window.OrderByMutable()[0].expression); - } else if (window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) { - D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(context, window.orders[0].type); + } else if (window.WindowEnd() == WindowBoundary::EXPR_FOLLOWING_RANGE) { + D_ASSERT(window.OrderBy().size() == 1); + range_sense = config.ResolveOrder(context, window.OrderByMutable()[0].type); const auto range_name = (range_sense == OrderType::ASCENDING) ? "+" : "-"; - end_type = BindRangeExpression(context, range_name, window.end_expr, window.orders[0].expression); + end_type = + BindRangeExpression(context, range_name, window.EndExprMutable(), window.OrderByMutable()[0].expression); } // Cast ORDER and boundary expressions to the same type if (range_sense != OrderType::INVALID) { - D_ASSERT(window.orders.size() == 1); + D_ASSERT(window.OrderBy().size() == 1); - auto &order_expr = window.orders[0].expression; + auto &order_expr = window.OrderByMutable()[0].expression; D_ASSERT(order_expr.get()); D_ASSERT(order_expr->GetExpressionClass() == ExpressionClass::BOUND_EXPRESSION); auto &bound_order = BoundExpression::GetExpression(*order_expr); auto order_type = bound_order->GetReturnType(); - if (window.start_expr) { + if (window.StartExpr()) { order_type = LogicalType::MaxLogicalType(context, order_type, start_type); } - if (window.end_expr) { + if (window.EndExpr()) { order_type = LogicalType::MaxLogicalType(context, order_type, end_type); } @@ -344,33 +357,33 @@ BindResult BaseSelectBinder::BindWindowExpression(WindowExpression &window, idx_ start_type = end_type = order_type; } - for (auto &order : window.orders) { + for (auto &order : window.OrderByMutable()) { auto type = config.ResolveOrder(context, order.type); auto null_order = config.ResolveNullOrder(context, type, order.null_order); auto expression = GetExpression(order.expression); - result->orders.emplace_back(type, null_order, std::move(expression)); + result->OrderByMutable().emplace_back(type, null_order, std::move(expression)); } // Argument orders are just like arguments, not frames - for (auto &order : window.arg_orders) { + for (auto &order : window.ArgOrdersMutable()) { auto type = config.ResolveOrder(context, order.type); auto null_order = config.ResolveNullOrder(context, type, order.null_order); auto expression = GetExpression(order.expression); - result->arg_orders.emplace_back(type, null_order, std::move(expression)); + result->ArgOrdersMutable().emplace_back(type, null_order, std::move(expression)); } - result->filter_expr = CastWindowExpression(window.filter_expr, LogicalType::BOOLEAN); - result->start_expr = CastWindowExpression(window.start_expr, start_type); - result->end_expr = CastWindowExpression(window.end_expr, end_type); - result->start = window.start; - result->end = window.end; - result->exclude_clause = window.exclude_clause; + result->FilterMutable() = CastWindowExpression(window.FilterMutable(), LogicalType::BOOLEAN); + result->StartExprMutable() = CastWindowExpression(window.StartExprMutable(), start_type); + result->EndExprMutable() = CastWindowExpression(window.EndExprMutable(), end_type); + result->WindowStartMutable() = window.WindowStart(); + result->WindowEndMutable() = window.WindowEnd(); + result->WindowExcludeMutable() = window.WindowExclude(); // move the WINDOW expression into the set of bound windows auto &window_type = result->GetReturnType(); auto window_idx = ColumnBinding::PushExpression(node.windows, std::move(result)); - auto colref = make_uniq(std::move(name), window_type, - ColumnBinding(node.window_index, window_idx), depth); + auto colref = + make_uniq(name, window_type, ColumnBinding(node.window_index, window_idx), depth); return BindResult(std::move(colref)); } diff --git a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp index e6cc001ca..2fd957532 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/query_node/update_query_node.hpp" #include "duckdb/parser/query_node/delete_query_node.hpp" #include "duckdb/parser/query_node/insert_query_node.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_materialized_cte.hpp" #include "duckdb/parser/query_node/list.hpp" @@ -14,20 +15,16 @@ namespace duckdb { struct BoundCTEData { - string ctename; + Identifier ctename; CTEMaterialize materialized; TableIndex setop_index; shared_ptr child_binder; shared_ptr cte_bind_state; }; -static bool IsDMLQueryNode(const CommonTableExpressionInfo &cte) { - if (!cte.query_node) { - return false; - } - auto t = cte.query_node->type; +static bool IsDMLQueryNode(QueryNodeType t) { return t == QueryNodeType::INSERT_QUERY_NODE || t == QueryNodeType::UPDATE_QUERY_NODE || - t == QueryNodeType::DELETE_QUERY_NODE; + t == QueryNodeType::DELETE_QUERY_NODE || t == QueryNodeType::MERGE_QUERY_NODE; } BoundStatement Binder::BindNode(QueryNode &node) { @@ -35,7 +32,7 @@ BoundStatement Binder::BindNode(QueryNode &node) { vector bound_ctes; idx_t dml_cte_count = 0; for (auto &cte : node.cte_map.map) { - if (IsDMLQueryNode(*cte.second)) { + if (IsDMLQueryNode(cte.second->query_node->type)) { if (parent && !cte.second->is_trigger_generated) { throw BinderException("WITH clause containing a data-modifying statement must be at the top level"); } @@ -68,6 +65,9 @@ BoundStatement Binder::BindNode(QueryNode &node) { case QueryNodeType::INSERT_QUERY_NODE: result = current_binder.get().BindNode(node.Cast()); break; + case QueryNodeType::MERGE_QUERY_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; default: throw InternalException("Unsupported query node type"); } @@ -90,7 +90,7 @@ BoundStatement Binder::BindNode(QueryNode &node) { return result; } -CTEBindState::CTEBindState(Binder &parent_binder_p, QueryNode &cte_def_p, const vector &aliases_p) +CTEBindState::CTEBindState(Binder &parent_binder_p, QueryNode &cte_def_p, const vector &aliases_p) : parent_binder(parent_binder_p), cte_def(cte_def_p), aliases(aliases_p), active_binder_count(parent_binder.GetActiveBinders().size()) { } @@ -120,7 +120,7 @@ void CTEBindState::Bind(CTEBinding &binding) { // add this CTE to the query binder on the RHS with "CANNOT_BE_REFERENCED" to detect recursive references to // ourselves - query_binder->bind_context.AddCTEBinding(binding.GetIndex(), binding.GetBindingAlias(), vector(), + query_binder->bind_context.AddCTEBinding(binding.GetIndex(), binding.GetBindingAlias(), vector(), vector(), CTEType::CANNOT_BE_REFERENCED); // bind the actual CTE @@ -143,7 +143,7 @@ void CTEBindState::Bind(CTEBinding &binding) { QueryResult::DeduplicateColumns(names); } -BoundCTEData Binder::PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement) { +BoundCTEData Binder::PrepareCTE(const Identifier &ctename, CommonTableExpressionInfo &statement) { BoundCTEData result; // first recursively visit the materialized CTE operations @@ -171,8 +171,7 @@ BoundCTEData Binder::PrepareCTE(const string &ctename, CommonTableExpressionInfo BoundStatement Binder::FinishCTE(BoundCTEData &bound_cte, BoundStatement child) { if (!bound_cte.cte_bind_state->IsBound()) { auto node_type = bound_cte.cte_bind_state->cte_def.type; - bool is_dml = node_type == QueryNodeType::INSERT_QUERY_NODE || node_type == QueryNodeType::UPDATE_QUERY_NODE || - node_type == QueryNodeType::DELETE_QUERY_NODE; + bool is_dml = IsDMLQueryNode(node_type); if (is_dml) { // DML CTEs always execute even if not referenced - force bind now auto dummy_binding = diff --git a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp index 192064e4e..fcac0602e 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp @@ -97,7 +97,7 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { auto bound_expr = expression_binder.Bind(expr); auto &bound_ref = bound_expr->Cast(); - auto column_index = bound_ref.binding.column_index; + auto column_index = bound_ref.Binding().column_index; if (key_references.find(column_index) != key_references.end()) { continue; } @@ -107,27 +107,27 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { } else if (expr->GetExpressionType() == ExpressionType::FUNCTION) { auto &func_expr = expr->Cast(); - if (func_expr.filter) { - throw BinderException(func_expr.filter->GetQueryLocation(), + if (func_expr.Filter()) { + throw BinderException(func_expr.Filter()->GetQueryLocation(), "FILTER clause is not yet supported for aggregates in USING KEY"); } - if (!func_expr.order_bys->orders.empty()) { + if (!func_expr.OrderBy()->orders.empty()) { throw BinderException(func_expr.GetQueryLocation(), "ORDER BY clause is not yet supported for aggregates in USING KEY"); } - if (func_expr.distinct) { + if (func_expr.Distinct()) { throw BinderException(func_expr.GetQueryLocation(), "DISTINCT is not yet supported for aggregates in USING KEY"); } QueryErrorContext error_context(expr->GetQueryLocation()); - EntryLookupInfo function_lookup(CatalogType::AGGREGATE_FUNCTION_ENTRY, func_expr.function_name, + EntryLookupInfo function_lookup(CatalogType::AGGREGATE_FUNCTION_ENTRY, func_expr.FunctionName(), error_context); - auto entry = - GetCatalogEntry(func_expr.catalog, DEFAULT_SCHEMA, function_lookup, OnEntryNotFound::RETURN_NULL); + auto entry = GetCatalogEntry(func_expr.Catalog(), Identifier::DefaultSchema(), function_lookup, + OnEntryNotFound::RETURN_NULL); if (!entry || entry->type != CatalogType::AGGREGATE_FUNCTION_ENTRY) { throw BinderException( @@ -140,9 +140,10 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { vector aggregation_input_types; vector> bound_children; + // Bind the children of the aggregate function - for (auto &child : func_expr.children) { - auto bound_child = expression_binder.Bind(child); + for (auto &child : func_expr.GetArgumentsMutable()) { + auto bound_child = expression_binder.Bind(child.GetExpressionMutable()); aggregation_input_types.push_back(bound_child->GetReturnType()); bound_children.push_back(std::move(bound_child)); } @@ -166,7 +167,7 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { expr->GetQueryLocation(), "In USING KEY, an aggregate must either have a column reference or an alias."); } - aggregate_idx = bound_children[0]->Cast().binding.column_index; + aggregate_idx = bound_children[0]->Cast().Binding().column_index; } // Find the best matching aggregate function diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index 1580f2f33..2a17bcf5c 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -30,6 +30,7 @@ #include "duckdb/planner/expression_binder/select_binder.hpp" #include "duckdb/planner/expression_binder/where_binder.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_sample.hpp" namespace duckdb { @@ -46,9 +47,10 @@ unique_ptr Binder::BindOrderExpression(OrderBinder &order_binder, un } BoundLimitNode Binder::BindLimitValue(OrderBinder &order_binder, unique_ptr limit_val, - bool is_percentage, bool is_offset) { + LimitValueType value_type, bool is_offset) { auto new_binder = Binder::CreateBinder(context, this); ExpressionBinder expr_binder(*new_binder, context); + bool is_percentage = value_type == LimitValueType::PERCENTAGE; auto target_type = is_percentage ? LogicalType::DOUBLE : LogicalType::BIGINT; expr_binder.target_type = target_type; auto original_limit = limit_val->Copy(); @@ -107,21 +109,10 @@ BoundLimitNode Binder::BindLimitValue(OrderBinder &order_binder, unique_ptr Binder::BindLimit(OrderBinder &order_binder, LimitModifier &limit_mod) { auto result = make_uniq(); if (limit_mod.limit) { - result->limit_val = BindLimitValue(order_binder, std::move(limit_mod.limit), false, false); + result->limit_val = BindLimitValue(order_binder, std::move(limit_mod.limit), limit_mod.limit_type, false); } if (limit_mod.offset) { - result->offset_val = BindLimitValue(order_binder, std::move(limit_mod.offset), false, true); - } - return std::move(result); -} - -unique_ptr Binder::BindLimitPercent(OrderBinder &order_binder, LimitPercentModifier &limit_mod) { - auto result = make_uniq(); - if (limit_mod.limit) { - result->limit_val = BindLimitValue(order_binder, std::move(limit_mod.limit), true, false); - } - if (limit_mod.offset) { - result->offset_val = BindLimitValue(order_binder, std::move(limit_mod.offset), false, true); + result->offset_val = BindLimitValue(order_binder, std::move(limit_mod.offset), LimitValueType::ROW_COUNT, true); } return std::move(result); } @@ -177,7 +168,7 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B auto &order_binders = order_binder.GetBinders(); if (order.orders.size() == 1 && order.orders[0].expression->GetExpressionType() == ExpressionType::STAR) { auto &star = order.orders[0].expression->Cast(); - if (star.exclude_list.empty() && star.replace_list.empty() && !star.expr) { + if (star.ExcludeList().empty() && star.ReplaceList().empty() && !star.Expression()) { // ORDER BY ALL // replace the order list with the all elements in the SELECT list auto order_type = config.ResolveOrder(context, order.orders[0].type); @@ -240,9 +231,6 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B case ResultModifierType::LIMIT_MODIFIER: bound_modifier = BindLimit(order_binder, mod->Cast()); break; - case ResultModifierType::LIMIT_PERCENT_MODIFIER: - bound_modifier = BindLimitPercent(order_binder, mod->Cast()); - break; default: throw InternalException("Unsupported result modifier"); } @@ -252,7 +240,7 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B } } -static unique_ptr CreateOrderExpression(unique_ptr expr, const vector &names, +static unique_ptr CreateOrderExpression(unique_ptr expr, const vector &names, const vector &sql_types, TableIndex table_index, ProjectionIndex index) { if (index >= sql_types.size()) { @@ -267,14 +255,14 @@ static unique_ptr CreateOrderExpression(unique_ptr expr, } static unique_ptr FinalizeBindOrderExpression(unique_ptr expr, TableIndex table_index, - const vector &names, + const vector &names, const vector &sql_types, const SelectBindState &bind_state) { auto &constant = expr->Cast(); - switch (constant.value.type().id()) { + switch (constant.GetValue().type().id()) { case LogicalTypeId::UBIGINT: { // index - auto index = UBigIntValue::Get(constant.value); + auto index = UBigIntValue::Get(constant.GetValue()); return CreateOrderExpression(std::move(expr), names, sql_types, table_index, bind_state.GetFinalIndex(index)); } case LogicalTypeId::VARCHAR: { @@ -283,7 +271,7 @@ static unique_ptr FinalizeBindOrderExpression(unique_ptr } case LogicalTypeId::STRUCT: { // collation - auto &struct_values = StructValue::GetChildren(constant.value); + auto &struct_values = StructValue::GetChildren(constant.GetValue()); if (struct_values.size() > 2) { throw InternalException("Expected one or two children: index and optional collation"); } @@ -306,7 +294,7 @@ static unique_ptr FinalizeBindOrderExpression(unique_ptr } } -static void AssignReturnType(unique_ptr &expr, TableIndex table_index, const vector &names, +static void AssignReturnType(unique_ptr &expr, TableIndex table_index, const vector &names, const vector &sql_types, const SelectBindState &bind_state) { if (!expr) { return; @@ -318,10 +306,10 @@ static void AssignReturnType(unique_ptr &expr, TableIndex table_inde return; } auto &bound_colref = expr->Cast(); - bound_colref.SetReturnType(sql_types[bound_colref.binding.column_index]); + bound_colref.SetReturnType(sql_types[bound_colref.Binding().column_index]); } -void Binder::BindModifiers(BoundQueryNode &result, TableIndex table_index, const vector &names, +void Binder::BindModifiers(BoundQueryNode &result, TableIndex table_index, const vector &names, const vector &sql_types, const SelectBindState &bind_state) { for (auto &bound_mod : result.modifiers) { switch (bound_mod->type) { @@ -394,14 +382,14 @@ void Binder::BindWhereStarExpression(unique_ptr &expr) { // expand any expressions in the upper AND recursively if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { auto &conj = expr->Cast(); - for (auto &child : conj.children) { + for (auto &child : conj.GetChildrenMutable()) { BindWhereStarExpression(child); } return; } if (expr->GetExpressionType() == ExpressionType::STAR) { auto &star = expr->Cast(); - if (!star.columns) { + if (!star.IsColumns()) { throw ParserException("STAR expression is not allowed in the WHERE clause. Use COLUMNS(*) instead."); } } @@ -421,6 +409,13 @@ void Binder::BindWhereStarExpression(unique_ptr &expr) { } } +Identifier Binder::GetExpressionName(const ParsedExpression &expr) { + if (!expr.GetAlias().empty()) { + return expr.GetAlias(); + } + return expr.GetName(); +} + BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from_table) { D_ASSERT(from_table.plan); D_ASSERT(!statement.from_table); @@ -436,7 +431,8 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from result.from_table = std::move(from_table); // bind the sample clause if (statement.sample) { - result.sample_options = std::move(statement.sample); + result.from_table.plan = + make_uniq(std::move(statement.sample), std::move(result.from_table.plan)); } // visit the select list and expand any "*" statements @@ -451,7 +447,7 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from auto &bind_state = result.bind_state; for (idx_t i = 0; i < statement.select_list.size(); i++) { auto &expr = statement.select_list[i]; - result.names.push_back(expr->GetName()); + result.names.emplace_back(GetExpressionName(*expr)); ExpressionBinder::QualifyColumnNames(*this, expr); if (!expr->GetAlias().empty()) { bind_state.alias_map[expr->GetAlias()] = i; @@ -462,40 +458,48 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from } result.column_count = statement.select_list.size(); - // first visit the WHERE clause - // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses if (statement.where_clause) { // bind any star expressions in the WHERE clause BindWhereStarExpression(statement.where_clause); + } + for (idx_t i = 0; i < statement.groups.group_expressions.size(); i++) { + auto &grp = statement.groups.group_expressions[i]; + ExpressionBinder::QualifyColumnNames(*this, grp); - ColumnAliasBinder alias_binder(bind_state); - WhereBinder where_binder(*this, context, &alias_binder); - unique_ptr condition = std::move(statement.where_clause); - result.where_clause = where_binder.Bind(condition); + GroupBinder::ReplaceSelectRef(statement, bind_state, ProjectionIndex(i), grp); + + // set up a mapping of expression -> group index so we can map expressions in the select list / having to groups + auto grp_copy = grp->Copy(); + bind_state.group_map[*grp_copy] = ProjectionIndex(i); + bind_state.unbound_groups.push_back(std::move(grp_copy)); + } + + if (statement.qualify) { + ExpressionBinder::QualifyColumnNames(*this, statement.qualify); } - // now bind all the result modifiers; including DISTINCT and ORDER BY targets + // prepare binding of all the result modifiers; including DISTINCT and ORDER BY targets OrderBinder order_binder({*this}, statement, bind_state); PrepareModifiers(order_binder, statement, result); - vector> unbound_groups; - BoundGroupInformation info; + // first visit the WHERE clause + // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses + if (statement.where_clause) { + ColumnAliasBinder alias_binder(bind_state); + WhereBinder where_binder(*this, context, alias_binder); + unique_ptr condition = std::move(statement.where_clause); + auto where_clause = where_binder.Bind(condition); + result.from_table.plan = PlanFilter(std::move(where_clause), std::move(result.from_table.plan)); + } + auto &group_expressions = statement.groups.group_expressions; if (!group_expressions.empty()) { // the statement has a GROUP BY clause, bind it - unbound_groups.resize(group_expressions.size()); - GroupBinder group_binder(*this, context, statement, result.group_index, bind_state, info.alias_map); + GroupBinder group_binder(*this, context, result.group_index, bind_state); // Allow NULL constants in GROUP BY to maintain their SQLNULL type - auto prev_can_contain_nulls = this->can_contain_nulls; - this->can_contain_nulls = true; + auto prev_can_contain_nulls = CanContainNulls(); + SetCanContainNulls(true); for (idx_t i = 0; i < group_expressions.size(); i++) { - // we keep a copy of the unbound expression; - // we keep the unbound copy around to check for group references in the SELECT and HAVING clause - // the reason we want the unbound copy is because we want to figure out whether an expression - // is a group reference BEFORE binding in the SELECT/HAVING binder - group_binder.unbound_expression = group_expressions[i]->Copy(); - group_binder.bind_index = ProjectionIndex(i); - // bind the groups LogicalType group_type; auto bound_expr = group_binder.Bind(group_expressions[i], &group_type); @@ -521,25 +525,17 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from function->SetAlias("__collated_group"); auto collated_idx = ColumnBinding::PushExpression(result.aggregates, std::move(function)); - info.collated_groups[ProjectionIndex(i)] = collated_idx; + bind_state.collated_groups[ProjectionIndex(i)] = collated_idx; } result.groups.group_expressions.push_back(std::move(bound_expr)); - - // in the unbound expression we DO bind the table names of any ColumnRefs - // we do this to make sure that "table.a" and "a" are treated the same - // if we wouldn't do this then (SELECT test.a FROM test GROUP BY a) would not work because "test.a" <> "a" - // hence we convert "a" -> "test.a" in the unbound expression - unbound_groups[i] = std::move(group_binder.unbound_expression); - ExpressionBinder::QualifyColumnNames(*this, unbound_groups[i]); - info.map[*unbound_groups[i]] = ProjectionIndex(i); } - this->can_contain_nulls = prev_can_contain_nulls; + SetCanContainNulls(prev_can_contain_nulls); } result.groups.grouping_sets = std::move(statement.groups.grouping_sets); // bind the HAVING clause, if any if (statement.having) { - HavingBinder having_binder(*this, context, result, info, statement.aggregate_handling); + HavingBinder having_binder(*this, context, result, statement.aggregate_handling); ExpressionBinder::QualifyColumnNames(having_binder, statement.having); result.having = having_binder.Bind(statement.having); } @@ -550,8 +546,7 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { throw BinderException("Combining QUALIFY with GROUP BY ALL is not supported yet"); } - QualifyBinder qualify_binder(*this, context, result, info); - ExpressionBinder::QualifyColumnNames(*this, statement.qualify); + QualifyBinder qualify_binder(*this, context, result); result.qualify = qualify_binder.Bind(statement.qualify); if (qualify_binder.HasBoundColumns()) { if (qualify_binder.BoundAggregates()) { @@ -562,12 +557,12 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from } // after that, we bind to the SELECT list - SelectBinder select_binder(*this, context, result, info); + SelectBinder select_binder(*this, context, result); // if we expand select-list expressions, e.g., via UNNEST, then we need to possibly // adjust the column index of the already bound ORDER BY modifiers, and not only set their types vector group_by_all_indexes; - vector new_names; + vector new_names; vector internal_sql_types; for (idx_t i = 0; i < statement.select_list.size(); i++) { @@ -589,11 +584,11 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from } auto &expanded = expr->Cast(); - auto &struct_expressions = expanded.expanded_expressions; + auto &struct_expressions = expanded.GetChildrenMutable(); D_ASSERT(!struct_expressions.empty()); for (auto &struct_expr : struct_expressions) { - new_names.push_back(struct_expr->GetName()); + new_names.emplace_back(struct_expr->GetName()); result.types.push_back(struct_expr->GetReturnType()); internal_sql_types.push_back(struct_expr->GetReturnType()); result.select_list.push_back(std::move(struct_expr)); diff --git a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp index bb20e6846..d65dacecf 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp @@ -21,7 +21,7 @@ struct SetOpAliasGatherer { } void GatherAliases(BoundStatement &stmt, const vector &reorder_idx); - void GatherSetOpAliases(SetOperationType setop_type, const vector &names, + void GatherSetOpAliases(SetOperationType setop_type, const vector &names, vector &bound_children, const vector &reorder_idx); private: @@ -70,13 +70,13 @@ void SetOpAliasGatherer::GatherAliases(BoundStatement &stmt, const vector } } -void SetOpAliasGatherer::GatherSetOpAliases(SetOperationType setop_type, const vector &stmt_names, +void SetOpAliasGatherer::GatherSetOpAliases(SetOperationType setop_type, const vector &stmt_names, vector &bound_children, const vector &reorder_idx) { // create new reorder index if (setop_type == SetOperationType::UNION_BY_NAME) { auto &setop_names = stmt_names; // for UNION BY NAME - create a new re-order index - case_insensitive_map_t reorder_map; + identifier_map_t reorder_map; for (idx_t col_idx = 0; col_idx < setop_names.size(); ++col_idx) { reorder_map[setop_names[col_idx]] = reorder_idx[col_idx]; } @@ -114,15 +114,15 @@ static void GatherAliases(BoundSetOperationNode &root, vector &c void Binder::BuildUnionByNameInfo(BoundSetOperationNode &result) { D_ASSERT(result.setop_type == SetOperationType::UNION_BY_NAME); - vector> node_name_maps; - case_insensitive_set_t global_name_set; + vector> node_name_maps; + identifier_set_t global_name_set; // Build a name_map to use to check if a name exists // We throw a binder exception if two same name in the SELECT list D_ASSERT(result.names.empty()); for (auto &child : result.bound_children) { auto &child_names = child.names; - case_insensitive_map_t node_name_map; + identifier_map_t node_name_map; for (idx_t i = 0; i < child_names.size(); ++i) { auto &col_name = child_names[i]; if (node_name_map.find(col_name) != node_name_map.end()) { @@ -133,7 +133,7 @@ void Binder::BuildUnionByNameInfo(BoundSetOperationNode &result) { } if (global_name_set.find(col_name) == global_name_set.end()) { // column is not yet present in the result - result.names.push_back(col_name); + result.names.emplace_back(col_name); global_name_set.insert(col_name); } node_name_map[col_name] = ProjectionIndex(i); @@ -162,7 +162,7 @@ void Binder::BuildUnionByNameInfo(BoundSetOperationNode &result) { if (result_type.id() == LogicalTypeId::INVALID) { result_type = child_col_type; } else { - result_type = LogicalType::ForceMaxLogicalType(result_type, child_col_type); + result_type = LogicalType::ForceMaxLogicalType(context, result_type, child_col_type); } if (i != col_idx_in_child) { // the column exists - but the children are out-of-order, so we need to re-order anyway @@ -171,7 +171,7 @@ void Binder::BuildUnionByNameInfo(BoundSetOperationNode &result) { } } // compute the final type for each column - if (!can_contain_nulls) { + if (!CanContainNulls()) { if (ExpressionBinder::ContainsNullType(result_type)) { result_type = ExpressionBinder::ExchangeNullType(result_type); } @@ -252,7 +252,7 @@ BoundStatement Binder::BindNode(SetOperationNode &statement) { } for (auto &child : statement.children) { auto child_binder = Binder::CreateBinder(context, this); - child_binder->can_contain_nulls = true; + child_binder->SetCanContainNulls(true); auto child_node = child_binder->BindNode(*child); MoveCorrelatedExpressions(*child_binder); result.bound_children.push_back(std::move(child_node)); @@ -278,9 +278,9 @@ BoundStatement Binder::BindNode(SetOperationNode &statement) { auto result_type = result.bound_children[0].types[i]; for (idx_t child_idx = 1; child_idx < result.bound_children.size(); ++child_idx) { auto &child_types = result.bound_children[child_idx].types; - result_type = LogicalType::ForceMaxLogicalType(result_type, child_types[i]); + result_type = LogicalType::ForceMaxLogicalType(context, result_type, child_types[i]); } - if (!can_contain_nulls) { + if (!CanContainNulls()) { if (ExpressionBinder::ContainsNullType(result_type)) { result_type = ExpressionBinder::ExchangeNullType(result_type); } diff --git a/src/duckdb/src/planner/binder/query_node/bind_table_macro_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_table_macro_node.cpp index 034918070..52387f016 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_table_macro_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_table_macro_node.cpp @@ -20,7 +20,7 @@ unique_ptr Binder::BindTableMacro(FunctionExpression &function, Table idx_t depth) { // validate the arguments and separate positional and default arguments vector> positional_arguments; - InsertionOrderPreservingMap> named_arguments; + InsertionOrderPreservingMap, Identifier, identifier_map_t> named_arguments; auto bind_result = MacroFunction::BindMacroFunction(*this, macro_func.macros, macro_func.name, function, positional_arguments, named_arguments, depth); @@ -37,7 +37,7 @@ unique_ptr Binder::BindTableMacro(FunctionExpression &function, Table auto eb = ExpressionBinder(*this, this->context); eb.macro_binding = new_macro_binding.get(); - vector> lambda_params; + vector lambda_params; auto node = macro_def.Cast().query_node->Copy(); ParsedExpressionIterator::EnumerateQueryNodeChildren( diff --git a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp index e85641013..8e5fe772c 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp @@ -5,6 +5,13 @@ #include "duckdb/planner/operator/logical_dummy_scan.hpp" #include "duckdb/planner/operator/logical_limit.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/function/scalar/generic_common.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/function/scalar/system_functions.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/main/settings.hpp" +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" namespace duckdb { @@ -15,19 +22,113 @@ unique_ptr Binder::PlanFilter(unique_ptr condition, return std::move(filter); } -unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { - D_ASSERT(statement.from_table.plan); - auto root = std::move(statement.from_table.plan); +bool Binder::DebugAggregateStateExportVerify(BoundSelectNode &statement, unique_ptr &root) { + if (!Settings::Get(context)) { + return false; + } + if (statement.groups.grouping_sets.size() > 1 || !statement.grouping_functions.empty()) { + // verification is not supported with multiple grouping sets / grouping functions (yet?) + return false; + } + for (auto &aggr : statement.aggregates) { + if (aggr->GetAlias() == "__collated_group") { + // collated GROUP BY adds first() aggregates to preserve original values - skip verification since + // ExtractPivotAggregates relies on the __collated_group alias to filter them out + return false; + } + auto &res_type = aggr->GetReturnType(); + if (res_type.IsAggregateState()) { + // already an aggregate state export - skip verification + return false; + } + } - // plan the sample clause - if (statement.sample_options) { - root = make_uniq(std::move(statement.sample_options), std::move(root)); + // we transform a query like "SELECT grp, SUM(col) FROM tbl GROUP BY grp" into + // "SELECT grp, first(finalize(sum_state)) FROM (SELECT grp, SUM(col) export_state FROM tbl GROUP BY grp) GROUP BY + // grp the second grouping is technically not necessary - we add this just so that subsequent binding remains + // correct first push an aggregate that actually does the state export + vector> finalize_expressions; + vector> final_group_expressions; + vector> final_aggregates; + auto intermediate_group_index = GenerateTableIndex(); + auto intermediate_aggregate_index = GenerateTableIndex(); + auto intermediate_proj_index = GenerateTableIndex(); + // push references to any groups first + for (idx_t grp_idx = 0; grp_idx < statement.groups.group_expressions.size(); grp_idx++) { + auto &group = statement.groups.group_expressions[grp_idx]; + auto group_ref = make_uniq( + group->GetReturnType(), ColumnBinding(intermediate_group_index, ProjectionIndex(grp_idx))); + + auto final_group_ref = make_uniq( + group->GetReturnType(), + ColumnBinding(intermediate_proj_index, ProjectionIndex(finalize_expressions.size()))); + + finalize_expressions.push_back(std::move(group_ref)); + final_group_expressions.push_back(std::move(final_group_ref)); } + auto &system_catalog = Catalog::GetSystemCatalog(context); + ErrorData error; + FunctionBinder function_binder(context); + for (idx_t aggr_idx = 0; aggr_idx < statement.aggregates.size(); aggr_idx++) { + auto &aggr = statement.aggregates[aggr_idx]; + auto aggr_expr = unique_ptr_cast(std::move(aggr)); + aggr = ExportAggregateFunction::Bind(std::move(aggr_expr)); + + // create a reference to the exported aggregate state + auto aggr_state_ref = make_uniq( + aggr->GetReturnType(), ColumnBinding(intermediate_aggregate_index, ProjectionIndex(aggr_idx))); + + // create the "finalize" expression + vector> finalize_children; + finalize_children.push_back(std::move(aggr_state_ref)); + auto &finalize_function = + system_catalog.GetEntry(context, DEFAULT_SCHEMA, "finalize"); + auto result = + function_binder.BindScalarFunction(finalize_function, std::move(finalize_children), error, false, *this); + if (!result) { + throw InternalException("Failed to bind finalize during debug verification - %s", error.Message()); + } + + // create the "first" expression + auto &first_function = system_catalog.GetEntry(context, DEFAULT_SCHEMA, "first"); + auto finalize_ref = make_uniq( + result->GetReturnType(), + ColumnBinding(intermediate_proj_index, ProjectionIndex(finalize_expressions.size()))); - if (statement.where_clause) { - root = PlanFilter(std::move(statement.where_clause), std::move(root)); + vector>> first_children; + first_children.emplace_back(Identifier(), std::move(finalize_ref)); + auto first_fun = function_binder.BindAggregateFunction(first_function, std::move(first_children), error); + + finalize_expressions.push_back(std::move(result)); + final_aggregates.push_back(std::move(first_fun)); } + // create the initial aggregate that performs the aggregate state export + auto aggregate = make_uniq(intermediate_group_index, intermediate_aggregate_index, + std::move(statement.aggregates)); + aggregate->groups = std::move(statement.groups.group_expressions); + aggregate->grouping_sets = statement.groups.grouping_sets; + aggregate->AddChild(std::move(root)); + + // now push the projection that finalizes the states + auto proj = make_uniq(intermediate_proj_index, std::move(finalize_expressions)); + proj->AddChild(std::move(aggregate)); + + // in order for bindings after this point to be correct, we need to push another aggregate + // this just does grp1, grp2, FIRST(finalize_result), etc + auto final_aggr = + make_uniq(statement.group_index, statement.aggregate_index, std::move(final_aggregates)); + final_aggr->groups = std::move(final_group_expressions); + final_aggr->grouping_sets = std::move(statement.groups.grouping_sets); + final_aggr->AddChild(std::move(proj)); + root = std::move(final_aggr); + return true; +} + +unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { + D_ASSERT(statement.from_table.plan); + auto root = std::move(statement.from_table.plan); + if (!statement.aggregates.empty() || !statement.groups.group_expressions.empty() || statement.having) { if (!statement.groups.group_expressions.empty()) { // visit the groups @@ -39,16 +140,18 @@ unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { for (auto &expr : statement.aggregates) { PlanSubqueries(expr, root); } - // finally create the aggregate node with the group_index and aggregate_index as obtained from the binder - auto aggregate = make_uniq(statement.group_index, statement.aggregate_index, - std::move(statement.aggregates)); - aggregate->groups = std::move(statement.groups.group_expressions); - aggregate->groupings_index = statement.groupings_index; - aggregate->grouping_sets = std::move(statement.groups.grouping_sets); - aggregate->grouping_functions = std::move(statement.grouping_functions); - - aggregate->AddChild(std::move(root)); - root = std::move(aggregate); + + if (!DebugAggregateStateExportVerify(statement, root)) { + // finally create the aggregate node with the group_index and aggregate_index as obtained from the binder + auto aggregate = make_uniq(statement.group_index, statement.aggregate_index, + std::move(statement.aggregates)); + aggregate->groups = std::move(statement.groups.group_expressions); + aggregate->groupings_index = statement.groupings_index; + aggregate->grouping_sets = std::move(statement.groups.grouping_sets); + aggregate->grouping_functions = std::move(statement.grouping_functions); + aggregate->AddChild(std::move(root)); + root = std::move(aggregate); + } } else if (!statement.groups.grouping_sets.empty()) { // edge case: we have grouping sets but no groups or aggregates // this can only happen if we have e.g. select 1 from tbl group by (); diff --git a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp index a1a7f60b0..96001104d 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp @@ -35,7 +35,7 @@ unique_ptr Binder::CastLogicalOperatorToTypes(const vectorexpressions.size(); i++) { if (op->expressions[i]->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &col_ref = op->expressions[i]->Cast(); - auto column_id = column_ids[col_ref.binding.column_index].GetPrimaryIndex(); + auto column_id = column_ids[col_ref.Binding().column_index].GetPrimaryIndex(); if (new_column_types.find(column_id) != new_column_types.end()) { // Only one reference per column is accepted do_pushdown = false; @@ -61,7 +61,7 @@ unique_ptr Binder::CastLogicalOperatorToTypes(const vectorexpressions[i]->GetAlias(); + auto cur_alias = node->expressions[i]->GetAlias(); node->expressions[i] = BoundCastExpression::AddCastToType(context, std::move(node->expressions[i]), target_types[i]); node->expressions[i]->SetAlias(cur_alias); diff --git a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp index 9855663eb..3f6fe0a15 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp @@ -28,7 +28,7 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq unique_ptr &root, unique_ptr plan) { D_ASSERT(!expr.IsCorrelated()); - switch (expr.subquery_type) { + switch (expr.GetSubqueryType()) { case SubqueryType::EXISTS: { // uncorrelated EXISTS // we only care about existence, hence we push a LIMIT 1 operator @@ -71,7 +71,7 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq // we replace the original subquery with a ColumnRefExpression referring to the result of the projection (either // TRUE or FALSE) - return make_uniq(expr.GetName(), LogicalType::BOOLEAN, + return make_uniq(Identifier(expr.GetName()), LogicalType::BOOLEAN, ColumnBinding(projection_index, ProjectionIndex(0))); } case SubqueryType::SCALAR: { @@ -149,11 +149,11 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq // we replace the original subquery with a BoundColumnRefExpression referring to the first result of the // aggregation - return make_uniq(expr.GetName(), expr.GetReturnType(), + return make_uniq(Identifier(expr.GetName()), expr.GetReturnType(), ColumnBinding(aggr_index, ProjectionIndex(0))); } default: { - D_ASSERT(expr.subquery_type == SubqueryType::ANY); + D_ASSERT(expr.GetSubqueryType() == SubqueryType::ANY); // we generate a MARK join that results in either (TRUE, FALSE or NULL) // subquery has NULL values -> result is (TRUE or NULL) // subquery has no NULL values -> result is (TRUE, FALSE or NULL [if input is NULL]) @@ -171,13 +171,13 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq // Special case: if we have a single struct child and multiple types, // this means we kept the struct intact for ordered comparison (e.g., (a,b) < ANY(...)) // We need to construct a corresponding struct on the RHS from the subquery columns - if (expr.children.size() == 1 && expr.child_types.size() > 1) { + if (expr.GetChildren().size() == 1 && expr.GetChildTypes().size() > 1) { // Construct a struct on the RHS from the subquery columns vector> struct_children; - struct_children.reserve(expr.child_types.size()); - for (idx_t i = 0; i < expr.child_types.size(); i++) { - auto &child_type = expr.child_types[i]; - auto &compare_type = expr.child_targets[i]; + struct_children.reserve(expr.GetChildTypes().size()); + for (idx_t i = 0; i < expr.GetChildTypes().size(); i++) { + auto &child_type = expr.GetChildTypes()[i]; + auto &compare_type = expr.GetChildTargets()[i]; auto colref = BoundCastExpression::AddDefaultCastToType( make_uniq(child_type, plan_columns[i]), compare_type); struct_children.push_back(std::move(colref)); @@ -187,7 +187,7 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq FunctionBinder function_binder(binder); auto struct_expr = function_binder.BindScalarFunction(RowFun::GetFunction(), std::move(struct_children)); - JoinCondition cond(std::move(expr.children[0]), std::move(struct_expr), expr.comparison_type); + JoinCondition cond(std::move(expr.GetChildrenMutable()[0]), std::move(struct_expr), expr.ComparisonType()); // push collations ExpressionBinder::PushCollation(binder.context, cond.LeftReference(), cond.GetLHS().GetReturnType()); @@ -196,12 +196,13 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq join->conditions.push_back(std::move(cond)); } else { // Standard case: compare each child separately - for (idx_t child_idx = 0; child_idx < expr.children.size(); child_idx++) { - auto &child_type = expr.child_types[child_idx]; - auto &compare_type = expr.child_targets[child_idx]; + for (idx_t child_idx = 0; child_idx < expr.GetChildren().size(); child_idx++) { + auto &child_type = expr.GetChildTypes()[child_idx]; + auto &compare_type = expr.GetChildTargets()[child_idx]; auto right_expr = BoundCastExpression::AddDefaultCastToType( make_uniq(child_type, plan_columns[child_idx]), compare_type); - JoinCondition cond(std::move(expr.children[child_idx]), std::move(right_expr), expr.comparison_type); + JoinCondition cond(std::move(expr.GetChildrenMutable()[child_idx]), std::move(right_expr), + expr.ComparisonType()); // push collations ExpressionBinder::PushCollation(binder.context, cond.LeftReference(), compare_type); @@ -213,7 +214,7 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq root = std::move(join); // we replace the original subquery with a BoundColumnRefExpression referring to the mark column - return make_uniq(expr.GetName(), expr.GetReturnType(), + return make_uniq(Identifier(expr.GetName()), expr.GetReturnType(), ColumnBinding(mark_index, ProjectionIndex(0))); } } @@ -251,7 +252,7 @@ static bool PerformDelimOnType(const LogicalType &type) { } static bool PerformDuplicateElimination(Binder &binder, CorrelatedColumns &correlated_columns) { - if (!ClientConfig::GetConfig(binder.context).enable_optimizer) { + if (!Settings::Get(binder.context)) { // if optimizations are disabled we always do a delim join return true; } @@ -277,17 +278,17 @@ static bool PerformDuplicateElimination(Binder &binder, CorrelatedColumns &corre static unique_ptr PlanCorrelatedSubquery(Binder &binder, BoundSubqueryExpression &expr, unique_ptr &root, unique_ptr plan) { - auto &correlated_columns = expr.binder->correlated_columns; + auto &correlated_columns = expr.GetBinder()->correlated_columns; // FIXME: there should be a way of disabling decorrelation for ANY queries as well, but not for now... bool perform_delim = - expr.subquery_type == SubqueryType::ANY ? true : PerformDuplicateElimination(binder, correlated_columns); + expr.GetSubqueryType() == SubqueryType::ANY ? true : PerformDuplicateElimination(binder, correlated_columns); D_ASSERT(expr.IsCorrelated()); // correlated subquery // for a more in-depth explanation of this code, read the paper "Unnesting Arbitrary Subqueries" // also read "Improving Unnesting of Complex Queries" // we handle three types of correlated subqueries: Scalar, EXISTS and ANY // all three cases are very similar with some minor changes (mainly the type of join performed at the end) - switch (expr.subquery_type) { + switch (expr.GetSubqueryType()) { case SubqueryType::SCALAR: { // correlated SCALAR query // first push a DUPLICATE ELIMINATED join @@ -311,7 +312,7 @@ static unique_ptr PlanCorrelatedSubquery(Binder &binder, BoundSubque delim_join->AddChild(std::move(plan)); root = std::move(delim_join); // finally push the BoundColumnRefExpression referring to the data element returned by the join - return make_uniq(expr.GetName(), expr.GetReturnType(), plan_column); + return make_uniq(Identifier(expr.GetName()), expr.GetReturnType(), plan_column); } case SubqueryType::EXISTS: { // correlated EXISTS query @@ -326,11 +327,11 @@ static unique_ptr PlanCorrelatedSubquery(Binder &binder, BoundSubque delim_join->AddChild(std::move(plan)); root = std::move(delim_join); // finally push the BoundColumnRefExpression referring to the marker - return make_uniq(expr.GetName(), expr.GetReturnType(), + return make_uniq(Identifier(expr.GetName()), expr.GetReturnType(), ColumnBinding(mark_index, ProjectionIndex(0))); } default: { - D_ASSERT(expr.subquery_type == SubqueryType::ANY); + D_ASSERT(expr.GetSubqueryType() == SubqueryType::ANY); // correlated ANY query // this query is similar to the correlated SCALAR query // however, in this case we push a correlated MARK join @@ -347,7 +348,7 @@ static unique_ptr PlanCorrelatedSubquery(Binder &binder, BoundSubque delim_join->any_join = true; auto &dependent_join = plan; - if (expr.children.size() > 1) { + if (expr.GetChildren().size() > 1) { // FIXME: the code to generate the plan here is actually correct // the problem is in the hash join - specifically PhysicalHashJoin::InitializeHashTable // this contains code that is hard-coded for a single comparison @@ -356,15 +357,15 @@ static unique_ptr PlanCorrelatedSubquery(Binder &binder, BoundSubque throw NotImplementedException("Correlated IN/ANY/ALL with multiple columns not yet supported"); } - delim_join->expression_children = std::move(expr.children); - delim_join->child_types = expr.child_types; - delim_join->child_targets = expr.child_targets; - delim_join->comparison_type = expr.comparison_type; + delim_join->expression_children = std::move(expr.GetChildrenMutable()); + delim_join->child_types = expr.GetChildTypes(); + delim_join->child_targets = expr.GetChildTargets(); + delim_join->comparison_type = expr.ComparisonType(); delim_join->AddChild(std::move(dependent_join)); root = std::move(delim_join); // finally push the BoundColumnRefExpression referring to the marker - return make_uniq(expr.GetName(), expr.GetReturnType(), + return make_uniq(Identifier(expr.GetName()), expr.GetReturnType(), ColumnBinding(mark_index, ProjectionIndex(0))); } } @@ -407,7 +408,7 @@ unique_ptr Binder::PlanSubquery(BoundSubqueryExpression &expr, uniqu // first we translate the QueryNode of the subquery into a logical plan auto sub_binder = Binder::CreateBinder(context, this); sub_binder->is_outside_flattened = false; - auto subquery_root = std::move(expr.subquery.plan); + auto subquery_root = std::move(expr.SubqueryMutable().plan); D_ASSERT(subquery_root); // now we actually flatten the subquery diff --git a/src/duckdb/src/planner/binder/statement/bind_connect.cpp b/src/duckdb/src/planner/binder/statement/bind_connect.cpp new file mode 100644 index 000000000..3f93062ef --- /dev/null +++ b/src/duckdb/src/planner/binder/statement/bind_connect.cpp @@ -0,0 +1,34 @@ +#include "duckdb/parser/statement/connect_statement.hpp" +#include "duckdb/parser/statement/disconnect_statement.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_simple.hpp" + +namespace duckdb { + +BoundStatement Binder::Bind(ConnectStatement &stmt) { + BoundStatement result; + result.types = {LogicalType::BOOLEAN}; + result.names = {"Success"}; + + result.plan = make_uniq(LogicalOperatorType::LOGICAL_CONNECT, std::move(stmt.info)); + + auto &properties = GetStatementProperties(); + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +BoundStatement Binder::Bind(DisconnectStatement &stmt) { + BoundStatement result; + result.types = {LogicalType::BOOLEAN}; + result.names = {"Success"}; + + result.plan = make_uniq(LogicalOperatorType::LOGICAL_DISCONNECT, std::move(stmt.info)); + + auto &properties = GetStatementProperties(); + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + properties.return_type = StatementReturnType::NOTHING; + return result; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_copy.cpp b/src/duckdb/src/planner/binder/statement/bind_copy.cpp index 2fd93069d..e432f5301 100644 --- a/src/duckdb/src/planner/binder/statement/bind_copy.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_copy.cpp @@ -22,6 +22,7 @@ #include "duckdb/planner/operator/logical_insert.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/expression_binder/table_function_binder.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" #include "duckdb/common/algorithm.hpp" #include "duckdb/main/extension_entries.hpp" @@ -38,7 +39,7 @@ void IsFormatExtensionKnown(const string &format) { // It's a match, we must throw throw CatalogException( "Copy Function with name \"%s\" is not in the catalog, but it exists in the %s extension.", format, - std::string(file_postfixes.extension)); + file_postfixes.extension); } } } @@ -80,6 +81,7 @@ case_insensitive_map_t Binder::GetFullCopyOptionsList(const CopyFunc copy_options["row_groups_per_file"] = CopyOption(LogicalType::UBIGINT, CopyOptionMode::WRITE_ONLY); copy_options["file_size_bytes"] = CopyOption(LogicalType::ANY, CopyOptionMode::WRITE_ONLY); copy_options["partition_by"] = CopyOption(LogicalType::ANY, CopyOptionMode::WRITE_ONLY); + copy_options["order_by"] = CopyOption(LogicalType::ANY, CopyOptionMode::WRITE_ONLY); copy_options["return_files"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::WRITE_ONLY); copy_options["preserve_order"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::WRITE_ONLY); copy_options["return_stats"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::WRITE_ONLY); @@ -101,46 +103,205 @@ static idx_t ParseBytesArg(const string &name, Value &arg) { return arg.GetValue(); } -BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &function, CopyToType copy_to_type) { - if (function.plan) { - // plan rewrite COPY TO - return function.plan(*this, stmt); +struct CopyToParsedOptions { + optional use_tmp_file; + optional overwrite_mode; + optional filename_pattern; + optional per_thread_output; + optional_idx batch_size; + optional_idx batch_size_bytes; + optional_idx file_size_bytes; + optional_idx batches_per_file; + vector partition_cols; + vector order_columns; + optional write_partition_columns; + optional write_empty_file; + optional hive_file_pattern; + optional preserve_order; + optional return_type; + + bool UserSetUseTmpFile() const { + return use_tmp_file.has_value(); } - auto ©_info = *stmt.info; - // bind the select statement - // preserve SQLNULL types from table functions (e.g. JSON reader null columns) - // so file writers that support it can emit the correct null type (e.g. parquet UNKNOWN/NullType) - auto prev_can_contain_nulls = can_contain_nulls; - if (function.supports_sql_null) { - can_contain_nulls = true; + bool PerThreadOutput() const { + return per_thread_output.value_or(false); } - auto node_copy = copy_info.select_statement->Copy(); - auto select_node = Bind(*node_copy); - can_contain_nulls = prev_can_contain_nulls; - if (!function.copy_to_bind) { - throw NotImplementedException("COPY TO is not supported for FORMAT \"%s\"", stmt.info->format); + bool WritePartitionColumns() const { + return write_partition_columns.value_or(false); + } + + bool WriteEmptyFile() const { + return write_empty_file.value_or(true); + } + + bool HiveFilePattern() const { + return hive_file_pattern.value_or(true); + } + + PreserveOrderType PreserveOrder() const { + return preserve_order.value_or(PreserveOrderType::AUTOMATIC); } + CopyOverwriteMode OverwriteMode() const { + return overwrite_mode.value_or(CopyOverwriteMode::COPY_ERROR_ON_CONFLICT); + } + + FilenamePattern GetFilenamePattern() const { + return filename_pattern.value_or(FilenamePattern()); + } + + CopyFunctionReturnType ReturnType() const { + return return_type.value_or(CopyFunctionReturnType::CHANGED_ROWS); + } + + void SetFilenamePattern(const string &pattern) { + FilenamePattern result; + result.SetFilenamePattern(pattern); + filename_pattern = std::move(result); + } + + void SetReturnType(CopyFunctionReturnType return_type_p) { + if (return_type.has_value() && *return_type != return_type_p) { + throw BinderException("Can only set one of RETURN_FILES or RETURN_STATS for COPY"); + } + return_type = return_type_p; + } + + bool Rotate() const { + return file_size_bytes.IsValid() || batches_per_file.IsValid(); + } + + bool Partitioned() const { + return !partition_cols.empty(); + } +}; + +struct CopyToResolvedOptions { bool use_tmp_file = true; CopyOverwriteMode overwrite_mode = CopyOverwriteMode::COPY_ERROR_ON_CONFLICT; FilenamePattern filename_pattern; - bool user_set_use_tmp_file = false; bool per_thread_output = false; optional_idx batch_size; optional_idx batch_size_bytes; optional_idx file_size_bytes; optional_idx batches_per_file; vector partition_cols; - bool seen_overwrite_mode = false; - bool seen_filepattern = false; + vector order_columns; bool write_partition_columns = false; bool write_empty_file = true; bool hive_file_pattern = true; PreserveOrderType preserve_order = PreserveOrderType::AUTOMATIC; CopyFunctionReturnType return_type = CopyFunctionReturnType::CHANGED_ROWS; + bool Rotate() const { + return file_size_bytes.IsValid() || batches_per_file.IsValid(); + } + + bool Partitioned() const { + return !partition_cols.empty(); + } +}; + +static bool ResolveUseTmpFile(ClientContext &context, const string &file_path, const CopyToParsedOptions &options) { + if (FileSystem::IsRemoteFile(file_path)) { + return false; + } + if (options.use_tmp_file.has_value()) { + return *options.use_tmp_file; + } + + auto &fs = FileSystem::GetFileSystem(context); + bool is_file_and_exists = fs.FileExists(file_path); + bool is_stdout = file_path == "/dev/stdout"; + return is_file_and_exists && !options.PerThreadOutput() && !options.Partitioned() && !is_stdout; +} + +static CopyToResolvedOptions ResolveCopyToOptions(ClientContext &context, const string &file_path, + CopyToParsedOptions options) { + CopyToResolvedOptions result; + result.use_tmp_file = ResolveUseTmpFile(context, file_path, options); + result.overwrite_mode = options.OverwriteMode(); + result.filename_pattern = options.GetFilenamePattern(); + result.per_thread_output = options.PerThreadOutput(); + result.batch_size = options.batch_size; + result.batch_size_bytes = options.batch_size_bytes; + result.file_size_bytes = options.file_size_bytes; + result.batches_per_file = options.batches_per_file; + result.partition_cols = std::move(options.partition_cols); + result.order_columns = std::move(options.order_columns); + result.write_partition_columns = options.WritePartitionColumns(); + result.write_empty_file = options.WriteEmptyFile(); + result.hive_file_pattern = options.HiveFilePattern(); + result.preserve_order = options.PreserveOrder(); + result.return_type = options.ReturnType(); + return result; +} + +static void ValidateCopyToOptionCombinations(const CopyToParsedOptions &options, const CopyFunction &function, + const string &format) { + if (options.OverwriteMode() == CopyOverwriteMode::COPY_APPEND && !options.GetFilenamePattern().HasUUID()) { + throw BinderException("APPEND mode requires a {uuid} label in filename_pattern"); + } + if (options.UserSetUseTmpFile() && options.PerThreadOutput()) { + throw NotImplementedException("Can't combine USE_TMP_FILE and PER_THREAD_OUTPUT for COPY"); + } + if (options.UserSetUseTmpFile() && options.Rotate()) { + throw NotImplementedException("Can't combine USE_TMP_FILE and FILE_SIZE_BYTES/BATCHES_PER_FILE for COPY"); + } + if (options.UserSetUseTmpFile() && options.Partitioned()) { + throw NotImplementedException("Can't combine USE_TMP_FILE and PARTITION_BY for COPY"); + } + if (options.PerThreadOutput() && options.Partitioned()) { + throw NotImplementedException("Can't combine PER_THREAD_OUTPUT and PARTITION_BY for COPY"); + } + if (options.Rotate() && (!function.prepare_batch || !function.flush_batch)) { + throw NotImplementedException("Can't use file rotation (e.g., ROW_GROUPS_PER_FILE) with FORMAT %s", + function.name); + } + if (!options.WriteEmptyFile()) { + if (options.PerThreadOutput()) { + throw NotImplementedException("Can't combine WRITE_EMPTY_FILE false with PER_THREAD_OUTPUT"); + } + if (options.Partitioned()) { + throw NotImplementedException("Can't combine WRITE_EMPTY_FILE false with PARTITION_BY"); + } + } + if (options.ReturnType() == CopyFunctionReturnType::WRITTEN_FILE_STATISTICS && + !function.copy_to_get_written_statistics) { + throw NotImplementedException("RETURN_STATS is not supported for the \"%s\" copy format", format); + } + if (!options.order_columns.empty() && !options.Partitioned()) { + throw NotImplementedException("ORDER_BY is not supported without PARTITION_BY"); + } +} + +static void ValidateCopyToOutputColumns(const CopyToResolvedOptions &options, idx_t column_count) { + if (!options.write_partition_columns && options.partition_cols.size() == column_count) { + throw NotImplementedException("No column to write as all columns are specified as partition columns. " + "WRITE_PARTITION_COLUMNS option can be used to write partition columns."); + } +} + +BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &function, CopyToType copy_to_type) { + if (function.plan) { + // plan rewrite COPY TO + return function.plan(*this, stmt); + } + + auto ©_info = *stmt.info; + // bind the select statement + // preserve SQLNULL types from table functions (e.g. JSON reader null columns) + // so file writers that support it can emit the correct null type (e.g. parquet UNKNOWN/NullType) + auto node_copy = copy_info.select_statement->Copy(); + auto select_node = Bind(*node_copy); + + if (!function.copy_to_bind) { + throw NotImplementedException("COPY TO is not supported for FORMAT \"%s\"", stmt.info->format); + } + + CopyToParsedOptions parsed_options; CopyFunctionBindInput bind_input(*stmt.info, function.function_info); bind_input.file_extension = function.extension; @@ -150,51 +311,50 @@ BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &funct for (auto &option : original_options) { auto loption = StringUtil::Lower(option.first); if (loption == "use_tmp_file") { - use_tmp_file = GetBooleanArg(context, option.second); - user_set_use_tmp_file = true; + parsed_options.use_tmp_file = GetBooleanArg(context, option.second); } else if (loption == "overwrite_or_ignore" || loption == "overwrite" || loption == "append") { - if (seen_overwrite_mode) { + if (parsed_options.overwrite_mode.has_value()) { throw BinderException("Can only set one of OVERWRITE_OR_IGNORE, OVERWRITE or APPEND"); } - seen_overwrite_mode = true; auto boolean = GetBooleanArg(context, option.second); if (boolean) { if (loption == "overwrite_or_ignore") { - overwrite_mode = CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE; + parsed_options.overwrite_mode = CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE; } else if (loption == "overwrite") { - overwrite_mode = CopyOverwriteMode::COPY_OVERWRITE; + parsed_options.overwrite_mode = CopyOverwriteMode::COPY_OVERWRITE; } else if (loption == "append") { - if (!seen_filepattern) { - filename_pattern.SetFilenamePattern("{uuid}"); + if (!parsed_options.filename_pattern.has_value()) { + parsed_options.SetFilenamePattern("{uuid}"); } - overwrite_mode = CopyOverwriteMode::COPY_APPEND; + parsed_options.overwrite_mode = CopyOverwriteMode::COPY_APPEND; } + } else { + parsed_options.overwrite_mode = CopyOverwriteMode::COPY_ERROR_ON_CONFLICT; } } else if (loption == "filename_pattern") { if (option.second.empty()) { throw IOException("FILENAME_PATTERN cannot be empty"); } - filename_pattern.SetFilenamePattern( + parsed_options.SetFilenamePattern( option.second[0].CastAs(context, LogicalType::VARCHAR).GetValue()); - seen_filepattern = true; } else if (loption == "file_extension") { if (option.second.empty()) { throw IOException("FILE_EXTENSION cannot be empty"); } bind_input.file_extension = option.second[0].CastAs(context, LogicalType::VARCHAR).GetValue(); } else if (loption == "per_thread_output") { - per_thread_output = GetBooleanArg(context, option.second); + parsed_options.per_thread_output = GetBooleanArg(context, option.second); } else if (loption == "batch_size" || loption == "row_group_size") { if (option.second.empty()) { throw BinderException("BATCH_SIZE/ROW_GROUP_SIZE cannot be empty"); } - batch_size = option.second[0].GetValue(); + parsed_options.batch_size = option.second[0].GetValue(); } else if (loption == "batch_size_bytes" || loption == "row_group_size_bytes") { if (option.second.empty()) { throw BinderException("BATCH_SIZE_BYTES/ROW_GROUP_SIZE_BYTES cannot be empty"); } - batch_size_bytes = ParseBytesArg(loption, option.second[0]); + parsed_options.batch_size_bytes = ParseBytesArg(loption, option.second[0]); } else if (loption == "file_size_bytes") { if (option.second.empty()) { throw BinderException("FILE_SIZE_BYTES cannot be empty"); @@ -202,74 +362,44 @@ BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &funct if (!function.file_size_bytes) { throw BinderException("FILE_SIZE_BYTES not implemented for %s", function.name); } - file_size_bytes = ParseBytesArg(loption, option.second[0]); + parsed_options.file_size_bytes = ParseBytesArg(loption, option.second[0]); } else if (loption == "batches_per_file" || loption == "row_groups_per_file") { if (option.second.empty()) { throw BinderException("BATCHES_PER_FILE/ROW_GROUPS_PER_FILE cannot be empty"); } - batches_per_file = option.second[0].GetValue(); + parsed_options.batches_per_file = option.second[0].GetValue(); } else if (loption == "partition_by") { auto converted = ConvertVectorToValue(std::move(option.second)); - partition_cols = ParseColumnsOrdered(converted, select_node.names, loption); + parsed_options.partition_cols = ParseColumnsOrdered(converted, select_node.names, loption); + } else if (loption == "order_by") { + parsed_options.order_columns = ParseOrderByColumns(*this, option.second, select_node, loption); } else if (loption == "return_files") { if (GetBooleanArg(context, option.second)) { - return_type = CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST; + parsed_options.SetReturnType(CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST); } } else if (loption == "preserve_order") { if (GetBooleanArg(context, option.second)) { - preserve_order = PreserveOrderType::PRESERVE_ORDER; + parsed_options.preserve_order = PreserveOrderType::PRESERVE_ORDER; } else { - preserve_order = PreserveOrderType::DONT_PRESERVE_ORDER; + parsed_options.preserve_order = PreserveOrderType::DONT_PRESERVE_ORDER; } } else if (loption == "return_stats") { if (GetBooleanArg(context, option.second)) { - return_type = CopyFunctionReturnType::WRITTEN_FILE_STATISTICS; + parsed_options.SetReturnType(CopyFunctionReturnType::WRITTEN_FILE_STATISTICS); } } else if (loption == "write_partition_columns") { - write_partition_columns = GetBooleanArg(context, option.second); + parsed_options.write_partition_columns = GetBooleanArg(context, option.second); } else if (loption == "write_empty_file") { - write_empty_file = GetBooleanArg(context, option.second); + parsed_options.write_empty_file = GetBooleanArg(context, option.second); } else if (loption == "hive_file_pattern") { - hive_file_pattern = GetBooleanArg(context, option.second); + parsed_options.hive_file_pattern = GetBooleanArg(context, option.second); } else { stmt.info->options[option.first] = option.second; } } - if (overwrite_mode == CopyOverwriteMode::COPY_APPEND && !filename_pattern.HasUUID()) { - throw BinderException("APPEND mode requires a {uuid} label in filename_pattern"); - } - if (user_set_use_tmp_file && per_thread_output) { - throw NotImplementedException("Can't combine USE_TMP_FILE and PER_THREAD_OUTPUT for COPY"); - } - if (user_set_use_tmp_file && (file_size_bytes.IsValid() || batches_per_file.IsValid())) { - throw NotImplementedException("Can't combine USE_TMP_FILE and FILE_SIZE_BYTES/BATCHES_PER_FILE for COPY"); - } - if (user_set_use_tmp_file && !partition_cols.empty()) { - throw NotImplementedException("Can't combine USE_TMP_FILE and PARTITION_BY for COPY"); - } - if (per_thread_output && !partition_cols.empty()) { - throw NotImplementedException("Can't combine PER_THREAD_OUTPUT and PARTITION_BY for COPY"); - } - if ((file_size_bytes.IsValid() || batches_per_file.IsValid()) && !partition_cols.empty()) { - throw NotImplementedException("Can't combine FILE_SIZE_BYTES and PARTITION_BY for COPY"); - } - if (!write_partition_columns) { - if (partition_cols.size() == select_node.names.size()) { - throw NotImplementedException("No column to write as all columns are specified as partition columns. " - "WRITE_PARTITION_COLUMNS option can be used to write partition columns."); - } - } - bool is_remote_file = FileSystem::IsRemoteFile(stmt.info->file_path); - if (is_remote_file) { - use_tmp_file = false; - } else { - auto &fs = FileSystem::GetFileSystem(context); - bool is_file_and_exists = fs.FileExists(stmt.info->file_path); - bool is_stdout = stmt.info->file_path == "/dev/stdout"; - if (!user_set_use_tmp_file) { - use_tmp_file = is_file_and_exists && !per_thread_output && partition_cols.empty() && !is_stdout; - } - } + + ValidateCopyToOptionCombinations(parsed_options, function, stmt.info->format); + auto resolved_options = ResolveCopyToOptions(context, stmt.info->file_path, std::move(parsed_options)); // Allow the copy function to intercept the select list and types and push a new projection on top of the plan if (function.copy_to_select) { @@ -308,59 +438,35 @@ BoundStatement Binder::BindCopyTo(CopyStatement &stmt, const CopyFunction &funct QueryResult::DeduplicateColumns(unique_column_names); auto file_path = stmt.info->file_path; - auto names_to_write = - LogicalCopyToFile::GetNamesWithoutPartitions(unique_column_names, partition_cols, write_partition_columns); - auto types_to_write = - LogicalCopyToFile::GetTypesWithoutPartitions(select_node.types, partition_cols, write_partition_columns); - auto function_data = function.copy_to_bind(context, bind_input, names_to_write, types_to_write); + ValidateCopyToOutputColumns(resolved_options, select_node.names.size()); - const auto rotate = file_size_bytes.IsValid() || batches_per_file.IsValid(); - if (rotate) { - if (user_set_use_tmp_file) { - throw NotImplementedException( - "Can't combine USE_TMP_FILE and file rotation (e.g., ROW_GROUPS_PER_FILE) for COPY"); - } - if (!partition_cols.empty()) { - throw NotImplementedException( - "Can't combine file rotation (e.g., ROW_GROUPS_PER_FILE) and PARTITION_BY for COPY"); - } - if (!function.prepare_batch || !function.flush_batch) { - throw NotImplementedException( - "Can't use file rotation (e.g., ROW_GROUPS_PER_FILE) and PARTITION_BY with FORMAT %s", function.name); - } - } - if (!write_empty_file) { - if (per_thread_output) { - throw NotImplementedException("Can't combine WRITE_EMPTY_FILE false with PER_THREAD_OUTPUT"); - } - if (!partition_cols.empty()) { - throw NotImplementedException("Can't combine WRITE_EMPTY_FILE false with PARTITION_BY"); - } - } - if (return_type == CopyFunctionReturnType::WRITTEN_FILE_STATISTICS && !function.copy_to_get_written_statistics) { - throw NotImplementedException("RETURN_STATS is not supported for the \"%s\" copy format", stmt.info->format); - } + auto names_to_write = LogicalCopyToFile::GetNamesWithoutPartitions( + unique_column_names, resolved_options.partition_cols, resolved_options.write_partition_columns); + auto types_to_write = LogicalCopyToFile::GetTypesWithoutPartitions( + select_node.types, resolved_options.partition_cols, resolved_options.write_partition_columns); + auto function_data = function.copy_to_bind(context, bind_input, names_to_write, types_to_write); // now create the copy information auto copy = make_uniq(function, std::move(function_data), std::move(stmt.info)); copy->file_path = file_path; - copy->use_tmp_file = use_tmp_file; - copy->overwrite_mode = overwrite_mode; - copy->filename_pattern = filename_pattern; + copy->use_tmp_file = resolved_options.use_tmp_file; + copy->overwrite_mode = resolved_options.overwrite_mode; + copy->filename_pattern = resolved_options.filename_pattern; copy->file_extension = bind_input.file_extension; - copy->per_thread_output = per_thread_output; - copy->batch_size = batch_size; - copy->batch_size_bytes = batch_size_bytes; - copy->batches_per_file = batches_per_file; - copy->file_size_bytes = file_size_bytes; - copy->rotate = rotate; - copy->partition_output = !partition_cols.empty(); - copy->write_partition_columns = write_partition_columns; - copy->partition_columns = std::move(partition_cols); - copy->write_empty_file = write_empty_file; - copy->return_type = return_type; - copy->preserve_order = preserve_order; - copy->hive_file_pattern = hive_file_pattern; + copy->per_thread_output = resolved_options.per_thread_output; + copy->batch_size = resolved_options.batch_size; + copy->batch_size_bytes = resolved_options.batch_size_bytes; + copy->batches_per_file = resolved_options.batches_per_file; + copy->file_size_bytes = resolved_options.file_size_bytes; + copy->rotate = resolved_options.Rotate(); + copy->partition_output = resolved_options.Partitioned(); + copy->write_partition_columns = resolved_options.write_partition_columns; + copy->partition_columns = std::move(resolved_options.partition_cols); + copy->write_empty_file = resolved_options.write_empty_file; + copy->return_type = resolved_options.return_type; + copy->preserve_order = resolved_options.preserve_order; + copy->hive_file_pattern = resolved_options.hive_file_pattern; + copy->order_columns = std::move(resolved_options.order_columns); copy->names = unique_column_names; copy->expected_types = select_node.types; @@ -429,20 +535,20 @@ BoundStatement Binder::BindCopyFrom(CopyStatement &stmt, const CopyFunction &fun for (auto &col : table.GetColumns().Physical()) { auto i = col.Physical(); if (bound_insert.column_index_map[i] != DConstants::INVALID_INDEX) { - expected_names[bound_insert.column_index_map[i]] = col.Name(); + expected_names[bound_insert.column_index_map[i]] = col.Name().GetIdentifierName(); } } } else { expected_names.reserve(bound_insert.expected_types.size()); for (auto &col : table.GetColumns().Physical()) { - expected_names.push_back(col.Name()); + expected_names.emplace_back(col.Name()); } } auto copy_from_function = function.copy_from_function; CopyFromFunctionBindInput input(*stmt.info, copy_from_function); auto function_data = function.copy_from_bind(context, input, expected_names, bound_insert.expected_types); auto get = make_uniq(GenerateTableIndex(), std::move(copy_from_function), std::move(function_data), - bound_insert.expected_types, expected_names); + bound_insert.expected_types, StringsToIdentifiers(expected_names)); for (idx_t i = 0; i < bound_insert.expected_types.size(); i++) { get->AddColumnId(i); } @@ -460,8 +566,8 @@ vector BindCopyOption(ClientContext &context, TableFunctionBinder &option if (expr->GetExpressionType() == ExpressionType::STAR) { auto &star = expr->Cast(); // for compatibility with previous copy implementation - turn a raw * into a * string literal - if (star.relation_name.empty() && star.exclude_list.empty() && star.replace_list.empty() && - star.rename_list.empty() && !star.expr && !star.columns) { + if (star.RelationName().empty() && star.ExcludeList().empty() && star.ReplaceList().empty() && + star.RenameList().empty() && !star.Expression() && !star.IsColumns()) { result.push_back("*"); return result; } @@ -574,13 +680,15 @@ BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { stmt.info->is_format_auto_detected ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; CatalogEntryRetriever entry_retriever {context}; auto &catalog = Catalog::GetSystemCatalog(context); - auto entry = catalog.GetEntry(entry_retriever, DEFAULT_SCHEMA, - {CatalogType::COPY_FUNCTION_ENTRY, stmt.info->format}, on_entry_do); + auto entry = + catalog.GetEntry(entry_retriever, Identifier::DefaultSchema(), + EntryLookupInfo(CatalogType::COPY_FUNCTION_ENTRY, Identifier(stmt.info->format)), on_entry_do); if (!entry) { IsFormatExtensionKnown(stmt.info->format); // If we did not find an entry, we default to a CSV - entry = catalog.GetEntry(entry_retriever, DEFAULT_SCHEMA, {CatalogType::COPY_FUNCTION_ENTRY, "csv"}, + entry = catalog.GetEntry(entry_retriever, Identifier::DefaultSchema(), + EntryLookupInfo(CatalogType::COPY_FUNCTION_ENTRY, "csv"), OnEntryNotFound::THROW_EXCEPTION); } auto ©_function = entry->Cast(); @@ -609,8 +717,8 @@ BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { // check if this matches the mode if (copy_option.mode != CopyOptionMode::READ_WRITE && copy_option.mode != copy_mode) { throw InvalidInputException("Option \"%s\" is not supported for %s - only for %s", provided_option, - std::string(stmt.info->is_from ? "reading" : "writing"), - std::string(stmt.info->is_from ? "writing" : "reading")); + stmt.info->is_from ? "reading" : "writing", + stmt.info->is_from ? "writing" : "reading"); } if (copy_option.type.id() != LogicalTypeId::ANY) { if (provided_entry.second.empty()) { @@ -626,6 +734,10 @@ BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { provided_option); } auto &original_value = provided_entry.second[0]; + if (original_value.IsNull()) { + throw BinderException("NULL is not supported as a valid option for COPY option \"%s\"", + provided_option); + } if (copy_option.type == original_value.type()) { // types match continue; diff --git a/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp b/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp index 88ea22b23..705b6ede6 100644 --- a/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_copy_database.cpp @@ -21,7 +21,8 @@ namespace duckdb { -unique_ptr Binder::BindCopyDatabaseSchema(Catalog &from_database, const string &target_database_name) { +unique_ptr Binder::BindCopyDatabaseSchema(Catalog &from_database, + const Identifier &target_database_name) { catalog_entry_vector_t catalog_entries; catalog_entries = PhysicalExport::GetNaiveExportOrder(context, from_database); @@ -46,7 +47,8 @@ unique_ptr Binder::BindCopyDatabaseSchema(Catalog &from_databas return make_uniq(std::move(info)); } -unique_ptr Binder::BindCopyDatabaseData(Catalog &source_catalog, const string &target_database_name) { +unique_ptr Binder::BindCopyDatabaseData(Catalog &source_catalog, + const Identifier &target_database_name) { auto source_schemas = source_catalog.GetSchemas(context); // We can just use ExtractEntries here because the order doesn't matter diff --git a/src/duckdb/src/planner/binder/statement/bind_create.cpp b/src/duckdb/src/planner/binder/statement/bind_create.cpp index 0be43a08d..cdb6f4824 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create.cpp @@ -26,6 +26,9 @@ #include "duckdb/parser/parsed_data/create_view_info.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/parser/statement/create_statement.hpp" +#include "duckdb/parser/common_table_expression_info.hpp" +#include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/tableref/basetableref.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/binder.hpp" @@ -51,7 +54,20 @@ namespace duckdb { -void Binder::BindSchemaOrCatalog(CatalogEntryRetriever &retriever, string &catalog, string &schema) { +static unique_ptr MakeTriggerValidationCTE(const TableCatalogEntry &table) { + auto alias_select = make_uniq(); + alias_select->select_list.push_back(make_uniq()); + auto alias_table_ref = make_uniq(); + alias_table_ref->table_name = table.name; + alias_table_ref->schema_name = table.schema.name; + alias_table_ref->catalog_name = table.catalog.GetName(); + alias_select->from_table = std::move(alias_table_ref); + auto alias_cte = make_uniq(); + alias_cte->query_node = std::move(alias_select); + return alias_cte; +} + +void Binder::BindSchemaOrCatalog(CatalogEntryRetriever &retriever, Identifier &catalog, Identifier &schema) { auto &context = retriever.GetContext(); if (schema.empty()) { return; @@ -72,7 +88,7 @@ void Binder::BindSchemaOrCatalog(CatalogEntryRetriever &retriever, string &catal auto &search_path = retriever.GetSearchPath(); auto catalog_names = search_path.GetCatalogsForSchema(schema); if (catalog_names.empty()) { - catalog_names.push_back(DatabaseManager::GetDefaultDatabase(context)); + catalog_names.emplace_back(DatabaseManager::GetDefaultDatabase(context)); } for (auto &catalog_name : catalog_names) { auto catalog_ptr = Catalog::GetCatalogEntry(retriever, catalog_name); @@ -81,24 +97,24 @@ void Binder::BindSchemaOrCatalog(CatalogEntryRetriever &retriever, string &catal } if (catalog_ptr->CheckAmbiguousCatalogOrSchema(context, schema)) { throw BinderException( - "Ambiguous reference to catalog or schema \"%s\" - use a fully qualified path like \"%s.%s\"", schema, - catalog_name, schema); + "Ambiguous reference to catalog or schema \"%s\" - use a fully qualified path like \"%s.%s\"", + schema.GetIdentifierName(), catalog_name.GetIdentifierName(), schema.GetIdentifierName()); } } catalog = schema; - schema = string(); + schema = Identifier(); } -void Binder::BindSchemaOrCatalog(ClientContext &context, string &catalog, string &schema) { +void Binder::BindSchemaOrCatalog(ClientContext &context, Identifier &catalog, Identifier &schema) { CatalogEntryRetriever retriever(context); BindSchemaOrCatalog(retriever, catalog, schema); } -void Binder::BindSchemaOrCatalog(string &catalog, string &schema) { +void Binder::BindSchemaOrCatalog(Identifier &catalog, Identifier &schema) { BindSchemaOrCatalog(context, catalog, schema); } -const string Binder::BindCatalog(string &catalog) { +Identifier Binder::BindCatalog(const Identifier &catalog) { auto &db_manager = DatabaseManager::Get(context); optional_ptr database = db_manager.GetDatabase(context, catalog); if (database) { @@ -111,7 +127,7 @@ const string Binder::BindCatalog(string &catalog) { void Binder::SearchSchema(CreateInfo &info) { BindSchemaOrCatalog(info.catalog, info.schema); if (IsInvalidCatalog(info.catalog) && info.temporary) { - info.catalog = TEMP_CATALOG; + info.catalog = Identifier::TempCatalog(); } auto &search_path = ClientData::Get(context).catalog_search_path; if (IsInvalidCatalog(info.catalog) && IsInvalidSchema(info.schema)) { @@ -119,9 +135,9 @@ void Binder::SearchSchema(CreateInfo &info) { info.catalog = default_entry.catalog; info.schema = default_entry.schema; } else if (IsInvalidSchema(info.schema)) { - info.schema = search_path->GetDefaultSchema(context, info.catalog); + info.schema = Identifier(search_path->GetDefaultSchema(context, info.catalog)); } else if (IsInvalidCatalog(info.catalog)) { - info.catalog = search_path->GetDefaultCatalog(info.schema); + info.catalog = Identifier(search_path->GetDefaultCatalog(info.schema)); } if (IsInvalidCatalog(info.catalog)) { info.catalog = DatabaseManager::GetDefaultDatabase(context); @@ -129,11 +145,11 @@ void Binder::SearchSchema(CreateInfo &info) { if (!info.temporary) { // non-temporary create: not read only if (info.catalog == TEMP_CATALOG) { - throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", std::string(TEMP_CATALOG)); + throw ParserException("Only TEMPORARY table names can use the \"%s\" catalog", TEMP_CATALOG); } } else { if (info.catalog != TEMP_CATALOG) { - throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", std::string(TEMP_CATALOG)); + throw ParserException("TEMPORARY table names can *only* use the \"%s\" catalog", TEMP_CATALOG); } } } @@ -163,9 +179,10 @@ void Binder::SetCatalogLookupCallback(catalog_entry_callback_t callback) { entry_retriever.SetCallback(std::move(callback)); } -void Binder::BindView(ClientContext &context, const SelectStatement &stmt, const string &catalog_name, - const string &schema_name, optional_ptr dependencies, - const vector &aliases, vector &result_types, vector &result_names) { +void Binder::BindView(ClientContext &context, const SelectStatement &stmt, const Identifier &catalog_name, + const Identifier &schema_name, optional_ptr dependencies, + const vector &aliases, vector &result_types, + vector &result_names) { auto view_binder = Binder::CreateBinder(context); auto &catalog = Catalog::GetCatalog(context, catalog_name); @@ -178,7 +195,7 @@ void Binder::BindView(ClientContext &context, const SelectStatement &stmt, const dependencies->AddDependency(entry); }); } - view_binder->can_contain_nulls = true; + view_binder->SetCanContainNulls(true); auto view_search_path = view_binder->GetSearchPath(catalog, schema_name); view_binder->entry_retriever.SetSearchPath(std::move(view_search_path)); @@ -255,7 +272,7 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { if (attached.HasStorageManager()) { // If DuckDB is used as a storage, we must check the version. auto &storage_manager = attached.GetStorageManager(); - const auto since = SerializationCompatibility::FromString("v1.4.0").serialization_version; + const auto since = StorageCompatibility::FromString("v1.4.0").storage_version; store_types = info.temporary || attached.IsTemporary() || storage_manager.InMemory() || storage_manager.GetStorageVersion() >= since; } @@ -338,20 +355,20 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { } vector dummy_types; - vector dummy_names; + vector dummy_names; // positional parameters for (idx_t param_idx = 0; param_idx < function->parameters.size(); param_idx++) { dummy_types.emplace_back(function->types.empty() ? LogicalType::UNKNOWN : function->types[param_idx]); - dummy_names.push_back(function->parameters[param_idx]->Cast().GetColumnName()); + dummy_names.emplace_back(function->parameters[param_idx]->Cast().GetColumnName()); } if (!type_overloads.insert(dummy_types).second) { throw BinderException( "Ambiguity in macro overloads - macro %s() has multiple definitions with the same parameters", - base.name); + base.name.GetIdentifierName()); } - auto this_macro_binding = make_uniq(dummy_types, dummy_names, base.name); + auto this_macro_binding = make_uniq(dummy_types, dummy_names, base.name.GetIdentifierName()); macro_binding = this_macro_binding.get(); auto &dependencies = base.dependencies; @@ -369,8 +386,7 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { ErrorData error; if (info.type == CatalogType::MACRO_ENTRY) { BoundSelectNode sel_node; - BoundGroupInformation group_info; - SelectBinder binder(*this, context, sel_node, group_info); + SelectBinder binder(*this, context, sel_node); if (should_create_dependencies) { binder.SetCatalogLookupCallback(binder_callback); } @@ -432,7 +448,7 @@ LogicalType Binder::BindLogicalTypeInternal(const unique_ptr & // Shortcut for constant expressions if (expr->GetExpressionClass() == ExpressionClass::BOUND_CONSTANT) { auto &const_expr = expr->Cast(); - return TypeValue::GetType(const_expr.value); + return TypeValue::GetType(const_expr.GetValue()); } // Else, evaluate the type expression @@ -488,7 +504,7 @@ SchemaCatalogEntry &Binder::BindCreateTriggerInfo(CreateTriggerInfo &create_trig auto &attached = catalog.GetAttached(); if (attached.HasStorageManager()) { auto &storage_manager = attached.GetStorageManager(); - const auto since = SerializationCompatibility::FromString("v2.0.0").serialization_version; + const auto since = StorageVersion::V2_0_0; if (!create_trigger_info.temporary && !attached.IsTemporary() && !storage_manager.InMemory() && storage_manager.GetStorageVersion() < since) { string msg = "CREATE TRIGGER is only supported for storage versions v2.0.0 and higher.\n"; @@ -508,24 +524,41 @@ SchemaCatalogEntry &Binder::BindCreateTriggerInfo(CreateTriggerInfo &create_trig if (create_trigger_info.for_each == TriggerForEach::ROW) { throw NotImplementedException("FOR EACH ROW triggers are not yet supported"); } - - if (create_trigger_info.on_conflict != OnCreateConflict::IGNORE_ON_CONFLICT) { - table.ScanTriggers(table.ParentCatalog().GetCatalogTransaction(context), [&](CatalogEntry &entry) { - auto &t = entry.Cast(); - if (t.timing == create_trigger_info.timing && t.event_type == create_trigger_info.event_type) { - throw NotImplementedException("Multiple triggers per table event are not yet supported"); - } - }); + if ((!create_trigger_info.referencing_new_table.empty() || !create_trigger_info.referencing_old_table.empty()) && + create_trigger_info.timing != TriggerTiming::AFTER) { + throw BinderException("Transition tables can only be specified for AFTER triggers"); + } + if ((!create_trigger_info.referencing_new_table.empty() || !create_trigger_info.referencing_old_table.empty()) && + !create_trigger_info.columns.empty()) { + throw BinderException("UPDATE OF is not valid with transition tables"); + } + if (!create_trigger_info.referencing_old_table.empty()) { + if (create_trigger_info.event_type == TriggerEventType::INSERT_EVENT) { + throw BinderException("REFERENCING OLD TABLE AS is not valid for AFTER INSERT triggers"); + } + if (create_trigger_info.event_type == TriggerEventType::UPDATE_EVENT) { + throw NotImplementedException("REFERENCING OLD TABLE AS is not yet supported for AFTER UPDATE triggers"); + } + } + if (!create_trigger_info.referencing_new_table.empty() && + create_trigger_info.event_type == TriggerEventType::DELETE_EVENT) { + throw BinderException("REFERENCING NEW TABLE AS is not valid for AFTER DELETE triggers"); } // Validate the trigger body using an isolated binder (own GlobalBinderState). // Set up trigger_expanded_tables to match runtime behavior. - // Set up and trigger_creation_table to detect recursive triggers during the validation. + // Set up trigger_creation_table to detect recursive triggers during the validation. auto validation_binder = Binder::CreateBinder(context); validation_binder->global_binder_state->trigger_expanded_tables.insert(table); validation_binder->global_binder_state->trigger_creation_table = &table; validation_binder->global_binder_state->trigger_creation_name = create_trigger_info.trigger_name; auto body_copy = create_trigger_info.trigger_action->Copy(); + + for (const auto &alias : {create_trigger_info.referencing_new_table, create_trigger_info.referencing_old_table}) { + if (!alias.empty() && body_copy->cte_map.map.find(alias) == body_copy->cte_map.map.end()) { + body_copy->cte_map.map[alias] = MakeTriggerValidationCTE(table); + } + } validation_binder->Bind(*body_copy); // Add table dependency @@ -746,8 +779,14 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { bound_options.insert({option.first, ExpressionExecutor::EvaluateScalar(context, *bound_value, true)}); } - CreateSecretInput create_secret_input {type_string, provider_string, info.storage_type, info.name, - scope_strings, bound_options, info.on_conflict, info.persist_type}; + CreateSecretInput create_secret_input {Identifier(type_string), + Identifier(provider_string), + Identifier(info.storage_type), + info.name, + scope_strings, + bound_options, + info.on_conflict, + info.persist_type}; result = SecretManager::Get(context).BindCreateSecret(transaction, create_secret_input); break; diff --git a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp index cc6e5cd6d..33888ff47 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp @@ -58,8 +58,9 @@ static void VerifyCompressionType(ClientContext &context, optional_ptr( - context, INVALID_CATALOG, INVALID_SCHEMA, logical_type.GetAlias(), OnEntryNotFound::RETURN_NULL); + const auto type_entry = + Catalog::GetEntry(context, Identifier::InvalidCatalog(), Identifier::InvalidSchema(), + Identifier(logical_type.GetAlias()), OnEntryNotFound::RETURN_NULL); if (type_entry) { logical_type = type_entry->user_type; } @@ -79,7 +80,7 @@ static void VerifyCompressionType(ClientContext &context, optional_ptr> Binder::BindConstraints(ClientContext &context, const vector> &constraints, - const string &table_name, const ColumnList &columns) { + const Identifier &table_name, const ColumnList &columns) { auto binder = Binder::CreateBinder(context); return binder->BindConstraints(constraints, table_name, columns); } @@ -89,7 +90,7 @@ vector> Binder::BindConstraints(const TableCatalogEn } vector> Binder::BindConstraints(const vector> &constraints, - const string &table_name, const ColumnList &columns) { + const Identifier &table_name, const ColumnList &columns) { vector> bound_constraints; for (const auto &constr : constraints) { bound_constraints.push_back(BindConstraint(*constr, table_name, columns)); @@ -98,7 +99,8 @@ vector> Binder::BindConstraints(const vector> Binder::BindNewConstraints(vector> &constraints, - const string &table_name, const ColumnList &columns) { + const Identifier &table_name, + const ColumnList &columns) { auto bound_constraints = BindConstraints(constraints, table_name, columns); // Handle PK and NOT NULL constraints. @@ -145,7 +147,7 @@ vector> Binder::BindNewConstraints(vector BindCheckConstraint(Binder &binder, const Constraint &constraint, const string &table, +unique_ptr BindCheckConstraint(Binder &binder, const Constraint &constraint, const Identifier &table, const ColumnList &columns) { auto bound_constraint = make_uniq(); auto &bound_check = bound_constraint->Cast(); @@ -162,7 +164,7 @@ unique_ptr BindCheckConstraint(Binder &binder, const Constraint return std::move(bound_constraint); } -unique_ptr Binder::BindUniqueConstraint(const Constraint &constraint, const string &table, +unique_ptr Binder::BindUniqueConstraint(const Constraint &constraint, const Identifier &table, const ColumnList &columns) { auto &unique = constraint.Cast(); @@ -225,7 +227,7 @@ unique_ptr BindForeignKey(const Constraint &constraint) { return make_uniq(fk.info, std::move(pk_key_set), std::move(fk_key_set)); } -unique_ptr Binder::BindConstraint(const Constraint &constraint, const string &table, +unique_ptr Binder::BindConstraint(const Constraint &constraint, const Identifier &table, const ColumnList &columns) { switch (constraint.type) { case ConstraintType::CHECK: { @@ -250,12 +252,12 @@ unique_ptr Binder::BindConstraint(const Constraint &constraint, void Binder::BindGeneratedColumns(BoundCreateTableInfo &info) { auto &base = info.base->Cast(); - vector names; + vector names; vector types; D_ASSERT(base.type == CatalogType::TABLE_ENTRY); for (auto &col : base.columns.Logical()) { - names.push_back(col.Name()); + names.emplace_back(col.Name()); types.push_back(col.Type()); } auto table_index = GenerateTableIndex(); @@ -325,7 +327,7 @@ void Binder::BindDefaultValues(const ColumnList &columns, vector bound_default; @@ -360,7 +362,7 @@ unique_ptr Binder::BindCreateTableCheckpoint(unique_ptr &gcols, +static void ExpressionContainsGeneratedColumn(const ParsedExpression &root_expr, const identifier_set_t &gcols, bool &contains_gcol) { ParsedExpressionIterator::VisitExpression(root_expr, [&](const ColumnRefExpression &column_ref) { @@ -373,7 +375,7 @@ static void ExpressionContainsGeneratedColumn(const ParsedExpression &root_expr, } static bool AnyConstraintReferencesGeneratedColumn(CreateTableInfo &table_info) { - unordered_set generated_columns; + identifier_set_t generated_columns; for (auto &col : table_info.columns.Logical()) { if (!col.Generated()) { continue; @@ -430,13 +432,13 @@ static bool AnyConstraintReferencesGeneratedColumn(CreateTableInfo &table_info) return false; } -static void FindForeignKeyIndexes(const ColumnList &columns, const vector &names, +static void FindForeignKeyIndexes(const ColumnList &columns, const vector &names, vector &indexes) { D_ASSERT(indexes.empty()); D_ASSERT(!names.empty()); for (auto &name : names) { if (!columns.ColumnExists(name)) { - throw BinderException("column \"%s\" named in key does not exist", name); + throw BinderException("column \"%s\" named in key does not exist", name.GetIdentifierName()); } auto &column = columns.GetColumn(name); if (column.Generated()) { @@ -463,9 +465,9 @@ static void FindMatchingPrimaryKeyColumns(const ColumnList &columns, const vecto } found_constraint = true; - vector pk_names; + vector pk_names; if (unique.HasIndex()) { - pk_names.push_back(columns.GetColumn(LogicalIndex(unique.GetIndex())).Name()); + pk_names.emplace_back(columns.GetColumn(LogicalIndex(unique.GetIndex())).Name()); } else { pk_names = unique.GetColumnNames(); } @@ -487,7 +489,7 @@ static void FindMatchingPrimaryKeyColumns(const ColumnList &columns, const vecto } bool equals = true; for (idx_t i = 0; i < fk.pk_columns.size(); i++) { - if (!StringUtil::CIEquals(fk.pk_columns[i], pk_names[i])) { + if (fk.pk_columns[i] != pk_names[i]) { equals = false; break; } @@ -503,7 +505,7 @@ static void FindMatchingPrimaryKeyColumns(const ColumnList &columns, const vecto // no unique constraint or primary key string search_term = find_primary_key ? "primary key" : "primary key or unique constraint"; throw BinderException("Failed to create foreign key: there is no %s for referenced table \"%s\"", search_term, - fk.info.table); + fk.info.table.GetIdentifierName()); } // check if all the columns exist for (auto &name : fk.pk_columns) { @@ -511,13 +513,13 @@ static void FindMatchingPrimaryKeyColumns(const ColumnList &columns, const vecto if (!found) { throw BinderException( "Failed to create foreign key: referenced table \"%s\" does not have a column named \"%s\"", - fk.info.table, name); + fk.info.table.GetIdentifierName(), name); } } auto fk_names = StringUtil::Join(fk.pk_columns, ","); throw BinderException("Failed to create foreign key: referenced table \"%s\" does not have a primary key or unique " "constraint on the columns %s", - fk.info.table, fk_names); + fk.info.table.GetIdentifierName(), fk_names); } static void CheckForeignKeyTypes(const ColumnList &pk_columns, const ColumnList &fk_columns, ForeignKeyConstraint &fk) { @@ -554,7 +556,7 @@ static void BindCreateTableConstraints(CreateTableInfo &create_info, CatalogEntr FindForeignKeyIndexes(create_info.columns, fk.fk_columns, fk.info.fk_keys); // Resolve the self-reference. - if (StringUtil::CIEquals(create_info.table, fk.info.table)) { + if (create_info.table == fk.info.table) { fk.info.type = ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE; FindMatchingPrimaryKeyColumns(create_info.columns, create_info.constraints, fk); FindForeignKeyIndexes(create_info.columns, fk.pk_columns, fk.info.pk_keys); @@ -564,10 +566,12 @@ static void BindCreateTableConstraints(CreateTableInfo &create_info, CatalogEntr // Resolve the table reference in the same catalog/schema as the table being // created, so FK references work for external catalogs (not just the default). - string fk_catalog = fk.info.schema.empty() ? schema.ParentCatalog().GetName() : INVALID_CATALOG; - string fk_schema = fk.info.schema.empty() ? schema.name : fk.info.schema; + Identifier fk_catalog = + fk.info.schema.empty() ? schema.ParentCatalog().GetName() : Identifier::InvalidCatalog(); + string fk_schema = + fk.info.schema.empty() ? schema.name.GetIdentifierName() : fk.info.schema.GetIdentifierName(); EntryLookupInfo table_lookup(CatalogType::TABLE_ENTRY, fk.info.table); - auto table_entry = entry_retriever.GetEntry(fk_catalog, fk_schema, table_lookup); + auto table_entry = entry_retriever.GetEntry(fk_catalog, Identifier(fk_schema), table_lookup); if (table_entry->type == CatalogType::VIEW_ENTRY) { throw BinderException("cannot reference a VIEW with a FOREIGN KEY"); } @@ -635,12 +639,12 @@ unique_ptr Binder::BindCreateTableInfo(unique_ptr Binder::BindCreateTableInfo(unique_ptrCast(); - string catalog_name = base_table_ref.catalog_name; - string schema_name = base_table_ref.schema_name; + Identifier catalog_name = base_table_ref.catalog_name; + Identifier schema_name = base_table_ref.schema_name; BindSchemaOrCatalog(catalog_name, schema_name); // IF EXISTS only guards the trigger, not the table (PostgreSQL-compatible behavior). auto &table_entry = @@ -94,7 +94,7 @@ BoundStatement Binder::Bind(DropStatement &stmt) { break; } if (entry->internal) { - throw CatalogException("Cannot drop internal catalog entry \"%s\"!", entry->name); + throw CatalogException("Cannot drop internal catalog entry \"%s\"!", entry->name.GetIdentifierName()); } stmt.info->catalog = entry->ParentCatalog().GetName(); if (!entry->temporary) { diff --git a/src/duckdb/src/planner/binder/statement/bind_execute.cpp b/src/duckdb/src/planner/binder/statement/bind_execute.cpp index 36d19eb7c..95001cf84 100644 --- a/src/duckdb/src/planner/binder/statement/bind_execute.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_execute.cpp @@ -31,7 +31,7 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) { auto &mapped_named_values = stmt.named_values; // bind any supplied parameters - case_insensitive_map_t bind_values; + identifier_map_t bind_values; auto constant_binder = Binder::CreateBinder(context); constant_binder->SetCanContainNulls(true); for (auto &pair : mapped_named_values) { @@ -47,11 +47,11 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) { StringType::GetCollation(constant.GetReturnType()).empty()) { return_type = LogicalTypeId::STRING_LITERAL; } else if (constant.GetReturnType().IsIntegral()) { - return_type = LogicalType::INTEGER_LITERAL(constant.value); + return_type = LogicalType::INTEGER_LITERAL(constant.GetValueMutable()); } else { - return_type = constant.value.type(); + return_type = constant.GetValueMutable().type(); } - parameter_data = BoundParameterData(std::move(constant.value), std::move(return_type)); + parameter_data = BoundParameterData(std::move(constant.GetValueMutable()), std::move(return_type)); } else { auto value = ExpressionExecutor::EvaluateScalar(context, *bound_expr, true); auto value_type = value.type(); diff --git a/src/duckdb/src/planner/binder/statement/bind_export.cpp b/src/duckdb/src/planner/binder/statement/bind_export.cpp index 8fa7e673a..3446be63f 100644 --- a/src/duckdb/src/planner/binder/statement/bind_export.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_export.cpp @@ -23,12 +23,12 @@ namespace duckdb { //! Sanitizes a string to have only low case chars and underscores -string SanitizeExportIdentifier(const string &str) { +string SanitizeExportIdentifier(const Identifier &str) { // Copy the original string to result - string result(str); + string result(str.GetIdentifierName()); - for (idx_t i = 0; i < str.length(); ++i) { - auto c = str[i]; + for (idx_t i = 0; i < result.length(); ++i) { + auto c = result[i]; if (c >= 'a' && c <= 'z') { // If it is lower case just continue continue; @@ -46,10 +46,10 @@ string SanitizeExportIdentifier(const string &str) { return result; } -bool ReferencedTableIsOrdered(string &referenced_table, catalog_entry_vector_t &ordered) { +bool ReferencedTableIsOrdered(const Identifier &referenced_table, catalog_entry_vector_t &ordered) { for (auto &entry : ordered) { auto &table_entry = entry.get().Cast(); - if (StringUtil::CIEquals(table_entry.name, referenced_table)) { + if (table_entry.name == referenced_table) { // The referenced table is already ordered return true; } @@ -157,8 +157,8 @@ BoundStatement Binder::Bind(ExportStatement &stmt) { BindCopyOptions(*stmt.info); // lookup the format in the catalog - auto ©_function = - Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, stmt.info->format); + auto ©_function = Catalog::GetEntry( + context, Identifier::InvalidCatalog(), Identifier::DefaultSchema(), Identifier(stmt.info->format)); if (!copy_function.function.copy_to_bind && !copy_function.function.plan) { throw NotImplementedException("COPY TO is not supported for FORMAT \"%s\"", stmt.info->format); } @@ -225,26 +225,26 @@ BoundStatement Binder::Bind(ExportStatement &stmt) { id++; } info->is_from = false; - info->catalog = catalog; + info->catalog = Identifier(catalog); info->schema = table.schema.name; info->table = table.name; // We can not export generated columns child_list_t select_list; // Let's verify if any on these columns have not null constraints - vector not_null_columns; + vector not_null_columns; for (auto &constraint : table.GetConstraints()) { if (constraint->type == ConstraintType::NOT_NULL) { auto ¬_null_constraint = constraint->Cast(); - not_null_columns.push_back(table.GetColumn(not_null_constraint.index).GetName()); + not_null_columns.emplace_back(table.GetColumn(not_null_constraint.index).GetName().GetIdentifierName()); } } for (auto &col : table.GetColumns().Physical()) { - select_list.push_back(std::make_pair(col.Name(), col.Type())); + select_list.emplace_back(std::make_pair(col.Name(), col.Type())); } ExportedTableData exported_data; - exported_data.database_name = catalog; + exported_data.database_name = Identifier(catalog); exported_data.table_name = info->table; exported_data.schema_name = info->schema; @@ -274,7 +274,7 @@ BoundStatement Binder::Bind(ExportStatement &stmt) { fs.CreateDirectory(stmt.info->file_path); } - stmt.info->catalog = catalog; + stmt.info->catalog = Identifier(catalog); // prepare the options for export auto &format = stmt.info->format; auto &options = stmt.info->options; diff --git a/src/duckdb/src/planner/binder/statement/bind_insert.cpp b/src/duckdb/src/planner/binder/statement/bind_insert.cpp index 8617edfbc..3f42d5b5b 100644 --- a/src/duckdb/src/planner/binder/statement/bind_insert.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_insert.cpp @@ -6,6 +6,7 @@ #include "duckdb/parser/statement/insert_statement.hpp" #include "duckdb/parser/query_node/insert_query_node.hpp" #include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/tableref/expressionlistref.hpp" #include "duckdb/planner/binder.hpp" @@ -32,7 +33,7 @@ namespace duckdb { void Binder::CheckInsertColumnCountMismatch(idx_t expected_columns, idx_t result_columns, bool columns_provided, - const string &tname) { + const Identifier &tname) { if (result_columns != expected_columns) { string msg = StringUtil::Format(!columns_provided ? "table %s has %lld columns but %lld values were supplied" : "Column name/value mismatch for insert on %s: " @@ -93,30 +94,30 @@ void Binder::ExpandDefaultInValuesList(InsertQueryNode &node, TableCatalogEntry } void DoUpdateSetQualify(unique_ptr &expr, const string &table_name, - vector> &lambda_params); + vector &lambda_params); void DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &table_name, - vector> &lambda_params) { - for (auto &child : function.children) { - if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { - DoUpdateSetQualify(child, table_name, lambda_params); + vector &lambda_params) { + for (auto &child : function.GetArgumentsMutable()) { + if (child.GetExpression().GetExpressionClass() != ExpressionClass::LAMBDA) { + DoUpdateSetQualify(child.GetExpressionMutable(), table_name, lambda_params); continue; } // Special-handling for LHS lambda parameters. // We do not qualify them, and we add them to the lambda_params vector. - auto &lambda_expr = child->Cast(); + auto &lambda_expr = child.GetExpressionMutable()->Cast(); string error_message; auto column_ref_expressions = lambda_expr.ExtractColumnRefExpressions(error_message); if (!error_message.empty()) { // Possibly a JSON function, qualify both LHS and RHS. - ParsedExpressionIterator::EnumerateChildren(*lambda_expr.lhs, [&](unique_ptr &child) { - DoUpdateSetQualify(child, table_name, lambda_params); - }); - ParsedExpressionIterator::EnumerateChildren(*lambda_expr.expr, [&](unique_ptr &child) { - DoUpdateSetQualify(child, table_name, lambda_params); - }); + ParsedExpressionIterator::EnumerateChildren( + *lambda_expr.LeftMutable(), + [&](unique_ptr &child) { DoUpdateSetQualify(child, table_name, lambda_params); }); + ParsedExpressionIterator::EnumerateChildren( + *lambda_expr.RightMutable(), + [&](unique_ptr &child) { DoUpdateSetQualify(child, table_name, lambda_params); }); continue; } @@ -128,16 +129,16 @@ void DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &tabl } // Only qualify in the RHS of the expression. - ParsedExpressionIterator::EnumerateChildren(*lambda_expr.expr, [&](unique_ptr &child) { - DoUpdateSetQualify(child, table_name, lambda_params); - }); + ParsedExpressionIterator::EnumerateChildren( + *lambda_expr.RightMutable(), + [&](unique_ptr &child) { DoUpdateSetQualify(child, table_name, lambda_params); }); lambda_params.pop_back(); } } void DoUpdateSetQualify(unique_ptr &expr, const string &table_name, - vector> &lambda_params) { + vector &lambda_params) { // We avoid ambiguity with EXCLUDED columns by qualifying all column references. switch (expr->GetExpressionClass()) { case ExpressionClass::COLUMN_REF: { @@ -152,7 +153,7 @@ void DoUpdateSetQualify(unique_ptr &expr, const string &table_ } // Qualify the column reference. - expr = make_uniq(col_ref.GetColumnName(), table_name); + expr = make_uniq(col_ref.GetColumnName(), Identifier(table_name)); return; } case ExpressionClass::FUNCTION: { @@ -198,7 +199,7 @@ unique_ptr CreateSetInfoForReplace(TableCatalogEntry &table, Inse if (conflict_columns.count(column.Oid())) { continue; } - columns.push_back(column.Name()); + columns.emplace_back(column.Name()); } } else { // a list of columns was explicitly supplied, only update those @@ -219,14 +220,14 @@ unique_ptr CreateSetInfoForReplace(TableCatalogEntry &table, Inse return set_info; } -void Binder::BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, +void Binder::BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, vector &named_column_map, vector &expected_types, IndexVector &column_index_map) { if (!columns.empty() || default_values) { // insertion statement specifies column list // create a mapping of (list index) -> (column index) - case_insensitive_map_t column_name_map; + identifier_map_t column_name_map; for (idx_t i = 0; i < columns.size(); i++) { auto entry = column_name_map.insert(make_pair(columns[i], i)); if (!entry.second) { @@ -263,7 +264,7 @@ void Binder::BindInsertColumnList(TableCatalogEntry &table, vector &colu } } -static unordered_set GetConflictColumnNames(const TableStorageInfo &storage_info, TableCatalogEntry &table) { +static identifier_set_t GetConflictColumnNames(const TableStorageInfo &storage_info, TableCatalogEntry &table) { unordered_set conflict_column_ids; for (auto &index : storage_info.index_info) { if (!index.is_unique) { @@ -273,7 +274,7 @@ static unordered_set GetConflictColumnNames(const TableStorageInfo &stor conflict_column_ids.insert(col_id); } } - unordered_set conflict_column_names; + identifier_set_t conflict_column_names; for (auto &col : table.GetColumns().Physical()) { if (conflict_column_ids.count(col.Physical().index)) { conflict_column_names.insert(col.Name()); @@ -288,13 +289,14 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, auto &on_conflict_info = *node.on_conflict_info; auto merge_into = make_uniq(); // set up the target table - string table_name = !node.table_ref->alias.empty() ? node.table_ref->alias : node.table; - merge_into->target = std::move(node.table_ref); + string table_name = + !node.table_ref->alias.empty() ? node.table_ref->alias.GetIdentifierName() : node.table.GetIdentifierName(); + merge_into->node->target = std::move(node.table_ref); auto storage_info = table.GetStorageInfo(context); auto &columns = table.GetColumns(); // set up the columns on which to join - vector> all_distinct_on_columns; + vector> all_distinct_on_columns; if (on_conflict_info.indexed_columns.empty()) { // When omitting the conflict target, we derive the join columns from the primary key/unique constraints // traverse the primary key/unique constraints @@ -309,17 +311,17 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, vector> and_children; auto &indexed_columns = index.column_set; - vector distinct_on_columns; + vector distinct_on_columns; for (auto &column : columns.Physical()) { if (!indexed_columns.count(column.Physical().index)) { continue; } - auto lhs = make_uniq(column.Name(), table_name); + auto lhs = make_uniq(column.Name(), Identifier(table_name)); auto rhs = make_uniq(column.Name(), "excluded"); auto new_condition = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(lhs), std::move(rhs)); and_children.push_back(std::move(new_condition)); - distinct_on_columns.push_back(column.Name()); + distinct_on_columns.emplace_back(column.Name()); } all_distinct_on_columns.push_back(std::move(distinct_on_columns)); if (and_children.empty()) { @@ -343,7 +345,7 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, join_condition = make_uniq(ExpressionType::CONJUNCTION_OR, std::move(join_conditions)); } - merge_into->join_condition = std::move(join_condition); + merge_into->node->join_condition = std::move(join_condition); if (!found_matching_indexes) { throw BinderException("There are no UNIQUE/PRIMARY KEY constraints that refer to this table, specify ON " @@ -359,7 +361,7 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, } else { // when on conflict columns are explicitly provided - use them directly // first figure out if there is an index on the columns or not - case_insensitive_map_t specified_columns; + identifier_map_t specified_columns; for (idx_t i = 0; i < on_conflict_info.indexed_columns.size(); i++) { specified_columns[on_conflict_info.indexed_columns[i]] = i; auto column_index = table.GetColumnIndex(on_conflict_info.indexed_columns[i]); @@ -396,8 +398,8 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, throw BinderException("The specified columns as conflict target are not referenced by a UNIQUE/PRIMARY KEY " "CONSTRAINT or INDEX"); } - all_distinct_on_columns.push_back(on_conflict_info.indexed_columns); - merge_into->using_columns = std::move(on_conflict_info.indexed_columns); + all_distinct_on_columns.emplace_back(on_conflict_info.indexed_columns); + merge_into->node->using_columns = on_conflict_info.indexed_columns; } // expand any default values @@ -433,7 +435,7 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, // now push another subquery that adds the default columns auto select_stmt = make_uniq(); auto select_node = make_uniq(); - unordered_set set_columns; + identifier_set_t set_columns; for (auto &set_col : node.columns) { set_columns.insert(set_col); } @@ -461,7 +463,7 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, } // push all columns of the table as an alias for (auto &column : columns.Physical()) { - source->column_name_alias.push_back(column.Name()); + source->column_name_alias.emplace_back(column.Name()); } } // push DISTINCT ON(unique_columns) @@ -479,7 +481,7 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, source = make_uniq(std::move(distinct_stmt), "excluded"); } - merge_into->source = std::move(source); + merge_into->node->source = std::move(source); if (on_conflict_info.action_type == OnConflictAction::REPLACE) { D_ASSERT(!on_conflict_info.set_info); @@ -496,7 +498,7 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, insert_action->action_type = MergeActionType::MERGE_INSERT; insert_action->column_order = node.column_order; - merge_into->actions[MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET].push_back(std::move(insert_action)); + merge_into->node->actions[MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET].push_back(std::move(insert_action)); if (on_conflict_info.condition) { throw BinderException("ON CONFLICT WHERE clause is only supported in DO UPDATE SET ... WHERE ...\nThe WHERE " @@ -510,23 +512,22 @@ unique_ptr Binder::GenerateMergeInto(InsertQueryNode &node, update_action->column_order = node.column_order; if (on_conflict_info.set_info) { for (auto &col : on_conflict_info.set_info->expressions) { - vector> lambda_params; + vector lambda_params; DoUpdateSetQualify(col, table_name, lambda_params); } if (on_conflict_info.set_info->condition) { - vector> lambda_params; + vector lambda_params; DoUpdateSetQualify(on_conflict_info.set_info->condition, table_name, lambda_params); update_action->condition = std::move(on_conflict_info.set_info->condition); } update_action->update_info = std::move(on_conflict_info.set_info); } - merge_into->actions[MergeActionCondition::WHEN_MATCHED].push_back(std::move(update_action)); + merge_into->node->actions[MergeActionCondition::WHEN_MATCHED].push_back(std::move(update_action)); } // move over extra properties - merge_into->cte_map = std::move(node.cte_map); - merge_into->returning_list = std::move(node.returning_list); + merge_into->node->returning_list = std::move(node.returning_list); return merge_into; } @@ -542,7 +543,7 @@ BoundStatement Binder::BindNode(InsertQueryNode &node) { BindSchemaOrCatalog(node.catalog, node.schema); auto &table = Catalog::GetEntry(context, node.catalog, node.schema, node.table); - if (auto expanded = TryExpandAfterTriggers(node, node.returning_list, table, TriggerEventType::INSERT_EVENT)) { + if (auto expanded = TryExpandTriggers(node, node.returning_list, table, TriggerEventType::INSERT_EVENT)) { return std::move(*expanded); } @@ -592,7 +593,8 @@ BoundStatement Binder::BindNode(InsertQueryNode &node) { // bind the default values auto &catalog_name = table.ParentCatalog().GetName(); auto &schema_name = table.ParentSchema().name; - BindDefaultValues(table.GetColumns(), insert->bound_defaults, catalog_name, schema_name); + BindDefaultValues(table.GetColumns(), insert->bound_defaults, catalog_name.GetIdentifierName(), + schema_name.GetIdentifierName()); insert->bound_constraints = BindConstraints(table); if (!node.select_statement && !node.default_values) { result.plan = std::move(insert); @@ -625,8 +627,9 @@ BoundStatement Binder::BindNode(InsertQueryNode &node) { insert->table_index = insert_table_index; unique_ptr index_as_logicaloperator = std::move(insert); - return BindReturning(std::move(node.returning_list), table, node.table_ref ? node.table_ref->alias : string(), - insert_table_index, std::move(index_as_logicaloperator)); + return BindReturning(std::move(node.returning_list), table, + node.table_ref ? node.table_ref->alias : Identifier(), insert_table_index, + std::move(index_as_logicaloperator)); } D_ASSERT(result.types.size() == result.names.size()); diff --git a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp index a9c491299..ad7dbe43b 100644 --- a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp @@ -21,7 +21,7 @@ BoundStatement Binder::Bind(LogicalPlanStatement &stmt) { BoundStatement result; result.types = stmt.plan->types; for (idx_t i = 0; i < result.types.size(); i++) { - result.names.push_back(StringUtil::Format("col%d", i)); + result.names.emplace_back(StringUtil::Format("col%d", i)); } result.plan = std::move(stmt.plan); diff --git a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp index 976ae112f..acc3cfdf6 100644 --- a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp @@ -1,5 +1,6 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/expression_binder/where_binder.hpp" @@ -20,30 +21,8 @@ namespace duckdb { -//! Validates that column references in MERGE action expressions are valid for the match context. -static void ValidateMergeColumns(const Expression &expr, MergeActionCondition condition, TableIndex target_table_index, - const unordered_set &source_table_indices) { - if (condition == MergeActionCondition::WHEN_MATCHED) { - return; - } - - ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { - bool is_target_column = (colref.binding.table_index == target_table_index); - bool is_source_column = source_table_indices.count(colref.binding.table_index.index) > 0; - - if (condition == MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET && is_target_column) { - throw BinderException("Target column '%s' cannot be referenced in a WHEN NOT MATCHED BY TARGET clause", - colref.GetAlias().empty() ? colref.ToString() : colref.GetAlias()); - } - if (condition == MergeActionCondition::WHEN_NOT_MATCHED_BY_SOURCE && is_source_column) { - throw BinderException("Source column '%s' cannot be referenced in a WHEN NOT MATCHED BY SOURCE clause", - colref.GetAlias().empty() ? colref.ToString() : colref.GetAlias()); - } - }); -} - vector> GenerateColumnReferences(Binder &binder, const vector &aliases, - const vector &names) { + const vector &names) { vector> result; D_ASSERT(aliases.size() == names.size()); @@ -57,28 +36,15 @@ vector> GenerateColumnReferences(Binder &binder, co unique_ptr Binder::BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, LogicalGet &get, TableIndex proj_index, - vector> &expressions, unique_ptr &root, - MergeIntoAction &action, const vector &source_aliases, - const vector &source_names, MergeActionCondition condition, - const unordered_set &source_table_indices) { + vector> &expressions, MergeIntoAction &action, + const vector &source_aliases, const vector &source_names) { auto result = make_uniq(); result->action_type = action.action_type; - auto expr_start_idx = expressions.size(); if (action.condition) { - if (action.condition->HasSubquery()) { - // if we have a subquery we need to execute the condition outside of the MERGE INTO statement - WhereBinder where_binder(*this, context); - auto cond = where_binder.Bind(action.condition); - PlanSubqueries(cond, root); - auto cond_type = cond->GetReturnType(); - auto cond_idx = ColumnBinding::PushExpression(expressions, std::move(cond)); - result->condition = make_uniq(cond_type, ColumnBinding(proj_index, cond_idx)); - } else { - ProjectionBinder proj_binder(*this, context, proj_index, expressions, "WHERE clause"); - proj_binder.target_type = LogicalType::BOOLEAN; - auto cond = proj_binder.Bind(action.condition); - result->condition = std::move(cond); - } + ProjectionBinder proj_binder(*this, context, proj_index, expressions, "WHERE clause"); + proj_binder.target_type = LogicalType::BOOLEAN; + auto cond = proj_binder.Bind(action.condition); + result->condition = std::move(cond); } switch (action.action_type) { case MergeActionType::MERGE_UPDATE: { @@ -107,7 +73,9 @@ Binder::BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, action.update_info->expressions = GenerateColumnReferences(*this, source_aliases, source_names); } } - BindUpdateSet(proj_index, root, *action.update_info, table, result->columns, result->expressions, expressions); + unique_ptr fake_root; + BindUpdateSet(proj_index, fake_root, *action.update_info, table, result->columns, result->expressions, + expressions); // bind any additional columns that need to be bound for update constraints // FIXME: this is pretty hacky @@ -164,7 +132,6 @@ Binder::BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, TryReplaceDefaultExpression(action.expressions[i], column); auto insert_expr = insert_binder.Bind(action.expressions[i]); - PlanSubqueries(insert_expr, root); insert_expressions.push_back(std::move(insert_expr)); } @@ -193,12 +160,6 @@ Binder::BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, default: throw InternalException("Unsupported merge action type"); } - - for (idx_t i = expr_start_idx; i < expressions.size(); i++) { - if (expressions[i]) { - ValidateMergeColumns(*expressions[i], condition, get.table_index, source_table_indices); - } - } return result; } @@ -207,9 +168,9 @@ void RewriteMergeBindings(unique_ptr &expr, const vector( expr, [&](BoundColumnRefExpression &bound_colref, unique_ptr &expr) { for (idx_t i = 0; i < source_bindings.size(); i++) { - if (bound_colref.binding == source_bindings[i]) { - bound_colref.binding.table_index = new_table_index; - bound_colref.binding.column_index = ProjectionIndex(i); + if (bound_colref.Binding() == source_bindings[i]) { + bound_colref.BindingMutable().table_index = new_table_index; + bound_colref.BindingMutable().column_index = ProjectionIndex(i); } } }); @@ -236,10 +197,14 @@ void CheckMergeAction(MergeActionCondition condition, MergeActionType action_typ } BoundStatement Binder::Bind(MergeIntoStatement &stmt) { + return Bind(*stmt.node); +} + +BoundStatement Binder::BindNode(MergeQueryNode &node) { // bind the target table auto target_binder = Binder::CreateBinder(context, this); - string table_alias = stmt.target->alias; - auto bound_table = target_binder->Bind(*stmt.target); + auto table_alias = node.target->alias; + auto bound_table = target_binder->Bind(*node.target); if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { throw BinderException("Can only merge into base tables!"); } @@ -262,7 +227,7 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { auto &properties = GetStatementProperties(); // modification type depends on actions DatabaseModificationType modification; - for (auto &action_condition : stmt.actions) { + for (auto &action_condition : node.actions) { for (auto &action : action_condition.second) { switch (action->action_type) { case MergeActionType::MERGE_UPDATE: @@ -284,21 +249,40 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { // bind the source auto source_binder = Binder::CreateBinder(context, this); - auto source_binding = source_binder->Bind(*stmt.source); + auto source_binding = source_binder->Bind(*node.source); // get the source names/types and collect source table indices for validation vector source_aliases; - vector source_names; - unordered_set source_table_indices; + vector source_names; for (auto &binding_entry : source_binder->bind_context.GetBindingsList()) { auto &binding = *binding_entry; - source_table_indices.insert(binding.GetIndex().index); auto &column_names = binding.GetColumnNames(); for (idx_t c = 0; c < column_names.size(); c++) { source_aliases.push_back(binding.GetBindingAlias()); source_names.push_back(column_names[c]); } } + // bind the WHEN NOT MATCHED BY SOURCE / TARGET merge actions + auto &get = bound_table.plan->Cast(); + auto merge_into = make_uniq(table); + merge_into->table_index = GenerateTableIndex(); + auto proj_index = GenerateTableIndex(); + vector> projection_expressions; + + for (auto &entry : node.actions) { + if (entry.first == MergeActionCondition::WHEN_MATCHED) { + continue; + } + auto &action_binder = + entry.first == MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET ? *source_binder : *target_binder; + vector> bound_actions; + for (auto &action : entry.second) { + CheckMergeAction(entry.first, action->action_type); + bound_actions.push_back(action_binder.BindMergeAction( + *merge_into, table, get, proj_index, projection_expressions, *action, source_aliases, source_names)); + } + merge_into->actions.emplace(entry.first, std::move(bound_actions)); + } // bind the join between the source and target // our conditions determine the join type we need @@ -307,8 +291,8 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { // if we have both -> FULL join // if we only have WHEN MATCHED we only need matches -> INNER join JoinRef join; - auto has_not_matched_by_source = stmt.actions.count(MergeActionCondition::WHEN_NOT_MATCHED_BY_SOURCE) > 0; - auto has_not_matched_by_target = stmt.actions.count(MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET) > 0; + auto has_not_matched_by_source = node.actions.count(MergeActionCondition::WHEN_NOT_MATCHED_BY_SOURCE) > 0; + auto has_not_matched_by_target = node.actions.count(MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET) > 0; if (has_not_matched_by_source && has_not_matched_by_target) { join.type = JoinType::OUTER; } else if (has_not_matched_by_source) { @@ -318,13 +302,12 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { } else { join.type = JoinType::INNER; } - auto &get = bound_table.plan->Cast(); join.left = make_uniq(std::move(source_binding), std::move(source_binder)); join.right = make_uniq(std::move(bound_table), std::move(target_binder)); - if (stmt.join_condition) { - join.condition = std::move(stmt.join_condition); + if (node.join_condition) { + join.condition = std::move(node.join_condition); } else { - join.using_columns = std::move(stmt.using_columns); + join.using_columns = std::move(node.using_columns); } auto bound_join_node = Bind(join); @@ -341,34 +324,37 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { bool inverted = join.type == JoinType::RIGHT; auto &source = join_ref.get().children[inverted ? 1 : 0]; - auto merge_into = make_uniq(table); - merge_into->table_index = GenerateTableIndex(); - if (!stmt.returning_list.empty()) { + if (!node.returning_list.empty()) { merge_into->return_chunk = true; } // bind table constraints/default values in case these are referenced auto &catalog_name = table.ParentCatalog().GetName(); auto &schema_name = table.ParentSchema().name; - BindDefaultValues(table.GetColumns(), merge_into->bound_defaults, catalog_name, schema_name); + BindDefaultValues(table.GetColumns(), merge_into->bound_defaults, catalog_name.GetIdentifierName(), + schema_name.GetIdentifierName()); merge_into->bound_constraints = BindConstraints(table); - // bind the merge actions - auto proj_index = GenerateTableIndex(); - vector> projection_expressions; - - for (auto &entry : stmt.actions) { + // bind WHEN_MATCHED merge actions (can contain references to both source and target) + for (auto &entry : node.actions) { + if (entry.first != MergeActionCondition::WHEN_MATCHED) { + continue; + } vector> bound_actions; for (auto &action : entry.second) { CheckMergeAction(entry.first, action->action_type); - bound_actions.push_back(BindMergeAction(*merge_into, table, get, proj_index, projection_expressions, root, - *action, source_aliases, source_names, entry.first, - source_table_indices)); + bound_actions.push_back(BindMergeAction(*merge_into, table, get, proj_index, projection_expressions, + *action, source_aliases, source_names)); } merge_into->actions.emplace(entry.first, std::move(bound_actions)); } + // plan merge action subqueries + for (auto &expr : projection_expressions) { + PlanSubqueries(expr, root); + } + if (has_not_matched_by_source) { // if we have "has_not_matched_by_source" we need to push an extra marker into the source // this marker tells us if we have found a source match or not @@ -426,7 +412,7 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { // so we can pass them through instead of fetching by row ID in PhysicalDelete. // Generated columns will be computed in the RETURNING projection by the binder. if (has_delete_action) { - if (!stmt.returning_list.empty()) { + if (!node.returning_list.empty()) { // Use the overloaded helper to add physical columns to the scan and build projection expressions auto &target_binding = join_ref.get().children[inverted ? 0 : 1]; BindDeleteReturningColumns(table, get, merge_into->delete_return_columns, projection_expressions, @@ -451,14 +437,14 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { merge_into->AddChild(std::move(proj)); - if (!stmt.returning_list.empty()) { + if (!node.returning_list.empty()) { auto merge_table_index = merge_into->table_index; unique_ptr index_as_logicaloperator = std::move(merge_into); // add the merge_action virtual column virtual_column_map_t virtual_columns; virtual_columns.insert(make_pair(VIRTUAL_COLUMN_START, TableColumn("merge_action", LogicalType::VARCHAR))); - return BindReturning(std::move(stmt.returning_list), table, table_alias, merge_table_index, + return BindReturning(std::move(node.returning_list), table, table_alias, merge_table_index, std::move(index_as_logicaloperator), std::move(virtual_columns)); } diff --git a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp index 5bd0db038..f100f29e1 100644 --- a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp @@ -29,12 +29,13 @@ unique_ptr Binder::BindPragma(PragmaInfo &info, QueryErrorConte } // bind the pragma function - auto entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, - OnEntryNotFound::RETURN_NULL); + auto entry = Catalog::GetEntry( + context, Identifier::InvalidCatalog(), Identifier::DefaultSchema(), info.name, OnEntryNotFound::RETURN_NULL); if (!entry) { // try to find whether a table entry might exist - auto table_entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, - info.name, OnEntryNotFound::RETURN_NULL); + auto table_entry = Catalog::GetEntry(context, Identifier::InvalidCatalog(), + Identifier::DefaultSchema(), info.name, + OnEntryNotFound::RETURN_NULL); if (table_entry) { // there is a table entry with the same name, now throw more explicit error message throw CatalogException("Pragma Function with name %s does not exist, but a table function with the same " @@ -42,7 +43,8 @@ unique_ptr Binder::BindPragma(PragmaInfo &info, QueryErrorConte info.name, info.name); } // rebind to throw exception - entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + entry = Catalog::GetEntry(context, Identifier::InvalidCatalog(), + Identifier::DefaultSchema(), info.name, OnEntryNotFound::THROW_EXCEPTION); } diff --git a/src/duckdb/src/planner/binder/statement/bind_simple.cpp b/src/duckdb/src/planner/binder/statement/bind_simple.cpp index 31869ac1b..620c2db7e 100644 --- a/src/duckdb/src/planner/binder/statement/bind_simple.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_simple.cpp @@ -68,7 +68,7 @@ BoundStatement Binder::BindAlterAddIndex(BoundStatement &result, CatalogEntry &e throw BinderException("can only add an index to a base table"); } auto &get = bound_table.plan->Cast(); - get.names = column_list.GetColumnNames(); + get.names = StringsToIdentifiers(column_list.GetColumnNames()); auto alter_table_info = unique_ptr_cast(std::move(alter_info)); result.plan = table.catalog.BindAlterAddIndex(*this, table, std::move(bound_table.plan), diff --git a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp index f8a68ae4c..e81457d96 100644 --- a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp @@ -18,24 +18,24 @@ static unique_ptr SummarizeWrapUnnest(vector> unnest_children; unnest_children.push_back(std::move(list_function)); auto unnest_function = make_uniq("unnest", std::move(unnest_children)); - unnest_function->SetAlias(alias); + unnest_function->SetAlias(Identifier(alias)); return std::move(unnest_function); } -static unique_ptr SummarizeCreateAggregate(const string &aggregate, string column_name) { +static unique_ptr SummarizeCreateAggregate(const string &aggregate, Identifier column_name) { vector> children; children.push_back(make_uniq(std::move(column_name))); - auto aggregate_function = make_uniq(aggregate, std::move(children)); + auto aggregate_function = make_uniq(Identifier(aggregate), std::move(children)); auto cast_function = make_uniq(LogicalType::VARCHAR, std::move(aggregate_function)); return std::move(cast_function); } -static unique_ptr SummarizeCreateAggregate(const string &aggregate, string column_name, +static unique_ptr SummarizeCreateAggregate(const string &aggregate, Identifier column_name, const Value &modifier) { vector> children; children.push_back(make_uniq(std::move(column_name))); children.push_back(make_uniq(modifier)); - auto aggregate_function = make_uniq(aggregate, std::move(children)); + auto aggregate_function = make_uniq(Identifier(aggregate), std::move(children)); auto cast_function = make_uniq(LogicalType::VARCHAR, std::move(aggregate_function)); return std::move(cast_function); } @@ -51,11 +51,11 @@ static unique_ptr SummarizeCreateBinaryFunction(const string & vector> children; children.push_back(std::move(left)); children.push_back(std::move(right)); - auto binary_function = make_uniq(op, std::move(children)); + auto binary_function = make_uniq(Identifier(op), std::move(children)); return std::move(binary_function); } -static unique_ptr SummarizeCreateNullPercentage(string column_name) { +static unique_ptr SummarizeCreateNullPercentage(Identifier column_name) { auto count_star = make_uniq(LogicalType::DOUBLE, SummarizeCreateCountStar()); auto count = make_uniq(LogicalType::DOUBLE, SummarizeCreateAggregate("count", std::move(column_name))); @@ -71,8 +71,8 @@ static unique_ptr SummarizeCreateNullPercentage(string column_ CaseCheck check; check.when_expr = std::move(comp_expr); check.then_expr = std::move(percentage_x); - case_expr->case_checks.push_back(std::move(check)); - case_expr->else_expr = make_uniq(Value()); + case_expr->CaseChecksMutable().push_back(std::move(check)); + case_expr->ElseMutable() = make_uniq(Value()); return make_uniq(LogicalType::DECIMAL(9, 2), std::move(case_expr)); } @@ -82,7 +82,7 @@ BoundStatement Binder::BindSummarize(ShowRef &ref) { if (ref.query) { query = std::move(ref.query); } else { - auto table_name = QualifiedName::Parse(ref.table_name); + auto table_name = QualifiedName::Parse(ref.table_name.GetIdentifierName()); auto node = make_uniq(); node->select_list.push_back(make_uniq()); auto basetableref = make_uniq(); diff --git a/src/duckdb/src/planner/binder/statement/bind_update.cpp b/src/duckdb/src/planner/binder/statement/bind_update.cpp index 7d9ffeb6c..093559068 100644 --- a/src/duckdb/src/planner/binder/statement/bind_update.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_update.cpp @@ -40,10 +40,12 @@ void Binder::BindUpdateSet(TableIndex proj_index, unique_ptr &r if (!table.ColumnExists(colname)) { vector column_names; for (auto &col : table.GetColumns().Physical()) { - column_names.push_back(col.Name()); + column_names.emplace_back(col.Name().GetIdentifierName()); } - auto candidates = StringUtil::CandidatesErrorMessage(column_names, colname, "Did you mean"); - throw BinderException("Referenced update column %s not found in table!\n%s", colname, candidates); + auto candidates = + StringUtil::CandidatesErrorMessage(column_names, colname.GetIdentifierName(), "Did you mean"); + throw BinderException("Referenced update column %s not found in table!\n%s", colname.GetIdentifierName(), + candidates); } auto &column = table.GetColumn(colname); if (column.Generated()) { @@ -59,7 +61,9 @@ void Binder::BindUpdateSet(TableIndex proj_index, unique_ptr &r UpdateBinder binder(*expr_binder_ptr, context); binder.target_type = column.Type(); auto bound_expr = binder.Bind(expr); - PlanSubqueries(bound_expr, root); + if (root) { + PlanSubqueries(bound_expr, root); + } auto bound_type = bound_expr->GetReturnType(); auto expr_index = ColumnBinding::PushExpression(projection_expressions, std::move(bound_expr)); @@ -136,7 +140,7 @@ BoundStatement Binder::BindNode(UpdateQueryNode &node) { } auto &table = *table_ptr; - if (auto expanded = TryExpandAfterTriggers(node, node.returning_list, table, TriggerEventType::UPDATE_EVENT)) { + if (auto expanded = TryExpandTriggers(node, node.returning_list, table, TriggerEventType::UPDATE_EVENT)) { return std::move(*expanded); } @@ -168,7 +172,8 @@ BoundStatement Binder::BindNode(UpdateQueryNode &node) { // bind the default values auto &catalog_name = table.ParentCatalog().GetName(); auto &schema_name = table.ParentSchema().name; - BindDefaultValues(table.GetColumns(), update->bound_defaults, catalog_name, schema_name); + BindDefaultValues(table.GetColumns(), update->bound_defaults, catalog_name.GetIdentifierName(), + schema_name.GetIdentifierName()); update->bound_constraints = BindConstraints(table); // project any additional columns required for the condition/expressions diff --git a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp index 026f682b0..f5303a171 100644 --- a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp @@ -34,12 +34,12 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr if (columns.empty()) { // Empty means ALL columns should be vacuumed/analyzed for (auto &col : table.GetColumns().Physical()) { - columns.push_back(col.GetName()); + columns.emplace_back(col.GetName()); } } - case_insensitive_set_t column_name_set; - vector non_generated_column_names; + identifier_set_t column_name_set; + vector non_generated_column_names; for (auto &col_name : columns) { if (column_name_set.count(col_name) > 0) { throw BinderException("cannot vacuum or analyze the same column twice, i.e., there is a duplicate entry in " @@ -56,7 +56,7 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr "cannot vacuum or analyze generated column \"%s\" - specify non-generated columns to vacuum or analyze", col.GetName()); } - non_generated_column_names.push_back(col_name); + non_generated_column_names.emplace_back(col_name); ColumnRefExpression colref(col_name, table.name); auto result = bind_context.BindColumn(colref, 0); if (result.HasError()) { @@ -64,7 +64,7 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr } select_list.push_back(std::move(result.expression)); } - info.columns = std::move(non_generated_column_names); + info.columns = non_generated_column_names; auto &column_ids = get.GetColumnIds(); D_ASSERT(select_list.size() == column_ids.size()); diff --git a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp index 8089cfb53..2a6bb2f0b 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp @@ -50,7 +50,8 @@ BoundStatement Binder::BindWithReplacementScan(ClientContext &context, BaseTable return BoundStatement(); } for (auto &scan : config.replacement_scans) { - ReplacementScanInput input(ref.catalog_name, ref.schema_name, ref.table_name); + ReplacementScanInput input(ref.catalog_name.GetIdentifierName(), ref.schema_name.GetIdentifierName(), + ref.table_name.GetIdentifierName()); auto replacement_function = scan.function(context, input, scan.data.get()); if (!replacement_function) { continue; @@ -69,12 +70,16 @@ BoundStatement Binder::BindWithReplacementScan(ClientContext &context, BaseTable auto &subquery = replacement_function->Cast(); subquery.column_name_alias = ref.column_name_alias; } else { + // carry the alias to the wrapping SubqueryRef so qualified references + // like `SELECT d.x FROM _ AS d` can resolve against the outer ref + auto inner_alias = replacement_function->alias; auto select_node = make_uniq(); select_node->select_list.push_back(make_uniq()); select_node->from_table = std::move(replacement_function); auto select_stmt = make_uniq(); select_stmt->node = std::move(select_node); auto subquery = make_uniq(std::move(select_stmt)); + subquery->alias = std::move(inner_alias); subquery->column_name_alias = ref.column_name_alias; replacement_function = std::move(subquery); } @@ -96,7 +101,7 @@ unique_ptr Binder::BindAtClause(optional_ptr at_clause) return make_uniq(at_clause->Unit(), std::move(val)); } -vector Binder::GetSearchPath(Catalog &catalog, const string &schema_name) { +vector Binder::GetSearchPath(Catalog &catalog, const Identifier &schema_name) { vector view_search_path; auto &catalog_name = catalog.GetName(); if (!schema_name.empty()) { @@ -104,14 +109,14 @@ vector Binder::GetSearchPath(Catalog &catalog, const string } auto default_schema = catalog.GetDefaultSchema(); if (schema_name.empty() && schema_name != default_schema) { - view_search_path.emplace_back(catalog_name, default_schema); + view_search_path.emplace_back(catalog_name, Identifier(default_schema)); } //! Signal that this catalog should be checked, regardless of the schema in the reference view_search_path.emplace_back(catalog_name, INVALID_SCHEMA); return view_search_path; } -void Binder::SetSearchPath(Catalog &catalog, const string &schema) { +void Binder::SetSearchPath(Catalog &catalog, const Identifier &schema) { auto search_path = GetSearchPath(catalog, schema); entry_retriever.SetSearchPath(std::move(search_path)); } @@ -143,8 +148,7 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { BoundStatement result; result.types = ctebinding->GetColumnTypes(); result.names = names; - result.plan = - make_uniq(index, ctebinding->GetIndex(), result.types, std::move(names), is_recurring); + result.plan = make_uniq(index, ctebinding->GetIndex(), result.types, names, is_recurring); return result; } @@ -163,14 +167,14 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { if (GetBindingMode() == BindingMode::EXTRACT_QUALIFIED_NAMES) { AddTableName(ref.ToString()); } else { - AddTableName(ref.table_name); + AddTableName(ref.table_name.GetIdentifierName()); } // add a bind context entry auto table_index = GenerateTableIndex(); auto ref_alias = ref.alias.empty() ? ref.table_name : ref.alias; vector types {LogicalType::INTEGER}; - vector names {"__dummy_col" + to_string(table_index.index)}; + vector names {Identifier("__dummy_col" + to_string(table_index.index))}; bind_context.AddGenericBinding(table_index, ref_alias, names, types); BoundStatement result; @@ -189,7 +193,9 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { } // Try autoloading an extension, then retry the replacement scan bind - auto full_path = ReplacementScan::GetFullPath(ref.catalog_name, ref.schema_name, ref.table_name); + auto full_path = + ReplacementScan::GetFullPath(ref.catalog_name.GetIdentifierName(), ref.schema_name.GetIdentifierName(), + ref.table_name.GetIdentifierName()); auto extension_loaded = TryLoadExtensionForReplacementScan(context, full_path); if (extension_loaded) { replacement_scan_bind_result = BindWithReplacementScan(context, ref); @@ -215,7 +221,7 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { throw BinderException(error_context, "Circular reference to CTE \"%s\", use WITH RECURSIVE to " "use recursive CTEs.", - ref.table_name); + ref.table_name.GetIdentifierName()); } // could not find an alternative: bind again to get the error // note: this will always throw when using DuckDB as a catalog, but a second look-up might succeed @@ -239,16 +245,16 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { } // TODO: bundle the type and name vector in a struct (e.g PackedColumnMetadata) vector table_types; - vector table_names; + vector table_names; vector table_categories; vector return_types; - vector return_names; + vector return_names; for (auto &col : table.GetColumns().Logical()) { table_types.push_back(col.Type()); - table_names.push_back(col.Name()); + table_names.emplace_back(col.Name()); return_types.push_back(col.Type()); - return_names.push_back(col.Name()); + return_names.emplace_back(col.Name()); } table_names = BindContext::AliasColumnNames(ref.table_name, table_names, ref.column_name_alias); @@ -281,7 +287,7 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { // defined for this binder so there are no collisions between the CTEs defined // for the view and for the current query auto view_binder = Binder::CreateBinder(context, this, BinderType::VIEW_BINDER); - view_binder->can_contain_nulls = true; + view_binder->SetCanContainNulls(true); // The view may contain CTEs, but maybe only in the cte_map, so we need create CTE nodes for them auto query = view_catalog_entry.GetQuery().Copy(); diff --git a/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp b/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp index d3c5ea4a2..2e9f3557b 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp @@ -11,7 +11,7 @@ BoundStatement Binder::Bind(ColumnDataRef &ref) { BoundStatement result; result.names = std::move(ref.expected_names); for (idx_t i = result.names.size(); i < types.size(); i++) { - result.names.push_back("col" + to_string(i + 1)); + result.names.emplace_back("col" + to_string(i + 1)); } result.types = types; auto bind_index = GenerateTableIndex(); diff --git a/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp b/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp index 4e7ffe3a6..698cfbffb 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp @@ -11,10 +11,10 @@ BoundStatement Binder::Bind(DelimGetRef &ref) { BoundStatement result; result.types = std::move(ref.types); - result.names = std::move(ref.internal_aliases); + result.names = ref.internal_aliases; result.plan = make_uniq(tbl_idx, result.types); - bind_context.AddGenericBinding(tbl_idx, internal_name, result.names, result.types); + bind_context.AddGenericBinding(tbl_idx, Identifier(internal_name), result.names, result.types); return result; } diff --git a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp index 139f94670..dbc0fc7fd 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp @@ -14,7 +14,7 @@ BoundStatement Binder::Bind(ExpressionListRef &expr) { result.names = expr.expected_names; vector>> values; - auto prev_can_contain_nulls = this->can_contain_nulls; + auto prev_can_contain_nulls = CanContainNulls(); // bind value list InsertBinder binder(*this, context); binder.target_type = LogicalType(LogicalTypeId::INVALID); @@ -23,11 +23,11 @@ BoundStatement Binder::Bind(ExpressionListRef &expr) { if (result.names.empty()) { // no names provided, generate them for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - result.names.push_back("col" + to_string(val_idx)); + result.names.emplace_back("col" + to_string(val_idx)); } } - this->can_contain_nulls = true; + SetCanContainNulls(true); vector> list; for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { if (!result.types.empty()) { @@ -38,7 +38,7 @@ BoundStatement Binder::Bind(ExpressionListRef &expr) { list.push_back(std::move(bound_expr)); } values.push_back(std::move(list)); - this->can_contain_nulls = prev_can_contain_nulls; + this->SetCanContainNulls(prev_can_contain_nulls); } if (result.types.empty() && !expr.values.empty()) { // there are no types specified diff --git a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp index d255a3da0..29d66f8d2 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp @@ -10,13 +10,15 @@ #include "duckdb/parser/expression/star_expression.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression_binder/lateral_binder.hpp" +#include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" namespace duckdb { static unique_ptr BindColumn(Binder &binder, ClientContext &context, const BindingAlias &alias, - const string &column_name) { + const Identifier &column_name) { auto expr = make_uniq_base(column_name, alias); ExpressionBinder expr_binder(binder, context); auto result = expr_binder.Bind(expr); @@ -25,14 +27,14 @@ static unique_ptr BindColumn(Binder &binder, ClientContext &co static unique_ptr AddCondition(ClientContext &context, Binder &left_binder, Binder &right_binder, const BindingAlias &left_alias, const BindingAlias &right_alias, - const string &column_name, ExpressionType type) { + const Identifier &column_name, ExpressionType type) { ExpressionBinder expr_binder(left_binder, context); auto left = BindColumn(left_binder, context, left_alias, column_name); auto right = BindColumn(right_binder, context, right_alias, column_name); return make_uniq(type, std::move(left), std::move(right)); } -bool Binder::TryFindBinding(const string &using_column, const string &join_side, BindingAlias &result) { +bool Binder::TryFindBinding(const Identifier &using_column, const string &join_side, BindingAlias &result) { // for each using column, get the matching binding auto bindings = bind_context.GetMatchingBindings(using_column); if (bindings.empty()) { @@ -61,10 +63,11 @@ bool Binder::TryFindBinding(const string &using_column, const string &join_side, return true; } -BindingAlias Binder::FindBinding(const string &using_column, const string &join_side) { +BindingAlias Binder::FindBinding(const Identifier &using_column, const string &join_side) { BindingAlias result; if (!TryFindBinding(using_column, join_side, result)) { - throw BinderException("Column \"%s\" does not exist on %s side of join!", using_column, join_side); + throw BinderException("Column \"%s\" does not exist on %s side of join!", using_column.GetIdentifierName(), + join_side); } return result; } @@ -99,8 +102,35 @@ static void SetPrimaryBinding(UsingColumnSet &set, JoinType join_type, const Bin } } +static bool HasLateralCorrelatedColumns(const Expression &root_expr, Binder &binder) { + if (binder.IsInsideSubquery()) { + return false; + } + unordered_set lateral_bindings; + for (auto &active_binder_ref : binder.GetActiveBinders()) { + auto &active_binder = active_binder_ref.get(); + if (!active_binder.IsLateralBinder()) { + continue; + } + for (auto &binding : active_binder.GetBinder().bind_context.GetBindingsList()) { + lateral_bindings.insert(binding->GetIndex()); + } + } + if (lateral_bindings.empty()) { + return false; + } + bool has_lateral_correlated_columns = false; + ExpressionIterator::VisitExpression( + root_expr, [&](const BoundColumnRefExpression &colref) { + if (colref.Depth() > 0 && lateral_bindings.find(colref.Binding().table_index) != lateral_bindings.end()) { + has_lateral_correlated_columns = true; + } + }); + return has_lateral_correlated_columns; +} + BindingAlias Binder::RetrieveUsingBinding(Binder ¤t_binder, optional_ptr current_set, - const string &using_column, const string &join_side) { + const Identifier &using_column, const string &join_side) { BindingAlias binding; if (!current_set) { binding = current_binder.FindBinding(using_column, join_side); @@ -110,9 +140,9 @@ BindingAlias Binder::RetrieveUsingBinding(Binder ¤t_binder, optional_ptr RemoveDuplicateUsingColumns(const vector &using_columns) { - vector result; - case_insensitive_set_t handled_columns; +static vector RemoveDuplicateUsingColumns(const vector &using_columns) { + vector result; + identifier_set_t handled_columns; for (auto &using_column : using_columns) { if (handled_columns.find(using_column) == handled_columns.end()) { handled_columns.insert(using_column); @@ -141,7 +171,9 @@ BoundStatement Binder::Bind(JoinRef &ref) { result->delim_flipped = ref.delim_flipped; { - LateralBinder binder(left_binder, context); + LateralBinder lateral_binder(left_binder, context); + + right_binder.BeginSubqueryBind(left_binder, lateral_binder); result->right = right_binder.BindJoin(*this, *ref.right); if (!ref.duplicate_eliminated_columns.empty()) { if (ref.delim_flipped) { @@ -158,6 +190,7 @@ BoundStatement Binder::Bind(JoinRef &ref) { } } } + right_binder.FinishSubqueryBind(); bool is_lateral = false; // Store the correlated columns in the right binder in bound ref for planning of LATERALs // Ignore the correlated columns in the left binder, flattening handles those correlations @@ -180,12 +213,12 @@ BoundStatement Binder::Bind(JoinRef &ref) { } vector> extra_conditions; - vector extra_using_columns; + vector extra_using_columns; switch (ref.ref_type) { case JoinRefType::NATURAL: { // natural join, figure out which column names are present in both sides of the join // first bind the left hand side and get a list of all the tables and column names - case_insensitive_set_t lhs_columns; + identifier_set_t lhs_columns; auto &lhs_binding_list = left_binder.bind_context.GetBindingsList(); for (auto &binding : lhs_binding_list) { for (auto &column_name : binding->GetColumnNames()) { @@ -335,6 +368,10 @@ BoundStatement Binder::Bind(JoinRef &ref) { if (ref.condition) { WhereBinder binder(*this, context); result->condition = binder.Bind(ref.condition); + if (result->type != JoinType::INNER && result->type != JoinType::LEFT && + HasLateralCorrelatedColumns(*result->condition, *this)) { + throw NotImplementedException("Non-inner join on correlated columns not supported"); + } } if (result->type == JoinType::SEMI || result->type == JoinType::ANTI || result->type == JoinType::MARK) { @@ -342,7 +379,7 @@ BoundStatement Binder::Bind(JoinRef &ref) { if (result->type == JoinType::MARK) { auto mark_join_idx = GenerateTableIndex(); string mark_join_alias = "__internal_mark_join_ref" + to_string(mark_join_idx.index); - bind_context.AddGenericBinding(mark_join_idx, mark_join_alias, {"__mark_index_column"}, + bind_context.AddGenericBinding(mark_join_idx, Identifier(mark_join_alias), {"__mark_index_column"}, {LogicalType::BOOLEAN}); result->mark_index = mark_join_idx; } diff --git a/src/duckdb/src/planner/binder/tableref/bind_named_parameters.cpp b/src/duckdb/src/planner/binder/tableref/bind_named_parameters.cpp index 80afc9f33..cfb8ac9b6 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_named_parameters.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_named_parameters.cpp @@ -7,8 +7,17 @@ map order_case_insensitive_map(const case_insensitive_map_t &input return map(input_map.begin(), input_map.end()); } +template +map order_case_insensitive_map(const identifier_map_t &input_map) { + map result; + for (auto &entry : input_map) { + result[entry.first.GetIdentifierName()] = entry.second; + } + return result; +} + void Binder::BindNamedParameters(named_parameter_type_map_t &types, named_parameter_map_t &values, - QueryErrorContext &error_context, string &func_name) { + QueryErrorContext &error_context, const Identifier &func_name) { for (auto &kv : values) { auto entry = types.find(kv.first); if (entry == types.end()) { @@ -29,7 +38,7 @@ void Binder::BindNamedParameters(named_parameter_type_map_t &types, named_parame error_msg = "Candidates:\n" + named_params; } throw BinderException(error_context, "Invalid named parameter \"%s\" for function %s\n%s", kv.first, - func_name, error_msg); + func_name.GetIdentifierName(), error_msg); } if (entry->second.id() != LogicalTypeId::ANY) { kv.second = kv.second.DefaultCastAs(entry->second); diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp index 66ffb5ddf..ec7b06f87 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -33,7 +33,7 @@ static void ConstructPivots(PivotRef &ref, vector &pivot_valu bool last_pivot = pivot_idx + 1 == ref.pivots.size(); for (auto &entry : pivot.entries) { PivotValueElement new_value = current_value; - string name = entry.alias; + string name = entry.alias.GetIdentifierName(); D_ASSERT(entry.values.size() == pivot.pivot_expressions.size()); for (idx_t v = 0; v < entry.values.size(); v++) { auto &value = entry.values[v]; @@ -60,13 +60,13 @@ static void ConstructPivots(PivotRef &ref, vector &pivot_valu } } -static void ExtractPivotExpressions(ParsedExpression &root_expr, case_insensitive_set_t &handled_columns, +static void ExtractPivotExpressions(ParsedExpression &root_expr, identifier_set_t &handled_columns, optional_ptr macro_binding) { ParsedExpressionIterator::VisitExpression( root_expr, [&](const ColumnRefExpression &child_colref) { if (child_colref.IsQualified()) { - if (child_colref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos && macro_binding && - macro_binding->HasMatchingBinding(child_colref.GetName())) { + if (child_colref.ColumnNames()[0].GetIdentifierName().find(DummyBinding::DUMMY_NAME) != string::npos && + macro_binding && macro_binding->HasMatchingBinding(Identifier(child_colref.GetName()))) { throw ParameterNotResolvedException(); } throw BinderException(child_colref, "PIVOT expression cannot contain qualified columns"); @@ -105,8 +105,8 @@ void TemplatedHandlePivotAggregate(ClientContext &context, unique_ptrCast(); // check if this is an aggregate to ensure it is an aggregate and not a scalar function - EntryLookupInfo lookup_info(CatalogType::AGGREGATE_FUNCTION_ENTRY, aggr_function.function_name, *expr); - auto &entry = Catalog::GetEntry(context, aggr_function.catalog, aggr_function.schema, lookup_info); + EntryLookupInfo lookup_info(CatalogType::AGGREGATE_FUNCTION_ENTRY, aggr_function.FunctionName(), *expr); + auto &entry = Catalog::GetEntry(context, aggr_function.Catalog(), aggr_function.Schema(), lookup_info); if (entry.type == CatalogType::AGGREGATE_FUNCTION_ENTRY) { // aggregate OP::HandleAggregate(expr, aggr_function, aggregates); @@ -132,7 +132,7 @@ void ReplacePivotAggregateExpression(ClientContext &context, unique_ptr ConstructInitialGrouping(PivotRef &ref, vector> all_columns, - const case_insensitive_set_t &handled_columns) { + const identifier_set_t &handled_columns) { auto subquery = make_uniq(); subquery->from_table = std::move(ref.source); if (ref.groups.empty()) { @@ -163,7 +163,7 @@ static unique_ptr ConstructInitialGrouping(PivotRef &ref, vector PivotFilteredAggregate(ClientContext &context, PivotRef &ref, vector> all_columns, - const case_insensitive_set_t &handled_columns, + const identifier_set_t &handled_columns, vector pivot_values) { auto subquery = ConstructInitialGrouping(ref, std::move(all_columns), handled_columns); @@ -194,14 +194,15 @@ static unique_ptr PivotFilteredAggregate(ClientContext &context, Piv D_ASSERT(aggregates.size() == 1); auto &aggr = aggregates[0].get().Cast(); - aggr.filter = filter->Copy(); + aggr.FilterMutable() = filter->Copy(); auto &aggr_name = aggregate->GetAlias(); auto name = pivot_value.name; if (ref.aggregates.size() > 1 || !aggr_name.empty()) { // if there are multiple aggregates specified we add the name of the aggregate as well - name += "_" + (aggr_name.empty() ? aggregate->GetName() : aggr_name); + name += "_" + + (aggr_name.empty() ? aggregate->GetName().GetIdentifierName() : aggr_name.GetIdentifierName()); } - copied_aggr->SetAlias(name); + copied_aggr->SetAlias(Identifier(name)); subquery->select_list.push_back(std::move(copied_aggr)); } } @@ -209,37 +210,37 @@ static unique_ptr PivotFilteredAggregate(ClientContext &context, Piv } struct PivotBindState { - vector internal_group_names; - vector group_names; - vector aggregate_names; - vector internal_aggregate_names; + vector internal_group_names; + vector group_names; + vector aggregate_names; + vector internal_aggregate_names; }; static unique_ptr PivotInitialAggregate(ClientContext &context, PivotBindState &bind_state, PivotRef &ref, vector> all_columns, - const case_insensitive_set_t &handled_columns) { + const identifier_set_t &handled_columns) { auto subquery_stage1 = ConstructInitialGrouping(ref, std::move(all_columns), handled_columns); idx_t group_count = 0; for (auto &expr : subquery_stage1->select_list) { bind_state.group_names.push_back(expr->GetName()); if (expr->GetAlias().empty()) { - expr->SetAlias("__internal_pivot_group" + std::to_string(++group_count)); + expr->SetAlias(Identifier("__internal_pivot_group" + std::to_string(++group_count))); } - bind_state.internal_group_names.push_back(expr->GetAlias()); + bind_state.internal_group_names.emplace_back(expr->GetAlias()); } // group by all of the pivot values idx_t pivot_count = 0; for (auto &pivot_column : ref.pivots) { for (auto &pivot_expr : pivot_column.pivot_expressions) { if (pivot_expr->GetAlias().empty()) { - pivot_expr->SetAlias("__internal_pivot_ref" + std::to_string(++pivot_count)); + pivot_expr->SetAlias(Identifier("__internal_pivot_ref" + std::to_string(++pivot_count))); } auto pivot_alias = pivot_expr->GetAlias(); subquery_stage1->groups.group_expressions.push_back(make_uniq( Value::INTEGER(UnsafeNumericCast(subquery_stage1->select_list.size() + 1)))); subquery_stage1->select_list.push_back(std::move(pivot_expr)); - pivot_expr = make_uniq(std::move(pivot_alias)); + pivot_expr = make_uniq(pivot_alias); } } idx_t aggregate_count = 0; @@ -248,19 +249,19 @@ static unique_ptr PivotInitialAggregate(ClientContext &context, Pivo auto aggregate_alias = "__internal_pivot_aggregate" + std::to_string(++aggregate_count); auto aggr_name = aggregate->GetAlias(); if (aggr_name.empty() && ref.aggregates.size() > 1) { - aggr_name = aggregate->ToString(); + aggr_name = Identifier(aggregate->ToString()); } if (!aggr_name.empty()) { - aggr_name = "_" + aggr_name; + aggr_name = Identifier("_" + aggr_name); } unique_ptr aggregate_ref; - aggregate_ref = make_uniq(aggregate_alias); + aggregate_ref = make_uniq(Identifier(aggregate_alias)); ReplacePivotAggregateExpression(context, aggregate, aggregate_ref); - bind_state.aggregate_names.push_back(std::move(aggr_name)); - bind_state.internal_aggregate_names.push_back(aggregate_alias); - aggregate_ref->SetAlias(std::move(aggregate_alias)); + bind_state.aggregate_names.emplace_back(std::move(aggr_name)); + bind_state.internal_aggregate_names.push_back(Identifier(aggregate_alias)); + aggregate_ref->SetAlias(Identifier(std::move(aggregate_alias))); subquery_stage1->select_list.push_back(std::move(aggregate_ref)); } return subquery_stage1; @@ -326,16 +327,16 @@ static unique_ptr PivotListAggregate(PivotBindState &bind_state, Piv list_children.push_back(std::move(expr)); auto aggregate = make_uniq("list", std::move(list_children)); - aggregate->SetAlias(pivot_name); + aggregate->SetAlias(Identifier(pivot_name)); subquery_stage2->select_list.push_back(std::move(aggregate)); subquery_stage2->from_table = std::move(subquery_ref); return subquery_stage2; } -void ReplacePivotColumnRef(ParsedExpression &root_expr, const string &name) { +void ReplacePivotColumnRef(ParsedExpression &root_expr, const Identifier &name) { ParsedExpressionIterator::VisitExpressionMutable( - root_expr, [&](ColumnRefExpression &colref) { colref.column_names[0] = name; }); + root_expr, [&](ColumnRefExpression &colref) { colref.ColumnNamesMutable()[0] = name; }); } static unique_ptr PivotFinalOperator(PivotBindState &bind_state, PivotRef &ref, @@ -357,13 +358,13 @@ static unique_ptr PivotFinalOperator(PivotBindState &bind_state, Piv final_pivot_operator->select_list.push_back(make_uniq(group_name)); } // gather aggregate names - vector aggregate_names; + vector aggregate_names; for (auto &pivot_value : bound_pivot->bound_pivot_values) { for (idx_t aggr_idx = 0; aggr_idx < ref.aggregates.size(); aggr_idx++) { auto aggr = ref.aggregates[aggr_idx]->Copy(); auto &aggr_name = bound_pivot->bound_aggregate_names[aggr_idx]; - auto pivot_aggr_name = pivot_value.name + aggr_name; - aggregate_names.push_back(std::move(pivot_aggr_name)); + auto pivot_aggr_name = pivot_value.name + aggr_name.GetIdentifierName(); + aggregate_names.push_back(Identifier(std::move(pivot_aggr_name))); } } QueryResult::DeduplicateColumns(aggregate_names); @@ -433,10 +434,10 @@ BoundStatement Binder::BindBoundPivot(PivotRef &ref) { if (aggregates.size() < ref.bound_aggregate_names.size()) { vector unique_names; - unordered_set seen; + identifier_set_t seen; for (auto &name : ref.bound_aggregate_names) { if (seen.find(name) == seen.end()) { - unique_names.push_back(name); + unique_names.push_back(name.GetIdentifierName()); seen.insert(name); } } @@ -446,9 +447,9 @@ BoundStatement Binder::BindBoundPivot(PivotRef &ref) { aggregates.size()); } - unordered_map name_to_position; + identifier_map_t name_to_position; for (idx_t i = 0; i < unique_names.size(); i++) { - name_to_position[unique_names[i]] = i; + name_to_position[Identifier(unique_names[i])] = i; } vector> expanded_aggs; @@ -465,11 +466,11 @@ BoundStatement Binder::BindBoundPivot(PivotRef &ref) { ref.bound_aggregate_names.size(), aggregates.size()); } - vector child_names; + vector child_names; vector child_types; result.child_binder->bind_context.GetTypesAndNames(child_names, child_types); - vector names; + vector names; vector types; // emit the groups for (idx_t i = 0; i < ref.bound_group_names.size(); i++) { @@ -493,13 +494,13 @@ BoundStatement Binder::BindBoundPivot(PivotRef &ref) { } } result.bound_pivot.pivot_values.push_back(std::move(pivot_str)); - names.push_back(std::move(name)); + names.push_back(Identifier(std::move(name))); types.push_back(aggr->GetReturnType()); } } result.bound_pivot.group_count = ref.bound_group_names.size(); result.bound_pivot.types = types; - auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; + Identifier subquery_alias = ref.alias.empty() ? Identifier("__unnamed_pivot") : ref.alias; QueryResult::DeduplicateColumns(names); bind_context.AddGenericBinding(result.bind_index, subquery_alias, names, types); @@ -511,6 +512,51 @@ BoundStatement Binder::BindBoundPivot(PivotRef &ref) { return result_statement; } +static void BindPivotConstantInList(unique_ptr &expr, vector &values, Binder &binder) { + Value val; + ConstantBinder const_binder(binder, binder.context, "PIVOT IN list"); + auto bound_expr = const_binder.Bind(expr); + if (!bound_expr->IsFoldable()) { + throw BinderException(expr->GetQueryLocation(), "PIVOT IN list must contain constant expressions"); + } + auto folded_value = ExpressionExecutor::EvaluateScalar(binder.context, *bound_expr); + values.push_back(folded_value); +} + +static bool TryExtractUnpivotList(ParsedExpression &expr, vector &column_names) { + auto initial_size = column_names.size(); + switch (expr.GetExpressionType()) { + case ExpressionType::COLUMN_REF: { + auto &colref = expr.Cast(); + if (colref.IsQualified()) { + return false; + } + column_names.push_back(colref.GetColumnName()); + return true; + } + case ExpressionType::VALUE_CONSTANT: { + auto &constant = expr.Cast(); + column_names.push_back(Identifier(constant.GetValue().ToString())); + return true; + } + case ExpressionType::FUNCTION: { + auto &function = expr.Cast(); + if (function.FunctionName() != "row") { + return false; + } + for (auto &child : function.GetArgumentsMutable()) { + if (!TryExtractUnpivotList(*child.GetExpressionMutable(), column_names)) { + column_names.resize(initial_size); + return false; + } + } + return true; + } + default: + return false; + } +} + static void BindPivotInList(unique_ptr &expr, vector &values, Binder &binder) { switch (expr->GetExpressionType()) { case ExpressionType::COLUMN_REF: { @@ -522,22 +568,16 @@ static void BindPivotInList(unique_ptr &expr, vector &v } break; case ExpressionType::FUNCTION: { auto &function = expr->Cast(); - if (function.function_name != "row") { - throw BinderException(expr->GetQueryLocation(), "PIVOT IN list must contain columns or lists of columns"); - } - for (auto &child : function.children) { - BindPivotInList(child, values, binder); + if (function.FunctionName() == "row") { + for (auto &child : function.GetArgumentsMutable()) { + BindPivotInList(child.GetExpressionMutable(), values, binder); + } + } else { + BindPivotConstantInList(expr, values, binder); } } break; default: { - Value val; - ConstantBinder const_binder(binder, binder.context, "PIVOT IN list"); - auto bound_expr = const_binder.Bind(expr); - if (!bound_expr->IsFoldable()) { - throw BinderException(expr->GetQueryLocation(), "PIVOT IN list must contain constant expressions"); - } - auto folded_value = ExpressionExecutor::EvaluateScalar(binder.context, *bound_expr); - values.push_back(folded_value); + BindPivotConstantInList(expr, values, binder); } break; } } @@ -545,7 +585,7 @@ static void BindPivotInList(unique_ptr &expr, vector &v unique_ptr Binder::BindPivot(PivotRef &ref, vector> all_columns) { // keep track of the columns by which we pivot/aggregate // any columns which are not pivoted/aggregated on are added to the GROUP BY clause - case_insensitive_set_t handled_columns; + identifier_set_t handled_columns; vector> pivot_aggregates; // parse the aggregate, and extract the referenced columns from the aggregate @@ -599,8 +639,8 @@ unique_ptr Binder::BindPivot(PivotRef &ref, vector(context, INVALID_CATALOG, INVALID_SCHEMA, pivot.pivot_enum); + auto &type_entry = Catalog::GetEntry(context, Identifier::InvalidCatalog(), + Identifier::InvalidSchema(), pivot.pivot_enum); auto type = type_entry.user_type; if (type.id() != LogicalTypeId::ENUM) { throw BinderException(ref, "Pivot must reference an ENUM type: \"%s\" is of type \"%s\"", @@ -614,7 +654,7 @@ unique_ptr Binder::BindPivot(PivotRef &ref, vector Binder::BindPivot(PivotRef &ref, vector column_names; + Identifier alias; + vector column_names; vector> expressions; }; void Binder::ExtractUnpivotEntries(Binder &child_binder, PivotColumnEntry &entry, vector &unpivot_entries) { - // Try to bind the entry expression as values - try { - auto expr_copy = entry.expr->Copy(); - BindPivotInList(expr_copy, entry.values, child_binder); - // successfully bound as values - clear the expression - entry.expr = nullptr; - } catch (...) { - // ignore binder exceptions here - we fall back to expression mode - entry.values.clear(); - } - if (!entry.expr) { // pivot entry without an expression - generate one UnpivotEntry unpivot_entry; - unpivot_entry.alias = entry.alias; + unpivot_entry.alias = Identifier(entry.alias.GetIdentifierName()); for (auto &val : entry.values) { auto column_name = val.ToString(); + if (column_name.empty()) { + throw BinderException("UNPIVOT - empty column name not supported"); + } + unpivot_entry.expressions.push_back(make_uniq(Identifier(column_name))); + } + unpivot_entries.push_back(std::move(unpivot_entry)); + return; + } + + vector column_names; + if (TryExtractUnpivotList(*entry.expr, column_names)) { + UnpivotEntry unpivot_entry; + unpivot_entry.alias = Identifier(entry.alias.GetIdentifierName()); + for (auto &column_name : column_names) { if (column_name.empty()) { throw BinderException("UNPIVOT - empty column name not supported"); } @@ -727,6 +770,7 @@ void Binder::ExtractUnpivotEntries(Binder &child_binder, PivotColumnEntry &entry unpivot_entries.push_back(std::move(unpivot_entry)); return; } + D_ASSERT(entry.values.empty()); // expand star expressions (if any) vector> star_columns; @@ -736,14 +780,14 @@ void Binder::ExtractUnpivotEntries(Binder &child_binder, PivotColumnEntry &entry // create one pivot entry per result column UnpivotEntry unpivot_entry; if (!expr->GetAlias().empty()) { - unpivot_entry.alias = expr->GetAlias(); + unpivot_entry.alias = Identifier(expr->GetAlias().GetIdentifierName()); } unpivot_entry.expressions.push_back(std::move(expr)); unpivot_entries.push_back(std::move(unpivot_entry)); } } -void Binder::ExtractUnpivotColumnName(ParsedExpression &expr, vector &result) { +void Binder::ExtractUnpivotColumnName(ParsedExpression &expr, vector &result) { if (expr.GetExpressionType() == ExpressionType::COLUMN_REF) { auto &colref = expr.Cast(); result.push_back(colref.GetColumnName()); @@ -777,11 +821,11 @@ unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, throw BinderException(ref, "UNPIVOT clause must unpivot on at least one column - zero were provided"); } - case_insensitive_set_t handled_columns; - case_insensitive_map_t name_map; + identifier_set_t handled_columns; + identifier_map_t name_map; for (auto &entry : unpivot_entries) { for (auto &unpivot_expr : entry.expressions) { - vector result; + vector result; ExtractUnpivotColumnName(*unpivot_expr, result); if (result.empty()) { throw BinderException( @@ -800,7 +844,7 @@ unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, } } - vector select_names; + vector select_names; for (auto &col_expr : all_columns) { if (col_expr->GetExpressionType() != ExpressionType::COLUMN_REF) { throw InternalException("Unexpected child of pivot source - not a ColumnRef"); @@ -837,7 +881,7 @@ unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, } generated_name += name_entry->second; } - unpivot_names.emplace_back(!entry.alias.empty() ? entry.alias : generated_name); + unpivot_names.emplace_back(!entry.alias.empty() ? entry.alias : Identifier(generated_name)); } vector>> unpivot_expressions; for (idx_t v_idx = 1; v_idx < unpivot_entries.size(); v_idx++) { @@ -864,9 +908,9 @@ unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, select_names.push_back("unpivot_names"); for (idx_t i = 0; i < unpivot_expressions.size(); i++) { if (i > 0) { - select_names.push_back("unpivot_list_" + std::to_string(i + 1)); + select_names.push_back(Identifier("unpivot_list_" + std::to_string(i + 1))); } else { - select_names.push_back("unpivot_list"); + select_names.push_back(Identifier("unpivot_list")); } } @@ -919,7 +963,7 @@ unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, vector> unnest_val_children; unnest_val_children.push_back(std::move(unpivot_list_ref)); auto unnest_val_expr = make_uniq("unnest", std::move(unnest_val_children)); - auto unnest_name = i < ref.column_name_alias.size() ? ref.column_name_alias[i] : ref.unpivot_names[i]; + auto &unnest_name = i < ref.column_name_alias.size() ? ref.column_name_alias[i] : ref.unpivot_names[i]; unnest_val_expr->SetAlias(unnest_name); result_node->select_list.push_back(std::move(unnest_val_expr)); if (!ref.include_nulls) { @@ -949,7 +993,7 @@ BoundStatement Binder::Bind(PivotRef &ref) { // bind the source of the pivot // we need to do this to be able to expand star expressions if (ref.source->type == TableReferenceType::SUBQUERY && ref.source->alias.empty()) { - ref.source->alias = "__internal_pivot_alias_" + to_string(GenerateTableIndex().index); + ref.source->alias = Identifier("__internal_pivot_alias_" + to_string(GenerateTableIndex().index)); } auto copied_source = ref.source->Copy(); auto star_binder = Binder::CreateBinder(context, this); @@ -966,13 +1010,12 @@ BoundStatement Binder::Bind(PivotRef &ref) { } else { select_node = BindUnpivot(*star_binder, ref, std::move(all_columns), where_clause); } - // bind the generated select node auto child_binder = Binder::CreateBinder(context, this); auto result = child_binder->BindNode(*select_node); auto root_index = result.plan->GetRootIndex(); MoveCorrelatedExpressions(*child_binder); - auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; + Identifier subquery_alias = ref.alias.empty() ? Identifier("__unnamed_pivot") : ref.alias; SubqueryRef subquery_ref(nullptr, subquery_alias); subquery_ref.column_name_alias = std::move(ref.column_name_alias); if (where_clause) { diff --git a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp index 6b1b37e44..c0bba7f33 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp @@ -52,7 +52,7 @@ BaseTableColumnInfo FindBaseTableColumn(LogicalOperator &op, ColumnBinding bindi if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { // if the projection at this index only has a column reference we can directly trace it to the base table auto &bound_colref = expr.Cast(); - return FindBaseTableColumn(*projection.children[0], bound_colref.binding); + return FindBaseTableColumn(*projection.children[0], bound_colref.Binding()); } break; } @@ -93,7 +93,7 @@ BoundStatement Binder::BindShowQuery(ShowRef &ref) { auto plan = child_binder->Bind(*ref.query); // construct a column data collection with the result - vector return_names = {"column_name", "column_type", "null", "key", "default", "extra"}; + vector return_names = {"column_name", "column_type", "null", "key", "default", "extra"}; vector return_types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}; DataChunk output; @@ -132,7 +132,7 @@ BoundStatement Binder::BindShowQuery(ShowRef &ref) { output.data[5].Append(Value()); } - output.SetCardinality(output.size() + 1); + // both branches above append exactly one row to the child vectors, growing output.size() accordingly if (output.size() == STANDARD_VECTOR_SIZE) { collection->Append(append_state, output); output.Reset(); @@ -151,7 +151,7 @@ BoundStatement Binder::BindShowQuery(ShowRef &ref) { } BoundStatement Binder::BindShowTable(ShowRef &ref) { - auto lname = StringUtil::Lower(ref.table_name); + auto &lname = ref.table_name; string sql; if (lname == "\"databases\"") { @@ -179,8 +179,8 @@ BoundStatement Binder::BindShowTable(ShowRef &ref) { if (!catalog_name.empty() && !schema_name.empty()) { auto schema_entry = Catalog::GetSchema(context, catalog_name, schema_name, OnEntryNotFound::RETURN_NULL); if (!schema_entry) { - throw CatalogException("SHOW TABLES FROM: No catalog + schema named \"%s.%s\" found.", catalog_name, - schema_name); + throw CatalogException("SHOW TABLES FROM: No catalog + schema named \"%s.%s\" found.", + catalog_name.GetIdentifierName(), schema_name.GetIdentifierName()); } } else if (catalog_name.empty() && !schema_name.empty()) { // We have a schema name, use default catalog @@ -189,17 +189,17 @@ BoundStatement Binder::BindShowTable(ShowRef &ref) { catalog_name = default_entry.catalog; auto schema_entry = Catalog::GetSchema(context, catalog_name, schema_name, OnEntryNotFound::RETURN_NULL); if (!schema_entry) { - throw CatalogException("SHOW TABLES FROM: No catalog + schema named \"%s.%s\" found.", catalog_name, - schema_name); + throw CatalogException("SHOW TABLES FROM: No catalog + schema named \"%s.%s\" found.", + catalog_name.GetIdentifierName(), schema_name.GetIdentifierName()); } } - sql = PragmaShowTables(catalog_name, schema_name); + sql = PragmaShowTables(catalog_name.GetIdentifierName(), schema_name.GetIdentifierName()); } else if (lname == "\"variables\"") { sql = PragmaShowVariables(); } else if (lname == "__show_tables_expanded") { sql = PragmaShowTablesExpanded(); } else { - sql = PragmaShow(ref.table_name); + sql = PragmaShow(ref.table_name.GetIdentifierName()); } auto select = CreateViewInfo::ParseSelect(sql); auto subquery = make_uniq(std::move(select)); diff --git a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp index b60bd6e1f..02dec0202 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp @@ -5,17 +5,17 @@ namespace duckdb { BoundStatement Binder::Bind(SubqueryRef &ref) { auto binder = Binder::CreateBinder(context, this); - binder->can_contain_nulls = true; + binder->SetCanContainNulls(true); + binder->SetInsideSubquery(); auto subquery = binder->BindNode(*ref.subquery->node); binder->alias = ref.alias.empty() ? "unnamed_subquery" : ref.alias; auto bind_index = subquery.plan->GetRootIndex(); - string subquery_alias; + Identifier subquery_alias; if (ref.alias.empty()) { auto index = unnamed_subquery_index++; subquery_alias = "unnamed_subquery"; - ; if (index > 1) { - subquery_alias += to_string(index); + subquery_alias = Identifier(subquery_alias + to_string(index)); } } else { subquery_alias = ref.alias; diff --git a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp index d19ca939e..f0c2321f7 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp @@ -79,7 +79,7 @@ void Binder::BindTableInTableOutFunction(vector> &e auto select_node = make_uniq(); select_node->select_list = std::move(expressions); select_node->from_table = make_uniq(); - binder->can_contain_nulls = true; + binder->SetCanContainNulls(true); subquery = binder->BindNode(*select_node); MoveCorrelatedExpressions(*binder); } @@ -99,17 +99,17 @@ bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_functi } bool seen_subquery = false; for (auto &child : expressions) { - string parameter_name; + Identifier parameter_name; // hack to make named parameters work if (child->GetExpressionType() == ExpressionType::COMPARE_EQUAL) { // comparison, check if the LHS is a columnref auto &comp = child->Cast(); - if (comp.left->GetExpressionType() == ExpressionType::COLUMN_REF) { - auto &colref = comp.left->Cast(); + if (comp.Left().GetExpressionType() == ExpressionType::COLUMN_REF) { + auto &colref = comp.Left().Cast(); if (!colref.IsQualified()) { parameter_name = colref.GetColumnName(); - child = std::move(comp.right); + child = std::move(comp.RightMutable()); } } } else if (!child->GetAlias().empty()) { @@ -130,9 +130,9 @@ bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_functi return false; } auto binder = Binder::CreateBinder(this->context, this); - binder->can_contain_nulls = true; + binder->SetCanContainNulls(true); auto &se = child->Cast(); - subquery = binder->BindNode(*se.subquery->node); + subquery = binder->BindNode(*se.Subquery()->node); MoveCorrelatedExpressions(*binder); seen_subquery = true; arguments.emplace_back(LogicalTypeId::TABLE); @@ -140,7 +140,7 @@ bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_functi continue; } - TableFunctionBinder binder(*this, context, table_function.name); + TableFunctionBinder binder(*this, context, table_function.name.GetIdentifierName()); LogicalType sql_type; auto expr = binder.Bind(child, &sql_type); if (expr->HasParameter()) { @@ -168,19 +168,28 @@ bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_functi static string GetAlias(const TableFunctionRef &ref) { if (!ref.alias.empty()) { - return ref.alias; + return ref.alias.GetIdentifierName(); } if (ref.function && ref.function->GetExpressionType() == ExpressionType::FUNCTION) { auto &function_expr = ref.function->Cast(); - return function_expr.function_name; + return function_expr.FunctionName().GetIdentifierName(); } return string(); } +static void ApplyPostgresSetofAliasCompatibility(const TableFunction &table_function, const TableFunctionRef &ref, + vector &return_names) { + if (table_function.return_type != TableFunctionReturnType::SET_RETURNING_FUNCTION || ref.alias.empty() || + !ref.column_name_alias.empty() || return_names.size() != 1) { + return; + } + return_names[0] = ref.alias; +} + BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, vector parameters, named_parameter_map_t named_parameters, vector input_table_types, - vector input_table_names, + vector input_table_names, optional_ptr> input_plan) { auto function_name = GetAlias(ref); auto &column_name_alias = ref.column_name_alias; @@ -188,7 +197,7 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, // perform the binding unique_ptr bind_data; vector return_types; - vector return_names; + vector return_names; auto constexpr ordinality_name = "ordinality"; string ordinality_column_name = ordinality_name; optional_idx ordinality_column_id; @@ -196,7 +205,9 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, TableFunctionBindInput bind_input(parameters, named_parameters, input_table_types, input_table_names, table_function.function_info.get(), this, table_function, ref, input_plan); if (table_function.bind_operator) { - auto new_plan = table_function.bind_operator(context, bind_input, bind_index, return_names); + vector operator_names; + auto new_plan = table_function.bind_operator(context, bind_input, bind_index, operator_names); + return_names = StringsToIdentifiers(operator_names); if (new_plan) { new_plan->ResolveOperatorTypes(); if (new_plan->types.size() != return_names.size()) { @@ -210,8 +221,9 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, table_function.name); } } + ApplyPostgresSetofAliasCompatibility(table_function, ref, return_names); BoundStatement result; - bind_context.AddGenericBinding(bind_index, function_name, return_names, new_plan->types); + bind_context.AddGenericBinding(bind_index, Identifier(function_name), return_names, new_plan->types); result.names = return_names; result.types = new_plan->types; result.plan = std::move(new_plan); @@ -234,11 +246,13 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, throw BinderException("Failed to bind \"%s\": nullptr returned from bind_replace without bind function", table_function.name); } - bind_data = table_function.bind(context, bind_input, return_types, return_names); + vector bind_names; + bind_data = table_function.bind(context, bind_input, return_types, bind_names); + return_names = StringsToIdentifiers(bind_names); if (ref.with_ordinality == OrdinalityType::WITH_ORDINALITY) { // check if column name 'ordinality' already exists and if so, replace it iteratively until free name is // found - case_insensitive_set_t ci_return_names; + identifier_set_t ci_return_names; idx_t ordinality_name_suffix = 0; for (auto &n : return_names) { ci_return_names.insert(n); @@ -246,7 +260,7 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, for (auto &n : column_name_alias) { ci_return_names.insert(n); } - while (ci_return_names.find(ordinality_column_name) != ci_return_names.end()) { + while (ci_return_names.find(Identifier(ordinality_column_name)) != ci_return_names.end()) { ordinality_column_name = ordinality_name + to_string(ordinality_name_suffix++); } if (!correlated_columns.empty()) { @@ -270,13 +284,14 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, throw InternalException("Failed to bind \"%s\": Table function must return at least one column", table_function.name); } + ApplyPostgresSetofAliasCompatibility(table_function, ref, return_names); // overwrite the names with any supplied aliases for (idx_t i = 0; i < column_name_alias.size() && i < return_names.size(); i++) { return_names[i] = column_name_alias[i]; } for (idx_t i = 0; i < return_names.size(); i++) { if (return_names[i].empty()) { - return_names[i] = "C" + to_string(i); + return_names[i] = Identifier("C" + to_string(i)); } } virtual_column_map_t virtual_columns; @@ -300,27 +315,28 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, } if (ref.with_ordinality == OrdinalityType::WITH_ORDINALITY && correlated_columns.empty()) { - bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->GetMutableColumnIds(), - get->GetTable().get(), std::move(virtual_columns)); + bind_context.AddTableFunction(bind_index, Identifier(function_name), return_names, return_types, + get->GetMutableColumnIds(), get->GetTable().get(), std::move(virtual_columns)); auto window_index = GenerateTableIndex(); auto window = make_uniq(window_index); auto row_number = RowNumberFun::GetFunction().Bind(context); - row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; - row_number->end = WindowBoundary::CURRENT_ROW_ROWS; - string ordinality_alias = ordinality_column_name; + row_number->WindowStartMutable() = WindowBoundary::UNBOUNDED_PRECEDING; + row_number->WindowEndMutable() = WindowBoundary::CURRENT_ROW_ROWS; + Identifier ordinality_alias(ordinality_column_name); if (return_names.size() < column_name_alias.size()) { row_number->SetAlias(column_name_alias[return_names.size()]); ordinality_alias = column_name_alias[return_names.size()]; } else { - row_number->SetAlias(ordinality_column_name); + row_number->SetAlias(Identifier(ordinality_column_name)); } return_names.push_back(ordinality_alias); return_types.push_back(LogicalType::BIGINT); window->expressions.push_back(std::move(row_number)); window->types.push_back(LogicalType::BIGINT); window->children.push_back(std::move(get)); - bind_context.AddGenericBinding(window_index, function_name, {ordinality_alias}, {LogicalType::BIGINT}); + bind_context.AddGenericBinding(window_index, Identifier(function_name), {ordinality_alias}, + {LogicalType::BIGINT}); BoundStatement result; result.names = std::move(return_names); @@ -330,8 +346,8 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, } // now add the table function to the bind context so its columns can be bound BoundStatement result; - bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->GetMutableColumnIds(), - get->GetTable().get(), std::move(virtual_columns)); + bind_context.AddTableFunction(bind_index, Identifier(function_name), return_names, return_types, + get->GetMutableColumnIds(), get->GetTable().get(), std::move(virtual_columns)); result.names = std::move(return_names); result.types = std::move(return_types); result.plan = std::move(get); @@ -341,7 +357,7 @@ BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, BoundStatement Binder::BindTableFunction(TableFunction &function, vector parameters) { named_parameter_map_t named_parameters; vector input_table_types; - vector input_table_names; + vector input_table_names; TableFunctionRef ref; ref.alias = function.name; @@ -356,13 +372,13 @@ BoundStatement Binder::Bind(TableFunctionRef &ref) { D_ASSERT(ref.function->GetExpressionType() == ExpressionType::FUNCTION); auto &fexpr = ref.function->Cast(); - string catalog = fexpr.catalog; - string schema = fexpr.schema; + Identifier catalog = fexpr.Catalog(); + Identifier schema = fexpr.Schema(); Binder::BindSchemaOrCatalog(context, catalog, schema); // fetch the function from the catalog - EntryLookupInfo table_function_lookup(CatalogType::TABLE_FUNCTION_ENTRY, fexpr.function_name, error_context); + EntryLookupInfo table_function_lookup(CatalogType::TABLE_FUNCTION_ENTRY, fexpr.FunctionName(), error_context); auto &func_catalog = *GetCatalogEntry(catalog, schema, table_function_lookup, OnEntryNotFound::THROW_EXCEPTION); if (func_catalog.type == CatalogType::TABLE_MACRO_ENTRY) { @@ -371,7 +387,7 @@ BoundStatement Binder::Bind(TableFunctionRef &ref) { D_ASSERT(query_node); auto binder = Binder::CreateBinder(context, this); - binder->can_contain_nulls = true; + binder->SetCanContainNulls(true); binder->alias = ref.alias.empty() ? "unnamed_query" : ref.alias; BoundStatement query; @@ -385,10 +401,11 @@ BoundStatement Binder::Bind(TableFunctionRef &ref) { auto bind_index = query.plan->GetRootIndex(); // string alias; - string alias = (ref.alias.empty() ? "unnamed_query" + to_string(bind_index.index) : ref.alias); + string alias = + (ref.alias.empty() ? "unnamed_query" + to_string(bind_index.index) : ref.alias.GetIdentifierName()); // remember ref here is TableFunctionRef and NOT base class - bind_context.AddSubquery(bind_index, alias, ref, query); + bind_context.AddSubquery(bind_index, Identifier(alias), ref, query); MoveCorrelatedExpressions(*binder); return query; } @@ -401,8 +418,13 @@ BoundStatement Binder::Bind(TableFunctionRef &ref) { named_parameter_map_t named_parameters; BoundStatement subquery; ErrorData error; - if (!BindTableFunctionParameters(function, fexpr.children, arguments, parameters, named_parameters, subquery, - error)) { + + vector> children; + for (auto &child : fexpr.GetArgumentsMutable()) { + children.push_back(std::move(child.GetExpressionMutable())); + } + + if (!BindTableFunctionParameters(function, children, arguments, parameters, named_parameters, subquery, error)) { error.AddQueryLocation(ref); error.Throw(); } @@ -420,7 +442,7 @@ BoundStatement Binder::Bind(TableFunctionRef &ref) { BindNamedParameters(table_function.named_parameters, named_parameters, error_context, table_function.name); vector input_table_types; - vector input_table_names; + vector input_table_names; if (subquery.plan) { input_table_types = subquery.types; @@ -428,7 +450,7 @@ BoundStatement Binder::Bind(TableFunctionRef &ref) { } else if (table_function.in_out_function) { for (auto ¶m : parameters) { input_table_types.push_back(param.type()); - input_table_names.push_back(string()); + input_table_names.push_back(Identifier()); } } if (!parameters.empty()) { diff --git a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp index 09c4bb90c..1fb8c2089 100644 --- a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp @@ -341,7 +341,7 @@ static bool HasCorrelatedColumns(const Expression &root_expr) { bool has_correlated_columns = false; ExpressionIterator::VisitExpression(root_expr, [&](const BoundColumnRefExpression &colref) { - if (colref.depth > 0) { + if (colref.Depth() > 0) { has_correlated_columns = true; } }); @@ -379,7 +379,6 @@ unique_ptr Binder::CreatePlan(BoundJoinRef &ref) { } if (ref.type == JoinType::RIGHT && ref.ref_type != JoinRefType::ASOF && - ClientConfig::GetConfig(context).enable_optimizer && !Optimizer::OptimizerDisabled(context, OptimizerType::BUILD_SIDE_PROBE_SIDE)) { // we turn any right outer joins into left outer joins for optimization purposes // they are the same but with sides flipped, so treating them the same simplifies life @@ -416,7 +415,6 @@ unique_ptr Binder::CreatePlan(BoundJoinRef &ref) { filter->AddChild(std::move(root)); return std::move(filter); } - // now create the join operator from the join condition auto result = LogicalComparisonJoin::CreateJoin(context, ref.type, ref.ref_type, std::move(left), std::move(right), std::move(ref.condition)); diff --git a/src/duckdb/src/planner/binding_alias.cpp b/src/duckdb/src/planner/binding_alias.cpp index 0134137b7..4ac245c14 100644 --- a/src/duckdb/src/planner/binding_alias.cpp +++ b/src/duckdb/src/planner/binding_alias.cpp @@ -8,17 +8,18 @@ namespace duckdb { BindingAlias::BindingAlias() { } -BindingAlias::BindingAlias(string alias_p) : alias(std::move(alias_p)) { +BindingAlias::BindingAlias(Identifier alias_p) : alias(std::move(alias_p)) { } -BindingAlias::BindingAlias(string schema_p, string alias_p) : schema(std::move(schema_p)), alias(std::move(alias_p)) { +BindingAlias::BindingAlias(Identifier schema_p, Identifier alias_p) + : schema(std::move(schema_p)), alias(std::move(alias_p)) { } BindingAlias::BindingAlias(const StandardEntry &entry) : catalog(entry.ParentCatalog().GetName()), schema(entry.schema.name), alias(entry.name) { } -BindingAlias::BindingAlias(string catalog_p, string schema_p, string alias_p) +BindingAlias::BindingAlias(Identifier catalog_p, Identifier schema_p, Identifier alias_p) : catalog(std::move(catalog_p)), schema(std::move(schema_p)), alias(std::move(alias_p)) { } @@ -26,7 +27,7 @@ bool BindingAlias::IsSet() const { return !alias.empty(); } -const string &BindingAlias::GetAlias() const { +const Identifier &BindingAlias::GetAlias() const { if (!IsSet()) { throw InternalException("Calling BindingAlias::GetAlias on a non-set alias"); } @@ -50,21 +51,20 @@ bool BindingAlias::Matches(const BindingAlias &other) const { // i.e. "tbl" matches "catalog.schema.tbl" // but "schema2.tbl" does not match "schema.tbl" if (!other.catalog.empty()) { - if (!StringUtil::CIEquals(catalog, other.catalog)) { + if (catalog != other.catalog) { return false; } } if (!other.schema.empty()) { - if (!StringUtil::CIEquals(schema, other.schema)) { + if (schema != other.schema) { return false; } } - return StringUtil::CIEquals(alias, other.alias); + return alias == other.alias; } bool BindingAlias::operator==(const BindingAlias &other) const { - return StringUtil::CIEquals(catalog, other.catalog) && StringUtil::CIEquals(schema, other.schema) && - StringUtil::CIEquals(alias, other.alias); + return catalog == other.catalog && schema == other.schema && alias == other.alias; } } // namespace duckdb diff --git a/src/duckdb/src/planner/bound_parameter_map.cpp b/src/duckdb/src/planner/bound_parameter_map.cpp index 489b47f41..650bf968e 100644 --- a/src/duckdb/src/planner/bound_parameter_map.cpp +++ b/src/duckdb/src/planner/bound_parameter_map.cpp @@ -4,11 +4,11 @@ namespace duckdb { -BoundParameterMap::BoundParameterMap(case_insensitive_map_t ¶meter_data) +BoundParameterMap::BoundParameterMap(identifier_map_t ¶meter_data) : parameter_data(parameter_data) { } -LogicalType BoundParameterMap::GetReturnType(const string &identifier) { +LogicalType BoundParameterMap::GetReturnType(const Identifier &identifier) { D_ASSERT(!identifier.empty()); auto it = parameter_data.find(identifier); if (it == parameter_data.end()) { @@ -25,25 +25,25 @@ const bound_parameter_map_t &BoundParameterMap::GetParameters() { return parameters; } -const case_insensitive_map_t &BoundParameterMap::GetParameterData() { +const identifier_map_t &BoundParameterMap::GetParameterData() { return parameter_data; } -shared_ptr BoundParameterMap::CreateOrGetData(const string &identifier) { +shared_ptr BoundParameterMap::CreateOrGetData(const Identifier &identifier) { auto entry = parameters.find(identifier); if (entry == parameters.end()) { // no entry yet: create a new one auto data = make_shared_ptr(); data->return_type = GetReturnType(identifier); - CreateNewParameter(identifier, data); + CreateNewParameter(identifier.GetIdentifierName(), data); return data; } return entry->second; } unique_ptr BoundParameterMap::BindParameterExpression(ParameterExpression &expr) { - auto &identifier = expr.identifier; + auto &identifier = expr.Identifier(); D_ASSERT(!parameter_data.count(identifier)); // No value has been supplied yet, @@ -52,7 +52,7 @@ unique_ptr BoundParameterMap::BindParameterExpression( auto param_data = CreateOrGetData(identifier); auto bound_expr = make_uniq(identifier); - bound_expr->parameter_data = param_data; + bound_expr->ParameterDataMutable() = param_data; bound_expr->SetAlias(expr.GetAlias()); auto param_type = param_data->return_type; @@ -70,7 +70,7 @@ unique_ptr BoundParameterMap::BindParameterExpression( } void BoundParameterMap::CreateNewParameter(const string &id, const shared_ptr ¶m_data) { - D_ASSERT(!parameters.count(id)); + D_ASSERT(!parameters.count(Identifier(id))); parameters.emplace(std::make_pair(id, param_data)); } diff --git a/src/duckdb/src/planner/collation_binding.cpp b/src/duckdb/src/planner/collation_binding.cpp index cb3638f4f..8a4b2fc24 100644 --- a/src/duckdb/src/planner/collation_binding.cpp +++ b/src/duckdb/src/planner/collation_binding.cpp @@ -39,7 +39,8 @@ bool PushVarcharCollation(ClientContext &context, unique_ptr &source // we already applied this collation continue; } - auto &collation_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, collation_argument); + auto &collation_entry = + catalog.GetEntry(context, Identifier::DefaultSchema(), Identifier(collation_argument)); if (collation_entry.combinable) { entries.insert(entries.begin(), collation_entry); } else { @@ -75,7 +76,7 @@ bool PushTimeTZCollation(ClientContext &context, unique_ptr &source, auto &catalog = Catalog::GetSystemCatalog(context); auto &function_entry = - catalog.GetEntry(context, DEFAULT_SCHEMA, "timetz_byte_comparable"); + catalog.GetEntry(context, Identifier::DefaultSchema(), "timetz_byte_comparable"); if (function_entry.functions.Size() != 1) { throw InternalException("timetz_byte_comparable should only have a single overload"); } @@ -89,6 +90,28 @@ bool PushTimeTZCollation(ClientContext &context, unique_ptr &source, return true; } +bool PushBitStringCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, + CollationType) { + if (sql_type.id() != LogicalTypeId::BIT) { + return false; + } + + auto &catalog = Catalog::GetSystemCatalog(context); + auto &function_entry = + catalog.GetEntry(context, Identifier::DefaultSchema(), "bitstring_byte_comparable"); + if (function_entry.functions.Size() != 1) { + throw InternalException("bitstring_byte_comparable should only have a single overload"); + } + const auto &scalar_function = function_entry.functions.GetFunctionByOffset(0); + vector> children; + children.push_back(std::move(source)); + + FunctionBinder function_binder(context); + auto function = function_binder.BindScalarFunction(scalar_function, std::move(children)); + source = std::move(function); + return true; +} + bool PushIntervalCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, CollationType) { if (sql_type.id() != LogicalTypeId::INTERVAL) { @@ -96,7 +119,8 @@ bool PushIntervalCollation(ClientContext &context, unique_ptr &sourc } auto &catalog = Catalog::GetSystemCatalog(context); - auto &function_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "normalized_interval"); + auto &function_entry = + catalog.GetEntry(context, Identifier::DefaultSchema(), "normalized_interval"); if (function_entry.functions.Size() != 1) { throw InternalException("normalized_interval should only have a single overload"); } @@ -116,7 +140,8 @@ bool PushVariantCollation(ClientContext &context, unique_ptr &source return false; } auto &catalog = Catalog::GetSystemCatalog(context); - auto &function_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "variant_normalize"); + auto &function_entry = + catalog.GetEntry(context, Identifier::DefaultSchema(), "variant_normalize"); if (function_entry.functions.Size() != 1) { throw InternalException("variant_normalize should only have a single overload"); } @@ -136,6 +161,7 @@ bool PushVariantCollation(ClientContext &context, unique_ptr &source CollationBinding::CollationBinding() { RegisterCollation(CollationCallback(PushVarcharCollation)); RegisterCollation(CollationCallback(PushTimeTZCollation)); + RegisterCollation(CollationCallback(PushBitStringCollation)); RegisterCollation(CollationCallback(PushIntervalCollation)); RegisterCollation(CollationCallback(PushVariantCollation)); } diff --git a/src/duckdb/src/planner/column_qualifier.cpp b/src/duckdb/src/planner/column_qualifier.cpp new file mode 100644 index 000000000..46046007a --- /dev/null +++ b/src/duckdb/src/planner/column_qualifier.cpp @@ -0,0 +1,583 @@ +#include "duckdb/planner/column_qualifier.hpp" + +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/expression/columnref_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/parser/expression/lambda_expression.hpp" +#include "duckdb/parser/expression/operator_expression.hpp" +#include "duckdb/parser/expression/positional_reference_expression.hpp" +#include "duckdb/planner/expression_binder/having_binder.hpp" +#include "duckdb/planner/planner_extension.hpp" + +namespace duckdb { + +ColumnQualifier::ColumnQualifier(Binder &binder_p, optional_ptr> lambda_bindings_p, + optional_ptr alias_binder_p, + optional_ptr having_binder_p) + : binder(binder_p), lambda_bindings(lambda_bindings_p), alias_binder(alias_binder_p), + having_binder(having_binder_p) { +} + +static Identifier GetSQLValueFunctionName(const Identifier &column_name) { + if (column_name == "current_catalog") { + return "current_catalog"; + } + if (column_name == "current_date") { + return "current_date"; + } + if (column_name == "current_schema") { + return "current_schema"; + } + if (column_name == "current_role") { + return "current_role"; + } + if (column_name == "current_time") { + return "get_current_time"; + } + if (column_name == "current_timestamp") { + return "get_current_timestamp"; + } + if (column_name == "current_user") { + return "current_user"; + } + if (column_name == "localtime") { + return "current_localtime"; + } + if (column_name == "localtimestamp") { + return "current_localtimestamp"; + } + if (column_name == "session_user") { + return "session_user"; + } + if (column_name == "user") { + return "user"; + } + return Identifier(); +} + +unique_ptr Binder::GetSQLValueFunction(const Identifier &column_name) { + for (auto &ext : PlannerExtension::Iterate(context)) { + if (ext.get_sql_value_function) { + PlannerExtensionInput input {context, *this, ext.planner_info.get()}; + unique_ptr result; + auto result_type = ext.get_sql_value_function(input, column_name.GetIdentifierName(), result); + if (result_type == GetSQLValueFunctionReturnType::FINISH_BINDING || result) { + return result; + } + } + } + + auto value_function = GetSQLValueFunctionName(column_name); + if (value_function.empty()) { + return nullptr; + } + + vector> children; + return make_uniq(value_function, std::move(children)); +} + +unique_ptr ColumnQualifier::CreateStructExtract(unique_ptr base, + const Identifier &field_name) { + vector> children; + children.push_back(std::move(base)); + children.push_back(make_uniq_base(Value(field_name))); + auto extract_fun = make_uniq(ExpressionType::STRUCT_EXTRACT, std::move(children)); + return std::move(extract_fun); +} + +unique_ptr ColumnQualifier::CreateStructPack(ColumnRefExpression &col_ref) { + if (col_ref.ColumnNames().size() > 3) { + return nullptr; + } + D_ASSERT(!col_ref.ColumnNames().empty()); + + // get a matching binding + ErrorData error; + optional_ptr binding; + switch (col_ref.ColumnNames().size()) { + case 1: { + // single entry - this must be the table name + BindingAlias alias(col_ref.ColumnNames()[0]); + binding = binder.bind_context.GetBinding(alias, error); + break; + } + case 2: { + // two entries - this can either be "catalog.table" or "schema.table" - try both + BindingAlias alias(col_ref.ColumnNames()[0], col_ref.ColumnNames()[1]); + binding = binder.bind_context.GetBinding(alias, error); + if (!binding) { + alias = BindingAlias(col_ref.ColumnNames()[0], Identifier::InvalidSchema(), col_ref.ColumnNames()[1]); + binding = binder.bind_context.GetBinding(alias, error); + } + break; + } + case 3: { + // three entries - this must be "catalog.schema.table" + BindingAlias alias(col_ref.ColumnNames()[0], col_ref.ColumnNames()[1], col_ref.ColumnNames()[2]); + binding = binder.bind_context.GetBinding(alias, error); + break; + } + default: + throw InternalException("Expected 1, 2 or 3 column names for CreateStructPack"); + } + if (!binding) { + return nullptr; + } + + // We found the table, now create the struct_pack expression + auto &column_names = binding->GetColumnNames(); + vector child_expressions; + child_expressions.reserve(column_names.size()); + for (const auto &column_name : column_names) { + auto ref = binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), column_name, + ColumnBindType::DO_NOT_EXPAND_GENERATED_COLUMNS); + child_expressions.emplace_back(column_name, std::move(ref)); + } + return make_uniq("struct_pack", std::move(child_expressions)); +} + +unique_ptr ColumnQualifier::QualifyColumnName(const ParsedExpression &expr, + const Identifier &column_name, ErrorData &error) { + auto using_binding = binder.bind_context.GetUsingBinding(column_name); + if (using_binding) { + // we are referencing a USING column + // check if we can refer to one of the base columns directly + unique_ptr expression; + if (using_binding->primary_binding.IsSet()) { + // we can! just assign the table name and re-bind + return binder.bind_context.CreateColumnReference(using_binding->primary_binding, column_name); + } else { + // we cannot! we need to bind this as COALESCE between all the relevant columns + auto coalesce = make_uniq(ExpressionType::OPERATOR_COALESCE); + coalesce->GetChildrenMutable().reserve(using_binding->bindings.size()); + for (auto &entry : using_binding->bindings) { + coalesce->GetChildrenMutable().push_back(make_uniq(column_name, entry)); + } + return std::move(coalesce); + } + } + + // try binding as a lambda parameter + auto lambda_ref = LambdaRefExpression::FindMatchingBinding(lambda_bindings, column_name); + if (lambda_ref) { + return lambda_ref; + } + + // find a table binding that contains this column name + auto table_binding = binder.bind_context.GetMatchingBinding(column_name, expr); + + // throw an error if a macro parameter name conflicts with a column name + auto is_macro_column = false; + if (binder.macro_binding && binder.macro_binding->HasMatchingBinding(column_name)) { + is_macro_column = true; + if (table_binding) { + throw BinderException(expr, "Conflicting column names for column " + column_name + "!"); + } + } + + // bind as a macro column + if (is_macro_column) { + return binder.bind_context.CreateColumnReference(binder.macro_binding->GetBindingAlias(), column_name); + } + + // bind as a regular column + if (table_binding) { + return binder.bind_context.CreateColumnReference(table_binding->GetBindingAlias(), column_name); + } + + // it's not, find candidates and error + auto similar_bindings = binder.bind_context.GetSimilarBindings(column_name); + error = ErrorData(BinderException::ColumnNotFound(column_name, similar_bindings)); + return nullptr; +} + +void ColumnQualifier::QualifyColumnNames(unique_ptr &expr, vector &lambda_params, + const bool within_function_expression) { + bool next_within_function_expression = false; + switch (expr->GetExpressionType()) { + case ExpressionType::COLUMN_REF: { + auto &col_ref = expr->Cast(); + + // don't qualify lambda parameters + if (LambdaExpression::IsLambdaParameter(lambda_params, col_ref.GetName())) { + return; + } + + ErrorData error; + auto new_expr = QualifyColumnName(col_ref, error); + + if (new_expr) { + if (!expr->GetAlias().empty()) { + // Pre-existing aliases are added to the qualified column reference + new_expr->SetAlias(expr->GetAlias()); + } else if (within_function_expression) { + // Qualifying the column reference may add an alias, but this needs to be removed within function + // expressions, because the alias here means a named parameter instead of a positional parameter + new_expr->ClearAlias(); + } + + // replace the expression with the qualified column reference + new_expr->SetQueryLocation(col_ref.GetQueryLocation()); + expr = std::move(new_expr); + } + return; + } + case ExpressionType::POSITIONAL_REFERENCE: { + auto &ref = expr->Cast(); + if (ref.GetAlias().empty()) { + Identifier table_name, column_name; + auto error = binder.bind_context.BindColumn(ref, table_name, column_name); + if (error.empty()) { + ref.SetAlias(column_name); + } + } + break; + } + case ExpressionType::FUNCTION: { + // Special-handling for lambdas, which are inside function expressions. + auto &function = expr->Cast(); + if (!ExpressionBinder::IsUnnestFunction(function.FunctionName())) { + QualifyFunction(function); + } + if (function.IsLambdaFunction()) { + return QualifyColumnNamesInLambda(function, lambda_params); + } + + next_within_function_expression = true; + break; + } + default: // fall through + break; + } + + // recurse on the child expressions + ParsedExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { + QualifyColumnNames(child, lambda_params, next_within_function_expression); + }); +} + +optional_ptr ColumnQualifier::QualifyFunction(FunctionExpression &function) { + D_ASSERT(!ExpressionBinder::IsUnnestFunction(function.FunctionName())); + // lookup the function in the catalog + QueryErrorContext error_context(function.GetQueryLocation()); + binder.BindSchemaOrCatalog(function.CatalogMutable(), function.SchemaMutable()); + + EntryLookupInfo function_lookup(CatalogType::SCALAR_FUNCTION_ENTRY, function.FunctionName(), error_context); + auto func = + binder.GetCatalogEntry(function.Catalog(), function.Schema(), function_lookup, OnEntryNotFound::RETURN_NULL); + if (func) { + // found the function - we are done + return func; + } + // not a table function - check if the schema is set + if (function.Schema().empty()) { + // schema is not set - leave it as-is + return nullptr; + } + // the schema is set - check if we can turn this the schema into a column ref + // does this function exist in the system catalog? + func = binder.GetCatalogEntry(Identifier::InvalidCatalog(), Identifier::InvalidSchema(), function_lookup, + OnEntryNotFound::RETURN_NULL); + if (!func) { + // we could not find the function - bail + return nullptr; + } + // the function exists in the system catalog - turn this into a dot call + ErrorData error; + unique_ptr colref; + if (function.Catalog().empty()) { + colref = make_uniq(function.Schema()); + } else { + colref = make_uniq(function.Schema(), function.Catalog()); + } + auto new_colref = QualifyColumnName(*colref, error); + if (!new_colref) { + new_colref = std::move(colref); + } + // we can! transform this into a function call on the column + // i.e. "x.lower()" becomes "lower(x)" + function.GetArgumentsMutable().insert(function.GetArgumentsMutable().begin(), std::move(new_colref)); + function.CatalogMutable() = INVALID_CATALOG; + function.SchemaMutable() = INVALID_SCHEMA; + return func; +} + +void ColumnQualifier::QualifyColumnNamesInLambda(FunctionExpression &function, + vector &lambda_params) { + for (auto &child : function.GetArgumentsMutable()) { + if (child.GetExpression().GetExpressionClass() != ExpressionClass::LAMBDA) { + // not a lambda expression + QualifyColumnNames(child.GetExpressionMutable(), lambda_params, true); + continue; + } + + // special-handling for LHS lambda parameters + // we do not qualify them, and we add them to the lambda_params vector + auto &lambda_expr = child.GetExpressionMutable()->Cast(); + string error_message; + auto column_ref_expressions = lambda_expr.ExtractColumnRefExpressions(error_message); + + if (!error_message.empty()) { + // possibly a JSON function, qualify both LHS and RHS + QualifyColumnNames(lambda_expr.LeftMutable(), lambda_params, true); + QualifyColumnNames(lambda_expr.RightMutable(), lambda_params, true); + continue; + } + + // push this level + lambda_params.emplace_back(); + + // push the lambda parameter names + for (const auto &column_ref_expr : column_ref_expressions) { + const auto &column_ref = column_ref_expr.get().Cast(); + lambda_params.back().emplace(column_ref.GetName()); + } + + // only qualify in RHS + QualifyColumnNames(lambda_expr.RightMutable(), lambda_params, true); + + // pop this level + lambda_params.pop_back(); + } +} + +unique_ptr ColumnQualifier::QualifyColumnNameWithManyDotsInternal(ColumnRefExpression &col_ref, + ErrorData &error, + idx_t &struct_extract_start) { + // two or more dots (i.e. "part1.part2.part3.part4...") + // -> part1 is a catalog, part2 is a schema, part3 is a table, part4 is a column name, part 5 and beyond are + // struct fields + // -> part1 is a catalog, part2 is a table, part3 is a column name, part4 and beyond are struct fields + // -> part1 is a schema, part2 is a table, part3 is a column name, part4 and beyond are struct fields + // -> part1 is a table, part2 is a column name, part3 and beyond are struct fields + // -> part1 is a column, part2 and beyond are struct fields + + // we always prefer the most top-level view + // i.e. in case of multiple resolution options, we resolve in order: + // -> 1. resolve "part1" as a catalog + // -> 2. resolve "part1" as a schema + // -> 3. resolve "part1" as a table + // -> 4. resolve "part1" as a column + + // first check if part1 is a catalog + ErrorData fully_qualified_error; + optional_ptr binding; + if (col_ref.ColumnNames().size() > 3) { + binding = binder.GetMatchingBinding(col_ref.ColumnNames()[0], col_ref.ColumnNames()[1], + col_ref.ColumnNames()[2], col_ref.ColumnNames()[3], fully_qualified_error); + if (binding) { + // part1 is a catalog - the column reference is "catalog.schema.table.column" + struct_extract_start = 4; + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.ColumnNames()[3]); + } + } + ErrorData catalog_table_error; + binding = binder.GetMatchingBinding(col_ref.ColumnNames()[0], Identifier::InvalidSchema(), col_ref.ColumnNames()[1], + col_ref.ColumnNames()[2], catalog_table_error); + if (binding) { + // part1 is a catalog - the column reference is "catalog.table.column" + struct_extract_start = 3; + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.ColumnNames()[2]); + } + ErrorData schema_table_error; + binding = binder.GetMatchingBinding(col_ref.ColumnNames()[0], col_ref.ColumnNames()[1], col_ref.ColumnNames()[2], + schema_table_error); + if (binding) { + // part1 is a schema - the column reference is "schema.table.column" + // any additional fields are turned into struct_extract calls + struct_extract_start = 3; + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.ColumnNames()[2]); + } + ErrorData table_column_error; + binding = binder.GetMatchingBinding(col_ref.ColumnNames()[0], col_ref.ColumnNames()[1], table_column_error); + if (binding) { + // part1 is a table + // the column reference is "table.column" + // any additional fields are turned into struct_extract calls + struct_extract_start = 2; + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.ColumnNames()[1]); + } + // part1 could be a column + ErrorData unused_error; + auto result_expr = QualifyColumnName(col_ref, col_ref.ColumnNames()[0], unused_error); + if (result_expr) { + // it is! add the struct extract calls + struct_extract_start = 1; + return result_expr; + } + auto struct_pack = CreateStructPack(col_ref); + if (struct_pack) { + return struct_pack; + } + + // we could not find the column - return an error + // try to find out which error to return + optional_idx catalog_pos; + optional_idx schema_pos; + optional_idx table_pos; + for (const auto &binding_entry : binder.bind_context.GetBindingsList()) { + auto &alias = binding_entry->GetBindingAlias(); + string catalog = alias.GetCatalog().GetIdentifierName(); + string schema = alias.GetSchema().GetIdentifierName(); + string table = alias.GetAlias().GetIdentifierName(); + auto entry = binding_entry->GetStandardEntry(); + if (entry) { + catalog = entry->ParentCatalog().GetName().GetIdentifierName(); + schema = entry->ParentSchema().name.GetIdentifierName(); + } + + for (idx_t i = 0; i < 3; i++) { + if (catalog == col_ref.ColumnNames()[i]) { + catalog_pos = i; + } + if (schema == col_ref.ColumnNames()[i]) { + schema_pos = i; + } + if (table == col_ref.ColumnNames()[i]) { + table_pos = i; + } + } + } + + error = std::move(table_column_error); + if (table_pos.IsValid()) { + // we have a valid table + if (table_pos.GetIndex() == 0) { + // the first column is a valid table - return the error as if we were binding the column as the second + + } else if (table_pos.GetIndex() == 1) { + // the first column is a valid table + // this means either the first column is a catalog OR the first column is a schema + if (catalog_pos.IsValid()) { + // catalog.table + error = std::move(catalog_table_error); + } else { + // assume schema.table otherwise + error = std::move(schema_table_error); + } + } else if (col_ref.ColumnNames().size() > 3) { + // the second column is a valid table + // return the fully qualified error if we have enough components + error = std::move(fully_qualified_error); + } + } else if (catalog_pos.IsValid()) { + // the first column is a catalog + if (schema_pos.IsValid()) { + // AND a valid schema - this is likely "catalog.schema.table" + if (col_ref.ColumnNames().size() > 3) { + error = std::move(fully_qualified_error); + } else { + error = std::move(catalog_table_error); + } + } else { + // but no valid schema - this could be either "catalog.schema.table" or "catalog.table" + // return "catalog.table" for now + error = std::move(catalog_table_error); + } + } else if (schema_pos.IsValid()) { + // we did not find a catalog or a table, but we did find a schema + if (schema_pos.GetIndex() == 0) { + // "schema.table" + error = std::move(schema_table_error); + } else if (schema_pos.GetIndex() == 1 && col_ref.ColumnNames().size() > 3) { + // catalog.schema.table + error = std::move(fully_qualified_error); + } + } + return nullptr; +} + +unique_ptr ColumnQualifier::QualifyColumnNameWithManyDots(ColumnRefExpression &col_ref, + ErrorData &error) { + idx_t struct_extract_start = col_ref.ColumnNames().size(); + auto result_expr = QualifyColumnNameWithManyDotsInternal(col_ref, error, struct_extract_start); + if (!result_expr) { + return nullptr; + } + + // create a struct extract with all remaining column names + for (idx_t i = struct_extract_start; i < col_ref.ColumnNames().size(); i++) { + result_expr = CreateStructExtract(std::move(result_expr), col_ref.ColumnNames()[i]); + } + + return result_expr; +} + +unique_ptr ColumnQualifier::QualifyColumnName(ColumnRefExpression &colref, ErrorData &error) { + auto qualified_colref = QualifyColumnNameInternal(colref, error); + if (!qualified_colref) { + if (alias_binder) { + return alias_binder->ResolveAlias(colref); + } + return nullptr; + } + if (!having_binder) { + return qualified_colref; + } + auto group_index = having_binder->TryBindGroup(*qualified_colref); + if (group_index.IsValid()) { + return qualified_colref; + } + if (having_binder->column_alias_binder.DoesColumnAliasExist(colref)) { + return nullptr; + } + return qualified_colref; +} + +unique_ptr ColumnQualifier::QualifyColumnNameInternal(ColumnRefExpression &col_ref, + ErrorData &error) { + if (!col_ref.IsQualified()) { + // Try binding as a lambda parameter. + auto lambda_ref = LambdaRefExpression::FindMatchingBinding(lambda_bindings, col_ref.GetColumnName()); + if (lambda_ref) { + return lambda_ref; + } + } + + idx_t column_parts = col_ref.ColumnNames().size(); + + // column names can have an arbitrary amount of dots + // here is how the resolution works: + if (column_parts == 1) { + // no dots (i.e. "part1") + // -> part1 refers to a column + // check if we can qualify the column name with the table name + auto qualified_col_ref = QualifyColumnName(col_ref, col_ref.GetColumnName(), error); + if (qualified_col_ref) { + // we could: return it + return qualified_col_ref; + } + // we could not! Try creating an implicit struct_pack + return CreateStructPack(col_ref); + } + + if (column_parts == 2) { + // one dot (i.e. "part1.part2") + // EITHER: + // -> part1 is a table, part2 is a column + // -> part1 is a column, part2 is a property of that column (i.e. struct_extract) + + // first check if part1 is a table, and part2 is a standard column name + auto binding = binder.GetMatchingBinding(col_ref.ColumnNames()[0], col_ref.ColumnNames()[1], error); + if (binding) { + // it is! return the column reference directly + return binder.bind_context.CreateColumnReference(binding->GetBindingAlias(), col_ref.GetColumnName()); + } + + // otherwise check if we can turn this into a struct extract + ErrorData other_error; + auto qualified_col_ref = QualifyColumnName(col_ref, col_ref.ColumnNames()[0], other_error); + if (qualified_col_ref) { + // we could: create a struct extract + return CreateStructExtract(std::move(qualified_col_ref), col_ref.ColumnNames()[1]); + } + // we could not! Try creating an implicit struct_pack + return CreateStructPack(col_ref); + } + + // three or more dots + return QualifyColumnNameWithManyDots(col_ref, error); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp b/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp index ec1a938ea..f1b602e8f 100644 --- a/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_aggregate_expression.cpp @@ -5,6 +5,7 @@ #include "duckdb/common/types/hash.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/function/function_serialization.hpp" +#include "duckdb/function/scalar/generic_common.hpp" namespace duckdb { @@ -14,13 +15,48 @@ BoundAggregateExpression::BoundAggregateExpression(BoundAggregateFunction functi AggregateType aggr_type) : Expression(ExpressionType::BOUND_AGGREGATE, ExpressionClass::BOUND_AGGREGATE, function.GetReturnType()), function(std::move(function)), children(std::move(children)), bind_info(std::move(bind_info)), - aggr_type(aggr_type), filter(std::move(filter)) { + aggr_type(aggr_type), state_export_mode(AggregateStateExportMode::NONE), filter(std::move(filter)) { D_ASSERT(!this->function.GetName().empty()); } string BoundAggregateExpression::ToString() const { - return FunctionExpression::ToString( - *this, string(), string(), function.GetName(), false, IsDistinct(), filter.get(), order_bys.get()); + auto distinct = IsDistinct(); + auto &function_name = function.GetName(); + + string result; + result += SQLIdentifier(function_name); + result += "("; + if (distinct) { + result += "DISTINCT "; + } + result += StringUtil::Join(children, children.size(), ", ", + [&](const unique_ptr &child) { return child->ToString(); }); + + // ordered aggregate + if (order_bys && !order_bys->orders.empty()) { + if (children.empty()) { + result += ") WITHIN GROUP ("; + } + result += " ORDER BY "; + for (idx_t i = 0; i < order_bys->orders.size(); i++) { + if (i > 0) { + result += ", "; + } + result += order_bys->orders[i].ToString(); + } + } + result += ")"; + + if (state_export_mode == AggregateStateExportMode::STATE_EXPORT) { + result += " EXPORT_STATE"; + } + + // filtered aggregate + if (filter) { + result += " FILTER (WHERE " + filter->ToString() + ")"; + } + + return result; } hash_t BoundAggregateExpression::Hash() const { @@ -52,7 +88,10 @@ bool BoundAggregateExpression::Equals(const BaseExpression &other_p) const { return false; } } - if (!FunctionData::Equals(bind_info.get(), other.bind_info.get())) { + if (state_export_mode != other.state_export_mode) { + return false; + } + if (!FunctionData::Equals(bind_info.get(), other.BindInfo().get())) { return false; } if (!BoundOrderModifier::Equals(order_bys, other.order_bys)) { @@ -78,6 +117,7 @@ unique_ptr BoundAggregateExpression::Copy() const { auto copy = make_uniq(function, std::move(new_children), std::move(new_filter), std::move(new_bind_info), aggr_type); copy->CopyProperties(*this); + copy->state_export_mode = state_export_mode; copy->order_bys = order_bys ? order_bys->Copy() : nullptr; return std::move(copy); } @@ -90,6 +130,7 @@ void BoundAggregateExpression::Serialize(Serializer &serializer) const { serializer.WriteProperty(203, "aggregate_type", aggr_type); serializer.WritePropertyWithDefault(204, "filter", filter, unique_ptr()); serializer.WritePropertyWithDefault(205, "order_bys", order_bys, unique_ptr()); + serializer.WritePropertyWithDefault(206, "state_export", state_export_mode, AggregateStateExportMode::NONE); } unique_ptr BoundAggregateExpression::Deserialize(Deserializer &deserializer) { @@ -102,12 +143,19 @@ unique_ptr BoundAggregateExpression::Deserialize(Deserializer &deser deserializer.ReadPropertyWithExplicitDefault>(204, "filter", unique_ptr()); auto result = make_uniq(std::move(entry.first), std::move(children), std::move(filter), std::move(entry.second), aggregate_type); - if (result->return_type != return_type) { + deserializer.ReadPropertyWithExplicitDefault(205, "order_bys", result->order_bys, unique_ptr()); + deserializer.ReadPropertyWithExplicitDefault(206, "state_export", result->state_export_mode, + AggregateStateExportMode::NONE); + if (result->state_export_mode == AggregateStateExportMode::STATE_EXPORT) { + if (!return_type.IsAggregateState()) { + throw SerializationException("Aggregate State export should return an aggregate state type"); + } + ExportAggregateFunction::SetStateExport(*result, std::move(return_type)); + } else if (result->return_type != return_type) { // return type mismatch - push a cast auto &context = deserializer.Get(); return BoundCastExpression::AddCastToType(context, std::move(result), return_type); } - deserializer.ReadPropertyWithExplicitDefault(205, "order_bys", result->order_bys, unique_ptr()); return std::move(result); } diff --git a/src/duckdb/src/planner/expression/bound_cast_expression.cpp b/src/duckdb/src/planner/expression/bound_cast_expression.cpp index 6cbb2b87d..e58942d36 100644 --- a/src/duckdb/src/planner/expression/bound_cast_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_cast_expression.cpp @@ -43,7 +43,7 @@ static unique_ptr AddCastExpressionInternal(unique_ptr e } } auto result = make_uniq(std::move(expr), target_type, std::move(bound_cast), try_cast); - result->SetQueryLocation(result->child->GetQueryLocation()); + result->SetQueryLocation(result->Child().GetQueryLocation()); return std::move(result); } @@ -55,36 +55,36 @@ static unique_ptr AddCastToTypeInternal(unique_ptr expr, auto ¶meter = expr->Cast(); if (!target_type.IsValid()) { // invalidate the parameter - parameter.parameter_data->return_type = LogicalType::INVALID; + parameter.ParameterData()->return_type = LogicalType::INVALID; parameter.SetReturnType(target_type); return expr; } - if (parameter.parameter_data->return_type.id() == LogicalTypeId::INVALID) { + if (parameter.ParameterData()->return_type.id() == LogicalTypeId::INVALID) { // we don't know the type of this parameter parameter.SetReturnType(target_type); return expr; } - if (parameter.parameter_data->return_type.id() == LogicalTypeId::UNKNOWN) { + if (parameter.ParameterData()->return_type.id() == LogicalTypeId::UNKNOWN) { // prepared statement parameter cast - but there is no type, convert the type - parameter.parameter_data->return_type = target_type; + parameter.ParameterData()->return_type = target_type; parameter.SetReturnType(target_type); return expr; } // prepared statement parameter already has a type - if (parameter.parameter_data->return_type == target_type) { + if (parameter.ParameterData()->return_type == target_type) { // this type! we are done - parameter.SetReturnType(parameter.parameter_data->return_type); + parameter.SetReturnType(parameter.ParameterData()->return_type); return expr; } // If this occurrence's own return_type still matches parameter_data->return_type, the // parameter was pinned by an inner cast on this same occurrence (e.g. CAST(CAST($1 AS A) AS B)). // Add a regular cast on top instead of invalidating. - if (parameter.GetReturnType() == parameter.parameter_data->return_type) { + if (parameter.GetReturnType() == parameter.ParameterData()->return_type) { auto cast_function = cast_functions.GetCastFunction(parameter.GetReturnType(), target_type, get_input); return AddCastExpressionInternal(std::move(expr), target_type, std::move(cast_function), try_cast); } // invalidate the type - parameter.parameter_data->return_type = LogicalType::INVALID; + parameter.ParameterData()->return_type = LogicalType::INVALID; parameter.SetReturnType(target_type); return expr; } else if (expr->GetExpressionClass() == ExpressionClass::BOUND_DEFAULT) { @@ -241,7 +241,7 @@ unique_ptr BoundCastExpression::Copy() const { bool BoundCastExpression::CanThrow() const { const auto child_type = child->GetReturnType(); if (return_type.id() != child_type.id() && - LogicalType::ForceMaxLogicalType(return_type, child_type) == child_type.id()) { + LogicalType::DefaultForceMaxLogicalType(return_type, child_type) == child_type.id()) { return true; } // Casting VARCHAR to JSON involves parsing and validation that can throw on malformed input diff --git a/src/duckdb/src/planner/expression/bound_columnref_expression.cpp b/src/duckdb/src/planner/expression/bound_columnref_expression.cpp index c4ed36d83..081cba658 100644 --- a/src/duckdb/src/planner/expression/bound_columnref_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_columnref_expression.cpp @@ -5,14 +5,15 @@ namespace duckdb { -BoundColumnRefExpression::BoundColumnRefExpression(string alias_p, LogicalType type, ColumnBinding binding, idx_t depth) +BoundColumnRefExpression::BoundColumnRefExpression(Identifier alias_p, LogicalType type, ColumnBinding binding, + idx_t depth) : Expression(ExpressionType::BOUND_COLUMN_REF, ExpressionClass::BOUND_COLUMN_REF, std::move(type)), binding(binding), depth(depth) { this->alias = std::move(alias_p); } BoundColumnRefExpression::BoundColumnRefExpression(LogicalType type, ColumnBinding binding, idx_t depth) - : BoundColumnRefExpression(string(), std::move(type), binding, depth) { + : BoundColumnRefExpression(Identifier(), std::move(type), binding, depth) { } unique_ptr BoundColumnRefExpression::Copy() const { @@ -34,10 +35,10 @@ bool BoundColumnRefExpression::Equals(const BaseExpression &other_p) const { return other.binding == binding && other.depth == depth; } -string BoundColumnRefExpression::GetName() const { +Identifier BoundColumnRefExpression::GetName() const { #ifdef DEBUG if (DBConfigOptions::debug_print_bindings) { - return StringUtil::Format("%s (%s)", binding.ToString(), return_type.ToString()); + return Identifier(StringUtil::Format("%s (%s)", binding.ToString(), return_type.ToString())); } #endif return Expression::GetName(); @@ -50,7 +51,7 @@ string BoundColumnRefExpression::ToString() const { } #endif if (!alias.empty()) { - return alias; + return alias.GetIdentifierName(); } return binding.ToString(); } diff --git a/src/duckdb/src/planner/expression/bound_expanded_expression.cpp b/src/duckdb/src/planner/expression/bound_expanded_expression.cpp index 9e730fe42..d57bf4242 100644 --- a/src/duckdb/src/planner/expression/bound_expanded_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_expanded_expression.cpp @@ -4,7 +4,7 @@ namespace duckdb { BoundExpandedExpression::BoundExpandedExpression(vector> expanded_expressions_p) : Expression(ExpressionType::BOUND_EXPANDED, ExpressionClass::BOUND_EXPANDED, LogicalType::INTEGER), - expanded_expressions(std::move(expanded_expressions_p)) { + children(std::move(expanded_expressions_p)) { } string BoundExpandedExpression::ToString() const { diff --git a/src/duckdb/src/planner/expression/bound_function_expression.cpp b/src/duckdb/src/planner/expression/bound_function_expression.cpp index 7f7bb37fa..a33689666 100644 --- a/src/duckdb/src/planner/expression/bound_function_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_function_expression.cpp @@ -62,9 +62,33 @@ string BoundFunctionExpression::ToString() const { FunctionToStringInput input(function, bind_info.get(), children); return function.FunctionToString(input); } - return FunctionExpression::ToString(*this, string(), string(), - function.GetName(), is_operator); + auto &function_name = function.GetName().GetIdentifierName(); + + if (is_operator) { + // built-in operator + if (children.size() == 1) { + if (StringUtil::Contains(function_name, "__postfix")) { + return "((" + children[0]->ToString() + ")" + StringUtil::Replace(function_name, "__postfix", "") + ")"; + } + return function_name + "(" + children[0]->ToString() + ")"; + } + if (children.size() == 2) { + return StringUtil::Format("(%s %s %s)", children[0]->ToString(), function_name, children[1]->ToString()); + } + } + + // standard function call + string result; + result += SQLIdentifier(function_name); + result += "("; + + result += StringUtil::Join(children, children.size(), ", ", + [&](const unique_ptr &child) { return child->ToString(); }); + + result += ")"; + return result; } + bool BoundFunctionExpression::PropagatesNullValues() const { return function.GetNullHandling() == FunctionNullHandling::SPECIAL_HANDLING ? false : Expression::PropagatesNullValues(); @@ -111,13 +135,14 @@ void BoundFunctionExpression::Verify() const { } void BoundFunctionExpression::Serialize(Serializer &serializer) const { - if (!serializer.ShouldSerialize(8) && function.HasLegacySerializeCallback()) { + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0) && function.HasLegacySerializeCallback()) { // serialize legacy expression for backwards compatibility FunctionToStringInput input(function, bind_info.get(), children); auto legacy_expr = function.GetLegacySerializeCallback()(input); legacy_expr->Serialize(serializer); return; } + Expression::Serialize(serializer); serializer.WriteProperty(200, "return_type", return_type); serializer.WriteProperty(201, "children", children); diff --git a/src/duckdb/src/planner/expression/bound_lambdaref_expression.cpp b/src/duckdb/src/planner/expression/bound_lambdaref_expression.cpp index 43d1a8f58..2cb9e106a 100644 --- a/src/duckdb/src/planner/expression/bound_lambdaref_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_lambdaref_expression.cpp @@ -5,7 +5,7 @@ namespace duckdb { -BoundLambdaRefExpression::BoundLambdaRefExpression(string alias_p, LogicalType type, ColumnBinding binding, +BoundLambdaRefExpression::BoundLambdaRefExpression(Identifier alias_p, LogicalType type, ColumnBinding binding, idx_t lambda_idx, idx_t depth) : Expression(ExpressionType::BOUND_LAMBDA_REF, ExpressionClass::BOUND_LAMBDA_REF, std::move(type)), binding(binding), lambda_idx(lambda_idx), depth(depth) { @@ -14,7 +14,7 @@ BoundLambdaRefExpression::BoundLambdaRefExpression(string alias_p, LogicalType t BoundLambdaRefExpression::BoundLambdaRefExpression(LogicalType type, ColumnBinding binding, idx_t lambda_idx, idx_t depth) - : BoundLambdaRefExpression(string(), std::move(type), binding, lambda_idx, depth) { + : BoundLambdaRefExpression(Identifier(), std::move(type), binding, lambda_idx, depth) { } unique_ptr BoundLambdaRefExpression::Copy() const { @@ -39,7 +39,7 @@ bool BoundLambdaRefExpression::Equals(const BaseExpression &other_p) const { string BoundLambdaRefExpression::ToString() const { if (!alias.empty()) { - return alias; + return alias.GetIdentifierName(); } return "#[" + to_string(binding.table_index.index) + "." + to_string(binding.column_index) + "." + to_string(lambda_idx) + "]"; diff --git a/src/duckdb/src/planner/expression/bound_parameter_expression.cpp b/src/duckdb/src/planner/expression/bound_parameter_expression.cpp index 82ceca2d4..ab6ec7b47 100644 --- a/src/duckdb/src/planner/expression/bound_parameter_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_parameter_expression.cpp @@ -5,14 +5,14 @@ namespace duckdb { -BoundParameterExpression::BoundParameterExpression(const string &identifier) +BoundParameterExpression::BoundParameterExpression(const duckdb::Identifier &identifier) : Expression(ExpressionType::VALUE_PARAMETER, ExpressionClass::BOUND_PARAMETER, LogicalType(LogicalTypeId::UNKNOWN)), identifier(identifier) { } -BoundParameterExpression::BoundParameterExpression(bound_parameter_map_t &global_parameter_set, string identifier, - LogicalType return_type, +BoundParameterExpression::BoundParameterExpression(bound_parameter_map_t &global_parameter_set, + duckdb::Identifier identifier, LogicalType return_type, shared_ptr parameter_data) : Expression(ExpressionType::VALUE_PARAMETER, ExpressionClass::BOUND_PARAMETER, std::move(return_type)), identifier(std::move(identifier)) { @@ -64,7 +64,7 @@ bool BoundParameterExpression::Equals(const BaseExpression &other_p) const { return false; } auto &other = other_p.Cast(); - return StringUtil::CIEquals(identifier, other.identifier); + return identifier == other.identifier; } hash_t BoundParameterExpression::Hash() const { diff --git a/src/duckdb/src/planner/expression/bound_reference_expression.cpp b/src/duckdb/src/planner/expression/bound_reference_expression.cpp index a163b7c35..24d8eb398 100644 --- a/src/duckdb/src/planner/expression/bound_reference_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_reference_expression.cpp @@ -6,12 +6,12 @@ namespace duckdb { -BoundReferenceExpression::BoundReferenceExpression(string alias, LogicalType type, idx_t index) +BoundReferenceExpression::BoundReferenceExpression(Identifier alias, LogicalType type, idx_t index) : Expression(ExpressionType::BOUND_REF, ExpressionClass::BOUND_REF, std::move(type)), index(index) { this->alias = std::move(alias); } BoundReferenceExpression::BoundReferenceExpression(LogicalType type, storage_t index) - : BoundReferenceExpression(string(), std::move(type), index) { + : BoundReferenceExpression(Identifier(), std::move(type), index) { } string BoundReferenceExpression::ToString() const { @@ -21,7 +21,7 @@ string BoundReferenceExpression::ToString() const { } #endif if (!alias.empty()) { - return alias; + return alias.GetIdentifierName(); } return "#" + to_string(index); } diff --git a/src/duckdb/src/planner/expression/bound_window_expression.cpp b/src/duckdb/src/planner/expression/bound_window_expression.cpp index 0f41864c3..c4f2ec687 100644 --- a/src/duckdb/src/planner/expression/bound_window_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_window_expression.cpp @@ -18,7 +18,8 @@ BoundWindowExpression::BoundWindowExpression(LogicalType return_type, unique_ptr } string BoundWindowExpression::ToString() const { - string function_name = aggregate.get() ? aggregate->GetName() : window->GetName(); + string function_name = + aggregate.get() ? aggregate->GetName().GetIdentifierName() : window->GetName().GetIdentifierName(); return WindowExpression::ToString(*this, string(), function_name); } @@ -201,7 +202,7 @@ unique_ptr BoundWindowExpression::Copy() const { vector> BoundWindowExpression::SerializedChildren(Serializer &serializer) const { vector> result; idx_t nargs = children.size(); - if (!serializer.ShouldSerialize(8) && window) { + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0) && window) { const auto &function_name = window->GetName(); if (function_name == "lead" || function_name == "lag") { nargs = 1; @@ -216,7 +217,7 @@ vector> BoundWindowExpression::SerializedChildren(Seriali } unique_ptr BoundWindowExpression::SerializedOffset(Serializer &serializer) const { - if (!serializer.ShouldSerialize(8) && children.size() > 1 && window) { + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0) && children.size() > 1 && window) { const auto &function_name = window->GetName(); if (function_name == "lead" || function_name == "lag") { return children[1]->Copy(); @@ -227,7 +228,7 @@ unique_ptr BoundWindowExpression::SerializedOffset(Serializer &seria } unique_ptr BoundWindowExpression::SerializedDefault(Serializer &serializer) const { - if (!serializer.ShouldSerialize(8) && children.size() > 2 && window) { + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0) && children.size() > 2 && window) { const auto &function_name = window->GetName(); if (function_name == "lead" || function_name == "lag") { return children[2]->Copy(); @@ -320,8 +321,9 @@ unique_ptr BoundWindowExpression::Deserialize(Deserializer &deserial auto &context = deserializer.Get(); auto binder = Binder::CreateBinder(context); - EntryLookupInfo lookup(CatalogType::SCALAR_FUNCTION_ENTRY, name); - auto entry = binder->GetCatalogEntry(SYSTEM_CATALOG, DEFAULT_SCHEMA, lookup, OnEntryNotFound::THROW_EXCEPTION); + EntryLookupInfo lookup(CatalogType::SCALAR_FUNCTION_ENTRY, Identifier(name)); + auto entry = binder->GetCatalogEntry(Identifier::SystemCatalog(), Identifier::DefaultSchema(), lookup, + OnEntryNotFound::THROW_EXCEPTION); auto &func = entry->Cast(); FunctionBinder function_binder(*binder); diff --git a/src/duckdb/src/planner/expression_binder.cpp b/src/duckdb/src/planner/expression_binder.cpp index e50760dc1..c15b0cc1f 100644 --- a/src/duckdb/src/planner/expression_binder.cpp +++ b/src/duckdb/src/planner/expression_binder.cpp @@ -16,25 +16,11 @@ void ExpressionBinder::SetCatalogLookupCallback(catalog_entry_callback_t callbac binder.SetCatalogLookupCallback(std::move(callback)); } -ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context, bool replace_binder) - : binder(binder), context(context) { +ExpressionBinder::ExpressionBinder(Binder &binder, ClientContext &context) : binder(binder), context(context) { InitializeStackCheck(); - if (replace_binder) { - stored_binder = &binder.GetActiveBinder(); - binder.SetActiveBinder(*this); - } else { - binder.PushExpressionBinder(*this); - } } ExpressionBinder::~ExpressionBinder() { - if (binder.HasActiveBinder()) { - if (stored_binder) { - binder.SetActiveBinder(*stored_binder); - } else { - binder.PopExpressionBinder(); - } - } } void ExpressionBinder::InitializeStackCheck() { @@ -84,7 +70,7 @@ BindResult ExpressionBinder::BindExpression(unique_ptr &expr, return BindExpression(expr_ref.Cast(), depth); case ExpressionClass::FUNCTION: { auto &function = expr_ref.Cast(); - if (IsUnnestFunction(function.function_name)) { + if (IsUnnestFunction(function.FunctionName())) { // special case, not in catalog return BindUnnest(function, depth, root_expression); } @@ -185,7 +171,7 @@ static bool CombineMissingColumns(ErrorData ¤t, ErrorData new_error) { context = QueryErrorContext(position); } // generate a new (combined) error - current = BinderException::ColumnNotFound(column_name, top_candidates, context); + current = BinderException::ColumnNotFound(Identifier(column_name), StringsToIdentifiers(top_candidates), context); return true; } @@ -206,8 +192,6 @@ BindResult ExpressionBinder::BindCorrelatedColumns(unique_ptr // make a copy of the set of binders, so we can restore it later auto binders = active_binders; auto bind_error = std::move(error_message); - // we already failed with the current binder - active_binders.pop_back(); idx_t depth = 1; while (!active_binders.empty()) { auto &next_binder = active_binders.back().get(); @@ -237,7 +221,7 @@ void ExpressionBinder::BindChild(unique_ptr &expr, idx_t depth void ExpressionBinder::ExtractCorrelatedExpressions(Binder &binder, Expression &expr) { if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &bound_colref = expr.Cast(); - if (bound_colref.depth > 0) { + if (bound_colref.Depth() > 0) { binder.AddCorrelatedColumn(CorrelatedColumnInfo(bound_colref)); } } @@ -341,7 +325,7 @@ unique_ptr ExpressionBinder::Bind(unique_ptr &expr // the binder has a specific target type: add a cast to that type result = BoundCastExpression::AddCastToType(context, std::move(result), target_type); } else { - if (!binder.can_contain_nulls) { + if (!binder.CanContainNulls()) { // SQL NULL type is only used internally in the binder // cast to INTEGER if we encounter it outside of the binder if (ContainsNullType(result->GetReturnType())) { @@ -407,7 +391,7 @@ BindResult ExpressionBinder::BindUnsupportedExpression(ParsedExpression &expr, i return BindResult(BinderException::Unsupported(expr, message)); } -bool ExpressionBinder::IsUnnestFunction(const string &function_name) { +bool ExpressionBinder::IsUnnestFunction(const Identifier &function_name) { return function_name == "unnest" || function_name == "unlist"; } @@ -416,8 +400,8 @@ bool ExpressionBinder::IsPotentialAlias(const ColumnRefExpression &colref) { if (!colref.IsQualified()) { return true; } - if (colref.column_names.size() == 2) { - return StringUtil::CIEquals(colref.GetTableName(), "alias"); + if (colref.ColumnNames().size() == 2) { + return colref.GetTableName() == "alias"; } return false; } diff --git a/src/duckdb/src/planner/expression_binder/aggregate_binder.cpp b/src/duckdb/src/planner/expression_binder/aggregate_binder.cpp deleted file mode 100644 index b02cbf970..000000000 --- a/src/duckdb/src/planner/expression_binder/aggregate_binder.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include "duckdb/planner/expression_binder/aggregate_binder.hpp" - -#include "duckdb/planner/binder.hpp" - -namespace duckdb { - -AggregateBinder::AggregateBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context, true) { -} - -BindResult AggregateBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { - auto &expr = *expr_ptr; - switch (expr.GetExpressionClass()) { - case ExpressionClass::WINDOW: - throw BinderException::Unsupported(expr, "aggregate function calls cannot contain window function calls"); - default: - return ExpressionBinder::BindExpression(expr_ptr, depth); - } -} - -string AggregateBinder::UnsupportedAggregateMessage() { - return "aggregate function calls cannot be nested"; -} -} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/alter_binder.cpp b/src/duckdb/src/planner/expression_binder/alter_binder.cpp index 128ba6541..db85268d0 100644 --- a/src/duckdb/src/planner/expression_binder/alter_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/alter_binder.cpp @@ -14,9 +14,9 @@ AlterBinder::AlterBinder(Binder &binder, ClientContext &context, TableCatalogEnt } BindResult AlterBinder::BindLambdaReference(LambdaRefExpression &expr, idx_t depth) { - D_ASSERT(lambda_bindings && expr.lambda_idx < lambda_bindings->size()); + D_ASSERT(lambda_bindings && expr.LambdaIndex() < lambda_bindings->size()); auto &lambda_ref = expr.Cast(); - return (*lambda_bindings)[expr.lambda_idx].Bind(lambda_ref, depth); + return (*lambda_bindings)[expr.LambdaIndex()].Bind(lambda_ref, depth); } BindResult AlterBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { @@ -46,14 +46,15 @@ BindResult AlterBinder::BindColumnReference(ColumnRefExpression &col_ref, idx_t } } - if (col_ref.column_names.size() > 1) { + if (col_ref.ColumnNames().size() > 1) { return BindQualifiedColumnName(col_ref, table.name); } - auto idx = table.GetColumnIndex(col_ref.column_names[0], true); + auto col_name = col_ref.ColumnNames()[0]; + auto idx = table.GetColumnIndex(col_name, true); if (!idx.IsValid()) { throw BinderException("Table does not contain column %s referenced in alter statement!", - col_ref.column_names[0]); + col_ref.ColumnNames()[0]); } if (table.GetColumn(idx).Generated()) { throw BinderException("Using generated columns in alter statement not supported"); diff --git a/src/duckdb/src/planner/expression_binder/base_select_binder.cpp b/src/duckdb/src/planner/expression_binder/base_select_binder.cpp index deceb7f12..bbab12bb4 100644 --- a/src/duckdb/src/planner/expression_binder/base_select_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/base_select_binder.cpp @@ -9,15 +9,13 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_case_expression.hpp" -#include "duckdb/planner/expression_binder/aggregate_binder.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" namespace duckdb { -BaseSelectBinder::BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, - BoundGroupInformation &info) - : ExpressionBinder(binder, context), inside_window(false), node(node), info(info) { +BaseSelectBinder::BaseSelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node) + : ExpressionBinder(binder, context), node(node) { } BindResult BaseSelectBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { @@ -29,6 +27,9 @@ BindResult BaseSelectBinder::BindExpression(unique_ptr &expr_p } switch (expr.GetExpressionClass()) { case ExpressionClass::COLUMN_REF: + if (inside_aggregate) { + return ExpressionBinder::BindExpression(expr_ptr, depth, root_expression); + } return BindColumnRef(expr_ptr, depth, root_expression); case ExpressionClass::DEFAULT: return BindResult(BinderException::Unsupported(expr, "SELECT clause cannot contain DEFAULT clause")); @@ -40,12 +41,16 @@ BindResult BaseSelectBinder::BindExpression(unique_ptr &expr_p } ProjectionIndex BaseSelectBinder::TryBindGroup(ParsedExpression &expr) { + if (inside_aggregate) { + return ProjectionIndex(); + } // first check the group alias map, if expr is a ColumnRefExpression + auto &alias_map = node.bind_state.group_alias_map; if (expr.GetExpressionType() == ExpressionType::COLUMN_REF) { auto &colref = expr.Cast(); if (!colref.IsQualified()) { - auto alias_entry = info.alias_map.find(colref.column_names[0]); - if (alias_entry != info.alias_map.end()) { + auto alias_entry = alias_map.find(colref.ColumnNames()[0]); + if (alias_entry != alias_map.end()) { // found entry! return alias_entry->second; } @@ -53,12 +58,13 @@ ProjectionIndex BaseSelectBinder::TryBindGroup(ParsedExpression &expr) { } // no alias reference found // check the list of group columns for a match - auto entry = info.map.find(expr); - if (entry != info.map.end()) { + auto &group_map = node.bind_state.group_map; + auto entry = group_map.find(expr); + if (entry != group_map.end()) { return entry->second; } #ifdef DEBUG - for (auto map_entry : info.map) { + for (auto map_entry : group_map) { D_ASSERT(!map_entry.first.get().Equals(expr)); D_ASSERT(!expr.Equals(map_entry.first.get())); } @@ -75,13 +81,13 @@ BindResult BaseSelectBinder::BindGroupingFunction(OperatorExpression &op, idx_t return BindResult(BinderException(op, "GROUPING statement cannot be used without groups")); } vector group_indexes; - if (op.children.empty()) { + if (op.GetChildren().empty()) { // No arguments provided - use all group columns for (idx_t i = 0; i < node.groups.group_expressions.size(); i++) { group_indexes.push_back(ProjectionIndex(i)); } } else { - for (auto &child : op.children) { + for (auto &child : op.GetChildrenMutable()) { ExpressionBinder::QualifyColumnNames(binder, child); auto idx = TryBindGroup(*child); if (!idx.IsValid()) { @@ -96,18 +102,19 @@ BindResult BaseSelectBinder::BindGroupingFunction(OperatorExpression &op, idx_t } ProjectionIndex col_idx(node.grouping_functions.size()); node.grouping_functions.push_back(std::move(group_indexes)); - return BindResult(make_uniq(op.GetName(), LogicalType::BIGINT, + return BindResult(make_uniq(Identifier(op.GetName()), LogicalType::BIGINT, ColumnBinding(node.groupings_index, col_idx), depth)); } BindResult BaseSelectBinder::BindGroup(ParsedExpression &expr, idx_t depth, ProjectionIndex group_index) { - auto it = info.collated_groups.find(group_index); - if (it != info.collated_groups.end()) { + auto &collated_groups = node.bind_state.collated_groups; + auto it = collated_groups.find(group_index); + if (it != collated_groups.end()) { // This is an implicitly collated group, so we need to refer to the first() aggregate const auto &aggr_index = it->second; const auto return_type = node.aggregates[aggr_index]->GetReturnType(); auto uncollated_first_expression = make_uniq( - expr.GetName(), return_type, ColumnBinding(node.aggregate_index, aggr_index), depth); + Identifier(expr.GetName()), return_type, ColumnBinding(node.aggregate_index, aggr_index), depth); if (node.groups.grouping_sets.size() <= 1) { // if there are no more than two grouping sets, you can return the uncollated first expression. @@ -119,11 +126,11 @@ BindResult BaseSelectBinder::BindGroup(ParsedExpression &expr, idx_t depth, Proj // otherwise you can return the "first" of the uncollated expression. auto &group = node.groups.group_expressions[group_index]; auto collated_group_expression = make_uniq( - expr.GetName(), group->GetReturnType(), ColumnBinding(node.group_index, group_index), depth); + Identifier(expr.GetName()), group->GetReturnType(), ColumnBinding(node.group_index, group_index), depth); auto sql_null = make_uniq(Value(return_type)); auto when_expr = make_uniq(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN); - when_expr->children.push_back(std::move(collated_group_expression)); + when_expr->GetChildrenMutable().push_back(std::move(collated_group_expression)); auto then_expr = make_uniq(Value(return_type)); auto else_expr = std::move(uncollated_first_expression); auto case_expr = @@ -131,7 +138,7 @@ BindResult BaseSelectBinder::BindGroup(ParsedExpression &expr, idx_t depth, Proj return BindResult(std::move(case_expr)); } else { auto &group = node.groups.group_expressions[group_index]; - return BindResult(make_uniq(expr.GetName(), group->GetReturnType(), + return BindResult(make_uniq(Identifier(expr.GetName()), group->GetReturnType(), ColumnBinding(node.group_index, group_index), depth)); } } diff --git a/src/duckdb/src/planner/expression_binder/check_binder.cpp b/src/duckdb/src/planner/expression_binder/check_binder.cpp index c6f1abb5a..be5ceace5 100644 --- a/src/duckdb/src/planner/expression_binder/check_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/check_binder.cpp @@ -6,7 +6,7 @@ namespace duckdb { -CheckBinder::CheckBinder(Binder &binder, ClientContext &context, string table_p, const ColumnList &columns, +CheckBinder::CheckBinder(Binder &binder, ClientContext &context, Identifier table_p, const ColumnList &columns, physical_index_set_t &bound_columns) : ExpressionBinder(binder, context), table(std::move(table_p)), columns(columns), bound_columns(bound_columns) { target_type = LogicalType::INTEGER; @@ -30,14 +30,14 @@ string CheckBinder::UnsupportedAggregateMessage() { return "aggregate functions are not allowed in check constraints"; } -BindResult ExpressionBinder::BindQualifiedColumnName(ColumnRefExpression &colref, const string &table_name) { +BindResult ExpressionBinder::BindQualifiedColumnName(ColumnRefExpression &colref, const Identifier &table_name) { idx_t struct_start = 0; - if (colref.column_names[0] == table_name) { + if (colref.ColumnNames()[0] == table_name) { struct_start++; } - auto result = make_uniq_base(colref.column_names.back()); - for (idx_t i = struct_start; i + 1 < colref.column_names.size(); i++) { - result = CreateStructExtract(std::move(result), colref.column_names[i]); + auto result = make_uniq_base(colref.ColumnNames().back()); + for (idx_t i = struct_start; i + 1 < colref.ColumnNames().size(); i++) { + result = CreateStructExtract(std::move(result), colref.ColumnNames()[i]); } return BindExpression(result, 0); } @@ -46,7 +46,7 @@ BindResult CheckBinder::BindCheckColumn(ColumnRefExpression &colref) { if (!colref.IsQualified()) { if (lambda_bindings) { for (idx_t i = lambda_bindings->size(); i > 0; i--) { - if ((*lambda_bindings)[i - 1].HasMatchingBinding(colref.GetName())) { + if ((*lambda_bindings)[i - 1].HasMatchingBinding(Identifier(colref.GetName()))) { // FIXME: support lambdas in CHECK constraints // FIXME: like so: return (*lambda_bindings)[i - 1].Bind(colref, i, depth); // FIXME: and move this to LambdaRefExpression::FindMatchingBinding @@ -56,14 +56,14 @@ BindResult CheckBinder::BindCheckColumn(ColumnRefExpression &colref) { } } - if (colref.column_names.size() > 1) { + if (colref.ColumnNames().size() > 1) { return BindQualifiedColumnName(colref, table); } - if (!columns.ColumnExists(colref.column_names[0])) { + if (!columns.ColumnExists(colref.ColumnNames()[0])) { throw BinderException("Table does not contain column %s referenced in check constraint!", - colref.column_names[0]); + colref.ColumnNames()[0]); } - auto &col = columns.GetColumn(colref.column_names[0]); + auto &col = columns.GetColumn(colref.ColumnNames()[0]); if (col.Generated()) { auto bound_expression = col.GeneratedExpression().Copy(); return BindExpression(bound_expression, 0, false); diff --git a/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp b/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp index e7c6c256d..77413c68b 100644 --- a/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/column_alias_binder.cpp @@ -11,6 +11,21 @@ namespace duckdb { ColumnAliasBinder::ColumnAliasBinder(SelectBindState &bind_state) : bind_state(bind_state), visited_select_indexes() { } +unique_ptr ColumnAliasBinder::ResolveAlias(ColumnRefExpression &colref) { + if (!ExpressionBinder::IsPotentialAlias(colref)) { + return nullptr; + } + + // We try to find the alias in the alias_map and return false, if no alias exists. + auto alias_entry = bind_state.alias_map.find(colref.ColumnNames().back()); + if (alias_entry == bind_state.alias_map.end()) { + return nullptr; + } + + // We found an alias - bind it + return bind_state.BindAlias(alias_entry->second); +} + bool ColumnAliasBinder::BindAlias(ExpressionBinder &enclosing_binder, unique_ptr &expr_ptr, idx_t depth, bool root_expression, BindResult &result) { D_ASSERT(expr_ptr->GetExpressionClass() == ExpressionClass::COLUMN_REF); @@ -21,7 +36,7 @@ bool ColumnAliasBinder::BindAlias(ExpressionBinder &enclosing_binder, unique_ptr } // We try to find the alias in the alias_map and return false, if no alias exists. - auto alias_entry = bind_state.alias_map.find(expr.column_names.back()); + auto alias_entry = bind_state.alias_map.find(expr.ColumnNames().back()); if (alias_entry == bind_state.alias_map.end()) { return false; } @@ -45,7 +60,7 @@ bool ColumnAliasBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) if (!ExpressionBinder::IsPotentialAlias(colref)) { return false; } - return bind_state.alias_map.find(colref.column_names.back()) != bind_state.alias_map.end(); + return bind_state.alias_map.find(colref.ColumnNames().back()) != bind_state.alias_map.end(); } } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/group_binder.cpp b/src/duckdb/src/planner/expression_binder/group_binder.cpp index e6bcc98ba..6a8abf471 100644 --- a/src/duckdb/src/planner/expression_binder/group_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/group_binder.cpp @@ -10,26 +10,12 @@ namespace duckdb { -GroupBinder::GroupBinder(Binder &binder, ClientContext &context, SelectNode &node, TableIndex group_index, - SelectBindState &bind_state, case_insensitive_map_t &group_alias_map) - : ExpressionBinder(binder, context), node(node), bind_state(bind_state), group_alias_map(group_alias_map), - group_index(group_index) { +GroupBinder::GroupBinder(Binder &binder, ClientContext &context, TableIndex group_index, SelectBindState &bind_state) + : ExpressionBinder(binder, context), bind_state(bind_state), group_index(group_index) { } BindResult GroupBinder::BindExpression(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { auto &expr = *expr_ptr; - if (root_expression && depth == 0) { - switch (expr.GetExpressionClass()) { - case ExpressionClass::COLUMN_REF: - return BindColumnRef(expr.Cast(), expr_ptr); - case ExpressionClass::CONSTANT: - return BindConstant(expr.Cast()); - case ExpressionClass::PARAMETER: - throw ParameterNotAllowedException("Parameter not supported in GROUP BY clause"); - default: - break; - } - } switch (expr.GetExpressionClass()) { case ExpressionClass::DEFAULT: return BindUnsupportedExpression(expr, depth, "GROUP BY clause cannot contain DEFAULT clause"); @@ -44,40 +30,69 @@ string GroupBinder::UnsupportedAggregateMessage() { return "GROUP BY clause cannot contain aggregates!"; } -BindResult GroupBinder::BindSelectRef(idx_t entry) { - if (used_aliases.find(entry) != used_aliases.end()) { +void GroupBinder::ReplaceSelectRef(SelectNode &node, SelectBindState &bind_state, ProjectionIndex group_index, + unique_ptr &expr_ptr) { + auto &expr = *expr_ptr; + idx_t select_list_idx; + switch (expr.GetExpressionClass()) { + case ExpressionClass::COLUMN_REF: { + // root is a column - check if it refers to an alias in the select list + auto &colref = expr.Cast(); + if (!IsPotentialAlias(colref)) { + return; + } + auto &alias_name = colref.GetColumnName(); + auto entry = bind_state.alias_map.find(alias_name); + if (entry == bind_state.alias_map.end()) { + // no matching alias found + return; + } + select_list_idx = entry->second; + bind_state.group_alias_map[alias_name] = group_index; + break; + } + case ExpressionClass::CONSTANT: { + // root is a constant + auto &constant = expr.Cast(); + if (!constant.GetValue().type().IsIntegral()) { + // non-integral expression, we just leave the constant here. + return; + } + // INTEGER constant: we use the integer as an index into the select list (e.g. GROUP BY 1) + auto index = (idx_t)constant.GetValue().GetValue(); + select_list_idx = index - 1; + break; + } + case ExpressionClass::PARAMETER: + throw ParameterNotAllowedException("Parameter not supported in GROUP BY clause"); + default: + return; + } + // this group entry directly refers to a select list expression + // we need to switch them around - the select list expression needs to become a group + // it will then instead refer to the group entry + + // check if we already have a reference to this group + auto &used_aliases = bind_state.used_group_aliases; + if (used_aliases.find(select_list_idx) != used_aliases.end()) { // the alias has already been bound to before! // this happens if we group on the same alias twice // e.g. GROUP BY k, k or GROUP BY 1, 1 // in this case, we can just replace the grouping with a constant since the second grouping has no effect // (the constant grouping will be optimized out later) - return BindResult(make_uniq(Value::INTEGER(42))); + expr_ptr = make_uniq(Value::INTEGER(42)); + return; } - if (entry >= node.select_list.size()) { + if (select_list_idx >= node.select_list.size()) { throw BinderException("GROUP BY term out of range - should be between 1 and %d", (int)node.select_list.size()); } - // we replace the root expression, also replace the unbound expression - unbound_expression = node.select_list[entry]->Copy(); - // move the expression that this refers to here and bind it - auto select_entry = std::move(node.select_list[entry]); - auto binding = Bind(select_entry, nullptr, false); + // move the expression that this refers to into the group expression + expr_ptr = std::move(node.select_list[select_list_idx]); // now replace the original expression in the select list with a reference to this group - group_alias_map[to_string(entry)] = bind_index; - node.select_list[entry] = make_uniq(to_string(entry)); + bind_state.group_alias_map[Identifier(to_string(select_list_idx))] = group_index; + node.select_list[select_list_idx] = make_uniq(Identifier(to_string(select_list_idx))); // insert into the set of used aliases - used_aliases.insert(entry); - return BindResult(std::move(binding)); -} - -BindResult GroupBinder::BindConstant(ConstantExpression &constant) { - // constant as root expression - if (!constant.GetValue().type().IsIntegral()) { - // non-integral expression, we just leave the constant here. - return ExpressionBinder::BindExpression(constant, 0); - } - // INTEGER constant: we use the integer as an index into the select list (e.g. GROUP BY 1) - auto index = (idx_t)constant.GetValue().GetValue(); - return BindSelectRef(index - 1); + used_aliases.insert(select_list_idx); } bool GroupBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, @@ -92,28 +107,10 @@ bool GroupBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t de // no matching alias found return false; } - if (!root_expression) { - result = BindResult(BinderException( - colref, - "Alias with name \"%s\" exists, but aliases cannot be used as part of an expression in the GROUP BY", - alias_name)); - return true; - } - result = BindResult(BindSelectRef(entry->second)); - if (!result.HasError()) { - group_alias_map[alias_name] = bind_index; - } + result = BindResult(BinderException( + colref, "Alias with name \"%s\" exists, but aliases cannot be used as part of an expression in the GROUP BY", + alias_name)); return true; } -BindResult GroupBinder::BindColumnRef(ColumnRefExpression &colref, unique_ptr &expr_ptr) { - // columns in GROUP BY clauses: - // FIRST refer to the original tables, and - // THEN if no match is found refer to aliases in the SELECT list - // THEN if no match is found, refer to outer queries - - // first try to bind to the base columns (original tables) - return ExpressionBinder::BindExpression(colref, 0, true, expr_ptr); -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/having_binder.cpp b/src/duckdb/src/planner/expression_binder/having_binder.cpp index e792369a0..868fee5a5 100644 --- a/src/duckdb/src/planner/expression_binder/having_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/having_binder.cpp @@ -5,40 +5,21 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/column_qualifier.hpp" namespace duckdb { -HavingBinder::HavingBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info, +HavingBinder::HavingBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, AggregateHandling aggregate_handling) - : BaseSelectBinder(binder, context, node, info), column_alias_binder(node.bind_state), + : BaseSelectBinder(binder, context, node), column_alias_binder(node.bind_state), aggregate_handling(aggregate_handling) { target_type = LogicalType(LogicalTypeId::BOOLEAN); } -bool HavingBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) { - return column_alias_binder.DoesColumnAliasExist(colref); -} - BindResult HavingBinder::BindLambdaReference(LambdaRefExpression &expr, idx_t depth) { - D_ASSERT(lambda_bindings && expr.lambda_idx < lambda_bindings->size()); + D_ASSERT(lambda_bindings && expr.LambdaIndex() < lambda_bindings->size()); auto &lambda_ref = expr.Cast(); - return (*lambda_bindings)[expr.lambda_idx].Bind(lambda_ref, depth); -} - -unique_ptr HavingBinder::QualifyColumnName(ColumnRefExpression &colref, ErrorData &error) { - auto qualified_colref = ExpressionBinder::QualifyColumnName(colref, error); - if (!qualified_colref) { - return nullptr; - } - - auto group_index = TryBindGroup(*qualified_colref); - if (group_index.IsValid()) { - return qualified_colref; - } - if (column_alias_binder.DoesColumnAliasExist(colref)) { - return nullptr; - } - return qualified_colref; + return (*lambda_bindings)[expr.LambdaIndex()].Bind(lambda_ref, depth); } BindResult HavingBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { @@ -96,4 +77,10 @@ BindResult HavingBinder::BindWindowExpression(WindowExpression &expr, idx_t dept throw BinderException::Unsupported(expr, "HAVING clause cannot contain window functions!"); } +void ExpressionBinder::QualifyColumnNames(HavingBinder &having_binder, unique_ptr &expr) { + ColumnQualifier qualifier(having_binder.binder, having_binder.lambda_bindings, nullptr, having_binder); + vector lambda_params; + qualifier.QualifyColumnNames(expr, lambda_params); +} + } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/index_binder.cpp b/src/duckdb/src/planner/expression_binder/index_binder.cpp index d44bcfb09..6441e9392 100644 --- a/src/duckdb/src/planner/expression_binder/index_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/index_binder.cpp @@ -48,7 +48,7 @@ unique_ptr IndexBinder::BindIndex(const UnboundIndex &unbound_index) return index_type->create_instance(input); } -void IndexBinder::InitCreateIndexInfo(LogicalGet &get, CreateIndexInfo &info, const string &schema) { +void IndexBinder::InitCreateIndexInfo(LogicalGet &get, CreateIndexInfo &info, const Identifier &schema) { auto &column_ids = get.GetColumnIds(); for (auto &column_id : column_ids) { if (column_id.IsRowIdColumn()) { @@ -61,7 +61,7 @@ void IndexBinder::InitCreateIndexInfo(LogicalGet &get, CreateIndexInfo &info, co info.scan_types.emplace_back(LogicalType::ROW_TYPE); info.names = get.names; - info.schema = schema; + info.schema = Identifier(schema); info.catalog = get.GetTable()->catalog.GetName(); get.AddColumnId(COLUMN_IDENTIFIER_ROW_ID); } diff --git a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp index e52ecedf8..b810c1215 100644 --- a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp @@ -13,7 +13,7 @@ LateralBinder::LateralBinder(Binder &binder, ClientContext &context) : Expressio void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &bound_colref = expr.Cast(); - if (bound_colref.depth > 0) { + if (bound_colref.Depth() > 0) { // add the correlated column info CorrelatedColumnInfo info(bound_colref); if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { @@ -56,13 +56,13 @@ string LateralBinder::UnsupportedAggregateMessage() { static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, const CorrelatedColumns &correlated_columns) { // don't need to reduce this - if (expr.depth == 0) { + if (expr.Depth() == 0) { return; } for (auto &correlated : correlated_columns) { - if (correlated.binding == expr.binding) { - D_ASSERT(expr.depth > 1); - expr.depth--; + if (correlated.binding == expr.Binding()) { + D_ASSERT(expr.Depth() > 1); + expr.DepthMutable()--; break; } } @@ -105,9 +105,9 @@ class ExpressionDepthReducerRecursive : public LogicalOperatorVisitor { } static void ReduceExpressionSubquery(BoundSubqueryExpression &expr, const CorrelatedColumns &correlated_columns) { - ReduceColumnDepth(expr.binder->correlated_columns, correlated_columns); + ReduceColumnDepth(expr.GetBinder()->correlated_columns, correlated_columns); ExpressionDepthReducerRecursive recursive(correlated_columns); - recursive.VisitOperator(*expr.subquery.plan); + recursive.VisitOperator(*expr.SubqueryMutable().plan); } private: diff --git a/src/duckdb/src/planner/expression_binder/order_binder.cpp b/src/duckdb/src/planner/expression_binder/order_binder.cpp index 4fcc9f14f..89e944f58 100644 --- a/src/duckdb/src/planner/expression_binder/order_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/order_binder.cpp @@ -27,9 +27,9 @@ OrderBinder::OrderBinder(vector> binders, SelectNode &node, Se } unique_ptr OrderBinder::CreateProjectionReference(ParsedExpression &expr, const idx_t index) { - string alias; + Identifier alias; if (extra_list && index < extra_list->size()) { - alias = extra_list->at(index)->ToString(); + alias = Identifier(extra_list->at(index)->ToString()); } else { if (!expr.GetAlias().empty()) { alias = expr.GetAlias(); @@ -80,7 +80,7 @@ optional_idx OrderBinder::TryGetProjectionReference(ParsedExpression &expr) cons break; } - string alias_name = colref.column_names.back(); + auto &alias_name = colref.ColumnNames().back(); // check the alias list auto entry = bind_state.alias_map.find(alias_name); if (entry != bind_state.alias_map.end()) { @@ -108,7 +108,7 @@ optional_idx OrderBinder::TryGetProjectionReference(ParsedExpression &expr) cons } case ExpressionClass::POSITIONAL_REFERENCE: { auto &posref = expr.Cast(); - return posref.index - 1; + return posref.Index() - 1; } default: break; @@ -130,7 +130,7 @@ unique_ptr OrderBinder::BindConstant(ParsedExpression &expr) { return nullptr; } child_list_t values; - values.push_back(make_pair("index", Value::UBIGINT(index.GetIndex()))); + values.emplace_back(make_pair("index", Value::UBIGINT(index.GetIndex()))); auto result = make_uniq(Value::STRUCT(std::move(values))); result->SetAlias(expr.GetAlias()); result->SetQueryLocation(expr.GetQueryLocation()); @@ -164,11 +164,11 @@ unique_ptr OrderBinder::Bind(unique_ptr expr) { } case ExpressionClass::COLLATE: { auto &collation = expr->Cast(); - auto collation_index = TryGetProjectionReference(*collation.child); + auto collation_index = TryGetProjectionReference(*collation.ChildMutable()); if (collation_index.IsValid()) { child_list_t values; - values.push_back(make_pair("index", Value::UBIGINT(collation_index.GetIndex()))); - values.push_back(make_pair("collation", Value(std::move(collation.collation)))); + values.emplace_back(make_pair("index", Value::UBIGINT(collation_index.GetIndex()))); + values.emplace_back(make_pair("collation", Value(collation.Collation()))); return make_uniq(Value::STRUCT(std::move(values))); } break; diff --git a/src/duckdb/src/planner/expression_binder/projection_binder.cpp b/src/duckdb/src/planner/expression_binder/projection_binder.cpp index 95b8032d8..b0d349322 100644 --- a/src/duckdb/src/planner/expression_binder/projection_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/projection_binder.cpp @@ -10,7 +10,12 @@ ProjectionBinder::ProjectionBinder(Binder &binder, ClientContext &context, Table } BindResult ProjectionBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { + if (in_child_projection) { + return ExpressionBinder::BindExpression(expr_ptr, depth); + } + in_child_projection = true; auto result = ExpressionBinder::BindExpression(expr_ptr, depth); + in_child_projection = false; if (result.HasError()) { return result; } @@ -34,6 +39,7 @@ BindResult ProjectionBinder::BindExpression(unique_ptr &expr_p case ExpressionClass::WINDOW: return BindUnsupportedExpression(expr, depth, clause + " cannot contain window functions!"); case ExpressionClass::COLUMN_REF: + case ExpressionClass::SUBQUERY: return BindColumnRef(expr_ptr, depth, root_expression); default: return ExpressionBinder::BindExpression(expr_ptr, depth); diff --git a/src/duckdb/src/planner/expression_binder/qualify_binder.cpp b/src/duckdb/src/planner/expression_binder/qualify_binder.cpp index 6b3c704c8..7b511227e 100644 --- a/src/duckdb/src/planner/expression_binder/qualify_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/qualify_binder.cpp @@ -2,22 +2,17 @@ #include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/expression_binder/aggregate_binder.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/parser/expression/window_expression.hpp" namespace duckdb { -QualifyBinder::QualifyBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info) - : BaseSelectBinder(binder, context, node, info), column_alias_binder(node.bind_state) { +QualifyBinder::QualifyBinder(Binder &binder, ClientContext &context, BoundSelectNode &node) + : BaseSelectBinder(binder, context, node), column_alias_binder(node.bind_state) { target_type = LogicalType(LogicalTypeId::BOOLEAN); } -bool QualifyBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) { - return column_alias_binder.DoesColumnAliasExist(colref); -} - BindResult QualifyBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { auto result = duckdb::BaseSelectBinder::BindColumnRef(expr_ptr, depth, root_expression); if (!result.HasError()) { diff --git a/src/duckdb/src/planner/expression_binder/select_binder.cpp b/src/duckdb/src/planner/expression_binder/select_binder.cpp index acbba314b..ae8d4b48d 100644 --- a/src/duckdb/src/planner/expression_binder/select_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/select_binder.cpp @@ -5,8 +5,8 @@ namespace duckdb { -SelectBinder::SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node, BoundGroupInformation &info) - : BaseSelectBinder(binder, context, node, info) { +SelectBinder::SelectBinder(Binder &binder, ClientContext &context, BoundSelectNode &node) + : BaseSelectBinder(binder, context, node) { } bool SelectBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, @@ -16,7 +16,7 @@ bool SelectBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t d return false; } - const auto &alias_name = colref.column_names.back(); + const auto &alias_name = colref.ColumnNames().back(); auto entry = node.bind_state.alias_map.find(alias_name); if (entry == node.bind_state.alias_map.end()) { return false; @@ -27,7 +27,7 @@ bool SelectBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t d if (alias_index >= node.bound_column_count) { throw BinderException("Column \"%s\" referenced that exists in the SELECT clause - but this column " "cannot be referenced before it is defined", - colref.column_names.back()); + colref.ColumnNames().back()); } if (node.bind_state.AliasHasSubquery(alias_index)) { @@ -41,19 +41,4 @@ bool SelectBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t d return true; } -bool SelectBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) { - // Using `back()` to support both qualified and unqualified aliasing - auto alias_name = colref.column_names.back(); - return node.bind_state.alias_map.find(alias_name) != node.bind_state.alias_map.end(); -} - -unique_ptr SelectBinder::GetSQLValueFunction(const string &column_name) { - auto alias_entry = node.bind_state.alias_map.find(column_name); - if (alias_entry != node.bind_state.alias_map.end()) { - // don't replace SQL value functions if they are in the alias map - return nullptr; - } - return ExpressionBinder::GetSQLValueFunction(column_name); -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp index f653a129d..b9941868c 100644 --- a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp @@ -13,9 +13,9 @@ TableFunctionBinder::TableFunctionBinder(Binder &binder, ClientContext &context, } BindResult TableFunctionBinder::BindLambdaReference(LambdaRefExpression &expr, idx_t depth) { - D_ASSERT(lambda_bindings && expr.lambda_idx < lambda_bindings->size()); + D_ASSERT(lambda_bindings && expr.LambdaIndex() < lambda_bindings->size()); auto &lambda_ref = expr.Cast(); - return (*lambda_bindings)[expr.lambda_idx].Bind(lambda_ref, depth); + return (*lambda_bindings)[expr.LambdaIndex()].Bind(lambda_ref, depth); } BindResult TableFunctionBinder::BindColumnReference(unique_ptr &expr_ptr, idx_t depth, @@ -28,16 +28,16 @@ BindResult TableFunctionBinder::BindColumnReference(unique_ptr return BindLambdaReference(lambda_ref->Cast(), depth); } - if (binder.macro_binding && binder.macro_binding->HasMatchingBinding(col_ref.GetName())) { + if (binder.macro_binding && binder.macro_binding->HasMatchingBinding(Identifier(col_ref.GetName()))) { throw ParameterNotResolvedException(); } - } else if (col_ref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos && binder.macro_binding && - binder.macro_binding->HasMatchingBinding(col_ref.GetName())) { + } else if (col_ref.ColumnNames()[0].GetIdentifierName().find(DummyBinding::DUMMY_NAME) != string::npos && + binder.macro_binding && binder.macro_binding->HasMatchingBinding(Identifier(col_ref.GetName()))) { throw ParameterNotResolvedException(); } auto query_location = col_ref.GetQueryLocation(); - auto column_names = col_ref.column_names; + auto column_names = col_ref.ColumnNames(); auto result_name = StringUtil::Join(column_names, "."); if (!table_function_name.empty()) { // check if this is a lateral join column/parameter diff --git a/src/duckdb/src/planner/expression_binder/try_operator_binder.cpp b/src/duckdb/src/planner/expression_binder/try_operator_binder.cpp deleted file mode 100644 index 256f9c474..000000000 --- a/src/duckdb/src/planner/expression_binder/try_operator_binder.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include "duckdb/planner/expression_binder/try_operator_binder.hpp" - -#include "duckdb/planner/binder.hpp" - -namespace duckdb { - -TryOperatorBinder::TryOperatorBinder(Binder &binder, ClientContext &context) : ExpressionBinder(binder, context, true) { -} - -BindResult TryOperatorBinder::BindAggregate(FunctionExpression &expr, AggregateFunctionCatalogEntry &function, - idx_t depth) { - throw BinderException("aggregates are not allowed inside the TRY expression"); -} - -bool TryOperatorBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t depth, bool root_expression, - BindResult &result, unique_ptr &expr_ptr) { - if (!stored_binder) { - return false; - } - return stored_binder->TryResolveAliasReference(colref, depth, root_expression, result, expr_ptr); -} - -bool TryOperatorBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) { - if (!stored_binder) { - return false; - } - return stored_binder->DoesColumnAliasExist(colref); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/where_binder.cpp b/src/duckdb/src/planner/expression_binder/where_binder.cpp index 21ff6c2cd..ee75c2f0b 100644 --- a/src/duckdb/src/planner/expression_binder/where_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/where_binder.cpp @@ -33,11 +33,4 @@ bool WhereBinder::TryResolveAliasReference(ColumnRefExpression &colref, idx_t de return column_alias_binder->BindAlias(*this, expr_ptr, depth, root_expression, result); } -bool WhereBinder::DoesColumnAliasExist(const ColumnRefExpression &colref) { - if (column_alias_binder) { - return column_alias_binder->DoesColumnAliasExist(colref); - } - return false; -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_iterator.cpp b/src/duckdb/src/planner/expression_iterator.cpp index 0bc99af2f..b296580f7 100644 --- a/src/duckdb/src/planner/expression_iterator.cpp +++ b/src/duckdb/src/planner/expression_iterator.cpp @@ -20,14 +20,14 @@ void ExpressionIterator::EnumerateChildren(Expression &expr, switch (expr.GetExpressionClass()) { case ExpressionClass::BOUND_AGGREGATE: { auto &aggr_expr = expr.Cast(); - for (auto &child : aggr_expr.children) { + for (auto &child : aggr_expr.GetChildrenMutable()) { callback(child); } - if (aggr_expr.filter) { - callback(aggr_expr.filter); + if (aggr_expr.GetFilter()) { + callback(aggr_expr.GetFilterMutable()); } - if (aggr_expr.order_bys) { - for (auto &order : aggr_expr.order_bys->orders) { + if (aggr_expr.GetOrderBys()) { + for (auto &order : aggr_expr.GetOrderBysMutable()->orders) { callback(order.expression); } } @@ -35,74 +35,74 @@ void ExpressionIterator::EnumerateChildren(Expression &expr, } case ExpressionClass::BOUND_CASE: { auto &case_expr = expr.Cast(); - for (auto &case_check : case_expr.case_checks) { + for (auto &case_check : case_expr.CaseChecksMutable()) { callback(case_check.when_expr); callback(case_check.then_expr); } - callback(case_expr.else_expr); + callback(case_expr.ElseMutable()); break; } case ExpressionClass::BOUND_CAST: { auto &cast_expr = expr.Cast(); - callback(cast_expr.child); + callback(cast_expr.ChildMutable()); break; } case ExpressionClass::BOUND_CONJUNCTION: { auto &conj_expr = expr.Cast(); - for (auto &child : conj_expr.children) { + for (auto &child : conj_expr.GetChildrenMutable()) { callback(child); } break; } case ExpressionClass::BOUND_FUNCTION: { auto &func_expr = expr.Cast(); - for (auto &child : func_expr.children) { + for (auto &child : func_expr.GetChildrenMutable()) { callback(child); } break; } case ExpressionClass::BOUND_OPERATOR: { auto &op_expr = expr.Cast(); - for (auto &child : op_expr.children) { + for (auto &child : op_expr.GetChildrenMutable()) { callback(child); } break; } case ExpressionClass::BOUND_SUBQUERY: { auto &subquery_expr = expr.Cast(); - for (auto &child : subquery_expr.children) { + for (auto &child : subquery_expr.GetChildrenMutable()) { callback(child); } break; } case ExpressionClass::BOUND_WINDOW: { auto &window_expr = expr.Cast(); - for (auto &partition : window_expr.partitions) { + for (auto &partition : window_expr.PartitionsMutable()) { callback(partition); } - for (auto &order : window_expr.orders) { + for (auto &order : window_expr.OrderByMutable()) { callback(order.expression); } - for (auto &child : window_expr.children) { + for (auto &child : window_expr.GetChildrenMutable()) { callback(child); } - if (window_expr.filter_expr) { - callback(window_expr.filter_expr); + if (window_expr.FilterMutable()) { + callback(window_expr.FilterMutable()); } - if (window_expr.start_expr) { - callback(window_expr.start_expr); + if (window_expr.StartExprMutable()) { + callback(window_expr.StartExprMutable()); } - if (window_expr.end_expr) { - callback(window_expr.end_expr); + if (window_expr.EndExprMutable()) { + callback(window_expr.EndExprMutable()); } - for (auto &order : window_expr.arg_orders) { + for (auto &order : window_expr.ArgOrdersMutable()) { callback(order.expression); } break; } case ExpressionClass::BOUND_UNNEST: { auto &unnest_expr = expr.Cast(); - callback(unnest_expr.child); + callback(unnest_expr.ChildMutable()); break; } case ExpressionClass::BOUND_COLUMN_REF: diff --git a/src/duckdb/src/planner/filter/bloom_filter.cpp b/src/duckdb/src/planner/filter/bloom_filter.cpp index 531228884..0a8a4bc35 100644 --- a/src/duckdb/src/planner/filter/bloom_filter.cpp +++ b/src/duckdb/src/planner/filter/bloom_filter.cpp @@ -1,237 +1,32 @@ #include "duckdb/planner/filter/bloom_filter.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/planner/table_filter_state.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" namespace duckdb { -static constexpr idx_t MAX_NUM_SECTORS = (1ULL << 26); -static constexpr idx_t MIN_NUM_BITS_PER_KEY = 12; -static constexpr idx_t MIN_NUM_BITS = 512; -static constexpr idx_t LOG_SECTOR_SIZE = 6; // a sector is 64 bits, log2(64) = 6 -static constexpr idx_t SHIFT_MASK = 0x3F3F3F3F3F3F3F3F; // 6 bits for 64 positions -static constexpr idx_t N_BITS = 4; // the number of bits to set per hash - -void BloomFilter::Initialize(ClientContext &context_p, idx_t number_of_rows) { - BufferManager &buffer_manager = BufferManager::GetBufferManager(context_p); - - const idx_t min_bits = MaxValue(MIN_NUM_BITS, number_of_rows * MIN_NUM_BITS_PER_KEY); - num_sectors = MinValue(NextPowerOfTwo(min_bits) >> LOG_SECTOR_SIZE, MAX_NUM_SECTORS); - bitmask = num_sectors - 1; - - buf_ = buffer_manager.GetBufferAllocator().Allocate(64 + num_sectors * sizeof(uint64_t)); - // make sure blocks is a 64-byte aligned pointer, i.e., cache-line aligned - bf = reinterpret_cast((64ULL + reinterpret_cast(buf_.get())) & ~63ULL); - std::fill_n(bf, num_sectors, 0); - - initialized = true; -} - -inline uint64_t GetMask(const hash_t hash) { - const uint64_t shifts = hash & SHIFT_MASK; - const auto shifts_8 = reinterpret_cast(&shifts); - - uint64_t mask = 0; - - for (idx_t bit_idx = 8 - N_BITS; bit_idx < 8; bit_idx++) { - const uint8_t bit_pos = shifts_8[bit_idx]; - mask |= (1ULL << bit_pos); - } - - return mask; -} - -void BloomFilter::InsertHashes(const Vector &hashes_v, idx_t count) const { - auto hashes = FlatVector::GetData(hashes_v); - for (idx_t i = 0; i < count; i++) { - InsertOne(hashes[i]); - } -} - -idx_t BloomFilter::LookupHashes(const Vector &hashes_v, SelectionVector &result_sel, const idx_t count) const { - D_ASSERT(hashes_v.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(hashes_v.GetType() == LogicalType::HASH); - - const auto hashes = FlatVector::GetData(hashes_v); - idx_t found_count = 0; - for (idx_t i = 0; i < count; i++) { - result_sel.set_index(found_count, i); - found_count += LookupOne(hashes[i]); - } - return found_count; -} - -inline void BloomFilter::InsertOne(const hash_t hash) const { - D_ASSERT(initialized); - const uint64_t bf_offset = hash & bitmask; - const uint64_t mask = GetMask(hash); - atomic &slot = *reinterpret_cast *>(&bf[bf_offset]); - - slot.fetch_or(mask, std::memory_order_relaxed); -} - -inline bool BloomFilter::LookupOne(const uint64_t hash) const { - D_ASSERT(initialized); - const uint64_t bf_offset = hash & bitmask; - const uint64_t mask = GetMask(hash); - atomic &slot = *reinterpret_cast *>(&bf[bf_offset]); - auto bf_entry = slot.load(std::memory_order_relaxed); - - return (bf_entry & mask) == mask; -} - -string BFTableFilter::ToString(const string &column_name) const { - return column_name + " IN BF(" + key_column_name + ")"; -} - -idx_t BFTableFilter::Filter(Vector &keys_v, SelectionVector &sel, idx_t &approved_tuple_count, - JoinFilterTableFilterState &state) const { - state.PrepareSlicedKeys(keys_v, sel, approved_tuple_count); - VectorOperations::Hash(state.keys_sliced_v, state.hashes_v, approved_tuple_count); - - idx_t found_count; - if (state.hashes_v.GetVectorType() == VectorType::CONSTANT_VECTOR) { - const auto constant_hash = *ConstantVector::GetData(state.hashes_v); - const bool found = this->filter.LookupOne(constant_hash); - found_count = found ? approved_tuple_count : 0; - } else { - state.hashes_v.Flatten(); - found_count = this->filter.LookupHashes(state.hashes_v, state.probe_sel, approved_tuple_count); - } - - // all the elements have been found, we don't need to translate anything - if (found_count == approved_tuple_count) { - return approved_tuple_count; - } - - if (sel.IsSet()) { - for (idx_t idx = 0; idx < found_count; idx++) { - const idx_t flat_sel_idx = state.probe_sel.get_index(idx); - const idx_t original_sel_idx = sel.get_index(flat_sel_idx); - sel.set_index(idx, original_sel_idx); - } - } else { - sel.Initialize(state.probe_sel); - } - - approved_tuple_count = found_count; - return approved_tuple_count; -} - -bool BFTableFilter::FilterValue(const Value &value) const { - const auto hash = value.Hash(); - return filter.LookupOne(hash); -} - -template -static FilterPropagateResult TemplatedCheckStatistics(const BloomFilter &bf, T min, T max) { - if (min > max) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; // Invalid stats - } - T range_typed; - idx_t range; - if (!TrySubtractOperator::Operation(max, min, range_typed) || !TryCast::Operation(range_typed, range) || - range >= DEFAULT_STANDARD_VECTOR_SIZE) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; // Overflow or too wide of a range - } - - T val = min; - idx_t hits = 0; - for (idx_t i = 0; i <= range; i++) { - hits += bf.LookupOne(Hash(val)); - val += i < range; // Avoids potential signed integer overflow on the last iteration - } - - if (hits == 0) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - if (hits == range + 1) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; -} - -FilterPropagateResult BFTableFilter::CheckStatistics(BaseStatistics &stats) const { - if (!TypeIsInteger(key_type.InternalType()) || !NumericStats::HasMinMax(stats)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - auto min_val = NumericStats::Min(stats); - auto max_val = NumericStats::Max(stats); - - // The filter was built with key_type (condition_type), but stats are in storage_type - if (stats.GetType() != key_type && (!min_val.DefaultTryCastAs(key_type) || !max_val.DefaultTryCastAs(key_type))) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - switch (key_type.InternalType()) { - case PhysicalType::UINT8: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::UINT16: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::UINT32: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::UINT64: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::UINT128: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::INT8: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::INT16: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::INT32: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::INT64: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - case PhysicalType::INT128: - return TemplatedCheckStatistics(filter, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe()); - default: - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } -} - -bool BFTableFilter::Equals(const TableFilter &other_p) const { - if (!TableFilter::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return RefersToSameObject(filter, other.filter) && filters_null_values == other.filters_null_values && - key_column_name == other.key_column_name && key_type == other.key_type; -} -unique_ptr BFTableFilter::Copy() const { - return make_uniq(this->filter, this->filters_null_values, this->key_column_name, this->key_type); -} - -unique_ptr BFTableFilter::ToExpression(const Expression &column) const { - auto bound_constant = make_uniq(Value(true)); - return std::move(bound_constant); +unique_ptr LegacyBFTableFilter::ToExpression(const Expression &column) const { + auto function = BloomFilterScalarFun::GetFunction(column.GetReturnType()); + auto bind_data = + make_uniq(filter, filters_null_values, key_column_name, key_type, 0.0f, idx_t(0)); + vector> arguments; + arguments.push_back(column.Copy()); + return make_uniq(BoundScalarFunction(function), std::move(arguments), + std::move(bind_data)); } -void BFTableFilter::Serialize(Serializer &serializer) const { +void LegacyBFTableFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); serializer.WriteProperty(200, "filters_null_values", filters_null_values); serializer.WriteProperty(201, "key_column_name", key_column_name); serializer.WriteProperty(202, "key_type", key_type); } -unique_ptr BFTableFilter::Deserialize(Deserializer &deserializer) { +unique_ptr LegacyBFTableFilter::Deserialize(Deserializer &deserializer) { auto filters_null_values = deserializer.ReadProperty(200, "filters_null_values"); auto key_column_name = deserializer.ReadProperty(201, "key_column_name"); auto key_type = deserializer.ReadProperty(202, "key_type"); - BloomFilter filter; - auto result = make_uniq(filter, filters_null_values, key_column_name, key_type); + auto result = make_uniq(nullptr, filters_null_values, key_column_name, key_type); return std::move(result); } diff --git a/src/duckdb/src/planner/filter/conjunction_filter.cpp b/src/duckdb/src/planner/filter/conjunction_filter.cpp index d5e98d47a..7e409816e 100644 --- a/src/duckdb/src/planner/filter/conjunction_filter.cpp +++ b/src/duckdb/src/planner/filter/conjunction_filter.cpp @@ -4,123 +4,26 @@ namespace duckdb { -ConjunctionOrFilter::ConjunctionOrFilter() : ConjunctionFilter(TableFilterType::CONJUNCTION_OR) { +LegacyConjunctionOrFilter::LegacyConjunctionOrFilter() + : LegacyConjunctionFilter(TableFilterType::LEGACY_CONJUNCTION_OR) { } -FilterPropagateResult ConjunctionOrFilter::CheckStatistics(BaseStatistics &stats) const { - // the OR filter is true if ANY of the children is true - D_ASSERT(!child_filters.empty()); - for (auto &filter : child_filters) { - auto prune_result = filter->CheckStatistics(stats); - if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } else if (prune_result == FilterPropagateResult::FILTER_ALWAYS_TRUE) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - } - return FilterPropagateResult::FILTER_ALWAYS_FALSE; -} - -string ConjunctionOrFilter::ToString(const string &column_name) const { - string result; - for (idx_t i = 0; i < child_filters.size(); i++) { - if (i > 0) { - result += " OR "; - } - result += child_filters[i]->ToString(column_name); - } - return result; -} - -bool ConjunctionOrFilter::Equals(const TableFilter &other_p) const { - if (!ConjunctionFilter::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (other.child_filters.size() != child_filters.size()) { - return false; - } - for (idx_t i = 0; i < other.child_filters.size(); i++) { - if (!child_filters[i]->Equals(*other.child_filters[i])) { - return false; - } - } - return true; -} - -unique_ptr ConjunctionOrFilter::Copy() const { - auto result = make_uniq(); - for (auto &filter : child_filters) { - result->child_filters.push_back(filter->Copy()); - } - return std::move(result); -} - -unique_ptr ConjunctionOrFilter::ToExpression(const Expression &column) const { +unique_ptr LegacyConjunctionOrFilter::ToExpression(const Expression &column) const { auto conjunction = make_uniq(ExpressionType::CONJUNCTION_OR); for (auto &filter : child_filters) { - conjunction->children.push_back(filter->ToExpression(column)); + conjunction->GetChildrenMutable().push_back(filter->ToExpression(column)); } return std::move(conjunction); } -ConjunctionAndFilter::ConjunctionAndFilter() : ConjunctionFilter(TableFilterType::CONJUNCTION_AND) { -} - -FilterPropagateResult ConjunctionAndFilter::CheckStatistics(BaseStatistics &stats) const { - // the AND filter is true if ALL of the children is true - D_ASSERT(!child_filters.empty()); - auto result = FilterPropagateResult::FILTER_ALWAYS_TRUE; - for (auto &filter : child_filters) { - auto prune_result = filter->CheckStatistics(stats); - if (prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } else if (prune_result != result) { - result = FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - } - return result; -} - -string ConjunctionAndFilter::ToString(const string &column_name) const { - string result; - for (idx_t i = 0; i < child_filters.size(); i++) { - if (i > 0) { - result += " AND "; - } - result += child_filters[i]->ToString(column_name); - } - return result; -} - -bool ConjunctionAndFilter::Equals(const TableFilter &other_p) const { - if (!ConjunctionFilter::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - if (other.child_filters.size() != child_filters.size()) { - return false; - } - for (idx_t i = 0; i < other.child_filters.size(); i++) { - if (!child_filters[i]->Equals(*other.child_filters[i])) { - return false; - } - } - return true; -} - -unique_ptr ConjunctionAndFilter::Copy() const { - auto result = make_uniq(); - for (auto &filter : child_filters) { - result->child_filters.push_back(filter->Copy()); - } - return std::move(result); +LegacyConjunctionAndFilter::LegacyConjunctionAndFilter() + : LegacyConjunctionFilter(TableFilterType::LEGACY_CONJUNCTION_AND) { } -unique_ptr ConjunctionAndFilter::ToExpression(const Expression &column) const { +unique_ptr LegacyConjunctionAndFilter::ToExpression(const Expression &column) const { auto conjunction = make_uniq(ExpressionType::CONJUNCTION_AND); for (auto &filter : child_filters) { - conjunction->children.push_back(filter->ToExpression(column)); + conjunction->GetChildrenMutable().push_back(filter->ToExpression(column)); } return std::move(conjunction); } diff --git a/src/duckdb/src/planner/filter/constant_filter.cpp b/src/duckdb/src/planner/filter/constant_filter.cpp index 27e3ed7e2..c163ab8ab 100644 --- a/src/duckdb/src/planner/filter/constant_filter.cpp +++ b/src/duckdb/src/planner/filter/constant_filter.cpp @@ -7,38 +7,18 @@ namespace duckdb { -ConstantFilter::ConstantFilter(ExpressionType comparison_type_p, Value constant_p) - : TableFilter(TableFilterType::CONSTANT_COMPARISON), comparison_type(comparison_type_p), +LegacyConstantFilter::LegacyConstantFilter(ExpressionType comparison_type_p, Value constant_p) + : TableFilter(TableFilterType::LEGACY_CONSTANT_COMPARISON), comparison_type(comparison_type_p), constant(std::move(constant_p)) { if (constant.IsNull()) { - throw InternalException("ConstantFilter constant cannot be NULL - use IsNullFilter instead"); + throw InternalException("LegacyConstantFilter constant cannot be NULL - use IsNullFilter instead"); } } -FilterPropagateResult ConstantFilter::CheckStatistics(BaseStatistics &stats) const { - throw InternalException("ConstantFilter::CheckStatistics should not be called: ConstantFilters should be converted " - "to ExpressionFilters before statistics checking"); -} - -string ConstantFilter::ToString(const string &column_name) const { - throw InternalException("ConstantFilter::ToString should not be called: ConstantFilters should be converted to " - "ExpressionFilters before rendering"); -} - -unique_ptr ConstantFilter::ToExpression(const Expression &column) const { +unique_ptr LegacyConstantFilter::ToExpression(const Expression &column) const { auto bound_constant = make_uniq(constant); auto result = BoundComparisonExpression::Create(comparison_type, column.Copy(), std::move(bound_constant)); return result; } -bool ConstantFilter::Equals(const TableFilter &other_p) const { - throw InternalException("ConstantFilter::Equals should not be called: ConstantFilters should be converted to " - "ExpressionFilters before equality checking"); -} - -unique_ptr ConstantFilter::Copy() const { - throw InternalException("ConstantFilter::Copy should not be called: ConstantFilters should be converted to " - "ExpressionFilters before copying"); -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/filter/dynamic_filter.cpp b/src/duckdb/src/planner/filter/dynamic_filter.cpp index 85f1f444b..58235e9f6 100644 --- a/src/duckdb/src/planner/filter/dynamic_filter.cpp +++ b/src/duckdb/src/planner/filter/dynamic_filter.cpp @@ -1,50 +1,21 @@ #include "duckdb/planner/filter/dynamic_filter.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/filter/expression_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/value_operations/value_operations.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/storage/statistics/numeric_stats.hpp" +#include "duckdb/storage/statistics/string_stats.hpp" namespace duckdb { -DynamicFilter::DynamicFilter() : TableFilter(TableFilterType::DYNAMIC_FILTER) { +LegacyDynamicFilter::LegacyDynamicFilter() : TableFilter(TableFilterType::LEGACY_DYNAMIC_FILTER) { } -DynamicFilter::DynamicFilter(shared_ptr filter_data_p) - : TableFilter(TableFilterType::DYNAMIC_FILTER), filter_data(std::move(filter_data_p)) { +LegacyDynamicFilter::LegacyDynamicFilter(shared_ptr filter_data_p) + : TableFilter(TableFilterType::LEGACY_DYNAMIC_FILTER), filter_data(std::move(filter_data_p)) { } -DynamicFilterData::DynamicFilterData(ExpressionType comparison_type_p, Value constant_p) - : comparison_type(comparison_type_p), constant(std::move(constant_p)) { -} - -unique_ptr DynamicFilterData::ToExpression(const Expression &column) const { - return BoundComparisonExpression::Create(comparison_type, column.Copy(), - make_uniq(constant)); -} - -FilterPropagateResult DynamicFilter::CheckStatistics(BaseStatistics &stats) const { - if (!filter_data) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - lock_guard l(filter_data->lock); - if (!filter_data->initialized) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - auto column = make_uniq(stats.GetType(), 0ULL); - auto expression = filter_data->ToExpression(*column); - return ExpressionFilter::CheckExpressionStatistics(*expression, stats); -} - -string DynamicFilter::ToString(const string &column_name) const { - if (filter_data) { - return "Dynamic Filter (" + column_name + ")"; - } else { - return "Empty Dynamic Filter (" + column_name + ")"; - } -} - -unique_ptr DynamicFilter::ToExpression(const Expression &column) const { +unique_ptr LegacyDynamicFilter::ToExpression(const Expression &column) const { if (!filter_data || !filter_data->initialized) { auto bound_constant = make_uniq(Value(true)); return std::move(bound_constant); @@ -53,18 +24,6 @@ unique_ptr DynamicFilter::ToExpression(const Expression &column) con return filter_data->ToExpression(column); } -bool DynamicFilter::Equals(const TableFilter &other_p) const { - if (!TableFilter::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return other.filter_data.get() == filter_data.get(); -} - -unique_ptr DynamicFilter::Copy() const { - return make_uniq(filter_data); -} - void DynamicFilterData::SetValue(Value val) { if (val.IsNull()) { return; diff --git a/src/duckdb/src/planner/filter/expression_filter.cpp b/src/duckdb/src/planner/filter/expression_filter.cpp index b28a1641b..e30c4219b 100644 --- a/src/duckdb/src/planner/filter/expression_filter.cpp +++ b/src/duckdb/src/planner/filter/expression_filter.cpp @@ -1,13 +1,27 @@ #include "duckdb/planner/filter/expression_filter.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/filter/bloom_filter.hpp" +#include "duckdb/planner/filter/dynamic_filter.hpp" +#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/perfect_hash_join_filter.hpp" +#include "duckdb/planner/filter/prefix_range_filter.hpp" +#include "duckdb/planner/filter/selectivity_optional_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/statistics/numeric_stats.hpp" #include "duckdb/storage/statistics/struct_stats.hpp" #include "duckdb/storage/statistics/string_stats.hpp" @@ -37,9 +51,9 @@ ExpressionFilter &ExpressionFilter::GetExpressionFilter(TableFilter &filter, con unique_ptr ExpressionFilter::CreateInExpression(unique_ptr column, vector values) { auto result = make_uniq(ExpressionType::COMPARE_IN, LogicalType::BOOLEAN); - result->children.push_back(std::move(column)); + result->GetChildrenMutable().push_back(std::move(column)); for (auto &value : values) { - result->children.push_back(make_uniq(std::move(value))); + result->GetChildrenMutable().push_back(make_uniq(std::move(value))); } return std::move(result); } @@ -49,19 +63,92 @@ unique_ptr ExpressionFilter::CreateNullCheckExpression(unique_ptr(expression_type, LogicalType::BOOLEAN); - result->children.push_back(std::move(column)); + result->GetChildrenMutable().push_back(std::move(column)); return std::move(result); } +static bool IsOptionalInternalFunction(const BoundFunctionExpression &func) { + return func.Function().GetName() == OptionalFilterScalarFun::NAME || + func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME; +} + +static bool IsOptionalExpressionInternal(const Expression &expr, bool recurse_through_and) { + if (expr.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + return IsOptionalInternalFunction(expr.Cast()); + } + if (!recurse_through_and || expr.GetExpressionClass() != ExpressionClass::BOUND_CONJUNCTION || + expr.GetExpressionType() != ExpressionType::CONJUNCTION_AND) { + return false; + } + auto &conj = expr.Cast(); + if (conj.GetChildren().empty()) { + return false; + } + for (auto &child : conj.GetChildren()) { + if (!IsOptionalExpressionInternal(*child, true)) { + return false; + } + } + return true; +} + +static bool ContainsInternalTableFilterFunction(const Expression &expr) { + if (expr.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &func = expr.Cast(); + if (TableFilterFunctions::IsTableFilterFunction(func.Function())) { + return true; + } + if (func.Function().GetName() == OptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + if (data.child_filter_expr && ContainsInternalTableFilterFunction(*data.child_filter_expr)) { + return true; + } + } + if (func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + if (data.child_filter_expr && ContainsInternalTableFilterFunction(*data.child_filter_expr)) { + return true; + } + } + } + bool found = false; + ExpressionIterator::EnumerateChildren(expr, [&](const Expression &child) { + if (!found) { + found = ContainsInternalTableFilterFunction(child); + } + }); + return found; +} + +static unique_ptr UnwrapOptionalFiltersForConstantEvaluation(const Expression &expr) { + if (expr.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &func = expr.Cast(); + if (func.Function().GetName() == OptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + return data.child_filter_expr ? UnwrapOptionalFiltersForConstantEvaluation(*data.child_filter_expr) + : make_uniq(Value::BOOLEAN(true)); + } + if (func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + return data.child_filter_expr ? UnwrapOptionalFiltersForConstantEvaluation(*data.child_filter_expr) + : make_uniq(Value::BOOLEAN(true)); + } + } + auto result = expr.Copy(); + ExpressionIterator::EnumerateChildren( + *result, [&](unique_ptr &child) { child = UnwrapOptionalFiltersForConstantEvaluation(*child); }); + return result; +} + bool ExpressionFilter::EvaluateWithConstant(ClientContext &context, const Value &val) const { - ExpressionExecutor executor(context, *expr); + auto constant_eval_expr = UnwrapOptionalFiltersForConstantEvaluation(*expr); + ExpressionExecutor executor(context, *constant_eval_expr); return EvaluateWithConstant(executor, val); } bool ExpressionFilter::EvaluateWithConstant(ExpressionExecutor &executor, const Value &val) const { DataChunk input; input.data.emplace_back(val, count_t(1)); - input.SetCardinality(1); SelectionVector sel(1); @@ -69,7 +156,7 @@ bool ExpressionFilter::EvaluateWithConstant(ExpressionExecutor &executor, const return count > 0; } -FilterPropagateResult ExpressionFilter::CheckStatistics(BaseStatistics &stats) const { +FilterPropagateResult ExpressionFilter::CheckStatistics(const BaseStatistics &stats) const { if (stats.GetStatsType() == StatisticsType::GEOMETRY_STATS) { // Delegate to GeometryStats for geometry types return GeometryStats::CheckZonemap(stats, expr); @@ -77,6 +164,14 @@ FilterPropagateResult ExpressionFilter::CheckStatistics(BaseStatistics &stats) c return CheckExpressionStatistics(*expr, stats); } +FilterPropagateResult ExpressionFilter::CheckStatistics(ClientContext &context, const BaseStatistics &stats) const { + if (stats.GetStatsType() == StatisticsType::GEOMETRY_STATS) { + // Delegate to GeometryStats for geometry types + return GeometryStats::CheckZonemap(stats, expr); + } + return CheckExpressionStatistics(&context, *expr, stats); +} + static FilterPropagateResult CheckZonemapAgainstConstants(const BaseStatistics &stats, ExpressionType comparison_type, array_ptr values) { D_ASSERT(values.size() > 0); @@ -104,22 +199,69 @@ static FilterPropagateResult CheckZonemapAgainstConstants(const BaseStatistics & } } -static optional_ptr TryGetFilterStats(const Expression &expr, const BaseStatistics &stats, +static optional_ptr TryGetFilterStats(optional_ptr context_p, + const Expression &expr, const BaseStatistics &stats, vector> &owned_stats) { switch (expr.GetExpressionClass()) { case ExpressionClass::BOUND_REF: return &stats; + case ExpressionClass::BOUND_CONSTANT: { + auto &constant = expr.Cast().GetValue(); + owned_stats.push_back(BaseStatistics::FromConstant(constant).ToUnique()); + return owned_stats.back().get(); + } + case ExpressionClass::BOUND_CAST: { + auto &cast_expr = expr.Cast(); + auto child_stats = TryGetFilterStats(context_p, cast_expr.Child(), stats, owned_stats); + if (!child_stats) { + return nullptr; + } + auto cast_stats = StatisticsPropagator::TryPropagateCast(*child_stats, cast_expr.Child().GetReturnType(), + cast_expr.GetReturnType()); + if (!cast_stats) { + return nullptr; + } + if (cast_expr.IsTryCast()) { + cast_stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + owned_stats.push_back(std::move(cast_stats)); + return owned_stats.back().get(); + } case ExpressionClass::BOUND_FUNCTION: { auto &func = expr.Cast(); - idx_t child_idx; - if (!TryGetStructExtractChildIndex(func, child_idx) || func.children.empty()) { - return nullptr; + + if (!context_p) { + // Since statistics callback need context, so this path is the fallback + idx_t child_idx; + if (TryGetStructExtractChildIndex(func, child_idx) && !func.GetChildren().empty()) { + auto child_stats = TryGetFilterStats(context_p, *func.GetChildren()[0], stats, owned_stats); + if (!child_stats || child_stats->GetType().id() != LogicalTypeId::STRUCT) { + return nullptr; + } + owned_stats.push_back(StructStats::GetChildStats(*child_stats, child_idx).ToUnique()); + return owned_stats.back().get(); + } } - auto child_stats = TryGetFilterStats(*func.children[0], stats, owned_stats); - if (!child_stats || child_stats->GetType().id() != LogicalTypeId::STRUCT) { + + if (!context_p || !func.Function().HasStatisticsCallback()) { return nullptr; } - owned_stats.push_back(StructStats::GetChildStats(*child_stats, child_idx).ToUnique()); + + vector child_stats; + child_stats.reserve(func.GetChildren().size()); + for (auto &child_expr : func.GetChildren()) { + auto child_stat = TryGetFilterStats(context_p, *child_expr, stats, owned_stats); + if (!child_stat) { + return nullptr; + } + child_stats.push_back(child_stat->Copy()); + } + + // Use copy to avoid expression rewritten + auto expr_copy = func.Copy(); + auto &func_copy = expr_copy->Cast(); + FunctionStatisticsInput input(func_copy, func_copy.BindInfo(), child_stats, &expr_copy); + owned_stats.push_back(func.Function().GetStatisticsCallback()(*context_p, input)); return owned_stats.back().get(); } default: @@ -127,8 +269,9 @@ static optional_ptr TryGetFilterStats(const Expression &ex } } -static FilterPropagateResult CheckComparisonStatistics(const BoundFunctionExpression &comp_expr, - BaseStatistics &stats) { +static FilterPropagateResult CheckComparisonStatistics(optional_ptr context_p, + const BoundFunctionExpression &comp_expr, + const BaseStatistics &stats) { vector> owned_stats; optional_ptr filter_stats; optional_ptr constant_expr; @@ -136,10 +279,10 @@ static FilterPropagateResult CheckComparisonStatistics(const BoundFunctionExpres auto &left = BoundComparisonExpression::Left(comp_expr); auto &right = BoundComparisonExpression::Right(comp_expr); if (right.GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - filter_stats = TryGetFilterStats(left, stats, owned_stats); + filter_stats = TryGetFilterStats(context_p, left, stats, owned_stats); constant_expr = &right.Cast(); } else if (left.GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - filter_stats = TryGetFilterStats(right, stats, owned_stats); + filter_stats = TryGetFilterStats(context_p, right, stats, owned_stats); constant_expr = &left.Cast(); comparison_type = FlipComparisonExpression(comparison_type); } else { @@ -148,38 +291,51 @@ static FilterPropagateResult CheckComparisonStatistics(const BoundFunctionExpres if (!filter_stats || !constant_expr) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - auto &constant = constant_expr->value; + auto &constant = constant_expr->GetValue(); if (constant.IsNull()) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } if (!filter_stats->CanHaveNoNull()) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; + return comparison_type == ExpressionType::COMPARE_DISTINCT_FROM ? FilterPropagateResult::FILTER_ALWAYS_TRUE + : FilterPropagateResult::FILTER_ALWAYS_FALSE; } auto result = CheckZonemapAgainstConstants(*filter_stats, comparison_type, array_ptr(&constant, 1)); - if (result == FilterPropagateResult::FILTER_ALWAYS_TRUE && filter_stats->CanHaveNull()) { + if (filter_stats->CanHaveNull()) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } return result; } -static FilterPropagateResult CheckFunctionStatistics(const BoundFunctionExpression &func_expr, BaseStatistics &stats) { +static FilterPropagateResult CheckFunctionStatistics(optional_ptr context_p, + const BoundFunctionExpression &func_expr, + const BaseStatistics &stats) { if (BoundComparisonExpression::IsComparison(func_expr.GetExpressionType())) { - return CheckComparisonStatistics(func_expr, stats); + return CheckComparisonStatistics(context_p, func_expr, stats); } - if (!func_expr.function.HasFilterPruneCallback()) { + if (!func_expr.Function().HasFilterPruneCallback()) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - FunctionStatisticsPruneInput input(func_expr.bind_info.get(), stats); - return func_expr.function.GetFilterPruneCallback()(input); + vector> owned_stats; + auto filter_stats = &stats; + if (!func_expr.GetChildren().empty()) { + auto child_stats = TryGetFilterStats(context_p, *func_expr.GetChildren()[0], stats, owned_stats); + if (!child_stats) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + filter_stats = child_stats.get(); + } + FunctionStatisticsPruneInput input(func_expr.BindInfo().get(), *filter_stats); + return func_expr.Function().GetFilterPruneCallback()(input); } -static FilterPropagateResult CheckNullOperatorStatistics(const BoundOperatorExpression &op_expr, BaseStatistics &stats, - ExpressionType operator_type) { - if (op_expr.children.empty()) { +static FilterPropagateResult CheckNullOperatorStatistics(optional_ptr context_p, + const BoundOperatorExpression &op_expr, + const BaseStatistics &stats, ExpressionType operator_type) { + if (op_expr.GetChildren().empty()) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } vector> owned_stats; - auto filter_stats = TryGetFilterStats(*op_expr.children[0], stats, owned_stats); + auto filter_stats = TryGetFilterStats(context_p, *op_expr.GetChildren()[0], stats, owned_stats); if (!filter_stats) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } @@ -201,22 +357,24 @@ static FilterPropagateResult CheckNullOperatorStatistics(const BoundOperatorExpr return FilterPropagateResult::NO_PRUNING_POSSIBLE; } -static FilterPropagateResult CheckInOperatorStatistics(const BoundOperatorExpression &op_expr, BaseStatistics &stats) { - if (op_expr.children.size() <= 1) { +static FilterPropagateResult CheckInOperatorStatistics(optional_ptr context_p, + const BoundOperatorExpression &op_expr, + const BaseStatistics &stats) { + if (op_expr.GetChildren().size() <= 1) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } vector> owned_stats; - auto filter_stats = TryGetFilterStats(*op_expr.children[0], stats, owned_stats); + auto filter_stats = TryGetFilterStats(context_p, *op_expr.GetChildren()[0], stats, owned_stats); if (!filter_stats) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } vector values; - values.reserve(op_expr.children.size() - 1); - for (idx_t i = 1; i < op_expr.children.size(); i++) { - if (op_expr.children[i]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { + values.reserve(op_expr.GetChildren().size() - 1); + for (idx_t i = 1; i < op_expr.GetChildren().size(); i++) { + if (op_expr.GetChildren()[i]->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - auto &value = op_expr.children[i]->Cast().value; + auto &value = op_expr.GetChildren()[i]->Cast().GetValue(); if (!value.IsNull()) { values.push_back(value); } @@ -235,24 +393,28 @@ static FilterPropagateResult CheckInOperatorStatistics(const BoundOperatorExpres return result; } -static FilterPropagateResult CheckOperatorStatistics(const BoundOperatorExpression &op_expr, BaseStatistics &stats) { +static FilterPropagateResult CheckOperatorStatistics(optional_ptr context_p, + const BoundOperatorExpression &op_expr, + const BaseStatistics &stats) { switch (op_expr.GetExpressionType()) { case ExpressionType::OPERATOR_IS_NULL: case ExpressionType::OPERATOR_IS_NOT_NULL: - return CheckNullOperatorStatistics(op_expr, stats, op_expr.GetExpressionType()); + return CheckNullOperatorStatistics(context_p, op_expr, stats, op_expr.GetExpressionType()); case ExpressionType::COMPARE_IN: - return CheckInOperatorStatistics(op_expr, stats); + return CheckInOperatorStatistics(context_p, op_expr, stats); default: return FilterPropagateResult::NO_PRUNING_POSSIBLE; } } -static FilterPropagateResult CheckConjunctionStatistics(const BoundConjunctionExpression &conj, BaseStatistics &stats) { +static FilterPropagateResult CheckConjunctionStatistics(optional_ptr context_p, + const BoundConjunctionExpression &conj, + const BaseStatistics &stats) { switch (conj.GetExpressionType()) { case ExpressionType::CONJUNCTION_AND: { auto result = FilterPropagateResult::FILTER_ALWAYS_TRUE; - for (auto &child : conj.children) { - auto prune_result = ExpressionFilter::CheckExpressionStatistics(*child, stats); + for (auto &child : conj.GetChildren()) { + auto prune_result = ExpressionFilter::CheckExpressionStatistics(context_p, *child, stats); if (prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { return FilterPropagateResult::FILTER_ALWAYS_FALSE; } @@ -263,9 +425,9 @@ static FilterPropagateResult CheckConjunctionStatistics(const BoundConjunctionEx return result; } case ExpressionType::CONJUNCTION_OR: - D_ASSERT(!conj.children.empty()); - for (auto &child : conj.children) { - auto prune_result = ExpressionFilter::CheckExpressionStatistics(*child, stats); + D_ASSERT(!conj.GetChildren().empty()); + for (auto &child : conj.GetChildren()) { + auto prune_result = ExpressionFilter::CheckExpressionStatistics(context_p, *child, stats); if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } @@ -279,19 +441,98 @@ static FilterPropagateResult CheckConjunctionStatistics(const BoundConjunctionEx } } -FilterPropagateResult ExpressionFilter::CheckExpressionStatistics(const Expression &expr, BaseStatistics &stats) { +FilterPropagateResult ExpressionFilter::CheckExpressionStatistics(const Expression &expr, const BaseStatistics &stats) { + return CheckExpressionStatistics(nullptr, expr, stats); +} + +FilterPropagateResult ExpressionFilter::CheckExpressionStatistics(optional_ptr context_p, + const Expression &expr, const BaseStatistics &stats) { switch (expr.GetExpressionClass()) { case ExpressionClass::BOUND_FUNCTION: - return CheckFunctionStatistics(expr.Cast(), stats); + return CheckFunctionStatistics(context_p, expr.Cast(), stats); case ExpressionClass::BOUND_OPERATOR: - return CheckOperatorStatistics(expr.Cast(), stats); + return CheckOperatorStatistics(context_p, expr.Cast(), stats); case ExpressionClass::BOUND_CONJUNCTION: - return CheckConjunctionStatistics(expr.Cast(), stats); + return CheckConjunctionStatistics(context_p, expr.Cast(), stats); default: return FilterPropagateResult::NO_PRUNING_POSSIBLE; } } +bool ExpressionFilter::ContainsInternalFunction(const Expression &expr, const string &func_name) { + if (expr.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &func = expr.Cast(); + if (func.Function().GetName() == func_name) { + return true; + } + if (func.Function().GetName() == OptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + if (data.child_filter_expr && ContainsInternalFunction(*data.child_filter_expr, func_name)) { + return true; + } + } + if (func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + if (data.child_filter_expr && ContainsInternalFunction(*data.child_filter_expr, func_name)) { + return true; + } + } + } + bool found = false; + ExpressionIterator::EnumerateChildren(expr, [&](const Expression &child) { + if (!found) { + found = ContainsInternalFunction(child, func_name); + } + }); + return found; +} + +bool ExpressionFilter::IsOptionalExpression(const Expression &expr) { + return IsOptionalExpressionInternal(expr, true); +} + +bool ExpressionFilter::IsRootOptionalExpression(const Expression &expr) { + return IsOptionalExpressionInternal(expr, false); +} + +bool ExpressionFilter::IsOptionalFilter(const TableFilter &filter) { + auto &expr_filter = GetExpressionFilter(filter, "ExpressionFilter::IsOptionalFilter"); + return IsOptionalExpression(*expr_filter.expr); +} + +bool ExpressionFilter::IsRootOptionalFilter(const TableFilter &filter) { + auto &expr_filter = GetExpressionFilter(filter, "ExpressionFilter::IsRootOptionalFilter"); + return IsRootOptionalExpression(*expr_filter.expr); +} + +static shared_ptr TryGetRootDynamicFilterData(const Expression &expr) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + return nullptr; + } + auto &func = expr.Cast(); + if (func.Function().GetName() != DynamicFilterScalarFun::NAME || !func.BindInfo()) { + return nullptr; + } + return func.BindInfo()->Cast().filter_data; +} + +shared_ptr ExpressionFilter::GetRootOptionalDynamicFilterData(const TableFilter &filter) { + auto &expr_filter = GetExpressionFilter(filter, "ExpressionFilter::GetRootOptionalDynamicFilterData"); + if (expr_filter.expr->GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + return nullptr; + } + auto &func = expr_filter.expr->Cast(); + if (func.Function().GetName() == OptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + return data.child_filter_expr ? TryGetRootDynamicFilterData(*data.child_filter_expr) : nullptr; + } + if (func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + return data.child_filter_expr ? TryGetRootDynamicFilterData(*data.child_filter_expr) : nullptr; + } + return nullptr; +} + unique_ptr ExpressionFilter::FromTableFilter(const TableFilter &filter, const LogicalType &col_type) { if (filter.filter_type == TableFilterType::EXPRESSION_FILTER) { auto &expr_filter = filter.Cast(); @@ -306,10 +547,68 @@ unique_ptr ExpressionFilter::FromTableFilter(const TableFilter return make_uniq(std::move(expr)); } +string ExpressionFilter::InternalFunctionToString(const BoundFunctionExpression &func_expr, const string &column_name) { + auto &func_name = func_expr.Function().GetName(); + if (func_name == BloomFilterScalarFun::NAME) { + auto &data = func_expr.BindInfo()->Cast(); + return BloomFilterScalarFun::ToString(column_name, data.key_column_name); + } else if (func_name == PerfectHashJoinScalarFun::NAME) { + auto &data = func_expr.BindInfo()->Cast(); + return PerfectHashJoinScalarFun::ToString(column_name, data.key_column_name); + } else if (func_name == PrefixRangeScalarFun::NAME) { + auto &data = func_expr.BindInfo()->Cast(); + return PrefixRangeScalarFun::ToString(column_name, data.key_column_name); + } else if (func_name == DynamicFilterScalarFun::NAME) { + const auto has_filter_data = + func_expr.BindInfo() && func_expr.BindInfo()->Cast().filter_data; + return DynamicFilterScalarFun::ToString(column_name, has_filter_data); + } else if (func_name == OptionalFilterScalarFun::NAME) { + string child_filter_string; + if (func_expr.BindInfo()) { + auto &data = func_expr.BindInfo()->Cast(); + if (data.child_filter_expr) { + child_filter_string = ExpressionToFriendlyString(*data.child_filter_expr, column_name); + } + } + return OptionalFilterScalarFun::ToString(child_filter_string); + } else if (func_name == SelectivityOptionalFilterScalarFun::NAME) { + string child_filter_string; + if (func_expr.BindInfo()) { + auto &data = func_expr.BindInfo()->Cast(); + if (data.child_filter_expr) { + child_filter_string = ExpressionToFriendlyString(*data.child_filter_expr, column_name); + } + } + return SelectivityOptionalFilterScalarFun::ToString(child_filter_string); + } + return string(); +} + string ExpressionFilter::ExpressionToFriendlyString(const Expression &expression, const string &column_name) { + if (expression.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &func_expr = expression.Cast(); + auto result = InternalFunctionToString(func_expr, column_name); + if (!result.empty()) { + return result; + } + } + if (expression.GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION) { + auto &conj = expression.Cast(); + if (ContainsInternalTableFilterFunction(expression)) { + string result = "("; + for (idx_t i = 0; i < conj.GetChildren().size(); i++) { + if (i > 0) { + result += conj.GetExpressionType() == ExpressionType::CONJUNCTION_AND ? " AND " : " OR "; + } + result += ExpressionToFriendlyString(*conj.GetChildren()[i], column_name); + } + result += ")"; + return result; + } + } // Default: use standard expression ToString with column name substitution auto expr_copy = expression.Copy(); - auto name_expr = make_uniq(column_name, LogicalType::INVALID, 0ULL); + auto name_expr = make_uniq(Identifier(column_name), LogicalType::INVALID, 0ULL); ReplaceExpressionRecursive(expr_copy, *name_expr, ExpressionType::BOUND_REF); return expr_copy->ToString(); } @@ -318,6 +617,10 @@ string ExpressionFilter::ToString(const string &column_name) const { return ExpressionToFriendlyString(*expr, column_name); } +string ExpressionFilter::DebugToString() const { + return ExpressionToFriendlyString(*expr, "c1"); +} + void ExpressionFilter::ReplaceExpressionRecursive(unique_ptr &expr, const Expression &column, ExpressionType replace_type) { if (expr->GetExpressionType() == replace_type) { @@ -334,15 +637,12 @@ unique_ptr ExpressionFilter::ToExpression(const Expression &column) return expr_copy; } -bool ExpressionFilter::Equals(const TableFilter &other_p) const { - if (!TableFilter::Equals(other_p)) { - return false; - } +bool ExpressionFilter::Equals(const ExpressionFilter &other_p) const { auto &other = other_p.Cast(); return other.expr->Equals(*expr); } -unique_ptr ExpressionFilter::Copy() const { +unique_ptr ExpressionFilter::Copy() const { return make_uniq(expr->Copy()); } diff --git a/src/duckdb/src/planner/filter/in_filter.cpp b/src/duckdb/src/planner/filter/in_filter.cpp index 9af7c25d0..f552e364b 100644 --- a/src/duckdb/src/planner/filter/in_filter.cpp +++ b/src/duckdb/src/planner/filter/in_filter.cpp @@ -6,44 +6,25 @@ namespace duckdb { -InFilter::InFilter(vector values_p) : TableFilter(TableFilterType::IN_FILTER), values(std::move(values_p)) { +LegacyInFilter::LegacyInFilter(vector values_p) + : TableFilter(TableFilterType::LEGACY_IN_FILTER), values(std::move(values_p)) { for (auto &val : values) { if (val.IsNull()) { - throw InternalException("InFilter constant cannot be NULL - use IsNullFilter instead"); + throw InternalException("LegacyInFilter constant cannot be NULL - use IsNullFilter instead"); } } for (idx_t i = 1; i < values.size(); i++) { if (values[0].type() != values[i].type()) { - throw InternalException("InFilter constants must all have the same type"); + throw InternalException("LegacyInFilter constants must all have the same type"); } } if (values.empty()) { - throw InternalException("InFilter constants cannot be empty"); + throw InternalException("LegacyInFilter constants cannot be empty"); } } -FilterPropagateResult InFilter::CheckStatistics(BaseStatistics &stats) const { - throw InternalException("InFilter::CheckStatistics should not be called: InFilters should be converted to " - "ExpressionFilters before statistics checking"); -} - -string InFilter::ToString(const string &column_name) const { - throw InternalException("InFilter::ToString should not be called: InFilters should be converted to " - "ExpressionFilters before rendering"); -} - -unique_ptr InFilter::ToExpression(const Expression &column) const { +unique_ptr LegacyInFilter::ToExpression(const Expression &column) const { return ExpressionFilter::CreateInExpression(column.Copy(), values); } -bool InFilter::Equals(const TableFilter &other_p) const { - throw InternalException("InFilter::Equals should not be called: InFilters should be converted to " - "ExpressionFilters before equality checking"); -} - -unique_ptr InFilter::Copy() const { - throw InternalException("InFilter::Copy should not be called: InFilters should be converted to " - "ExpressionFilters before copying"); -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/filter/null_filter.cpp b/src/duckdb/src/planner/filter/null_filter.cpp index e4ccd9f88..043d0abd9 100644 --- a/src/duckdb/src/planner/filter/null_filter.cpp +++ b/src/duckdb/src/planner/filter/null_filter.cpp @@ -4,57 +4,17 @@ namespace duckdb { -IsNullFilter::IsNullFilter() : TableFilter(TableFilterType::IS_NULL) { +LegacyIsNullFilter::LegacyIsNullFilter() : TableFilter(TableFilterType::LEGACY_IS_NULL) { } -FilterPropagateResult IsNullFilter::CheckStatistics(BaseStatistics &stats) const { - throw InternalException("IsNullFilter::CheckStatistics should not be called: IsNullFilters should be converted " - "to ExpressionFilters before statistics checking"); -} - -string IsNullFilter::ToString(const string &column_name) const { - throw InternalException("IsNullFilter::ToString should not be called: IsNullFilters should be converted to " - "ExpressionFilters before rendering"); -} - -bool IsNullFilter::Equals(const TableFilter &other_p) const { - throw InternalException("IsNullFilter::Equals should not be called: IsNullFilters should be converted to " - "ExpressionFilters before equality checking"); -} - -unique_ptr IsNullFilter::Copy() const { - throw InternalException("IsNullFilter::Copy should not be called: IsNullFilters should be converted to " - "ExpressionFilters before copying"); -} - -unique_ptr IsNullFilter::ToExpression(const Expression &column) const { +unique_ptr LegacyIsNullFilter::ToExpression(const Expression &column) const { return ExpressionFilter::CreateNullCheckExpression(column.Copy(), ExpressionType::OPERATOR_IS_NULL); } -IsNotNullFilter::IsNotNullFilter() : TableFilter(TableFilterType::IS_NOT_NULL) { -} - -FilterPropagateResult IsNotNullFilter::CheckStatistics(BaseStatistics &stats) const { - throw InternalException("IsNotNullFilter::CheckStatistics should not be called: IsNotNullFilters should be " - "converted to ExpressionFilters before statistics checking"); -} - -string IsNotNullFilter::ToString(const string &column_name) const { - throw InternalException("IsNotNullFilter::ToString should not be called: IsNotNullFilters should be converted " - "to ExpressionFilters before rendering"); -} - -bool IsNotNullFilter::Equals(const TableFilter &other_p) const { - throw InternalException("IsNotNullFilter::Equals should not be called: IsNotNullFilters should be converted to " - "ExpressionFilters before equality checking"); -} - -unique_ptr IsNotNullFilter::Copy() const { - throw InternalException("IsNotNullFilter::Copy should not be called: IsNotNullFilters should be converted to " - "ExpressionFilters before copying"); +LegacyIsNotNullFilter::LegacyIsNotNullFilter() : TableFilter(TableFilterType::LEGACY_IS_NOT_NULL) { } -unique_ptr IsNotNullFilter::ToExpression(const Expression &column) const { +unique_ptr LegacyIsNotNullFilter::ToExpression(const Expression &column) const { return ExpressionFilter::CreateNullCheckExpression(column.Copy(), ExpressionType::OPERATOR_IS_NOT_NULL); } diff --git a/src/duckdb/src/planner/filter/optional_filter.cpp b/src/duckdb/src/planner/filter/optional_filter.cpp index ac4e3da60..73c6119a0 100644 --- a/src/duckdb/src/planner/filter/optional_filter.cpp +++ b/src/duckdb/src/planner/filter/optional_filter.cpp @@ -1,35 +1,17 @@ #include "duckdb/planner/table_filter.hpp" #include "duckdb/planner/filter/optional_filter.hpp" #include "duckdb/planner/expression.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" namespace duckdb { -OptionalFilter::OptionalFilter(unique_ptr filter) - : TableFilter(TableFilterType::OPTIONAL_FILTER), child_filter(std::move(filter)) { +LegacyOptionalFilter::LegacyOptionalFilter(unique_ptr filter) + : TableFilter(TableFilterType::LEGACY_OPTIONAL_FILTER), child_filter(std::move(filter)) { } -FilterPropagateResult OptionalFilter::CheckStatistics(BaseStatistics &stats) const { - return child_filter->CheckStatistics(stats); -} - -string OptionalFilter::ToString(const string &column_name) const { - return string("optional: ") + child_filter->ToString(column_name); -} - -unique_ptr OptionalFilter::ToExpression(const Expression &column) const { - return child_filter->ToExpression(column); -} - -idx_t OptionalFilter::FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, - TableFilterState &filter_state, const idx_t scan_count, - idx_t &approved_tuple_count) const { - return scan_count; -} - -unique_ptr OptionalFilter::Copy() const { - auto copy = make_uniq(); - copy->child_filter = child_filter->Copy(); - return duckdb::unique_ptr_cast(std::move(copy)); +unique_ptr LegacyOptionalFilter::ToExpression(const Expression &column) const { + auto child_expr = child_filter ? child_filter->ToExpression(column) : nullptr; + return CreateOptionalFilterExpression(std::move(child_expr), column.GetReturnType()); } } // namespace duckdb diff --git a/src/duckdb/src/planner/filter/perfect_hash_join_filter.cpp b/src/duckdb/src/planner/filter/perfect_hash_join_filter.cpp index 881880b4c..65860e65b 100644 --- a/src/duckdb/src/planner/filter/perfect_hash_join_filter.cpp +++ b/src/duckdb/src/planner/filter/perfect_hash_join_filter.cpp @@ -1,181 +1,36 @@ #include "duckdb/planner/filter/perfect_hash_join_filter.hpp" -#include "duckdb/execution/operator/join/perfect_hash_join_executor.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/planner/table_filter_state.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" namespace duckdb { -PerfectHashJoinFilter::PerfectHashJoinFilter(optional_ptr perfect_join_executor_p, - const string &key_column_name_p, const LogicalType &key_type_p) +LegacyPerfectHashJoinFilter::LegacyPerfectHashJoinFilter( + optional_ptr perfect_join_executor_p, const string &key_column_name_p, + const LogicalType &key_type_p) : TableFilter(TYPE), perfect_join_executor(perfect_join_executor_p), key_column_name(key_column_name_p), key_type(key_type_p) { } -template -static FilterPropagateResult TemplatedCheckStatistics(const PerfectHashJoinExecutor &perfect_join_executor, T min, - T max, const LogicalType &type) { - if (min > max) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; // Invalid stats - } - - T range_typed; - idx_t range; - if (!TrySubtractOperator::Operation(max, min, range_typed) || !TryCast::Operation(range_typed, range) || - range >= DEFAULT_STANDARD_VECTOR_SIZE) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; // Overflow or too wide of a range - } - - Vector range_vec(type, DEFAULT_STANDARD_VECTOR_SIZE); - auto range_data = FlatVector::GetDataMutable(range_vec); - T val = min; - for (; val < max; val += 1) { - *range_data++ = val; - } - *range_data = val; - - const auto total_count = NumericCast(range_typed) + 1; - idx_t approved_tuple_count = 0; - SelectionVector probe_sel(total_count); - perfect_join_executor.FillSelectionVectorSwitchProbe(range_vec, total_count, probe_sel, approved_tuple_count, - nullptr); - - if (approved_tuple_count == 0) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - if (approved_tuple_count == total_count) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; -} - -FilterPropagateResult PerfectHashJoinFilter::CheckStatistics(BaseStatistics &stats) const { - if (!perfect_join_executor || !TypeIsInteger(key_type.InternalType()) || !NumericStats::HasMinMax(stats)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - auto min_val = NumericStats::Min(stats); - auto max_val = NumericStats::Max(stats); - - // The filter was built with key_type (condition_type), but stats are in storage_type - if (stats.GetType() != key_type && (!min_val.DefaultTryCastAs(key_type) || !max_val.DefaultTryCastAs(key_type))) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - switch (key_type.InternalType()) { - case PhysicalType::UINT8: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::UINT16: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::UINT32: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::UINT64: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::UINT128: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::INT8: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::INT16: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::INT32: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::INT64: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - case PhysicalType::INT128: - return TemplatedCheckStatistics(*perfect_join_executor, min_val.GetValueUnsafe(), - max_val.GetValueUnsafe(), key_type); - default: - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } -} - -string PerfectHashJoinFilter::ToString(const string &column_name) const { - return column_name + " IN PHJ(" + key_column_name + ")"; -} - -idx_t PerfectHashJoinFilter::Filter(Vector &keys, SelectionVector &sel, idx_t &approved_tuple_count, - JoinFilterTableFilterState &state) const { - if (!perfect_join_executor) { - return approved_tuple_count; - } - - state.PrepareSlicedKeys(keys, sel, approved_tuple_count); - - // Perform the probe - const idx_t approved_before = approved_tuple_count; - approved_tuple_count = 0; - perfect_join_executor->FillSelectionVectorSwitchProbe(state.keys_sliced_v, approved_before, state.probe_sel, - approved_tuple_count, nullptr); - - if (approved_tuple_count == approved_before) { - return approved_tuple_count; // Nothing was filtered - } - - if (sel.IsSet()) { - for (idx_t idx = 0; idx < approved_tuple_count; idx++) { - const idx_t sliced_sel_idx = state.probe_sel.get_index_unsafe(idx); - const idx_t original_sel_idx = sel.get_index_unsafe(sliced_sel_idx); - sel.set_index(idx, original_sel_idx); - } - } else { - sel.Initialize(state.probe_sel); - } - - return approved_tuple_count; -} - -bool PerfectHashJoinFilter::FilterValue(const Value &value) const { - auto cast_value = value; - if (!cast_value.DefaultTryCastAs(GetKeyType())) { - return true; - } - - Vector keys(cast_value, count_t(1)); - SelectionVector sel; - const idx_t approved_before = 1; - idx_t approved_tuple_count = 0; - perfect_join_executor->FillSelectionVectorSwitchProbe(keys, approved_before, sel, approved_tuple_count, nullptr); - - return approved_tuple_count == 1; -} - -bool PerfectHashJoinFilter::Equals(const TableFilter &other_p) const { - if (!TableFilter::Equals(other_p)) { - return false; - } - auto &other = other_p.Cast(); - return perfect_join_executor.get() == other.perfect_join_executor.get() && key_column_name == other.key_column_name; -} -unique_ptr PerfectHashJoinFilter::Copy() const { - return make_uniq(perfect_join_executor, key_column_name, key_type); -} - -unique_ptr PerfectHashJoinFilter::ToExpression(const Expression &column) const { - auto bound_constant = make_uniq(Value(true)); - return std::move(bound_constant); +unique_ptr LegacyPerfectHashJoinFilter::ToExpression(const Expression &column) const { + auto function = PerfectHashJoinScalarFun::GetFunction(column.GetReturnType()); + auto bind_data = make_uniq(perfect_join_executor, key_column_name, 0.0f, idx_t(0)); + vector> arguments; + arguments.push_back(column.Copy()); + return make_uniq(BoundScalarFunction(function), std::move(arguments), + std::move(bind_data)); } -void PerfectHashJoinFilter::Serialize(Serializer &serializer) const { +void LegacyPerfectHashJoinFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); serializer.WriteProperty(200, "key_column_name", key_column_name); serializer.WriteProperty(201, "key_type", key_type); } -unique_ptr PerfectHashJoinFilter::Deserialize(Deserializer &deserializer) { +unique_ptr LegacyPerfectHashJoinFilter::Deserialize(Deserializer &deserializer) { auto key_column_name = deserializer.ReadProperty(200, "key_column_name"); auto key_type = deserializer.ReadProperty(201, "key_type"); - return make_uniq(nullptr, key_column_name, key_type); + return make_uniq(nullptr, key_column_name, key_type); } } // namespace duckdb diff --git a/src/duckdb/src/planner/filter/prefix_range_filter.cpp b/src/duckdb/src/planner/filter/prefix_range_filter.cpp index 80e4e43ae..966f2b596 100644 --- a/src/duckdb/src/planner/filter/prefix_range_filter.cpp +++ b/src/duckdb/src/planner/filter/prefix_range_filter.cpp @@ -1,584 +1,38 @@ #include "duckdb/planner/filter/prefix_range_filter.hpp" -#include "duckdb/common/allocator.hpp" -#include "duckdb/common/array.hpp" -#include "duckdb/common/assert.hpp" -#include "duckdb/common/bit_utils.hpp" -#include "duckdb/common/enums/filter_propagate_result.hpp" -#include "duckdb/common/enums/vector_type.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/operator/subtract.hpp" #include "duckdb/common/optional_ptr.hpp" -#include "duckdb/common/typedefs.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/selection_vector.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/storage/statistics/string_stats.hpp" -#include "duckdb/storage/buffer_manager.hpp" -#include "duckdb/planner/table_filter_state.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" namespace duckdb { -namespace { - -AllocatedData AllocateBitmap(ClientContext &context, const idx_t word_count, uint64_t *&bitmap_begin) { - const idx_t size = word_count * sizeof(uint64_t); - BufferManager &buffer_manager = BufferManager::GetBufferManager(context); - auto buffer = buffer_manager.GetBufferAllocator().Allocate(64ULL + size); - bitmap_begin = reinterpret_cast((64ULL + reinterpret_cast(buffer.get())) & ~63ULL); - std::fill_n(bitmap_begin, word_count, 0); - return buffer; -} - -struct PrefixRangeBitmapBuildState : public PrefixRangeFilter::BuildState { - explicit PrefixRangeBitmapBuildState(AllocatedData data_p, uint64_t *bitmap_p) - : data(std::move(data_p)), bitmap(bitmap_p) { - } - - AllocatedData data; - uint64_t *bitmap; -}; - -template -class PrefixRangeBitmap { -public: - void Initialize(ClientContext &context, U min_p, U span_p) { - min = min_p; - span = span_p; - shift = 0; - - if (span >= CAP_BITS) { - const auto q = UnsafeNumericCast(span >> MAX_PREFIX_LENGTH); - shift = (q <= 1) ? 0 : (64 - CountZeros::Leading(q - 1)); - } - - const idx_t buckets = UnsafeNumericCast((span >> shift) + 1); - word_count = buckets == 0 ? 1 : (buckets + 63) >> WORD_SHIFT; - - buf_ = AllocateBitmap(context, word_count, bitmap); - - // Only mark initialized as true when local bitmaps are merged. - initialized = false; - } - - unique_ptr InitializeBuildState(ClientContext &context) const { - D_ASSERT(bitmap); - uint64_t *state_bitmap; - auto state_data = AllocateBitmap(context, word_count, state_bitmap); - return make_uniq(std::move(state_data), state_bitmap); - } - - template - void InsertKeys(Vector &keys, idx_t count, uint64_t *state_bitmap) const { - for (const auto &entry : keys.template ValidValues()) { - const U y = CONVERTER::Convert(entry.GetValue()) - min; - // All keys are in-range by construction, so the range check can be omitted here. - const U idx = y >> shift; - state_bitmap[idx >> WORD_SHIFT] |= 1ULL << (idx & WORD_MASK); - } - } - - void MergeBuildState(PrefixRangeBitmapBuildState &state) { - for (idx_t word_idx = 0; word_idx < word_count; word_idx++) { - bitmap[word_idx] |= state.bitmap[word_idx]; - } - initialized = true; - } - - template - inline bool LookupOne(const Value &value) const { - if (value.IsNull()) { - return false; - } - - const U comparable = CONVERTER::Convert(value.GetValueUnsafe()); - const U y = comparable - min; - const U bit_idx = y >> shift; - const uint8_t in_range = y <= span; - const uint32_t word_idx = (bit_idx >> WORD_SHIFT) & (0U - in_range); - const uint8_t bit = (bitmap[word_idx] >> (bit_idx & WORD_MASK)) & 1ULL; - return bit & in_range; - } - - template - idx_t LookupKeys(Vector &keys, SelectionVector &result_sel, idx_t count) const { - idx_t found_count = 0; - for (const auto &entry : keys.template ValidValues()) { - const U comparable = CONVERTER::Convert(entry.GetValue()); - const U y = comparable - min; - const U bit_idx = y >> shift; - const uint8_t in_range = y <= span; - const uint32_t word_idx = (bit_idx >> WORD_SHIFT) & (0U - in_range); - const uint8_t bit = (bitmap[word_idx] >> (bit_idx & WORD_MASK)) & 1ULL; - - result_sel.set_index(found_count, entry.GetIndex()); - found_count += bit & in_range; - } - return found_count; - } - - FilterPropagateResult LookupRange(U lower_bound, U upper_bound) const { - const U lb_y = lower_bound - min; - const U lb_bit_idx = lb_y >> shift; - const auto lb_word_idx = lb_bit_idx >> WORD_SHIFT; - - const U ub_y = upper_bound - min; - const U ub_bit_idx = ub_y >> shift; - const auto ub_word_idx = ub_bit_idx >> WORD_SHIFT; - - const idx_t lb_bit_off = UnsafeNumericCast(lb_bit_idx & UnsafeNumericCast(WORD_MASK)); - const idx_t ub_bit_off = UnsafeNumericCast(ub_bit_idx & UnsafeNumericCast(WORD_MASK)); - - // TODO: Count the amount of 1's in the range, compare to a threshold, and make a decision if we want to use the - // per-row filter for this row group. - if (lb_word_idx == ub_word_idx) { - const auto range_mask = ((~0ULL << lb_bit_off) & (~0ULL >> (WORD_MASK - ub_bit_off))); - if (bitmap[lb_word_idx] & range_mask) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - - const auto lb_word_mask = (~0ULL << lb_bit_off); - if (bitmap[lb_word_idx] & lb_word_mask) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - for (idx_t i = UnsafeNumericCast(lb_word_idx) + 1; i < UnsafeNumericCast(ub_word_idx); i++) { - if (bitmap[i]) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - } - - const auto ub_word_mask = ~0ULL >> (WORD_MASK - ub_bit_off); - if (bitmap[ub_word_idx] & ub_word_mask) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - - bool IsInitialized() const { - return initialized; - } - - U Min() const { - return min; - } - - U Span() const { - return span; - } - -private: - static constexpr idx_t MAX_PREFIX_LENGTH = 20; - static constexpr idx_t CAP_BITS = 1ULL << MAX_PREFIX_LENGTH; - static constexpr idx_t WORD_SHIFT = 6; - static constexpr idx_t WORD_MASK = 63; - - bool initialized = false; - U min; - U span; - idx_t shift; - idx_t word_count; - AllocatedData buf_; - uint64_t *bitmap; -}; - -template -struct NumericConverter { - using comparable_type = typename MakeUnsigned::type; - - static inline comparable_type Convert(T value) { - // Overflow is explicitly allowed for unsigned to signed cast - return static_cast(value); - } -}; - -struct StringPrefixConverter { - static inline uint32_t Convert(const string_t &value) { - return value.GetPrefixIntegerComparable(); - } -}; - -uint32_t StringMinComparable(const Value &value) { - return StringPrefixConverter::Convert(value.GetValueUnsafe()); -} - -uint32_t StringMaxComparable(const Value &value) { - const auto max_string = value.GetValueUnsafe(); - if (max_string.GetSize() >= string_t::PREFIX_BYTES) { - return max_string.GetPrefixIntegerComparable(); - } - - // Pad string prefix with 0xFF to keep correctness if max is truncated at \0 char, e.g., ab\0c -> ab - array padded_prefix; - padded_prefix.fill(char(0xFF)); - for (idx_t i = 0; i < max_string.GetSize(); i++) { - padded_prefix[i] = max_string.GetData()[i]; - } - return string_t(padded_prefix.data(), string_t::PREFIX_BYTES).GetPrefixIntegerComparable(); -} - -template -class NumericPrefixRangeFilter : public PrefixRangeFilter { -private: - using Comparable = typename MakeUnsigned::type; - -public: - void Initialize(ClientContext &context, idx_t number_of_rows, Value min_val, Value max_val) override { - D_ASSERT(min_val <= max_val); - D_ASSERT(number_of_rows > 0); - const auto min = NumericConverter::Convert(min_val.GetValueUnsafe()); - const auto max = NumericConverter::Convert(max_val.GetValueUnsafe()); - bitmap.Initialize(context, min, max - min); - } - - unique_ptr InitializeBuildState(ClientContext &context) const override { - return bitmap.InitializeBuildState(context); - } - - void InsertKeys(Vector &keys, idx_t count, BuildState &state) const override { - auto &bitmap_state = state.Cast(); - bitmap.template InsertKeys>(keys, count, bitmap_state.bitmap); - } - - void MergeBuildState(BuildState &state) override { - bitmap.MergeBuildState(state.Cast()); - } - - idx_t LookupKeys(Vector &keys, SelectionVector &result_sel, idx_t count) const override { - if (keys.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return bitmap.template LookupOne>(keys.GetValue(0)) ? count : 0; - } - return bitmap.template LookupKeys>(keys, result_sel, count); - } - - FilterPropagateResult LookupRange(const Value &lower_bound, const Value &upper_bound) const override { - const auto lb = lower_bound.GetValueUnsafe(); - const auto ub = upper_bound.GetValueUnsafe(); - - const auto bitmap_min = static_cast(bitmap.Min()); - const auto bitmap_max = static_cast(bitmap.Min() + bitmap.Span()); - if (ub < bitmap_min || lb > bitmap_max) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - - const auto adjusted_lb = NumericConverter::Convert(MaxValue(lb, bitmap_min)); - const auto adjusted_ub = NumericConverter::Convert(MinValue(ub, bitmap_max)); - return bitmap.LookupRange(adjusted_lb, adjusted_ub); - } - - bool IsInitialized() const override { - return bitmap.IsInitialized(); - } - -private: - PrefixRangeBitmap bitmap; -}; - -class StringPrefixRangeFilter : public PrefixRangeFilter { -public: - void Initialize(ClientContext &context, idx_t number_of_rows, Value min_val, Value max_val) override { - D_ASSERT(min_val <= max_val); - D_ASSERT(number_of_rows > 0); - const auto min = StringPrefixConverter::Convert(min_val.GetValueUnsafe()); - const auto max = StringPrefixConverter::Convert(max_val.GetValueUnsafe()); - D_ASSERT(min <= max); - bitmap.Initialize(context, min, max - min); - } - - unique_ptr InitializeBuildState(ClientContext &context) const override { - return bitmap.InitializeBuildState(context); - } - - void InsertKeys(Vector &keys, idx_t count, BuildState &state) const override { - auto &bitmap_state = state.Cast(); - bitmap.template InsertKeys(keys, count, bitmap_state.bitmap); - } - - void MergeBuildState(BuildState &state) override { - bitmap.MergeBuildState(state.Cast()); - } - - idx_t LookupKeys(Vector &keys, SelectionVector &result_sel, idx_t count) const override { - if (keys.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return bitmap.template LookupOne(keys.GetValue(0)) ? count : 0; - } - return bitmap.template LookupKeys(keys, result_sel, count); - } - - FilterPropagateResult LookupRange(const Value &lower_bound, const Value &upper_bound) const override { - auto lower_bound_comparable = StringMinComparable(lower_bound); - auto upper_bound_comparable = StringMaxComparable(upper_bound); - if (lower_bound_comparable > upper_bound_comparable) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - const auto bitmap_min = bitmap.Min(); - const auto bitmap_max = bitmap.Min() + bitmap.Span(); - if (upper_bound_comparable < bitmap_min || lower_bound_comparable > bitmap_max) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - - lower_bound_comparable = MaxValue(lower_bound_comparable, bitmap_min); - upper_bound_comparable = MinValue(upper_bound_comparable, bitmap_max); - return bitmap.LookupRange(lower_bound_comparable, upper_bound_comparable); - } - - bool IsInitialized() const override { - return bitmap.IsInitialized(); - } - -private: - PrefixRangeBitmap bitmap; -}; - -template -bool ComputeSpan(const Value &lower_bound, const Value &upper_bound, uhugeint_t &result) { - T lb_value = lower_bound.GetValueUnsafe(); - T ub_value = upper_bound.GetValueUnsafe(); - T res; - if (TrySubtractOperator::Operation(ub_value, lb_value, res)) { - result = Uhugeint::Convert(res); - return true; - } else { - return false; - } -} - -bool ComputeStringPrefixSpan(const Value &lower_bound, const Value &upper_bound, uhugeint_t &result) { -#ifdef DUCKDB_DEBUG_NO_INLINE - return false; -#else - auto lb_value = lower_bound.GetValueUnsafe().GetPrefixIntegerComparable(); - auto ub_value = upper_bound.GetValueUnsafe().GetPrefixIntegerComparable(); - uint32_t res; - if (TrySubtractOperator::Operation(ub_value, lb_value, res)) { - result = Uhugeint::Convert(res); - return true; - } else { - return false; - } -#endif -} - -} // namespace - -unique_ptr PrefixRangeFilter::CreatePrefixRangeFilter(const LogicalType &key_type) { - switch (key_type.InternalType()) { - case PhysicalType::UINT8: - return make_uniq>(); - case PhysicalType::UINT16: - return make_uniq>(); - case PhysicalType::UINT32: - return make_uniq>(); - case PhysicalType::UINT64: - return make_uniq>(); - case PhysicalType::INT8: - return make_uniq>(); - case PhysicalType::INT16: - return make_uniq>(); - case PhysicalType::INT32: - return make_uniq>(); - case PhysicalType::INT64: - return make_uniq>(); - case PhysicalType::VARCHAR: -#ifdef DUCKDB_DEBUG_NO_INLINE - throw NotImplementedException("Prefix range filter is not implemented for type %s", key_type.ToString()); -#else - return make_uniq(); -#endif - case PhysicalType::INT128: - case PhysicalType::UINT128: - default: - throw NotImplementedException("Prefix range filter is not implemented for type %s", key_type.ToString()); - } -} - -bool PrefixRangeFilter::TryComputeSpan(const Value &lower_bound, const Value &upper_bound, uhugeint_t &result) { - if (lower_bound.type().InternalType() != upper_bound.type().InternalType()) { - return false; - } - - switch (lower_bound.type().InternalType()) { - case PhysicalType::UINT8: - return ComputeSpan(lower_bound, upper_bound, result); - case PhysicalType::UINT16: - return ComputeSpan(lower_bound, upper_bound, result); - case PhysicalType::UINT32: - return ComputeSpan(lower_bound, upper_bound, result); - case PhysicalType::UINT64: - return ComputeSpan(lower_bound, upper_bound, result); - case PhysicalType::INT8: - return ComputeSpan(lower_bound, upper_bound, result); - case PhysicalType::INT16: - return ComputeSpan(lower_bound, upper_bound, result); - case PhysicalType::INT32: - return ComputeSpan(lower_bound, upper_bound, result); - case PhysicalType::INT64: - return ComputeSpan(lower_bound, upper_bound, result); - case PhysicalType::VARCHAR: - return ComputeStringPrefixSpan(lower_bound, upper_bound, result); - case PhysicalType::INT128: - case PhysicalType::UINT128: - default: - return false; - } -} - -bool PrefixRangeTableFilter::SupportedType(const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - return true; - case PhysicalType::VARCHAR: -#ifdef DUCKDB_DEBUG_NO_INLINE - return false; -#else - return true; -#endif - case PhysicalType::INT128: - case PhysicalType::UINT128: - default: - return false; - } -} - -PrefixRangeTableFilter::PrefixRangeTableFilter(optional_ptr filter_p, - const string &key_column_name_p, const LogicalType &key_type_p) +LegacyPrefixRangeTableFilter::LegacyPrefixRangeTableFilter(optional_ptr filter_p, + const string &key_column_name_p, + const LogicalType &key_type_p) : TableFilter(TYPE), filter(filter_p), key_column_name(key_column_name_p), key_type(key_type_p) { } -string PrefixRangeTableFilter::ToString(const string &column_name) const { - return column_name + " IN PRF(" + key_column_name + ")"; -} - -idx_t PrefixRangeTableFilter::Filter(Vector &keys, SelectionVector &sel, idx_t &approved_tuple_count, - JoinFilterTableFilterState &state) const { - if (!filter || !filter->IsInitialized()) { - return approved_tuple_count; - } - - state.PrepareSlicedKeys(keys, sel, approved_tuple_count); - - const auto approved_before = approved_tuple_count; - SelectionVector result_sel(approved_before); - approved_tuple_count = filter->LookupKeys(state.keys_sliced_v, result_sel, approved_before); - - if (approved_tuple_count == approved_before) { - // Nothing was filtered - return approved_tuple_count; - } - - if (sel.IsSet()) { - for (idx_t idx = 0; idx < approved_tuple_count; idx++) { - const idx_t sliced_sel_idx = result_sel.get_index_unsafe(idx); - const idx_t original_sel_idx = sel.get_index_unsafe(sliced_sel_idx); - sel.set_index(idx, original_sel_idx); - } - } else { - sel.Initialize(result_sel); - } - - return approved_tuple_count; -} - -bool PrefixRangeTableFilter::FilterValue(const Value &value) const { - if (!filter || !filter->IsInitialized()) { - return true; - } - - auto cast_value = value; - if (!cast_value.DefaultTryCastAs(GetKeyType())) { - return true; - } - - Vector keys(cast_value, count_t(1)); - SelectionVector sel; - idx_t approved_tuple_count = 1; - return filter->LookupKeys(keys, sel, approved_tuple_count) == 1; -} - -FilterPropagateResult PrefixRangeTableFilter::CheckStatistics(BaseStatistics &stats) const { - if (!filter || !filter->IsInitialized()) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - Value min; - Value max; - switch (stats.GetStatsType()) { - case StatisticsType::NUMERIC_STATS: - if (!NumericStats::HasMinMax(stats)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - min = NumericStats::Min(stats); - max = NumericStats::Max(stats); - break; - case StatisticsType::STRING_STATS: - if (stats.GetType().id() != LogicalTypeId::VARCHAR || key_type.id() != LogicalTypeId::VARCHAR) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - if (!StringStats::HasMinMax(stats)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - min = Value::BLOB_RAW(StringStats::Min(stats)); - max = Value::BLOB_RAW(StringStats::Max(stats)); - break; - default: - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - if (min > max) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - // The filter was built with key_type (condition_type), but stats are in storage_type. - if (stats.GetType() != key_type && (!min.DefaultTryCastAs(key_type) || !max.DefaultTryCastAs(key_type))) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - return filter->LookupRange(min, max); -} - -bool PrefixRangeTableFilter::Equals(const TableFilter &other_p) const { - if (!TableFilter::Equals(other_p)) { - return false; - } - const auto &other = other_p.Cast(); - return key_column_name == other.key_column_name && key_type == other.key_type; -} - -unique_ptr PrefixRangeTableFilter::Copy() const { - return make_uniq(this->filter, this->key_column_name, this->key_type); -} -unique_ptr PrefixRangeTableFilter::ToExpression(const Expression &column) const { - auto bound_constant = make_uniq(Value(true)); - return std::move(bound_constant); +unique_ptr LegacyPrefixRangeTableFilter::ToExpression(const Expression &column) const { + auto function = PrefixRangeScalarFun::GetFunction(column.GetReturnType()); + auto bind_data = make_uniq(filter, key_column_name, key_type, 0.0f, idx_t(0)); + vector> arguments; + arguments.push_back(column.Copy()); + return make_uniq(BoundScalarFunction(function), std::move(arguments), + std::move(bind_data)); } -void PrefixRangeTableFilter::Serialize(Serializer &serializer) const { +void LegacyPrefixRangeTableFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); serializer.WriteProperty(200, "key_column_name", key_column_name); serializer.WriteProperty(201, "key_type", key_type); } -unique_ptr PrefixRangeTableFilter::Deserialize(Deserializer &deserializer) { +unique_ptr LegacyPrefixRangeTableFilter::Deserialize(Deserializer &deserializer) { auto key_column_name = deserializer.ReadProperty(200, "key_column_name"); auto key_type = deserializer.ReadProperty(201, "key_type"); - auto result = make_uniq(nullptr, key_column_name, key_type); + auto result = make_uniq(nullptr, key_column_name, key_type); return std::move(result); } diff --git a/src/duckdb/src/planner/filter/selectivity_optional_filter.cpp b/src/duckdb/src/planner/filter/selectivity_optional_filter.cpp index ca4d314e2..08d36adf5 100644 --- a/src/duckdb/src/planner/filter/selectivity_optional_filter.cpp +++ b/src/duckdb/src/planner/filter/selectivity_optional_filter.cpp @@ -7,153 +7,49 @@ //===----------------------------------------------------------------------===// #include "duckdb/planner/filter/selectivity_optional_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/planner/table_filter_state.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/function/compression/compression.hpp" -#include "duckdb/storage/table/column_segment.hpp" namespace duckdb { -SelectivityOptionalFilterState::SelectivityStats::SelectivityStats(const idx_t n_vectors_to_check, - const float selectivity_threshold) - : n_vectors_to_check(n_vectors_to_check), selectivity_threshold(selectivity_threshold), tuples_accepted(0), - tuples_processed(0), vectors_processed(0), status(FilterStatus::ACTIVE), pause_multiplier(0) { -} - -void SelectivityOptionalFilterState::SelectivityStats::Update(idx_t accepted, idx_t processed) { - vectors_processed++; - tuples_accepted += accepted; - tuples_processed += processed; - - static constexpr idx_t VECTOR_PAUSE = 10; - D_ASSERT(n_vectors_to_check < VECTOR_PAUSE); - if (vectors_processed == MaxValue(pause_multiplier, 1) * VECTOR_PAUSE) { - vectors_processed = 0; - tuples_accepted = 0; - tuples_processed = 0; - status = FilterStatus::ACTIVE; - } else if (vectors_processed >= n_vectors_to_check) { - // pause the filter if we processed enough vectors and the selectivity is too high - if (GetSelectivity() >= selectivity_threshold) { - status = FilterStatus::PAUSED_DUE_TO_HIGH_SELECTIVITY; - pause_multiplier++; // increase the pause duration - } else { - pause_multiplier = 0; // selective enough, reset the pause duration - } - } -} - -bool SelectivityOptionalFilterState::SelectivityStats::IsActive() const { - return status == FilterStatus::ACTIVE; -} -double SelectivityOptionalFilterState::SelectivityStats::GetSelectivity() const { - if (tuples_processed == 0) { - return 0.0; - } - return static_cast(tuples_accepted) / static_cast(tuples_processed); -} - -static float GetSelectivityThresholdForType(SelectivityOptionalFilterType type) { - static constexpr float MIN_MAX_THRESHOLD = 0.9f; - static constexpr float BF_THRESHOLD = 0.5f; - static constexpr float PHJ_THRESHOLD = 0.3f; - static constexpr float PRF_THRESHOLD = 0.5f; - - switch (type) { - case SelectivityOptionalFilterType::MIN_MAX: - return MIN_MAX_THRESHOLD; - case SelectivityOptionalFilterType::BF: - return BF_THRESHOLD; - case SelectivityOptionalFilterType::PHJ: - return PHJ_THRESHOLD; - case SelectivityOptionalFilterType::PRF: - return PRF_THRESHOLD; - default: - throw NotImplementedException("GetSelectivityThresholdForType"); - } -} - -static idx_t GetCheckNForType(SelectivityOptionalFilterType type) { - static constexpr idx_t MIN_MAX_CHECK_N = 6; - static constexpr idx_t BF_CHECK_N = 6; - static constexpr idx_t PHJ_CHECK_N = 6; - static constexpr idx_t PRF_CHECK_N = 6; - - switch (type) { - case SelectivityOptionalFilterType::MIN_MAX: - return MIN_MAX_CHECK_N; - case SelectivityOptionalFilterType::BF: - return BF_CHECK_N; - case SelectivityOptionalFilterType::PHJ: - return PHJ_CHECK_N; - case SelectivityOptionalFilterType::PRF: - return PRF_CHECK_N; - default: - throw NotImplementedException("GetCheckNForType"); - } -} - -SelectivityOptionalFilter::SelectivityOptionalFilter(unique_ptr filter, float selectivity_threshold, - idx_t n_vectors_to_check) - : OptionalFilter(std::move(filter)), selectivity_threshold(selectivity_threshold), +LegacySelectivityOptionalFilter::LegacySelectivityOptionalFilter(unique_ptr filter, + float selectivity_threshold, idx_t n_vectors_to_check) + : LegacyOptionalFilter(std::move(filter)), selectivity_threshold(selectivity_threshold), n_vectors_to_check(n_vectors_to_check) { } -SelectivityOptionalFilter::SelectivityOptionalFilter(unique_ptr filter, SelectivityOptionalFilterType type) - : SelectivityOptionalFilter(std::move(filter), GetSelectivityThresholdForType(type), GetCheckNForType(type)) { -} - -FilterPropagateResult SelectivityOptionalFilter::CheckStatistics(BaseStatistics &stats) const { - return child_filter->CheckStatistics(stats); +LegacySelectivityOptionalFilter::LegacySelectivityOptionalFilter(unique_ptr filter, + SelectivityOptionalFilterType type) { + float threshold; + idx_t vectors_to_check; + GetThresholdAndVectorsToCheck(type, threshold, vectors_to_check); + child_filter = std::move(filter); + selectivity_threshold = threshold; + n_vectors_to_check = vectors_to_check; } -void SelectivityOptionalFilter::Serialize(Serializer &serializer) const { - OptionalFilter::Serialize(serializer); +void LegacySelectivityOptionalFilter::Serialize(Serializer &serializer) const { + LegacyOptionalFilter::Serialize(serializer); serializer.WritePropertyWithDefault(201, "selectivity_threshold", selectivity_threshold); serializer.WritePropertyWithDefault(202, "n_vectors_to_check", n_vectors_to_check); } -unique_ptr SelectivityOptionalFilter::Deserialize(Deserializer &deserializer) { - auto result = unique_ptr(new SelectivityOptionalFilter(nullptr, 0.5f, 100)); +unique_ptr LegacySelectivityOptionalFilter::Deserialize(Deserializer &deserializer) { + auto result = unique_ptr(new LegacySelectivityOptionalFilter(nullptr, 0.5f, 100)); deserializer.ReadPropertyWithDefault>(200, "child_filter", result->child_filter); deserializer.ReadPropertyWithDefault(201, "selectivity_threshold", result->selectivity_threshold); deserializer.ReadPropertyWithDefault(202, "n_vectors_to_check", result->n_vectors_to_check); return std::move(result); } -void SelectivityOptionalFilter::FiltersNullValues(const LogicalType &type, bool &filters_nulls, - bool &filters_valid_values, TableFilterState &filter_state) const { - const auto &state = filter_state.Cast(); - return ConstantFun::FiltersNullValues(type, *this->child_filter, filters_nulls, filters_valid_values, - *state.child_state); -} -unique_ptr SelectivityOptionalFilter::InitializeState(ClientContext &context) const { - D_ASSERT(child_filter); - auto child_filter_state = TableFilterState::Initialize(context, *child_filter); - return make_uniq(std::move(child_filter_state), this->n_vectors_to_check, - this->selectivity_threshold); -} - -idx_t SelectivityOptionalFilter::FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, - TableFilterState &filter_state, const idx_t scan_count, - idx_t &approved_tuple_count) const { - auto &state = filter_state.Cast(); - if (state.stats.IsActive()) { - const idx_t approved_before = approved_tuple_count; - const idx_t accepted_count = ColumnSegment::FilterSelection( - sel, vector, vdata, *child_filter, *state.child_state, scan_count, approved_tuple_count); - state.stats.Update(accepted_count, approved_before); - return accepted_count; - } - state.stats.Update(0, 0); - return scan_count; -} -unique_ptr SelectivityOptionalFilter::Copy() const { - auto copy = make_uniq(child_filter->Copy(), selectivity_threshold, n_vectors_to_check); - return unique_ptr_cast(std::move(copy)); +unique_ptr LegacySelectivityOptionalFilter::ToExpression(const Expression &column) const { + auto child_expr = child_filter ? child_filter->ToExpression(column) : nullptr; + return CreateSelectivityOptionalFilterExpression(std::move(child_expr), column.GetReturnType(), + selectivity_threshold, n_vectors_to_check); } } // namespace duckdb diff --git a/src/duckdb/src/planner/filter/struct_filter.cpp b/src/duckdb/src/planner/filter/struct_filter.cpp index a403e1e13..83b432bb2 100644 --- a/src/duckdb/src/planner/filter/struct_filter.cpp +++ b/src/duckdb/src/planner/filter/struct_filter.cpp @@ -7,32 +7,13 @@ namespace duckdb { -StructFilter::StructFilter(idx_t child_idx_p, string child_name_p, unique_ptr child_filter_p) - : TableFilter(TableFilterType::STRUCT_EXTRACT), child_idx(child_idx_p), child_name(std::move(child_name_p)), +LegacyStructFilter::LegacyStructFilter(idx_t child_idx_p, Identifier child_name_p, + unique_ptr child_filter_p) + : TableFilter(TableFilterType::LEGACY_STRUCT_EXTRACT), child_idx(child_idx_p), child_name(std::move(child_name_p)), child_filter(std::move(child_filter_p)) { } -FilterPropagateResult StructFilter::CheckStatistics(BaseStatistics &stats) const { - throw InternalException("StructFilter::CheckStatistics should not be called: StructFilters should be converted " - "to ExpressionFilters before statistics checking"); -} - -string StructFilter::ToString(const string &column_name) const { - throw InternalException("StructFilter::ToString should not be called: StructFilters should be converted to " - "ExpressionFilters before rendering"); -} - -bool StructFilter::Equals(const TableFilter &other_p) const { - throw InternalException("StructFilter::Equals should not be called: StructFilters should be converted to " - "ExpressionFilters before equality checking"); -} - -unique_ptr StructFilter::Copy() const { - throw InternalException("StructFilter::Copy should not be called: StructFilters should be converted to " - "ExpressionFilters before copying"); -} - -unique_ptr StructFilter::ToExpression(const Expression &column) const { +unique_ptr LegacyStructFilter::ToExpression(const Expression &column) const { vector> arguments; arguments.push_back(column.Copy()); arguments.push_back(make_uniq(Value::BIGINT(NumericCast(child_idx + 1)))); diff --git a/src/duckdb/src/planner/filter/table_filter_bloom_function.cpp b/src/duckdb/src/planner/filter/table_filter_bloom_function.cpp new file mode 100644 index 000000000..d50a3f65a --- /dev/null +++ b/src/duckdb/src/planner/filter/table_filter_bloom_function.cpp @@ -0,0 +1,290 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_bloom_function.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/planner/filter/table_filter_function_helpers.hpp" + +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/vector_size.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/storage/statistics/numeric_stats.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +namespace duckdb { + +static constexpr idx_t MAX_NUM_SECTORS = (1ULL << 26); +static constexpr idx_t MIN_NUM_BITS_PER_KEY = 12; +static constexpr idx_t MIN_NUM_BITS = 512; +static constexpr idx_t LOG_SECTOR_SIZE = 6; // a sector is 64 bits, log2(64) = 6 +static constexpr idx_t SHIFT_MASK = 0x3F3F3F3F3F3F3F3F; // 6 bits for 64 positions +static constexpr idx_t N_BITS = 4; // the number of bits to set per hash + +void BloomFilter::Initialize(ClientContext &context_p, idx_t number_of_rows) { + BufferManager &buffer_manager = BufferManager::GetBufferManager(context_p); + + const idx_t min_bits = MaxValue(MIN_NUM_BITS, number_of_rows * MIN_NUM_BITS_PER_KEY); + num_sectors = MinValue(NextPowerOfTwo(min_bits) >> LOG_SECTOR_SIZE, MAX_NUM_SECTORS); + bitmask = num_sectors - 1; + + buf_ = buffer_manager.GetBufferAllocator().Allocate(64 + num_sectors * sizeof(uint64_t)); + // make sure blocks is a 64-byte aligned pointer, i.e., cache-line aligned + bf = reinterpret_cast((64ULL + reinterpret_cast(buf_.get())) & ~63ULL); + std::fill_n(bf, num_sectors, 0); + + initialized = true; +} + +void BloomFilter::Merge(const BloomFilter &other) { + D_ASSERT(initialized); + D_ASSERT(other.initialized); + D_ASSERT(num_sectors == other.num_sectors); + D_ASSERT(bitmask == other.bitmask); + for (idx_t i = 0; i < num_sectors; i++) { + bf[i] |= other.bf[i]; + } +} + +void BloomFilter::Reset() { + buf_.Reset(); + num_sectors = 0; + bitmask = 0; + initialized = false; + bf = nullptr; +} + +inline uint64_t GetMask(const hash_t hash) { + const uint64_t shifts = hash & SHIFT_MASK; + const auto shifts_8 = reinterpret_cast(&shifts); + + uint64_t mask = 0; + + for (idx_t bit_idx = 8 - N_BITS; bit_idx < 8; bit_idx++) { + const uint8_t bit_pos = shifts_8[bit_idx]; + mask |= (1ULL << bit_pos); + } + + return mask; +} + +void BloomFilter::InsertHashes(const Vector &hashes_v) const { + for (auto hash : hashes_v.Values()) { + InsertOne(hash.GetValue()); + } +} + +idx_t BloomFilter::LookupHashes(const Vector &hashes_v, SelectionVector &result_sel, const idx_t count) const { + D_ASSERT(hashes_v.GetVectorType() == VectorType::FLAT_VECTOR); + D_ASSERT(hashes_v.GetType() == LogicalType::HASH); + + const auto hashes = FlatVector::GetData(hashes_v); + idx_t found_count = 0; + for (idx_t i = 0; i < count; i++) { + result_sel.set_index(found_count, i); + found_count += LookupOne(hashes[i]); + } + return found_count; +} + +inline void BloomFilter::InsertOne(const hash_t hash) const { + D_ASSERT(initialized); + const uint64_t bf_offset = hash & bitmask; + const uint64_t mask = GetMask(hash); + atomic &slot = *reinterpret_cast *>(&bf[bf_offset]); + + slot.fetch_or(mask, std::memory_order_relaxed); +} + +inline bool BloomFilter::LookupOne(const uint64_t hash) const { + D_ASSERT(initialized); + const uint64_t bf_offset = hash & bitmask; + const uint64_t mask = GetMask(hash); + atomic &slot = *reinterpret_cast *>(&bf[bf_offset]); + auto bf_entry = slot.load(std::memory_order_relaxed); + + return (bf_entry & mask) == mask; +} + +BloomFilterFunctionData::BloomFilterFunctionData(optional_ptr filter_p, bool filters_null_values_p, + const string &key_column_name_p, const LogicalType &key_type_p, + float selectivity_threshold_p, idx_t n_vectors_to_check_p) + : filter(filter_p), filters_null_values(filters_null_values_p), key_column_name(key_column_name_p), + key_type(key_type_p), selectivity_threshold(selectivity_threshold_p), n_vectors_to_check(n_vectors_to_check_p) { +} + +unique_ptr BloomFilterFunctionData::Copy() const { + return make_uniq(filter, filters_null_values, key_column_name, key_type, + selectivity_threshold, n_vectors_to_check); +} + +bool BloomFilterFunctionData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return filter.get() == other.filter.get() && filters_null_values == other.filters_null_values && + key_column_name == other.key_column_name && key_type == other.key_type; +} + +static idx_t SelectBloomFilter(Vector &input, const BloomFilterFunctionData &func_data, SelectionVector &result_sel, + idx_t count) { + D_ASSERT(func_data.filter); + Vector hashes(LogicalType::HASH, count); + VectorOperations::Hash(input, hashes, count); + hashes.Flatten(); + + UnifiedVectorFormat input_data; + input.ToUnifiedFormat(input_data); + + SelectionVector bloom_sel(count); + const auto bloom_count = func_data.filter->LookupHashes(hashes, bloom_sel, count); + + idx_t result_count = 0; + idx_t bloom_idx = 0; + for (idx_t i = 0; i < count; i++) { + const auto matched = bloom_idx < bloom_count && bloom_sel.get_index_unsafe(bloom_idx) == i; + if (matched) { + bloom_idx++; + } + const auto input_idx = input_data.sel->get_index(i); + bool passed; + if (!input_data.validity.RowIsValid(input_idx)) { + passed = !func_data.filters_null_values; + } else { + passed = matched; + } + if (passed) { + result_sel.set_index(result_count++, i); + } + } + return result_count; +} + +static unique_ptr +BloomFilterInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { + auto &data = bind_data->Cast(); + if (!data.filter) { + return nullptr; + } + return InitSelectivityTrackingLocalState(data.n_vectors_to_check, data.selectivity_threshold); +} + +static idx_t BloomFilterSelect(DataChunk &args, ExpressionState &state, optional_ptr sel, + optional_ptr true_sel, optional_ptr false_sel) { + auto &func_expr = state.expr.Cast(); + auto &func_data = func_expr.BindInfo()->Cast(); + auto local_state_ptr = ExecuteFunctionState::GetFunctionState(state); + auto tracking_state = local_state_ptr ? &local_state_ptr->Cast() : nullptr; + + auto count = args.size(); + if (!func_data.filter) { + return SetAllTrueSelection(count, sel, true_sel, false_sel); + } + if (tracking_state && !tracking_state->IsActive()) { + tracking_state->Update(0, 0); + return SetAllTrueSelection(count, sel, true_sel, false_sel); + } + + SelectionVector temp_true(count); + auto result_true_sel = (!true_sel || (sel && true_sel.get() == sel.get())) ? &temp_true : true_sel.get(); + auto approved_count = SelectBloomFilter(args.data[0], func_data, *result_true_sel, count); + approved_count = TranslateSelection(count, sel, *result_true_sel, approved_count, true_sel, false_sel); + if (tracking_state) { + tracking_state->Update(approved_count, count); + } + return approved_count; +} + +template +static FilterPropagateResult TemplatedBloomFilterPrune(const BloomFilter &bf, const BaseStatistics &stats) { + if (!NumericStats::HasMinMax(stats)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + const auto min = NumericStats::GetMin(stats); + const auto max = NumericStats::GetMax(stats); + if (min > max) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + T range_typed; + idx_t range; + if (!TrySubtractOperator::Operation(max, min, range_typed) || !TryCast::Operation(range_typed, range) || + range >= DEFAULT_STANDARD_VECTOR_SIZE) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + T val = min; + idx_t hits = 0; + for (idx_t i = 0; i <= range; i++) { + hits += bf.LookupOne(Hash(val)); + val += i < range; + } + + if (hits == 0) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + if (hits == range + 1) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +ScalarFunction BloomFilterScalarFun::GetFunction(const LogicalType &input_type) { + ScalarFunction func(NAME, {input_type}, LogicalType::BOOLEAN, nullptr, TableFilterFunctions::Bind); + func.SetInitStateCallback(BloomFilterInitLocalState); + func.SetSelectCallback(BloomFilterSelect); + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + func.SetFilterPruneCallback(BloomFilterScalarFun::FilterPrune); + func.SetSerializeCallback(TableFilterFunctionSerialize); + func.SetDeserializeCallback(TableFilterFunctionDeserialize); + return func; +} + +string BloomFilterScalarFun::ToString(const string &column_name, const string &key_column_name) { + return column_name + " IN BF(" + key_column_name + ")"; +} + +FilterPropagateResult BloomFilterScalarFun::FilterPrune(const FunctionStatisticsPruneInput &input) { + if (!input.bind_data) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + auto &data = input.bind_data->Cast(); + if (!data.filter || !data.filter->IsInitialized()) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + switch (data.key_type.InternalType()) { + case PhysicalType::UINT8: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::UINT16: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::UINT32: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::UINT64: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::UINT128: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::INT8: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::INT16: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::INT32: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::INT64: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + case PhysicalType::INT128: + return TemplatedBloomFilterPrune(*data.filter, input.stats); + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } +} + +ScalarFunction TableFilterBloomFilterFun::GetFunction() { + return BloomFilterScalarFun::GetFunction(LogicalType::ANY); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/table_filter_dynamic_function.cpp b/src/duckdb/src/planner/filter/table_filter_dynamic_function.cpp new file mode 100644 index 000000000..8441d22c4 --- /dev/null +++ b/src/duckdb/src/planner/filter/table_filter_dynamic_function.cpp @@ -0,0 +1,164 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_dynamic_function.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/planner/filter/table_filter_function_helpers.hpp" + +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +DynamicFilterData::DynamicFilterData(ExpressionType comparison_type_p, Value constant_p) + : comparison_type(comparison_type_p), constant(std::move(constant_p)) { +} + +unique_ptr DynamicFilterData::ToExpression(const Expression &column) const { + return BoundComparisonExpression::Create(comparison_type, column.Copy(), + make_uniq(constant)); +} + +bool DynamicFilterData::CompareValue(ExpressionType comparison_type, const Value &constant, const Value &value) { + switch (comparison_type) { + case ExpressionType::COMPARE_EQUAL: + return ValueOperations::Equals(value, constant); + case ExpressionType::COMPARE_NOTEQUAL: + return ValueOperations::NotEquals(value, constant); + case ExpressionType::COMPARE_GREATERTHAN: + return ValueOperations::GreaterThan(value, constant); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return ValueOperations::GreaterThanEquals(value, constant); + case ExpressionType::COMPARE_LESSTHAN: + return ValueOperations::LessThan(value, constant); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return ValueOperations::LessThanEquals(value, constant); + default: + throw InternalException("Unknown comparison type for DynamicFilter: " + EnumUtil::ToString(comparison_type)); + } +} + +FilterPropagateResult DynamicFilterData::CheckStatistics(const BaseStatistics &stats, ExpressionType comparison_type, + const Value &constant) { + switch (constant.type().InternalType()) { + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::UINT128: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + case PhysicalType::INT128: + case PhysicalType::FLOAT: + case PhysicalType::DOUBLE: + return NumericStats::CheckZonemap(stats, comparison_type, array_ptr(constant)); + case PhysicalType::VARCHAR: + if (stats.GetStatsType() == StatisticsType::STRING_STATS) { + return StringStats::CheckZonemap(stats, comparison_type, array_ptr(constant)); + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } +} + +DynamicFilterFunctionData::DynamicFilterFunctionData(shared_ptr filter_data_p) + : filter_data(std::move(filter_data_p)) { +} + +unique_ptr DynamicFilterFunctionData::Copy() const { + return make_uniq(filter_data); +} + +bool DynamicFilterFunctionData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return filter_data.get() == other.filter_data.get(); +} + +static idx_t SelectDynamicFilter(Vector &input, ExpressionType comparison_type, const Value &constant, + SelectionVector &result_sel, idx_t count) { + UnifiedVectorFormat input_data; + input.ToUnifiedFormat(input_data); + + idx_t approved_count = 0; + for (idx_t i = 0; i < count; i++) { + auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { + continue; + } + auto value = input.GetValue(idx); + if (!DynamicFilterData::CompareValue(comparison_type, constant, value)) { + continue; + } + result_sel.set_index(approved_count++, i); + } + return approved_count; +} + +static idx_t DynamicFilterSelect(DataChunk &args, ExpressionState &state, optional_ptr sel, + optional_ptr true_sel, optional_ptr false_sel) { + auto &func_expr = state.expr.Cast(); + auto &func_data = func_expr.BindInfo()->Cast(); + auto count = args.size(); + if (!func_data.filter_data || !func_data.filter_data->initialized.load()) { + return SetAllTrueSelection(count, sel, true_sel, false_sel); + } + + ExpressionType comparison_type; + Value constant; + { + lock_guard lock(func_data.filter_data->lock); + comparison_type = func_data.filter_data->comparison_type; + constant = func_data.filter_data->constant; + } + + SelectionVector temp_true(count); + auto result_true_sel = (!true_sel || (sel && true_sel.get() == sel.get())) ? &temp_true : true_sel.get(); + auto approved_count = SelectDynamicFilter(args.data[0], comparison_type, constant, *result_true_sel, count); + return TranslateSelection(count, sel, *result_true_sel, approved_count, true_sel, false_sel); +} + +ScalarFunction DynamicFilterScalarFun::GetFunction(const LogicalType &input_type) { + ScalarFunction func(NAME, {input_type}, LogicalType::BOOLEAN, nullptr, TableFilterFunctions::Bind); + func.SetSelectCallback(DynamicFilterSelect); + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + func.SetFilterPruneCallback(DynamicFilterScalarFun::FilterPrune); + func.SetSerializeCallback(TableFilterFunctionSerialize); + func.SetDeserializeCallback(TableFilterFunctionDeserialize); + return func; +} + +FilterPropagateResult DynamicFilterScalarFun::FilterPrune(const FunctionStatisticsPruneInput &input) { + if (!input.bind_data) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + auto &data = input.bind_data->Cast(); + if (!data.filter_data) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + lock_guard lock(data.filter_data->lock); + if (!data.filter_data->initialized) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + return DynamicFilterData::CheckStatistics(input.stats, data.filter_data->comparison_type, + data.filter_data->constant); +} + +string DynamicFilterScalarFun::ToString(const string &column_name, bool has_filter_data) { + if (has_filter_data) { + return "Dynamic Filter (" + column_name + ")"; + } + return "Dynamic Filter"; +} + +ScalarFunction TableFilterDynamicFun::GetFunction() { + return DynamicFilterScalarFun::GetFunction(LogicalType::ANY); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/table_filter_functions.cpp b/src/duckdb/src/planner/filter/table_filter_functions.cpp new file mode 100644 index 000000000..436a8ba94 --- /dev/null +++ b/src/duckdb/src/planner/filter/table_filter_functions.cpp @@ -0,0 +1,125 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_functions.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/planner/filter/table_filter_function_helpers.hpp" + +#include "duckdb/common/exception/binder_exception.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +// LCOV_EXCL_START +unique_ptr TableFilterFunctions::Bind(BindScalarFunctionInput &input) { + throw BinderException("Table filter functions are for internal use only!"); +} + +bool TableFilterFunctions::IsTableFilterFunction(const Identifier &name) { + static const char *const TABLE_FILTER_FUNCTIONS[] = { + BloomFilterScalarFun::NAME, DynamicFilterScalarFun::NAME, OptionalFilterScalarFun::NAME, + PerfectHashJoinScalarFun::NAME, PrefixRangeScalarFun::NAME, SelectivityOptionalFilterScalarFun::NAME}; + for (auto function_name : TABLE_FILTER_FUNCTIONS) { + if (name == function_name) { + return true; + } + } + return false; +} +// LCOV_EXCL_STOP + +void GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType type, float &selectivity_threshold, + idx_t &n_vectors_to_check) { + static constexpr float MIN_MAX_THRESHOLD = 0.9f; + static constexpr float BF_THRESHOLD = 0.5f; + static constexpr float PHJ_THRESHOLD = 0.3f; + static constexpr float PRF_THRESHOLD = 0.5f; + + static constexpr idx_t MIN_MAX_CHECK_N = 6; + static constexpr idx_t BF_CHECK_N = 6; + static constexpr idx_t PHJ_CHECK_N = 6; + static constexpr idx_t PRF_CHECK_N = 6; + + switch (type) { + case SelectivityOptionalFilterType::MIN_MAX: + selectivity_threshold = MIN_MAX_THRESHOLD; + n_vectors_to_check = MIN_MAX_CHECK_N; + return; + case SelectivityOptionalFilterType::BF: + selectivity_threshold = BF_THRESHOLD; + n_vectors_to_check = BF_CHECK_N; + return; + case SelectivityOptionalFilterType::PHJ: + selectivity_threshold = PHJ_THRESHOLD; + n_vectors_to_check = PHJ_CHECK_N; + return; + case SelectivityOptionalFilterType::PRF: + selectivity_threshold = PRF_THRESHOLD; + n_vectors_to_check = PRF_CHECK_N; + return; + default: + throw NotImplementedException("GetThresholdAndVectorsToCheck"); + } +} + +static unique_ptr CreateSingleArgumentFunctionExpression(const ScalarFunction &function, + const LogicalType &target_type, + unique_ptr bind_data) { + vector> arguments; + arguments.push_back(make_uniq(target_type, storage_t(0))); + return make_uniq(BoundScalarFunction(function), std::move(arguments), + std::move(bind_data)); +} + +unique_ptr CreateOptionalFilterExpression(unique_ptr child_expr, + const LogicalType &target_type) { + auto function = OptionalFilterScalarFun::GetFunction(target_type); + auto bind_data = make_uniq(std::move(child_expr)); + return CreateSingleArgumentFunctionExpression(function, target_type, std::move(bind_data)); +} + +unique_ptr CreateSelectivityOptionalFilterExpression(unique_ptr child_expr, + const LogicalType &target_type, + float selectivity_threshold, + idx_t n_vectors_to_check) { + auto function = SelectivityOptionalFilterScalarFun::GetFunction(target_type); + auto bind_data = make_uniq(std::move(child_expr), selectivity_threshold, + n_vectors_to_check); + return CreateSingleArgumentFunctionExpression(function, target_type, std::move(bind_data)); +} + +unique_ptr CreateDynamicFilterExpression(shared_ptr filter_data, + const LogicalType &target_type) { + auto function = DynamicFilterScalarFun::GetFunction(target_type); + auto bind_data = make_uniq(std::move(filter_data)); + return CreateSingleArgumentFunctionExpression(function, target_type, std::move(bind_data)); +} + +void TableFilterFunctionSerialize(Serializer &serializer, const optional_ptr bind_data, + const BoundScalarFunction &function) { + // Runtime state cannot be serialized - write nothing +} + +unique_ptr TableFilterFunctionDeserialize(Deserializer &deserializer, BoundScalarFunction &function) { + auto key_type = function.GetArguments().empty() ? LogicalType::ANY : function.GetArguments()[0]; + if (function.GetName() == BloomFilterScalarFun::NAME) { + return make_uniq(nullptr, false, string(), key_type, 0.0f, idx_t(0)); + } + if (function.GetName() == PerfectHashJoinScalarFun::NAME) { + return make_uniq(nullptr, string(), 0.0f, idx_t(0)); + } + if (function.GetName() == PrefixRangeScalarFun::NAME) { + return make_uniq(nullptr, string(), key_type, 0.0f, idx_t(0)); + } + if (function.GetName() == DynamicFilterScalarFun::NAME) { + return make_uniq(nullptr); + } + throw InternalException("Unsupported table filter function \"%s\" during deserialization", function.GetName()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/table_filter_optional_function.cpp b/src/duckdb/src/planner/filter/table_filter_optional_function.cpp new file mode 100644 index 000000000..2769f412a --- /dev/null +++ b/src/duckdb/src/planner/filter/table_filter_optional_function.cpp @@ -0,0 +1,87 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_optional_function.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/planner/filter/table_filter_function_helpers.hpp" + +#include "duckdb/planner/filter/expression_filter.hpp" + +namespace duckdb { + +OptionalFilterFunctionData::OptionalFilterFunctionData(unique_ptr child_filter_expr_p) + : child_filter_expr(std::move(child_filter_expr_p)) { +} + +unique_ptr OptionalFilterFunctionData::Copy() const { + return make_uniq(child_filter_expr ? child_filter_expr->Copy() : nullptr); +} + +bool OptionalFilterFunctionData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + if (!child_filter_expr && !other.child_filter_expr) { + return true; + } + if (!child_filter_expr || !other.child_filter_expr) { + return false; + } + return child_filter_expr->Equals(*other.child_filter_expr); +} + +static void OptionalFilterSerialize(Serializer &serializer, const optional_ptr bind_data, + const BoundScalarFunction &function) { + if (!bind_data) { + return; + } + auto &data = bind_data->Cast(); + serializer.WritePropertyWithDefault>(200, "child_filter_expr", data.child_filter_expr); +} + +static unique_ptr OptionalFilterDeserialize(Deserializer &deserializer, BoundScalarFunction &function) { + auto child_filter_expr = deserializer.ReadPropertyWithDefault>(200, "child_filter_expr"); + return make_uniq(std::move(child_filter_expr)); +} + +static void OptionalFilterFunction(DataChunk &args, ExpressionState &state, Vector &result) { + SetAllTrue(args, result); +} + +static idx_t OptionalFilterSelect(DataChunk &args, ExpressionState &state, optional_ptr sel, + optional_ptr true_sel, optional_ptr false_sel) { + return SetAllTrueSelection(args.size(), sel, true_sel, false_sel); +} + +ScalarFunction OptionalFilterScalarFun::GetFunction(const LogicalType &input_type) { + ScalarFunction func(NAME, {input_type}, LogicalType::BOOLEAN, OptionalFilterFunction, TableFilterFunctions::Bind); + func.SetSelectCallback(OptionalFilterSelect); + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + func.SetFilterPruneCallback(OptionalFilterScalarFun::FilterPrune); + func.SetSerializeCallback(OptionalFilterSerialize); + func.SetDeserializeCallback(OptionalFilterDeserialize); + return func; +} + +FilterPropagateResult OptionalFilterScalarFun::FilterPrune(const FunctionStatisticsPruneInput &input) { + if (!input.bind_data) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + auto &data = input.bind_data->Cast(); + if (!data.child_filter_expr) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + return ExpressionFilter::CheckExpressionStatistics(*data.child_filter_expr, input.stats); +} + +string OptionalFilterScalarFun::ToString(const string &child_filter_string) { + return FormatOptionalFilterString(child_filter_string); +} + +ScalarFunction TableFilterOptionalFun::GetFunction() { + return OptionalFilterScalarFun::GetFunction(LogicalType::ANY); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/table_filter_perfect_hash_join_function.cpp b/src/duckdb/src/planner/filter/table_filter_perfect_hash_join_function.cpp new file mode 100644 index 000000000..9b5d45743 --- /dev/null +++ b/src/duckdb/src/planner/filter/table_filter_perfect_hash_join_function.cpp @@ -0,0 +1,196 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_perfect_hash_join_function.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/planner/filter/table_filter_function_helpers.hpp" + +#include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/vector_size.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/execution/operator/join/perfect_hash_join_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/storage/statistics/numeric_stats.hpp" + +namespace duckdb { + +PerfectHashJoinFunctionData::PerfectHashJoinFunctionData(optional_ptr executor_p, + const string &key_column_name_p, float selectivity_threshold_p, + idx_t n_vectors_to_check_p) + : executor(executor_p), key_column_name(key_column_name_p), selectivity_threshold(selectivity_threshold_p), + n_vectors_to_check(n_vectors_to_check_p) { +} + +unique_ptr PerfectHashJoinFunctionData::Copy() const { + return make_uniq(executor, key_column_name, selectivity_threshold, n_vectors_to_check); +} + +bool PerfectHashJoinFunctionData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return executor.get() == other.executor.get() && key_column_name == other.key_column_name; +} + +static idx_t SelectPerfectHashJoin(Vector &input, const PerfectHashJoinFunctionData &func_data, + SelectionVector &result_sel, idx_t count) { + D_ASSERT(func_data.executor); + idx_t approved_count = 0; + func_data.executor->FillSelectionVectorSwitchProbe(input, count, result_sel, approved_count, nullptr); + return approved_count; +} + +static unique_ptr +PerfectHashJoinInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { + auto &data = bind_data->Cast(); + if (!data.executor) { + return nullptr; + } + return InitSelectivityTrackingLocalState(data.n_vectors_to_check, data.selectivity_threshold); +} + +static void PerfectHashJoinFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &func_data = func_expr.BindInfo()->Cast(); + auto local_state_ptr = ExecuteFunctionState::GetFunctionState(state); + auto tracking_state = local_state_ptr ? &local_state_ptr->Cast() : nullptr; + + if (!func_data.executor) { + SetAllTrue(args, result); + return; + } + + ExecuteWithSelectivityTracking(args, result, tracking_state, [&] { + SelectionVector probe_sel(args.size()); + auto approved_count = SelectPerfectHashJoin(args.data[0], func_data, probe_sel, args.size()); + SelectionToBooleanResult(args.size(), probe_sel, approved_count, result); + return approved_count; + }); +} + +static idx_t PerfectHashJoinSelect(DataChunk &args, ExpressionState &state, optional_ptr sel, + optional_ptr true_sel, optional_ptr false_sel) { + auto &func_expr = state.expr.Cast(); + auto &func_data = func_expr.BindInfo()->Cast(); + auto local_state_ptr = ExecuteFunctionState::GetFunctionState(state); + auto tracking_state = local_state_ptr ? &local_state_ptr->Cast() : nullptr; + + auto count = args.size(); + if (!func_data.executor) { + return SetAllTrueSelection(count, sel, true_sel, false_sel); + } + if (tracking_state && !tracking_state->IsActive()) { + tracking_state->Update(0, 0); + return SetAllTrueSelection(count, sel, true_sel, false_sel); + } + + SelectionVector temp_true(count); + auto result_true_sel = (!true_sel || (sel && true_sel.get() == sel.get())) ? &temp_true : true_sel.get(); + auto approved_count = SelectPerfectHashJoin(args.data[0], func_data, *result_true_sel, count); + approved_count = TranslateSelection(count, sel, *result_true_sel, approved_count, true_sel, false_sel); + if (tracking_state) { + tracking_state->Update(approved_count, count); + } + return approved_count; +} + +template +static FilterPropagateResult TemplatedPerfectHashJoinPrune(const PerfectHashJoinExecutor &executor, + const BaseStatistics &stats) { + if (!NumericStats::HasMinMax(stats)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + const auto min = NumericStats::GetMin(stats); + const auto max = NumericStats::GetMax(stats); + if (min > max) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + T range_typed; + idx_t range; + if (!TrySubtractOperator::Operation(max, min, range_typed) || !TryCast::Operation(range_typed, range) || + range >= DEFAULT_STANDARD_VECTOR_SIZE) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + Vector range_vec(stats.GetType(), DEFAULT_STANDARD_VECTOR_SIZE); + auto range_data = FlatVector::GetDataMutable(range_vec); + T val = min; + for (; val < max; val += 1) { + *range_data++ = val; + } + *range_data = val; + + const auto total_count = NumericCast(range_typed) + 1; + idx_t approved_tuple_count = 0; + SelectionVector probe_sel(total_count); + executor.FillSelectionVectorSwitchProbe(range_vec, total_count, probe_sel, approved_tuple_count, nullptr); + + if (approved_tuple_count == 0) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + if (approved_tuple_count == total_count) { + return FilterPropagateResult::FILTER_ALWAYS_TRUE; + } + return FilterPropagateResult::NO_PRUNING_POSSIBLE; +} + +ScalarFunction PerfectHashJoinScalarFun::GetFunction(const LogicalType &input_type) { + ScalarFunction func(NAME, {input_type}, LogicalType::BOOLEAN, PerfectHashJoinFunction, TableFilterFunctions::Bind); + func.SetInitStateCallback(PerfectHashJoinInitLocalState); + func.SetSelectCallback(PerfectHashJoinSelect); + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + func.SetFilterPruneCallback(PerfectHashJoinScalarFun::FilterPrune); + func.SetSerializeCallback(TableFilterFunctionSerialize); + func.SetDeserializeCallback(TableFilterFunctionDeserialize); + return func; +} + +string PerfectHashJoinScalarFun::ToString(const string &column_name, const string &key_column_name) { + return column_name + " IN PHJ(" + key_column_name + ")"; +} + +FilterPropagateResult PerfectHashJoinScalarFun::FilterPrune(const FunctionStatisticsPruneInput &input) { + if (!input.bind_data) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + auto &data = input.bind_data->Cast(); + if (!data.executor) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + switch (input.stats.GetType().InternalType()) { + case PhysicalType::UINT8: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::UINT16: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::UINT32: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::UINT64: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::UINT128: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::INT8: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::INT16: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::INT32: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::INT64: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + case PhysicalType::INT128: + return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } +} + +ScalarFunction TableFilterPerfectHashJoinFun::GetFunction() { + return PerfectHashJoinScalarFun::GetFunction(LogicalType::ANY); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/table_filter_prefix_range_function.cpp b/src/duckdb/src/planner/filter/table_filter_prefix_range_function.cpp new file mode 100644 index 000000000..dec6475df --- /dev/null +++ b/src/duckdb/src/planner/filter/table_filter_prefix_range_function.cpp @@ -0,0 +1,581 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_prefix_range_function.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/planner/filter/table_filter_function_helpers.hpp" + +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/storage/statistics/numeric_stats.hpp" + +#include "duckdb/common/allocator.hpp" +#include "duckdb/common/array.hpp" +#include "duckdb/common/assert.hpp" +#include "duckdb/common/bit_utils.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/enums/vector_type.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/selection_vector.hpp" +#include "duckdb/common/types/uhugeint.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/uhugeint.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/storage/statistics/string_stats.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/planner/table_filter_state.hpp" + +namespace duckdb { + +AllocatedData AllocateBitmap(ClientContext &context, const idx_t word_count, uint64_t *&bitmap_begin) { + const idx_t size = word_count * sizeof(uint64_t); + BufferManager &buffer_manager = BufferManager::GetBufferManager(context); + auto buffer = buffer_manager.GetBufferAllocator().Allocate(64ULL + size); + bitmap_begin = reinterpret_cast((64ULL + reinterpret_cast(buffer.get())) & ~63ULL); + std::fill_n(bitmap_begin, word_count, 0); + return buffer; +} + +struct PrefixRangeBitmapBuildState : public PrefixRangeFilter::BuildState { + explicit PrefixRangeBitmapBuildState(AllocatedData data_p, uint64_t *bitmap_p) + : data(std::move(data_p)), bitmap(bitmap_p) { + } + + AllocatedData data; + uint64_t *bitmap; +}; + +template +class PrefixRangeBitmap { +public: + void Initialize(ClientContext &context, U min_p, U span_p) { + min = min_p; + span = span_p; + shift = 0; + + if (span >= CAP_BITS) { + const auto q = UnsafeNumericCast(span >> MAX_PREFIX_LENGTH); + shift = (q <= 1) ? 0 : (64 - CountZeros::Leading(q - 1)); + } + + const idx_t buckets = UnsafeNumericCast((span >> shift) + 1); + word_count = buckets == 0 ? 1 : (buckets + 63) >> WORD_SHIFT; + + buf_ = AllocateBitmap(context, word_count, bitmap); + + // Only mark initialized as true when local bitmaps are merged. + initialized = false; + } + + unique_ptr InitializeBuildState(ClientContext &context) const { + D_ASSERT(bitmap); + uint64_t *state_bitmap; + auto state_data = AllocateBitmap(context, word_count, state_bitmap); + return make_uniq(std::move(state_data), state_bitmap); + } + + template + void InsertKeys(Vector &keys, uint64_t *state_bitmap) const { + for (const auto &entry : keys.template ValidValues()) { + const U y = CONVERTER::Convert(entry.GetValue()) - min; + // All keys are in-range by construction, so the range check can be omitted here. + const U idx = y >> shift; + state_bitmap[idx >> WORD_SHIFT] |= 1ULL << (idx & WORD_MASK); + } + } + + void MergeBuildState(PrefixRangeBitmapBuildState &state) { + for (idx_t word_idx = 0; word_idx < word_count; word_idx++) { + bitmap[word_idx] |= state.bitmap[word_idx]; + } + initialized = true; + } + + template + inline bool LookupOne(const Value &value) const { + if (value.IsNull()) { + return false; + } + + const U comparable = CONVERTER::Convert(value.GetValueUnsafe()); + const U y = comparable - min; + const U bit_idx = y >> shift; + const uint8_t in_range = y <= span; + const uint32_t word_idx = (bit_idx >> WORD_SHIFT) & (0U - in_range); + const uint8_t bit = (bitmap[word_idx] >> (bit_idx & WORD_MASK)) & 1ULL; + return bit & in_range; + } + + template + idx_t LookupKeys(Vector &keys, SelectionVector &result_sel, idx_t count) const { + idx_t found_count = 0; + for (const auto &entry : keys.template ValidValues()) { + const U comparable = CONVERTER::Convert(entry.GetValue()); + const U y = comparable - min; + const U bit_idx = y >> shift; + const uint8_t in_range = y <= span; + const uint32_t word_idx = (bit_idx >> WORD_SHIFT) & (0U - in_range); + const uint8_t bit = (bitmap[word_idx] >> (bit_idx & WORD_MASK)) & 1ULL; + + result_sel.set_index(found_count, entry.GetIndex()); + found_count += bit & in_range; + } + return found_count; + } + + FilterPropagateResult LookupRange(U lower_bound, U upper_bound) const { + const U lb_y = lower_bound - min; + const U lb_bit_idx = lb_y >> shift; + const auto lb_word_idx = lb_bit_idx >> WORD_SHIFT; + + const U ub_y = upper_bound - min; + const U ub_bit_idx = ub_y >> shift; + const auto ub_word_idx = ub_bit_idx >> WORD_SHIFT; + + const idx_t lb_bit_off = UnsafeNumericCast(lb_bit_idx & UnsafeNumericCast(WORD_MASK)); + const idx_t ub_bit_off = UnsafeNumericCast(ub_bit_idx & UnsafeNumericCast(WORD_MASK)); + + // TODO: Count the amount of 1's in the range, compare to a threshold, and make a decision if we want to use the + // per-row filter for this row group. + if (lb_word_idx == ub_word_idx) { + const auto range_mask = ((~0ULL << lb_bit_off) & (~0ULL >> (WORD_MASK - ub_bit_off))); + if (bitmap[lb_word_idx] & range_mask) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + const auto lb_word_mask = (~0ULL << lb_bit_off); + if (bitmap[lb_word_idx] & lb_word_mask) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + for (idx_t i = UnsafeNumericCast(lb_word_idx) + 1; i < UnsafeNumericCast(ub_word_idx); i++) { + if (bitmap[i]) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + } + + const auto ub_word_mask = ~0ULL >> (WORD_MASK - ub_bit_off); + if (bitmap[ub_word_idx] & ub_word_mask) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + bool IsInitialized() const { + return initialized; + } + + U Min() const { + return min; + } + + U Span() const { + return span; + } + +private: + static constexpr idx_t MAX_PREFIX_LENGTH = 20; + static constexpr idx_t CAP_BITS = 1ULL << MAX_PREFIX_LENGTH; + static constexpr idx_t WORD_SHIFT = 6; + static constexpr idx_t WORD_MASK = 63; + + bool initialized = false; + U min; + U span; + idx_t shift; + idx_t word_count; + AllocatedData buf_; + uint64_t *bitmap; +}; + +template +struct NumericConverter { + using comparable_type = typename MakeUnsigned::type; + + static inline comparable_type Convert(T value) { + // Overflow is explicitly allowed for unsigned to signed cast + return static_cast(value); + } +}; + +struct StringPrefixConverter { + static inline uint32_t Convert(const string_t &value) { + return value.GetPrefixIntegerComparable(); + } +}; + +uint32_t StringMinComparable(const Value &value) { + return StringPrefixConverter::Convert(value.GetValueUnsafe()); +} + +uint32_t StringMaxComparable(const Value &value) { + const auto max_string = value.GetValueUnsafe(); + if (max_string.GetSize() >= string_t::PREFIX_BYTES) { + return max_string.GetPrefixIntegerComparable(); + } + + // Pad string prefix with 0xFF to keep correctness if max is truncated at \0 char, e.g., ab\0c -> ab + array padded_prefix; + padded_prefix.fill(char(0xFF)); + for (idx_t i = 0; i < max_string.GetSize(); i++) { + padded_prefix[i] = max_string.GetData()[i]; + } + return string_t(padded_prefix.data(), string_t::PREFIX_BYTES).GetPrefixIntegerComparable(); +} + +template +class NumericPrefixRangeFilter : public PrefixRangeFilter { +private: + using Comparable = typename MakeUnsigned::type; + +public: + void Initialize(ClientContext &context, idx_t number_of_rows, Value min_val, Value max_val) override { + D_ASSERT(min_val <= max_val); + D_ASSERT(number_of_rows > 0); + const auto min = NumericConverter::Convert(min_val.GetValueUnsafe()); + const auto max = NumericConverter::Convert(max_val.GetValueUnsafe()); + bitmap.Initialize(context, min, max - min); + } + + unique_ptr InitializeBuildState(ClientContext &context) const override { + return bitmap.InitializeBuildState(context); + } + + void InsertKeys(Vector &keys, BuildState &state) const override { + auto &bitmap_state = state.Cast(); + bitmap.template InsertKeys>(keys, bitmap_state.bitmap); + } + + void MergeBuildState(BuildState &state) override { + bitmap.MergeBuildState(state.Cast()); + } + + idx_t LookupKeys(Vector &keys, SelectionVector &result_sel, idx_t count) const override { + if (keys.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return bitmap.template LookupOne>(keys.GetValue(0)) ? count : 0; + } + return bitmap.template LookupKeys>(keys, result_sel, count); + } + + FilterPropagateResult LookupRange(const Value &lower_bound, const Value &upper_bound) const override { + const auto lb = lower_bound.GetValueUnsafe(); + const auto ub = upper_bound.GetValueUnsafe(); + + const auto bitmap_min = static_cast(bitmap.Min()); + const auto bitmap_max = static_cast(bitmap.Min() + bitmap.Span()); + if (ub < bitmap_min || lb > bitmap_max) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + const auto adjusted_lb = NumericConverter::Convert(MaxValue(lb, bitmap_min)); + const auto adjusted_ub = NumericConverter::Convert(MinValue(ub, bitmap_max)); + return bitmap.LookupRange(adjusted_lb, adjusted_ub); + } + + bool IsInitialized() const override { + return bitmap.IsInitialized(); + } + +private: + PrefixRangeBitmap bitmap; +}; + +class StringPrefixRangeFilter : public PrefixRangeFilter { +public: + void Initialize(ClientContext &context, idx_t number_of_rows, Value min_val, Value max_val) override { + D_ASSERT(min_val <= max_val); + D_ASSERT(number_of_rows > 0); + const auto min = StringPrefixConverter::Convert(min_val.GetValueUnsafe()); + const auto max = StringPrefixConverter::Convert(max_val.GetValueUnsafe()); + D_ASSERT(min <= max); + bitmap.Initialize(context, min, max - min); + } + + unique_ptr InitializeBuildState(ClientContext &context) const override { + return bitmap.InitializeBuildState(context); + } + + void InsertKeys(Vector &keys, BuildState &state) const override { + auto &bitmap_state = state.Cast(); + bitmap.template InsertKeys(keys, bitmap_state.bitmap); + } + + void MergeBuildState(BuildState &state) override { + bitmap.MergeBuildState(state.Cast()); + } + + idx_t LookupKeys(Vector &keys, SelectionVector &result_sel, idx_t count) const override { + if (keys.GetVectorType() == VectorType::CONSTANT_VECTOR) { + return bitmap.template LookupOne(keys.GetValue(0)) ? count : 0; + } + return bitmap.template LookupKeys(keys, result_sel, count); + } + + FilterPropagateResult LookupRange(const Value &lower_bound, const Value &upper_bound) const override { + auto lower_bound_comparable = StringMinComparable(lower_bound); + auto upper_bound_comparable = StringMaxComparable(upper_bound); + if (lower_bound_comparable > upper_bound_comparable) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + + const auto bitmap_min = bitmap.Min(); + const auto bitmap_max = bitmap.Min() + bitmap.Span(); + if (upper_bound_comparable < bitmap_min || lower_bound_comparable > bitmap_max) { + return FilterPropagateResult::FILTER_ALWAYS_FALSE; + } + + lower_bound_comparable = MaxValue(lower_bound_comparable, bitmap_min); + upper_bound_comparable = MinValue(upper_bound_comparable, bitmap_max); + return bitmap.LookupRange(lower_bound_comparable, upper_bound_comparable); + } + + bool IsInitialized() const override { + return bitmap.IsInitialized(); + } + +private: + PrefixRangeBitmap bitmap; +}; + +template +bool ComputeSpan(const Value &lower_bound, const Value &upper_bound, uhugeint_t &result) { + T lb_value = lower_bound.GetValueUnsafe(); + T ub_value = upper_bound.GetValueUnsafe(); + T res; + if (TrySubtractOperator::Operation(ub_value, lb_value, res)) { + result = Uhugeint::Convert(res); + return true; + } else { + return false; + } +} + +bool ComputeStringPrefixSpan(const Value &lower_bound, const Value &upper_bound, uhugeint_t &result) { +#ifdef DUCKDB_DEBUG_NO_INLINE + return false; +#else + auto lb_value = lower_bound.GetValueUnsafe().GetPrefixIntegerComparable(); + auto ub_value = upper_bound.GetValueUnsafe().GetPrefixIntegerComparable(); + uint32_t res; + if (TrySubtractOperator::Operation(ub_value, lb_value, res)) { + result = Uhugeint::Convert(res); + return true; + } else { + return false; + } +#endif +} + +unique_ptr PrefixRangeFilter::CreatePrefixRangeFilter(const LogicalType &key_type) { + switch (key_type.InternalType()) { + case PhysicalType::UINT8: + return make_uniq>(); + case PhysicalType::UINT16: + return make_uniq>(); + case PhysicalType::UINT32: + return make_uniq>(); + case PhysicalType::UINT64: + return make_uniq>(); + case PhysicalType::INT8: + return make_uniq>(); + case PhysicalType::INT16: + return make_uniq>(); + case PhysicalType::INT32: + return make_uniq>(); + case PhysicalType::INT64: + return make_uniq>(); + case PhysicalType::VARCHAR: +#ifdef DUCKDB_DEBUG_NO_INLINE + throw NotImplementedException("Prefix range filter is not implemented for type %s", key_type.ToString()); +#else + return make_uniq(); +#endif + case PhysicalType::INT128: + case PhysicalType::UINT128: + default: + throw NotImplementedException("Prefix range filter is not implemented for type %s", key_type.ToString()); + } +} + +bool PrefixRangeFilter::TryComputeSpan(const Value &lower_bound, const Value &upper_bound, uhugeint_t &result) { + if (lower_bound.type().InternalType() != upper_bound.type().InternalType()) { + return false; + } + + switch (lower_bound.type().InternalType()) { + case PhysicalType::UINT8: + return ComputeSpan(lower_bound, upper_bound, result); + case PhysicalType::UINT16: + return ComputeSpan(lower_bound, upper_bound, result); + case PhysicalType::UINT32: + return ComputeSpan(lower_bound, upper_bound, result); + case PhysicalType::UINT64: + return ComputeSpan(lower_bound, upper_bound, result); + case PhysicalType::INT8: + return ComputeSpan(lower_bound, upper_bound, result); + case PhysicalType::INT16: + return ComputeSpan(lower_bound, upper_bound, result); + case PhysicalType::INT32: + return ComputeSpan(lower_bound, upper_bound, result); + case PhysicalType::INT64: + return ComputeSpan(lower_bound, upper_bound, result); + case PhysicalType::VARCHAR: + return ComputeStringPrefixSpan(lower_bound, upper_bound, result); + case PhysicalType::INT128: + case PhysicalType::UINT128: + default: + return false; + } +} + +bool PrefixRangeFilter::SupportedType(const LogicalType &type) { + switch (type.InternalType()) { + case PhysicalType::UINT8: + case PhysicalType::UINT16: + case PhysicalType::UINT32: + case PhysicalType::UINT64: + case PhysicalType::INT8: + case PhysicalType::INT16: + case PhysicalType::INT32: + case PhysicalType::INT64: + return true; + case PhysicalType::VARCHAR: +#ifdef DUCKDB_DEBUG_NO_INLINE + return false; +#else + return true; +#endif + case PhysicalType::INT128: + case PhysicalType::UINT128: + default: + return false; + } +} + +PrefixRangeFunctionData::PrefixRangeFunctionData(optional_ptr filter_p, + const string &key_column_name_p, const LogicalType &key_type_p, + float selectivity_threshold_p, idx_t n_vectors_to_check_p) + : filter(filter_p), key_column_name(key_column_name_p), key_type(key_type_p), + selectivity_threshold(selectivity_threshold_p), n_vectors_to_check(n_vectors_to_check_p) { +} + +unique_ptr PrefixRangeFunctionData::Copy() const { + return make_uniq(filter, key_column_name, key_type, selectivity_threshold, + n_vectors_to_check); +} + +bool PrefixRangeFunctionData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + return filter.get() == other.filter.get() && key_column_name == other.key_column_name && key_type == other.key_type; +} + +static idx_t SelectPrefixRange(Vector &input, const PrefixRangeFunctionData &func_data, SelectionVector &result_sel, + idx_t count) { + D_ASSERT(func_data.filter); + return func_data.filter->LookupKeys(input, result_sel, count); +} + +static unique_ptr +PrefixRangeInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { + auto &data = bind_data->Cast(); + if (!data.filter) { + return nullptr; + } + return InitSelectivityTrackingLocalState(data.n_vectors_to_check, data.selectivity_threshold); +} + +static idx_t PrefixRangeSelect(DataChunk &args, ExpressionState &state, optional_ptr sel, + optional_ptr true_sel, optional_ptr false_sel) { + auto &func_expr = state.expr.Cast(); + auto &func_data = func_expr.BindInfo()->Cast(); + auto local_state_ptr = ExecuteFunctionState::GetFunctionState(state); + auto tracking_state = local_state_ptr ? &local_state_ptr->Cast() : nullptr; + + auto count = args.size(); + if (!func_data.filter || !func_data.filter->IsInitialized()) { + return SetAllTrueSelection(count, sel, true_sel, false_sel); + } + if (tracking_state && !tracking_state->IsActive()) { + tracking_state->Update(0, 0); + return SetAllTrueSelection(count, sel, true_sel, false_sel); + } + + SelectionVector temp_true(count); + auto result_true_sel = (!true_sel || (sel && true_sel.get() == sel.get())) ? &temp_true : true_sel.get(); + auto approved_count = SelectPrefixRange(args.data[0], func_data, *result_true_sel, count); + approved_count = TranslateSelection(count, sel, *result_true_sel, approved_count, true_sel, false_sel); + if (tracking_state) { + tracking_state->Update(approved_count, count); + } + return approved_count; +} + +ScalarFunction PrefixRangeScalarFun::GetFunction(const LogicalType &input_type) { + ScalarFunction func(NAME, {input_type}, LogicalType::BOOLEAN, nullptr, TableFilterFunctions::Bind); + func.SetInitStateCallback(PrefixRangeInitLocalState); + func.SetSelectCallback(PrefixRangeSelect); + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + func.SetFilterPruneCallback(PrefixRangeScalarFun::FilterPrune); + func.SetSerializeCallback(TableFilterFunctionSerialize); + func.SetDeserializeCallback(TableFilterFunctionDeserialize); + return func; +} + +string PrefixRangeScalarFun::ToString(const string &column_name, const string &key_column_name) { + return column_name + " IN PRF(" + key_column_name + ")"; +} + +FilterPropagateResult PrefixRangeScalarFun::FilterPrune(const FunctionStatisticsPruneInput &input) { + if (!input.bind_data) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + auto &data = input.bind_data->Cast(); + if (!data.filter || !data.filter->IsInitialized()) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + switch (input.stats.GetStatsType()) { + case StatisticsType::NUMERIC_STATS: { + if (!NumericStats::HasMinMax(input.stats)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + const auto min = NumericStats::Min(input.stats); + const auto max = NumericStats::Max(input.stats); + if (min > max) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + return data.filter->LookupRange(min, max); + } + case StatisticsType::STRING_STATS: { + if (!StringStats::HasMinMax(input.stats)) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + // String stats may contain raw parquet bytes that are not valid UTF-8. Reconstruct them as BLOBs so the + // prefix-range comparable logic can inspect the raw bytes without value-construction validation. + return data.filter->LookupRange(Value::BLOB_RAW(StringStats::Min(input.stats)), + Value::BLOB_RAW(StringStats::Max(input.stats))); + } + default: + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } +} + +ScalarFunction TableFilterPrefixRangeFun::GetFunction() { + return PrefixRangeScalarFun::GetFunction(LogicalType::ANY); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/table_filter_selectivity_optional_function.cpp b/src/duckdb/src/planner/filter/table_filter_selectivity_optional_function.cpp new file mode 100644 index 000000000..2dfb35f76 --- /dev/null +++ b/src/duckdb/src/planner/filter/table_filter_selectivity_optional_function.cpp @@ -0,0 +1,186 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/filter/table_filter_selectivity_optional_function.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/filter/table_filter_functions.hpp" +#include "duckdb/planner/filter/table_filter_function_helpers.hpp" + +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/expression_executor_state.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/filter/expression_filter.hpp" + +namespace duckdb { + +SelectivityOptionalFilterState::SelectivityStats::SelectivityStats(const idx_t n_vectors_to_check, + const float selectivity_threshold) + : n_vectors_to_check(n_vectors_to_check), selectivity_threshold(selectivity_threshold), tuples_accepted(0), + tuples_processed(0), vectors_processed(0), status(FilterStatus::ACTIVE), pause_multiplier(0) { +} + +void SelectivityOptionalFilterState::SelectivityStats::Update(idx_t accepted, idx_t processed) { + vectors_processed++; + tuples_accepted += accepted; + tuples_processed += processed; + + static constexpr idx_t VECTOR_PAUSE = 10; + D_ASSERT(n_vectors_to_check < VECTOR_PAUSE); + if (vectors_processed == MaxValue(pause_multiplier, 1) * VECTOR_PAUSE) { + vectors_processed = 0; + tuples_accepted = 0; + tuples_processed = 0; + status = FilterStatus::ACTIVE; + } else if (vectors_processed >= n_vectors_to_check) { + // pause the filter if we processed enough vectors and the selectivity is too high + if (GetSelectivity() >= selectivity_threshold) { + status = FilterStatus::PAUSED_DUE_TO_HIGH_SELECTIVITY; + pause_multiplier++; // increase the pause duration + } else { + pause_multiplier = 0; // selective enough, reset the pause duration + } + } +} + +bool SelectivityOptionalFilterState::SelectivityStats::IsActive() const { + return status == FilterStatus::ACTIVE; +} +double SelectivityOptionalFilterState::SelectivityStats::GetSelectivity() const { + if (tuples_processed == 0) { + return 0.0; + } + return static_cast(tuples_accepted) / static_cast(tuples_processed); +} + +SelectivityOptionalFilterFunctionData::SelectivityOptionalFilterFunctionData(unique_ptr child_filter_expr_p, + float selectivity_threshold_p, + idx_t n_vectors_to_check_p) + : child_filter_expr(std::move(child_filter_expr_p)), selectivity_threshold(selectivity_threshold_p), + n_vectors_to_check(n_vectors_to_check_p) { +} + +unique_ptr SelectivityOptionalFilterFunctionData::Copy() const { + return make_uniq(child_filter_expr ? child_filter_expr->Copy() : nullptr, + selectivity_threshold, n_vectors_to_check); +} + +bool SelectivityOptionalFilterFunctionData::Equals(const FunctionData &other_p) const { + auto &other = other_p.Cast(); + if (selectivity_threshold != other.selectivity_threshold || n_vectors_to_check != other.n_vectors_to_check) { + return false; + } + if (!child_filter_expr && !other.child_filter_expr) { + return true; + } + if (!child_filter_expr || !other.child_filter_expr) { + return false; + } + return child_filter_expr->Equals(*other.child_filter_expr); +} + +static void SelectivityOptionalFilterSerialize(Serializer &serializer, const optional_ptr bind_data, + const BoundScalarFunction &function) { + if (!bind_data) { + return; + } + auto &data = bind_data->Cast(); + serializer.WritePropertyWithDefault>(200, "child_filter_expr", data.child_filter_expr); + serializer.WritePropertyWithDefault(201, "selectivity_threshold", data.selectivity_threshold); + serializer.WritePropertyWithDefault(202, "n_vectors_to_check", data.n_vectors_to_check); +} + +static unique_ptr SelectivityOptionalFilterDeserialize(Deserializer &deserializer, + BoundScalarFunction &function) { + auto child_filter_expr = deserializer.ReadPropertyWithDefault>(200, "child_filter_expr"); + auto selectivity_threshold = + deserializer.ReadPropertyWithExplicitDefault(201, "selectivity_threshold", 0.5f); + auto n_vectors_to_check = deserializer.ReadPropertyWithExplicitDefault(202, "n_vectors_to_check", idx_t(6)); + return make_uniq(std::move(child_filter_expr), selectivity_threshold, + n_vectors_to_check); +} + +struct SelectivityOptionalFilterLocalState : public FunctionLocalState { + SelectivityOptionalFilterLocalState(ClientContext &context, const Expression &child_filter_expr, + idx_t n_vectors_to_check, float selectivity_threshold) + : stats(n_vectors_to_check, selectivity_threshold), executor(context, child_filter_expr) { + } + + bool IsActive() const { + return stats.IsActive(); + } + void Update(idx_t accepted, idx_t processed) { + stats.Update(accepted, processed); + } + + SelectivityOptionalFilterState::SelectivityStats stats; + ExpressionExecutor executor; +}; + +static unique_ptr SelectivityOptionalFilterInitLocalState(ExpressionState &state, + const BoundFunctionExpression &expr, + FunctionData *bind_data) { + auto &data = bind_data->Cast(); + if (!data.child_filter_expr) { + return nullptr; + } + return make_uniq(state.GetContext(), *data.child_filter_expr, + data.n_vectors_to_check, data.selectivity_threshold); +} + +static idx_t SelectivityOptionalFilterSelect(DataChunk &args, ExpressionState &state, + optional_ptr sel, + optional_ptr true_sel, + optional_ptr false_sel) { + auto local_state_ptr = ExecuteFunctionState::GetFunctionState(state); + if (!local_state_ptr) { + return SetAllTrueSelection(args.size(), sel, true_sel, false_sel); + } + auto &local_state = local_state_ptr->Cast(); + auto count = args.size(); + if (!local_state.IsActive()) { + local_state.Update(0, 0); + return SetAllTrueSelection(count, sel, true_sel, false_sel); + } + + SelectionVector temp_true(count); + auto result_true_sel = (!true_sel || (sel && true_sel.get() == sel.get())) ? &temp_true : true_sel.get(); + auto approved_count = local_state.executor.SelectExpression(args, *result_true_sel); + approved_count = TranslateSelection(count, sel, *result_true_sel, approved_count, true_sel, false_sel); + local_state.Update(approved_count, count); + return approved_count; +} + +ScalarFunction SelectivityOptionalFilterScalarFun::GetFunction(const LogicalType &input_type) { + ScalarFunction func(NAME, {input_type}, LogicalType::BOOLEAN, nullptr, TableFilterFunctions::Bind); + func.SetInitStateCallback(SelectivityOptionalFilterInitLocalState); + func.SetSelectCallback(SelectivityOptionalFilterSelect); + func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + func.SetFilterPruneCallback(SelectivityOptionalFilterScalarFun::FilterPrune); + func.SetSerializeCallback(SelectivityOptionalFilterSerialize); + func.SetDeserializeCallback(SelectivityOptionalFilterDeserialize); + return func; +} + +FilterPropagateResult SelectivityOptionalFilterScalarFun::FilterPrune(const FunctionStatisticsPruneInput &input) { + if (!input.bind_data) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + auto &data = input.bind_data->Cast(); + if (!data.child_filter_expr) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + return ExpressionFilter::CheckExpressionStatistics(*data.child_filter_expr, input.stats); +} + +string SelectivityOptionalFilterScalarFun::ToString(const string &child_filter_string) { + return FormatOptionalFilterString(child_filter_string); +} + +ScalarFunction TableFilterSelectivityOptionalFun::GetFunction() { + return SelectivityOptionalFilterScalarFun::GetFunction(LogicalType::ANY); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/joinside.cpp b/src/duckdb/src/planner/joinside.cpp index 68f02af2d..77acc9e10 100644 --- a/src/duckdb/src/planner/joinside.cpp +++ b/src/duckdb/src/planner/joinside.cpp @@ -82,22 +82,22 @@ JoinSide JoinSide::GetJoinSide(const Expression &expression, const unordered_set const unordered_set &right_bindings) { if (expression.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { auto &colref = expression.Cast(); - if (colref.depth > 0) { - throw NotImplementedException("Non-inner join on correlated columns not supported"); + if (colref.Depth() > 0) { + return JoinSide::NONE; } - return GetJoinSide(colref.binding.table_index, left_bindings, right_bindings); + return GetJoinSide(colref.Binding().table_index, left_bindings, right_bindings); } D_ASSERT(expression.GetExpressionType() != ExpressionType::BOUND_REF); if (expression.GetExpressionType() == ExpressionType::SUBQUERY) { D_ASSERT(expression.GetExpressionClass() == ExpressionClass::BOUND_SUBQUERY); auto &subquery = expression.Cast(); JoinSide side = JoinSide::NONE; - for (auto &child : subquery.children) { + for (auto &child : subquery.GetChildren()) { auto child_side = GetJoinSide(*child, left_bindings, right_bindings); side = CombineJoinSide(side, child_side); } // correlated subquery, check the side of each of correlated columns in the subquery - for (auto &corr : subquery.binder->correlated_columns) { + for (auto &corr : subquery.GetBinder()->correlated_columns) { if (corr.depth > 1) { // correlated column has depth > 1 // it does not refer to any table in the current set of bindings diff --git a/src/duckdb/src/planner/logical_operator.cpp b/src/duckdb/src/planner/logical_operator.cpp index d6cfc8769..5f01308f8 100644 --- a/src/duckdb/src/planner/logical_operator.cpp +++ b/src/duckdb/src/planner/logical_operator.cpp @@ -198,10 +198,10 @@ void LogicalOperator::Verify(ClientContext &context) { try { auto &config = DBConfig::GetConfig(context); SerializationOptions options; - if (config.options.serialization_compatibility.manually_set) { - options.serialization_compatibility = config.options.serialization_compatibility; + if (config.options.storage_compatibility.manually_set) { + options.storage_compatibility = config.options.storage_compatibility; } else { - options.serialization_compatibility = SerializationCompatibility::Latest(); + options.storage_compatibility = StorageCompatibility::Latest(); } BinarySerializer::Serialize(*expressions[expr_idx], stream, options); @@ -257,7 +257,7 @@ vector LogicalOperator::GetTableIndex() const { unique_ptr LogicalOperator::Copy(ClientContext &context) const { MemoryStream stream(Allocator::Get(context)); SerializationOptions options; - options.serialization_compatibility = SerializationCompatibility::Latest(); + options.storage_compatibility = StorageCompatibility::Latest(); BinarySerializer serializer(stream, options); try { serializer.Begin(); diff --git a/src/duckdb/src/planner/logical_operator_deep_copy.cpp b/src/duckdb/src/planner/logical_operator_deep_copy.cpp index 9223bd622..712c54356 100644 --- a/src/duckdb/src/planner/logical_operator_deep_copy.cpp +++ b/src/duckdb/src/planner/logical_operator_deep_copy.cpp @@ -250,9 +250,9 @@ void TableBindingReplacer::VisitExpression(unique_ptr *expression) { auto &expr = *expression; if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { auto &bound_column_ref = expr->Cast(); - auto entry = table_idx_replacements.find(bound_column_ref.binding.table_index); + auto entry = table_idx_replacements.find(bound_column_ref.Binding().table_index); if (entry != table_idx_replacements.end()) { - bound_column_ref.binding.table_index = entry->second; + bound_column_ref.BindingMutable().table_index = entry->second; } } else if (expr->GetExpressionClass() == ExpressionClass::BOUND_PARAMETER) { // we have to replace the parameter data if it is a bound parameter @@ -260,9 +260,9 @@ void TableBindingReplacer::VisitExpression(unique_ptr *expression) { // correct pointer to the parameter data auto &bound_parameter = expr->Cast(); if (parameter_data) { - auto entry = parameter_data->find(bound_parameter.identifier); + auto entry = parameter_data->find(bound_parameter.Identifier()); if (entry != parameter_data->end()) { - bound_parameter.parameter_data = entry->second; + bound_parameter.ParameterDataMutable() = entry->second; } } } diff --git a/src/duckdb/src/planner/operator/logical_copy_to_file.cpp b/src/duckdb/src/planner/operator/logical_copy_to_file.cpp index 3bb4a6aa3..ee79b450c 100644 --- a/src/duckdb/src/planner/operator/logical_copy_to_file.cpp +++ b/src/duckdb/src/planner/operator/logical_copy_to_file.cpp @@ -23,12 +23,13 @@ vector LogicalCopyToFile::GetTypesWithoutPartitions(const vector LogicalCopyToFile::GetNamesWithoutPartitions(const vector &col_names, - const vector &part_cols, bool write_part_cols) { +vector LogicalCopyToFile::GetNamesWithoutPartitions(const vector &col_names, + const vector &part_cols, + bool write_part_cols) { if (write_part_cols || part_cols.empty()) { return col_names; } - vector names; + vector names; set part_col_set(part_cols.begin(), part_cols.end()); for (idx_t col_idx = 0; col_idx < col_names.size(); col_idx++) { if (part_col_set.find(col_idx) == part_col_set.end()) { @@ -73,6 +74,7 @@ void LogicalCopyToFile::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(221, "batch_size", batch_size, optional_idx()); serializer.WritePropertyWithDefault(222, "batch_size_bytes", batch_size_bytes, optional_idx()); serializer.WritePropertyWithDefault(223, "batches_per_file", batches_per_file, optional_idx()); + serializer.WritePropertyWithDefault(224, "order_columns", order_columns); } unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deserializer) { @@ -83,7 +85,7 @@ unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deseria auto per_thread_output = deserializer.ReadProperty(204, "per_thread_output"); auto partition_output = deserializer.ReadProperty(205, "partition_output"); auto partition_columns = deserializer.ReadProperty>(206, "partition_columns"); - auto names = deserializer.ReadProperty>(207, "names"); + auto names = deserializer.ReadProperty>(207, "names"); auto expected_types = deserializer.ReadProperty>(208, "expected_types"); auto copy_info = unique_ptr_cast(deserializer.ReadProperty>(209, "copy_info")); @@ -92,8 +94,8 @@ unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deseria auto &context = deserializer.Get(); auto name = deserializer.ReadProperty(210, "function_name"); - auto &func_catalog_entry = - Catalog::GetEntry(context, SYSTEM_CATALOG, DEFAULT_SCHEMA, name); + auto &func_catalog_entry = Catalog::GetEntry( + context, Identifier::SystemCatalog(), Identifier::DefaultSchema(), Identifier(name)); if (func_catalog_entry.type != CatalogType::COPY_FUNCTION_ENTRY) { throw InternalException("DeserializeFunction - cant find catalog entry for function %s", name); } @@ -125,6 +127,7 @@ unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deseria auto batch_size = deserializer.ReadPropertyWithExplicitDefault(221, "batch_size", optional_idx()); auto batch_size_bytes = deserializer.ReadPropertyWithExplicitDefault(222, "batch_size_bytes", optional_idx()); auto batches_per_file = deserializer.ReadPropertyWithExplicitDefault(223, "batches_per_file", optional_idx()); + auto order_columns = deserializer.ReadPropertyWithExplicitDefault(224, "order_columns", vector()); if (!has_serialize) { // If not serialized, re-bind with the copy info @@ -159,6 +162,7 @@ unique_ptr LogicalCopyToFile::Deserialize(Deserializer &deseria result->batch_size = batch_size; result->batch_size_bytes = batch_size_bytes; result->batches_per_file = batches_per_file; + result->order_columns = std::move(order_columns); return std::move(result); } diff --git a/src/duckdb/src/planner/operator/logical_export.cpp b/src/duckdb/src/planner/operator/logical_export.cpp index 046aff3e0..74a66be1c 100644 --- a/src/duckdb/src/planner/operator/logical_export.cpp +++ b/src/duckdb/src/planner/operator/logical_export.cpp @@ -20,8 +20,8 @@ LogicalExport::LogicalExport(ClientContext &context, unique_ptr copy_ } CopyFunction LogicalExport::GetCopyFunction(ClientContext &context, CopyInfo &info) { - auto ©_entry = - Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.format); + auto ©_entry = Catalog::GetEntry( + context, Identifier::InvalidCatalog(), Identifier::DefaultSchema(), Identifier(info.format)); return copy_entry.function; } diff --git a/src/duckdb/src/planner/operator/logical_filter.cpp b/src/duckdb/src/planner/operator/logical_filter.cpp index aa9b21642..cd434ac46 100644 --- a/src/duckdb/src/planner/operator/logical_filter.cpp +++ b/src/duckdb/src/planner/operator/logical_filter.cpp @@ -29,11 +29,11 @@ bool LogicalFilter::SplitPredicates(vector> &expressions) auto &conjunction = expressions[i]->Cast(); found_conjunction = true; // AND expression, append the other children - for (idx_t k = 1; k < conjunction.children.size(); k++) { - expressions.push_back(std::move(conjunction.children[k])); + for (idx_t k = 1; k < conjunction.GetChildrenMutable().size(); k++) { + expressions.push_back(std::move(conjunction.GetChildrenMutable()[k])); } // replace this expression with the first child of the conjunction - expressions[i] = std::move(conjunction.children[0]); + expressions[i] = std::move(conjunction.GetChildrenMutable()[0]); // we move back by one so the right child is checked again // in case it is an AND expression as well i--; diff --git a/src/duckdb/src/planner/operator/logical_get.cpp b/src/duckdb/src/planner/operator/logical_get.cpp index 86123ff3c..6c3512ba4 100644 --- a/src/duckdb/src/planner/operator/logical_get.cpp +++ b/src/duckdb/src/planner/operator/logical_get.cpp @@ -10,10 +10,7 @@ #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" namespace duckdb { @@ -42,7 +39,7 @@ LogicalGet::LogicalGet() : LogicalOperator(LogicalOperatorType::LOGICAL_GET) { } LogicalGet::LogicalGet(TableIndex table_index, TableFunction function, unique_ptr bind_data, - vector returned_types, vector returned_names, + vector returned_types, vector returned_names, virtual_column_map_t virtual_columns_p) : LogicalOperator(LogicalOperatorType::LOGICAL_GET), table_index(table_index), function(std::move(function)), bind_data(std::move(bind_data)), returned_types(std::move(returned_types)), names(std::move(returned_names)), @@ -63,7 +60,7 @@ InsertionOrderPreservingMap LogicalGet::ParamsToString() const { bool first_item = true; for (auto &kv : table_filters) { auto filter_idx = kv.GetIndex(); - auto &filter = kv.Filter(); + auto &filter = kv.Filter().Cast(); auto &col_id_entry = column_ids[filter_idx]; const auto col_id = col_id_entry.GetPrimaryIndex(); if (col_id_entry.IsVirtualColumn()) { @@ -73,13 +70,13 @@ InsertionOrderPreservingMap LogicalGet::ParamsToString() const { filters_info += "\n"; } first_item = false; - filters_info += filter.ToString(entry->second.name); + filters_info += filter.ToString(entry->second.name.GetIdentifierName()); } } else if (col_id < names.size()) { if (!first_item) { filters_info += "\n"; } - auto column_name = col_id_entry.GetName(names[col_id]); + auto column_name = col_id_entry.GetName(names[col_id].GetIdentifierName()); first_item = false; filters_info += filter.ToString(column_name); } @@ -192,7 +189,7 @@ const LogicalType &LogicalGet::GetColumnType(const ColumnIndex &index) const { } } -const string &LogicalGet::GetColumnName(const ColumnIndex &index) const { +const Identifier &LogicalGet::GetColumnName(const ColumnIndex &index) const { if (index.IsVirtualColumn()) { auto entry = virtual_columns.find(index.GetPrimaryIndex()); if (entry == virtual_columns.end()) { @@ -295,6 +292,14 @@ void LogicalGet::SetScanOrder(unique_ptr options) { function.set_scan_order(std::move(options), bind_data.get()); } +void LogicalGet::SetPartitionsToScan(vector partition_indices) { + if (!function.set_partitions_to_scan) { + throw InternalException("LogicalGet::SetPartitionsToScan called but function is not defined"); + } + scan_partition_indices = partition_indices; + function.set_partitions_to_scan(std::move(partition_indices), bind_data.get()); +} + void LogicalGet::Serialize(Serializer &serializer) const { LogicalOperator::Serialize(serializer); serializer.WriteProperty(200, "table_index", table_index); @@ -318,6 +323,7 @@ void LogicalGet::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(213, "ordinality_idx", ordinality_idx); serializer.WritePropertyWithDefault>(214, "row_group_order_options", row_group_order_options); + serializer.WritePropertyWithDefault(215, "scan_partition_indices", scan_partition_indices, vector()); } unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) { @@ -350,6 +356,8 @@ unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) deserializer.ReadPropertyWithDefault(213, "ordinality_idx", result->ordinality_idx); auto row_group_order_options = deserializer.ReadPropertyWithDefault>(214, "row_group_order_options"); + auto scan_partition_indices = + deserializer.ReadPropertyWithExplicitDefault>(215, "scan_partition_indices", vector()); if (!legacy_column_ids.empty()) { if (!result->column_ids.empty()) { throw SerializationException( @@ -410,6 +418,9 @@ unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) if (row_group_order_options) { result->SetScanOrder(std::move(row_group_order_options)); } + if (!scan_partition_indices.empty()) { + result->SetPartitionsToScan(std::move(scan_partition_indices)); + } return std::move(result); } @@ -420,10 +431,10 @@ vector LogicalGet::GetTableIndex() const { string LogicalGet::GetName() const { #ifdef DEBUG if (DBConfigOptions::debug_print_bindings) { - return StringUtil::Upper(function.name) + StringUtil::Format(" #%llu", table_index.index); + return StringUtil::Upper(function.name.GetIdentifierName()) + StringUtil::Format(" #%llu", table_index.index); } #endif - return StringUtil::Upper(function.name); + return StringUtil::Upper(function.name.GetIdentifierName()); } } // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_join.cpp b/src/duckdb/src/planner/operator/logical_join.cpp index 6b172f99c..92f2f993d 100644 --- a/src/duckdb/src/planner/operator/logical_join.cpp +++ b/src/duckdb/src/planner/operator/logical_join.cpp @@ -78,8 +78,8 @@ void LogicalJoin::GetTableReferences(LogicalOperator &op, unordered_set &bindings) { ExpressionIterator::VisitExpression(root_expr, [&](const BoundColumnRefExpression &colref) { - D_ASSERT(colref.depth == 0); - bindings.insert(colref.binding.table_index); + D_ASSERT(colref.Depth() == 0); + bindings.insert(colref.Binding().table_index); }); } diff --git a/src/duckdb/src/planner/operator/logical_materialized_cte.cpp b/src/duckdb/src/planner/operator/logical_materialized_cte.cpp index 505f7083d..1a04417df 100644 --- a/src/duckdb/src/planner/operator/logical_materialized_cte.cpp +++ b/src/duckdb/src/planner/operator/logical_materialized_cte.cpp @@ -4,7 +4,7 @@ namespace duckdb { InsertionOrderPreservingMap LogicalMaterializedCTE::ParamsToString() const { InsertionOrderPreservingMap result; - result["CTE Name"] = ctename; + result["CTE Name"] = ctename.GetIdentifierName(); result["Table Index"] = StringUtil::Format("%llu", table_index.index); SetParamsEstimatedCardinality(result); return result; diff --git a/src/duckdb/src/planner/operator/logical_recursive_cte.cpp b/src/duckdb/src/planner/operator/logical_recursive_cte.cpp index 8d917e58d..ea9ab5f35 100644 --- a/src/duckdb/src/planner/operator/logical_recursive_cte.cpp +++ b/src/duckdb/src/planner/operator/logical_recursive_cte.cpp @@ -7,9 +7,9 @@ namespace duckdb { LogicalRecursiveCTE::LogicalRecursiveCTE() : LogicalCTE(LogicalOperatorType::LOGICAL_RECURSIVE_CTE) { } -LogicalRecursiveCTE::LogicalRecursiveCTE(string ctename_p, TableIndex table_index, idx_t column_count, bool union_all, - vector> key_targets, unique_ptr top, - unique_ptr bottom) +LogicalRecursiveCTE::LogicalRecursiveCTE(Identifier ctename_p, TableIndex table_index, idx_t column_count, + bool union_all, vector> key_targets, + unique_ptr top, unique_ptr bottom) : LogicalCTE(std::move(ctename_p), table_index, column_count, std::move(top), std::move(bottom), LogicalOperatorType::LOGICAL_RECURSIVE_CTE), union_all(union_all), key_targets(std::move(key_targets)) { @@ -17,7 +17,7 @@ LogicalRecursiveCTE::LogicalRecursiveCTE(string ctename_p, TableIndex table_inde InsertionOrderPreservingMap LogicalRecursiveCTE::ParamsToString() const { InsertionOrderPreservingMap result; - result["CTE Name"] = ctename; + result["CTE Name"] = ctename.GetIdentifierName(); result["Table Index"] = StringUtil::Format("%llu", table_index.index); SetParamsEstimatedCardinality(result); return result; @@ -47,7 +47,7 @@ void LogicalRecursiveCTE::ResolveTypes() { for (auto &key_target : key_targets) { D_ASSERT(key_target->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF); auto &bound_ref = key_target->Cast(); - key_idx.insert(bound_ref.binding.column_index); + key_idx.insert(bound_ref.Binding().column_index); } idx_t pay_idx = 0; diff --git a/src/duckdb/src/planner/operator/logical_update.cpp b/src/duckdb/src/planner/operator/logical_update.cpp index 27563caf0..a457d0754 100644 --- a/src/duckdb/src/planner/operator/logical_update.cpp +++ b/src/duckdb/src/planner/operator/logical_update.cpp @@ -100,7 +100,7 @@ void LogicalUpdate::RewriteInPlaceUpdates(LogicalOperator &update_op) { // Update the target binding. target_binding = - proj.expressions[target_binding.column_index]->Cast().binding; + proj.expressions[target_binding.column_index]->Cast().Binding(); // Traverse the child. stack.push_back(proj.children.back()); diff --git a/src/duckdb/src/planner/planner.cpp b/src/duckdb/src/planner/planner.cpp index 7bf69afdf..3ef0da32f 100644 --- a/src/duckdb/src/planner/planner.cpp +++ b/src/duckdb/src/planner/planner.cpp @@ -19,6 +19,7 @@ #include "duckdb/planner/subquery/flatten_dependent_join.hpp" #include "duckdb/planner/operator_extension.hpp" #include "duckdb/planner/planner_extension.hpp" +#include "duckdb/optimizer/optimizer.hpp" namespace duckdb { @@ -52,10 +53,10 @@ void Planner::CreatePlan(SQLStatement &statement) { // first bind the tables and columns to the catalog bool parameters_resolved = true; try { - profiler.StartPhase(MetricType::PLANNER_BINDING); + auto binding_timer = profiler.StartTimer(); binder->SetParameters(bound_parameters); auto bound_statement = binder->Bind(statement); - profiler.EndPhase(); + binding_timer.EndTimer(); RunPostBindExtensions(context, *binder, bound_statement); @@ -133,6 +134,9 @@ shared_ptr Planner::PrepareSQLStatement(unique_ptr statement) { D_ASSERT(statement); + Optimizer optimizer(*binder, context); + optimizer.OptimizeStatement(statement); + switch (statement->type) { case StatementType::SELECT_STATEMENT: case StatementType::INSERT_STATEMENT: @@ -160,6 +164,8 @@ void Planner::CreatePlan(unique_ptr statement) { case StatementType::COPY_DATABASE_STATEMENT: case StatementType::UPDATE_EXTENSIONS_STATEMENT: case StatementType::MERGE_INTO_STATEMENT: + case StatementType::CONNECT_STATEMENT: + case StatementType::DISCONNECT_STATEMENT: CreatePlan(*statement); break; default: @@ -197,11 +203,11 @@ void Planner::VerifyPlan(ClientContext &context, unique_ptr &op MemoryStream stream(Allocator::Get(context)); SerializationOptions options; - if (config.options.serialization_compatibility.manually_set) { + if (config.options.storage_compatibility.manually_set) { // Override the default of 'latest' if this was manually set (for testing, mostly) - options.serialization_compatibility = config.options.serialization_compatibility; + options.storage_compatibility = config.options.storage_compatibility; } else { - options.serialization_compatibility = SerializationCompatibility::Latest(); + options.storage_compatibility = StorageCompatibility::Latest(); } BinarySerializer::Serialize(*op, stream, options); diff --git a/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp b/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp new file mode 100644 index 000000000..f6577271c --- /dev/null +++ b/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp @@ -0,0 +1,1907 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/subquery/delim_join_cte_rewriter.cpp +// +// +//===----------------------------------------------------------------------===// + +#include "duckdb/planner/subquery/delim_join_cte_rewriter.hpp" + +#include "duckdb/common/enums/optimizer_type.hpp" +#include "duckdb/main/config.hpp" +#include "duckdb/main/query_profiler.hpp" +#include "duckdb/main/settings.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_operator_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/filter/expression_filter.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/operator/list.hpp" + +#include + +namespace duckdb { + +static constexpr const char *CTE_DELIMINATOR_PROFILER_KEY = "optimizer.deliminator"; + +static void VerifyNoDelim(LogicalOperator &op) { + // Verify that there are no delim joins or delim scans in the plan, as these should have been rewritten to CTEs at + // this point. + if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + throw InternalException("Found DELIM_JOIN after flattening dependent joins"); + } + if (op.type == LogicalOperatorType::LOGICAL_DELIM_GET) { + throw InternalException("Found DELIM_GET after flattening dependent joins"); + } + for (auto &child : op.children) { + VerifyNoDelim(*child); + } +} + +static vector GenerateCTEColumnNames(idx_t column_count, const string &prefix) { + vector result; + result.reserve(column_count); + for (idx_t i = 0; i < column_count; i++) { + result.push_back(Identifier(prefix + to_string(i))); + } + return result; +} + +static unique_ptr CreateIdentityProjection(Binder &binder, unique_ptr child) { + child->ResolveOperatorTypes(); + auto bindings = child->GetColumnBindings(); + vector> expressions; + expressions.reserve(bindings.size()); + for (idx_t i = 0; i < bindings.size(); i++) { + expressions.push_back(make_uniq(child->types[i], bindings[i])); + } + auto projection = make_uniq(binder.GenerateTableIndex(), std::move(expressions)); + projection->children.push_back(std::move(child)); + projection->ResolveOperatorTypes(); + return std::move(projection); +} + +static idx_t RewriteDelimScanReferences(unique_ptr &op, TableIndex delim_scan_index) { + if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + if (!op->children.empty()) { + return RewriteDelimScanReferences(op->children[0], delim_scan_index); + } + return 0; + } + idx_t rewritten_count = 0; + for (auto &child : op->children) { + rewritten_count += RewriteDelimScanReferences(child, delim_scan_index); + } + if (op->type == LogicalOperatorType::LOGICAL_DELIM_GET) { + auto &delim_get = op->Cast(); + auto delim_scan_names = GenerateCTEColumnNames(delim_get.chunk_types.size(), "__duckdb_delim_scan_"); + auto cte_scan = + make_uniq(delim_get.table_index, delim_scan_index, delim_get.chunk_types, delim_scan_names); + op = std::move(cte_scan); + rewritten_count++; + } + return rewritten_count; +} + +static vector CreateBindingReplacements(const vector &old_bindings, + const vector &new_bindings) { + vector result; + auto count = MinValue(old_bindings.size(), new_bindings.size()); + for (idx_t i = 0; i < count; i++) { + if (old_bindings[i] != new_bindings[i] && + std::find(new_bindings.begin(), new_bindings.end(), old_bindings[i]) == new_bindings.end()) { + result.emplace_back(old_bindings[i], new_bindings[i]); + } + } + return result; +} + +static vector RewriteChangedChildBindings(LogicalOperator &op, LogicalOperator &child, + const vector &old_child_bindings) { + auto new_child_bindings = child.GetColumnBindings(); + auto replacements = CreateBindingReplacements(old_child_bindings, new_child_bindings); + if (replacements.empty()) { + return replacements; + } + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { + auto &dependent_join = op.Cast(); + for (auto &col : dependent_join.correlated_columns) { + for (const auto &replacement : replacements) { + if (col.binding == replacement.old_binding) { + col.binding = replacement.new_binding; + break; + } + } + } + } + ColumnBindingReplacer replacer; + replacer.replacement_bindings = replacements; + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN && op.children[0].get() == &child) { + replacer.stop_operator = child; + replacer.VisitOperator(op); + } else { + LogicalOperatorVisitor::EnumerateExpressions( + op, [&](unique_ptr *expr) { replacer.VisitExpression(expr); }); + } + return replacements; +} + +static optional_idx FindBindingIndex(const vector &bindings, const ColumnBinding &binding) { + auto entry = std::find(bindings.begin(), bindings.end(), binding); + if (entry == bindings.end()) { + return optional_idx(); + } + return NumericCast(entry - bindings.begin()); +} + +static void AddFilterToOperator(unique_ptr &child, unique_ptr filter) { + if (child->type == LogicalOperatorType::LOGICAL_FILTER && !child->HasProjectionMap()) { + child->Cast().expressions.push_back(std::move(filter)); + return; + } + + auto new_filter = make_uniq(); + new_filter->expressions.push_back(std::move(filter)); + new_filter->children.push_back(std::move(child)); + child = std::move(new_filter); +} + +static bool GetExpressionColumnBindings(Expression &expr, column_binding_set_t &bindings) { + bool depth_zero = true; + ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { + if (colref.Depth() == 0) { + bindings.insert(colref.Binding()); + } else { + depth_zero = false; + } + }); + return depth_zero; +} + +static bool ChildContainsBindings(LogicalOperator &child, const column_binding_set_t &bindings) { + column_binding_set_t child_bindings; + for (auto &binding : child.GetColumnBindings()) { + child_bindings.insert(binding); + } + for (auto &binding : bindings) { + if (child_bindings.find(binding) == child_bindings.end()) { + return false; + } + } + return true; +} + +static bool FilterReferencesDelimInput(LogicalComparisonJoin &delim_join, Expression &filter) { + D_ASSERT(delim_join.type == LogicalOperatorType::LOGICAL_DELIM_JOIN); + column_binding_set_t filter_bindings; + if (!GetExpressionColumnBindings(filter, filter_bindings)) { + return false; + } + if (filter_bindings.empty()) { + return false; + } + return ChildContainsBindings(*delim_join.children[0], filter_bindings); +} + +static bool ExpressionReferencesChild(Expression &expr, LogicalOperator &child) { + column_binding_set_t expr_bindings; + if (!GetExpressionColumnBindings(expr, expr_bindings)) { + return false; + } + if (expr_bindings.empty()) { + return false; + } + column_binding_set_t child_bindings; + for (auto &binding : child.GetColumnBindings()) { + child_bindings.insert(binding); + } + for (auto &binding : expr_bindings) { + if (child_bindings.find(binding) != child_bindings.end()) { + return true; + } + } + return false; +} + +static bool ExpressionNullPropagatesForChild(Expression &expr, LogicalOperator &child) { + return ExpressionReferencesChild(expr, child) && expr.PropagatesNullValues(); +} + +static bool ExpressionNullRejectsDelimJoinRHS(Expression &expr, LogicalComparisonJoin &delim_join) { + auto &rhs = *delim_join.children[1]; + if (!ExpressionReferencesChild(expr, rhs)) { + return false; + } + if (BoundComparisonExpression::IsComparison(expr)) { + return expr.PropagatesNullValues(); + } + if (expr.GetExpressionClass() == ExpressionClass::BOUND_OPERATOR && + expr.GetExpressionType() == ExpressionType::OPERATOR_IS_NOT_NULL) { + bool null_propagating_child = false; + ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { + null_propagating_child = null_propagating_child || ExpressionNullPropagatesForChild(child, rhs); + }); + return null_propagating_child; + } + return false; +} + +static bool FilterNullRejectsDelimJoinRHS(LogicalFilter &filter, LogicalComparisonJoin &delim_join) { + if (filter.HasProjectionMap() || delim_join.join_type != JoinType::SINGLE) { + return false; + } + for (auto &expr : filter.expressions) { + if (ExpressionNullRejectsDelimJoinRHS(*expr, delim_join)) { + return true; + } + } + return false; +} + +static bool PushEligibleFilterExpressionsIntoDelimJoinInputs(unique_ptr &plan) { + auto &filter = plan->Cast(); + if (filter.HasProjectionMap()) { + return false; + } + if (filter.children[0]->type != LogicalOperatorType::LOGICAL_DELIM_JOIN) { + return false; + } + + bool changed = false; + auto &delim_join = filter.children[0]->Cast(); + vector> remaining_expressions; + auto expressions = std::move(filter.expressions); + LogicalFilter::SplitPredicates(expressions); + for (auto &expr : expressions) { + if (FilterReferencesDelimInput(delim_join, *expr)) { + AddFilterToOperator(delim_join.children[0], std::move(expr)); + changed = true; + continue; + } + remaining_expressions.push_back(std::move(expr)); + } + + if (remaining_expressions.empty()) { + plan = std::move(filter.children[0]); + } else { + filter.expressions = std::move(remaining_expressions); + } + return changed; +} + +static bool PushEligibleFiltersIntoDelimJoinInputs(unique_ptr &plan) { + bool changed = false; + for (auto &child : plan->children) { + changed = PushEligibleFiltersIntoDelimJoinInputs(child) || changed; + } + if (plan->type == LogicalOperatorType::LOGICAL_FILTER) { + changed = PushEligibleFilterExpressionsIntoDelimJoinInputs(plan) || changed; + } + return changed; +} + +static bool HasSelection(const LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_GET: { + auto &get = op.Cast(); + for (const auto &entry : get.table_filters) { + auto &expr_filter = + ExpressionFilter::GetExpressionFilter(entry.Filter(), "DelimJoinCTERewriter::HasSelection"); + auto &expr = *expr_filter.expr; + if (expr.GetExpressionClass() != ExpressionClass::BOUND_OPERATOR || + expr.GetExpressionType() != ExpressionType::OPERATOR_IS_NOT_NULL) { + return true; + } + } + break; + } + case LogicalOperatorType::LOGICAL_FILTER: + return true; + default: + break; + } + + for (auto &child : op.children) { + if (HasSelection(*child)) { + return true; + } + } + return false; +} + +struct JoinWithGeneratedDedupRef { + JoinWithGeneratedDedupRef(unique_ptr &join_p, idx_t depth_p, bool filter_cross_product_p = false) + : join(join_p), depth(depth_p), filter_cross_product(filter_cross_product_p) { + } + + reference> join; + idx_t depth; + bool filter_cross_product; +}; + +struct GeneratedDedupRef { + optional_ptr cte_ref; + vector output_bindings; + vector> output_expressions; + vector> filters; + bool has_projection = false; +}; + +static bool IsEqualityJoinCondition(const JoinCondition &cond) { + if (!cond.IsComparison()) { + return false; + } + switch (cond.GetComparisonType()) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + return true; + default: + return false; + } +} + +static bool FindAndReplaceBindings(vector &traced_bindings, + const vector> &expressions, + const vector ¤t_bindings) { + for (auto &binding : traced_bindings) { + idx_t current_idx; + for (current_idx = 0; current_idx < expressions.size(); current_idx++) { + if (binding == current_bindings[current_idx]) { + break; + } + } + + if (current_idx == expressions.size() || + expressions[current_idx]->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + + auto &colref = expressions[current_idx]->Cast(); + binding = colref.Binding(); + } + return true; +} + +class ExpressionBindingReplacer : public LogicalOperatorVisitor { +public: + ExpressionBindingReplacer(const vector &bindings, const vector> &expressions) + : bindings(bindings), expressions(expressions) { + D_ASSERT(bindings.size() == expressions.size()); + } + + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override { + if (expr.Depth() != 0) { + return nullptr; + } + for (idx_t idx = 0; idx < bindings.size(); idx++) { + if (expr.Binding() == bindings[idx]) { + return expressions[idx]->Copy(); + } + } + return nullptr; + } + +private: + const vector &bindings; + const vector> &expressions; +}; + +static void ReplaceExpressionBindings(unique_ptr &expr, const vector &bindings, + const vector> &expressions) { + if (bindings.empty()) { + return; + } + ExpressionBindingReplacer replacer(bindings, expressions); + replacer.VisitExpression(&expr); +} + +static bool GetBoundColumnRefBinding(Expression &expr, ColumnBinding &binding) { + if (expr.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + auto &colref = expr.Cast(); + if (colref.Depth() != 0) { + return false; + } + binding = colref.Binding(); + return true; +} + +static bool ExpressionReferencesBinding(Expression &expr, const vector &bindings) { + bool found = false; + ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { + if (colref.Depth() == 0 && FindBindingIndex(bindings, colref.Binding()).IsValid()) { + found = true; + } + }); + return found; +} + +static bool CoversAllBindings(const vector &all_bindings, + const vector &covered_bindings) { + if (all_bindings.size() != covered_bindings.size()) { + return false; + } + for (auto &binding : all_bindings) { + idx_t match_count = 0; + for (auto &covered_binding : covered_bindings) { + if (binding == covered_binding) { + match_count++; + } + } + if (match_count != 1) { + return false; + } + } + return true; +} + +static bool AddExpressionReplacement(vector &bindings, vector> &expressions, + ColumnBinding binding, unique_ptr expression) { + for (idx_t binding_idx = 0; binding_idx < bindings.size(); binding_idx++) { + if (bindings[binding_idx] != binding) { + continue; + } + return expressions[binding_idx]->Equals(*expression); + } + bindings.push_back(binding); + expressions.push_back(std::move(expression)); + return true; +} + +static void ReplaceOperatorBindings(LogicalOperator &op, const vector &bindings, + const vector> &expressions) { + if (bindings.empty()) { + return; + } + ExpressionBindingReplacer replacer(bindings, expressions); + replacer.VisitOperator(op); +} + +class GeneratedDedupRefEliminator { +public: + GeneratedDedupRefEliminator(LogicalComparisonJoin &delim_join, TableIndex dedup_cte_index, idx_t dedup_ref_count, + LogicalOperator &rewrite_root); + + idx_t Remove(); + +private: + unique_ptr GetGeneratedDedupRef(LogicalOperator &op, bool collect_filters = false, + bool allow_projection = false) const; + bool ExpressionReferencesGeneratedDedupRef(Expression &expr, const GeneratedDedupRef &dedup_ref) const; + bool AddReplacement(vector &replacements, ColumnBinding old_binding, + ColumnBinding new_binding) const; + bool AddReplacement(vector &replacements, ColumnBinding old_binding, ColumnBinding new_binding, + const LogicalType &new_type) const; + bool CoversAllDedupColumns(const GeneratedDedupRef &dedup_ref, const vector &bindings) const; + optional_idx FindGeneratedOutputBinding(Expression &expr, const GeneratedDedupRef &dedup_ref) const; + bool ExpressionReferencesGeneratedSide(Expression &expr, const GeneratedDedupRef &dedup_ref) const; + bool FilterIsGeneratedDedupCrossProduct(LogicalOperator &op) const; + void FindJoinsWithGeneratedDedupRefs(unique_ptr &op, vector &joins, + idx_t depth = 0) const; + idx_t CountGeneratedDedupRefs(LogicalOperator &op) const; + bool RemoveInequalityJoinConditions(LogicalOperator &target_op, const vector &join_conditions, + idx_t dedup_idx); + bool RemoveJoin(unique_ptr &join); + bool RemoveFilterCrossProduct(unique_ptr &filter_op); + +private: + LogicalComparisonJoin &delim_join; + TableIndex dedup_cte_index; + idx_t dedup_ref_count; + LogicalOperator &rewrite_root; +}; + +GeneratedDedupRefEliminator::GeneratedDedupRefEliminator(LogicalComparisonJoin &delim_join, TableIndex dedup_cte_index, + idx_t dedup_ref_count, LogicalOperator &rewrite_root) + : delim_join(delim_join), dedup_cte_index(dedup_cte_index), dedup_ref_count(dedup_ref_count), + rewrite_root(rewrite_root) { +} + +unique_ptr GeneratedDedupRefEliminator::GetGeneratedDedupRef(LogicalOperator &op, + bool collect_filters, + bool allow_projection) const { + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + auto &cteref = op.Cast(); + if (cteref.cte_index == dedup_cte_index) { + auto result = make_uniq(); + result->cte_ref = cteref; + result->output_bindings = cteref.GetColumnBindings(); + result->output_expressions.reserve(result->output_bindings.size()); + for (idx_t col_idx = 0; col_idx < result->output_bindings.size(); col_idx++) { + result->output_expressions.push_back( + make_uniq(cteref.chunk_types[col_idx], result->output_bindings[col_idx])); + } + return result; + } + return nullptr; + } + if (op.type == LogicalOperatorType::LOGICAL_FILTER) { + auto &filter = op.Cast(); + if (filter.HasProjectionMap() || filter.children.size() != 1) { + return nullptr; + } + auto result = GetGeneratedDedupRef(*filter.children[0], collect_filters, allow_projection); + if (!result) { + return nullptr; + } + if (collect_filters) { + for (auto &expr : filter.expressions) { + auto filter_expr = expr->Copy(); + ReplaceExpressionBindings(filter_expr, result->output_bindings, result->output_expressions); + result->filters.push_back(std::move(filter_expr)); + } + } + return result; + } + if (allow_projection && op.type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = op.Cast(); + if (projection.children.size() != 1) { + return nullptr; + } + auto result = GetGeneratedDedupRef(*projection.children[0], collect_filters, allow_projection); + if (!result) { + return nullptr; + } + + auto child_bindings = result->output_bindings; + auto child_expressions = std::move(result->output_expressions); + result->output_bindings = projection.GetColumnBindings(); + result->output_expressions.clear(); + result->output_expressions.reserve(projection.expressions.size()); + for (auto &expr : projection.expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, child_bindings, child_expressions); + result->output_expressions.push_back(std::move(rewritten_expr)); + } + result->has_projection = true; + return result; + } + return nullptr; +} + +bool GeneratedDedupRefEliminator::ExpressionReferencesGeneratedDedupRef(Expression &expr, + const GeneratedDedupRef &dedup_ref) const { + bool found = false; + ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { + if (colref.Depth() == 0 && colref.Binding().table_index == dedup_ref.cte_ref->table_index) { + found = true; + } + }); + return found; +} + +bool GeneratedDedupRefEliminator::AddReplacement(vector &replacements, ColumnBinding old_binding, + ColumnBinding new_binding) const { + if (old_binding == new_binding) { + return true; + } + for (auto &replacement : replacements) { + if (replacement.old_binding == old_binding) { + return replacement.new_binding == new_binding; + } + } + replacements.emplace_back(old_binding, new_binding); + return true; +} + +bool GeneratedDedupRefEliminator::AddReplacement(vector &replacements, ColumnBinding old_binding, + ColumnBinding new_binding, const LogicalType &new_type) const { + if (old_binding == new_binding) { + return true; + } + for (auto &replacement : replacements) { + if (replacement.old_binding == old_binding) { + return replacement.new_binding == new_binding; + } + } + replacements.emplace_back(old_binding, new_binding, new_type); + return true; +} + +bool GeneratedDedupRefEliminator::CoversAllDedupColumns(const GeneratedDedupRef &dedup_ref, + const vector &bindings) const { + auto cte_bindings = + LogicalOperator::GenerateColumnBindings(dedup_ref.cte_ref->table_index, dedup_ref.cte_ref->chunk_types.size()); + if (bindings.size() != cte_bindings.size()) { + return false; + } + for (auto &cte_binding : cte_bindings) { + idx_t match_count = 0; + for (auto &binding : bindings) { + if (binding == cte_binding) { + match_count++; + } + } + if (match_count != 1) { + return false; + } + } + return true; +} + +optional_idx GeneratedDedupRefEliminator::FindGeneratedOutputBinding(Expression &expr, + const GeneratedDedupRef &dedup_ref) const { + optional_idx result; + bool unsupported = false; + ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { + if (unsupported || colref.Depth() != 0) { + return; + } + auto binding_idx = FindBindingIndex(dedup_ref.output_bindings, colref.Binding()); + if (!binding_idx.IsValid()) { + return; + } + if (result.IsValid() && result.GetIndex() != binding_idx.GetIndex()) { + unsupported = true; + return; + } + result = binding_idx; + }); + return unsupported ? optional_idx() : result; +} + +bool GeneratedDedupRefEliminator::ExpressionReferencesGeneratedSide(Expression &expr, + const GeneratedDedupRef &dedup_ref) const { + if (ExpressionReferencesGeneratedDedupRef(expr, dedup_ref)) { + return true; + } + bool found = false; + ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { + if (colref.Depth() != 0) { + return; + } + if (FindBindingIndex(dedup_ref.output_bindings, colref.Binding()).IsValid()) { + found = true; + } + }); + return found; +} + +bool GeneratedDedupRefEliminator::FilterIsGeneratedDedupCrossProduct(LogicalOperator &op) const { + if (op.type != LogicalOperatorType::LOGICAL_FILTER || op.HasProjectionMap()) { + return false; + } + auto &filter = op.Cast(); + if (filter.children.size() != 1 || filter.children[0]->type != LogicalOperatorType::LOGICAL_CROSS_PRODUCT) { + return false; + } + auto &cross_product = *filter.children[0]; + return cross_product.children.size() == 2 && + (GetGeneratedDedupRef(*cross_product.children[0]) || GetGeneratedDedupRef(*cross_product.children[1])); +} + +void GeneratedDedupRefEliminator::FindJoinsWithGeneratedDedupRefs(unique_ptr &op, + vector &joins, + idx_t depth) const { + if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + if (!op->children.empty()) { + FindJoinsWithGeneratedDedupRefs(op->children[0], joins, depth + 1); + } + return; + } + + for (auto &child : op->children) { + FindJoinsWithGeneratedDedupRefs(child, joins, depth + 1); + } + + if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN && + (GetGeneratedDedupRef(*op->children[0], false, true) || GetGeneratedDedupRef(*op->children[1], false, true))) { + joins.emplace_back(op, depth); + } else if (FilterIsGeneratedDedupCrossProduct(*op)) { + joins.emplace_back(op, depth, true); + } +} + +idx_t GeneratedDedupRefEliminator::CountGeneratedDedupRefs(LogicalOperator &op) const { + if (op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + if (!op.children.empty()) { + return CountGeneratedDedupRefs(*op.children[0]); + } + return 0; + } + + idx_t count = 0; + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF && op.Cast().cte_index == dedup_cte_index) { + count++; + } + for (auto &child : op.children) { + count += CountGeneratedDedupRefs(*child); + } + return count; +} + +bool GeneratedDedupRefEliminator::RemoveInequalityJoinConditions(LogicalOperator &target_op, + const vector &join_conditions, + idx_t dedup_idx) { + auto &delim_conditions = delim_join.conditions; + if (dedup_ref_count != 1 || delim_conditions.size() != join_conditions.size()) { + return false; + } + if (delim_join.join_type != JoinType::ANTI && delim_join.join_type != JoinType::MARK && + delim_join.join_type != JoinType::SEMI && delim_join.join_type != JoinType::SINGLE) { + return false; + } + + if (delim_join.join_type == JoinType::SINGLE || delim_join.join_type == JoinType::MARK) { + bool has_one_equality = false; + for (auto &cond : join_conditions) { + has_one_equality = has_one_equality || IsEqualityJoinCondition(cond); + } + if (!has_one_equality) { + return false; + } + } + + vector traced_bindings; + for (const auto &cond : delim_conditions) { + if (!cond.IsComparison() || cond.GetRHS().GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + auto &colref = cond.GetRHS().Cast(); + traced_bindings.emplace_back(colref.Binding()); + } + + reference current_op = *delim_join.children[1]; + while (¤t_op.get() != &target_op) { + if (current_op.get().children.size() != 1) { + return false; + } + + switch (current_op.get().type) { + case LogicalOperatorType::LOGICAL_PROJECTION: + if (!FindAndReplaceBindings(traced_bindings, current_op.get().expressions, + current_op.get().GetColumnBindings())) { + return false; + } + break; + case LogicalOperatorType::LOGICAL_FILTER: + break; + default: + return false; + } + current_op = *current_op.get().children[0]; + } + + bool found_all = true; + for (idx_t cond_idx = 0; cond_idx < delim_conditions.size(); cond_idx++) { + auto &delim_condition = delim_conditions[cond_idx]; + if (!delim_condition.IsComparison()) { + continue; + } + const auto &traced_binding = traced_bindings[cond_idx]; + + bool found = false; + for (auto &join_condition : join_conditions) { + if (!join_condition.IsComparison()) { + continue; + } + auto &dedup_side = dedup_idx == 0 ? join_condition.GetLHS() : join_condition.GetRHS(); + if (dedup_side.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + continue; + } + auto &colref = dedup_side.Cast(); + if (colref.Binding() == traced_binding) { + auto join_comparison = join_condition.GetComparisonType(); + auto original_join_comparison = join_condition.GetComparisonType(); + if (delim_condition.GetComparisonType() == ExpressionType::COMPARE_DISTINCT_FROM || + delim_condition.GetComparisonType() == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + if (join_comparison == ExpressionType::COMPARE_EQUAL) { + join_comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; + } else if (join_comparison == ExpressionType::COMPARE_NOTEQUAL) { + join_comparison = ExpressionType::COMPARE_DISTINCT_FROM; + } else if (join_comparison != ExpressionType::COMPARE_DISTINCT_FROM && + join_comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + found = false; + break; + } + } + auto left_copy = delim_condition.LeftReference()->Copy(); + auto right_copy = delim_condition.RightReference()->Copy(); + + delim_conditions[cond_idx] = JoinCondition(std::move(left_copy), std::move(right_copy), + FlipComparisonExpression(join_comparison)); + if (delim_join.join_type != JoinType::MARK && + original_join_comparison != ExpressionType::COMPARE_DISTINCT_FROM && + original_join_comparison != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + auto final_comparison = delim_conditions[cond_idx].GetComparisonType(); + if (final_comparison == ExpressionType::COMPARE_DISTINCT_FROM) { + final_comparison = ExpressionType::COMPARE_NOTEQUAL; + } else if (final_comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + final_comparison = ExpressionType::COMPARE_EQUAL; + } + delim_conditions[cond_idx] = + JoinCondition(delim_conditions[cond_idx].LeftReference()->Copy(), + delim_conditions[cond_idx].RightReference()->Copy(), final_comparison); + } + found = true; + break; + } + } + found_all = found_all && found; + } + + return found_all; +} + +bool GeneratedDedupRefEliminator::RemoveJoin(unique_ptr &join) { + auto &comparison_join = join->Cast(); + if (comparison_join.join_type != JoinType::INNER && comparison_join.join_type != JoinType::SEMI) { + return false; + } + + auto left_is_generated = GetGeneratedDedupRef(*join->children[0], false, true) != nullptr; + auto right_is_generated = GetGeneratedDedupRef(*join->children[1], false, true) != nullptr; + if (left_is_generated == right_is_generated) { + return false; + } + const idx_t dedup_idx = left_is_generated ? 0 : 1; + + auto dedup_ref = GetGeneratedDedupRef(*join->children[dedup_idx], true, true); + if (!dedup_ref) { + return false; + } + + ColumnBindingReplacer replacer; + auto &replacement_bindings = replacer.replacement_bindings; + bool all_equality_conditions = true; + vector covered_dedup_bindings; + covered_dedup_bindings.reserve(comparison_join.conditions.size()); + vector base_replacement_bindings; + vector> base_replacement_expressions; + vector> join_filter_expressions; + vector> not_null_filter_expressions; + + for (auto &cond : comparison_join.conditions) { + if (!cond.IsComparison()) { + return false; + } + all_equality_conditions = all_equality_conditions && IsEqualityJoinCondition(cond); + + auto lhs_generated_idx = FindGeneratedOutputBinding(cond.GetLHS(), *dedup_ref); + auto rhs_generated_idx = FindGeneratedOutputBinding(cond.GetRHS(), *dedup_ref); + if (lhs_generated_idx.IsValid() == rhs_generated_idx.IsValid()) { + return false; + } + auto generated_idx = lhs_generated_idx.IsValid() ? lhs_generated_idx.GetIndex() : rhs_generated_idx.GetIndex(); + auto &generated_binding = dedup_ref->output_bindings[generated_idx]; + auto &generated_expression = *dedup_ref->output_expressions[generated_idx]; + auto &other_side = lhs_generated_idx.IsValid() ? cond.GetRHS() : cond.GetLHS(); + + ColumnBinding other_binding; + if (!GetBoundColumnRefBinding(other_side, other_binding)) { + return false; + } + if (!AddReplacement(replacement_bindings, generated_binding, other_binding, + generated_expression.GetReturnType())) { + return false; + } + + ColumnBinding base_binding; + if (GetBoundColumnRefBinding(generated_expression, base_binding) && + base_binding.table_index == dedup_ref->cte_ref->table_index) { + if (!AddReplacement(replacement_bindings, base_binding, other_binding)) { + return false; + } + covered_dedup_bindings.emplace_back(base_binding); + base_replacement_bindings.push_back(base_binding); + base_replacement_expressions.push_back(other_side.Copy()); + } + + join_filter_expressions.push_back( + BoundComparisonExpression::Create(cond.GetComparisonType(), cond.GetLHS().Copy(), cond.GetRHS().Copy())); + if (cond.GetComparisonType() != ExpressionType::COMPARE_NOT_DISTINCT_FROM && + cond.GetComparisonType() != ExpressionType::COMPARE_DISTINCT_FROM) { + auto is_not_null_expr = + make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); + is_not_null_expr->GetChildrenMutable().push_back(other_side.Copy()); + not_null_filter_expressions.push_back(std::move(is_not_null_expr)); + } + } + if (!CoversAllDedupColumns(*dedup_ref, covered_dedup_bindings)) { + return false; + } + + vector> generated_output_replacements; + generated_output_replacements.reserve(dedup_ref->output_expressions.size()); + for (auto &expr : dedup_ref->output_expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, base_replacement_bindings, base_replacement_expressions); + if (ExpressionReferencesGeneratedSide(*rewritten_expr, *dedup_ref)) { + return false; + } + generated_output_replacements.push_back(std::move(rewritten_expr)); + } + + vector> filter_expressions; + if (all_equality_conditions) { + for (auto &expr : dedup_ref->filters) { + ReplaceExpressionBindings(expr, base_replacement_bindings, base_replacement_expressions); + if (ExpressionReferencesGeneratedSide(*expr, *dedup_ref)) { + return false; + } + filter_expressions.push_back(std::move(expr)); + } + for (auto &expr : join_filter_expressions) { + ReplaceExpressionBindings(expr, dedup_ref->output_bindings, generated_output_replacements); + ReplaceExpressionBindings(expr, base_replacement_bindings, base_replacement_expressions); + if (ExpressionReferencesGeneratedSide(*expr, *dedup_ref)) { + return false; + } + filter_expressions.push_back(std::move(expr)); + } + } else { + if (dedup_ref->has_projection || + !RemoveInequalityJoinConditions(*join, comparison_join.conditions, dedup_idx)) { + return false; + } + filter_expressions = std::move(dedup_ref->filters); + for (auto &expr : not_null_filter_expressions) { + filter_expressions.push_back(std::move(expr)); + } + } + + unique_ptr replacement_op = std::move(comparison_join.children[1 - dedup_idx]); + if (!filter_expressions.empty()) { + auto new_filter = make_uniq(); + new_filter->expressions = std::move(filter_expressions); + new_filter->children.emplace_back(std::move(replacement_op)); + replacement_op = std::move(new_filter); + } + + join = std::move(replacement_op); + + replacer.VisitOperator(rewrite_root); + ReplaceOperatorBindings(rewrite_root, dedup_ref->output_bindings, generated_output_replacements); + return true; +} + +bool GeneratedDedupRefEliminator::RemoveFilterCrossProduct(unique_ptr &filter_op) { + auto &filter = filter_op->Cast(); + D_ASSERT(filter.children.size() == 1); + auto &cross_product = *filter.children[0]; + D_ASSERT(cross_product.type == LogicalOperatorType::LOGICAL_CROSS_PRODUCT); + + const idx_t dedup_idx = GetGeneratedDedupRef(*cross_product.children[0]) ? 0 : 1; + auto dedup_ref = GetGeneratedDedupRef(*cross_product.children[dedup_idx], true); + if (!dedup_ref) { + return false; + } + + vector> generated_filter_expressions = std::move(dedup_ref->filters); + + filter.SplitPredicates(); + vector consumed(filter.expressions.size(), false); + vector join_conditions; + ColumnBindingReplacer replacer; + auto &replacement_bindings = replacer.replacement_bindings; + bool all_equality_conditions = true; + vector covered_dedup_bindings; + + for (idx_t expr_idx = 0; expr_idx < filter.expressions.size(); expr_idx++) { + auto &expr = *filter.expressions[expr_idx]; + if (!BoundComparisonExpression::IsComparison(expr)) { + continue; + } + auto &comparison = expr.Cast(); + auto &lhs = BoundComparisonExpression::Left(comparison); + auto &rhs = BoundComparisonExpression::Right(comparison); + if (lhs.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF || + rhs.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + continue; + } + + auto &lhs_colref = lhs.Cast(); + auto &rhs_colref = rhs.Cast(); + auto lhs_dedup = lhs_colref.Binding().table_index == dedup_ref->cte_ref->table_index; + auto rhs_dedup = rhs_colref.Binding().table_index == dedup_ref->cte_ref->table_index; + if (lhs_dedup == rhs_dedup) { + continue; + } + + auto comparison_type = expr.GetExpressionType(); + if (lhs_dedup) { + if (!AddReplacement(replacement_bindings, lhs_colref.Binding(), rhs_colref.Binding())) { + return false; + } + covered_dedup_bindings.emplace_back(lhs_colref.Binding()); + if (dedup_idx == 0) { + join_conditions.emplace_back(lhs.Copy(), rhs.Copy(), comparison_type); + } else { + join_conditions.emplace_back(rhs.Copy(), lhs.Copy(), FlipComparisonExpression(comparison_type)); + } + if (comparison_type != ExpressionType::COMPARE_NOT_DISTINCT_FROM && + comparison_type != ExpressionType::COMPARE_DISTINCT_FROM) { + auto is_not_null_expr = + make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); + is_not_null_expr->GetChildrenMutable().push_back(rhs.Copy()); + generated_filter_expressions.push_back(std::move(is_not_null_expr)); + } + } else { + if (!AddReplacement(replacement_bindings, rhs_colref.Binding(), lhs_colref.Binding())) { + return false; + } + covered_dedup_bindings.emplace_back(rhs_colref.Binding()); + if (dedup_idx == 0) { + join_conditions.emplace_back(rhs.Copy(), lhs.Copy(), FlipComparisonExpression(comparison_type)); + } else { + join_conditions.emplace_back(lhs.Copy(), rhs.Copy(), comparison_type); + } + if (comparison_type != ExpressionType::COMPARE_NOT_DISTINCT_FROM && + comparison_type != ExpressionType::COMPARE_DISTINCT_FROM) { + auto is_not_null_expr = + make_uniq(ExpressionType::OPERATOR_IS_NOT_NULL, LogicalType::BOOLEAN); + is_not_null_expr->GetChildrenMutable().push_back(lhs.Copy()); + generated_filter_expressions.push_back(std::move(is_not_null_expr)); + } + } + all_equality_conditions = all_equality_conditions && IsEqualityJoinCondition(join_conditions.back()); + consumed[expr_idx] = true; + } + + if (join_conditions.size() != dedup_ref->cte_ref->chunk_types.size() || + !CoversAllDedupColumns(*dedup_ref, covered_dedup_bindings)) { + return false; + } + for (idx_t expr_idx = 0; expr_idx < filter.expressions.size(); expr_idx++) { + if (!consumed[expr_idx] && ExpressionReferencesGeneratedDedupRef(*filter.expressions[expr_idx], *dedup_ref)) { + return false; + } + } + + if (!all_equality_conditions && !RemoveInequalityJoinConditions(*filter_op, join_conditions, dedup_idx)) { + return false; + } + + unique_ptr replacement_op = std::move(cross_product.children[1 - dedup_idx]); + for (idx_t expr_idx = 0; expr_idx < filter.expressions.size(); expr_idx++) { + if (!consumed[expr_idx]) { + generated_filter_expressions.push_back(std::move(filter.expressions[expr_idx])); + } + } + if (!generated_filter_expressions.empty()) { + auto new_filter = make_uniq(); + new_filter->expressions = std::move(generated_filter_expressions); + new_filter->children.emplace_back(std::move(replacement_op)); + replacement_op = std::move(new_filter); + } + + filter_op = std::move(replacement_op); + + replacer.VisitOperator(rewrite_root); + return true; +} + +idx_t GeneratedDedupRefEliminator::Remove() { + vector joins; + FindJoinsWithGeneratedDedupRefs(delim_join.children[1], joins); + if (joins.empty()) { + return dedup_ref_count; + } + + std::sort(joins.begin(), joins.end(), + [](const JoinWithGeneratedDedupRef &lhs, const JoinWithGeneratedDedupRef &rhs) { + return lhs.depth > rhs.depth; + }); + + if (!joins.empty() && HasSelection(*delim_join.children[0])) { + joins.erase(joins.begin()); + } + + for (auto &join : joins) { + if (join.filter_cross_product) { + RemoveFilterCrossProduct(join.join.get()); + } else { + RemoveJoin(join.join.get()); + } + } + dedup_ref_count = CountGeneratedDedupRefs(*delim_join.children[1]); + return dedup_ref_count; +} + +struct GeneratedDomainRef { + optional_ptr cte_ref; + vector source_bindings; + vector output_bindings; + vector> output_expressions; + vector> filters; +}; + +class GeneratedDomainJoinEliminator { +public: + GeneratedDomainJoinEliminator(unique_ptr &rewrite_root, + const vector &generated_dedup_cte_indexes); + + bool Rewrite(); + +private: + void CollectCTEs(LogicalOperator &op); + optional_ptr FindCTE(TableIndex cte_index) const; + bool TryRewriteOnce(unique_ptr &op); + + unique_ptr GetGeneratedDedupRef(LogicalOperator &op, bool collect_filters = false, + bool allow_projection = false) const; + unique_ptr GetGeneratedDomainDefinition(LogicalOperator &op) const; + unique_ptr GetGeneratedDomainRef(LogicalOperator &op, bool collect_filters = false, + bool allow_projection = false) const; + + optional_idx FindOutputBinding(Expression &expr, const vector &bindings) const; + bool ContainsRecursiveCTERef(LogicalOperator &op) const; + + bool AddReplacement(vector &replacements, ColumnBinding old_binding, + ColumnBinding new_binding) const; + bool RemoveGeneratedDedupJoin(unique_ptr &join); + bool RemoveGeneratedDomainJoin(unique_ptr &join); + +private: + unique_ptr &rewrite_root; + const vector &generated_dedup_cte_indexes; + vector> ctes; +}; + +GeneratedDomainJoinEliminator::GeneratedDomainJoinEliminator(unique_ptr &rewrite_root, + const vector &generated_dedup_cte_indexes) + : rewrite_root(rewrite_root), generated_dedup_cte_indexes(generated_dedup_cte_indexes) { +} + +void GeneratedDomainJoinEliminator::CollectCTEs(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_MATERIALIZED_CTE || + op.type == LogicalOperatorType::LOGICAL_RECURSIVE_CTE) { + ctes.push_back(op.Cast()); + } + for (auto &child : op.children) { + CollectCTEs(*child); + } +} + +optional_ptr GeneratedDomainJoinEliminator::FindCTE(TableIndex cte_index) const { + for (auto &cte : ctes) { + if (cte.get().table_index == cte_index) { + return cte.get(); + } + } + return nullptr; +} + +unique_ptr GeneratedDomainJoinEliminator::GetGeneratedDedupRef(LogicalOperator &op, + bool collect_filters, + bool allow_projection) const { + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + auto &cteref = op.Cast(); + if (std::find(generated_dedup_cte_indexes.begin(), generated_dedup_cte_indexes.end(), cteref.cte_index) == + generated_dedup_cte_indexes.end()) { + return nullptr; + } + + auto result = make_uniq(); + result->cte_ref = cteref; + result->output_bindings = cteref.GetColumnBindings(); + result->output_expressions.reserve(result->output_bindings.size()); + for (idx_t col_idx = 0; col_idx < result->output_bindings.size(); col_idx++) { + result->output_expressions.push_back( + make_uniq(cteref.chunk_types[col_idx], result->output_bindings[col_idx])); + } + return result; + } + if (op.type == LogicalOperatorType::LOGICAL_FILTER) { + auto &filter = op.Cast(); + if (filter.HasProjectionMap() || filter.children.size() != 1) { + return nullptr; + } + auto result = GetGeneratedDedupRef(*filter.children[0], collect_filters, allow_projection); + if (!result) { + return nullptr; + } + if (collect_filters) { + for (auto &expr : filter.expressions) { + auto filter_expr = expr->Copy(); + ReplaceExpressionBindings(filter_expr, result->output_bindings, result->output_expressions); + result->filters.push_back(std::move(filter_expr)); + } + } + return result; + } + if (allow_projection && op.type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = op.Cast(); + if (projection.children.size() != 1) { + return nullptr; + } + auto result = GetGeneratedDedupRef(*projection.children[0], collect_filters, allow_projection); + if (!result) { + return nullptr; + } + + auto child_bindings = result->output_bindings; + auto child_expressions = std::move(result->output_expressions); + result->output_bindings = projection.GetColumnBindings(); + result->output_expressions.clear(); + result->output_expressions.reserve(projection.expressions.size()); + for (auto &expr : projection.expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, child_bindings, child_expressions); + result->output_expressions.push_back(std::move(rewritten_expr)); + } + result->has_projection = true; + return result; + } + return nullptr; +} + +unique_ptr GeneratedDomainJoinEliminator::GetGeneratedDomainDefinition(LogicalOperator &op) const { + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + auto &cteref = op.Cast(); + if (std::find(generated_dedup_cte_indexes.begin(), generated_dedup_cte_indexes.end(), cteref.cte_index) == + generated_dedup_cte_indexes.end()) { + return nullptr; + } + + auto result = make_uniq(); + result->source_bindings = cteref.GetColumnBindings(); + result->output_bindings = result->source_bindings; + result->output_expressions.reserve(result->output_bindings.size()); + for (idx_t col_idx = 0; col_idx < result->output_bindings.size(); col_idx++) { + result->output_expressions.push_back( + make_uniq(cteref.chunk_types[col_idx], result->output_bindings[col_idx])); + } + return result; + } + if (op.type == LogicalOperatorType::LOGICAL_FILTER) { + auto &filter = op.Cast(); + if (filter.HasProjectionMap() || filter.children.size() != 1) { + return nullptr; + } + auto result = GetGeneratedDomainDefinition(*filter.children[0]); + if (!result) { + return nullptr; + } + for (auto &expr : filter.expressions) { + auto filter_expr = expr->Copy(); + ReplaceExpressionBindings(filter_expr, result->output_bindings, result->output_expressions); + result->filters.push_back(std::move(filter_expr)); + } + return result; + } + if (op.type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = op.Cast(); + if (projection.children.size() != 1) { + return nullptr; + } + auto result = GetGeneratedDomainDefinition(*projection.children[0]); + if (!result) { + return nullptr; + } + + auto child_bindings = result->output_bindings; + auto child_expressions = std::move(result->output_expressions); + result->output_bindings = projection.GetColumnBindings(); + result->output_expressions.clear(); + result->output_expressions.reserve(projection.expressions.size()); + for (auto &expr : projection.expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, child_bindings, child_expressions); + result->output_expressions.push_back(std::move(rewritten_expr)); + } + return result; + } + return nullptr; +} + +unique_ptr GeneratedDomainJoinEliminator::GetGeneratedDomainRef(LogicalOperator &op, + bool collect_filters, + bool allow_projection) const { + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + auto &cteref = op.Cast(); + if (std::find(generated_dedup_cte_indexes.begin(), generated_dedup_cte_indexes.end(), cteref.cte_index) != + generated_dedup_cte_indexes.end()) { + return nullptr; + } + + auto cte = FindCTE(cteref.cte_index); + if (!cte || cte->children.empty()) { + return nullptr; + } + if (cte->type == LogicalOperatorType::LOGICAL_RECURSIVE_CTE) { + return nullptr; + } + auto result = GetGeneratedDomainDefinition(*cte->children[0]); + if (!result || result->output_expressions.size() != cteref.chunk_types.size()) { + return nullptr; + } + + result->cte_ref = cteref; + result->output_bindings = cteref.GetColumnBindings(); + return result; + } + if (op.type == LogicalOperatorType::LOGICAL_FILTER) { + auto &filter = op.Cast(); + if (filter.HasProjectionMap() || filter.children.size() != 1) { + return nullptr; + } + auto result = GetGeneratedDomainRef(*filter.children[0], collect_filters, allow_projection); + if (!result) { + return nullptr; + } + if (collect_filters) { + for (auto &expr : filter.expressions) { + auto filter_expr = expr->Copy(); + ReplaceExpressionBindings(filter_expr, result->output_bindings, result->output_expressions); + result->filters.push_back(std::move(filter_expr)); + } + } + return result; + } + if (allow_projection && op.type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = op.Cast(); + if (projection.children.size() != 1) { + return nullptr; + } + auto result = GetGeneratedDomainRef(*projection.children[0], collect_filters, allow_projection); + if (!result) { + return nullptr; + } + + auto child_bindings = result->output_bindings; + auto child_expressions = std::move(result->output_expressions); + result->output_bindings = projection.GetColumnBindings(); + result->output_expressions.clear(); + result->output_expressions.reserve(projection.expressions.size()); + for (auto &expr : projection.expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, child_bindings, child_expressions); + result->output_expressions.push_back(std::move(rewritten_expr)); + } + return result; + } + return nullptr; +} + +optional_idx GeneratedDomainJoinEliminator::FindOutputBinding(Expression &expr, + const vector &bindings) const { + optional_idx result; + bool unsupported = false; + ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { + if (unsupported || colref.Depth() != 0) { + return; + } + auto binding_idx = FindBindingIndex(bindings, colref.Binding()); + if (!binding_idx.IsValid()) { + return; + } + if (result.IsValid() && result.GetIndex() != binding_idx.GetIndex()) { + unsupported = true; + return; + } + result = binding_idx; + }); + return unsupported ? optional_idx() : result; +} + +bool GeneratedDomainJoinEliminator::ContainsRecursiveCTERef(LogicalOperator &op) const { + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + auto &cteref = op.Cast(); + for (auto &cte : ctes) { + if (cte.get().table_index == cteref.cte_index) { + return cte.get().type == LogicalOperatorType::LOGICAL_RECURSIVE_CTE; + } + } + } + for (auto &child : op.children) { + if (ContainsRecursiveCTERef(*child)) { + return true; + } + } + return false; +} + +bool GeneratedDomainJoinEliminator::AddReplacement(vector &replacements, ColumnBinding old_binding, + ColumnBinding new_binding) const { + if (old_binding == new_binding) { + return true; + } + for (auto &replacement : replacements) { + if (replacement.old_binding == old_binding) { + return replacement.new_binding == new_binding; + } + } + replacements.emplace_back(old_binding, new_binding); + return true; +} + +bool GeneratedDomainJoinEliminator::RemoveGeneratedDedupJoin(unique_ptr &join) { + auto &comparison_join = join->Cast(); + if (comparison_join.join_type != JoinType::INNER && + (comparison_join.join_type != JoinType::SEMI || !GetGeneratedDedupRef(*join->children[1], false, true))) { + return false; + } + + auto left_is_generated = GetGeneratedDedupRef(*join->children[0], false, true) != nullptr; + auto right_is_generated = GetGeneratedDedupRef(*join->children[1], false, true) != nullptr; + if (left_is_generated == right_is_generated) { + return false; + } + const idx_t dedup_idx = left_is_generated ? 0 : 1; + if (ContainsRecursiveCTERef(*join->children[1 - dedup_idx])) { + return false; + } + auto dedup_ref = GetGeneratedDedupRef(*join->children[dedup_idx], true, true); + if (!dedup_ref) { + return false; + } + + vector covered_dedup_bindings; + vector base_replacement_bindings; + vector> base_replacement_expressions; + vector> join_filter_expressions; + ColumnBindingReplacer column_replacer; + auto &replacement_bindings = column_replacer.replacement_bindings; + for (auto &cond : comparison_join.conditions) { + if (!cond.IsComparison() || !IsEqualityJoinCondition(cond)) { + return false; + } + auto lhs_generated_idx = FindOutputBinding(cond.GetLHS(), dedup_ref->output_bindings); + auto rhs_generated_idx = FindOutputBinding(cond.GetRHS(), dedup_ref->output_bindings); + if (lhs_generated_idx.IsValid() == rhs_generated_idx.IsValid()) { + return false; + } + auto generated_idx = lhs_generated_idx.IsValid() ? lhs_generated_idx.GetIndex() : rhs_generated_idx.GetIndex(); + auto &generated_binding = dedup_ref->output_bindings[generated_idx]; + auto &generated_expression = *dedup_ref->output_expressions[generated_idx]; + auto &other_side = lhs_generated_idx.IsValid() ? cond.GetRHS() : cond.GetLHS(); + + ColumnBinding other_binding; + if (!GetBoundColumnRefBinding(other_side, other_binding)) { + return false; + } + if (!AddReplacement(replacement_bindings, generated_binding, other_binding)) { + return false; + } + + ColumnBinding base_binding; + if (GetBoundColumnRefBinding(generated_expression, base_binding) && + base_binding.table_index == dedup_ref->cte_ref->table_index) { + if (!AddReplacement(replacement_bindings, base_binding, other_binding)) { + return false; + } + covered_dedup_bindings.emplace_back(base_binding); + if (!AddExpressionReplacement(base_replacement_bindings, base_replacement_expressions, base_binding, + other_side.Copy())) { + return false; + } + } + + join_filter_expressions.push_back( + BoundComparisonExpression::Create(cond.GetComparisonType(), cond.GetLHS().Copy(), cond.GetRHS().Copy())); + } + + auto cte_bindings = LogicalOperator::GenerateColumnBindings(dedup_ref->cte_ref->table_index, + dedup_ref->cte_ref->chunk_types.size()); + if (!CoversAllBindings(cte_bindings, covered_dedup_bindings)) { + return false; + } + + vector> generated_output_replacements; + generated_output_replacements.reserve(dedup_ref->output_expressions.size()); + for (auto &expr : dedup_ref->output_expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, base_replacement_bindings, base_replacement_expressions); + if (ExpressionReferencesBinding(*rewritten_expr, dedup_ref->output_bindings) || + ExpressionReferencesBinding(*rewritten_expr, cte_bindings)) { + return false; + } + generated_output_replacements.push_back(std::move(rewritten_expr)); + } + + vector> filter_expressions; + for (auto &expr : dedup_ref->filters) { + ReplaceExpressionBindings(expr, base_replacement_bindings, base_replacement_expressions); + if (ExpressionReferencesBinding(*expr, dedup_ref->output_bindings) || + ExpressionReferencesBinding(*expr, cte_bindings)) { + return false; + } + filter_expressions.push_back(std::move(expr)); + } + for (auto &expr : join_filter_expressions) { + ReplaceExpressionBindings(expr, dedup_ref->output_bindings, generated_output_replacements); + ReplaceExpressionBindings(expr, base_replacement_bindings, base_replacement_expressions); + if (ExpressionReferencesBinding(*expr, dedup_ref->output_bindings) || + ExpressionReferencesBinding(*expr, cte_bindings)) { + return false; + } + filter_expressions.push_back(std::move(expr)); + } + + unique_ptr replacement_op = std::move(comparison_join.children[1 - dedup_idx]); + for (auto &expr : filter_expressions) { + AddFilterToOperator(replacement_op, std::move(expr)); + } + join = std::move(replacement_op); + + column_replacer.VisitOperator(*rewrite_root); + ReplaceOperatorBindings(*rewrite_root, dedup_ref->output_bindings, generated_output_replacements); + return true; +} + +bool GeneratedDomainJoinEliminator::RemoveGeneratedDomainJoin(unique_ptr &join) { + auto &comparison_join = join->Cast(); + if (comparison_join.join_type != JoinType::INNER) { + return false; + } + + auto left_generated = GetGeneratedDedupRef(*join->children[0], false, true); + auto right_generated = GetGeneratedDedupRef(*join->children[1], false, true); + auto left_domain = GetGeneratedDomainRef(*join->children[0], false, true); + auto right_domain = GetGeneratedDomainRef(*join->children[1], false, true); + auto left_is_generated = left_generated != nullptr; + auto right_is_generated = right_generated != nullptr; + auto left_is_domain = left_domain != nullptr; + auto right_is_domain = right_domain != nullptr; + if (left_is_generated == right_is_generated || left_is_domain == right_is_domain) { + return false; + } + if (left_is_generated == left_is_domain) { + return false; + } + + const idx_t generated_idx = left_is_generated ? 0 : 1; + const idx_t domain_idx = 1 - generated_idx; + auto generated_ref = GetGeneratedDedupRef(*join->children[generated_idx], true, true); + auto domain_ref = GetGeneratedDomainRef(*join->children[domain_idx], true, true); + if (!generated_ref || !domain_ref) { + return false; + } + + vector source_replacement_bindings; + vector> source_replacement_expressions; + vector covered_source_bindings; + vector> join_filter_expressions; + for (auto &cond : comparison_join.conditions) { + if (!cond.IsComparison() || !IsEqualityJoinCondition(cond)) { + return false; + } + + auto lhs_generated_idx = FindOutputBinding(cond.GetLHS(), generated_ref->output_bindings); + auto rhs_generated_idx = FindOutputBinding(cond.GetRHS(), generated_ref->output_bindings); + auto lhs_domain_idx = FindOutputBinding(cond.GetLHS(), domain_ref->output_bindings); + auto rhs_domain_idx = FindOutputBinding(cond.GetRHS(), domain_ref->output_bindings); + + const bool lhs_generated = lhs_generated_idx.IsValid(); + const bool rhs_generated = rhs_generated_idx.IsValid(); + const bool lhs_domain = lhs_domain_idx.IsValid(); + const bool rhs_domain = rhs_domain_idx.IsValid(); + if (lhs_generated == rhs_generated || lhs_domain == rhs_domain || lhs_generated == lhs_domain) { + return false; + } + + auto domain_output_idx = lhs_domain ? lhs_domain_idx.GetIndex() : rhs_domain_idx.GetIndex(); + auto &domain_expression = *domain_ref->output_expressions[domain_output_idx]; + ColumnBinding source_binding; + if (GetBoundColumnRefBinding(domain_expression, source_binding) && + FindBindingIndex(domain_ref->source_bindings, source_binding).IsValid()) { + auto generated_expression = lhs_generated ? cond.GetLHS().Copy() : cond.GetRHS().Copy(); + if (!AddExpressionReplacement(source_replacement_bindings, source_replacement_expressions, source_binding, + std::move(generated_expression))) { + return false; + } + covered_source_bindings.emplace_back(source_binding); + } + + join_filter_expressions.push_back( + BoundComparisonExpression::Create(cond.GetComparisonType(), cond.GetLHS().Copy(), cond.GetRHS().Copy())); + } + if (!CoversAllBindings(domain_ref->source_bindings, covered_source_bindings)) { + return false; + } + + vector> domain_output_replacements; + domain_output_replacements.reserve(domain_ref->output_expressions.size()); + for (auto &expr : domain_ref->output_expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, source_replacement_bindings, source_replacement_expressions); + if (ExpressionReferencesBinding(*rewritten_expr, domain_ref->output_bindings) || + ExpressionReferencesBinding(*rewritten_expr, domain_ref->source_bindings)) { + return false; + } + domain_output_replacements.push_back(std::move(rewritten_expr)); + } + + vector> filter_expressions; + for (auto &expr : domain_ref->filters) { + ReplaceExpressionBindings(expr, source_replacement_bindings, source_replacement_expressions); + if (ExpressionReferencesBinding(*expr, domain_ref->output_bindings) || + ExpressionReferencesBinding(*expr, domain_ref->source_bindings)) { + return false; + } + filter_expressions.push_back(std::move(expr)); + } + for (auto &expr : join_filter_expressions) { + ReplaceExpressionBindings(expr, domain_ref->output_bindings, domain_output_replacements); + ReplaceExpressionBindings(expr, source_replacement_bindings, source_replacement_expressions); + if (ExpressionReferencesBinding(*expr, domain_ref->output_bindings) || + ExpressionReferencesBinding(*expr, domain_ref->source_bindings)) { + return false; + } + filter_expressions.push_back(std::move(expr)); + } + + unique_ptr replacement_op = std::move(comparison_join.children[generated_idx]); + for (auto &expr : filter_expressions) { + AddFilterToOperator(replacement_op, std::move(expr)); + } + join = std::move(replacement_op); + + ReplaceOperatorBindings(*rewrite_root, domain_ref->output_bindings, domain_output_replacements); + return true; +} + +bool GeneratedDomainJoinEliminator::TryRewriteOnce(unique_ptr &op) { + for (auto &child : op->children) { + if (TryRewriteOnce(child)) { + return true; + } + } + if (op->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return false; + } + return RemoveGeneratedDomainJoin(op) || RemoveGeneratedDedupJoin(op); +} + +bool GeneratedDomainJoinEliminator::Rewrite() { + if (generated_dedup_cte_indexes.empty()) { + return false; + } + + bool changed = false; + while (true) { + ctes.clear(); + CollectCTEs(*rewrite_root); + if (!TryRewriteOnce(rewrite_root)) { + break; + } + changed = true; + } + return changed; +} + +static bool SingleJoinRHSIsDeduplicated(LogicalComparisonJoin &join) { + if (join.join_type != JoinType::SINGLE) { + return false; + } + + vector join_bindings; + for (const auto &cond : join.conditions) { + if (!IsEqualityJoinCondition(cond)) { + return false; + } + if (!cond.IsComparison() || cond.GetRHS().GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + auto &colref = cond.GetRHS().Cast(); + join_bindings.emplace_back(colref.Binding()); + } + + reference current_op = *join.children[1]; + while (current_op.get().type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + if (current_op.get().children.size() != 1) { + return false; + } + + switch (current_op.get().type) { + case LogicalOperatorType::LOGICAL_PROJECTION: + if (!FindAndReplaceBindings(join_bindings, current_op.get().expressions, + current_op.get().GetColumnBindings())) { + return false; + } + break; + case LogicalOperatorType::LOGICAL_FILTER: + break; + default: + return false; + } + current_op = *current_op.get().children[0]; + } + + auto &aggr = current_op.get().Cast(); + if (!aggr.grouping_functions.empty()) { + return false; + } + + for (idx_t group_idx = 0; group_idx < aggr.groups.size(); group_idx++) { + if (std::find(join_bindings.begin(), join_bindings.end(), + ColumnBinding(aggr.group_index, ProjectionIndex(group_idx))) == join_bindings.end()) { + return false; + } + } + + return true; +} + +DelimJoinCTERewriter::DelimJoinCTERewriter(Binder &binder) : binder(binder) { + auto &config = DBConfig::GetConfig(binder.context); + cte_deliminator_enabled = + Settings::Get(binder.context) && + config.options.disabled_optimizers.find(OptimizerType::DELIMINATOR) == config.options.disabled_optimizers.end(); +} + +void DelimJoinCTERewriter::MaterializeDelimJoinAsCTE(unique_ptr &plan, LogicalOperator &rewrite_root, + bool null_rejecting_filter_above) { + auto &join = plan->Cast(); + if (join.delim_flipped) { + throw InternalException("Flatten dependent joins - flipped delim join CTE rewrite not supported"); + } + + plan->type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; + if (join.join_type == JoinType::MARK) { + // Match the LOGICAL_DELIM_JOIN filter-pushdown semantics: the mark column can still be required above + // this join, so pushing NOT(mark) into the join must not drop it by rewriting to ANTI. + join.convert_mark_to_semi = false; + } + + auto dedup_cte_index = binder.GenerateTableIndex(); + auto dedup_ref_count = RewriteDelimScanReferences(plan->children[1], dedup_cte_index); + if (cte_deliminator_enabled) { + auto cte_deliminator_timer = + QueryProfiler::Get(binder.context).StartTimerInternal(CTE_DELIMINATOR_PROFILER_KEY); + GeneratedDedupRefEliminator eliminator(join, dedup_cte_index, dedup_ref_count, rewrite_root); + dedup_ref_count = eliminator.Remove(); + if (SingleJoinRHSIsDeduplicated(join)) { + join.join_type = null_rejecting_filter_above ? JoinType::INNER : JoinType::LEFT; + } + } + if (dedup_ref_count == 0) { + join.duplicate_eliminated_columns.clear(); + return; + } + generated_dedup_cte_indexes.push_back(dedup_cte_index); + + plan->children[0]->ResolveOperatorTypes(); + auto left_bindings = plan->children[0]->GetColumnBindings(); + auto left_types = plan->children[0]->types; + auto visible_left_column_count = left_bindings.size(); + + vector dedup_column_indices; + vector dedup_types; + vector dedup_names; + vector> extra_left_expressions; + dedup_column_indices.reserve(join.duplicate_eliminated_columns.size()); + dedup_types.reserve(join.duplicate_eliminated_columns.size()); + dedup_names.reserve(join.duplicate_eliminated_columns.size()); + for (auto &expr : join.duplicate_eliminated_columns) { + optional_idx binding_index; + if (expr->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { + auto &colref_expr = expr->Cast(); + binding_index = FindBindingIndex(left_bindings, colref_expr.Binding()); + } + if (binding_index.IsValid()) { + dedup_column_indices.push_back(binding_index.GetIndex()); + } else { + dedup_column_indices.push_back(left_bindings.size() + extra_left_expressions.size()); + extra_left_expressions.push_back(expr->Copy()); + } + dedup_types.push_back(expr->GetReturnType()); + dedup_names.push_back(expr->GetName()); + } + + if (!extra_left_expressions.empty()) { + vector> expressions; + expressions.reserve(left_bindings.size() + extra_left_expressions.size()); + for (idx_t i = 0; i < left_bindings.size(); i++) { + expressions.push_back(make_uniq(left_types[i], left_bindings[i])); + } + for (auto &expr : extra_left_expressions) { + expressions.push_back(std::move(expr)); + } + auto projection = make_uniq(binder.GenerateTableIndex(), std::move(expressions)); + projection->children.push_back(std::move(plan->children[0])); + plan->children[0] = std::move(projection); + plan->children[0]->ResolveOperatorTypes(); + left_bindings = plan->children[0]->GetColumnBindings(); + left_types = plan->children[0]->types; + if (join.left_projection_map.empty()) { + join.left_projection_map.reserve(visible_left_column_count); + for (idx_t i = 0; i < visible_left_column_count; i++) { + join.left_projection_map.emplace_back(i); + } + } + } + + auto left_column_count = left_bindings.size(); + auto cte_source_bindings = left_bindings; + vector> cte_source_expressions; + cte_source_expressions.reserve(left_column_count); + for (idx_t i = 0; i < left_column_count; i++) { + cte_source_expressions.push_back(make_uniq(left_types[i], left_bindings[i])); + } + auto cte_source = make_uniq(binder.GenerateTableIndex(), std::move(cte_source_expressions)); + cte_source->children.push_back(std::move(plan->children[0])); + cte_source->ResolveOperatorTypes(); + left_types = cte_source->types; + + auto cte_index = binder.GenerateTableIndex(); + auto cte_name = Identifier("__duckdb_delim_" + to_string(cte_index.index)); + + auto left_cte_ref_index = binder.GenerateTableIndex(); + auto left_cte_ref = make_uniq(left_cte_ref_index, cte_index, left_types, + GenerateCTEColumnNames(left_column_count, "__duckdb_delim_col_")); + auto new_left_bindings = left_cte_ref->GetColumnBindings(); + auto binding_replacements = CreateBindingReplacements(cte_source_bindings, new_left_bindings); + + plan->children[0] = std::move(left_cte_ref); + ColumnBindingReplacer replacer; + replacer.replacement_bindings = binding_replacements; + replacer.stop_operator = plan->children[1]; + replacer.VisitOperator(*plan); + + join.duplicate_eliminated_columns.clear(); + + auto dedup_group_index = binder.GenerateTableIndex(); + auto dedup_aggregate_index = binder.GenerateTableIndex(); + vector> dedup_aggrs; + auto dedup = make_uniq(dedup_group_index, dedup_aggregate_index, std::move(dedup_aggrs)); + auto dedup_child_index = binder.GenerateTableIndex(); + auto dedup_child = make_uniq(dedup_child_index, cte_index, left_types, + GenerateCTEColumnNames(left_column_count, "__duckdb_delim_col_")); + auto dedup_child_bindings = dedup_child->GetColumnBindings(); + for (idx_t i = 0; i < dedup_column_indices.size(); i++) { + auto colref = make_uniq(dedup_names[i], dedup_types[i], + dedup_child_bindings[dedup_column_indices[i]]); + auto new_group_index = ColumnBinding::PushExpression(dedup->groups, std::move(colref)); + for (auto &set : dedup->grouping_sets) { + set.insert(new_group_index); + } + } + dedup->children.push_back(std::move(dedup_child)); + + auto dedup_cte_name = Identifier("__duckdb_delim_dedup_" + to_string(dedup_cte_index.index)); + auto dedup_cte_child = CreateIdentityProjection(binder, std::move(plan)); + auto dedup_cte = + make_uniq(dedup_cte_name, dedup_cte_index, dedup_types.size(), std::move(dedup), + std::move(dedup_cte_child), CTEMaterialize::CTE_MATERIALIZE_DEFAULT); + auto cte_child = CreateIdentityProjection(binder, std::move(dedup_cte)); + auto cte = make_uniq(cte_name, cte_index, left_column_count, std::move(cte_source), + std::move(cte_child), CTEMaterialize::CTE_MATERIALIZE_DEFAULT); + plan = std::move(cte); +} + +void DelimJoinCTERewriter::RewriteDelimJoinsToCTEs(unique_ptr &plan, LogicalOperator &rewrite_root, + bool null_rejecting_filter_above) { + for (auto &child : plan->children) { + auto old_child_bindings = child->GetColumnBindings(); + bool child_null_rejecting_filter_above = false; + if (plan->type == LogicalOperatorType::LOGICAL_FILTER && + child->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + auto &filter = plan->Cast(); + auto &delim_join = child->Cast(); + child_null_rejecting_filter_above = FilterNullRejectsDelimJoinRHS(filter, delim_join); + } + RewriteDelimJoinsToCTEs(child, rewrite_root, child_null_rejecting_filter_above); + RewriteChangedChildBindings(*plan, *child, old_child_bindings); + } + if (plan->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + MaterializeDelimJoinAsCTE(plan, rewrite_root, null_rejecting_filter_above); + } +} + +void DelimJoinCTERewriter::Rewrite(Binder &binder, unique_ptr &plan) { + DelimJoinCTERewriter rewriter(binder); + rewriter.Rewrite(plan); +} + +void DelimJoinCTERewriter::Rewrite(unique_ptr &plan) { + bool filters_pushed; + do { + filters_pushed = PushEligibleFiltersIntoDelimJoinInputs(plan); + } while (filters_pushed); + RewriteDelimJoinsToCTEs(plan, *plan); + if (cte_deliminator_enabled) { + auto cte_deliminator_timer = + QueryProfiler::Get(binder.context).StartTimerInternal(CTE_DELIMINATOR_PROFILER_KEY); + GeneratedDomainJoinEliminator generated_domain_join_eliminator(plan, generated_dedup_cte_indexes); + generated_domain_join_eliminator.Rewrite(); + } + VerifyNoDelim(*plan); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp index 26f77906c..0388e83b3 100644 --- a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp +++ b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp @@ -1,4 +1,5 @@ #include "duckdb/planner/subquery/flatten_dependent_join.hpp" +#include "duckdb/planner/subquery/delim_join_cte_rewriter.hpp" #include "duckdb/common/operator/add.hpp" #include "duckdb/common/exception/parser_exception.hpp" @@ -6,6 +7,7 @@ #include "duckdb/function/aggregate/distributive_functions.hpp" #include "duckdb/function/aggregate/distributive_function_utils.hpp" #include "duckdb/function/window/rows_functions.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" @@ -15,6 +17,8 @@ #include "duckdb/planner/subquery/rewrite_correlated_expressions.hpp" #include "duckdb/planner/operator/logical_dependent_join.hpp" +#include + namespace duckdb { static bool HasCTEAccessor(LogicalOperator &op, TableIndex table_index) { @@ -64,8 +68,8 @@ static bool DependsOnCorrelatedWalk(LogicalOperator &op, Binder &binder, } ExpressionIterator::VisitExpression( **expr_ptr, [&](const BoundColumnRefExpression &bound_colref) { - result |= bound_colref.depth > 0 && - correlated_aliases.find(bound_colref.binding) != correlated_aliases.end(); + result |= bound_colref.Depth() > 0 && + correlated_aliases.find(bound_colref.Binding()) != correlated_aliases.end(); }); }); } @@ -176,7 +180,7 @@ void FlattenDependentJoins::PatchAccessingOperators(LogicalOperator &subtree_roo if (reader.cte_index == table_index && reader.correlated_columns == 0) { for (auto &column : correlated_columns) { reader.chunk_types.push_back(column.type); - reader.bound_columns.push_back(column.name); + reader.bound_columns.emplace_back(column.name); } reader.correlated_columns += correlated_columns.size(); } @@ -228,6 +232,9 @@ unique_ptr FlattenDependentJoins::DecorrelateIndependent(Binder CorrelatedColumns correlated; FlattenDependentJoins flatten(binder, correlated); flatten.DecorrelateSubtree(plan, true, {}); + if (Settings::Get(binder.context)) { + DelimJoinCTERewriter::Rewrite(binder, plan); + } return plan; } @@ -293,10 +300,10 @@ static unique_ptr CreateRowNumberWindow(Binder &binder, unique_pt auto window = make_uniq(table_index); auto row_number = RowNumberFun::GetFunction().Bind(binder.context); - row_number->partitions = std::move(partitions); - row_number->orders = std::move(orders); - row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; - row_number->end = WindowBoundary::CURRENT_ROW_ROWS; + row_number->PartitionsMutable() = std::move(partitions); + row_number->OrderByMutable() = std::move(orders); + row_number->WindowStartMutable() = WindowBoundary::UNBOUNDED_PRECEDING; + row_number->WindowEndMutable() = WindowBoundary::CURRENT_ROW_ROWS; row_number->SetAlias("limit_rownum"); window->expressions.push_back(std::move(row_number)); @@ -606,7 +613,7 @@ vector FlattenDependentJoins::PushDownAggregate(unique_ptrGetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &bound_func = aggr.expressions[i]->Cast().function; + auto &bound_func = aggr.expressions[i]->Cast().Function(); auto count_fun = CountFunctionBase::GetFunction(); auto count_star_fun = CountStarFun::GetFunction(); @@ -781,7 +788,7 @@ vector FlattenDependentJoins::PushDownWindow(unique_ptrGetExpressionClass() == ExpressionClass::BOUND_WINDOW); auto &w = expr->Cast(); - AppendCorrelatedColumns(w.partitions, state, false); + AppendCorrelatedColumns(w.PartitionsMutable(), state, false); } return state; } diff --git a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp index 00dd73f62..747176d34 100644 --- a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp +++ b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp @@ -52,20 +52,20 @@ void RewriteCorrelatedExpressions::VisitOperator(LogicalOperator &op) { unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) { - if (expr.depth == 0) { + if (expr.Depth() == 0) { return nullptr; } - auto alias_entry = correlated_aliases.find(expr.binding); + auto alias_entry = correlated_aliases.find(expr.Binding()); if (alias_entry == correlated_aliases.end()) { return nullptr; } auto current_entry = current_binding_map.find(alias_entry->second); D_ASSERT(current_entry != current_binding_map.end()); - auto original_binding = expr.binding; - expr.binding = current_entry->second; - RegisterCorrelatedBinding(original_binding, expr.binding); - D_ASSERT(expr.depth > 0); - expr.depth--; + auto original_binding = expr.Binding(); + expr.BindingMutable() = current_entry->second; + RegisterCorrelatedBinding(original_binding, expr.Binding()); + D_ASSERT(expr.Depth() > 0); + expr.DepthMutable()--; return nullptr; } @@ -80,12 +80,12 @@ void RewriteCountAggregates::Rewrite(LogicalOperator &op, column_binding_map_t RewriteCountAggregates::VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) { - auto entry = replacement_map.find(expr.binding); + auto entry = replacement_map.find(expr.Binding()); if (entry != replacement_map.end()) { // reference to a COUNT(*) aggregate // replace this with CASE WHEN COUNT(*) IS NULL THEN 0 ELSE COUNT(*) END auto is_null = make_uniq(ExpressionType::OPERATOR_IS_NULL, LogicalType::BOOLEAN); - is_null->children.push_back(expr.Copy()); + is_null->GetChildrenMutable().push_back(expr.Copy()); auto check = std::move(is_null); auto result_if_true = make_uniq(Value::Numeric(expr.GetReturnType(), 0)); auto result_if_false = std::move(*expr_ptr); diff --git a/src/duckdb/src/planner/table_binding.cpp b/src/duckdb/src/planner/table_binding.cpp index 7a61ad250..385437c2d 100644 --- a/src/duckdb/src/planner/table_binding.cpp +++ b/src/duckdb/src/planner/table_binding.cpp @@ -15,10 +15,13 @@ namespace duckdb { -Binding::Binding(BindingType binding_type, BindingAlias alias_p, vector coltypes, vector colnames, - TableIndex index) - : binding_type(binding_type), alias(std::move(alias_p)), index(index), types(std::move(coltypes)), - names(std::move(colnames)) { +Binding::Binding(BindingType binding_type, BindingAlias alias_p, vector coltypes, + vector colnames, TableIndex index) + : binding_type(binding_type), alias(std::move(alias_p)), index(index), types(std::move(coltypes)) { + names.reserve(colnames.size()); + for (auto &colname : colnames) { + names.emplace_back(std::move(colname)); + } Initialize(); } @@ -50,7 +53,7 @@ const vector &Binding::GetColumnTypes() { return types; } -const vector &Binding::GetColumnNames() { +const vector &Binding::GetColumnNames() { return names; } @@ -62,11 +65,11 @@ void Binding::SetColumnType(idx_t col_idx, LogicalType type_p) { types[col_idx] = std::move(type_p); } -string Binding::GetAlias() const { +const Identifier &Binding::GetAlias() const { return alias.GetAlias(); } -bool Binding::TryGetBindingIndex(const string &column_name, column_t &result) { +bool Binding::TryGetBindingIndex(const Identifier &column_name, column_t &result) { auto entry = name_map.find(column_name); if (entry == name_map.end()) { return false; @@ -76,7 +79,7 @@ bool Binding::TryGetBindingIndex(const string &column_name, column_t &result) { return true; } -column_t Binding::GetBindingIndex(const string &column_name) { +column_t Binding::GetBindingIndex(const Identifier &column_name) { column_t result; if (!TryGetBindingIndex(column_name, result)) { throw InternalException("Binding index for column \"%s\" not found", column_name); @@ -84,12 +87,12 @@ column_t Binding::GetBindingIndex(const string &column_name) { return result; } -bool Binding::HasMatchingBinding(const string &column_name) { +bool Binding::HasMatchingBinding(const Identifier &column_name) { column_t result; return TryGetBindingIndex(column_name, result); } -ErrorData Binding::ColumnNotFoundError(const string &column_name) const { +ErrorData Binding::ColumnNotFoundError(const Identifier &column_name) const { return ErrorData(ExceptionType::BINDER, StringUtil::Format("Values list \"%s\" does not have a column named \"%s\"", GetAlias(), column_name)); } @@ -108,14 +111,14 @@ BindResult Binding::Bind(ColumnRefExpression &colref, idx_t depth) { if (colref.GetAlias().empty()) { colref.SetAlias(names[column_index]); } - return BindResult(make_uniq(colref.GetName(), sql_type, binding, depth)); + return BindResult(make_uniq(Identifier(colref.GetName()), sql_type, binding, depth)); } optional_ptr Binding::GetStandardEntry() { return nullptr; } -BindingAlias Binding::GetAlias(const string &explicit_alias, const StandardEntry &entry) { +BindingAlias Binding::GetAlias(const Identifier &explicit_alias, const StandardEntry &entry) { if (!explicit_alias.empty()) { return BindingAlias(explicit_alias); } @@ -123,7 +126,7 @@ BindingAlias Binding::GetAlias(const string &explicit_alias, const StandardEntry return BindingAlias(entry); } -BindingAlias Binding::GetAlias(const string &explicit_alias, optional_ptr entry) { +BindingAlias Binding::GetAlias(const Identifier &explicit_alias, optional_ptr entry) { if (!explicit_alias.empty()) { return BindingAlias(explicit_alias); } @@ -134,8 +137,8 @@ BindingAlias Binding::GetAlias(const string &explicit_alias, optional_ptr types_p, vector names_p, TableIndex index, - StandardEntry &entry) +EntryBinding::EntryBinding(const Identifier &alias, vector types_p, vector names_p, + TableIndex index, StandardEntry &entry) : Binding(BindingType::CATALOG_ENTRY, GetAlias(alias, entry), std::move(types_p), std::move(names_p), index), entry(entry) { } @@ -144,7 +147,7 @@ optional_ptr EntryBinding::GetStandardEntry() { return &entry; } -TableBinding::TableBinding(const string &alias, vector types_p, vector names_p, +TableBinding::TableBinding(const Identifier &alias, vector types_p, vector names_p, vector &bound_column_ids, optional_ptr entry, TableIndex index, virtual_column_map_t virtual_columns_p) : Binding(BindingType::TABLE, GetAlias(alias, entry), std::move(types_p), std::move(names_p), index), @@ -172,21 +175,21 @@ TableBinding::TableBinding(const string &alias, vector types_p, vec } static void ReplaceAliases(ParsedExpression &root_expr, const ColumnList &list, - const unordered_map &alias_map) { + const unordered_map &alias_map) { ParsedExpressionIterator::VisitExpressionMutable(root_expr, [&](ColumnRefExpression &colref) { D_ASSERT(!colref.IsQualified()); - auto &col_names = colref.column_names; + auto &col_names = colref.ColumnNamesMutable(); D_ASSERT(col_names.size() == 1); auto idx_entry = list.GetColumnIndex(col_names[0]); auto &alias = alias_map.at(idx_entry.index); - col_names = {alias}; + col_names = vector {alias}; }); } static void BakeTableName(ParsedExpression &root_expr, const BindingAlias &binding_alias) { ParsedExpressionIterator::VisitExpressionMutable(root_expr, [&](ColumnRefExpression &colref) { D_ASSERT(!colref.IsQualified()); - auto &col_names = colref.column_names; + auto &col_names = colref.ColumnNamesMutable(); col_names.insert(col_names.begin(), binding_alias.GetAlias()); if (!binding_alias.GetSchema().empty()) { col_names.insert(col_names.begin(), binding_alias.GetSchema()); @@ -197,7 +200,7 @@ static void BakeTableName(ParsedExpression &root_expr, const BindingAlias &bindi }); } -unique_ptr TableBinding::ExpandGeneratedColumn(const string &column_name) { +unique_ptr TableBinding::ExpandGeneratedColumn(const Identifier &column_name) { auto catalog_entry = GetStandardEntry(); D_ASSERT(catalog_entry); // Should only be called on a TableBinding @@ -209,7 +212,7 @@ unique_ptr TableBinding::ExpandGeneratedColumn(const string &c D_ASSERT(table_entry.GetColumn(LogicalIndex(column_index)).Generated()); // Get a copy of the generated column auto expression = table_entry.GetColumn(LogicalIndex(column_index)).GeneratedExpression().Copy(); - unordered_map alias_map; + unordered_map alias_map; for (auto &entry : name_map) { alias_map[entry.second] = entry.first; } @@ -227,7 +230,7 @@ const vector &TableBinding::GetBoundColumnIds() const { // assert that all entries in the bound_column_ids are unique D_ASSERT(result.second); auto it = std::find_if(name_map.begin(), name_map.end(), - [&](const std::pair &it) { return it.second == id; }); + [&](const std::pair &it) { return it.second == id; }); // assert that every id appears in the name_map D_ASSERT(it != name_map.end()); // the order that they appear in is not guaranteed to be sequential @@ -290,21 +293,22 @@ BindResult TableBinding::Bind(ColumnRefExpression &colref, idx_t depth) { } } ColumnBinding binding = GetColumnBinding(column_index); - return BindResult(make_uniq(colref.GetName(), col_type, binding, depth)); + return BindResult(make_uniq(Identifier(colref.GetName()), col_type, binding, depth)); } optional_ptr TableBinding::GetStandardEntry() { return entry; } -ErrorData TableBinding::ColumnNotFoundError(const string &column_name) const { - auto candidate_message = StringUtil::CandidatesErrorMessage(names, column_name, "Candidate bindings: "); +ErrorData TableBinding::ColumnNotFoundError(const Identifier &column_name) const { + auto candidate_message = StringUtil::CandidatesErrorMessage( + IdentifiersToStrings(names), column_name.GetIdentifierName(), "Candidate bindings: "); return ErrorData(ExceptionType::BINDER, StringUtil::Format("Table \"%s\" does not have a column named \"%s\"\n%s", alias.GetAlias(), column_name, candidate_message)); } -DummyBinding::DummyBinding(vector types, vector names, string dummy_name) - : Binding(BindingType::DUMMY, BindingAlias(DummyBinding::DUMMY_NAME + dummy_name), std::move(types), +DummyBinding::DummyBinding(vector types, vector names, string dummy_name) + : Binding(BindingType::DUMMY, BindingAlias(Identifier(DummyBinding::DUMMY_NAME + dummy_name)), std::move(types), std::move(names), TableIndex()), dummy_name(std::move(dummy_name)) { } @@ -317,17 +321,18 @@ BindResult DummyBinding::Bind(ColumnRefExpression &colref, idx_t depth) { ColumnBinding binding(index, ProjectionIndex(column_index)); // we are binding a parameter to create the dummy binding, no arguments are supplied - return BindResult(make_uniq(colref.GetName(), types[column_index], binding, depth)); + return BindResult( + make_uniq(Identifier(colref.GetName()), types[column_index], binding, depth)); } BindResult DummyBinding::Bind(LambdaRefExpression &lambdaref, idx_t depth) { column_t column_index; - if (!TryGetBindingIndex(lambdaref.GetName(), column_index)) { + if (!TryGetBindingIndex(Identifier(lambdaref.GetName()), column_index)) { throw InternalException("Column %s not found in bindings", lambdaref.GetName()); } ColumnBinding binding(index, ProjectionIndex(column_index)); return BindResult(make_uniq(lambdaref.GetName(), types[column_index], binding, - lambdaref.lambda_idx, depth)); + lambdaref.LambdaIndex(), depth)); } unique_ptr DummyBinding::ParamToArg(ColumnRefExpression &colref) { @@ -340,14 +345,14 @@ unique_ptr DummyBinding::ParamToArg(ColumnRefExpression &colre return arg; } -CTEBinding::CTEBinding(BindingAlias alias, vector types, vector names, TableIndex index, +CTEBinding::CTEBinding(BindingAlias alias, vector types, vector names, TableIndex index, CTEType cte_type) : Binding(BindingType::CTE, std::move(alias), std::move(types), std::move(names), index), cte_type(cte_type), reference_count(0) { } CTEBinding::CTEBinding(BindingAlias alias_p, shared_ptr bind_state_p, TableIndex index) - : Binding(BindingType::CTE, std::move(alias_p), vector(), vector(), index), + : Binding(BindingType::CTE, std::move(alias_p), vector(), vector(), index), cte_type(CTEType::CAN_BE_REFERENCED), reference_count(0), bind_state(std::move(bind_state_p)) { } diff --git a/src/duckdb/src/planner/table_filter.cpp b/src/duckdb/src/planner/table_filter.cpp deleted file mode 100644 index cb22f52da..000000000 --- a/src/duckdb/src/planner/table_filter.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "duckdb/planner/table_filter.hpp" - -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/planner/filter/null_filter.hpp" -#include "duckdb/execution/operator/scan/physical_table_scan.hpp" - -namespace duckdb { - -string TableFilter::DebugToString() const { - return ToString("c0"); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/table_filter_set.cpp b/src/duckdb/src/planner/table_filter_set.cpp index b0ec087aa..bdb3a32d2 100644 --- a/src/duckdb/src/planner/table_filter_set.cpp +++ b/src/duckdb/src/planner/table_filter_set.cpp @@ -15,6 +15,7 @@ #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/serializer/serializer.hpp" @@ -25,10 +26,28 @@ namespace duckdb { struct LegacyStructPathEntry { idx_t child_idx; - string child_name; + Identifier child_name; }; static bool ContainsInternalTableFilterFunction(const Expression &expr) { + if (expr.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &func = expr.Cast(); + if (TableFilterFunctions::IsTableFilterFunction(func.Function())) { + return true; + } + if (func.Function().GetName() == OptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + if (data.child_filter_expr && ContainsInternalTableFilterFunction(*data.child_filter_expr)) { + return true; + } + } + if (func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME && func.BindInfo()) { + auto &data = func.BindInfo()->Cast(); + if (data.child_filter_expr && ContainsInternalTableFilterFunction(*data.child_filter_expr)) { + return true; + } + } + } bool found = false; ExpressionIterator::EnumerateChildren(expr, [&](const Expression &child) { if (found) { @@ -78,16 +97,16 @@ static bool TryExtractLegacySubject(const Expression &expr, vector(); idx_t child_idx; - if (!TryGetStructExtractChildIndex(func, child_idx) || func.children.empty()) { + if (!TryGetStructExtractChildIndex(func, child_idx) || func.GetChildren().empty()) { return false; } - if (!TryExtractLegacySubject(*func.children[0], struct_path)) { + if (!TryExtractLegacySubject(*func.GetChildren()[0], struct_path)) { return false; } - string child_name; - if (func.children[0]->GetReturnType().id() == LogicalTypeId::STRUCT && - !StructType::IsUnnamed(func.children[0]->GetReturnType())) { - child_name = StructType::GetChildName(func.children[0]->GetReturnType(), child_idx); + Identifier child_name; + if (func.GetChildren()[0]->GetReturnType().id() == LogicalTypeId::STRUCT && + !StructType::IsUnnamed(func.GetChildren()[0]->GetReturnType())) { + child_name = StructType::GetChildName(func.GetChildren()[0]->GetReturnType(), child_idx); } struct_path.push_back({child_idx, std::move(child_name)}); return true; @@ -100,7 +119,7 @@ static bool TryExtractLegacySubject(const Expression &expr, vector WrapStructFilterPath(unique_ptr filter, const vector &struct_path) { for (auto it = struct_path.rbegin(); it != struct_path.rend(); ++it) { - filter = make_uniq(it->child_idx, it->child_name, std::move(filter)); + filter = make_uniq(it->child_idx, it->child_name, std::move(filter)); } return filter; } @@ -111,7 +130,7 @@ static void NormalizeLegacyExpression(unique_ptr &expr) { owned_expr = make_uniq(col_ref.GetAlias(), col_ref.GetReturnType(), 0ULL); }); ExpressionIterator::VisitExpressionMutable( - expr, [](BoundReferenceExpression &ref, unique_ptr &owned_expr) { ref.index = 0; }); + expr, [](BoundReferenceExpression &ref, unique_ptr &owned_expr) { ref.IndexMutable() = 0; }); } static unique_ptr TrySerializeComparisonToLegacyFilter(const BoundFunctionExpression &comparison) { @@ -125,7 +144,7 @@ static unique_ptr TrySerializeComparisonToLegacyFilter(const BoundF } auto &subject = rhs_constant ? left : right; auto &constant_expr = rhs_constant ? right : left; - auto &constant = constant_expr.Cast().value; + const auto &constant = constant_expr.Cast().GetValue(); if (!rhs_constant) { comparison_type = FlipComparisonType(comparison_type); } @@ -137,9 +156,9 @@ static unique_ptr TrySerializeComparisonToLegacyFilter(const BoundF if (constant.IsNull()) { switch (comparison_type) { case ExpressionType::COMPARE_DISTINCT_FROM: - return WrapStructFilterPath(make_uniq(), struct_path); + return WrapStructFilterPath(make_uniq(), struct_path); case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return WrapStructFilterPath(make_uniq(), struct_path); + return WrapStructFilterPath(make_uniq(), struct_path); default: return nullptr; } @@ -147,40 +166,40 @@ static unique_ptr TrySerializeComparisonToLegacyFilter(const BoundF if (!IsSupportedConstantComparison(comparison_type)) { return nullptr; } - return WrapStructFilterPath(make_uniq(comparison_type, constant), struct_path); + return WrapStructFilterPath(make_uniq(comparison_type, constant), struct_path); } static unique_ptr TrySerializeOperatorToLegacyFilter(const BoundOperatorExpression &op) { switch (op.GetExpressionType()) { case ExpressionType::OPERATOR_IS_NULL: case ExpressionType::OPERATOR_IS_NOT_NULL: { - if (op.children.size() != 1) { + if (op.GetChildren().size() != 1) { return nullptr; } vector struct_path; - if (!TryExtractLegacySubject(*op.children[0], struct_path)) { + if (!TryExtractLegacySubject(*op.GetChildren()[0], struct_path)) { return nullptr; } if (op.GetExpressionType() == ExpressionType::OPERATOR_IS_NULL) { - return WrapStructFilterPath(make_uniq(), struct_path); + return WrapStructFilterPath(make_uniq(), struct_path); } - return WrapStructFilterPath(make_uniq(), struct_path); + return WrapStructFilterPath(make_uniq(), struct_path); } case ExpressionType::COMPARE_IN: { - if (op.children.empty()) { + if (op.GetChildren().empty()) { return nullptr; } vector struct_path; - if (!TryExtractLegacySubject(*op.children[0], struct_path)) { + if (!TryExtractLegacySubject(*op.GetChildren()[0], struct_path)) { return nullptr; } vector values; - values.reserve(op.children.size() - 1); - for (idx_t i = 1; i < op.children.size(); i++) { - if (op.children[i]->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { + values.reserve(op.GetChildren().size() - 1); + for (idx_t i = 1; i < op.GetChildren().size(); i++) { + if (op.GetChildren()[i]->GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { return nullptr; } - auto value = op.children[i]->Cast().value; + auto value = op.GetChildren()[i]->Cast().GetValue(); if (value.IsNull()) { return nullptr; } @@ -189,24 +208,62 @@ static unique_ptr TrySerializeOperatorToLegacyFilter(const BoundOpe if (values.empty()) { return nullptr; } - return WrapStructFilterPath(make_uniq(std::move(values)), struct_path); + return WrapStructFilterPath(make_uniq(std::move(values)), struct_path); } default: return nullptr; } } +static unique_ptr SerializeOptionalChild(const optional_ptr child_expr) { + if (!child_expr) { + return nullptr; + } + return SerializeExpressionToLegacyFilter(*child_expr); +} + +static unique_ptr SerializeInternalFunctionToLegacyFilter(const BoundFunctionExpression &func_expr) { + auto &func_name = func_expr.Function().GetName(); + if (func_name == OptionalFilterScalarFun::NAME) { + unique_ptr child_filter; + if (func_expr.BindInfo()) { + auto &data = func_expr.BindInfo()->Cast(); + child_filter = SerializeOptionalChild(data.child_filter_expr.get()); + } + return make_uniq(std::move(child_filter)); + } + if (func_name == SelectivityOptionalFilterScalarFun::NAME) { + unique_ptr child_filter; + if (func_expr.BindInfo()) { + auto &data = func_expr.BindInfo()->Cast(); + child_filter = SerializeOptionalChild(data.child_filter_expr.get()); + } + return make_uniq(std::move(child_filter)); + } + if (func_name == DynamicFilterScalarFun::NAME) { + if (!func_expr.BindInfo()) { + return make_uniq(); + } + auto &data = func_expr.BindInfo()->Cast(); + return make_uniq(data.filter_data); + } + if (func_name == BloomFilterScalarFun::NAME || func_name == PerfectHashJoinScalarFun::NAME || + func_name == PrefixRangeScalarFun::NAME) { + return make_uniq(); + } + throw SerializationException("Unsupported internal tablefilter function \"%s\" during serialization", func_name); +} static unique_ptr SerializeConjunctionToLegacyFilter(const BoundConjunctionExpression &conjunction) { - unique_ptr result; + unique_ptr result; if (conjunction.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { - result = make_uniq(); + result = make_uniq(); } else if (conjunction.GetExpressionType() == ExpressionType::CONJUNCTION_OR) { - result = make_uniq(); + result = make_uniq(); } else { throw SerializationException("Unsupported conjunction type %s during table-filter serialization", EnumUtil::ToString(conjunction.GetExpressionType())); } - for (auto &child : conjunction.children) { + for (auto &child : conjunction.GetChildren()) { auto child_filter = SerializeExpressionToLegacyFilter(*child); if (!child_filter) { return nullptr; @@ -232,6 +289,12 @@ static unique_ptr SerializeExpressionToLegacyFilter(const Expressio return result; } } + if (expr.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &func = expr.Cast(); + if (TableFilterFunctions::IsTableFilterFunction(func.Function())) { + return SerializeInternalFunctionToLegacyFilter(func); + } + } if (ContainsInternalTableFilterFunction(expr)) { return nullptr; } @@ -340,7 +403,7 @@ bool TableFilterSet::Equals(TableFilterSet &other) { if (other_entry == other.filters.end()) { return false; } - if (!entry.second->Equals(*other_entry->second)) { + if (!entry.second->Cast().Equals(other_entry->second->Cast())) { return false; } } @@ -360,7 +423,7 @@ bool TableFilterSet::Equals(TableFilterSet *left, TableFilterSet *right) { unique_ptr TableFilterSet::Copy() const { auto copy = make_uniq(); for (auto &it : filters) { - copy->filters.emplace(it.first, it.second->Copy()); + copy->filters.emplace(it.first, it.second->Cast().Copy()); } return copy; } @@ -369,21 +432,18 @@ void TableFilterSet::PushFilter(ProjectionIndex col_idx, unique_ptr if (!col_idx.IsValid()) { throw InternalException("Cannot push a filter over an invalid ProjectionIndex"); } + auto &new_filter = ExpressionFilter::GetExpressionFilter(*filter, "TableFilterSet::PushFilter"); auto entry = filters.find(col_idx); if (entry == filters.end()) { // no filter yet: push the filter directly filters[col_idx] = std::move(filter); } else { // there is already a filter: AND it together - if (entry->second->filter_type == TableFilterType::CONJUNCTION_AND) { - auto &and_filter = entry->second->Cast(); - and_filter.child_filters.push_back(std::move(filter)); - } else { - auto and_filter = make_uniq(); - and_filter->child_filters.push_back(std::move(entry->second)); - and_filter->child_filters.push_back(std::move(filter)); - filters[col_idx] = std::move(and_filter); - } + auto &existing = ExpressionFilter::GetExpressionFilter(*entry->second, "TableFilterSet::PushFilter"); + auto and_expr = make_uniq(ExpressionType::CONJUNCTION_AND); + and_expr->GetChildrenMutable().push_back(std::move(existing.expr)); + and_expr->GetChildrenMutable().push_back(std::move(new_filter.expr)); + filters[col_idx] = make_uniq(std::move(and_expr)); } } @@ -420,12 +480,12 @@ DynamicTableFilterSet::GetFinalTableFilters(const PhysicalTableScan &scan, auto result = make_uniq(); if (existing_filters) { for (auto &filter_entry : *existing_filters) { - result->PushFilter(filter_entry.GetIndex(), filter_entry.Filter().Copy()); + result->PushFilter(filter_entry.GetIndex(), filter_entry.Filter().Cast().Copy()); } } for (auto &entry : filters) { for (auto &filter_entry : *entry.second) { - result->PushFilter(filter_entry.GetIndex(), filter_entry.Filter().Copy()); + result->PushFilter(filter_entry.GetIndex(), filter_entry.Filter().Cast().Copy()); } } if (!result->HasFilters()) { @@ -438,10 +498,6 @@ map> TableFilterSet::GetTableFiltersForSerialization(Serializer &serializer) const { map> result; for (auto &entry : filters) { - if (entry.second->filter_type != TableFilterType::EXPRESSION_FILTER) { - result.emplace(entry.first, entry.second->Copy()); - continue; - } auto &expr_filter = ExpressionFilter::GetExpressionFilter(*entry.second, "TableFilterSet::GetTableFiltersForSerialization"); auto serialized_filter = SerializeExpressionToLegacyFilter(*expr_filter.expr); diff --git a/src/duckdb/src/planner/table_filter_state.cpp b/src/duckdb/src/planner/table_filter_state.cpp index 538a4c493..daca7e529 100644 --- a/src/duckdb/src/planner/table_filter_state.cpp +++ b/src/duckdb/src/planner/table_filter_state.cpp @@ -1,12 +1,5 @@ #include "duckdb/planner/table_filter_state.hpp" -#include "duckdb/planner/filter/bloom_filter.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/selectivity_optional_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" -#include "duckdb/planner/table_filter.hpp" -#include "duckdb/planner/filter/perfect_hash_join_filter.hpp" -#include "duckdb/planner/filter/prefix_range_filter.hpp" namespace duckdb { @@ -19,79 +12,9 @@ ExpressionFilterState::ExpressionFilterState(ClientContext &context, const Expre InitializeExecutor(context, expression, *this); } -JoinFilterTableFilterState::JoinFilterTableFilterState(const LogicalType &key_logical_type) - : current_capacity(0), hashes_v(LogicalType::HASH, current_capacity), - keys_sliced_v(key_logical_type, current_capacity), probe_sel(current_capacity) { -} - -void JoinFilterTableFilterState::PrepareSlicedKeys(Vector &keys_v, SelectionVector &sel, - const idx_t approved_tuple_count) { - if (current_capacity < approved_tuple_count) { - hashes_v.Initialize(VectorDataInitialization::UNINITIALIZED, approved_tuple_count); - keys_sliced_v.Initialize(VectorDataInitialization::UNINITIALIZED, approved_tuple_count); - probe_sel.Initialize(approved_tuple_count); - current_capacity = approved_tuple_count; - } - - if (keys_sliced_v.GetType() == keys_v.GetType()) { - keys_sliced_v.Slice(keys_v, sel, approved_tuple_count); - } else { - // Apply sel first (into keys_v space), then cast to key_type - Vector sliced_src(keys_v.GetType()); - sliced_src.Slice(keys_v, sel, approved_tuple_count); - sliced_src.Flatten(); - VectorOperations::DefaultCast(sliced_src, keys_sliced_v, approved_tuple_count); - } -} - -const LogicalType &GetTableFilterKeyType(const TableFilter &filter) { - if (filter.filter_type == TableFilterType::BLOOM_FILTER) { - return filter.Cast().GetKeyType(); - } - if (filter.filter_type == TableFilterType::PERFECT_HASH_JOIN_FILTER) { - return filter.Cast().GetKeyType(); - } - if (filter.filter_type == TableFilterType::PREFIX_RANGE_FILTER) { - return filter.Cast().GetKeyType(); - } - throw NotImplementedException("Unknown filter type"); -} - unique_ptr TableFilterState::Initialize(ClientContext &context, const TableFilter &filter) { - switch (filter.filter_type) { - case TableFilterType::BLOOM_FILTER: - case TableFilterType::PERFECT_HASH_JOIN_FILTER: - case TableFilterType::PREFIX_RANGE_FILTER: { - return make_uniq(GetTableFilterKeyType(filter)); - } - case TableFilterType::OPTIONAL_FILTER: { - // the optional filter may be executed if it is a SelectivityOptionalFilter - auto &optional_filter = filter.Cast(); - return optional_filter.InitializeState(context); - } - case TableFilterType::CONJUNCTION_OR: { - auto &conj_filter = filter.Cast(); - auto result = make_uniq(); - for (auto &child_filter : conj_filter.child_filters) { - result->child_states.push_back(Initialize(context, *child_filter)); - } - return std::move(result); - } - case TableFilterType::CONJUNCTION_AND: { - auto &conj_filter = filter.Cast(); - auto result = make_uniq(); - for (auto &child_filter : conj_filter.child_filters) { - result->child_states.push_back(Initialize(context, *child_filter)); - } - return std::move(result); - } - case TableFilterType::EXPRESSION_FILTER: { - auto &expr_filter = filter.Cast(); - return make_uniq(context, *expr_filter.expr); - } - default: - throw InternalException("Unsupported filter type for TableFilterState::Initialize"); - } + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "TableFilterState::Initialize"); + return make_uniq(context, *expr_filter.expr); } } // namespace duckdb diff --git a/src/duckdb/src/storage/block_allocator.cpp b/src/duckdb/src/storage/block_allocator.cpp index 0e7842b83..7cf54028b 100644 --- a/src/duckdb/src/storage/block_allocator.cpp +++ b/src/duckdb/src/storage/block_allocator.cpp @@ -137,21 +137,24 @@ class BlockAllocatorThreadLocalState { } void Clear() { - // Return all local blocks back to global - if (!touched.empty()) { - block_allocator->touched->q.enqueue_bulk(touched.begin(), touched.size()); - touched.clear(); - } - if (!untouched.empty()) { - block_allocator->untouched->q.enqueue_bulk(untouched.begin(), untouched.size()); - untouched.clear(); + if (alive_token && alive_token->load()) { + // Allocator is still alive — return local blocks to global queues + if (!touched.empty()) { + block_allocator->touched->q.enqueue_bulk(touched.begin(), touched.size()); + } + if (!untouched.empty()) { + block_allocator->untouched->q.enqueue_bulk(untouched.begin(), untouched.size()); + } } + touched.clear(); + untouched.clear(); } private: void Initialize(const BlockAllocator &block_allocator_p) { cached_uuid = block_allocator_p.uuid; block_allocator = block_allocator_p; + alive_token = block_allocator_p.alive_token; untouched.clear(); touched.clear(); untouched.reserve(BATCH_SIZE); @@ -185,6 +188,8 @@ class BlockAllocatorThreadLocalState { private: hugeint_t cached_uuid; optional_ptr block_allocator; + // Whether the BlockAllocator is still alive. + shared_ptr> alive_token; static constexpr idx_t BATCH_SIZE = 128; static constexpr idx_t FREE_THRESHOLD = BATCH_SIZE * 2; @@ -215,12 +220,14 @@ BlockAllocator::BlockAllocator(Allocator &allocator_p, const idx_t block_size_p, : uuid(UUID::GenerateRandomUUID()), allocator(allocator_p), block_size(block_size_p), block_size_div_shift(CountZeros::Trailing(block_size)), virtual_memory_size(AlignValue(virtual_memory_size_p, block_size)), virtual_memory_space(nullptr), - physical_memory_size(0), untouched(make_unsafe_uniq()), touched(make_unsafe_uniq()) { + physical_memory_size(0), untouched(make_unsafe_uniq()), touched(make_unsafe_uniq()), + alive_token(make_shared_ptr>(true)) { D_ASSERT(IsPowerOfTwo(block_size)); Resize(physical_memory_size_p); } BlockAllocator::~BlockAllocator() { + alive_token->store(false); GetBlockAllocatorThreadLocalState(*this).Clear(); if (IsActive()) { try { diff --git a/src/duckdb/src/storage/buffer/block_handle.cpp b/src/duckdb/src/storage/buffer/block_handle.cpp index a652940a6..c05ab53ae 100644 --- a/src/duckdb/src/storage/buffer/block_handle.cpp +++ b/src/duckdb/src/storage/buffer/block_handle.cpp @@ -34,8 +34,9 @@ BlockMemory::~BlockMemory() { // NOLINT: allow internal exceptions // The block memory is being destroyed, meaning that any unswizzled pointers are now binary junk. SetSwizzling(nullptr); D_ASSERT(!GetBuffer() || GetBuffer()->GetBufferType() == GetBufferType()); - if (GetBuffer() && GetBufferType() != FileBufferType::TINY_BUFFER) { - // Kill the latest version in the eviction queue. + if (GetEvictionSequenceNumber() > 0 && GetBufferType() != FileBufferType::TINY_BUFFER) { + // eviction_seq_num > 0 means there is a live queue entry for this block (it's reset + // to 0 on unload/evict). That entry is now dead — account for it. GetBufferManager().GetBufferPool().IncrementDeadNodes(*this); } @@ -177,7 +178,11 @@ BlockHandle::~BlockHandle() { // NOLINT: allow internal exceptions unique_ptr AllocateBlock(BlockManager &block_manager, unique_ptr reusable_buffer, block_id_t block_id) { - if (reusable_buffer && reusable_buffer->GetHeaderSize() == block_manager.GetBlockHeaderSize()) { + // A buffer that doesn't own its memory (e.g. it adopted a pointer into a memory-mapped + // region) cannot be reused for a different block: rewriting its bytes would clobber the + // original block on disk through the mapping. + if (reusable_buffer && reusable_buffer->OwnsInternalBuffer() && + reusable_buffer->GetHeaderSize() == block_manager.GetBlockHeaderSize()) { // Reusable buffer: reuse it. if (reusable_buffer->GetBufferType() == FileBufferType::BLOCK) { // Reuse the entire buffer. diff --git a/src/duckdb/src/storage/buffer/buffer_pool.cpp b/src/duckdb/src/storage/buffer/buffer_pool.cpp index 104d3fba6..fb4ad93c8 100644 --- a/src/duckdb/src/storage/buffer/buffer_pool.cpp +++ b/src/duckdb/src/storage/buffer/buffer_pool.cpp @@ -66,13 +66,27 @@ shared_ptr BufferEvictionNode::TryGetBlockMemory() { return shared_memory_p; } +bool BufferEvictionNode::IsDeadNode(optional_idx debug_sleep_micros) { + auto shared_memory_p = memory_p.lock(); + if (debug_sleep_micros.IsValid()) { + ThreadUtil::SleepMicroSeconds(debug_sleep_micros.GetIndex()); + } + if (!shared_memory_p) { + return true; + } + if (handle_sequence_number != shared_memory_p->GetEvictionSequenceNumber()) { + return true; + } + return false; +} + typedef duckdb_moodycamel::ConcurrentQueue eviction_queue_t; struct EvictionQueue { public: explicit EvictionQueue(const vector &file_buffer_types_p) : file_buffer_types(file_buffer_types_p), debug_eviction_queue_sleep(0), evict_queue_insertions(0), - total_dead_nodes(0) { + total_dead_nodes(0), purge_consumer_token(q), purge_producer_token(q) { } public: @@ -108,7 +122,7 @@ struct EvictionQueue { } private: - //! Bulk purge dead nodes from the eviction queue. Then, enqueue those that are still alive. + //! Bulk purge dead nodes from the eviction queue. Then, re-enqueue those that are still alive. void PurgeIteration(const idx_t purge_size); public: @@ -143,6 +157,10 @@ struct EvictionQueue { mutex purge_lock; //! A pre-allocated vector of eviction nodes. We reuse this to keep the allocation overhead of purges small. vector purge_nodes; + //! Consumer token for purge dequeuing — progresses through sub-queues sequentially. + duckdb_moodycamel::ConsumerToken purge_consumer_token; + //! Producer token for re-enqueuing alive nodes into a dedicated sub-queue. + duckdb_moodycamel::ProducerToken purge_producer_token; }; bool EvictionQueue::AddToEvictionQueue(BufferEvictionNode &&node) { @@ -191,6 +209,7 @@ void EvictionQueue::Purge() { // guaranteeing that we always exit the loop. idx_t max_purges = approx_q_size / purge_size; + while (max_purges != 0) { PurgeIteration(purge_size); @@ -224,28 +243,33 @@ void EvictionQueue::PurgeIteration(const idx_t purge_size) { purge_nodes.resize(purge_size); } - // bulk purge - const idx_t actually_dequeued = q.try_dequeue_bulk(purge_nodes.begin(), purge_size); + // Dequeue using consumer token — progresses through sub-queues sequentially + const idx_t actually_dequeued = q.try_dequeue_bulk(purge_consumer_token, purge_nodes.begin(), purge_size); + if (actually_dequeued == 0) { + return; + } - // retrieve all alive nodes that have been wrongly dequeued - idx_t alive_nodes = 0; - auto debug_sleep_micros = debug_eviction_queue_sleep.load(std::memory_order_relaxed); + idx_t dead_count = 0; + idx_t alive_count = 0; + auto raw_sleep_micros = debug_eviction_queue_sleep.load(std::memory_order_relaxed); + optional_idx debug_sleep_micros = raw_sleep_micros > 0 ? optional_idx(raw_sleep_micros) : optional_idx(); for (idx_t i = 0; i < actually_dequeued; i++) { auto &node = purge_nodes[i]; - auto handle = node.TryGetBlockMemory(); - if (debug_sleep_micros > 0) { - // Debug race conditions regarding the ownership of the BlockMemory. - ThreadUtil::SleepMicroSeconds(debug_sleep_micros); - } - if (handle) { - purge_nodes[alive_nodes++] = std::move(node); + if (node.IsDeadNode(debug_sleep_micros)) { + dead_count++; + } else { + // Move alive nodes to the front for bulk re-enqueue + purge_nodes[alive_count++] = std::move(node); } } - // bulk re-add (TODO order them by timestamp to better retain the LRU behavior) - q.enqueue_bulk(purge_nodes.begin(), alive_nodes); + total_dead_nodes -= dead_count; - total_dead_nodes -= actually_dequeued - alive_nodes; + // Re-enqueue alive nodes via producer token — goes into a dedicated sub-queue + // that the consumer token has already passed + if (alive_count > 0) { + q.enqueue_bulk(purge_producer_token, purge_nodes.begin(), alive_count); + } } BufferPool::BufferPool(BlockAllocator &block_allocator, idx_t maximum_memory, bool track_eviction_timestamps, diff --git a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp index ab1e2bb5a..7f94213d6 100644 --- a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp +++ b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp @@ -4,6 +4,7 @@ #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/serializer/binary_deserializer.hpp" #include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" #include "duckdb/main/settings.hpp" #include "duckdb/parallel/task_scheduler.hpp" @@ -21,12 +22,12 @@ TableDataWriter::TableDataWriter(TableCatalogEntry &table_p, QueryContext contex : table(table_p.Cast()), context(context.GetClientContext()) { D_ASSERT(table_p.IsDuckTable()); - auto serialization_version = SerializationCompatibility::FromDatabase(table_p.ParentCatalog().GetAttached()); - if (serialization_version.serialization_version < - SerializationCompatibility::FromString("v1.4.4").serialization_version) { + auto storage_compatibility = StorageCompatibility::FromDatabase(table_p.ParentCatalog().GetAttached()); + if (storage_compatibility.storage_version < StorageVersion::V1_4_4) { // older storage versions require legacy start row to be written require_legacy_start_row = true; } + can_persist_rowid_gaps = storage_compatibility.storage_version >= StorageVersion::V2_0_0; } TableDataWriter::~TableDataWriter() { @@ -56,6 +57,10 @@ unique_ptr TableDataWriter::CreateTaskExecutor() { return make_uniq(TaskScheduler::GetScheduler(GetDatabase())); } +optional_ptr TableDataWriter::TryGetClientContext() const { + return context; +} + SingleFileTableDataWriter::SingleFileTableDataWriter(SingleFileCheckpointWriter &checkpoint_manager, TableCatalogEntry &table, MetadataWriter &table_data_writer) : TableDataWriter(table, checkpoint_manager.GetClientContext()), checkpoint_manager(checkpoint_manager), @@ -76,11 +81,12 @@ MetadataManager &SingleFileTableDataWriter::GetMetadataManager() { } void SingleFileTableDataWriter::WriteUnchangedTable(MetaBlockPointer pointer, - const vector &metadata_pointers, - idx_t total_rows) { + const vector &metadata_pointers, idx_t total_rows, + idx_t next_row_id) { existing_pointer = pointer; existing_pointers = metadata_pointers; existing_rows = total_rows; + existing_next_row_id = next_row_id; } void SingleFileTableDataWriter::FlushPartialBlocks() { @@ -91,8 +97,10 @@ void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stat RowGroupCollection &collection, Serializer &serializer) { MetaBlockPointer pointer; idx_t total_rows; + idx_t next_row_id; auto debug_verify_blocks = Settings::Get(GetDatabase()); if (!existing_pointer.IsValid()) { + auto supports_per_column_writes = collection.SupportsPerColumnWrites(); // write the metadata // store the current position in the metadata writer // this is where the row groups for this table start @@ -109,16 +117,18 @@ void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stat // now start writing the row group pointers to disk table_data_writer.Write(row_group_pointers.size()); total_rows = 0; + next_row_id = 0; for (auto &row_group_pointer : row_group_pointers) { - auto row_group_count = row_group_pointer.row_start + row_group_pointer.tuple_count; - if (row_group_count > total_rows) { - total_rows = row_group_count; + total_rows += row_group_pointer.tuple_count; + auto row_group_end = row_group_pointer.row_start + row_group_pointer.tuple_count; + if (row_group_end > next_row_id) { + next_row_id = row_group_end; } // Each RowGroup is its own unit BinarySerializer row_group_serializer(table_data_writer, serializer.GetOptions()); row_group_serializer.Begin(); - RowGroup::Serialize(row_group_pointer, row_group_serializer); + RowGroup::Serialize(row_group_pointer, row_group_serializer, supports_per_column_writes); row_group_serializer.End(); } table_data_writer.SetWrittenPointers(nullptr); @@ -127,6 +137,8 @@ void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stat // we have existing metadata and the table is unchanged - write a pointer to the existing metadata pointer = existing_pointer; total_rows = existing_rows.GetIndex(); + next_row_id = existing_next_row_id.GetIndex(); + D_ASSERT(next_row_id >= total_rows); // label the blocks as used again to prevent them from being freed auto &metadata_manager = checkpoint_manager.GetMetadataManager(); @@ -170,12 +182,15 @@ void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stat } } + D_ASSERT(next_row_id >= total_rows); // Now begin the metadata as a unit // Pointer to the table itself goes to the metadata stream. serializer.WriteProperty(101, "table_pointer", pointer); serializer.WriteProperty(102, "total_rows", total_rows); - auto v1_0_0_storage = serializer.GetOptions().serialization_compatibility.serialization_version < 3; + // prior: ser version 3 + auto v1_0_0_storage = StorageManager::IsPriorToVersion( + StorageVersion::V1_2_0, serializer.GetOptions().storage_compatibility.storage_version); IndexSerializationInfo serialization_info; if (!v1_0_0_storage) { serialization_info.options.emplace("v1_0_0_storage", v1_0_0_storage); @@ -200,6 +215,11 @@ void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stat serializer.WriteList( 104, "index_storage_infos", index_storage_infos.ordered_infos.size(), [&](Serializer::List &list, idx_t i) { list.WriteElement(index_storage_infos.ordered_infos[i].get()); }); + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + serializer.WriteProperty(105, "next_row_id", next_row_id); + } + // ¬serializer.ShouldSerialize(StorageVersion::V2_0_0) ==> (next_row_id == total_rows) + D_ASSERT(serializer.ShouldSerialize(StorageVersion::V2_0_0) || (next_row_id == total_rows)); } } // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint_manager.cpp b/src/duckdb/src/storage/checkpoint_manager.cpp index a33f4e05c..1a369af45 100644 --- a/src/duckdb/src/storage/checkpoint_manager.cpp +++ b/src/duckdb/src/storage/checkpoint_manager.cpp @@ -727,6 +727,10 @@ void CheckpointReader::ReadTableData(CatalogTransaction transaction, Deserialize // Cover reading new storage files. auto index_storage_infos = deserializer.ReadPropertyWithExplicitDefault>(104, "index_storage_infos", {}); + // Read next_row_id as total_rows for backwards compatibility. Older storage versions do not allow for gaps in + // row_id numbering, in which case next_row_id = total_rows. + auto next_row_id = deserializer.ReadPropertyWithExplicitDefault(105, "next_row_id", total_rows); + D_ASSERT(next_row_id >= total_rows); if (!index_storage_infos.empty()) { bound_info.indexes = std::move(index_storage_infos); @@ -751,6 +755,7 @@ void CheckpointReader::ReadTableData(CatalogTransaction transaction, Deserialize data_reader.ReadTableData(); bound_info.data->total_rows = total_rows; + bound_info.data->next_row_id = next_row_id; bound_info.data->read_metadata_pointers = read_pointers; } diff --git a/src/duckdb/src/storage/compression/bitpacking.cpp b/src/duckdb/src/storage/compression/bitpacking.cpp index 3b31a9c36..f0368c2d9 100644 --- a/src/duckdb/src/storage/compression/bitpacking.cpp +++ b/src/duckdb/src/storage/compression/bitpacking.cpp @@ -12,6 +12,7 @@ #include "duckdb/main/settings.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/compression/bitpacking.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/storage/table/column_data_checkpointer.hpp" #include "duckdb/storage/table/column_segment.hpp" #include "duckdb/storage/table/scan_state.hpp" @@ -301,21 +302,20 @@ struct BitpackingState { //===--------------------------------------------------------------------===// template struct BitpackingAnalyzeState : public AnalyzeState { - explicit BitpackingAnalyzeState(const CompressionInfo &info) : AnalyzeState(info) {}; + explicit BitpackingAnalyzeState(BlockManager &block_manager) : AnalyzeState(block_manager) {}; BitpackingState state; }; template unique_ptr BitpackingInitAnalyze(ColumnData &col_data, PhysicalType type) { - CompressionInfo info(col_data.GetBlockManager()); - auto state = make_uniq>(info); + auto state = make_uniq>(col_data.GetBlockManager()); state->state.mode = Settings::Get(col_data.GetDatabase()); return std::move(state); } template -bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { +bool BitpackingAnalyze(AnalyzeState &state, const Vector &input) { // We use BITPACKING_METADATA_GROUP_SIZE tuples, which can exceed the block size. // In that case, we disable bitpacking. // we are conservative here by multiplying by 2 @@ -347,22 +347,17 @@ idx_t BitpackingFinalAnalyze(AnalyzeState &state) { // Compress //===--------------------------------------------------------------------===// template ::type> -struct BitpackingCompressionState : public CompressionState { +struct BitpackingCompressionState : public StandardCompressionState { public: - explicit BitpackingCompressionState(ColumnDataCheckpointData &checkpoint_data, const CompressionInfo &info) - : CompressionState(info), checkpoint_data(checkpoint_data), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_BITPACKING)) { + explicit BitpackingCompressionState(ColumnDataCheckpointData &checkpoint_data) + : StandardCompressionState(checkpoint_data, CompressionType::COMPRESSION_BITPACKING) { CreateEmptySegment(); state.data_ptr = reinterpret_cast(this); state.mode = Settings::Get(checkpoint_data.GetDatabase()); } - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; - unique_ptr current_segment; - BufferHandle handle; - + StatsWriter stats_writer; // Ptr to next free spot in segment; data_ptr_t data_ptr; // Ptr to next free spot for storing bitwidths and frame-of-references (growing downwards). @@ -450,16 +445,14 @@ struct BitpackingCompressionState : public CompressionState { state->current_segment->count += count; if (WRITE_STATISTICS) { + auto &stats_writer = state->stats_writer; if (state->state.has_valid) { - state->current_segment->stats.statistics.SetHasNoNullFast(); + stats_writer.SetHasValid(); + stats_writer.UpdateMinMax(state->state.minimum); + stats_writer.UpdateMinMax(state->state.maximum); } if (state->state.has_invalid) { - state->current_segment->stats.statistics.SetHasNullFast(); - } - - if (!state->state.all_invalid) { - state->current_segment->stats.statistics.template UpdateNumericStats(state->state.maximum); - state->current_segment->stats.statistics.template UpdateNumericStats(state->state.minimum); + stats_writer.SetHasNull(); } } } @@ -474,21 +467,13 @@ struct BitpackingCompressionState : public CompressionState { } void CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); - current_segment = std::move(compressed_segment); - - auto &buffer_manager = BufferManager::GetBufferManager(db); - handle = buffer_manager.Pin(current_segment->block); + CreateAndPinNewSegment(); data_ptr = handle.GetDataMutable() + BitpackingPrimitives::BITPACKING_HEADER_SIZE; metadata_ptr = handle.GetDataMutable() + info.GetBlockSize(); } - void Append(Vector &input, idx_t count) { + void Append(const Vector &input) { for (auto entry : input.Values()) { state.template Update(entry); } @@ -502,7 +487,6 @@ struct BitpackingCompressionState : public CompressionState { } void FlushSegment() { - auto &state = checkpoint_data.GetCheckpointState(); auto base_ptr = handle.GetDataMutable(); // Compact the segment by moving the metadata next to the data. @@ -525,8 +509,7 @@ struct BitpackingCompressionState : public CompressionState { // Store the offset of the metadata of the first group (which is at the highest address). Store(metadata_offset + metadata_size, base_ptr); - - state.FlushSegment(std::move(current_segment), std::move(handle), total_segment_size); + FlushCurrentSegment(stats_writer, total_segment_size); } void Finalize() { @@ -539,13 +522,13 @@ struct BitpackingCompressionState : public CompressionState { template unique_ptr BitpackingInitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state) { - return make_uniq>(checkpoint_data, state->info); + return make_uniq>(checkpoint_data); } template -void BitpackingCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void BitpackingCompress(CompressionState &state_p, const Vector &scan_vector) { auto &state = state_p.Cast>(); - state.Append(scan_vector, count); + state.Append(scan_vector); } template @@ -599,8 +582,8 @@ template ::type> struct BitpackingScanState : public SegmentScanState { public: explicit BitpackingScanState(const QueryContext &context, ColumnSegment &segment) : current_segment(segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(context, segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + handle = buffer_manager.Pin(context, segment.GetBlockHandle()); auto data_ptr = handle.GetDataMutable(); // load offset to bitpacking widths pointer @@ -609,7 +592,7 @@ struct BitpackingScanState : public SegmentScanState { data_ptr + segment.GetBlockOffset() + bitpacking_metadata_offset - sizeof(bitpacking_metadata_encoded_t); if (bitpacking_metadata_ptr >= handle.GetDataMutable() + current_segment.GetBlockSize()) { throw InternalException("Bitpacking offset is out of range at block \"%llu\" - corrupt database file", - segment.block->BlockId()); + segment.GetBlockHandle()->BlockId()); } // load the first group diff --git a/src/duckdb/src/storage/compression/dict_fsst.cpp b/src/duckdb/src/storage/compression/dict_fsst.cpp index e07edccad..5ad837e4d 100644 --- a/src/duckdb/src/storage/compression/dict_fsst.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst.cpp @@ -48,12 +48,12 @@ namespace dict_fsst { struct DictFSSTCompressionStorage { static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); - static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); + static bool StringAnalyze(AnalyzeState &state_p, const Vector &input); static idx_t StringFinalAnalyze(AnalyzeState &state_p); static unique_ptr InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state); - static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); + static void Compress(CompressionState &state_p, const Vector &scan_vector); static void FinalizeCompress(CompressionState &state_p); static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); @@ -70,18 +70,17 @@ struct DictFSSTCompressionStorage { //===--------------------------------------------------------------------===// unique_ptr DictFSSTCompressionStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { auto &storage_manager = col_data.GetStorageManager(); - if (storage_manager.GetStorageVersion() < 5) { + if (StorageManager::IsPriorToVersion(StorageVersion::V1_3_0, storage_manager.GetStorageVersion())) { // dict_fsst not introduced yet, disable it return nullptr; } - CompressionInfo info(col_data.GetBlockManager()); - return make_uniq(info); + return make_uniq(col_data.GetBlockManager()); } -bool DictFSSTCompressionStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { +bool DictFSSTCompressionStorage::StringAnalyze(AnalyzeState &state_p, const Vector &input) { auto &analyze_state = state_p.Cast(); - return analyze_state.Analyze(input, count); + return analyze_state.Analyze(input); } idx_t DictFSSTCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { @@ -98,9 +97,9 @@ unique_ptr DictFSSTCompressionStorage::InitCompression(ColumnD unique_ptr_cast(std::move(state))); } -void DictFSSTCompressionStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void DictFSSTCompressionStorage::Compress(CompressionState &state_p, const Vector &scan_vector) { auto &state = state_p.Cast(); - state.Compress(scan_vector, count); + state.Compress(scan_vector); } void DictFSSTCompressionStorage::FinalizeCompress(CompressionState &state_p) { @@ -113,11 +112,11 @@ void DictFSSTCompressionStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// unique_ptr DictFSSTCompressionStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto state = make_uniq(segment, buffer_manager.Pin(segment.block)); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto state = make_uniq(segment, buffer_manager.Pin(segment.GetBlockHandle())); state->Initialize(true); - const auto &stats = segment.stats.statistics; + const auto &stats = segment.GetStats(); if (stats.GetStatsType() == StatisticsType::STRING_STATS && StringStats::HasMaxStringLength(stats)) { state->all_values_inlined = StringStats::MaxStringLength(stats) <= string_t::INLINE_LENGTH; } diff --git a/src/duckdb/src/storage/compression/dict_fsst/analyze.cpp b/src/duckdb/src/storage/compression/dict_fsst/analyze.cpp index 9b0b310b9..07215b6ca 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/analyze.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/analyze.cpp @@ -3,10 +3,10 @@ namespace duckdb { namespace dict_fsst { -DictFSSTAnalyzeState::DictFSSTAnalyzeState(const CompressionInfo &info) : AnalyzeState(info) { +DictFSSTAnalyzeState::DictFSSTAnalyzeState(BlockManager &block_manager) : AnalyzeState(block_manager) { } -bool DictFSSTAnalyzeState::Analyze(Vector &input, idx_t count) { +bool DictFSSTAnalyzeState::Analyze(const Vector &input) { for (auto entry : input.Values()) { if (!entry.IsValid()) { contains_nulls = true; @@ -23,7 +23,7 @@ bool DictFSSTAnalyzeState::Analyze(Vector &input, idx_t count) { return false; } } - total_count += count; + total_count += input.size(); return true; } diff --git a/src/duckdb/src/storage/compression/dict_fsst/compression.cpp b/src/duckdb/src/storage/compression/dict_fsst/compression.cpp index d1d314181..114723b24 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/compression.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/compression.cpp @@ -12,8 +12,7 @@ namespace dict_fsst { DictFSSTCompressionState::DictFSSTCompressionState(ColumnDataCheckpointData &checkpoint_data_p, unique_ptr &&analyze_p) - : CompressionState(analyze_p->info), checkpoint_data(checkpoint_data_p), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_DICT_FSST)), + : StandardCompressionState(checkpoint_data_p, CompressionType::COMPRESSION_DICT_FSST), stats_writer(GetType()), current_string_map( info.GetBlockManager().buffer_manager.GetBufferAllocator(), MinValue(analyze_p.get()->total_count, info.GetBlockSize()) / 2, // maximum_size_p (amount of elements) @@ -92,7 +91,7 @@ idx_t DictFSSTCompressionState::Finalize() { D_ASSERT(info.GetBlockSize() >= required_space); // calculate ptr and offsets - auto base_ptr = current_handle.GetDataMutable(); + auto base_ptr = handle.GetDataMutable(); auto header_ptr = reinterpret_cast(base_ptr); auto dictionary_dest = AlignValue(DictFSSTCompression::DICTIONARY_HEADER_SIZE); auto symbol_table_dest = AlignValue(dictionary_dest + dictionary_offset); @@ -158,7 +157,7 @@ void DictFSSTCompressionState::FlushEncodingBuffer() { vector fsst_string_ptrs; data_ptr_t dictionary_start = - AlignPointer(current_handle.GetDataMutable() + sizeof(dict_fsst_compression_header_t)); + AlignPointer(handle.GetDataMutable() + sizeof(dict_fsst_compression_header_t)); D_ASSERT(dictionary_encoding_buffer.size() == dict_count - string_lengths.size()); auto string_count = dictionary_encoding_buffer.size(); idx_t sum = 0; @@ -237,16 +236,7 @@ void DictFSSTCompressionState::FlushEncodingBuffer() { } void DictFSSTCompressionState::CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); - current_segment = std::move(compressed_segment); - - // Reset the pointers into the current segment. - auto &buffer_manager = BufferManager::GetBufferManager(checkpoint_data.GetDatabase()); - current_handle = buffer_manager.Pin(current_segment->block); + CreateAndPinNewSegment(); append_state = DictionaryAppendState::REGULAR; string_lengths_width = 0; @@ -277,8 +267,7 @@ void DictFSSTCompressionState::Flush(bool final) { current_segment->count = tuple_count; auto segment_size = Finalize(); - auto &state = checkpoint_data.GetCheckpointState(); - state.FlushSegment(std::move(current_segment), std::move(current_handle), segment_size); + FlushCurrentSegment(stats_writer, segment_size); // Reset the state uncompressed_dictionary_copy.Destroy(); @@ -446,8 +435,8 @@ static inline bool AddToDictionary(DictFSSTCompressionState &state, const string state.current_string_map.Insert(uncompressed_string); } else { state.string_lengths.push_back(str_len); - auto baseptr = AlignPointer(state.current_handle.GetDataMutable() + - sizeof(dict_fsst_compression_header_t)); + auto baseptr = + AlignPointer(state.handle.GetDataMutable() + sizeof(dict_fsst_compression_header_t)); memcpy(baseptr + state.dictionary_offset, str.GetData(), str_len); string_t dictionary_string((const char *)(baseptr + state.dictionary_offset), str_len); // NOLINT state.dictionary_offset += str_len; @@ -674,7 +663,7 @@ DictionaryAppendState DictFSSTCompressionState::TryEncode() { uint32_t offset = 0; data_ptr_t dictionary_start = - AlignPointer(current_handle.GetDataMutable() + sizeof(dict_fsst_compression_header_t)); + AlignPointer(handle.GetDataMutable() + sizeof(dict_fsst_compression_header_t)); D_ASSERT(dictionary_offset > string_t::INLINE_BYTES && dictionary_offset <= string_t::MAX_STRING_SIZE); auto dict_copy = uncompressed_dictionary_copy.EmptyString(dictionary_offset); memcpy((void *)dict_copy.GetData(), (void *)dictionary_start, dictionary_offset); @@ -830,11 +819,12 @@ DictionaryAppendState DictFSSTCompressionState::TryEncode() { return new_state; } -void DictFSSTCompressionState::Compress(Vector &scan_vector, idx_t count) { +void DictFSSTCompressionState::Compress(const Vector &scan_vector) { UnifiedVectorFormat vector_format; scan_vector.ToUnifiedFormat(vector_format); auto strings = UnifiedVectorFormat::GetData(vector_format); + const auto count = scan_vector.size(); EncodedInput encoded_input; for (idx_t i = 0; i < count; i++) { auto idx = vector_format.sel->get_index(i); @@ -860,9 +850,9 @@ void DictFSSTCompressionState::Compress(Vector &scan_vector, idx_t count) { } } while (false); if (!is_null) { - UncompressedStringStorage::UpdateStringStats(current_segment->stats, str); + stats_writer.Update(str); } else { - current_segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); } tuple_count++; } diff --git a/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp b/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp index 3662d9dc2..5c45df5f9 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp @@ -101,7 +101,7 @@ void CompressedStringScanState::Initialize(bool initialize_dictionary) { return; } - dictionary = DictionaryVector::CreateReusableDictionary(segment.type, dict_count); + dictionary = DictionaryVector::CreateReusableDictionary(segment.GetType(), dict_count); auto &dict_data = dictionary->data; auto dict_child_data = FlatVector::GetDataMutable(dict_data); auto &validity = FlatVector::ValidityMutable(dict_data); diff --git a/src/duckdb/src/storage/compression/dictionary/analyze.cpp b/src/duckdb/src/storage/compression/dictionary/analyze.cpp index 492141537..352940f3d 100644 --- a/src/duckdb/src/storage/compression/dictionary/analyze.cpp +++ b/src/duckdb/src/storage/compression/dictionary/analyze.cpp @@ -2,8 +2,8 @@ namespace duckdb { -DictionaryAnalyzeState::DictionaryAnalyzeState(const CompressionInfo &info) - : DictionaryCompressionState(info), segment_count(0), current_tuple_count(0), current_unique_count(0), +DictionaryAnalyzeState::DictionaryAnalyzeState(BlockManager &block_manager) + : AnalyzeState(block_manager), segment_count(0), current_tuple_count(0), current_unique_count(0), current_dict_size(0), current_width(0), next_width(0) { } diff --git a/src/duckdb/src/storage/compression/dictionary/common.cpp b/src/duckdb/src/storage/compression/dictionary/common.cpp index 16ee752c1..f0621ef4b 100644 --- a/src/duckdb/src/storage/compression/dictionary/common.cpp +++ b/src/duckdb/src/storage/compression/dictionary/common.cpp @@ -38,52 +38,4 @@ void DictionaryCompression::SetDictionary(ColumnSegment &segment, BufferHandle & Store(container.end, data_ptr_cast(&header_ptr->dict_end)); } -DictionaryCompressionState::DictionaryCompressionState(const CompressionInfo &info) : CompressionState(info) { -} -DictionaryCompressionState::~DictionaryCompressionState() { -} - -bool DictionaryCompressionState::UpdateState(Vector &scan_vector, idx_t count) { - Verify(); - - for (auto entry : scan_vector.Values()) { - idx_t string_size = 0; - bool new_string = false; - auto row_is_valid = entry.IsValid(); - - if (row_is_valid) { - auto &str = entry.GetValue(); - string_size = str.GetSize(); - if (string_size >= StringUncompressed::GetStringBlockLimit(info.GetBlockSize())) { - // Big strings not implemented for dictionary compression - return false; - } - new_string = !LookupString(str); - } - - bool fits = CalculateSpaceRequirements(new_string, string_size); - if (!fits) { - Flush(); - new_string = true; - - fits = CalculateSpaceRequirements(new_string, string_size); - if (!fits) { - throw InternalException("Dictionary compression could not write to new segment"); - } - } - - if (!row_is_valid) { - AddNull(); - } else if (new_string) { - AddNewString(entry.GetValue()); - } else { - AddLastLookup(); - } - - Verify(); - } - - return true; -} - } // namespace duckdb diff --git a/src/duckdb/src/storage/compression/dictionary/compression.cpp b/src/duckdb/src/storage/compression/dictionary/compression.cpp index 26264048f..75b1ea794 100644 --- a/src/duckdb/src/storage/compression/dictionary/compression.cpp +++ b/src/duckdb/src/storage/compression/dictionary/compression.cpp @@ -3,10 +3,9 @@ namespace duckdb { DictionaryCompressionCompressState::DictionaryCompressionCompressState(ColumnDataCheckpointData &checkpoint_data_p, - const CompressionInfo &info, const idx_t max_unique_count_across_all_segments) - : DictionaryCompressionState(info), checkpoint_data(checkpoint_data_p), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_DICTIONARY)), + : StandardCompressionState(checkpoint_data_p, CompressionType::COMPRESSION_DICTIONARY), + stats_writer(checkpoint_data.GetType()), current_string_map( info.GetBlockManager().buffer_manager.GetBufferAllocator(), max_unique_count_across_all_segments * 2, // * 2 results in less linear probing, improving performance @@ -17,14 +16,10 @@ DictionaryCompressionCompressState::DictionaryCompressionCompressState(ColumnDat } void DictionaryCompressionCompressState::CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); - current_segment = std::move(compressed_segment); + CreateAndPinNewSegment(); // Reset the buffers and the string map. + stats_writer.Clear(); current_string_map.Clear(); index_buffer.clear(); @@ -36,10 +31,8 @@ void DictionaryCompressionCompressState::CreateEmptySegment() { next_width = 0; // Reset the pointers into the current segment. - auto &buffer_manager = BufferManager::GetBufferManager(checkpoint_data.GetDatabase()); - current_handle = buffer_manager.Pin(current_segment->block); - current_dictionary = DictionaryCompression::GetDictionary(*current_segment, current_handle); - current_end_ptr = current_handle.GetDataMutable() + current_dictionary.end; + current_dictionary = DictionaryCompression::GetDictionary(*current_segment, handle); + current_end_ptr = handle.GetDataMutable() + current_dictionary.end; } void DictionaryCompressionCompressState::Verify() { @@ -61,7 +54,7 @@ bool DictionaryCompressionCompressState::LookupString(string_t str) { } void DictionaryCompressionCompressState::AddNewString(string_t str) { - UncompressedStringStorage::UpdateStringStats(current_segment->stats, str); + stats_writer.Update(str); // Copy string to dict current_dictionary.size += str.GetSize(); @@ -80,14 +73,14 @@ void DictionaryCompressionCompressState::AddNewString(string_t str) { D_ASSERT(!dictionary_string.IsInlined()); current_string_map.Insert(dictionary_string); } - DictionaryCompression::SetDictionary(*current_segment, current_handle, current_dictionary); + DictionaryCompression::SetDictionary(*current_segment, handle, current_dictionary); current_width = next_width; current_segment->count++; } void DictionaryCompressionCompressState::AddNull() { - current_segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); selection_buffer.push_back(0); current_segment->count++; } @@ -110,8 +103,7 @@ bool DictionaryCompressionCompressState::CalculateSpaceRequirements(bool new_str void DictionaryCompressionCompressState::Flush(bool final) { auto segment_size = Finalize(); - auto &state = checkpoint_data.GetCheckpointState(); - state.FlushSegment(std::move(current_segment), std::move(current_handle), segment_size); + FlushCurrentSegment(stats_writer, segment_size); if (!final) { CreateEmptySegment(); @@ -120,7 +112,7 @@ void DictionaryCompressionCompressState::Flush(bool final) { idx_t DictionaryCompressionCompressState::Finalize() { auto &buffer_manager = BufferManager::GetBufferManager(checkpoint_data.GetDatabase()); - auto handle = buffer_manager.Pin(current_segment->block); + auto handle = buffer_manager.Pin(current_segment->GetBlockHandle()); D_ASSERT(current_dictionary.end == info.GetBlockSize()); // calculate sizes diff --git a/src/duckdb/src/storage/compression/dictionary/decompression.cpp b/src/duckdb/src/storage/compression/dictionary/decompression.cpp index 87e117761..6db176a5a 100644 --- a/src/duckdb/src/storage/compression/dictionary/decompression.cpp +++ b/src/duckdb/src/storage/compression/dictionary/decompression.cpp @@ -48,7 +48,7 @@ void CompressedStringScanState::Initialize(ColumnSegment &segment, bool initiali return; } - dictionary = DictionaryVector::CreateReusableDictionary(segment.type, index_buffer_count); + dictionary = DictionaryVector::CreateReusableDictionary(segment.GetType(), index_buffer_count); dictionary_size = index_buffer_count; auto dict_child_data = FlatVector::Writer(dictionary->data, index_buffer_count); dict_child_data.WriteNull(); diff --git a/src/duckdb/src/storage/compression/dictionary_compression.cpp b/src/duckdb/src/storage/compression/dictionary_compression.cpp index 6f10acb60..a8c2ac947 100644 --- a/src/duckdb/src/storage/compression/dictionary_compression.cpp +++ b/src/duckdb/src/storage/compression/dictionary_compression.cpp @@ -47,12 +47,12 @@ namespace duckdb { struct DictionaryCompressionStorage { static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); - static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); + static bool StringAnalyze(AnalyzeState &state_p, const Vector &input); static idx_t StringFinalAnalyze(AnalyzeState &state_p); static unique_ptr InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state); - static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); + static void Compress(CompressionState &state_p, const Vector &scan_vector); static void FinalizeCompress(CompressionState &state_p); static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); @@ -69,23 +69,21 @@ struct DictionaryCompressionStorage { //===--------------------------------------------------------------------===// unique_ptr DictionaryCompressionStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { auto &storage_manager = col_data.GetStorageManager(); - if (storage_manager.GetStorageVersion() >= 5) { + if (StorageManager::TargetAtLeastVersion(StorageVersion::V1_3_0, storage_manager.GetStorageVersion())) { // dict_fsst introduced - disable dictionary return nullptr; } - CompressionInfo info(col_data.GetBlockManager()); - return make_uniq(info); + return make_uniq(col_data.GetBlockManager()); } -bool DictionaryCompressionStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { - auto &state = state_p.Cast(); - return state.analyze_state->UpdateState(input, count); +bool DictionaryCompressionStorage::StringAnalyze(AnalyzeState &state_p, const Vector &input) { + auto &state = state_p.Cast(); + return DictionaryCompression::UpdateState(state, input); } idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { - auto &analyze_state = state_p.Cast(); - auto &state = *analyze_state.analyze_state; + auto &state = state_p.Cast(); if (state.current_tuple_count != 0) { state.UpdateMaxUniqueCount(); @@ -103,16 +101,14 @@ idx_t DictionaryCompressionStorage::StringFinalAnalyze(AnalyzeState &state_p) { // Compress //===--------------------------------------------------------------------===// unique_ptr DictionaryCompressionStorage::InitCompression(ColumnDataCheckpointData &checkpoint_data, - unique_ptr state) { - const auto &analyze_state = state->Cast(); - auto &actual_state = *analyze_state.analyze_state; - return make_uniq(checkpoint_data, state->info, - actual_state.max_unique_count_across_segments); + unique_ptr state_p) { + const auto &state = state_p->Cast(); + return make_uniq(checkpoint_data, state.max_unique_count_across_segments); } -void DictionaryCompressionStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void DictionaryCompressionStorage::Compress(CompressionState &state_p, const Vector &scan_vector) { auto &state = state_p.Cast(); - state.UpdateState(scan_vector, count); + DictionaryCompression::UpdateState(state, scan_vector); } void DictionaryCompressionStorage::FinalizeCompress(CompressionState &state_p) { @@ -125,8 +121,8 @@ void DictionaryCompressionStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// unique_ptr DictionaryCompressionStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto state = make_uniq(buffer_manager.Pin(segment.block)); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto state = make_uniq(buffer_manager.Pin(segment.GetBlockHandle())); state->Initialize(segment, true); return std::move(state); } diff --git a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp index fbf679cfe..28019a09c 100644 --- a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp @@ -15,20 +15,19 @@ namespace duckdb { // Analyze //===--------------------------------------------------------------------===// struct FixedSizeAnalyzeState : public AnalyzeState { - explicit FixedSizeAnalyzeState(const CompressionInfo &info) : AnalyzeState(info), count(0) { + explicit FixedSizeAnalyzeState(BlockManager &block_manager) : AnalyzeState(block_manager), count(0) { } idx_t count; }; unique_ptr FixedSizeInitAnalyze(ColumnData &col_data, PhysicalType type) { - CompressionInfo info(col_data.GetBlockManager()); - return make_uniq(info); + return make_uniq(col_data.GetBlockManager()); } -bool FixedSizeAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { +bool FixedSizeAnalyze(AnalyzeState &state_p, const Vector &input) { auto &state = state_p.Cast(); - state.count += count; + state.count += input.size(); return true; } @@ -43,33 +42,28 @@ idx_t FixedSizeFinalAnalyze(AnalyzeState &state_p) { //===--------------------------------------------------------------------===// struct UncompressedCompressState : public CompressionState { public: - UncompressedCompressState(ColumnDataCheckpointData &checkpoint_data, const CompressionInfo &info); + explicit UncompressedCompressState(ColumnDataCheckpointData &checkpoint_data); public: virtual void CreateEmptySegment(); void FlushSegment(idx_t segment_size); void Finalize(idx_t segment_size); + idx_t FinalizeAppend(); public: - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; unique_ptr current_segment; ColumnAppendState append_state; }; -UncompressedCompressState::UncompressedCompressState(ColumnDataCheckpointData &checkpoint_data, - const CompressionInfo &info) - : CompressionState(info), checkpoint_data(checkpoint_data), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED)) { +UncompressedCompressState::UncompressedCompressState(ColumnDataCheckpointData &checkpoint_data) + : CompressionState(checkpoint_data, CompressionType::COMPRESSION_UNCOMPRESSED) { UncompressedCompressState::CreateEmptySegment(); } void UncompressedCompressState::CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); auto &type = checkpoint_data.GetType(); - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); + auto compressed_segment = CreateNewSegment(); if (type.InternalType() == PhysicalType::VARCHAR) { auto &state = compressed_segment->GetSegmentState()->Cast(); auto &storage_manager = checkpoint_data.GetStorageManager(); @@ -81,11 +75,12 @@ void UncompressedCompressState::CreateEmptySegment() { } current_segment = std::move(compressed_segment); current_segment->InitializeAppend(append_state); + append_state.InitializeStats(type); } void UncompressedCompressState::FlushSegment(idx_t segment_size) { auto &state = checkpoint_data.GetCheckpointState(); - if (current_segment->type.InternalType() == PhysicalType::VARCHAR) { + if (current_segment->GetType().InternalType() == PhysicalType::VARCHAR) { auto &segment_state = current_segment->GetSegmentState()->Cast(); if (segment_state.overflow_writer) { segment_state.overflow_writer->Flush(); @@ -103,36 +98,42 @@ void UncompressedCompressState::Finalize(idx_t segment_size) { current_segment.reset(); } +idx_t UncompressedCompressState::FinalizeAppend() { + current_segment->GetStatsMutable().Merge(*append_state.append_stats); + return current_segment->FinalizeAppend(append_state); +} + unique_ptr UncompressedFunctions::InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state) { - return make_uniq(checkpoint_data, state->info); + return make_uniq(checkpoint_data); } -void UncompressedFunctions::Compress(CompressionState &state_p, Vector &data, idx_t count) { +void UncompressedFunctions::Compress(CompressionState &state_p, const Vector &data) { auto &state = state_p.Cast(); UnifiedVectorFormat vdata; data.ToUnifiedFormat(vdata); idx_t offset = 0; - while (count > 0) { - idx_t appended = state.current_segment->Append(state.append_state, vdata, offset, count); - if (appended == count) { + idx_t remaining = data.size(); + while (remaining > 0) { + idx_t appended = state.current_segment->Append(state.append_state, vdata, offset, remaining); + if (appended == remaining) { // appended everything: finished return; } // the segment is full: flush it to disk - state.FlushSegment(state.current_segment->FinalizeAppend(state.append_state)); + state.FlushSegment(state.FinalizeAppend()); // now create a new segment and continue appending state.CreateEmptySegment(); offset += appended; - count -= appended; + remaining -= appended; } } void UncompressedFunctions::FinalizeCompress(CompressionState &state_p) { auto &state = state_p.Cast(); - state.Finalize(state.current_segment->FinalizeAppend(state.append_state)); + state.Finalize(state.FinalizeAppend()); } //===--------------------------------------------------------------------===// @@ -144,8 +145,8 @@ struct FixedSizeScanState : public SegmentScanState { unique_ptr FixedSizeInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(context, segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + result->handle = buffer_manager.Pin(context, segment.GetBlockHandle()); return std::move(result); } @@ -184,8 +185,8 @@ void FixedSizeScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_co template void FixedSizeFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); // first fetch the data from the base table auto data_ptr = handle.GetDataMutable() + segment.GetBlockOffset() + NumericCast(row_id) * sizeof(T); @@ -197,14 +198,14 @@ void FixedSizeFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t ro // Append //===--------------------------------------------------------------------===// static unique_ptr FixedSizeInitAppend(ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); return make_uniq(std::move(handle)); } struct StandardFixedSizeAppend { template - static void Append(SegmentStatistics &stats, data_ptr_t target, idx_t target_offset, UnifiedVectorFormat &adata, + static void Append(BaseStatistics &stats, data_ptr_t target, idx_t target_offset, UnifiedVectorFormat &adata, idx_t offset, idx_t count) { auto sdata = UnifiedVectorFormat::GetData(adata); auto tdata = reinterpret_cast(target); @@ -214,22 +215,22 @@ struct StandardFixedSizeAppend { auto target_idx = target_offset + i; bool is_null = !adata.validity.RowIsValid(source_idx); if (!is_null) { - stats.statistics.SetHasNoNullFast(); - stats.statistics.UpdateNumericStats(sdata[source_idx]); + stats.SetHasNoNullFast(); + stats.UpdateNumericStats(sdata[source_idx]); tdata[target_idx] = sdata[source_idx]; } else { - stats.statistics.SetHasNullFast(); + stats.SetHasNullFast(); // we insert a NullValue in the null gap for debuggability // this value should never be used or read anywhere tdata[target_idx] = NullValue(); } } } else { - stats.statistics.SetHasNoNullFast(); + stats.SetHasNoNullFast(); for (idx_t i = 0; i < count; i++) { auto source_idx = adata.sel->get_index(offset + i); auto target_idx = target_offset + i; - stats.statistics.UpdateNumericStats(sdata[source_idx]); + stats.UpdateNumericStats(sdata[source_idx]); tdata[target_idx] = sdata[source_idx]; } } @@ -238,7 +239,7 @@ struct StandardFixedSizeAppend { struct ListFixedSizeAppend { template - static void Append(SegmentStatistics &stats, data_ptr_t target, idx_t target_offset, UnifiedVectorFormat &adata, + static void Append(BaseStatistics &stats, data_ptr_t target, idx_t target_offset, UnifiedVectorFormat &adata, idx_t offset, idx_t count) { auto sdata = UnifiedVectorFormat::GetData(adata); auto tdata = reinterpret_cast(target); @@ -251,7 +252,7 @@ struct ListFixedSizeAppend { }; template -idx_t FixedSizeAppend(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, +idx_t FixedSizeAppend(CompressionAppendState &append_state, ColumnSegment &segment, BaseStatistics &stats, UnifiedVectorFormat &data, idx_t offset, idx_t count) { D_ASSERT(segment.GetBlockOffset() == 0); @@ -265,7 +266,7 @@ idx_t FixedSizeAppend(CompressionAppendState &append_state, ColumnSegment &segme } template -idx_t FixedSizeFinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { +idx_t FixedSizeFinalizeAppend(ColumnSegment &segment, BaseStatistics &stats) { return segment.count * sizeof(T); } diff --git a/src/duckdb/src/storage/compression/fsst.cpp b/src/duckdb/src/storage/compression/fsst.cpp index 48a68434b..46fac2934 100644 --- a/src/duckdb/src/storage/compression/fsst.cpp +++ b/src/duckdb/src/storage/compression/fsst.cpp @@ -10,6 +10,7 @@ #include "duckdb/main/config.hpp" #include "duckdb/storage/string_uncompressed.hpp" #include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/main/settings.hpp" #include "fsst.h" @@ -19,13 +20,6 @@ namespace duckdb { struct FSSTScanState; -typedef struct { - uint32_t dict_size; - uint32_t dict_end; - uint32_t bitpacking_width; - uint32_t fsst_symbol_table_offset; -} fsst_compression_header_t; - // Counts and offsets used during scanning/fetching // | ColumnSegment to be scanned / fetched from | // | untouched | bp align | unused d-values | to scan | bp align | untouched | @@ -44,12 +38,12 @@ struct FSSTStorage { static constexpr double ANALYSIS_SAMPLE_SIZE = 0.25; static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); - static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); + static bool StringAnalyze(AnalyzeState &state_p, const Vector &input); static idx_t StringFinalAnalyze(AnalyzeState &state_p); static unique_ptr InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr analyze_state_p); - static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); + static void Compress(CompressionState &state_p, const Vector &scan_vector); static void FinalizeCompress(CompressionState &state_p); static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); @@ -78,8 +72,8 @@ struct FSSTStorage { // Analyze //===--------------------------------------------------------------------===// struct FSSTAnalyzeState : public AnalyzeState { - explicit FSSTAnalyzeState(const CompressionInfo &info) - : AnalyzeState(info), count(0), fsst_string_total_size(0), empty_strings(0) { + explicit FSSTAnalyzeState(BlockManager &block_manager) + : AnalyzeState(block_manager), count(0), fsst_string_total_size(0), empty_strings(0) { } ~FSSTAnalyzeState() override { @@ -103,20 +97,20 @@ struct FSSTAnalyzeState : public AnalyzeState { unique_ptr FSSTStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { auto &storage_manager = col_data.GetStorageManager(); - if (storage_manager.GetStorageVersion() >= 5) { + if (StorageManager::TargetAtLeastVersion(StorageVersion::V1_3_0, storage_manager.GetStorageVersion())) { // dict_fsst introduced - disable fsst return nullptr; } - CompressionInfo info(col_data.GetBlockManager()); - return make_uniq(info); + return make_uniq(col_data.GetBlockManager()); } -bool FSSTStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { +bool FSSTStorage::StringAnalyze(AnalyzeState &state_p, const Vector &input) { auto &state = state_p.Cast(); UnifiedVectorFormat vdata; input.ToUnifiedFormat(vdata); + const auto count = input.size(); state.count += count; auto data = UnifiedVectorFormat::GetData(vdata); @@ -212,11 +206,11 @@ idx_t FSSTStorage::StringFinalAnalyze(AnalyzeState &state_p) { // Compress //===--------------------------------------------------------------------===// -class FSSTCompressionState : public CompressionState { +class FSSTCompressionState : public StandardCompressionState { public: - FSSTCompressionState(ColumnDataCheckpointData &checkpoint_data, const CompressionInfo &info) - : CompressionState(info), checkpoint_data(checkpoint_data), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_FSST)) { + explicit FSSTCompressionState(ColumnDataCheckpointData &checkpoint_data) + : StandardCompressionState(checkpoint_data, CompressionType::COMPRESSION_FSST), + stats_writer(checkpoint_data.GetType()) { CreateEmptySegment(); } @@ -233,19 +227,12 @@ class FSSTCompressionState : public CompressionState { last_fitting_size = 0; // Reset the pointers into the current segment - auto &buffer_manager = BufferManager::GetBufferManager(current_segment->db); - current_handle = buffer_manager.Pin(current_segment->block); - current_dictionary = FSSTStorage::GetDictionary(*current_segment, current_handle); - current_end_ptr = current_handle.GetDataMutable() + current_dictionary.end; + current_dictionary = FSSTStorage::GetDictionary(*current_segment, handle); + current_end_ptr = handle.GetDataMutable() + current_dictionary.end; } void CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); - current_segment = std::move(compressed_segment); + CreateAndPinNewSegment(); Reset(); } @@ -257,7 +244,7 @@ class FSSTCompressionState : public CompressionState { }; } - UncompressedStringStorage::UpdateStringStats(current_segment->stats, uncompressed_string); + stats_writer.Update(uncompressed_string); // Write string into dictionary current_dictionary.size += compressed_string_len; @@ -287,12 +274,12 @@ class FSSTCompressionState : public CompressionState { void AddNull() { AddEmptyStringInternal(); - current_segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); } void AddEmptyString() { AddEmptyStringInternal(); - UncompressedStringStorage::UpdateStringStats(current_segment->stats, ""); + stats_writer.Update(string_t("")); } size_t GetRequiredSize(size_t string_len) { @@ -327,8 +314,7 @@ class FSSTCompressionState : public CompressionState { void Flush(bool final = false) { auto segment_size = Finalize(); - auto &state = checkpoint_data.GetCheckpointState(); - state.FlushSegment(std::move(current_segment), std::move(current_handle), segment_size); + FlushCurrentSegment(stats_writer, segment_size); if (!final) { CreateEmptySegment(); @@ -336,8 +322,8 @@ class FSSTCompressionState : public CompressionState { } idx_t Finalize() { - auto &buffer_manager = BufferManager::GetBufferManager(current_segment->db); - auto handle = buffer_manager.Pin(current_segment->block); + auto &buffer_manager = BufferManager::GetBufferManager(current_segment->GetDatabase()); + auto handle = buffer_manager.Pin(current_segment->GetBlockHandle()); if (current_dictionary.end != info.GetBlockSize()) { throw InternalException("dictionary end does not match the block size in FSSTCompressionState::Finalize"); } @@ -396,14 +382,10 @@ class FSSTCompressionState : public CompressionState { return total_size; } - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; - // State regarding current segment - unique_ptr current_segment; - BufferHandle current_handle; StringDictionaryContainer current_dictionary; data_ptr_t current_end_ptr; + StatsWriter stats_writer; // Buffers and map for current segment vector index_buffer; @@ -420,7 +402,7 @@ class FSSTCompressionState : public CompressionState { unique_ptr FSSTStorage::InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr analyze_state_p) { auto &analyze_state = analyze_state_p->Cast(); - auto compression_state = make_uniq(checkpoint_data, analyze_state.info); + auto compression_state = make_uniq(checkpoint_data); if (analyze_state.fsst_encoder == nullptr) { throw InternalException("No encoder found during FSST compression"); @@ -434,7 +416,7 @@ unique_ptr FSSTStorage::InitCompression(ColumnDataCheckpointDa return std::move(compression_state); } -void FSSTStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void FSSTStorage::Compress(CompressionState &state_p, const Vector &scan_vector) { auto &state = state_p.Cast(); // Get vector data @@ -447,6 +429,7 @@ void FSSTStorage::Compress(CompressionState &state_p, Vector &scan_vector, idx_t vector strings_in; size_t total_size = 0; idx_t total_count = 0; + const auto count = scan_vector.size(); for (idx_t i = 0; i < count; i++) { auto idx = vdata.sel->get_index(i); @@ -575,8 +558,8 @@ unique_ptr FSSTStorage::StringInitScan(const QueryContext &con auto block_size = segment.GetBlockSize(); auto string_block_limit = StringUncompressed::GetStringBlockLimit(block_size); auto state = make_uniq(string_block_limit); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - state->handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + state->handle = buffer_manager.Pin(segment.GetBlockHandle()); auto base_ptr = state->handle.GetDataMutable() + segment.GetBlockOffset(); state->duckdb_fsst_decoder = make_buffer(); @@ -587,7 +570,7 @@ unique_ptr FSSTStorage::StringInitScan(const QueryContext &con } state->duckdb_fsst_decoder_ptr = state->duckdb_fsst_decoder.get(); - const auto &stats = segment.stats.statistics; + const auto &stats = segment.GetStats(); if (stats.GetStatsType() == StatisticsType::STRING_STATS && StringStats::HasMaxStringLength(stats)) { state->all_values_inlined = StringStats::MaxStringLength(stats) <= string_t::INLINE_LENGTH; } @@ -648,7 +631,7 @@ void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &sta bool enable_fsst_vectors; if (ALLOW_FSST_VECTORS) { - enable_fsst_vectors = Settings::Get(segment.db); + enable_fsst_vectors = Settings::Get(segment.GetDatabase()); } else { enable_fsst_vectors = false; } @@ -735,8 +718,8 @@ void FSSTStorage::Select(ColumnSegment &segment, ColumnScanState &state, idx_t v //===--------------------------------------------------------------------===// void FSSTStorage::StringFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); auto base_ptr = handle.GetDataMutable() + segment.GetBlockOffset(); auto base_data = data_ptr_cast(base_ptr + sizeof(fsst_compression_header_t)); auto dict = GetDictionary(segment, handle); diff --git a/src/duckdb/src/storage/compression/numeric_constant.cpp b/src/duckdb/src/storage/compression/numeric_constant.cpp index fe1165786..854059058 100644 --- a/src/duckdb/src/storage/compression/numeric_constant.cpp +++ b/src/duckdb/src/storage/compression/numeric_constant.cpp @@ -1,21 +1,197 @@ +#include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/function/compression/compression.hpp" #include "duckdb/function/compression_function.hpp" -#include "duckdb/planner/filter/bloom_filter.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_conjunction_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_operator_expression.hpp" -#include "duckdb/planner/filter/prefix_range_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/segment/uncompressed.hpp" #include "duckdb/storage/table/column_segment.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/selectivity_optional_filter.hpp" namespace duckdb { +static optional_ptr TryGetFunctionExpression(const Expression &expression) { + if (expression.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + return nullptr; + } + return expression.Cast(); +} + +static bool IsSimpleFilterColumnRef(const Expression &expression) { + return expression.GetExpressionType() == ExpressionType::BOUND_REF || + expression.GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF; +} + +static bool TryComparisonFiltersNullValues(const BoundFunctionExpression &comparison, bool &filters_nulls, + bool &filters_valid_values) { + optional_ptr constant_expr; + auto &left = BoundComparisonExpression::Left(comparison); + auto &right = BoundComparisonExpression::Right(comparison); + auto comparison_type = comparison.GetExpressionType(); + if (IsSimpleFilterColumnRef(left) && right.GetExpressionType() == ExpressionType::VALUE_CONSTANT) { + constant_expr = right.Cast(); + } else if (IsSimpleFilterColumnRef(right) && left.GetExpressionType() == ExpressionType::VALUE_CONSTANT) { + constant_expr = left.Cast(); + comparison_type = FlipComparisonExpression(comparison_type); + } else { + return false; + } + if (constant_expr->GetValue().IsNull()) { + switch (comparison_type) { + case ExpressionType::COMPARE_NOT_DISTINCT_FROM: + filters_valid_values = true; + break; + case ExpressionType::COMPARE_DISTINCT_FROM: + filters_nulls = true; + break; + default: + filters_nulls = true; + filters_valid_values = true; + break; + } + } else { + switch (comparison_type) { + case ExpressionType::COMPARE_DISTINCT_FROM: + filters_nulls = false; + break; + default: + filters_nulls = true; + break; + } + } + return true; +} + +static bool TryExpressionFiltersNullValues(const Expression &expression, bool &filters_nulls, + bool &filters_valid_values) { + filters_nulls = false; + filters_valid_values = false; + + if (expression.GetExpressionClass() == ExpressionClass::BOUND_CONJUNCTION) { + auto &conjunction = expression.Cast(); + if (conjunction.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { + for (auto &child : conjunction.GetChildren()) { + bool child_filters_nulls = false; + bool child_filters_valid_values = false; + if (!TryExpressionFiltersNullValues(*child, child_filters_nulls, child_filters_valid_values)) { + return false; + } + filters_nulls = filters_nulls || child_filters_nulls; + filters_valid_values = filters_valid_values || child_filters_valid_values; + } + return true; + } + if (conjunction.GetExpressionType() == ExpressionType::CONJUNCTION_OR) { + filters_nulls = true; + filters_valid_values = true; + for (auto &child : conjunction.GetChildren()) { + bool child_filters_nulls = false; + bool child_filters_valid_values = false; + if (!TryExpressionFiltersNullValues(*child, child_filters_nulls, child_filters_valid_values)) { + return false; + } + filters_nulls = filters_nulls && child_filters_nulls; + filters_valid_values = filters_valid_values && child_filters_valid_values; + } + return true; + } + return false; + } + + if (BoundComparisonExpression::IsComparison(expression.GetExpressionType())) { + auto &comparison = expression.Cast(); + return TryComparisonFiltersNullValues(comparison, filters_nulls, filters_valid_values); + } + + if (expression.GetExpressionClass() == ExpressionClass::BOUND_OPERATOR) { + auto &op = expression.Cast(); + if (op.GetChildren().size() != 1 || !IsSimpleFilterColumnRef(*op.GetChildren()[0])) { + return false; + } + switch (expression.GetExpressionType()) { + case ExpressionType::OPERATOR_IS_NULL: + filters_valid_values = true; + return true; + case ExpressionType::OPERATOR_IS_NOT_NULL: + filters_nulls = true; + return true; + default: + return false; + } + } + + auto func_expr = TryGetFunctionExpression(expression); + if (!func_expr) { + return false; + } + + auto &function_name = func_expr->Function().GetName(); + if (function_name == OptionalFilterScalarFun::NAME) { + return true; + } + if (function_name == BloomFilterScalarFun::NAME) { + if (!func_expr->BindInfo()) { + return true; + } + auto &data = func_expr->BindInfo()->Cast(); + if (!data.filter) { + return true; + } + filters_nulls = data.filters_null_values; + return true; + } + if (function_name == SelectivityOptionalFilterScalarFun::NAME) { + if (!func_expr->BindInfo()) { + return false; + } + auto &data = func_expr->BindInfo()->Cast(); + if (!data.child_filter_expr) { + return false; + } + return TryExpressionFiltersNullValues(*data.child_filter_expr, filters_nulls, filters_valid_values); + } + if (function_name == PerfectHashJoinScalarFun::NAME) { + if (!func_expr->BindInfo()) { + return true; + } + auto &data = func_expr->BindInfo()->Cast(); + if (!data.executor) { + return true; + } + filters_nulls = true; + return true; + } + if (function_name == PrefixRangeScalarFun::NAME) { + if (!func_expr->BindInfo()) { + return true; + } + auto &data = func_expr->BindInfo()->Cast(); + if (!data.filter || !data.filter->IsInitialized()) { + return true; + } + filters_nulls = true; + return true; + } + if (function_name == DynamicFilterScalarFun::NAME) { + if (!func_expr->BindInfo()) { + return true; + } + auto &data = func_expr->BindInfo()->Cast(); + if (!data.filter_data || !data.filter_data->initialized.load()) { + return true; + } + filters_nulls = true; + return true; + } + return false; +} + //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// @@ -27,7 +203,7 @@ unique_ptr ConstantInitScan(const QueryContext &context, Colum // Scan Partial //===--------------------------------------------------------------------===// void ConstantFillFunctionValidity(ColumnSegment &segment, Vector &result, idx_t start_idx, idx_t count) { - auto &stats = segment.stats.statistics; + const auto &stats = segment.GetStats(); if (stats.CanHaveNull()) { auto &mask = FlatVector::ValidityMutable(result); for (idx_t i = 0; i < count; i++) { @@ -38,7 +214,7 @@ void ConstantFillFunctionValidity(ColumnSegment &segment, Vector &result, idx_t template void ConstantFillFunction(ColumnSegment &segment, Vector &result, idx_t start_idx, idx_t count) { - auto &nstats = segment.stats.statistics; + const auto &nstats = segment.GetStats(); auto data = FlatVector::GetDataMutable(result); auto constant_value = NumericStats::GetMin(nstats); @@ -62,7 +238,7 @@ void ConstantScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t s // Scan base data //===--------------------------------------------------------------------===// void ConstantScanFunctionValidity(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - auto &stats = segment.stats.statistics; + const auto &stats = segment.GetStats(); if (stats.CanHaveNull()) { if (result.GetType().InternalType() == PhysicalType::STRUCT || result.GetVectorType() == VectorType::CONSTANT_VECTOR) { @@ -76,7 +252,7 @@ void ConstantScanFunctionValidity(ColumnSegment &segment, ColumnScanState &state template void ConstantScanFunction(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - auto &nstats = segment.stats.statistics; + const auto &nstats = segment.GetStats(); auto data = FlatVector::GetDataMutable(result); data[0] = NumericStats::GetMin(nstats); @@ -119,61 +295,12 @@ void ConstantFun::FiltersNullValues(const LogicalType &type, const TableFilter & filters_nulls = false; filters_valid_values = false; - switch (filter.filter_type) { - case TableFilterType::OPTIONAL_FILTER: { - auto &opt_filter = filter.Cast(); - return opt_filter.FiltersNullValues(type, filters_nulls, filters_valid_values, filter_state); - } - case TableFilterType::CONJUNCTION_OR: { - auto &conjunction_or = filter.Cast(); - auto &state = filter_state.Cast(); - filters_nulls = true; - filters_valid_values = true; - for (idx_t child_idx = 0; child_idx < conjunction_or.child_filters.size(); child_idx++) { - auto &child_filter = *conjunction_or.child_filters[child_idx]; - auto &child_state = *state.child_states[child_idx]; - bool child_filters_nulls, child_filters_valid_values; - FiltersNullValues(type, child_filter, child_filters_nulls, child_filters_valid_values, child_state); - filters_nulls = filters_nulls && child_filters_nulls; - filters_valid_values = filters_valid_values && child_filters_valid_values; - } - break; - } - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and = filter.Cast(); - auto &state = filter_state.Cast(); - filters_nulls = false; - filters_valid_values = false; - for (idx_t child_idx = 0; child_idx < conjunction_and.child_filters.size(); child_idx++) { - auto &child_filter = *conjunction_and.child_filters[child_idx]; - auto &child_state = *state.child_states[child_idx]; - bool child_filters_nulls, child_filters_valid_values; - FiltersNullValues(type, child_filter, child_filters_nulls, child_filters_valid_values, child_state); - filters_nulls = filters_nulls || child_filters_nulls; - filters_valid_values = filters_valid_values || child_filters_valid_values; - } - break; - } - case TableFilterType::EXPRESSION_FILTER: { - auto &expr_filter = filter.Cast(); - auto &state = filter_state.Cast(); + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "ConstantFun::FiltersNullValues"); + auto &state = filter_state.Cast(); + if (!TryExpressionFiltersNullValues(*expr_filter.expr, filters_nulls, filters_valid_values)) { Value val(type); filters_nulls = !expr_filter.EvaluateWithConstant(*state.executor, val); filters_valid_values = false; - break; - } - case TableFilterType::BLOOM_FILTER: { - auto &bf = filter.Cast(); - filters_nulls = bf.FiltersNullValues(); - break; - } - case TableFilterType::PERFECT_HASH_JOIN_FILTER: - case TableFilterType::PREFIX_RANGE_FILTER: { - filters_nulls = true; - break; - } - default: - throw InternalException("FIXME: unsupported type for filter selection in validity select"); } } @@ -184,7 +311,7 @@ void ConstantFilterValidity(ColumnSegment &segment, ColumnScanState &state, idx_ bool filters_nulls, filters_valid_values; ConstantFun::FiltersNullValues(result.GetType(), filter, filters_nulls, filters_valid_values, filter_state); - auto &stats = segment.stats.statistics; + const auto &stats = segment.GetStats(); if (stats.CanHaveNull()) { // all values are NULL if (filters_nulls) { diff --git a/src/duckdb/src/storage/compression/rle.cpp b/src/duckdb/src/storage/compression/rle.cpp index 5be44a7d3..da7bde835 100644 --- a/src/duckdb/src/storage/compression/rle.cpp +++ b/src/duckdb/src/storage/compression/rle.cpp @@ -2,6 +2,7 @@ #include "duckdb/function/compression/compression.hpp" #include "duckdb/function/compression_function.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/storage/table/column_data_checkpointer.hpp" #include "duckdb/storage/table/column_segment.hpp" #include "duckdb/storage/table/scan_state.hpp" @@ -83,7 +84,7 @@ struct RLEState { template struct RLEAnalyzeState : public AnalyzeState { - explicit RLEAnalyzeState(const CompressionInfo &info) : AnalyzeState(info) { + explicit RLEAnalyzeState(BlockManager &block_manager) : AnalyzeState(block_manager) { } RLEState state; @@ -91,18 +92,17 @@ struct RLEAnalyzeState : public AnalyzeState { template unique_ptr RLEInitAnalyze(ColumnData &col_data, PhysicalType type) { - CompressionInfo info(col_data.GetBlockManager()); - return make_uniq>(info); + return make_uniq>(col_data.GetBlockManager()); } template -bool RLEAnalyze(AnalyzeState &state, Vector &input, idx_t count) { +bool RLEAnalyze(AnalyzeState &state, const Vector &input) { auto &rle_state = state.template Cast>(); UnifiedVectorFormat vdata; input.ToUnifiedFormat(vdata); auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { + for (idx_t i = 0; i < input.size(); i++) { auto idx = vdata.sel->get_index(i); rle_state.state.Update(data, vdata.validity, idx); } @@ -123,7 +123,15 @@ struct RLEConstants { }; template -struct RLECompressState : public CompressionState { +struct RLECompressState : public StandardCompressionState { + explicit RLECompressState(ColumnDataCheckpointData &checkpoint_data_p) + : StandardCompressionState(checkpoint_data_p, CompressionType::COMPRESSION_RLE) { + CreateEmptySegment(); + + state.dataptr = (void *)this; + max_rle_count = MaxRLECount(); + } + struct RLEWriter { template static void Operation(VALUE_TYPE value, rle_count_t count, void *dataptr, bool is_null) { @@ -137,31 +145,17 @@ struct RLECompressState : public CompressionState { return AlignValueFloor((info.GetBlockSize() - RLEConstants::RLE_HEADER_SIZE) / entry_size); } - RLECompressState(ColumnDataCheckpointData &checkpoint_data_p, const CompressionInfo &info) - : CompressionState(info), checkpoint_data(checkpoint_data_p), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_RLE)) { - CreateEmptySegment(); - - state.dataptr = (void *)this; - max_rle_count = MaxRLECount(); - } - void CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto column_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); - current_segment = std::move(column_segment); - - auto &buffer_manager = BufferManager::GetBufferManager(db); - handle = buffer_manager.Pin(current_segment->block); + CreateAndPinNewSegment(); } void Append(UnifiedVectorFormat &vdata, idx_t count) { auto data = UnifiedVectorFormat::GetData(vdata); for (idx_t i = 0; i < count; i++) { auto idx = vdata.sel->get_index(i); + if (WRITE_STATISTICS && !vdata.validity.RowIsValid(idx)) { + stats_writer.SetHasNull(); + } state.template Update::RLEWriter>(data, vdata.validity, idx); } } @@ -178,10 +172,9 @@ struct RLECompressState : public CompressionState { // update meta data if (WRITE_STATISTICS) { if (!is_null) { - current_segment->stats.statistics.SetHasNoNullFast(); - current_segment->stats.statistics.UpdateNumericStats(value); + stats_writer.Update(value); } else { - current_segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); } } current_segment->count += count; @@ -209,10 +202,8 @@ struct RLECompressState : public CompressionState { memmove(data_ptr + aligned_rle_offset, data_ptr + original_rle_offset, counts_size); // store the final RLE offset within the segment Store(aligned_rle_offset, data_ptr); - handle.Destroy(); - auto &state = checkpoint_data.GetCheckpointState(); - state.FlushSegment(std::move(current_segment), std::move(handle), total_segment_size); + FlushCurrentSegment(stats_writer, total_segment_size); } void Finalize() { @@ -222,12 +213,8 @@ struct RLECompressState : public CompressionState { current_segment.reset(); } - ColumnDataCheckpointData &checkpoint_data; - const CompressionFunction &function; - unique_ptr current_segment; - BufferHandle handle; - RLEState state; + StatsWriter stats_writer; idx_t entry_count = 0; idx_t max_rle_count; }; @@ -235,16 +222,16 @@ struct RLECompressState : public CompressionState { template unique_ptr RLEInitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr state) { - return make_uniq>(checkpoint_data, state->info); + return make_uniq>(checkpoint_data); } template -void RLECompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void RLECompress(CompressionState &state_p, const Vector &scan_vector) { auto &state = state_p.Cast>(); UnifiedVectorFormat vdata; scan_vector.ToUnifiedFormat(vdata); - state.Append(vdata, count); + state.Append(vdata, scan_vector.size()); } template @@ -259,8 +246,8 @@ void RLEFinalizeCompress(CompressionState &state_p) { template struct RLEScanState : public SegmentScanState { explicit RLEScanState(ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + handle = buffer_manager.Pin(segment.GetBlockHandle()); entry_pos = 0; position_in_entry = 0; rle_count_offset = diff --git a/src/duckdb/src/storage/compression/roaring/analyze.cpp b/src/duckdb/src/storage/compression/roaring/analyze.cpp index 15e3a1dda..7080c4251 100644 --- a/src/duckdb/src/storage/compression/roaring/analyze.cpp +++ b/src/duckdb/src/storage/compression/roaring/analyze.cpp @@ -39,8 +39,8 @@ static unsafe_unique_array CreateBitmaskTable() { //===--------------------------------------------------------------------===// // Analyze //===--------------------------------------------------------------------===// -RoaringAnalyzeState::RoaringAnalyzeState(const CompressionInfo &info) - : AnalyzeState(info), bitmask_table(CreateBitmaskTable()) { +RoaringAnalyzeState::RoaringAnalyzeState(BlockManager &block_manager) + : AnalyzeState(block_manager), bitmask_table(CreateBitmaskTable()) { } void RoaringAnalyzeState::HandleByte(RoaringAnalyzeState &state, uint8_t array_index) { @@ -157,17 +157,19 @@ void RoaringAnalyzeState::FlushContainer() { } template <> -void RoaringAnalyzeState::Analyze(Vector &input, idx_t count) { +void RoaringAnalyzeState::Analyze(const Vector &input) { auto &self = *this; - RoaringStateAppender::AppendVector(self, input, count); - total_count += count; + RoaringStateAppender::AppendVector(self, input); + total_count += input.size(); } template <> -void RoaringAnalyzeState::Analyze(Vector &input, idx_t count) { +void RoaringAnalyzeState::Analyze(const Vector &input) { auto &self = *this; + const auto count = input.size(); input.Flatten(); Vector bitpacked_vector(LogicalType::UBIGINT, count); + FlatVector::SetSize(bitpacked_vector, count); auto &bitpacked_vector_validity = FlatVector::ValidityMutable(bitpacked_vector); bitpacked_vector_validity.EnsureWritable(); auto dst = data_ptr_cast(bitpacked_vector_validity.GetData()); @@ -181,7 +183,7 @@ void RoaringAnalyzeState::Analyze(Vector &input, idx_t count // Bitpack the booleans, so they can be fed through the current compression code, with the same format as a validity // mask. - RoaringStateAppender::AppendVector(self, bitpacked_vector, count); + RoaringStateAppender::AppendVector(self, bitpacked_vector); total_count += count; } diff --git a/src/duckdb/src/storage/compression/roaring/common.cpp b/src/duckdb/src/storage/compression/roaring/common.cpp index f2cbbadcb..e73d0e24c 100644 --- a/src/duckdb/src/storage/compression/roaring/common.cpp +++ b/src/duckdb/src/storage/compression/roaring/common.cpp @@ -168,18 +168,19 @@ void SetInvalidRange(ValidityMask &result, idx_t start, idx_t end) { unique_ptr RoaringInitAnalyze(ColumnData &col_data, PhysicalType type) { // check if the storage version we are writing to supports roaring const auto storage_version = col_data.GetStorageManager().GetStorageVersion(); - if (storage_version < 4 || (type == PhysicalType::BOOL && storage_version < 7)) { + // before: serialization version 4 and ser version 7 + if (StorageManager::IsPriorToVersion(StorageVersion::V1_2_0, storage_version) || + (type == PhysicalType::BOOL && StorageManager::IsPriorToVersion(StorageVersion::V1_5_0, storage_version))) { // compatibility mode with old versions - disable roaring return nullptr; } - CompressionInfo info(col_data.GetBlockManager()); - auto state = make_uniq(info); + auto state = make_uniq(col_data.GetBlockManager()); return std::move(state); } template -bool RoaringAnalyze(AnalyzeState &state, Vector &input, idx_t count) { +bool RoaringAnalyze(AnalyzeState &state, const Vector &input) { auto &analyze_state = state.Cast(); - analyze_state.Analyze(input, count); + analyze_state.Analyze(input); return true; } @@ -198,9 +199,9 @@ unique_ptr RoaringInitCompression(ColumnDataCheckpointData &ch } template -void RoaringCompress(CompressionState &state_p, Vector &scan_vector, idx_t count) { +void RoaringCompress(CompressionState &state_p, const Vector &scan_vector) { auto &state = state_p.Cast(); - state.Compress(scan_vector, count); + state.Compress(scan_vector); } void RoaringFinalizeCompress(CompressionState &state_p) { diff --git a/src/duckdb/src/storage/compression/roaring/compress.cpp b/src/duckdb/src/storage/compression/roaring/compress.cpp index 06a433538..33b11673b 100644 --- a/src/duckdb/src/storage/compression/roaring/compress.cpp +++ b/src/duckdb/src/storage/compression/roaring/compress.cpp @@ -193,10 +193,9 @@ void ContainerCompressionState::Reset() { //===--------------------------------------------------------------------===// RoaringCompressState::RoaringCompressState(ColumnDataCheckpointData &checkpoint_data, unique_ptr analyze_state_p) - : CompressionState(analyze_state_p->info), owned_analyze_state(std::move(analyze_state_p)), - analyze_state(owned_analyze_state->Cast()), container_state(), - container_metadata(analyze_state.container_metadata), checkpoint_data(checkpoint_data), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ROARING)) { + : StandardCompressionState(checkpoint_data, CompressionType::COMPRESSION_ROARING), + owned_analyze_state(std::move(analyze_state_p)), analyze_state(owned_analyze_state->Cast()), + container_state(), container_metadata(analyze_state.container_metadata) { CreateEmptySegment(); total_count = 0; InitializeContainer(); @@ -288,22 +287,13 @@ void RoaringCompressState::InitializeContainer() { } void RoaringCompressState::CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); - current_segment = std::move(compressed_segment); - - auto &buffer_manager = BufferManager::GetBufferManager(db); - handle = buffer_manager.Pin(current_segment->block); + CreateAndPinNewSegment(); data_ptr = handle.GetDataMutable(); data_ptr += sizeof(idx_t); metadata_ptr = handle.GetDataMutable() + info.GetBlockSize(); } void RoaringCompressState::FlushSegment() { - auto &state = checkpoint_data.GetCheckpointState(); auto base_ptr = handle.GetDataMutable(); // +======================================+ // |x|ddddddddddddddd||mmm| | @@ -341,7 +331,11 @@ void RoaringCompressState::FlushSegment() { Store(metadata_start, handle.GetDataMutable()); auto total_segment_size = sizeof(idx_t) + data_size + metadata_size; - state.FlushSegment(std::move(current_segment), std::move(handle), total_segment_size); + if (GetType().id() == LogicalTypeId::BOOLEAN) { + FlushCurrentSegment(bool_stats_writer, total_segment_size); + } else { + FlushCurrentSegment(stats_writer, total_segment_size); + } } void RoaringCompressState::Finalize() { @@ -414,10 +408,10 @@ void RoaringCompressState::FlushContainer() { bool has_nulls = container_state.null_count != 0; bool has_non_nulls = container_state.null_count != container_state.appended_count; if (has_nulls || container_state.uncompressed) { - current_segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); } if (has_non_nulls || container_state.uncompressed) { - current_segment->stats.statistics.SetHasNoNullFast(); + stats_writer.SetHasValid(); } current_segment->count += container_state.appended_count; container_state.Reset(); @@ -486,17 +480,19 @@ void RoaringCompressState::Flush(RoaringCompressState &state) { state.NextContainer(); } template <> -void RoaringCompressState::Compress(Vector &input, idx_t count) { +void RoaringCompressState::Compress(const Vector &input) { auto &self = *this; - RoaringStateAppender::AppendVector(self, input, count); + RoaringStateAppender::AppendVector(self, input); } template <> -void RoaringCompressState::Compress(Vector &input, idx_t count) { +void RoaringCompressState::Compress(const Vector &input) { auto &self = *this; + const auto count = input.size(); input.Flatten(); const bool *src = FlatVector::GetData(input); Vector bitpacked_vector(LogicalType::UBIGINT, count); + FlatVector::SetSize(bitpacked_vector, count); auto &bitpacked_vector_validity = FlatVector::ValidityMutable(bitpacked_vector); bitpacked_vector_validity.EnsureWritable(); const auto dst = data_ptr_cast(bitpacked_vector_validity.GetData()); @@ -505,11 +501,11 @@ void RoaringCompressState::Compress(Vector &input, idx_t cou // Bitpack the booleans, so they can be fed through the current compression code, with the same format as a validity // mask. if (validity.CannotHaveNull()) { - BitPackBooleans(dst, src, count, &validity, &this->current_segment->stats.statistics); + BitPackBooleans(dst, src, count, &validity, &bool_stats_writer); } else { - BitPackBooleans(dst, src, count, &validity, &this->current_segment->stats.statistics); + BitPackBooleans(dst, src, count, &validity, &bool_stats_writer); } - RoaringStateAppender::AppendVector(self, bitpacked_vector, count); + RoaringStateAppender::AppendVector(self, bitpacked_vector); } } // namespace roaring diff --git a/src/duckdb/src/storage/compression/roaring/scan.cpp b/src/duckdb/src/storage/compression/roaring/scan.cpp index 7bbf46d9a..09f7f6c7f 100644 --- a/src/duckdb/src/storage/compression/roaring/scan.cpp +++ b/src/duckdb/src/storage/compression/roaring/scan.cpp @@ -190,8 +190,8 @@ void BitsetContainerScanState::Verify() const { } RoaringScanState::RoaringScanState(ColumnSegment &segment) : segment(segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + handle = buffer_manager.Pin(segment.GetBlockHandle()); auto segment_size = segment.SegmentSize(); auto segment_block_offset = segment.GetBlockOffset(); if (segment_block_offset >= segment_size) { diff --git a/src/duckdb/src/storage/compression/standard_compression_state.cpp b/src/duckdb/src/storage/compression/standard_compression_state.cpp new file mode 100644 index 000000000..f187d18f0 --- /dev/null +++ b/src/duckdb/src/storage/compression/standard_compression_state.cpp @@ -0,0 +1,36 @@ +#include "duckdb/storage/compression/standard_compression_state.hpp" +#include "duckdb/storage/table/column_data_checkpointer.hpp" + +namespace duckdb { + +CompressionState::CompressionState(ColumnDataCheckpointData &checkpoint_data_p, CompressionType compression_type) + : checkpoint_data(checkpoint_data_p), function(checkpoint_data.GetCompressionFunction(compression_type)), + block_manager(checkpoint_data.GetStorageManager().GetBlockManager()), info(block_manager) { +} + +unique_ptr CompressionState::CreateNewSegment() { + return ColumnSegment::CreateTransientSegment(checkpoint_data.GetDatabase(), function, checkpoint_data.GetType(), + info.GetBlockSize(), info.GetBlockManager()); +} + +const LogicalType &CompressionState::GetType() { + return checkpoint_data.GetType(); +} + +StandardCompressionState::~StandardCompressionState() { +} + +void StandardCompressionState::CreateAndPinNewSegment() { + auto compressed_segment = CreateNewSegment(); + current_segment = std::move(compressed_segment); + + auto &buffer_manager = BufferManager::GetBufferManager(current_segment->GetDatabase()); + handle = buffer_manager.Pin(current_segment->GetBlockHandle()); +} + +void StandardCompressionState::FlushCurrentSegment(idx_t segment_size) { + auto &state = checkpoint_data.GetCheckpointState(); + state.FlushSegment(std::move(current_segment), std::move(handle), segment_size); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/compression/string_uncompressed.cpp b/src/duckdb/src/storage/compression/string_uncompressed.cpp index 825a93416..7e32dafb6 100644 --- a/src/duckdb/src/storage/compression/string_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/string_uncompressed.cpp @@ -23,8 +23,8 @@ UncompressedStringSegmentState::~UncompressedStringSegmentState() { // Analyze //===--------------------------------------------------------------------===// struct StringAnalyzeState : public AnalyzeState { - explicit StringAnalyzeState(const CompressionInfo &info) - : AnalyzeState(info), count(0), total_string_size(0), overflow_strings(0) { + explicit StringAnalyzeState(BlockManager &block_manager) + : AnalyzeState(block_manager), count(0), total_string_size(0), overflow_strings(0) { } idx_t count; @@ -33,15 +33,15 @@ struct StringAnalyzeState : public AnalyzeState { }; unique_ptr UncompressedStringStorage::StringInitAnalyze(ColumnData &col_data, PhysicalType type) { - CompressionInfo info(col_data.GetBlockManager()); - return make_uniq(info); + return make_uniq(col_data.GetBlockManager()); } -bool UncompressedStringStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { +bool UncompressedStringStorage::StringAnalyze(AnalyzeState &state_p, const Vector &input) { auto &state = state_p.Cast(); UnifiedVectorFormat vdata; input.ToUnifiedFormat(vdata); + const auto count = input.size(); state.count += count; auto data = UnifiedVectorFormat::GetData(vdata); for (idx_t i = 0; i < count; i++) { @@ -66,11 +66,11 @@ idx_t UncompressedStringStorage::StringFinalAnalyze(AnalyzeState &state_p) { // Scan //===--------------------------------------------------------------------===// void UncompressedStringInitPrefetch(ColumnSegment &segment, PrefetchState &prefetch_state) { - prefetch_state.AddBlock(segment.block); + prefetch_state.AddBlock(segment.GetBlockHandle()); auto segment_state = segment.GetSegmentState(); if (segment_state) { auto &state = segment_state->Cast(); - auto &block_manager = segment.block->GetBlockManager(); + auto &block_manager = segment.GetBlockHandle()->GetBlockManager(); for (auto &block_id : state.on_disk_blocks) { auto block_handle = state.GetHandle(block_manager, block_id); prefetch_state.AddBlock(block_handle); @@ -81,8 +81,8 @@ void UncompressedStringInitPrefetch(ColumnSegment &segment, PrefetchState &prefe unique_ptr UncompressedStringStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + result->handle = buffer_manager.Pin(segment.GetBlockHandle()); return std::move(result); } @@ -144,13 +144,13 @@ void UncompressedStringStorage::Select(ColumnSegment &segment, ColumnScanState & // Fetch //===--------------------------------------------------------------------===// BufferHandle &ColumnFetchState::GetOrInsertHandle(ColumnSegment &segment) { - auto primary_id = segment.block->BlockId(); + auto primary_id = segment.GetBlockHandle()->BlockId(); auto entry = handles.find(primary_id); if (entry == handles.end()) { // not pinned yet: pin it - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); auto pinned_entry = handles.insert(make_pair(primary_id, std::move(handle))); return pinned_entry.first->second; } else { @@ -198,9 +198,9 @@ void SerializedStringSegmentState::Serialize(Serializer &serializer) const { unique_ptr UncompressedStringStorage::StringInitSegment(ColumnSegment &segment, block_id_t block_id, optional_ptr segment_state) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); if (block_id == INVALID_BLOCK) { - auto handle = buffer_manager.Pin(segment.block); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); StringDictionaryContainer dictionary; dictionary.size = 0; dictionary.end = UnsafeNumericCast(segment.SegmentSize()); @@ -214,16 +214,16 @@ UncompressedStringStorage::StringInitSegment(ColumnSegment &segment, block_id_t return std::move(result); } -idx_t UncompressedStringStorage::FinalizeAppend(ColumnSegment &segment, SegmentStatistics &) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); +idx_t UncompressedStringStorage::FinalizeAppend(ColumnSegment &segment, BaseStatistics &) { + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); auto dict = GetDictionary(segment, handle); D_ASSERT(dict.end == segment.SegmentSize()); // compute the total size required to store this segment auto offset_size = DICTIONARY_HEADER_SIZE + segment.count * sizeof(int32_t); auto total_size = offset_size + dict.size; - CompressionInfo info(segment.block->GetBlockManager()); + CompressionInfo info(segment.GetBlockHandle()->GetBlockManager()); if (total_size >= info.GetCompactionFlushLimit()) { // the block is full enough, don't bother moving around the dictionary return segment.SegmentSize(); @@ -334,7 +334,7 @@ void UncompressedStringStorage::WriteStringMemory(ColumnSegment &segment, string shared_ptr block; BufferHandle handle; - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); auto &state = segment.GetSegmentState()->Cast(); // check if the string fits in the current block if (!state.head || state.head->offset + total_length >= state.head->size) { @@ -369,7 +369,7 @@ void UncompressedStringStorage::WriteStringMemory(ColumnSegment &segment, string string_t UncompressedStringStorage::ReadOverflowString(ColumnSegment &segment, Vector &result, block_id_t block, int32_t offset) { - auto &buffer_manager = segment.block->GetMemory().GetBufferManager(); + auto &buffer_manager = segment.GetBlockHandle()->GetMemory().GetBufferManager(); auto &state = segment.GetSegmentState()->Cast(); D_ASSERT(block != INVALID_BLOCK); @@ -378,7 +378,7 @@ string_t UncompressedStringStorage::ReadOverflowString(ColumnSegment &segment, V if (block < MAXIMUM_BLOCK) { // read the overflow string from disk // pin the initial handle and read the length - auto block_handle = state.GetHandle(segment.block->GetBlockManager(), block); + auto block_handle = state.GetHandle(segment.GetBlockHandle()->GetBlockManager(), block); auto handle = buffer_manager.Pin(block_handle); // read header @@ -411,7 +411,7 @@ string_t UncompressedStringStorage::ReadOverflowString(ColumnSegment &segment, V if (remaining > 0) { // read the next block block_id_t next_block = Load(handle.GetDataMutable() + offset); - block_handle = state.GetHandle(segment.block->GetBlockManager(), next_block); + block_handle = state.GetHandle(segment.GetBlockHandle()->GetBlockManager(), next_block); handle = buffer_manager.Pin(block_handle); offset = 0; } diff --git a/src/duckdb/src/storage/compression/validity_uncompressed.cpp b/src/duckdb/src/storage/compression/validity_uncompressed.cpp index 7546731c7..4c374b1ed 100644 --- a/src/duckdb/src/storage/compression/validity_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/validity_uncompressed.cpp @@ -175,20 +175,19 @@ const validity_t ValidityUncompressed::UPPER_MASKS[] = {0x0, // Analyze //===--------------------------------------------------------------------===// struct ValidityAnalyzeState : public AnalyzeState { - explicit ValidityAnalyzeState(const CompressionInfo &info) : AnalyzeState(info), count(0) { + explicit ValidityAnalyzeState(BlockManager &block_manager) : AnalyzeState(block_manager), count(0) { } idx_t count; }; unique_ptr ValidityInitAnalyze(ColumnData &col_data, PhysicalType type) { - CompressionInfo info(col_data.GetBlockManager()); - return make_uniq(info); + return make_uniq(col_data.GetBlockManager()); } -bool ValidityAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { +bool ValidityAnalyze(AnalyzeState &state_p, const Vector &input) { auto &state = state_p.Cast(); - state.count += count; + state.count += input.size(); return true; } @@ -207,9 +206,9 @@ struct ValidityScanState : public SegmentScanState { unique_ptr ValidityInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(segment.block); - result->block_id = segment.block->BlockId(); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + result->handle = buffer_manager.Pin(segment.GetBlockHandle()); + result->block_id = segment.GetBlockHandle()->BlockId(); return std::move(result); } @@ -462,7 +461,7 @@ void ValidityScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t s auto &scan_state = state.scan_state->Cast(); auto buffer_ptr = scan_state.handle.GetDataMutable() + segment.GetBlockOffset(); - D_ASSERT(scan_state.block_id == segment.block->BlockId()); + D_ASSERT(scan_state.block_id == segment.GetBlockHandle()->BlockId()); auto &result_mask = FlatVector::ValidityMutable(result); ValidityUncompressed::UnalignedScan(buffer_ptr, segment.count, start, result_mask, result_offset, scan_count); } @@ -475,7 +474,7 @@ void ValidityScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_cou auto &scan_state = state.scan_state->Cast(); auto buffer_ptr = scan_state.handle.GetDataMutable() + segment.GetBlockOffset(); - D_ASSERT(scan_state.block_id == segment.block->BlockId()); + D_ASSERT(scan_state.block_id == segment.GetBlockHandle()->BlockId()); auto &result_mask = FlatVector::ValidityMutable(result); ValidityUncompressed::AlignedScan(buffer_ptr, start, result_mask, scan_count); } else { @@ -511,8 +510,8 @@ void ValiditySelect(ColumnSegment &segment, ColumnScanState &state, idx_t, Vecto //===--------------------------------------------------------------------===// void ValidityFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { D_ASSERT(row_id >= 0 && row_id < row_t(segment.count)); - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); auto dataptr = handle.GetDataMutable() + segment.GetBlockOffset(); ValidityMask mask(reinterpret_cast(dataptr), segment.count); auto &result_mask = FlatVector::ValidityMutable(result); @@ -525,25 +524,24 @@ void ValidityFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row // Append //===--------------------------------------------------------------------===// static unique_ptr ValidityInitAppend(ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); return make_uniq(std::move(handle)); } unique_ptr ValidityInitSegment(ColumnSegment &segment, block_id_t block_id, optional_ptr segment_state) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); if (block_id == INVALID_BLOCK) { - auto handle = buffer_manager.Pin(segment.block); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); memset(handle.GetDataMutable(), 0xFF, segment.SegmentSize()); } return nullptr; } -idx_t ValidityAppend(CompressionAppendState &append_state, ColumnSegment &segment, SegmentStatistics &stats, +idx_t ValidityAppend(CompressionAppendState &append_state, ColumnSegment &segment, BaseStatistics &validity_stats, UnifiedVectorFormat &data, idx_t offset, idx_t vcount) { D_ASSERT(segment.GetBlockOffset() == 0); - auto &validity_stats = stats.statistics; auto max_tuples = segment.SegmentSize() / ValidityMask::STANDARD_MASK_SIZE * STANDARD_VECTOR_SIZE; idx_t append_count = MinValue(vcount, max_tuples - segment.count); @@ -568,15 +566,15 @@ idx_t ValidityAppend(CompressionAppendState &append_state, ColumnSegment &segmen return append_count; } -idx_t ValidityFinalizeAppend(ColumnSegment &segment, SegmentStatistics &stats) { +idx_t ValidityFinalizeAppend(ColumnSegment &segment, BaseStatistics &stats) { return ((segment.count + STANDARD_VECTOR_SIZE - 1) / STANDARD_VECTOR_SIZE) * ValidityMask::STANDARD_MASK_SIZE; } void ValidityRevertAppend(ColumnSegment &segment, idx_t new_count) { idx_t start_bit = new_count; - auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - auto handle = buffer_manager.Pin(segment.block); + auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); + auto handle = buffer_manager.Pin(segment.GetBlockHandle()); idx_t revert_start; if (start_bit % 8 != 0) { // handle sub-bit stuff (yay) diff --git a/src/duckdb/src/storage/compression/zstd.cpp b/src/duckdb/src/storage/compression/zstd.cpp index 439974466..a7ec8a4cc 100644 --- a/src/duckdb/src/storage/compression/zstd.cpp +++ b/src/duckdb/src/storage/compression/zstd.cpp @@ -1,3 +1,4 @@ +#include "duckdb.hpp" #include "duckdb/common/constants.hpp" #include "duckdb/common/allocator.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -69,12 +70,12 @@ static idx_t GetVectorMetadataSize(idx_t vector_count) { struct ZSTDStorage { static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); - static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); + static bool StringAnalyze(AnalyzeState &state_p, const Vector &input); static idx_t StringFinalAnalyze(AnalyzeState &state_p); static unique_ptr InitCompression(ColumnDataCheckpointData &checkpoint_data, unique_ptr analyze_state_p); - static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); + static void Compress(CompressionState &state_p, const Vector &scan_vector); static void FinalizeCompress(CompressionState &state_p); static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); @@ -103,7 +104,8 @@ struct ZSTDStorage { struct ZSTDAnalyzeState : public AnalyzeState { public: - ZSTDAnalyzeState(CompressionInfo &info, DBConfig &config) : AnalyzeState(info), config(config), context(nullptr) { + ZSTDAnalyzeState(BlockManager &block_manager, DBConfig &config) + : AnalyzeState(block_manager), config(config), context(nullptr) { context = duckdb_zstd::ZSTD_createCCtx(); } ~ZSTDAnalyzeState() override { @@ -143,23 +145,23 @@ unique_ptr ZSTDStorage::StringInitAnalyze(ColumnData &col_data, Ph //! Can't use ZSTD in in-memory environment return nullptr; } - if (storage.GetStorageVersion() < 4) { + if (StorageManager::IsPriorToVersion(StorageVersion::V1_2_0, storage.GetStorageVersion())) { // compatibility mode with old versions - disable zstd return nullptr; } - CompressionInfo info(col_data.GetBlockManager()); auto &data_table_info = col_data.info; auto &attached_db = data_table_info.GetDB(); auto &config = DBConfig::Get(attached_db); - return make_uniq(info, config); + return make_uniq(col_data.GetBlockManager(), config); } -bool ZSTDStorage::StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count) { +bool ZSTDStorage::StringAnalyze(AnalyzeState &state_p, const Vector &input) { auto &state = state_p.Cast(); UnifiedVectorFormat vdata; input.ToUnifiedFormat(vdata); + const auto count = input.size(); auto data = UnifiedVectorFormat::GetData(vdata); for (idx_t i = 0; i < count; i++) { auto idx = vdata.sel->get_index(i); @@ -231,12 +233,12 @@ class ZSTDCompressionState : public CompressionState { public: explicit ZSTDCompressionState(ColumnDataCheckpointData &checkpoint_data, unique_ptr &&analyze_state_p) - : CompressionState(analyze_state_p->info), analyze_state(std::move(analyze_state_p)), - checkpoint_data(checkpoint_data), + : CompressionState(checkpoint_data, CompressionType::COMPRESSION_ZSTD), + analyze_state(std::move(analyze_state_p)), partial_block_manager(checkpoint_data.GetCheckpointState().GetPartialBlockManager()), - function(checkpoint_data.GetCompressionFunction(CompressionType::COMPRESSION_ZSTD)), - total_tuple_count(analyze_state->count), total_vector_count(GetVectorCount(total_tuple_count)), - total_segment_count(analyze_state->segment_count), vectors_per_segment(analyze_state->vectors_per_segment) { + stats_writer(checkpoint_data.GetType()), total_tuple_count(analyze_state->count), + total_vector_count(GetVectorCount(total_tuple_count)), total_segment_count(analyze_state->segment_count), + vectors_per_segment(analyze_state->vectors_per_segment) { segment_count = 0; vector_count = 0; vector_state.tuple_count = 0; @@ -435,7 +437,7 @@ class ZSTDCompressionState : public CompressionState { void AddString(const string_t &string) { AddStringInternal(string); - UncompressedStringStorage::UpdateStringStats(buffer_collection.segment->stats, string); + stats_writer.Update(string); } void NewPage(bool additional_data_page = false) { @@ -534,14 +536,12 @@ class ZSTDCompressionState : public CompressionState { } void CreateEmptySegment() { - auto &db = checkpoint_data.GetDatabase(); - auto &type = checkpoint_data.GetType(); - auto compressed_segment = - ColumnSegment::CreateTransientSegment(db, function, type, info.GetBlockSize(), info.GetBlockManager()); + auto compressed_segment = CreateNewSegment(); buffer_collection.segment = std::move(compressed_segment); + stats_writer.Clear(); auto &buffer_manager = BufferManager::GetBufferManager(checkpoint_data.GetDatabase()); - buffer_collection.segment_handle = buffer_manager.Pin(buffer_collection.segment->block); + buffer_collection.segment_handle = buffer_manager.Pin(buffer_collection.segment->GetBlockHandle()); } void FlushSegment() { @@ -581,6 +581,7 @@ class ZSTDCompressionState : public CompressionState { } auto &state = checkpoint_data.GetCheckpointState(); + stats_writer.Merge(buffer_collection.segment->GetStatsMutable()); state.FlushSegment(std::move(buffer_collection.segment), std::move(buffer_collection.segment_handle), segment_block_size); segment_buffer_state.flags.Clear(); @@ -596,16 +597,15 @@ class ZSTDCompressionState : public CompressionState { } void AddNull() { - buffer_collection.segment->stats.statistics.SetHasNullFast(); + stats_writer.SetHasNull(); string_t empty(static_cast(0)); AddStringInternal(empty); } public: unique_ptr analyze_state; - ColumnDataCheckpointData &checkpoint_data; PartialBlockManager &partial_block_manager; - const CompressionFunction &function; + StatsWriter stats_writer; //! --- Analyzed Data --- //! The amount of tuples we're writing @@ -639,7 +639,7 @@ unique_ptr ZSTDStorage::InitCompression(ColumnDataCheckpointDa unique_ptr_cast(std::move(analyze_state_p))); } -void ZSTDStorage::Compress(CompressionState &state_p, Vector &input, idx_t count) { +void ZSTDStorage::Compress(CompressionState &state_p, const Vector &input) { auto &state = state_p.Cast(); // Get vector data @@ -647,7 +647,7 @@ void ZSTDStorage::Compress(CompressionState &state_p, Vector &input, idx_t count input.ToUnifiedFormat(vdata); auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { + for (idx_t i = 0; i < input.size(); i++) { auto idx = vdata.sel->get_index(i); // Note: we treat nulls and empty strings the same if (!vdata.validity.RowIsValid(idx)) { @@ -707,10 +707,11 @@ struct ZSTDScanState : public SegmentScanState { public: explicit ZSTDScanState(ColumnSegment &segment) : state(segment.GetSegmentState()->Cast()), - block_manager(segment.block->GetBlockManager()), buffer_manager(BufferManager::GetBufferManager(segment.db)), + block_manager(segment.GetBlockHandle()->GetBlockManager()), + buffer_manager(BufferManager::GetBufferManager(segment.GetDatabase())), segment_block_offset(segment.GetBlockOffset()), segment(segment) { decompression_context = duckdb_zstd::ZSTD_createDCtx(); - segment_handle = buffer_manager.Pin(segment.block); + segment_handle = buffer_manager.Pin(segment.GetBlockHandle()); auto data = segment_handle.GetDataMutable() + segment.GetBlockOffset(); idx_t offset = 0; diff --git a/src/duckdb/src/storage/data_table.cpp b/src/duckdb/src/storage/data_table.cpp index 9b5e5afc7..6be939bbc 100644 --- a/src/duckdb/src/storage/data_table.cpp +++ b/src/duckdb/src/storage/data_table.cpp @@ -14,6 +14,8 @@ #include "duckdb/execution/index/unbound_index.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/profiling_utils.hpp" +#include "duckdb/main/query_profiler.hpp" #include "duckdb/parser/constraints/list.hpp" #include "duckdb/planner/constraints/list.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" @@ -37,8 +39,8 @@ namespace duckdb { -DataTableInfo::DataTableInfo(AttachedDatabase &db, shared_ptr table_io_manager_p, string schema, - string table) +DataTableInfo::DataTableInfo(AttachedDatabase &db, shared_ptr table_io_manager_p, Identifier schema, + Identifier table) : db(db), table_io_manager(std::move(table_io_manager_p)), schema(std::move(schema)), table(std::move(table)) { } @@ -50,7 +52,7 @@ bool DataTableInfo::IsTemporary() const { return db.IsTemporary(); } -IndexStorageInfo DataTableInfo::ExtractIndexStorageInfo(const string &name) { +IndexStorageInfo DataTableInfo::ExtractIndexStorageInfo(const Identifier &name) { for (idx_t i = 0; i < index_storage_infos.size(); i++) { if (index_storage_infos[i].name == name) { auto result = std::move(index_storage_infos[i]); @@ -58,13 +60,15 @@ IndexStorageInfo DataTableInfo::ExtractIndexStorageInfo(const string &name) { return result; } } - throw InternalException("ExtractIndexStorageInfo: index storage info with name '%s' not found", name); + throw InternalException("ExtractIndexStorageInfo: index storage info with name '%s' not found", + name.GetIdentifierName()); } DataTable::DataTable(AttachedDatabase &db, shared_ptr table_io_manager_p, const string &schema, const string &table, vector column_definitions_p, unique_ptr data) - : db(db), info(make_shared_ptr(db, std::move(table_io_manager_p), schema, table)), + : db(db), + info(make_shared_ptr(db, std::move(table_io_manager_p), Identifier(schema), Identifier(table))), column_definitions(std::move(column_definitions_p)), version(DataTableVersion::MAIN_TABLE) { // initialize the table with the existing data from disk, if any auto types = GetTypes(); @@ -410,7 +414,6 @@ void DataTable::RebuildIndexes() { for (idx_t i = 0; i < col_ids.size(); i++) { table_chunk.data[col_ids[i]].Reference(scan_chunk.data[i]); } - table_chunk.SetCardinality(scan_chunk); Vector &row_ids = scan_chunk.data[col_ids.size()]; auto error = bound_index.Append(table_chunk, row_ids); @@ -447,25 +450,25 @@ bool DataTable::IndexNameIsUnique(const string &name) { return info->indexes.NameIsUnique(name); } -string DataTableInfo::GetSchemaName() { +Identifier DataTableInfo::GetSchemaName() { return schema; } -string DataTableInfo::GetTableName() { +Identifier DataTableInfo::GetTableName() { lock_guard l(name_lock); return table; } -void DataTableInfo::SetTableName(string name) { +void DataTableInfo::SetTableName(Identifier name) { lock_guard l(name_lock); table = std::move(name); } -string DataTable::GetTableName() const { +Identifier DataTable::GetTableName() const { return info->GetTableName(); } -void DataTable::SetTableName(string new_name) { +void DataTable::SetTableName(Identifier new_name) { info->SetTableName(std::move(new_name)); } @@ -544,7 +547,7 @@ void DataTable::Fetch(DuckTransaction &transaction, DataChunk &result, const vec for (idx_t col = 0; col < result.ColumnCount(); col++) { VectorOperations::Copy(local_chunk.data[col], result.data[col], local_count, 0, committed_count); } - result.SetCardinality(committed_count + local_count); + result.SetChildCardinality(committed_count + local_count); // Build inverse permutation to restore original row order. // Current layout: [committed rows in relative order | local rows in relative order]. @@ -573,8 +576,8 @@ bool DataTable::CanFetch(DuckTransaction &transaction, const row_t row_id) { //===--------------------------------------------------------------------===// // Append //===--------------------------------------------------------------------===// -static void VerifyNotNullConstraint(TableCatalogEntry &table, Vector &vector, idx_t count, const string &col_name) { - if (!VectorOperations::HasNull(vector, count)) { +static void VerifyNotNullConstraint(TableCatalogEntry &table, const Vector &vector, const Identifier &col_name) { + if (!VectorOperations::HasNull(vector)) { return; } @@ -688,12 +691,11 @@ void DataTable::VerifyForeignKeyConstraint(optional_ptr stora DataChunk dst_chunk; dst_chunk.InitializeEmpty(types); for (idx_t i = 0; i < src_keys_ptr.get().size(); i++) { - auto &src_chunk = chunk.data[src_keys_ptr.get()[i].index]; + const auto &src_chunk = chunk.data[src_keys_ptr.get()[i].index]; dst_chunk.data[dst_keys_ptr.get()[i].index].Reference(src_chunk); } auto count = chunk.size(); - dst_chunk.SetCardinality(count); if (count <= 0) { return; } @@ -930,7 +932,7 @@ void DataTable::VerifyAppendConstraints(ConstraintState &constraint_state, Clien auto &bound_not_null = constraint->Cast(); auto ¬_null = base_constraint->Cast(); auto &col = table.GetColumns().GetColumn(LogicalIndex(not_null.index)); - VerifyNotNullConstraint(table, chunk.data[bound_not_null.index.index], chunk.size(), col.Name()); + VerifyNotNullConstraint(table, chunk.data[bound_not_null.index.index], col.Name()); break; } case ConstraintType::CHECK: { @@ -1152,7 +1154,7 @@ void DataTable::AppendLock(DuckTransaction &transaction, TableAppendState &state GetTableName(), TableModification()); } state.table_lock = transaction.SharedLockTable(*info); - state.row_start = NumericCast(row_groups->GetTotalRows()); + state.row_start = NumericCast(row_groups->GetNextRowId()); state.current_row = state.row_start; auto &transaction_manager = transaction.GetTransactionManager(); auto active_checkpoint = transaction_manager.GetActiveCheckpoint(); @@ -1321,10 +1323,13 @@ void DataTable::RevertAppend(DuckTransaction &transaction, idx_t start_row, idx_ idx_t current_row_base = start_row; row_t row_data[STANDARD_VECTOR_SIZE]; Vector row_identifiers(LogicalType::ROW_TYPE, data_ptr_cast(row_data), STANDARD_VECTOR_SIZE); - idx_t scan_count = MinValue(count, row_groups->GetTotalRows() - start_row); + auto next_row_id = row_groups->GetNextRowId(); + D_ASSERT(start_row <= next_row_id); + idx_t scan_count = MinValue(count, next_row_id - start_row); ScanTableSegment(transaction, start_row, scan_count, [&](DataChunk &chunk) { + auto row_id_writer = FlatVector::Writer(row_identifiers, chunk.size()); for (idx_t i = 0; i < chunk.size(); i++) { - row_data[i] = NumericCast(current_row_base + i); + row_id_writer.WriteValue(NumericCast(current_row_base + i)); } for (auto &entry : info->indexes.IndexEntries()) { lock_guard guard(entry.lock); @@ -1488,8 +1493,6 @@ void DataTable::RevertIndexAppend(TableAppendState &state, DataChunk &chunk, Vec void DataTable::RemoveFromIndexes(const QueryContext &context, Vector &row_identifiers, idx_t count, IndexRemovalType removal_type, optional_idx active_checkpoint) { - D_ASSERT(IsMainTable() || removal_type == IndexRemovalType::REVERT_MAIN_INDEX || - removal_type == IndexRemovalType::REVERT_MAIN_INDEX_ONLY); row_groups->RemoveFromIndexes(context, info->indexes, row_identifiers, count, removal_type, active_checkpoint); } @@ -1627,7 +1630,6 @@ static void CreateMockChunk(vector &types, const vector &column_ids, @@ -1671,7 +1673,7 @@ void DataTable::VerifyUpdateConstraints(ConstraintState &state, ClientContext &c if (column_ids[col_idx] == bound_not_null.index) { // found the column id: check the data in auto &col = table.GetColumn(LogicalIndex(not_null.index)); - VerifyNotNullConstraint(table, chunk.data[col_idx], chunk.size(), col.Name()); + VerifyNotNullConstraint(table, chunk.data[col_idx], col.Name()); break; } } @@ -1824,7 +1826,13 @@ void DataTable::Checkpoint(TableDataWriter &writer, Serializer &serializer) { row_groups->Checkpoint(writer, global_stats); row_groups->SetRowGroupAppendMode(RowGroupAppendMode::SUGGEST_NEW); if (writer.GetRebuildIndexes()) { + MetricsTimer timer; + auto context = writer.TryGetClientContext(); + if (context) { + timer = QueryProfiler::Get(*context).StartTimer(); + } RebuildIndexes(); + timer.EndTimer(); } // The row group payload data has been written. Now write: // sample @@ -1852,15 +1860,42 @@ idx_t DataTable::GetTotalRows() const { return row_groups->GetTotalRows(); } +idx_t DataTable::GetNextRowId() const { + return row_groups->GetNextRowId(); +} + void DataTable::CommitDropTable(CommitDropState &drop_state) { row_groups->CommitDropTable(drop_state); } +idx_t DataTable::GetRowGroupCount() const { + return row_groups->GetRowGroupCount(); +} + +idx_t DataTable::GetRowGroupCountWithLocalStorage(ClientContext &context) { + auto &local_storage = LocalStorage::Get(context, db); + auto storage = local_storage.GetStorage(*this); + if (!storage) { + return GetRowGroupCount(); + } + return GetRowGroupCount() + storage->GetCollection().GetRowGroupCount(); +} + //===--------------------------------------------------------------------===// // Column Segment Info //===--------------------------------------------------------------------===// -vector DataTable::GetColumnSegmentInfo(const QueryContext &context) { - return row_groups->GetColumnSegmentInfo(context); +vector DataTable::GetColumnSegmentInfo(const QueryContext &context, + const ColumnSegmentInfoScanOptions &options) { + return row_groups->GetColumnSegmentInfo(context, options); +} + +void DataTable::InitializeColumnSegmentInfoScan(ColumnSegmentInfoScanState &state) { + row_groups->InitializeColumnSegmentInfoScan(state); +} + +bool DataTable::ScanColumnSegmentInfo(const QueryContext &context, ColumnSegmentInfoScanState &state, + vector &result) { + return row_groups->ScanColumnSegmentInfo(context, state, result); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/storage/database_handle.cpp b/src/duckdb/src/storage/database_handle.cpp new file mode 100644 index 000000000..80ef817e3 --- /dev/null +++ b/src/duckdb/src/storage/database_handle.cpp @@ -0,0 +1,152 @@ +#include "duckdb/storage/database_handle.hpp" +#include "duckdb/storage/single_file_block_manager.hpp" + +namespace duckdb { + +DatabaseHandle::DatabaseHandle(unique_ptr handle_p) : handle(std::move(handle_p)) { +} +DatabaseHandle::DatabaseHandle(unique_ptr mmap_handle_p) : mmap_handle(std::move(mmap_handle_p)) { +} + +unique_ptr DatabaseHandle::Open(AttachedDatabase &db, const string &path, + const StorageManagerOptions &options, DatabaseOpenMode open_mode) { + if (options.io_mode == FileIOMode::MMAP) { + return OpenMemoryMap(db, path, options, open_mode); + } else { + return OpenFile(db, path, options, open_mode); + } +} + +unique_ptr DatabaseHandle::OpenFile(AttachedDatabase &db, const string &path, + const StorageManagerOptions &options, DatabaseOpenMode open_mode) { + FileOpenFlags file_flags; + if (options.read_only) { + D_ASSERT(open_mode == DatabaseOpenMode::OPEN_EXISTING_FILE); + file_flags = FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS | FileLockType::READ_LOCK; + } else { + file_flags = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_READ | FileLockType::WRITE_LOCK; + if (open_mode == DatabaseOpenMode::CREATE_NEW_FILE) { + file_flags |= FileFlags::FILE_FLAGS_FILE_CREATE; + } + } + if (options.io_mode == FileIOMode::DIRECT_IO) { + file_flags |= FileFlags::FILE_FLAGS_DIRECT_IO; + } + // database files can be read from in parallel + file_flags |= FileFlags::FILE_FLAGS_PARALLEL_ACCESS; + file_flags |= FileFlags::FILE_FLAGS_MULTI_CLIENT_ACCESS; + + auto &fs = FileSystem::Get(db); + auto file_handle = fs.OpenFile(path, file_flags); + if (!file_handle) { + // this can only happen in read-only mode - as that is when we set FILE_FLAGS_NULL_IF_NOT_EXISTS + throw IOException("Cannot open database \"%s\" in read-only mode: database does not exist", path); + } + return make_uniq(std::move(file_handle)); +} + +unique_ptr DatabaseHandle::OpenMemoryMap(AttachedDatabase &db, const string &path, + const StorageManagerOptions &options, + DatabaseOpenMode open_mode) { + if (options.encryption_options.encryption_enabled) { + // In-place decryption would write decrypted bytes back through the mapping. + throw InvalidInputException("MMAP mode is not supported for encrypted databases"); + } + // Default reserve covers the bulk of analytical databases; users override via MMAP_RESERVE_SIZE. + static constexpr idx_t MMAP_DEFAULT_RESERVE_SIZE = idx_t(256) * 1024 * 1024 * 1024; // 256 GiB + MMapOptions mmap_options; + mmap_options.reserve_size = + options.mmap_reserve_size.IsValid() ? options.mmap_reserve_size.GetIndex() : MMAP_DEFAULT_RESERVE_SIZE; + FileOpenFlags mmap_flags; + if (options.read_only) { + D_ASSERT(open_mode == DatabaseOpenMode::OPEN_EXISTING_FILE); + mmap_flags = FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS | FileLockType::READ_LOCK; + } else { + mmap_flags = FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_WRITE | FileLockType::WRITE_LOCK; + if (open_mode == DatabaseOpenMode::CREATE_NEW_FILE) { + mmap_flags |= FileFlags::FILE_FLAGS_FILE_CREATE; + } + } + auto &fs = FileSystem::Get(db); + auto mmap_handle = fs.MemoryMapFile(path, mmap_flags, mmap_options); + if (!mmap_handle) { + // Only happens in read-only mode, where FILE_FLAGS_NULL_IF_NOT_EXISTS is set. + throw IOException("Cannot open database \"%s\" in read-only mode: database does not exist", path); + } + return make_uniq(std::move(mmap_handle)); +} + +bool DatabaseHandle::OnDiskFile() const { + if (mmap_handle) { + return mmap_handle->OnDiskFile(); + } + return handle->OnDiskFile(); +} + +void DatabaseHandle::CheckMagicBytes(QueryContext context) { + if (mmap_handle) { + MainHeader::CheckMagicBytes(*mmap_handle); + } else { + MainHeader::CheckMagicBytes(context, *handle); + } +} + +void DatabaseHandle::Read(QueryContext context, FileBuffer &block, uint64_t location) const { + if (mmap_handle) { + block.Read(context, *mmap_handle, location); + } else { + block.Read(context, *handle, location); + } +} + +void DatabaseHandle::Write(QueryContext context, FileBuffer &buffer, uint64_t location) { + if (mmap_handle) { + EnsureMappedSize(location + buffer.AllocSize()); + buffer.Write(context, *mmap_handle, location); + } else { + buffer.Write(context, *handle, location); + } +} + +void DatabaseHandle::Sync() { + if (mmap_handle) { + mmap_handle->Sync(); + } else { + handle->Sync(); + } +} + +void DatabaseHandle::Truncate(idx_t new_size) { + if (mmap_handle) { + // File stays at reserve size in MAP mode; punch a hole to release the tail. + mmap_handle->Trim(new_size, mmap_handle->Size() - new_size); + } else { + handle->Truncate(NumericCast(new_size)); + } +} + +void DatabaseHandle::Trim(idx_t offset, idx_t length) { + if (mmap_handle) { + mmap_handle->Trim(offset, length); + } else { + handle->Trim(offset, length); + } +} + +FileHandle &DatabaseHandle::GetFileHandle() { + return *handle; +} + +void DatabaseHandle::EnsureMappedSize(idx_t required_size) const { + if (!mmap_handle) { + return; + } + if (required_size > mmap_handle->Size()) { + throw IOException("Database file \"%s\" would grow to %llu bytes, which exceeds the memory-mapped reserved " + "size of %llu bytes. Re-attach with a larger MMAP_RESERVE_SIZE, or without " + "IO_MODE='MMAP', to allow further growth.", + mmap_handle->GetPath(), required_size, mmap_handle->Size()); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/external_file_cache/caching_file_system.cpp b/src/duckdb/src/storage/external_file_cache/caching_file_system.cpp index c0fbbcb4e..966ffa731 100644 --- a/src/duckdb/src/storage/external_file_cache/caching_file_system.cpp +++ b/src/duckdb/src/storage/external_file_cache/caching_file_system.cpp @@ -126,28 +126,70 @@ CachingFileSystem CachingFileSystem::Get(ClientContext &context) { unique_ptr CachingFileSystem::OpenFile(const OpenFileInfo &path, FileOpenFlags flags, optional_ptr opener) { - return make_uniq(QueryContext(), *this, path, flags, opener, - external_file_cache.GetOrCreateCachedFile(path.path)); + return make_uniq(QueryContext(), *this, path, flags, opener); } unique_ptr CachingFileSystem::OpenFile(QueryContext context, const OpenFileInfo &path, FileOpenFlags flags, optional_ptr opener) { - return make_uniq(context, *this, path, flags, opener, - external_file_cache.GetOrCreateCachedFile(path.path)); + return make_uniq(context, *this, path, flags, opener); } //===----------------------------------------------------------------------===// // CachingFileHandle //===----------------------------------------------------------------------===// +bool CachingFileHandle::StripForceFullDownloadIfPresent() { + auto &extended_info_p = path.extended_info; + if (!extended_info_p) { + return false; + } + + auto &extended_info = *extended_info_p; + const bool contains_force_full_download = extended_info.options.count("force_full_download"); + if (!contains_force_full_download) { + return false; + } + + //! We do have 'force_full_download' - strip it + auto new_extended_info = make_shared_ptr(); + *new_extended_info = extended_info; + new_extended_info->options.erase("force_full_download"); + if (!new_extended_info->options.count("file_size")) { + new_extended_info->options["file_size"] = Value::UBIGINT(GetFileSize()); + } + path.extended_info = new_extended_info; + return true; +} + +shared_ptr CachingFileHandle::EnsureCachedFileCurrent() { + bool needs_reopen = false; + { + annotated_lock_guard guard(file_handle_mutex); + if (cached_file && cached_file->generation == external_file_cache.GetGeneration()) { + return cached_file; + } + needs_reopen = file_handle != nullptr; + if (needs_reopen) { + file_handle.reset(); + } + cached_file = external_file_cache.GetOrCreateCachedFile(path.path); + } + + if (needs_reopen) { + GetFileHandle(); + } + return cached_file; +} + CachingFileHandle::CachingFileHandle(QueryContext context, CachingFileSystem &caching_file_system_p, const OpenFileInfo &path_p, FileOpenFlags flags_p, - optional_ptr opener_p, CachedFile &cached_file_p) + optional_ptr opener_p) : context(context), caching_file_system(caching_file_system_p), external_file_cache(caching_file_system.external_file_cache), path(path_p), flags(flags_p), opener(opener_p), validate( ExternalFileCacheUtil::GetCacheValidationMode(path_p, context.GetClientContext(), caching_file_system_p.db)), - cached_file(cached_file_p), position(0) { + cached_file(nullptr), position(0) { + cached_file = external_file_cache.GetOrCreateCachedFile(path_p.path); if (!external_file_cache.IsEnabled() || Validate()) { // If caching is disabled, or if we must validate cache entries, we always have to open the file GetFileHandle(); @@ -156,12 +198,16 @@ CachingFileHandle::CachingFileHandle(QueryContext context, CachingFileSystem &ca // If we don't have any cached blocks, we must also open the file. bool needs_open = false; { - annotated_lock_guard guard(cached_file.map_lock); - needs_open = cached_file.blocks.empty(); + annotated_lock_guard guard(cached_file->map_lock); + needs_open = cached_file->blocks.empty(); } if (needs_open) { GetFileHandle(); } + auto needs_full_download = StripForceFullDownloadIfPresent(); + if (needs_full_download) { + Read(GetFileSize(), 0); + } } CachingFileHandle::~CachingFileHandle() { @@ -173,30 +219,38 @@ FileHandle &CachingFileHandle::GetFileHandle() { return *file_handle; } - file_handle = caching_file_system.file_system.OpenFile(path, flags, opener); + // The caching file handle can service parallel block fetch tasks against a single shared FileHandle. + // Request parallel access on the underlying filesystem handle to avoid races in implementations that + // require explicit opt-in for concurrent pread-style access (e.g., HTTPFS). + auto internal_flags = flags | FileFlags::FILE_FLAGS_PARALLEL_ACCESS; + file_handle = caching_file_system.file_system.OpenFile(path, internal_flags, opener); last_modified = caching_file_system.file_system.GetLastModifiedTime(*file_handle); version_tag = caching_file_system.file_system.GetVersionTag(*file_handle); { - annotated_lock_guard meta_guard(cached_file.meta_lock); - const bool first_access = (cached_file.file_size == 0); + annotated_lock_guard meta_guard(cached_file->meta_lock); + const bool first_access = (cached_file->file_size == 0); if (first_access || Validate()) { - if (!ExternalFileCache::IsValid(Validate(), cached_file.version_tag, cached_file.last_modified, version_tag, - last_modified)) { - annotated_lock_guard map_guard(cached_file.map_lock); - cached_file.blocks.clear(); - cached_file.cached_block_size.SetInvalid(); + if (!ExternalFileCache::IsValid(Validate(), cached_file->version_tag, cached_file->last_modified, + version_tag, last_modified)) { + annotated_lock_guard map_guard(cached_file->map_lock); + cached_file->blocks.clear(); + cached_file->cached_block_size.SetInvalid(); } - cached_file.file_size = file_handle->GetFileSize(); - cached_file.last_modified = last_modified; - cached_file.version_tag = version_tag; - cached_file.can_seek = file_handle->CanSeek(); - cached_file.on_disk_file = file_handle->OnDiskFile(); + cached_file->file_size = file_handle->GetFileSize(); + cached_file->last_modified = last_modified; + cached_file->version_tag = version_tag; + cached_file->can_seek = file_handle->CanSeek(); + cached_file->on_disk_file = file_handle->OnDiskFile(); } } return *file_handle; } +Allocator &CachingFileHandle::GetBufferAllocator() const { + return external_file_cache.GetBufferManager().GetBufferAllocator(); +} + FileBufferHandleGroup CachingFileHandle::Read(const idx_t nr_bytes, const idx_t location) { if (nr_bytes == 0) { return FileBufferHandleGroup(); @@ -210,18 +264,20 @@ FileBufferHandleGroup CachingFileHandle::Read(const idx_t nr_bytes, const idx_t return FileBufferHandleGroup(std::move(mem_handles)); } - const idx_t block_size = external_file_cache.GetCacheBlockSize(cached_file.path); + auto current_cached_file = EnsureCachedFileCurrent(); + const idx_t block_size = external_file_cache.GetCacheBlockSize(current_cached_file->path); const idx_t first_block = location / block_size; const idx_t last_block = (location + nr_bytes - 1) / block_size; const idx_t num_blocks = last_block - first_block + 1; // Atomically reindex (if needed) and acquire the block range. - auto blocks = external_file_cache.ReindexAndAcquireBlocks(cached_file, block_size, first_block, num_blocks); + auto blocks = + external_file_cache.ReindexAndAcquireBlocks(*current_cached_file, block_size, first_block, num_blocks); // Schedule block fetch tasks for all blocks. vector pins(num_blocks); auto &scheduler = TaskScheduler::GetScheduler(caching_file_system.db); - TaskExecutor executor(scheduler); + TaskExecutor executor(scheduler, TaskSchedulerType::ASYNC); for (idx_t idx = 0; idx < num_blocks; idx++) { executor.ScheduleTask(make_uniq(*this, executor, context, @@ -252,9 +308,9 @@ FileBufferHandleGroup CachingFileHandle::Read(const idx_t nr_bytes, const idx_t // After all tasks complete, check if the cache was invalidated by another thread. if (Validate()) { - const annotated_lock_guard meta_guard(cached_file.meta_lock); - if (!ExternalFileCache::IsValid(true, cached_file.version_tag, cached_file.last_modified, version_tag, - last_modified)) { + const annotated_lock_guard meta_guard(current_cached_file->meta_lock); + if (!ExternalFileCache::IsValid(true, current_cached_file->version_tag, current_cached_file->last_modified, + version_tag, last_modified)) { for (auto &block : blocks) { block->Reinit(); } @@ -287,21 +343,23 @@ FileBufferHandleGroup CachingFileHandle::Read(idx_t &nr_bytes) { } string CachingFileHandle::GetPath() const { - return cached_file.path; + return path.path; } idx_t CachingFileHandle::GetFileSize() { if (!Validate()) { - annotated_lock_guard guard(cached_file.meta_lock); - return cached_file.file_size; + auto current_cached_file = EnsureCachedFileCurrent(); + annotated_lock_guard guard(current_cached_file->meta_lock); + return current_cached_file->file_size; } return GetFileHandle().GetFileSize(); } timestamp_t CachingFileHandle::GetLastModifiedTime() { if (!Validate()) { - annotated_lock_guard guard(cached_file.meta_lock); - return cached_file.last_modified; + auto current_cached_file = EnsureCachedFileCurrent(); + annotated_lock_guard guard(current_cached_file->meta_lock); + return current_cached_file->last_modified; } GetFileHandle(); return last_modified; @@ -309,8 +367,9 @@ timestamp_t CachingFileHandle::GetLastModifiedTime() { string CachingFileHandle::GetVersionTag() { if (!Validate()) { - annotated_lock_guard guard(cached_file.meta_lock); - return cached_file.version_tag; + auto current_cached_file = EnsureCachedFileCurrent(); + annotated_lock_guard guard(current_cached_file->meta_lock); + return current_cached_file->version_tag; } GetFileHandle(); return version_tag; @@ -321,7 +380,7 @@ bool CachingFileHandle::Validate() const { case CacheValidationMode::VALIDATE_ALL: return true; case CacheValidationMode::VALIDATE_REMOTE: - return FileSystem::IsRemoteFile(cached_file.path); + return FileSystem::IsRemoteFile(path.path); case CacheValidationMode::NO_VALIDATION: return false; default: @@ -331,24 +390,30 @@ bool CachingFileHandle::Validate() const { bool CachingFileHandle::CanSeek() { if (!Validate()) { - annotated_lock_guard guard(cached_file.meta_lock); - return cached_file.can_seek; + auto current_cached_file = EnsureCachedFileCurrent(); + annotated_lock_guard guard(current_cached_file->meta_lock); + return current_cached_file->can_seek; } return GetFileHandle().CanSeek(); } bool CachingFileHandle::IsRemoteFile() const { - return FileSystem::IsRemoteFile(cached_file.path); + return FileSystem::IsRemoteFile(path.path); } bool CachingFileHandle::OnDiskFile() { if (!Validate()) { - annotated_lock_guard guard(cached_file.meta_lock); - return cached_file.on_disk_file; + auto current_cached_file = EnsureCachedFileCurrent(); + annotated_lock_guard guard(current_cached_file->meta_lock); + return current_cached_file->on_disk_file; } return GetFileHandle().OnDiskFile(); } +bool CachingFileHandle::TryGetNetworkThroughput(NetworkThroughputEstimate &result) { + return GetFileHandle().TryGetNetworkThroughput(result); +} + idx_t CachingFileHandle::SeekPosition() { return position; } diff --git a/src/duckdb/src/storage/external_file_cache/external_file_cache.cpp b/src/duckdb/src/storage/external_file_cache/external_file_cache.cpp index f93a09ba9..87cf77882 100644 --- a/src/duckdb/src/storage/external_file_cache/external_file_cache.cpp +++ b/src/duckdb/src/storage/external_file_cache/external_file_cache.cpp @@ -10,9 +10,42 @@ #include "duckdb/common/operator/subtract.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/buffer/block_handle.hpp" +#include "duckdb/storage/object_cache.hpp" namespace duckdb { +class ExternalFileCache::ExternalFileCacheObjectCacheEntry : public ObjectCacheEntry { +public: + ExternalFileCacheObjectCacheEntry(ExternalFileCache &cache_p, string path_p, idx_t generation_p) + : cache(cache_p), cached_file(make_shared_ptr(std::move(path_p), generation_p)) { + cache.InsertCachedFileKey(cached_file->path); + } + + ~ExternalFileCacheObjectCacheEntry() override { + cache.EraseCachedFileKey(cached_file->path); + } + + static string ObjectType() { + return "external_file_cache"; + } + + string GetObjectType() override { + return ObjectType(); + } + + optional_idx GetEstimatedCacheMemory() const override { + return cached_file->path.size() * 2; + } + + shared_ptr GetCachedFile() const { + return cached_file; + } + +private: + ExternalFileCache &cache; + shared_ptr cached_file; +}; + idx_t ExternalFileCache::GetCacheBlockSize(const string &path) const { auto &db = buffer_manager.GetDatabase(); if (FileSystem::IsRemoteFile(path)) { @@ -146,7 +179,8 @@ vector> ExternalFileCache::ReindexAndAcquireBlocks(Cached return blocks; } -ExternalFileCache::CachedFile::CachedFile(string path_p) : path(std::move(path_p)) { +ExternalFileCache::CachedFile::CachedFile(string path_p, idx_t generation_p) + : path(std::move(path_p)), generation(generation_p) { } bool ExternalFileCache::CachedFile::IsValid(bool validate, const string ¤t_version_tag, @@ -192,7 +226,7 @@ bool ExternalFileCache::IsValid(bool validate, const string &cached_version_tag, } ExternalFileCache::ExternalFileCache(DatabaseInstance &db, bool enable_p) - : buffer_manager(BufferManager::GetBufferManager(db)), enable(enable_p) { + : buffer_manager(BufferManager::GetBufferManager(db)), enable(enable_p), generation(0) { } bool ExternalFileCache::IsEnabled() const { @@ -200,21 +234,50 @@ bool ExternalFileCache::IsEnabled() const { } void ExternalFileCache::SetEnabled(bool enable_p) { - lock_guard guard(lock); - enable = enable_p; - if (!enable) { - cached_files.clear(); + vector keys_to_delete; + { + const annotated_lock_guard guard(lock); + if (enable == enable_p) { + return; + } + enable = enable_p; + generation++; + if (!enable) { + keys_to_delete.reserve(cached_file_keys.size()); + for (auto &key : cached_file_keys) { + keys_to_delete.emplace_back(key); + } + } } + DeleteObjectCacheEntries(keys_to_delete); +} + +idx_t ExternalFileCache::GetGeneration() const { + return generation; } vector ExternalFileCache::GetCachedFileInformation() const { - unique_lock files_guard(lock); + vector keys; + { + const annotated_lock_guard files_guard(lock); + keys.reserve(cached_file_keys.size()); + for (auto &key : cached_file_keys) { + keys.emplace_back(key); + } + } + + auto &object_cache = buffer_manager.GetDatabase().GetObjectCache(); vector result; - for (const auto &file : cached_files) { - annotated_lock_guard map_guard(file.second->map_lock); - const idx_t block_size = file.second->cached_block_size.IsValid() ? file.second->cached_block_size.GetIndex() - : GetCacheBlockSize(file.first); - for (const auto &block_entry : file.second->blocks) { + for (const auto &key : keys) { + auto entry = object_cache.GetWithTypePrefix(key); + if (!entry) { + continue; + } + auto file = entry->GetCachedFile(); + const annotated_lock_guard map_guard(file->map_lock); + const idx_t block_size = + file->cached_block_size.IsValid() ? file->cached_block_size.GetIndex() : GetCacheBlockSize(file->path); + for (const auto &block_entry : file->blocks) { const idx_t block_idx = block_entry.first; const auto &block = *block_entry.second; @@ -224,12 +287,17 @@ vector ExternalFileCache::GetCachedFileInformation() cons } const idx_t location = block_idx * block_size; const bool loaded = !block.block_handle->GetMemory().IsUnloaded(); - result.push_back({file.first, block.nr_bytes, location, loaded}); + result.push_back({file->path, block.nr_bytes, location, loaded}); } } return result; } +idx_t ExternalFileCache::GetCachedFileCount() const { + const annotated_lock_guard files_guard(lock); + return cached_file_keys.size(); +} + ExternalFileCache &ExternalFileCache::Get(DatabaseInstance &db) { return db.GetExternalFileCache(); } @@ -242,13 +310,48 @@ BufferManager &ExternalFileCache::GetBufferManager() const { return buffer_manager; } -ExternalFileCache::CachedFile &ExternalFileCache::GetOrCreateCachedFile(const string &path) { - lock_guard guard(lock); - auto &entry = cached_files[path]; - if (!entry) { - entry = make_uniq(path); +void ExternalFileCache::DeleteObjectCacheEntries(const vector &paths) { + auto &object_cache = buffer_manager.GetDatabase().GetObjectCache(); + for (auto &path : paths) { + object_cache.DeleteWithTypePrefix(path); + } +} + +shared_ptr ExternalFileCache::GetOrCreateCachedFile(const string &path) { + auto &object_cache = buffer_manager.GetDatabase().GetObjectCache(); + while (true) { + const auto current_generation = generation.load(); + if (!enable) { + return make_shared_ptr(path, current_generation); + } + + auto entry = object_cache.GetOrCreateWithTypePrefix(path, *this, path, + current_generation); + auto cached_file = entry->GetCachedFile(); + + if (!enable) { + object_cache.DeleteWithTypePrefix(path); + return make_shared_ptr(path, current_generation); + } + if (cached_file->generation != current_generation) { + object_cache.DeleteWithTypePrefix(path); + continue; + } + return cached_file; } - return *entry; +} + +void ExternalFileCache::InsertCachedFileKey(const string &path) { + const annotated_lock_guard guard(lock); + auto inserted = cached_file_keys.insert(path); + ALWAYS_ASSERT(inserted.second); +} + +void ExternalFileCache::EraseCachedFileKey(const string &path) { + const annotated_lock_guard guard(lock); + auto entry = cached_file_keys.find(path); + ALWAYS_ASSERT(entry != cached_file_keys.end()); + cached_file_keys.erase(entry); } } // namespace duckdb diff --git a/src/duckdb/src/storage/local_storage.cpp b/src/duckdb/src/storage/local_storage.cpp index 2222900cb..889376068 100644 --- a/src/duckdb/src/storage/local_storage.cpp +++ b/src/duckdb/src/storage/local_storage.cpp @@ -107,7 +107,7 @@ idx_t LocalTableStorage::EstimatedSize() { if (collection.GetTotalRows() >= collection.GetRowGroupSize() && deleted_rows == 0) { // Optimistic insertion does not generate many WAL logs, so we estimate the size of the Data Block Pointers here - idx_t row_group_count = row_groups->complete_row_groups + 1; + idx_t row_group_count = row_groups->collection->GetRowGroupCount(); idx_t column_count = collection.GetTypes().size(); data_size = row_group_count * (sizeof(PersistentRowGroupData) + @@ -138,12 +138,12 @@ idx_t LocalTableStorage::EstimatedSize() { return data_size + index_sizes; } -void LocalTableStorage::WriteNewRowGroup() { +void LocalTableStorage::WriteNewRowGroup(idx_t flushed_row_group_idx) { if (deleted_rows != 0) { // we have deletes - we cannot merge row groups return; } - optimistic_writer.WriteNewRowGroup(*row_groups); + optimistic_writer.WriteNewRowGroup(*row_groups, flushed_row_group_idx); } void LocalTableStorage::FlushBlocks() { @@ -197,7 +197,6 @@ ErrorData LocalTableStorage::AppendToIndexes(DuckTransaction &transaction, RowGr auto col_id = mapped_column_ids[i].GetPrimaryIndex(); table_chunk.data[col_id].Reference(index_chunk.data[i]); } - table_chunk.SetCardinality(index_chunk); // Pass both the table and the index chunk. // We need the table chunk for the bound indexes, @@ -480,7 +479,7 @@ void LocalStorage::Append(LocalAppendState &state, DuckTableEntry &table_entry, DataTableInfo &data_table_info) { auto storage = state.storage; storage->table_entry = &table_entry; - auto offset = NumericCast(MAX_ROW_ID) + storage->GetCollection().GetTotalRows(); + auto offset = NumericCast(MAX_ROW_ID) + storage->GetCollection().GetNextRowId(); idx_t base_id = offset + state.append_state.total_append_count; if (!storage->append_indexes.Empty()) { @@ -503,11 +502,11 @@ void LocalStorage::Append(LocalAppendState &state, DuckTableEntry &table_entry, } // Append the chunk to the local storage. - auto new_row_group = storage->GetCollection().Append(table_chunk, state.append_state); + auto flushed_row_group_idx = storage->GetCollection().Append(table_chunk, state.append_state); // Check if we should pre-emptively flush blocks to disk. - if (new_row_group) { - storage->WriteNewRowGroup(); + if (flushed_row_group_idx.IsValid()) { + storage->WriteNewRowGroup(flushed_row_group_idx.GetIndex()); } } @@ -520,7 +519,7 @@ void LocalStorage::LocalMerge(DataTable &table, DuckTableEntry &table_entry, Opt storage.table_entry = &table_entry; if (!storage.append_indexes.Empty()) { // append data to indexes if required - row_t base_id = MAX_ROW_ID + NumericCast(storage.GetCollection().GetTotalRows()); + row_t base_id = MAX_ROW_ID + NumericCast(storage.GetCollection().GetNextRowId()); auto error = storage.AppendToIndexes(transaction, *collection.collection, storage.append_indexes, table.GetTypes(), base_id); if (error.HasError()) { diff --git a/src/duckdb/src/storage/metadata/metadata_manager.cpp b/src/duckdb/src/storage/metadata/metadata_manager.cpp index 21414bf38..e3346f0bd 100644 --- a/src/duckdb/src/storage/metadata/metadata_manager.cpp +++ b/src/duckdb/src/storage/metadata/metadata_manager.cpp @@ -425,6 +425,17 @@ void MetadataManager::ClearModifiedBlocks(const vector &pointe } } +bool MetadataManager::BlockIsModified(const MetaBlockPointer &pointer) { + unique_lock guard(block_lock); + auto block_id = pointer.GetBlockId(); + auto entry = blocks.find(block_id); + if (entry == blocks.end()) { + throw InternalException("BlockIsNotModified - Block id %llu not found in blocks", block_id); + } + auto entry2 = modified_blocks.find(block_id); + return entry2 != modified_blocks.end(); +} + bool MetadataManager::BlockHasBeenCleared(const MetaBlockPointer &pointer) { unique_lock guard(block_lock); auto block_id = pointer.GetBlockId(); diff --git a/src/duckdb/src/storage/metadata/metadata_reader.cpp b/src/duckdb/src/storage/metadata/metadata_reader.cpp index c563df304..b47ca7805 100644 --- a/src/duckdb/src/storage/metadata/metadata_reader.cpp +++ b/src/duckdb/src/storage/metadata/metadata_reader.cpp @@ -53,18 +53,6 @@ MetaBlockPointer MetadataReader::GetMetaBlockPointer() { return manager.GetDiskPointer(block.pointer, UnsafeNumericCast(offset)); } -vector MetadataReader::GetRemainingBlocks(MetaBlockPointer last_block) { - vector result; - while (has_next_block) { - if (last_block.IsValid() && next_pointer.block_pointer == last_block.block_pointer) { - break; - } - result.push_back(next_pointer); - ReadNextBlock(); - } - return result; -} - void MetadataReader::ReadNextBlock() { ReadNextBlock(QueryContext()); } diff --git a/src/duckdb/src/storage/open_file_storage_extension.cpp b/src/duckdb/src/storage/open_file_storage_extension.cpp index 25dbacbe9..818ef36a5 100644 --- a/src/duckdb/src/storage/open_file_storage_extension.cpp +++ b/src/duckdb/src/storage/open_file_storage_extension.cpp @@ -9,20 +9,20 @@ namespace duckdb { class OpenFileDefaultGenerator : public DefaultGenerator { public: - OpenFileDefaultGenerator(Catalog &catalog, SchemaCatalogEntry &schema, const case_insensitive_set_t &view_names_p, + OpenFileDefaultGenerator(Catalog &catalog, SchemaCatalogEntry &schema, const identifier_set_t &view_names_p, string file_p) : DefaultGenerator(catalog), schema(schema), file(std::move(file_p)) { for (auto &view_name : view_names_p) { - view_names.push_back(view_name); + view_names.emplace_back(view_name); } } public: - unique_ptr CreateDefaultEntry(ClientContext &context, const string &entry_name) override { + unique_ptr CreateDefaultEntry(ClientContext &context, const Identifier &entry_name) override { for (auto &entry : view_names) { - if (StringUtil::CIEquals(entry_name, entry)) { + if (entry_name == entry) { auto result = make_uniq(); - result->schema = DEFAULT_SCHEMA; + result->schema = Identifier::DefaultSchema(); result->view_name = entry; result->sql = StringUtil::Format("SELECT * FROM %s", SQLString(file)); auto view_info = CreateViewInfo::FromSelect(context, std::move(result)); @@ -32,13 +32,13 @@ class OpenFileDefaultGenerator : public DefaultGenerator { return nullptr; } - vector GetDefaultEntries() override { + vector GetDefaultEntries() override { return view_names; } private: SchemaCatalogEntry &schema; - vector view_names; + vector view_names; string file; }; @@ -51,13 +51,13 @@ unique_ptr OpenFileStorageAttach(optional_ptr sto auto catalog = make_uniq(db); catalog->Initialize(false); - case_insensitive_set_t view_names; + identifier_set_t view_names; view_names.insert("file"); - view_names.insert(name); + view_names.insert(Identifier(name)); // set up the default view generator for "file" and the derived name of the file auto system_transaction = CatalogTransaction::GetSystemTransaction(db.GetDatabase()); - auto &schema = catalog->GetSchema(system_transaction, DEFAULT_SCHEMA); + auto &schema = catalog->GetSchema(system_transaction, Identifier::DefaultSchema()); auto &duck_schema = schema.Cast(); auto &catalog_set = duck_schema.GetCatalogSet(CatalogType::VIEW_ENTRY); auto default_generator = make_uniq(*catalog, schema, view_names, std::move(file)); diff --git a/src/duckdb/src/storage/optimistic_data_writer.cpp b/src/duckdb/src/storage/optimistic_data_writer.cpp index eb47b3e5e..bec7d92a5 100644 --- a/src/duckdb/src/storage/optimistic_data_writer.cpp +++ b/src/duckdb/src/storage/optimistic_data_writer.cpp @@ -61,16 +61,37 @@ unique_ptr OptimisticDataWriter::CreateCollection(Dat return result; } -void OptimisticDataWriter::WriteNewRowGroup(OptimisticWriteCollection &row_groups) { +void OptimisticDataWriter::WriteNewRowGroup(OptimisticWriteCollection &row_groups, idx_t flushed_row_group_idx) { // we finished writing a complete row group if (!PrepareWrite()) { return; } - row_groups.unflushed_row_groups.insert(row_groups.complete_row_groups); - row_groups.complete_row_groups++; + row_groups.unflushed_row_groups.insert(flushed_row_group_idx); + auto allocated_size = row_groups.collection->GetAllocationSize(); + if (row_groups.prev_allocated_size > allocated_size) { + throw InternalException("Row group prev allocated size is larger than currently allocated size"); + } + row_groups.unflushed_data_size += allocated_size - row_groups.prev_allocated_size; + row_groups.prev_allocated_size = allocated_size; auto unflushed_row_groups = row_groups.unflushed_row_groups.size(); - if (unflushed_row_groups >= Settings::Get(context)) { + // check if we should flush the row groups + // first check the amount of row groups + bool need_to_flush = unflushed_row_groups >= Settings::Get(context); + if (!need_to_flush) { + // we don't need to flush based on the amount of row groups - but we still might need to flush based on the + // amount of + auto &config = DBConfig::GetConfig(context); + auto memory_limit = config.options.write_buffer_row_group_memory_limit; + if (!memory_limit.IsValid()) { + memory_limit = config.options.maximum_memory / 5 / (config.options.maximum_threads + 1); + } + if (row_groups.unflushed_data_size >= memory_limit.GetIndex()) { + // we exhausted our memory available for buffering - flush + need_to_flush = true; + } + } + if (need_to_flush) { // we have crossed our flush threshold - flush any unwritten row groups to disk vector> to_flush; vector segment_indexes; @@ -80,10 +101,16 @@ void OptimisticDataWriter::WriteNewRowGroup(OptimisticWriteCollection &row_group segment_indexes.push_back(segment_index); } FlushToDisk(row_groups, to_flush, segment_indexes); - row_groups.unflushed_row_groups.clear(); + row_groups.FinalizeFlush(); } } +void OptimisticWriteCollection::FinalizeFlush() { + flushed_row_groups.insert(unflushed_row_groups.begin(), unflushed_row_groups.end()); + unflushed_row_groups.clear(); + unflushed_data_size = 0; +} + void OptimisticDataWriter::WriteUnflushedRowGroups(OptimisticWriteCollection &row_groups) { // we finished writing a complete row group if (!PrepareWrite()) { @@ -91,12 +118,12 @@ void OptimisticDataWriter::WriteUnflushedRowGroups(OptimisticWriteCollection &ro } // add any incomplete row groups to the set of unflushed row groups auto total_row_groups = row_groups.collection->GetRowGroupCount(); - if (row_groups.complete_row_groups > total_row_groups) { - throw InternalException("WriteUnflushedRowGroups - complete row groups > total_row_groups"); - } - for (idx_t i = row_groups.complete_row_groups; i < total_row_groups; i++) { - row_groups.unflushed_row_groups.insert(i); - row_groups.complete_row_groups++; + for (idx_t i = 0; i < total_row_groups; i++) { + // check if this row group was flushed + auto entry = row_groups.flushed_row_groups.find(i); + if (entry == row_groups.flushed_row_groups.end()) { + row_groups.unflushed_row_groups.insert(i); + } } if (!row_groups.unflushed_row_groups.empty()) { // flush the last batch of row groups @@ -114,8 +141,11 @@ void OptimisticDataWriter::WriteUnflushedRowGroups(OptimisticWriteCollection &ro for (auto &partial_manager : row_groups.partial_block_managers) { Merge(partial_manager); } - row_groups.unflushed_row_groups.clear(); + row_groups.FinalizeFlush(); row_groups.partial_block_managers.clear(); + // any new append to the row group collection needs to append a new row group + // otherwise we append to an already flushed row group + row_groups.collection->SetRowGroupAppendMode(RowGroupAppendMode::REQUIRE_NEW); } void OptimisticWriteCollection::MergeStorage(OptimisticWriteCollection &merge_collection) { @@ -124,31 +154,17 @@ void OptimisticWriteCollection::MergeStorage(OptimisticWriteCollection &merge_co // no rows to merge - done return; } - // when merging the other row group is appended to the END of this row group - // that means any trailing row groups that are not yet complete are now complete (even if they are half empty) - // add them to the unflushed set idx_t current_row_group_count = collection->GetRowGroupCount(); - if (complete_row_groups > current_row_group_count) { - throw InternalException("MergeStorage - complete row groups > total_row_groups"); - } - for (idx_t i = complete_row_groups; i < current_row_group_count; i++) { - unflushed_row_groups.insert(i); - complete_row_groups++; - } - // now we merge the target collection into this one - take over any unflushed row groups but adjust their index for (auto &unflushed_idx : merge_collection.unflushed_row_groups) { unflushed_row_groups.insert(current_row_group_count + unflushed_idx); } - complete_row_groups += merge_collection.complete_row_groups; + for (auto &flushed_idx : merge_collection.flushed_row_groups) { + flushed_row_groups.insert(current_row_group_count + flushed_idx); + } + unflushed_data_size += merge_collection.unflushed_data_size; // finally perform the actual merge collection->MergeStorage(merge_row_groups, nullptr, nullptr); - // check if all row groups have been flushed - // we cannot append into a row group that has been flushed - if (complete_row_groups == collection->GetRowGroupCount()) { - // if the last row group has been flushed move any new appends to a new row group - collection->SetRowGroupAppendMode(RowGroupAppendMode::SUGGEST_NEW); - } } void OptimisticDataWriter::FlushToDisk(OptimisticWriteCollection &collection, diff --git a/src/duckdb/src/storage/serialization/serialize_constraint.cpp b/src/duckdb/src/storage/serialization/serialize_constraint.cpp index 9993a59ed..251b9c274 100644 --- a/src/duckdb/src/storage/serialization/serialize_constraint.cpp +++ b/src/duckdb/src/storage/serialization/serialize_constraint.cpp @@ -48,22 +48,22 @@ unique_ptr CheckConstraint::Deserialize(Deserializer &deserializer) void ForeignKeyConstraint::Serialize(Serializer &serializer) const { Constraint::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "pk_columns", pk_columns); - serializer.WritePropertyWithDefault>(201, "fk_columns", fk_columns); + serializer.WritePropertyWithDefault>(200, "pk_columns", pk_columns); + serializer.WritePropertyWithDefault>(201, "fk_columns", fk_columns); serializer.WriteProperty(202, "fk_type", info.type); - serializer.WritePropertyWithDefault(203, "schema", info.schema); - serializer.WritePropertyWithDefault(204, "table", info.table); + serializer.WritePropertyWithDefault(203, "schema", info.schema); + serializer.WritePropertyWithDefault(204, "table", info.table); serializer.WritePropertyWithDefault>(205, "pk_keys", info.pk_keys); serializer.WritePropertyWithDefault>(206, "fk_keys", info.fk_keys); } unique_ptr ForeignKeyConstraint::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new ForeignKeyConstraint()); - deserializer.ReadPropertyWithDefault>(200, "pk_columns", result->pk_columns); - deserializer.ReadPropertyWithDefault>(201, "fk_columns", result->fk_columns); + deserializer.ReadPropertyWithDefault>(200, "pk_columns", result->pk_columns); + deserializer.ReadPropertyWithDefault>(201, "fk_columns", result->fk_columns); deserializer.ReadProperty(202, "fk_type", result->info.type); - deserializer.ReadPropertyWithDefault(203, "schema", result->info.schema); - deserializer.ReadPropertyWithDefault(204, "table", result->info.table); + deserializer.ReadPropertyWithDefault(203, "schema", result->info.schema); + deserializer.ReadPropertyWithDefault(204, "table", result->info.table); deserializer.ReadPropertyWithDefault>(205, "pk_keys", result->info.pk_keys); deserializer.ReadPropertyWithDefault>(206, "fk_keys", result->info.fk_keys); return std::move(result); @@ -84,14 +84,14 @@ void UniqueConstraint::Serialize(Serializer &serializer) const { Constraint::Serialize(serializer); serializer.WritePropertyWithDefault(200, "is_primary_key", is_primary_key); serializer.WriteProperty(201, "index", index); - serializer.WritePropertyWithDefault>(202, "columns", columns); + serializer.WritePropertyWithDefault>(202, "columns", columns); } unique_ptr UniqueConstraint::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new UniqueConstraint()); deserializer.ReadPropertyWithDefault(200, "is_primary_key", result->is_primary_key); deserializer.ReadProperty(201, "index", result->index); - deserializer.ReadPropertyWithDefault>(202, "columns", result->columns); + deserializer.ReadPropertyWithDefault>(202, "columns", result->columns); return std::move(result); } diff --git a/src/duckdb/src/storage/serialization/serialize_create_info.cpp b/src/duckdb/src/storage/serialization/serialize_create_info.cpp index b90d121a5..13de98d50 100644 --- a/src/duckdb/src/storage/serialization/serialize_create_info.cpp +++ b/src/duckdb/src/storage/serialization/serialize_create_info.cpp @@ -19,24 +19,24 @@ namespace duckdb { void CreateInfo::Serialize(Serializer &serializer) const { serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault(101, "catalog", catalog); - serializer.WritePropertyWithDefault(102, "schema", schema); + serializer.WritePropertyWithDefault(101, "catalog", catalog); + serializer.WritePropertyWithDefault(102, "schema", schema); serializer.WritePropertyWithDefault(103, "temporary", temporary); serializer.WritePropertyWithDefault(104, "internal", internal); serializer.WriteProperty(105, "on_conflict", on_conflict); serializer.WritePropertyWithDefault(106, "sql", sql); serializer.WritePropertyWithDefault(107, "comment", comment, Value()); serializer.WritePropertyWithDefault>(108, "tags", tags, InsertionOrderPreservingMap()); - if (serializer.ShouldSerialize(2)) { + if (serializer.ShouldSerialize(StorageVersion::V0_10_3)) { serializer.WritePropertyWithDefault(109, "dependencies", dependencies, LogicalDependencyList()); } - serializer.WritePropertyWithDefault(110, "extension_name", extension_name); + serializer.WritePropertyWithDefault(110, "extension_name", extension_name); } unique_ptr CreateInfo::Deserialize(Deserializer &deserializer) { auto type = deserializer.ReadProperty(100, "type"); - auto catalog = deserializer.ReadPropertyWithDefault(101, "catalog"); - auto schema = deserializer.ReadPropertyWithDefault(102, "schema"); + auto catalog = deserializer.ReadPropertyWithDefault(101, "catalog"); + auto schema = deserializer.ReadPropertyWithDefault(102, "schema"); auto temporary = deserializer.ReadPropertyWithDefault(103, "temporary"); auto internal = deserializer.ReadPropertyWithDefault(104, "internal"); auto on_conflict = deserializer.ReadProperty(105, "on_conflict"); @@ -44,7 +44,7 @@ unique_ptr CreateInfo::Deserialize(Deserializer &deserializer) { auto comment = deserializer.ReadPropertyWithExplicitDefault(107, "comment", Value()); auto tags = deserializer.ReadPropertyWithExplicitDefault>(108, "tags", InsertionOrderPreservingMap()); auto dependencies = deserializer.ReadPropertyWithExplicitDefault(109, "dependencies", LogicalDependencyList()); - auto extension_name = deserializer.ReadPropertyWithDefault(110, "extension_name"); + auto extension_name = deserializer.ReadPropertyWithDefault(110, "extension_name"); deserializer.Set(type); unique_ptr result; switch (type) { @@ -94,13 +94,13 @@ unique_ptr CreateInfo::Deserialize(Deserializer &deserializer) { void CreateIndexInfo::Serialize(Serializer &serializer) const { CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", index_name); - serializer.WritePropertyWithDefault(201, "table", table); + serializer.WritePropertyWithDefault(200, "name", index_name); + serializer.WritePropertyWithDefault(201, "table", table); /* [Deleted] (DeprecatedIndexType) "index_type" */ serializer.WriteProperty(203, "constraint_type", constraint_type); serializer.WritePropertyWithDefault>>(204, "parsed_expressions", parsed_expressions); serializer.WritePropertyWithDefault>(205, "scan_types", scan_types); - serializer.WritePropertyWithDefault>(206, "names", names); + serializer.WritePropertyWithDefault>(206, "names", names); serializer.WritePropertyWithDefault>(207, "column_ids", column_ids); serializer.WritePropertyWithDefault>(208, "options", options); serializer.WritePropertyWithDefault(209, "index_type_name", index_type); @@ -108,13 +108,13 @@ void CreateIndexInfo::Serialize(Serializer &serializer) const { unique_ptr CreateIndexInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new CreateIndexInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->index_name); - deserializer.ReadPropertyWithDefault(201, "table", result->table); + deserializer.ReadPropertyWithDefault(200, "name", result->index_name); + deserializer.ReadPropertyWithDefault(201, "table", result->table); deserializer.ReadDeletedProperty(202, "index_type"); deserializer.ReadProperty(203, "constraint_type", result->constraint_type); deserializer.ReadPropertyWithDefault>>(204, "parsed_expressions", result->parsed_expressions); deserializer.ReadPropertyWithDefault>(205, "scan_types", result->scan_types); - deserializer.ReadPropertyWithDefault>(206, "names", result->names); + deserializer.ReadPropertyWithDefault>(206, "names", result->names); deserializer.ReadPropertyWithDefault>(207, "column_ids", result->column_ids); deserializer.ReadPropertyWithDefault>(208, "options", result->options); deserializer.ReadPropertyWithDefault(209, "index_type_name", result->index_type); @@ -123,13 +123,13 @@ unique_ptr CreateIndexInfo::Deserialize(Deserializer &deserializer) void CreateMacroInfo::Serialize(Serializer &serializer) const { CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(200, "name", name); serializer.WritePropertyWithDefault>(201, "function", macros[0]); serializer.WritePropertyWithDefault>>(202, "extra_functions", GetAllButFirstFunction()); } unique_ptr CreateMacroInfo::Deserialize(Deserializer &deserializer) { - auto name = deserializer.ReadPropertyWithDefault(200, "name"); + auto name = deserializer.ReadPropertyWithDefault(200, "name"); auto function = deserializer.ReadPropertyWithDefault>(201, "function"); auto extra_functions = deserializer.ReadPropertyWithDefault>>(202, "extra_functions"); auto result = duckdb::unique_ptr(new CreateMacroInfo(deserializer.Get(), std::move(function), std::move(extra_functions))); @@ -148,21 +148,21 @@ unique_ptr CreateSchemaInfo::Deserialize(Deserializer &deserializer) void CreateSequenceInfo::Serialize(Serializer &serializer) const { CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(200, "name", name); serializer.WritePropertyWithDefault(201, "usage_count", usage_count); serializer.WritePropertyWithDefault(202, "increment", increment); serializer.WritePropertyWithDefault(203, "min_value", min_value); serializer.WritePropertyWithDefault(204, "max_value", max_value); serializer.WritePropertyWithDefault(205, "start_value", start_value); serializer.WritePropertyWithDefault(206, "cycle", cycle); - if (serializer.ShouldSerialize(8)) { + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { serializer.WritePropertyWithDefault>(207, "last_value", last_value); } } unique_ptr CreateSequenceInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new CreateSequenceInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault(200, "name", result->name); deserializer.ReadPropertyWithDefault(201, "usage_count", result->usage_count); deserializer.ReadPropertyWithDefault(202, "increment", result->increment); deserializer.ReadPropertyWithDefault(203, "min_value", result->min_value); @@ -175,7 +175,7 @@ unique_ptr CreateSequenceInfo::Deserialize(Deserializer &deserialize void CreateTableInfo::Serialize(Serializer &serializer) const { CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table", table); + serializer.WritePropertyWithDefault(200, "table", table); serializer.WriteProperty(201, "columns", columns); serializer.WritePropertyWithDefault>>(202, "constraints", constraints); serializer.WritePropertyWithDefault>(203, "query", query); @@ -186,7 +186,7 @@ void CreateTableInfo::Serialize(Serializer &serializer) const { unique_ptr CreateTableInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new CreateTableInfo()); - deserializer.ReadPropertyWithDefault(200, "table", result->table); + deserializer.ReadPropertyWithDefault(200, "table", result->table); deserializer.ReadProperty(201, "columns", result->columns); deserializer.ReadPropertyWithDefault>>(202, "constraints", result->constraints); deserializer.ReadPropertyWithDefault>(203, "query", result->query); @@ -198,64 +198,68 @@ unique_ptr CreateTableInfo::Deserialize(Deserializer &deserializer) void CreateTriggerInfo::Serialize(Serializer &serializer) const { CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "trigger_name", trigger_name); + serializer.WritePropertyWithDefault(200, "trigger_name", trigger_name); serializer.WritePropertyWithDefault>(201, "base_table", base_table); serializer.WriteProperty(204, "timing", timing); serializer.WriteProperty(205, "event_type", event_type); - serializer.WritePropertyWithDefault>(206, "columns", columns); + serializer.WritePropertyWithDefault>(206, "columns", columns); serializer.WriteProperty(207, "for_each", for_each); serializer.WritePropertyWithDefault>(208, "trigger_action", trigger_action); + serializer.WritePropertyWithDefault(209, "referencing_new_table", referencing_new_table); + serializer.WritePropertyWithDefault(210, "referencing_old_table", referencing_old_table); } unique_ptr CreateTriggerInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new CreateTriggerInfo()); - deserializer.ReadPropertyWithDefault(200, "trigger_name", result->trigger_name); + deserializer.ReadPropertyWithDefault(200, "trigger_name", result->trigger_name); auto base_table = deserializer.ReadPropertyWithDefault>(201, "base_table"); result->base_table = unique_ptr_cast(std::move(base_table)); deserializer.ReadProperty(204, "timing", result->timing); deserializer.ReadProperty(205, "event_type", result->event_type); - deserializer.ReadPropertyWithDefault>(206, "columns", result->columns); + deserializer.ReadPropertyWithDefault>(206, "columns", result->columns); deserializer.ReadProperty(207, "for_each", result->for_each); deserializer.ReadPropertyWithDefault>(208, "trigger_action", result->trigger_action); + deserializer.ReadPropertyWithDefault(209, "referencing_new_table", result->referencing_new_table); + deserializer.ReadPropertyWithDefault(210, "referencing_old_table", result->referencing_old_table); return std::move(result); } void CreateTypeInfo::Serialize(Serializer &serializer) const { CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(200, "name", name); serializer.WriteProperty(201, "logical_type", type); } unique_ptr CreateTypeInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new CreateTypeInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault(200, "name", result->name); deserializer.ReadProperty(201, "logical_type", result->type); return std::move(result); } void CreateViewInfo::Serialize(Serializer &serializer) const { CreateInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "view_name", view_name); - serializer.WritePropertyWithDefault>(201, "aliases", aliases); + serializer.WritePropertyWithDefault(200, "view_name", view_name); + serializer.WritePropertyWithDefault>(201, "aliases", aliases); serializer.WritePropertyWithDefault>(202, "types", types); serializer.WritePropertyWithDefault>(203, "query", query); - serializer.WritePropertyWithDefault>(204, "names", names); - if (!serializer.ShouldSerialize(7)) { + serializer.WritePropertyWithDefault>(204, "names", names); + if (!serializer.ShouldSerialize(StorageVersion::V1_5_0)) { serializer.WritePropertyWithDefault>(205, "column_comments", GetColumnCommentsList()); } - if (serializer.ShouldSerialize(7)) { - serializer.WritePropertyWithDefault>(206, "column_comments_map", column_comments_map, unordered_map()); + if (serializer.ShouldSerialize(StorageVersion::V1_5_0)) { + serializer.WritePropertyWithDefault>(206, "column_comments_map", column_comments_map, identifier_map_t()); } } unique_ptr CreateViewInfo::Deserialize(Deserializer &deserializer) { - auto view_name = deserializer.ReadPropertyWithDefault(200, "view_name"); - auto aliases = deserializer.ReadPropertyWithDefault>(201, "aliases"); + auto view_name = deserializer.ReadPropertyWithDefault(200, "view_name"); + auto aliases = deserializer.ReadPropertyWithDefault>(201, "aliases"); auto types = deserializer.ReadPropertyWithDefault>(202, "types"); auto query = deserializer.ReadPropertyWithDefault>(203, "query"); - auto names = deserializer.ReadPropertyWithDefault>(204, "names"); + auto names = deserializer.ReadPropertyWithDefault>(204, "names"); auto column_comments = deserializer.ReadPropertyWithDefault>(205, "column_comments"); - auto column_comments_map = deserializer.ReadPropertyWithExplicitDefault>(206, "column_comments_map", unordered_map()); + auto column_comments_map = deserializer.ReadPropertyWithExplicitDefault>(206, "column_comments_map", identifier_map_t()); auto result = duckdb::unique_ptr(new CreateViewInfo(std::move(names), std::move(column_comments), std::move(column_comments_map))); result->view_name = std::move(view_name); result->aliases = std::move(aliases); diff --git a/src/duckdb/src/storage/serialization/serialize_dependency.cpp b/src/duckdb/src/storage/serialization/serialize_dependency.cpp index f265b6e6f..9d9300f26 100644 --- a/src/duckdb/src/storage/serialization/serialize_dependency.cpp +++ b/src/duckdb/src/storage/serialization/serialize_dependency.cpp @@ -12,26 +12,26 @@ namespace duckdb { void CatalogEntryInfo::Serialize(Serializer &serializer) const { serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault(101, "schema", schema); - serializer.WritePropertyWithDefault(102, "name", name); + serializer.WritePropertyWithDefault(101, "schema", schema); + serializer.WritePropertyWithDefault(102, "name", name); } CatalogEntryInfo CatalogEntryInfo::Deserialize(Deserializer &deserializer) { CatalogEntryInfo result; deserializer.ReadProperty(100, "type", result.type); - deserializer.ReadPropertyWithDefault(101, "schema", result.schema); - deserializer.ReadPropertyWithDefault(102, "name", result.name); + deserializer.ReadPropertyWithDefault(101, "schema", result.schema); + deserializer.ReadPropertyWithDefault(102, "name", result.name); return result; } void LogicalDependency::Serialize(Serializer &serializer) const { serializer.WriteProperty(100, "entry", entry); - serializer.WritePropertyWithDefault(101, "catalog", catalog); + serializer.WritePropertyWithDefault(101, "catalog", catalog); } LogicalDependency LogicalDependency::Deserialize(Deserializer &deserializer) { auto entry = deserializer.ReadProperty(100, "entry"); - auto catalog = deserializer.ReadPropertyWithDefault(101, "catalog"); + auto catalog = deserializer.ReadPropertyWithDefault(101, "catalog"); LogicalDependency result(deserializer.TryGet(), entry, std::move(catalog)); return result; } diff --git a/src/duckdb/src/storage/serialization/serialize_expression.cpp b/src/duckdb/src/storage/serialization/serialize_expression.cpp index 3e6d8694f..79cb87542 100644 --- a/src/duckdb/src/storage/serialization/serialize_expression.cpp +++ b/src/duckdb/src/storage/serialization/serialize_expression.cpp @@ -14,14 +14,14 @@ namespace duckdb { void Expression::Serialize(Serializer &serializer) const { serializer.WriteProperty(100, "expression_class", expression_class); serializer.WriteProperty(101, "type", type); - serializer.WritePropertyWithDefault(102, "alias", alias); + serializer.WritePropertyWithDefault(102, "alias", alias); serializer.WritePropertyWithDefault(103, "query_location", query_location, optional_idx()); } unique_ptr Expression::Deserialize(Deserializer &deserializer) { auto expression_class = deserializer.ReadProperty(100, "expression_class"); auto type = deserializer.ReadProperty(101, "type"); - auto alias = deserializer.ReadPropertyWithDefault(102, "alias"); + auto alias = deserializer.ReadPropertyWithDefault(102, "alias"); auto query_location = deserializer.ReadPropertyWithExplicitDefault(103, "query_location", optional_idx()); deserializer.Set(type); unique_ptr result; @@ -214,16 +214,16 @@ unique_ptr BoundOperatorExpression::Deserialize(Deserializer &deseri void BoundParameterExpression::Serialize(Serializer &serializer) const { Expression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "identifier", identifier); + serializer.WritePropertyWithDefault(200, "identifier", identifier); serializer.WriteProperty(201, "return_type", return_type); serializer.WritePropertyWithDefault>(202, "parameter_data", parameter_data); } unique_ptr BoundParameterExpression::Deserialize(Deserializer &deserializer) { - auto identifier = deserializer.ReadPropertyWithDefault(200, "identifier"); + auto identifier = deserializer.ReadPropertyWithDefault(200, "identifier"); auto return_type = deserializer.ReadProperty(201, "return_type"); auto parameter_data = deserializer.ReadPropertyWithDefault>(202, "parameter_data"); - auto result = duckdb::unique_ptr(new BoundParameterExpression(deserializer.Get(), std::move(identifier), std::move(return_type), std::move(parameter_data))); + auto result = duckdb::unique_ptr(new BoundParameterExpression(deserializer.Get(), identifier, std::move(return_type), std::move(parameter_data))); return std::move(result); } diff --git a/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp b/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp index bbdc03a7b..eab655f37 100644 --- a/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp +++ b/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp @@ -43,6 +43,9 @@ unique_ptr LogicalOperator::Deserialize(Deserializer &deseriali case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: result = LogicalComparisonJoin::Deserialize(deserializer); break; + case LogicalOperatorType::LOGICAL_CONNECT: + result = LogicalSimple::Deserialize(deserializer); + break; case LogicalOperatorType::LOGICAL_COPY_DATABASE: result = LogicalCopyDatabase::Deserialize(deserializer); break; @@ -91,6 +94,9 @@ unique_ptr LogicalOperator::Deserialize(Deserializer &deseriali case LogicalOperatorType::LOGICAL_DETACH: result = LogicalSimple::Deserialize(deserializer); break; + case LogicalOperatorType::LOGICAL_DISCONNECT: + result = LogicalSimple::Deserialize(deserializer); + break; case LogicalOperatorType::LOGICAL_DISTINCT: result = LogicalDistinct::Deserialize(deserializer); break; @@ -295,7 +301,7 @@ void LogicalCTERef::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(200, "table_index", table_index); serializer.WritePropertyWithDefault(201, "cte_index", cte_index); serializer.WritePropertyWithDefault>(202, "chunk_types", chunk_types); - serializer.WritePropertyWithDefault>(203, "bound_columns", bound_columns); + serializer.WritePropertyWithDefault>(203, "bound_columns", bound_columns); /* [Deleted] (CTEMaterialize) "materialized_cte" */ serializer.WritePropertyWithDefault(205, "is_recurring", is_recurring); } @@ -304,7 +310,7 @@ unique_ptr LogicalCTERef::Deserialize(Deserializer &deserialize auto table_index = deserializer.ReadPropertyWithDefault(200, "table_index"); auto cte_index = deserializer.ReadPropertyWithDefault(201, "cte_index"); auto chunk_types = deserializer.ReadPropertyWithDefault>(202, "chunk_types"); - auto bound_columns = deserializer.ReadPropertyWithDefault>(203, "bound_columns"); + auto bound_columns = deserializer.ReadPropertyWithDefault>(203, "bound_columns"); auto result = duckdb::unique_ptr(new LogicalCTERef(table_index, cte_index, std::move(chunk_types), std::move(bound_columns))); deserializer.ReadDeletedProperty(204, "materialized_cte"); deserializer.ReadPropertyWithDefault(205, "is_recurring", result->is_recurring); @@ -603,7 +609,7 @@ void LogicalMaterializedCTE::Serialize(Serializer &serializer) const { LogicalOperator::Serialize(serializer); serializer.WritePropertyWithDefault(200, "table_index", table_index); serializer.WritePropertyWithDefault(201, "column_count", column_count); - serializer.WritePropertyWithDefault(202, "ctename", ctename); + serializer.WritePropertyWithDefault(202, "ctename", ctename); serializer.WritePropertyWithDefault(203, "materialize", materialize, CTEMaterialize::CTE_MATERIALIZE_DEFAULT); } @@ -611,7 +617,7 @@ unique_ptr LogicalMaterializedCTE::Deserialize(Deserializer &de auto result = duckdb::unique_ptr(new LogicalMaterializedCTE()); deserializer.ReadPropertyWithDefault(200, "table_index", result->table_index); deserializer.ReadPropertyWithDefault(201, "column_count", result->column_count); - deserializer.ReadPropertyWithDefault(202, "ctename", result->ctename); + deserializer.ReadPropertyWithDefault(202, "ctename", result->ctename); deserializer.ReadPropertyWithExplicitDefault(203, "materialize", result->materialize, CTEMaterialize::CTE_MATERIALIZE_DEFAULT); return std::move(result); } @@ -692,7 +698,7 @@ unique_ptr LogicalProjection::Deserialize(Deserializer &deseria void LogicalRecursiveCTE::Serialize(Serializer &serializer) const { LogicalOperator::Serialize(serializer); serializer.WritePropertyWithDefault(200, "union_all", union_all); - serializer.WritePropertyWithDefault(201, "ctename", ctename); + serializer.WritePropertyWithDefault(201, "ctename", ctename); serializer.WritePropertyWithDefault(202, "table_index", table_index); serializer.WritePropertyWithDefault(203, "column_count", column_count); serializer.WritePropertyWithDefault>>(204, "key_targets", key_targets); @@ -704,7 +710,7 @@ void LogicalRecursiveCTE::Serialize(Serializer &serializer) const { unique_ptr LogicalRecursiveCTE::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new LogicalRecursiveCTE()); deserializer.ReadPropertyWithDefault(200, "union_all", result->union_all); - deserializer.ReadPropertyWithDefault(201, "ctename", result->ctename); + deserializer.ReadPropertyWithDefault(201, "ctename", result->ctename); deserializer.ReadPropertyWithDefault(202, "table_index", result->table_index); deserializer.ReadPropertyWithDefault(203, "column_count", result->column_count); deserializer.ReadPropertyWithDefault>>(204, "key_targets", result->key_targets); @@ -716,12 +722,12 @@ unique_ptr LogicalRecursiveCTE::Deserialize(Deserializer &deser void LogicalReset::Serialize(Serializer &serializer) const { LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(200, "name", name); serializer.WriteProperty(201, "scope", scope); } unique_ptr LogicalReset::Deserialize(Deserializer &deserializer) { - auto name = deserializer.ReadPropertyWithDefault(200, "name"); + auto name = deserializer.ReadPropertyWithDefault(200, "name"); auto scope = deserializer.ReadProperty(201, "scope"); auto result = duckdb::unique_ptr(new LogicalReset(std::move(name), scope)); return std::move(result); @@ -740,13 +746,13 @@ unique_ptr LogicalSample::Deserialize(Deserializer &deserialize void LogicalSet::Serialize(Serializer &serializer) const { LogicalOperator::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(200, "name", name); serializer.WriteProperty(201, "value", value); serializer.WriteProperty(202, "scope", scope); } unique_ptr LogicalSet::Deserialize(Deserializer &deserializer) { - auto name = deserializer.ReadPropertyWithDefault(200, "name"); + auto name = deserializer.ReadPropertyWithDefault(200, "name"); auto value = deserializer.ReadProperty(201, "value"); auto scope = deserializer.ReadProperty(202, "scope"); auto result = duckdb::unique_ptr(new LogicalSet(std::move(name), value, scope)); diff --git a/src/duckdb/src/storage/serialization/serialize_macro_function.cpp b/src/duckdb/src/storage/serialization/serialize_macro_function.cpp index 39be0d59a..54117eeef 100644 --- a/src/duckdb/src/storage/serialization/serialize_macro_function.cpp +++ b/src/duckdb/src/storage/serialization/serialize_macro_function.cpp @@ -14,8 +14,8 @@ namespace duckdb { void MacroFunction::Serialize(Serializer &serializer) const { serializer.WriteProperty(100, "type", type); serializer.WritePropertyWithDefault>>(101, "parameters", GetPositionalParametersForSerialization(serializer)); - serializer.WritePropertyWithDefault>>(102, "default_parameters", default_parameters); - if (serializer.ShouldSerialize(6)) { + serializer.WritePropertyWithDefault, Identifier, identifier_map_t>>(102, "default_parameters", default_parameters); + if (serializer.ShouldSerialize(StorageVersion::V1_4_0)) { serializer.WritePropertyWithDefault>(103, "types", types, vector()); } } @@ -23,7 +23,7 @@ void MacroFunction::Serialize(Serializer &serializer) const { unique_ptr MacroFunction::Deserialize(Deserializer &deserializer) { auto type = deserializer.ReadProperty(100, "type"); auto parameters = deserializer.ReadPropertyWithDefault>>(101, "parameters"); - auto default_parameters = deserializer.ReadPropertyWithDefault>>(102, "default_parameters"); + auto default_parameters = deserializer.ReadPropertyWithDefault, Identifier, identifier_map_t>>(102, "default_parameters"); auto types = deserializer.ReadPropertyWithExplicitDefault>(103, "types", vector()); unique_ptr result; switch (type) { diff --git a/src/duckdb/src/storage/serialization/serialize_nodes.cpp b/src/duckdb/src/storage/serialization/serialize_nodes.cpp index b095bfd57..85d4197f4 100644 --- a/src/duckdb/src/storage/serialization/serialize_nodes.cpp +++ b/src/duckdb/src/storage/serialization/serialize_nodes.cpp @@ -11,6 +11,7 @@ #include "duckdb/planner/bound_result_modifier.hpp" #include "duckdb/parser/expression/case_expression.hpp" #include "duckdb/planner/expression/bound_case_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" #include "duckdb/execution/reservoir_sample.hpp" #include "duckdb/common/queue.hpp" @@ -195,7 +196,7 @@ ColumnBinding ColumnBinding::Deserialize(Deserializer &deserializer) { } void ColumnDefinition::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "name", name); + serializer.WritePropertyWithDefault(100, "name", name); serializer.WriteProperty(101, "type", type); serializer.WritePropertyWithDefault>(102, "expression", expression); serializer.WriteProperty(103, "category", category); @@ -205,7 +206,7 @@ void ColumnDefinition::Serialize(Serializer &serializer) const { } ColumnDefinition ColumnDefinition::Deserialize(Deserializer &deserializer) { - auto name = deserializer.ReadPropertyWithDefault(100, "name"); + auto name = deserializer.ReadPropertyWithDefault(100, "name"); auto type = deserializer.ReadProperty(101, "type"); auto expression = deserializer.ReadPropertyWithDefault>(102, "expression"); auto category = deserializer.ReadProperty(103, "category"); @@ -259,20 +260,20 @@ ColumnList ColumnList::Deserialize(Deserializer &deserializer) { } void CommonTableExpressionInfo::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "aliases", aliases); - if (!serializer.ShouldSerialize(8)) { + serializer.WritePropertyWithDefault>(100, "aliases", aliases); + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0)) { serializer.WritePropertyWithDefault>(101, "query", GetQueryForSerialization(serializer)); } serializer.WriteProperty(102, "materialized", GetMaterializedForSerialization(serializer)); serializer.WritePropertyWithDefault>>(103, "key_targets", key_targets); serializer.WritePropertyWithDefault>>(104, "payload_aggregates", payload_aggregates); - if (serializer.ShouldSerialize(8)) { + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { serializer.WritePropertyWithDefault>(105, "query_node", query_node); } } unique_ptr CommonTableExpressionInfo::Deserialize(Deserializer &deserializer) { - auto aliases = deserializer.ReadPropertyWithDefault>(100, "aliases"); + auto aliases = deserializer.ReadPropertyWithDefault>(100, "aliases"); auto query = deserializer.ReadPropertyWithDefault>(101, "query"); auto materialized = deserializer.ReadProperty(102, "materialized"); auto key_targets = deserializer.ReadPropertyWithDefault>>(103, "key_targets"); @@ -287,30 +288,30 @@ unique_ptr CommonTableExpressionInfo::Deserialize(Des } void CommonTableExpressionMap::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>>(100, "map", map); + serializer.WritePropertyWithDefault, Identifier, identifier_map_t>>(100, "map", map); } CommonTableExpressionMap CommonTableExpressionMap::Deserialize(Deserializer &deserializer) { CommonTableExpressionMap result; - deserializer.ReadPropertyWithDefault>>(100, "map", result.map); + deserializer.ReadPropertyWithDefault, Identifier, identifier_map_t>>(100, "map", result.map); return result; } void ExportedTableData::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(1, "table_name", table_name); - serializer.WritePropertyWithDefault(2, "schema_name", schema_name); - serializer.WritePropertyWithDefault(3, "database_name", database_name); + serializer.WritePropertyWithDefault(1, "table_name", table_name); + serializer.WritePropertyWithDefault(2, "schema_name", schema_name); + serializer.WritePropertyWithDefault(3, "database_name", database_name); serializer.WritePropertyWithDefault(4, "file_path", file_path); - serializer.WritePropertyWithDefault>(5, "not_null_columns", not_null_columns); + serializer.WritePropertyWithDefault>(5, "not_null_columns", not_null_columns); } ExportedTableData ExportedTableData::Deserialize(Deserializer &deserializer) { ExportedTableData result; - deserializer.ReadPropertyWithDefault(1, "table_name", result.table_name); - deserializer.ReadPropertyWithDefault(2, "schema_name", result.schema_name); - deserializer.ReadPropertyWithDefault(3, "database_name", result.database_name); + deserializer.ReadPropertyWithDefault(1, "table_name", result.table_name); + deserializer.ReadPropertyWithDefault(2, "schema_name", result.schema_name); + deserializer.ReadPropertyWithDefault(3, "database_name", result.database_name); deserializer.ReadPropertyWithDefault(4, "file_path", result.file_path); - deserializer.ReadPropertyWithDefault>(5, "not_null_columns", result.not_null_columns); + deserializer.ReadPropertyWithDefault>(5, "not_null_columns", result.not_null_columns); return result; } @@ -340,6 +341,18 @@ ExtraOperatorInfo ExtraOperatorInfo::Deserialize(Deserializer &deserializer) { return result; } +void FunctionArgument::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "name", name); + serializer.WritePropertyWithDefault>(101, "expression", expression); +} + +FunctionArgument FunctionArgument::Deserialize(Deserializer &deserializer) { + auto name = deserializer.ReadPropertyWithDefault(100, "name"); + auto expression = deserializer.ReadPropertyWithDefault>(101, "expression"); + FunctionArgument result(std::move(name), std::move(expression)); + return result; +} + void HivePartitioningIndex::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(100, "value", value); serializer.WritePropertyWithDefault(101, "index", index); @@ -374,6 +387,7 @@ void MultiFileOptions::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(104, "hive_types_autocast", hive_types_autocast); serializer.WritePropertyWithDefault>(105, "hive_types_schema", hive_types_schema); serializer.WritePropertyWithDefault(106, "filename_column", filename_column, MultiFileOptions::DEFAULT_FILENAME_COLUMN); + serializer.WritePropertyWithDefault(107, "allow_empty", allow_empty); } MultiFileOptions MultiFileOptions::Deserialize(Deserializer &deserializer) { @@ -385,6 +399,7 @@ MultiFileOptions MultiFileOptions::Deserialize(Deserializer &deserializer) { deserializer.ReadPropertyWithDefault(104, "hive_types_autocast", result.hive_types_autocast); deserializer.ReadPropertyWithDefault>(105, "hive_types_schema", result.hive_types_schema); deserializer.ReadPropertyWithExplicitDefault(106, "filename_column", result.filename_column, MultiFileOptions::DEFAULT_FILENAME_COLUMN); + deserializer.ReadPropertyWithDefault(107, "allow_empty", result.allow_empty); return result; } @@ -416,47 +431,47 @@ OrderByNode OrderByNode::Deserialize(Deserializer &deserializer) { void PivotColumn::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>>(100, "pivot_expressions", pivot_expressions); - serializer.WritePropertyWithDefault>(101, "unpivot_names", unpivot_names); + serializer.WritePropertyWithDefault>(101, "unpivot_names", unpivot_names); serializer.WritePropertyWithDefault>(102, "entries", GetEntriesForSerialization(serializer)); - serializer.WritePropertyWithDefault(103, "pivot_enum", pivot_enum); + serializer.WritePropertyWithDefault(103, "pivot_enum", pivot_enum); } PivotColumn PivotColumn::Deserialize(Deserializer &deserializer) { PivotColumn result; deserializer.ReadPropertyWithDefault>>(100, "pivot_expressions", result.pivot_expressions); - deserializer.ReadPropertyWithDefault>(101, "unpivot_names", result.unpivot_names); + deserializer.ReadPropertyWithDefault>(101, "unpivot_names", result.unpivot_names); deserializer.ReadPropertyWithDefault>(102, "entries", result.entries); - deserializer.ReadPropertyWithDefault(103, "pivot_enum", result.pivot_enum); + deserializer.ReadPropertyWithDefault(103, "pivot_enum", result.pivot_enum); return result; } void PivotColumnEntry::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(100, "values", values); serializer.WritePropertyWithDefault>(101, "star_expr", expr); - serializer.WritePropertyWithDefault(102, "alias", alias); + serializer.WritePropertyWithDefault(102, "alias", alias); } PivotColumnEntry PivotColumnEntry::Deserialize(Deserializer &deserializer) { PivotColumnEntry result; deserializer.ReadPropertyWithDefault>(100, "values", result.values); deserializer.ReadPropertyWithDefault>(101, "star_expr", result.expr); - deserializer.ReadPropertyWithDefault(102, "alias", result.alias); + deserializer.ReadPropertyWithDefault(102, "alias", result.alias); return result; } void QualifiedColumnName::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "catalog", catalog); - serializer.WritePropertyWithDefault(101, "schema", schema); - serializer.WritePropertyWithDefault(102, "table", table); - serializer.WritePropertyWithDefault(103, "column", column); + serializer.WritePropertyWithDefault(100, "catalog", catalog); + serializer.WritePropertyWithDefault(101, "schema", schema); + serializer.WritePropertyWithDefault(102, "table", table); + serializer.WritePropertyWithDefault(103, "column", column); } QualifiedColumnName QualifiedColumnName::Deserialize(Deserializer &deserializer) { QualifiedColumnName result; - deserializer.ReadPropertyWithDefault(100, "catalog", result.catalog); - deserializer.ReadPropertyWithDefault(101, "schema", result.schema); - deserializer.ReadPropertyWithDefault(102, "table", result.table); - deserializer.ReadPropertyWithDefault(103, "column", result.column); + deserializer.ReadPropertyWithDefault(100, "catalog", result.catalog); + deserializer.ReadPropertyWithDefault(101, "schema", result.schema); + deserializer.ReadPropertyWithDefault(102, "table", result.table); + deserializer.ReadPropertyWithDefault(103, "column", result.column); return result; } @@ -567,7 +582,7 @@ void SerializedCSVReaderOptions::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(132, "rejects_scan_name", options.rejects_scan_name, {"reject_scans"}); serializer.WritePropertyWithDefault>(133, "name_list", options.name_list); serializer.WritePropertyWithDefault>(134, "sql_type_list", options.sql_type_list); - serializer.WritePropertyWithDefault>(135, "sql_types_per_column", options.sql_types_per_column); + serializer.WritePropertyWithDefault>(135, "sql_types_per_column", options.sql_types_per_column); serializer.WritePropertyWithDefault(136, "columns_set", options.columns_set, false); serializer.WritePropertyWithDefault>(137, "comment", options.dialect_options.state_machine_options.comment, CSVOption('\0')); serializer.WritePropertyWithDefault(138, "rows_until_header", options.dialect_options.rows_until_header); @@ -576,6 +591,7 @@ void SerializedCSVReaderOptions::Serialize(Serializer &serializer) const { serializer.WriteProperty>(141, "multi_byte_delimiter", options.GetMultiByteDelimiter()); serializer.WritePropertyWithDefault(142, "multi_file_reader", options.multi_file_reader); serializer.WriteProperty>(143, "buffer_size_option", options.buffer_size_option); + serializer.WritePropertyWithDefault(144, "rejects_line_size_limit", options.rejects_line_size_limit, 10000); } SerializedCSVReaderOptions SerializedCSVReaderOptions::Deserialize(Deserializer &deserializer) { @@ -614,7 +630,7 @@ SerializedCSVReaderOptions SerializedCSVReaderOptions::Deserialize(Deserializer auto options_rejects_scan_name = deserializer.ReadPropertyWithExplicitDefault>(132, "rejects_scan_name", {"reject_scans"}); auto options_name_list = deserializer.ReadPropertyWithDefault>(133, "name_list"); auto options_sql_type_list = deserializer.ReadPropertyWithDefault>(134, "sql_type_list"); - auto options_sql_types_per_column = deserializer.ReadPropertyWithDefault>(135, "sql_types_per_column"); + auto options_sql_types_per_column = deserializer.ReadPropertyWithDefault>(135, "sql_types_per_column"); auto options_columns_set = deserializer.ReadPropertyWithExplicitDefault(136, "columns_set", false); auto options_dialect_options_state_machine_options_comment = deserializer.ReadPropertyWithExplicitDefault>(137, "comment", CSVOption('\0')); auto options_dialect_options_rows_until_header = deserializer.ReadPropertyWithDefault(138, "rows_until_header"); @@ -661,6 +677,7 @@ SerializedCSVReaderOptions SerializedCSVReaderOptions::Deserialize(Deserializer result.options.dialect_options.state_machine_options.strict_mode = options_dialect_options_state_machine_options_strict_mode; deserializer.ReadPropertyWithDefault(142, "multi_file_reader", result.options.multi_file_reader); deserializer.ReadProperty>(143, "buffer_size_option", result.options.buffer_size_option); + deserializer.ReadPropertyWithExplicitDefault(144, "rejects_line_size_limit", result.options.rejects_line_size_limit, 10000); return result; } @@ -721,12 +738,12 @@ StrpTimeFormat StrpTimeFormat::Deserialize(Deserializer &deserializer) { } void TableColumn::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "name", name); + serializer.WritePropertyWithDefault(100, "name", name); serializer.WriteProperty(101, "type", type); } TableColumn TableColumn::Deserialize(Deserializer &deserializer) { - auto name = deserializer.ReadPropertyWithDefault(100, "name"); + auto name = deserializer.ReadPropertyWithDefault(100, "name"); auto type = deserializer.ReadProperty(101, "type"); TableColumn result(std::move(name), std::move(type)); return result; diff --git a/src/duckdb/src/storage/serialization/serialize_parse_info.cpp b/src/duckdb/src/storage/serialization/serialize_parse_info.cpp index a0f5edd30..4cc985893 100644 --- a/src/duckdb/src/storage/serialization/serialize_parse_info.cpp +++ b/src/duckdb/src/storage/serialization/serialize_parse_info.cpp @@ -14,6 +14,8 @@ #include "duckdb/parser/parsed_data/copy_database_info.hpp" #include "duckdb/parser/parsed_data/copy_info.hpp" #include "duckdb/parser/parsed_data/detach_info.hpp" +#include "duckdb/parser/parsed_data/connect_info.hpp" +#include "duckdb/parser/parsed_data/disconnect_info.hpp" #include "duckdb/parser/parsed_data/drop_info.hpp" #include "duckdb/parser/parsed_data/load_info.hpp" #include "duckdb/parser/parsed_data/update_extensions_info.hpp" @@ -41,6 +43,9 @@ unique_ptr ParseInfo::Deserialize(Deserializer &deserializer) { case ParseInfoType::BOUND_EXPORT_DATA: result = BoundExportData::Deserialize(deserializer); break; + case ParseInfoType::CONNECT_INFO: + result = ConnectInfo::Deserialize(deserializer); + break; case ParseInfoType::COPY_DATABASE_INFO: result = CopyDatabaseInfo::Deserialize(deserializer); break; @@ -50,6 +55,9 @@ unique_ptr ParseInfo::Deserialize(Deserializer &deserializer) { case ParseInfoType::DETACH_INFO: result = DetachInfo::Deserialize(deserializer); break; + case ParseInfoType::DISCONNECT_INFO: + result = DisconnectInfo::Deserialize(deserializer); + break; case ParseInfoType::DROP_INFO: result = DropInfo::Deserialize(deserializer); break; @@ -77,18 +85,18 @@ unique_ptr ParseInfo::Deserialize(Deserializer &deserializer) { void AlterInfo::Serialize(Serializer &serializer) const { ParseInfo::Serialize(serializer); serializer.WriteProperty(200, "type", type); - serializer.WritePropertyWithDefault(201, "catalog", catalog); - serializer.WritePropertyWithDefault(202, "schema", schema); - serializer.WritePropertyWithDefault(203, "name", name); + serializer.WritePropertyWithDefault(201, "catalog", catalog); + serializer.WritePropertyWithDefault(202, "schema", schema); + serializer.WritePropertyWithDefault(203, "name", name); serializer.WriteProperty(204, "if_not_found", if_not_found); serializer.WritePropertyWithDefault(205, "allow_internal", allow_internal); } unique_ptr AlterInfo::Deserialize(Deserializer &deserializer) { auto type = deserializer.ReadProperty(200, "type"); - auto catalog = deserializer.ReadPropertyWithDefault(201, "catalog"); - auto schema = deserializer.ReadPropertyWithDefault(202, "schema"); - auto name = deserializer.ReadPropertyWithDefault(203, "name"); + auto catalog = deserializer.ReadPropertyWithDefault(201, "catalog"); + auto schema = deserializer.ReadPropertyWithDefault(202, "schema"); + auto name = deserializer.ReadPropertyWithDefault(203, "name"); auto if_not_found = deserializer.ReadProperty(204, "if_not_found"); auto allow_internal = deserializer.ReadPropertyWithDefault(205, "allow_internal"); unique_ptr result; @@ -252,22 +260,22 @@ void AddFieldInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); serializer.WriteProperty(400, "new_field", new_field); serializer.WritePropertyWithDefault(401, "if_field_not_exists", if_field_not_exists); - serializer.WritePropertyWithDefault>(402, "column_path", column_path); + serializer.WritePropertyWithDefault>(402, "column_path", column_path); } unique_ptr AddFieldInfo::Deserialize(Deserializer &deserializer) { auto new_field = deserializer.ReadProperty(400, "new_field"); auto result = duckdb::unique_ptr(new AddFieldInfo(std::move(new_field))); deserializer.ReadPropertyWithDefault(401, "if_field_not_exists", result->if_field_not_exists); - deserializer.ReadPropertyWithDefault>(402, "column_path", result->column_path); + deserializer.ReadPropertyWithDefault>(402, "column_path", result->column_path); return std::move(result); } void AlterForeignKeyInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "fk_table", fk_table); - serializer.WritePropertyWithDefault>(401, "pk_columns", pk_columns); - serializer.WritePropertyWithDefault>(402, "fk_columns", fk_columns); + serializer.WritePropertyWithDefault(400, "fk_table", fk_table); + serializer.WritePropertyWithDefault>(401, "pk_columns", pk_columns); + serializer.WritePropertyWithDefault>(402, "fk_columns", fk_columns); serializer.WritePropertyWithDefault>(403, "pk_keys", pk_keys); serializer.WritePropertyWithDefault>(404, "fk_keys", fk_keys); serializer.WriteProperty(405, "alter_fk_type", type); @@ -275,9 +283,9 @@ void AlterForeignKeyInfo::Serialize(Serializer &serializer) const { unique_ptr AlterForeignKeyInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new AlterForeignKeyInfo()); - deserializer.ReadPropertyWithDefault(400, "fk_table", result->fk_table); - deserializer.ReadPropertyWithDefault>(401, "pk_columns", result->pk_columns); - deserializer.ReadPropertyWithDefault>(402, "fk_columns", result->fk_columns); + deserializer.ReadPropertyWithDefault(400, "fk_table", result->fk_table); + deserializer.ReadPropertyWithDefault>(401, "pk_columns", result->pk_columns); + deserializer.ReadPropertyWithDefault>(402, "fk_columns", result->fk_columns); deserializer.ReadPropertyWithDefault>(403, "pk_keys", result->pk_keys); deserializer.ReadPropertyWithDefault>(404, "fk_keys", result->fk_keys); deserializer.ReadProperty(405, "alter_fk_type", result->type); @@ -286,7 +294,7 @@ unique_ptr AlterForeignKeyInfo::Deserialize(Deserializer &deseri void AttachInfo::Serialize(Serializer &serializer) const { ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(200, "name", name); serializer.WritePropertyWithDefault(201, "path", path); serializer.WritePropertyWithDefault>(202, "options", options); serializer.WritePropertyWithDefault(203, "on_conflict", on_conflict, OnCreateConflict::ERROR_ON_CONFLICT); @@ -294,7 +302,7 @@ void AttachInfo::Serialize(Serializer &serializer) const { unique_ptr AttachInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new AttachInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault(200, "name", result->name); deserializer.ReadPropertyWithDefault(201, "path", result->path); deserializer.ReadPropertyWithDefault>(202, "options", result->options); deserializer.ReadPropertyWithExplicitDefault(203, "on_conflict", result->on_conflict, OnCreateConflict::ERROR_ON_CONFLICT); @@ -314,14 +322,14 @@ unique_ptr BoundExportData::Deserialize(Deserializer &deserializer) { void ChangeColumnTypeInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "column_name", column_name); + serializer.WritePropertyWithDefault(400, "column_name", column_name); serializer.WriteProperty(401, "target_type", target_type); serializer.WritePropertyWithDefault>(402, "expression", expression); } unique_ptr ChangeColumnTypeInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new ChangeColumnTypeInfo()); - deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); + deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); deserializer.ReadProperty(401, "target_type", result->target_type); deserializer.ReadPropertyWithDefault>(402, "expression", result->expression); return std::move(result); @@ -330,37 +338,52 @@ unique_ptr ChangeColumnTypeInfo::Deserialize(Deserializer &deser void ChangeOwnershipInfo::Serialize(Serializer &serializer) const { AlterInfo::Serialize(serializer); serializer.WriteProperty(300, "entry_catalog_type", entry_catalog_type); - serializer.WritePropertyWithDefault(301, "owner_schema", owner_schema); - serializer.WritePropertyWithDefault(302, "owner_name", owner_name); + serializer.WritePropertyWithDefault(301, "owner_schema", owner_schema); + serializer.WritePropertyWithDefault(302, "owner_name", owner_name); } unique_ptr ChangeOwnershipInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new ChangeOwnershipInfo()); deserializer.ReadProperty(300, "entry_catalog_type", result->entry_catalog_type); - deserializer.ReadPropertyWithDefault(301, "owner_schema", result->owner_schema); - deserializer.ReadPropertyWithDefault(302, "owner_name", result->owner_name); + deserializer.ReadPropertyWithDefault(301, "owner_schema", result->owner_schema); + deserializer.ReadPropertyWithDefault(302, "owner_name", result->owner_name); + return std::move(result); +} + +void ConnectInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(201, "name_is_string_literal", name_is_string_literal); + serializer.WritePropertyWithDefault(202, "target_is_local", target_is_local); +} + +unique_ptr ConnectInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new ConnectInfo()); + deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault(201, "name_is_string_literal", result->name_is_string_literal); + deserializer.ReadPropertyWithDefault(202, "target_is_local", result->target_is_local); return std::move(result); } void CopyDatabaseInfo::Serialize(Serializer &serializer) const { ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "target_database", target_database); + serializer.WritePropertyWithDefault(200, "target_database", target_database); serializer.WritePropertyWithDefault>>(201, "entries", entries); } unique_ptr CopyDatabaseInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new CopyDatabaseInfo()); - deserializer.ReadPropertyWithDefault(200, "target_database", result->target_database); + deserializer.ReadPropertyWithDefault(200, "target_database", result->target_database); deserializer.ReadPropertyWithDefault>>(201, "entries", result->entries); return std::move(result); } void CopyInfo::Serialize(Serializer &serializer) const { ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "catalog", catalog); - serializer.WritePropertyWithDefault(201, "schema", schema); - serializer.WritePropertyWithDefault(202, "table", table); - serializer.WritePropertyWithDefault>(203, "select_list", select_list); + serializer.WritePropertyWithDefault(200, "catalog", catalog); + serializer.WritePropertyWithDefault(201, "schema", schema); + serializer.WritePropertyWithDefault(202, "table", table); + serializer.WritePropertyWithDefault>(203, "select_list", select_list); serializer.WritePropertyWithDefault(204, "is_from", is_from); serializer.WritePropertyWithDefault(205, "format", format); serializer.WritePropertyWithDefault(206, "file_path", file_path); @@ -371,10 +394,10 @@ void CopyInfo::Serialize(Serializer &serializer) const { unique_ptr CopyInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new CopyInfo()); - deserializer.ReadPropertyWithDefault(200, "catalog", result->catalog); - deserializer.ReadPropertyWithDefault(201, "schema", result->schema); - deserializer.ReadPropertyWithDefault(202, "table", result->table); - deserializer.ReadPropertyWithDefault>(203, "select_list", result->select_list); + deserializer.ReadPropertyWithDefault(200, "catalog", result->catalog); + deserializer.ReadPropertyWithDefault(201, "schema", result->schema); + deserializer.ReadPropertyWithDefault(202, "table", result->table); + deserializer.ReadPropertyWithDefault>(203, "select_list", result->select_list); deserializer.ReadPropertyWithDefault(204, "is_from", result->is_from); deserializer.ReadPropertyWithDefault(205, "format", result->format); deserializer.ReadPropertyWithDefault(206, "file_path", result->file_path); @@ -386,23 +409,32 @@ unique_ptr CopyInfo::Deserialize(Deserializer &deserializer) { void DetachInfo::Serialize(Serializer &serializer) const { ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(200, "name", name); serializer.WriteProperty(201, "if_not_found", if_not_found); } unique_ptr DetachInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new DetachInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault(200, "name", result->name); deserializer.ReadProperty(201, "if_not_found", result->if_not_found); return std::move(result); } +void DisconnectInfo::Serialize(Serializer &serializer) const { + ParseInfo::Serialize(serializer); +} + +unique_ptr DisconnectInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new DisconnectInfo()); + return std::move(result); +} + void DropInfo::Serialize(Serializer &serializer) const { ParseInfo::Serialize(serializer); serializer.WriteProperty(200, "type", type); - serializer.WritePropertyWithDefault(201, "catalog", catalog); - serializer.WritePropertyWithDefault(202, "schema", schema); - serializer.WritePropertyWithDefault(203, "name", name); + serializer.WritePropertyWithDefault(201, "catalog", catalog); + serializer.WritePropertyWithDefault(202, "schema", schema); + serializer.WritePropertyWithDefault(203, "name", name); serializer.WriteProperty(204, "if_not_found", if_not_found); serializer.WritePropertyWithDefault(205, "cascade", cascade); serializer.WritePropertyWithDefault(206, "allow_drop_internal", allow_drop_internal); @@ -412,9 +444,9 @@ void DropInfo::Serialize(Serializer &serializer) const { unique_ptr DropInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new DropInfo()); deserializer.ReadProperty(200, "type", result->type); - deserializer.ReadPropertyWithDefault(201, "catalog", result->catalog); - deserializer.ReadPropertyWithDefault(202, "schema", result->schema); - deserializer.ReadPropertyWithDefault(203, "name", result->name); + deserializer.ReadPropertyWithDefault(201, "catalog", result->catalog); + deserializer.ReadPropertyWithDefault(202, "schema", result->schema); + deserializer.ReadPropertyWithDefault(203, "name", result->name); deserializer.ReadProperty(204, "if_not_found", result->if_not_found); deserializer.ReadPropertyWithDefault(205, "cascade", result->cascade); deserializer.ReadPropertyWithDefault(206, "allow_drop_internal", result->allow_drop_internal); @@ -424,12 +456,12 @@ unique_ptr DropInfo::Deserialize(Deserializer &deserializer) { void DropNotNullInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "column_name", column_name); + serializer.WritePropertyWithDefault(400, "column_name", column_name); } unique_ptr DropNotNullInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new DropNotNullInfo()); - deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); + deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); return std::move(result); } @@ -440,6 +472,7 @@ void LoadInfo::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(202, "repository", repository); serializer.WritePropertyWithDefault(203, "version", version); serializer.WritePropertyWithDefault(204, "repo_is_alias", repo_is_alias); + serializer.WritePropertyWithDefault(205, "alias", alias); } unique_ptr LoadInfo::Deserialize(Deserializer &deserializer) { @@ -449,19 +482,20 @@ unique_ptr LoadInfo::Deserialize(Deserializer &deserializer) { deserializer.ReadPropertyWithDefault(202, "repository", result->repository); deserializer.ReadPropertyWithDefault(203, "version", result->version); deserializer.ReadPropertyWithDefault(204, "repo_is_alias", result->repo_is_alias); + deserializer.ReadPropertyWithDefault(205, "alias", result->alias); return std::move(result); } void PragmaInfo::Serialize(Serializer &serializer) const { ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "name", name); + serializer.WritePropertyWithDefault(200, "name", name); serializer.WritePropertyWithDefault>>(201, "parameters", parameters); serializer.WritePropertyWithDefault>>(202, "named_parameters", named_parameters); } unique_ptr PragmaInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new PragmaInfo()); - deserializer.ReadPropertyWithDefault(200, "name", result->name); + deserializer.ReadPropertyWithDefault(200, "name", result->name); deserializer.ReadPropertyWithDefault>>(201, "parameters", result->parameters); deserializer.ReadPropertyWithDefault>>(202, "named_parameters", result->named_parameters); return std::move(result); @@ -469,14 +503,14 @@ unique_ptr PragmaInfo::Deserialize(Deserializer &deserializer) { void RemoveColumnInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "removed_column", removed_column); + serializer.WritePropertyWithDefault(400, "removed_column", removed_column); serializer.WritePropertyWithDefault(401, "if_column_exists", if_column_exists); serializer.WritePropertyWithDefault(402, "cascade", cascade); } unique_ptr RemoveColumnInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new RemoveColumnInfo()); - deserializer.ReadPropertyWithDefault(400, "removed_column", result->removed_column); + deserializer.ReadPropertyWithDefault(400, "removed_column", result->removed_column); deserializer.ReadPropertyWithDefault(401, "if_column_exists", result->if_column_exists); deserializer.ReadPropertyWithDefault(402, "cascade", result->cascade); return std::move(result); @@ -484,14 +518,14 @@ unique_ptr RemoveColumnInfo::Deserialize(Deserializer &deseriali void RemoveFieldInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault>(400, "column_path", column_path); + serializer.WritePropertyWithDefault>(400, "column_path", column_path); serializer.WritePropertyWithDefault(401, "if_column_exists", if_column_exists); serializer.WritePropertyWithDefault(402, "cascade", cascade); } unique_ptr RemoveFieldInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new RemoveFieldInfo()); - deserializer.ReadPropertyWithDefault>(400, "column_path", result->column_path); + deserializer.ReadPropertyWithDefault>(400, "column_path", result->column_path); deserializer.ReadPropertyWithDefault(401, "if_column_exists", result->if_column_exists); deserializer.ReadPropertyWithDefault(402, "cascade", result->cascade); return std::move(result); @@ -499,71 +533,75 @@ unique_ptr RemoveFieldInfo::Deserialize(Deserializer &deserializ void RenameColumnInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "old_name", old_name); - serializer.WritePropertyWithDefault(401, "new_name", new_name); + serializer.WritePropertyWithDefault(400, "old_name", old_name); + serializer.WritePropertyWithDefault(401, "new_name", new_name); } unique_ptr RenameColumnInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new RenameColumnInfo()); - deserializer.ReadPropertyWithDefault(400, "old_name", result->old_name); - deserializer.ReadPropertyWithDefault(401, "new_name", result->new_name); + deserializer.ReadPropertyWithDefault(400, "old_name", result->old_name); + deserializer.ReadPropertyWithDefault(401, "new_name", result->new_name); return std::move(result); } void RenameDatabaseInfo::Serialize(Serializer &serializer) const { AlterDatabaseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "new_name", new_name); + serializer.WritePropertyWithDefault(400, "new_name", new_name); } unique_ptr RenameDatabaseInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new RenameDatabaseInfo()); - deserializer.ReadPropertyWithDefault(400, "new_name", result->new_name); + deserializer.ReadPropertyWithDefault(400, "new_name", result->new_name); return std::move(result); } void RenameFieldInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault>(400, "column_path", column_path); - serializer.WritePropertyWithDefault(401, "new_name", new_name); + serializer.WritePropertyWithDefault>(400, "column_path", column_path); + serializer.WritePropertyWithDefault(401, "new_name", new_name); } unique_ptr RenameFieldInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new RenameFieldInfo()); - deserializer.ReadPropertyWithDefault>(400, "column_path", result->column_path); - deserializer.ReadPropertyWithDefault(401, "new_name", result->new_name); + deserializer.ReadPropertyWithDefault>(400, "column_path", result->column_path); + deserializer.ReadPropertyWithDefault(401, "new_name", result->new_name); return std::move(result); } void RenameTableInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "new_table_name", new_table_name); + serializer.WritePropertyWithDefault(400, "new_table_name", new_table_name); } unique_ptr RenameTableInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new RenameTableInfo()); - deserializer.ReadPropertyWithDefault(400, "new_table_name", result->new_table_name); + deserializer.ReadPropertyWithDefault(400, "new_table_name", result->new_table_name); return std::move(result); } void RenameViewInfo::Serialize(Serializer &serializer) const { AlterViewInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "new_view_name", new_view_name); + serializer.WritePropertyWithDefault(400, "new_view_name", new_view_name); } unique_ptr RenameViewInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new RenameViewInfo()); - deserializer.ReadPropertyWithDefault(400, "new_view_name", result->new_view_name); + deserializer.ReadPropertyWithDefault(400, "new_view_name", result->new_view_name); return std::move(result); } void ResetTableOptionsInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WriteProperty(400, "table_options", table_options); + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + serializer.WritePropertyWithDefault(400, "table_options", table_options); + } else { + serializer.WriteProperty(400, "table_options", table_options); + } } unique_ptr ResetTableOptionsInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new ResetTableOptionsInfo()); - deserializer.ReadProperty(400, "table_options", result->table_options); + deserializer.ReadPropertyWithDefault(400, "table_options", result->table_options); return std::move(result); } @@ -571,14 +609,14 @@ void SetColumnCommentInfo::Serialize(Serializer &serializer) const { AlterInfo::Serialize(serializer); serializer.WriteProperty(300, "catalog_entry_type", catalog_entry_type); serializer.WriteProperty(301, "comment_value", comment_value); - serializer.WritePropertyWithDefault(302, "column_name", column_name); + serializer.WritePropertyWithDefault(302, "column_name", column_name); } unique_ptr SetColumnCommentInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new SetColumnCommentInfo()); deserializer.ReadProperty(300, "catalog_entry_type", result->catalog_entry_type); deserializer.ReadProperty(301, "comment_value", result->comment_value); - deserializer.ReadPropertyWithDefault(302, "column_name", result->column_name); + deserializer.ReadPropertyWithDefault(302, "column_name", result->column_name); return std::move(result); } @@ -597,25 +635,25 @@ unique_ptr SetCommentInfo::Deserialize(Deserializer &deserializer) { void SetDefaultInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "column_name", column_name); + serializer.WritePropertyWithDefault(400, "column_name", column_name); serializer.WritePropertyWithDefault>(401, "expression", expression); } unique_ptr SetDefaultInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new SetDefaultInfo()); - deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); + deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); deserializer.ReadPropertyWithDefault>(401, "expression", result->expression); return std::move(result); } void SetNotNullInfo::Serialize(Serializer &serializer) const { AlterTableInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(400, "column_name", column_name); + serializer.WritePropertyWithDefault(400, "column_name", column_name); } unique_ptr SetNotNullInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new SetNotNullInfo()); - deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); + deserializer.ReadPropertyWithDefault(400, "column_name", result->column_name); return std::move(result); } @@ -671,12 +709,12 @@ unique_ptr TransactionInfo::Deserialize(Deserializer &deserializer) { void UpdateExtensionsInfo::Serialize(Serializer &serializer) const { ParseInfo::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "extensions_to_update", extensions_to_update); + serializer.WritePropertyWithDefault>(200, "extensions_to_update", extensions_to_update); } unique_ptr UpdateExtensionsInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new UpdateExtensionsInfo()); - deserializer.ReadPropertyWithDefault>(200, "extensions_to_update", result->extensions_to_update); + deserializer.ReadPropertyWithDefault>(200, "extensions_to_update", result->extensions_to_update); return std::move(result); } @@ -685,7 +723,7 @@ void VacuumInfo::Serialize(Serializer &serializer) const { serializer.WriteProperty(200, "options", options); serializer.WritePropertyWithDefault(201, "has_table", has_table); serializer.WritePropertyWithDefault>(202, "ref", ref); - serializer.WritePropertyWithDefault>(203, "columns", columns); + serializer.WritePropertyWithDefault>(203, "columns", columns); } unique_ptr VacuumInfo::Deserialize(Deserializer &deserializer) { @@ -693,7 +731,7 @@ unique_ptr VacuumInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new VacuumInfo(options)); deserializer.ReadPropertyWithDefault(201, "has_table", result->has_table); deserializer.ReadPropertyWithDefault>(202, "ref", result->ref); - deserializer.ReadPropertyWithDefault>(203, "columns", result->columns); + deserializer.ReadPropertyWithDefault>(203, "columns", result->columns); return std::move(result); } diff --git a/src/duckdb/src/storage/serialization/serialize_parsed_expression.cpp b/src/duckdb/src/storage/serialization/serialize_parsed_expression.cpp index 1ddbe07a7..167775b1b 100644 --- a/src/duckdb/src/storage/serialization/serialize_parsed_expression.cpp +++ b/src/duckdb/src/storage/serialization/serialize_parsed_expression.cpp @@ -12,14 +12,14 @@ namespace duckdb { void ParsedExpression::Serialize(Serializer &serializer) const { serializer.WriteProperty(100, "class", expression_class); serializer.WriteProperty(101, "type", type); - serializer.WritePropertyWithDefault(102, "alias", alias); + serializer.WritePropertyWithDefault(102, "alias", alias); serializer.WritePropertyWithDefault(103, "query_location", query_location, optional_idx()); } unique_ptr ParsedExpression::Deserialize(Deserializer &deserializer) { auto expression_class = deserializer.ReadProperty(100, "class"); auto type = deserializer.ReadProperty(101, "type"); - auto alias = deserializer.ReadPropertyWithDefault(102, "alias"); + auto alias = deserializer.ReadPropertyWithDefault(102, "alias"); auto query_location = deserializer.ReadPropertyWithExplicitDefault(103, "query_location", optional_idx()); deserializer.Set(type); unique_ptr result; @@ -148,12 +148,12 @@ unique_ptr CollateExpression::Deserialize(Deserializer &deseri void ColumnRefExpression::Serialize(Serializer &serializer) const { ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "column_names", column_names); + serializer.WritePropertyWithDefault>(200, "column_names", column_names); } unique_ptr ColumnRefExpression::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new ColumnRefExpression()); - deserializer.ReadPropertyWithDefault>(200, "column_names", result->column_names); + deserializer.ReadPropertyWithDefault>(200, "column_names", result->column_names); return std::move(result); } @@ -201,39 +201,11 @@ unique_ptr DefaultExpression::Deserialize(Deserializer &deseri return std::move(result); } -void FunctionExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "function_name", function_name); - serializer.WritePropertyWithDefault(201, "schema", schema); - serializer.WritePropertyWithDefault>>(202, "children", children); - serializer.WritePropertyWithDefault>(203, "filter", filter); - serializer.WritePropertyWithDefault>(204, "order_bys", order_bys); - serializer.WritePropertyWithDefault(205, "distinct", distinct); - serializer.WritePropertyWithDefault(206, "is_operator", is_operator); - serializer.WritePropertyWithDefault(207, "export_state", export_state); - serializer.WritePropertyWithDefault(208, "catalog", catalog); -} - -unique_ptr FunctionExpression::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new FunctionExpression()); - deserializer.ReadPropertyWithDefault(200, "function_name", result->function_name); - deserializer.ReadPropertyWithDefault(201, "schema", result->schema); - deserializer.ReadPropertyWithDefault>>(202, "children", result->children); - deserializer.ReadPropertyWithDefault>(203, "filter", result->filter); - auto order_bys = deserializer.ReadPropertyWithDefault>(204, "order_bys"); - result->order_bys = unique_ptr_cast(std::move(order_bys)); - deserializer.ReadPropertyWithDefault(205, "distinct", result->distinct); - deserializer.ReadPropertyWithDefault(206, "is_operator", result->is_operator); - deserializer.ReadPropertyWithDefault(207, "export_state", result->export_state); - deserializer.ReadPropertyWithDefault(208, "catalog", result->catalog); - return std::move(result); -} - void LambdaExpression::Serialize(Serializer &serializer) const { ParsedExpression::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "lhs", lhs); serializer.WritePropertyWithDefault>(201, "expr", expr); - if (serializer.ShouldSerialize(5)) { + if (serializer.ShouldSerialize(StorageVersion::V1_3_0)) { serializer.WritePropertyWithDefault(202, "syntax_type", syntax_type, LambdaSyntaxType::SINGLE_ARROW_STORAGE); } } @@ -249,12 +221,12 @@ unique_ptr LambdaExpression::Deserialize(Deserializer &deseria void LambdaRefExpression::Serialize(Serializer &serializer) const { ParsedExpression::Serialize(serializer); serializer.WritePropertyWithDefault(200, "lambda_idx", lambda_idx); - serializer.WritePropertyWithDefault(201, "column_name", column_name); + serializer.WritePropertyWithDefault(201, "column_name", column_name); } unique_ptr LambdaRefExpression::Deserialize(Deserializer &deserializer) { auto lambda_idx = deserializer.ReadPropertyWithDefault(200, "lambda_idx"); - auto column_name = deserializer.ReadPropertyWithDefault(201, "column_name"); + auto column_name = deserializer.ReadPropertyWithDefault(201, "column_name"); auto result = duckdb::unique_ptr(new LambdaRefExpression(lambda_idx, std::move(column_name))); return std::move(result); } @@ -272,12 +244,12 @@ unique_ptr OperatorExpression::Deserialize(Deserializer &deser void ParameterExpression::Serialize(Serializer &serializer) const { ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "identifier", identifier); + serializer.WritePropertyWithDefault(200, "identifier", identifier); } unique_ptr ParameterExpression::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new ParameterExpression()); - deserializer.ReadPropertyWithDefault(200, "identifier", result->identifier); + deserializer.ReadPropertyWithDefault(200, "identifier", result->identifier); return std::move(result); } @@ -294,25 +266,29 @@ unique_ptr PositionalReferenceExpression::Deserialize(Deserial void StarExpression::Serialize(Serializer &serializer) const { ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "relation_name", relation_name); - serializer.WriteProperty(201, "exclude_list", SerializedExcludeList()); - serializer.WritePropertyWithDefault>>(202, "replace_list", replace_list); + serializer.WritePropertyWithDefault(200, "relation_name", relation_name); + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + serializer.WritePropertyWithDefault(201, "exclude_list", SerializedExcludeList()); + } else { + serializer.WriteProperty(201, "exclude_list", SerializedExcludeList()); + } + serializer.WritePropertyWithDefault>>(202, "replace_list", replace_list); serializer.WritePropertyWithDefault(203, "columns", columns); serializer.WritePropertyWithDefault>(204, "expr", expr); /* [Deleted] (bool) "unpacked" */ serializer.WritePropertyWithDefault(206, "qualified_exclude_list", SerializedQualifiedExcludeList(), qualified_column_set_t()); - serializer.WritePropertyWithDefault>(207, "rename_list", rename_list, qualified_column_map_t()); + serializer.WritePropertyWithDefault>(207, "rename_list", rename_list, qualified_column_map_t()); } unique_ptr StarExpression::Deserialize(Deserializer &deserializer) { - auto relation_name = deserializer.ReadPropertyWithDefault(200, "relation_name"); - auto exclude_list = deserializer.ReadProperty(201, "exclude_list"); - auto replace_list = deserializer.ReadPropertyWithDefault>>(202, "replace_list"); + auto relation_name = deserializer.ReadPropertyWithDefault(200, "relation_name"); + auto exclude_list = deserializer.ReadPropertyWithDefault(201, "exclude_list"); + auto replace_list = deserializer.ReadPropertyWithDefault>>(202, "replace_list"); auto columns = deserializer.ReadPropertyWithDefault(203, "columns"); auto expr = deserializer.ReadPropertyWithDefault>(204, "expr"); auto unpacked = deserializer.ReadPropertyWithExplicitDefault(205, "unpacked", false); auto qualified_exclude_list = deserializer.ReadPropertyWithExplicitDefault(206, "qualified_exclude_list", qualified_column_set_t()); - auto rename_list = deserializer.ReadPropertyWithExplicitDefault>(207, "rename_list", qualified_column_map_t()); + auto rename_list = deserializer.ReadPropertyWithExplicitDefault>(207, "rename_list", qualified_column_map_t()); auto result = StarExpression::DeserializeStarExpression(std::move(relation_name), exclude_list, std::move(replace_list), columns, std::move(expr), unpacked, qualified_exclude_list, std::move(rename_list)); return result; } @@ -336,75 +312,19 @@ unique_ptr SubqueryExpression::Deserialize(Deserializer &deser void TypeExpression::Serialize(Serializer &serializer) const { ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "catalog", catalog); - serializer.WritePropertyWithDefault(201, "schema", schema); - serializer.WritePropertyWithDefault(202, "type_name", type_name); + serializer.WritePropertyWithDefault(200, "catalog", catalog); + serializer.WritePropertyWithDefault(201, "schema", schema); + serializer.WritePropertyWithDefault(202, "type_name", type_name); serializer.WritePropertyWithDefault>>(203, "children", children); } unique_ptr TypeExpression::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new TypeExpression()); - deserializer.ReadPropertyWithDefault(200, "catalog", result->catalog); - deserializer.ReadPropertyWithDefault(201, "schema", result->schema); - deserializer.ReadPropertyWithDefault(202, "type_name", result->type_name); + deserializer.ReadPropertyWithDefault(200, "catalog", result->catalog); + deserializer.ReadPropertyWithDefault(201, "schema", result->schema); + deserializer.ReadPropertyWithDefault(202, "type_name", result->type_name); deserializer.ReadPropertyWithDefault>>(203, "children", result->children); return std::move(result); } -void WindowExpression::Serialize(Serializer &serializer) const { - ParsedExpression::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "function_name", function_name); - serializer.WritePropertyWithDefault(201, "schema", schema); - serializer.WritePropertyWithDefault(202, "catalog", catalog); - serializer.WritePropertyWithDefault>>(203, "children", SerializedChildren(serializer)); - serializer.WritePropertyWithDefault>>(204, "partitions", partitions); - serializer.WritePropertyWithDefault>(205, "orders", orders); - serializer.WriteProperty(206, "start", start); - serializer.WriteProperty(207, "end", end); - serializer.WritePropertyWithDefault>(208, "start_expr", start_expr); - serializer.WritePropertyWithDefault>(209, "end_expr", end_expr); - serializer.WritePropertyWithDefault>(210, "offset_expr", SerializedOffset(serializer)); - serializer.WritePropertyWithDefault>(211, "default_expr", SerializedDefault(serializer)); - serializer.WritePropertyWithDefault(212, "ignore_nulls", ignore_nulls); - serializer.WritePropertyWithDefault>(213, "filter_expr", filter_expr); - serializer.WritePropertyWithDefault(214, "exclude_clause", exclude_clause, WindowExcludeMode::NO_OTHER); - serializer.WritePropertyWithDefault(215, "distinct", distinct); - serializer.WritePropertyWithDefault>(216, "arg_orders", arg_orders); - if (serializer.ShouldSerialize(8)) { - serializer.WritePropertyWithDefault(217, "has_ignore_nulls", has_ignore_nulls); - } -} - -unique_ptr WindowExpression::Deserialize(Deserializer &deserializer) { - auto function_name = deserializer.ReadPropertyWithDefault(200, "function_name"); - auto schema = deserializer.ReadPropertyWithDefault(201, "schema"); - auto catalog = deserializer.ReadPropertyWithDefault(202, "catalog"); - auto children = deserializer.ReadPropertyWithDefault>>(203, "children"); - auto partitions = deserializer.ReadPropertyWithDefault>>(204, "partitions"); - auto orders = deserializer.ReadPropertyWithDefault>(205, "orders"); - auto start = deserializer.ReadProperty(206, "start"); - auto end = deserializer.ReadProperty(207, "end"); - auto start_expr = deserializer.ReadPropertyWithDefault>(208, "start_expr"); - auto end_expr = deserializer.ReadPropertyWithDefault>(209, "end_expr"); - auto offset_expr = deserializer.ReadPropertyWithDefault>(210, "offset_expr"); - auto default_expr = deserializer.ReadPropertyWithDefault>(211, "default_expr"); - auto result = duckdb::unique_ptr(new WindowExpression(deserializer.Get(), std::move(children), std::move(offset_expr), std::move(default_expr))); - result->function_name = std::move(function_name); - result->schema = std::move(schema); - result->catalog = std::move(catalog); - result->partitions = std::move(partitions); - result->orders = std::move(orders); - result->start = start; - result->end = end; - result->start_expr = std::move(start_expr); - result->end_expr = std::move(end_expr); - deserializer.ReadPropertyWithDefault(212, "ignore_nulls", result->ignore_nulls); - deserializer.ReadPropertyWithDefault>(213, "filter_expr", result->filter_expr); - deserializer.ReadPropertyWithExplicitDefault(214, "exclude_clause", result->exclude_clause, WindowExcludeMode::NO_OTHER); - deserializer.ReadPropertyWithDefault(215, "distinct", result->distinct); - deserializer.ReadPropertyWithDefault>(216, "arg_orders", result->arg_orders); - deserializer.ReadPropertyWithDefault(217, "has_ignore_nulls", result->has_ignore_nulls); - return std::move(result); -} - } // namespace duckdb diff --git a/src/duckdb/src/storage/serialization/serialize_query_node.cpp b/src/duckdb/src/storage/serialization/serialize_query_node.cpp index 1370e697f..fd2a7d9d4 100644 --- a/src/duckdb/src/storage/serialization/serialize_query_node.cpp +++ b/src/duckdb/src/storage/serialization/serialize_query_node.cpp @@ -11,6 +11,8 @@ #include "duckdb/parser/query_node/insert_query_node.hpp" #include "duckdb/parser/statement/insert_statement.hpp" #include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/query_node/merge_query_node.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" namespace duckdb { @@ -35,6 +37,9 @@ unique_ptr QueryNode::Deserialize(Deserializer &deserializer) { case QueryNodeType::INSERT_QUERY_NODE: result = InsertQueryNode::Deserialize(deserializer); break; + case QueryNodeType::MERGE_QUERY_NODE: + result = MergeQueryNode::Deserialize(deserializer); + break; case QueryNodeType::RECURSIVE_CTE_NODE: result = RecursiveCTENode::Deserialize(deserializer); break; @@ -60,19 +65,19 @@ unique_ptr QueryNode::Deserialize(Deserializer &deserializer) { void CTENode::Serialize(Serializer &serializer) const { QueryNode::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "cte_name", ctename); + serializer.WritePropertyWithDefault(200, "cte_name", ctename); serializer.WritePropertyWithDefault>(201, "query", query); serializer.WritePropertyWithDefault>(202, "child", child); - serializer.WritePropertyWithDefault>(203, "aliases", aliases); + serializer.WritePropertyWithDefault>(203, "aliases", aliases); serializer.WritePropertyWithDefault(204, "materialized", materialized, CTEMaterialize::CTE_MATERIALIZE_DEFAULT); } unique_ptr CTENode::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new CTENode()); - deserializer.ReadPropertyWithDefault(200, "cte_name", result->ctename); + deserializer.ReadPropertyWithDefault(200, "cte_name", result->ctename); deserializer.ReadPropertyWithDefault>(201, "query", result->query); deserializer.ReadPropertyWithDefault>(202, "child", result->child); - deserializer.ReadPropertyWithDefault>(203, "aliases", result->aliases); + deserializer.ReadPropertyWithDefault>(203, "aliases", result->aliases); deserializer.ReadPropertyWithExplicitDefault(204, "materialized", result->materialized, CTEMaterialize::CTE_MATERIALIZE_DEFAULT); return std::move(result); } @@ -97,10 +102,10 @@ unique_ptr DeleteQueryNode::Deserialize(Deserializer &deserializer) { void InsertQueryNode::Serialize(Serializer &serializer) const { QueryNode::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "select_statement", select_statement); - serializer.WritePropertyWithDefault>(201, "columns", columns); - serializer.WritePropertyWithDefault(202, "table", table); - serializer.WritePropertyWithDefault(203, "schema", schema); - serializer.WritePropertyWithDefault(204, "catalog", catalog); + serializer.WritePropertyWithDefault>(201, "columns", columns); + serializer.WritePropertyWithDefault(202, "table", table); + serializer.WritePropertyWithDefault(203, "schema", schema); + serializer.WritePropertyWithDefault(204, "catalog", catalog); serializer.WritePropertyWithDefault>>(205, "returning_list", returning_list); serializer.WritePropertyWithDefault>(206, "on_conflict_info", on_conflict_info); serializer.WritePropertyWithDefault>(207, "table_ref", table_ref); @@ -111,10 +116,10 @@ void InsertQueryNode::Serialize(Serializer &serializer) const { unique_ptr InsertQueryNode::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new InsertQueryNode()); deserializer.ReadPropertyWithDefault>(200, "select_statement", result->select_statement); - deserializer.ReadPropertyWithDefault>(201, "columns", result->columns); - deserializer.ReadPropertyWithDefault(202, "table", result->table); - deserializer.ReadPropertyWithDefault(203, "schema", result->schema); - deserializer.ReadPropertyWithDefault(204, "catalog", result->catalog); + deserializer.ReadPropertyWithDefault>(201, "columns", result->columns); + deserializer.ReadPropertyWithDefault(202, "table", result->table); + deserializer.ReadPropertyWithDefault(203, "schema", result->schema); + deserializer.ReadPropertyWithDefault(204, "catalog", result->catalog); deserializer.ReadPropertyWithDefault>>(205, "returning_list", result->returning_list); deserializer.ReadPropertyWithDefault>(206, "on_conflict_info", result->on_conflict_info); deserializer.ReadPropertyWithDefault>(207, "table_ref", result->table_ref); @@ -123,23 +128,44 @@ unique_ptr InsertQueryNode::Deserialize(Deserializer &deserializer) { return std::move(result); } +void MergeQueryNode::Serialize(Serializer &serializer) const { + QueryNode::Serialize(serializer); + serializer.WritePropertyWithDefault>(200, "target", target); + serializer.WritePropertyWithDefault>(201, "source", source); + serializer.WritePropertyWithDefault>(202, "join_condition", join_condition); + serializer.WritePropertyWithDefault>(203, "using_columns", using_columns); + serializer.WritePropertyWithDefault>>>(204, "actions", actions); + serializer.WritePropertyWithDefault>>(205, "returning_list", returning_list); +} + +unique_ptr MergeQueryNode::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new MergeQueryNode()); + deserializer.ReadPropertyWithDefault>(200, "target", result->target); + deserializer.ReadPropertyWithDefault>(201, "source", result->source); + deserializer.ReadPropertyWithDefault>(202, "join_condition", result->join_condition); + deserializer.ReadPropertyWithDefault>(203, "using_columns", result->using_columns); + deserializer.ReadPropertyWithDefault>>>(204, "actions", result->actions); + deserializer.ReadPropertyWithDefault>>(205, "returning_list", result->returning_list); + return std::move(result); +} + void RecursiveCTENode::Serialize(Serializer &serializer) const { QueryNode::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "cte_name", ctename); + serializer.WritePropertyWithDefault(200, "cte_name", ctename); serializer.WritePropertyWithDefault(201, "union_all", union_all, false); serializer.WritePropertyWithDefault>(202, "left", left); serializer.WritePropertyWithDefault>(203, "right", right); - serializer.WritePropertyWithDefault>(204, "aliases", aliases); + serializer.WritePropertyWithDefault>(204, "aliases", aliases); serializer.WritePropertyWithDefault>>(205, "key_targets", key_targets); } unique_ptr RecursiveCTENode::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new RecursiveCTENode()); - deserializer.ReadPropertyWithDefault(200, "cte_name", result->ctename); + deserializer.ReadPropertyWithDefault(200, "cte_name", result->ctename); deserializer.ReadPropertyWithExplicitDefault(201, "union_all", result->union_all, false); deserializer.ReadPropertyWithDefault>(202, "left", result->left); deserializer.ReadPropertyWithDefault>(203, "right", result->right); - deserializer.ReadPropertyWithDefault>(204, "aliases", result->aliases); + deserializer.ReadPropertyWithDefault>(204, "aliases", result->aliases); deserializer.ReadPropertyWithDefault>>(205, "key_targets", result->key_targets); return std::move(result); } @@ -177,7 +203,7 @@ void SetOperationNode::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(201, "left", SerializeChildNode(serializer, 0)); serializer.WritePropertyWithDefault>(202, "right", SerializeChildNode(serializer, 1)); serializer.WritePropertyWithDefault(203, "setop_all", setop_all, true); - if (serializer.ShouldSerialize(7)) { + if (serializer.ShouldSerialize(StorageVersion::V1_5_0)) { serializer.WritePropertyWithDefault>>(204, "children", children); } } diff --git a/src/duckdb/src/storage/serialization/serialize_result_modifier.cpp b/src/duckdb/src/storage/serialization/serialize_result_modifier.cpp index 261f106c6..35c009da7 100644 --- a/src/duckdb/src/storage/serialization/serialize_result_modifier.cpp +++ b/src/duckdb/src/storage/serialization/serialize_result_modifier.cpp @@ -21,12 +21,12 @@ unique_ptr ResultModifier::Deserialize(Deserializer &deserialize case ResultModifierType::DISTINCT_MODIFIER: result = DistinctModifier::Deserialize(deserializer); break; + case ResultModifierType::LEGACY_LIMIT_PERCENT_MODIFIER: + result = LegacyLimitPercentModifier::Deserialize(deserializer); + break; case ResultModifierType::LIMIT_MODIFIER: result = LimitModifier::Deserialize(deserializer); break; - case ResultModifierType::LIMIT_PERCENT_MODIFIER: - result = LimitPercentModifier::Deserialize(deserializer); - break; case ResultModifierType::ORDER_MODIFIER: result = OrderModifier::Deserialize(deserializer); break; @@ -57,29 +57,35 @@ unique_ptr DistinctModifier::Deserialize(Deserializer &deseriali return std::move(result); } -void LimitModifier::Serialize(Serializer &serializer) const { +void LegacyLimitPercentModifier::Serialize(Serializer &serializer) const { ResultModifier::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "limit", limit); serializer.WritePropertyWithDefault>(201, "offset", offset); } -unique_ptr LimitModifier::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LimitModifier()); - deserializer.ReadPropertyWithDefault>(200, "limit", result->limit); - deserializer.ReadPropertyWithDefault>(201, "offset", result->offset); - return std::move(result); +unique_ptr LegacyLimitPercentModifier::Deserialize(Deserializer &deserializer) { + auto limit = deserializer.ReadPropertyWithDefault>(200, "limit"); + auto offset = deserializer.ReadPropertyWithDefault>(201, "offset"); + auto result = LegacyLimitPercentModifier::DeserializeLegacyLimitPercentModifier(std::move(limit), std::move(offset)); + return result; } -void LimitPercentModifier::Serialize(Serializer &serializer) const { +void LimitModifier::Serialize(Serializer &serializer) const { + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0) && UseLegacySerialization()) { + LegacySerialize(serializer); + return; + } ResultModifier::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "limit", limit); serializer.WritePropertyWithDefault>(201, "offset", offset); + serializer.WritePropertyWithDefault(202, "limit_type", limit_type, LimitValueType::ROW_COUNT); } -unique_ptr LimitPercentModifier::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new LimitPercentModifier()); +unique_ptr LimitModifier::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LimitModifier()); deserializer.ReadPropertyWithDefault>(200, "limit", result->limit); deserializer.ReadPropertyWithDefault>(201, "offset", result->offset); + deserializer.ReadPropertyWithExplicitDefault(202, "limit_type", result->limit_type, LimitValueType::ROW_COUNT); return std::move(result); } diff --git a/src/duckdb/src/storage/serialization/serialize_statement.cpp b/src/duckdb/src/storage/serialization/serialize_statement.cpp index c09a121a6..cd11ffad7 100644 --- a/src/duckdb/src/storage/serialization/serialize_statement.cpp +++ b/src/duckdb/src/storage/serialization/serialize_statement.cpp @@ -8,12 +8,37 @@ #include "duckdb/parser/statement/select_statement.hpp" #include "duckdb/parser/statement/update_statement.hpp" #include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" namespace duckdb { +void MergeIntoAction::Serialize(Serializer &serializer) const { + serializer.WriteProperty(100, "action_type", action_type); + serializer.WritePropertyWithDefault>(101, "condition", condition); + serializer.WritePropertyWithDefault>(102, "update_info", update_info); + serializer.WritePropertyWithDefault>(103, "insert_columns", insert_columns); + serializer.WritePropertyWithDefault>>(104, "expressions", expressions); + serializer.WritePropertyWithDefault(105, "column_order", column_order, InsertColumnOrder::INSERT_BY_POSITION); + serializer.WritePropertyWithDefault(106, "default_values", default_values, false); + serializer.WritePropertyWithDefault(107, "exclude_columns", exclude_columns); +} + +unique_ptr MergeIntoAction::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new MergeIntoAction()); + deserializer.ReadProperty(100, "action_type", result->action_type); + deserializer.ReadPropertyWithDefault>(101, "condition", result->condition); + deserializer.ReadPropertyWithDefault>(102, "update_info", result->update_info); + deserializer.ReadPropertyWithDefault>(103, "insert_columns", result->insert_columns); + deserializer.ReadPropertyWithDefault>>(104, "expressions", result->expressions); + deserializer.ReadPropertyWithExplicitDefault(105, "column_order", result->column_order, InsertColumnOrder::INSERT_BY_POSITION); + deserializer.ReadPropertyWithExplicitDefault(106, "default_values", result->default_values, false); + deserializer.ReadPropertyWithDefault(107, "exclude_columns", result->exclude_columns); + return result; +} + void OnConflictInfo::Serialize(Serializer &serializer) const { serializer.WriteProperty(100, "action_type", action_type); - serializer.WritePropertyWithDefault>(101, "indexed_columns", indexed_columns); + serializer.WritePropertyWithDefault>(101, "indexed_columns", indexed_columns); serializer.WritePropertyWithDefault>(102, "set_info", set_info); serializer.WritePropertyWithDefault>(103, "condition", condition); } @@ -21,7 +46,7 @@ void OnConflictInfo::Serialize(Serializer &serializer) const { unique_ptr OnConflictInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new OnConflictInfo()); deserializer.ReadProperty(100, "action_type", result->action_type); - deserializer.ReadPropertyWithDefault>(101, "indexed_columns", result->indexed_columns); + deserializer.ReadPropertyWithDefault>(101, "indexed_columns", result->indexed_columns); deserializer.ReadPropertyWithDefault>(102, "set_info", result->set_info); deserializer.ReadPropertyWithDefault>(103, "condition", result->condition); return result; @@ -29,26 +54,26 @@ unique_ptr OnConflictInfo::Deserialize(Deserializer &deserialize void SelectStatement::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(100, "node", node); - serializer.WritePropertyWithDefault>(101, "named_param_map", named_param_map); + serializer.WritePropertyWithDefault>(101, "named_param_map", named_param_map); } unique_ptr SelectStatement::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new SelectStatement()); deserializer.ReadPropertyWithDefault>(100, "node", result->node); - deserializer.ReadPropertyWithDefault>(101, "named_param_map", result->named_param_map); + deserializer.ReadPropertyWithDefault>(101, "named_param_map", result->named_param_map); return result; } void UpdateSetInfo::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(100, "condition", condition); - serializer.WritePropertyWithDefault>(101, "columns", columns); + serializer.WritePropertyWithDefault>(101, "columns", columns); serializer.WritePropertyWithDefault>>(102, "expressions", expressions); } unique_ptr UpdateSetInfo::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new UpdateSetInfo()); deserializer.ReadPropertyWithDefault>(100, "condition", result->condition); - deserializer.ReadPropertyWithDefault>(101, "columns", result->columns); + deserializer.ReadPropertyWithDefault>(101, "columns", result->columns); deserializer.ReadPropertyWithDefault>>(102, "expressions", result->expressions); return result; } diff --git a/src/duckdb/src/storage/serialization/serialize_storage.cpp b/src/duckdb/src/storage/serialization/serialize_storage.cpp index a6964ed2b..55e962209 100644 --- a/src/duckdb/src/storage/serialization/serialize_storage.cpp +++ b/src/duckdb/src/storage/serialization/serialize_storage.cpp @@ -25,7 +25,7 @@ BlockPointer BlockPointer::Deserialize(Deserializer &deserializer) { } void DataPointer::Serialize(Serializer &serializer) const { - if (!serializer.ShouldSerialize(7)) { + if (!serializer.ShouldSerialize(StorageVersion::V1_5_0)) { serializer.WritePropertyWithDefault(100, "row_start", row_start); } serializer.WritePropertyWithDefault(101, "tuple_count", tuple_count); @@ -86,7 +86,7 @@ FixedSizeAllocatorInfo FixedSizeAllocatorInfo::Deserialize(Deserializer &deseria } void IndexStorageInfo::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "name", name); + serializer.WritePropertyWithDefault(100, "name", name); serializer.WritePropertyWithDefault(101, "root", root); serializer.WritePropertyWithDefault>(102, "allocator_infos", allocator_infos); serializer.WritePropertyWithDefault>(103, "options", options, case_insensitive_map_t()); @@ -94,7 +94,7 @@ void IndexStorageInfo::Serialize(Serializer &serializer) const { IndexStorageInfo IndexStorageInfo::Deserialize(Deserializer &deserializer) { IndexStorageInfo result; - deserializer.ReadPropertyWithDefault(100, "name", result.name); + deserializer.ReadPropertyWithDefault(100, "name", result.name); deserializer.ReadPropertyWithDefault(101, "root", result.root); deserializer.ReadPropertyWithDefault>(102, "allocator_infos", result.allocator_infos); deserializer.ReadPropertyWithExplicitDefault>(103, "options", result.options, case_insensitive_map_t()); diff --git a/src/duckdb/src/storage/serialization/serialize_table_filter.cpp b/src/duckdb/src/storage/serialization/serialize_table_filter.cpp index 296b91eea..7d9e20715 100644 --- a/src/duckdb/src/storage/serialization/serialize_table_filter.cpp +++ b/src/duckdb/src/storage/serialization/serialize_table_filter.cpp @@ -25,35 +25,35 @@ unique_ptr TableFilter::Deserialize(Deserializer &deserializer) { auto filter_type = deserializer.ReadProperty(100, "filter_type"); unique_ptr result; switch (filter_type) { - case TableFilterType::CONJUNCTION_AND: - result = ConjunctionAndFilter::Deserialize(deserializer); + case TableFilterType::EXPRESSION_FILTER: + result = ExpressionFilter::Deserialize(deserializer); break; - case TableFilterType::CONJUNCTION_OR: - result = ConjunctionOrFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_CONJUNCTION_AND: + result = LegacyConjunctionAndFilter::Deserialize(deserializer); break; - case TableFilterType::CONSTANT_COMPARISON: - result = ConstantFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_CONJUNCTION_OR: + result = LegacyConjunctionOrFilter::Deserialize(deserializer); break; - case TableFilterType::DYNAMIC_FILTER: - result = DynamicFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_CONSTANT_COMPARISON: + result = LegacyConstantFilter::Deserialize(deserializer); break; - case TableFilterType::EXPRESSION_FILTER: - result = ExpressionFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_DYNAMIC_FILTER: + result = LegacyDynamicFilter::Deserialize(deserializer); break; - case TableFilterType::IN_FILTER: - result = InFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_IN_FILTER: + result = LegacyInFilter::Deserialize(deserializer); break; - case TableFilterType::IS_NOT_NULL: - result = IsNotNullFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_IS_NOT_NULL: + result = LegacyIsNotNullFilter::Deserialize(deserializer); break; - case TableFilterType::IS_NULL: - result = IsNullFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_IS_NULL: + result = LegacyIsNullFilter::Deserialize(deserializer); break; - case TableFilterType::OPTIONAL_FILTER: - result = OptionalFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_OPTIONAL_FILTER: + result = LegacyOptionalFilter::Deserialize(deserializer); break; - case TableFilterType::STRUCT_EXTRACT: - result = StructFilter::Deserialize(deserializer); + case TableFilterType::LEGACY_STRUCT_EXTRACT: + result = LegacyStructFilter::Deserialize(deserializer); break; default: throw SerializationException("Unsupported type for deserialization of TableFilter!"); @@ -61,113 +61,113 @@ unique_ptr TableFilter::Deserialize(Deserializer &deserializer) { return result; } -void ConjunctionAndFilter::Serialize(Serializer &serializer) const { +void ExpressionFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); - serializer.WritePropertyWithDefault>>(200, "child_filters", child_filters); + serializer.WritePropertyWithDefault>(200, "expr", expr); } -unique_ptr ConjunctionAndFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ConjunctionAndFilter()); - deserializer.ReadPropertyWithDefault>>(200, "child_filters", result->child_filters); +unique_ptr ExpressionFilter::Deserialize(Deserializer &deserializer) { + auto expr = deserializer.ReadPropertyWithDefault>(200, "expr"); + auto result = duckdb::unique_ptr(new ExpressionFilter(std::move(expr))); return std::move(result); } -void ConjunctionOrFilter::Serialize(Serializer &serializer) const { +void LegacyConjunctionAndFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); serializer.WritePropertyWithDefault>>(200, "child_filters", child_filters); } -unique_ptr ConjunctionOrFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new ConjunctionOrFilter()); +unique_ptr LegacyConjunctionAndFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LegacyConjunctionAndFilter()); deserializer.ReadPropertyWithDefault>>(200, "child_filters", result->child_filters); return std::move(result); } -void ConstantFilter::Serialize(Serializer &serializer) const { +void LegacyConjunctionOrFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); - serializer.WriteProperty(200, "comparison_type", comparison_type); - serializer.WriteProperty(201, "constant", constant); + serializer.WritePropertyWithDefault>>(200, "child_filters", child_filters); } -unique_ptr ConstantFilter::Deserialize(Deserializer &deserializer) { - auto comparison_type = deserializer.ReadProperty(200, "comparison_type"); - auto constant = deserializer.ReadProperty(201, "constant"); - auto result = duckdb::unique_ptr(new ConstantFilter(comparison_type, constant)); +unique_ptr LegacyConjunctionOrFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LegacyConjunctionOrFilter()); + deserializer.ReadPropertyWithDefault>>(200, "child_filters", result->child_filters); return std::move(result); } -void DynamicFilter::Serialize(Serializer &serializer) const { +void LegacyConstantFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); + serializer.WriteProperty(200, "comparison_type", comparison_type); + serializer.WriteProperty(201, "constant", constant); } -unique_ptr DynamicFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new DynamicFilter()); +unique_ptr LegacyConstantFilter::Deserialize(Deserializer &deserializer) { + auto comparison_type = deserializer.ReadProperty(200, "comparison_type"); + auto constant = deserializer.ReadProperty(201, "constant"); + auto result = duckdb::unique_ptr(new LegacyConstantFilter(comparison_type, constant)); return std::move(result); } -void ExpressionFilter::Serialize(Serializer &serializer) const { +void LegacyDynamicFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "expr", expr); } -unique_ptr ExpressionFilter::Deserialize(Deserializer &deserializer) { - auto expr = deserializer.ReadPropertyWithDefault>(200, "expr"); - auto result = duckdb::unique_ptr(new ExpressionFilter(std::move(expr))); +unique_ptr LegacyDynamicFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LegacyDynamicFilter()); return std::move(result); } -void InFilter::Serialize(Serializer &serializer) const { +void LegacyInFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "values", values); } -unique_ptr InFilter::Deserialize(Deserializer &deserializer) { +unique_ptr LegacyInFilter::Deserialize(Deserializer &deserializer) { auto values = deserializer.ReadPropertyWithDefault>(200, "values"); - auto result = duckdb::unique_ptr(new InFilter(std::move(values))); + auto result = duckdb::unique_ptr(new LegacyInFilter(std::move(values))); return std::move(result); } -void IsNotNullFilter::Serialize(Serializer &serializer) const { +void LegacyIsNotNullFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); } -unique_ptr IsNotNullFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new IsNotNullFilter()); +unique_ptr LegacyIsNotNullFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LegacyIsNotNullFilter()); return std::move(result); } -void IsNullFilter::Serialize(Serializer &serializer) const { +void LegacyIsNullFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); } -unique_ptr IsNullFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new IsNullFilter()); +unique_ptr LegacyIsNullFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LegacyIsNullFilter()); return std::move(result); } -void OptionalFilter::Serialize(Serializer &serializer) const { +void LegacyOptionalFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "child_filter", child_filter); } -unique_ptr OptionalFilter::Deserialize(Deserializer &deserializer) { - auto result = duckdb::unique_ptr(new OptionalFilter()); +unique_ptr LegacyOptionalFilter::Deserialize(Deserializer &deserializer) { + auto result = duckdb::unique_ptr(new LegacyOptionalFilter()); deserializer.ReadPropertyWithDefault>(200, "child_filter", result->child_filter); return std::move(result); } -void StructFilter::Serialize(Serializer &serializer) const { +void LegacyStructFilter::Serialize(Serializer &serializer) const { TableFilter::Serialize(serializer); serializer.WritePropertyWithDefault(200, "child_idx", child_idx); - serializer.WritePropertyWithDefault(201, "child_name", child_name); + serializer.WritePropertyWithDefault(201, "child_name", child_name); serializer.WritePropertyWithDefault>(202, "child_filter", child_filter); } -unique_ptr StructFilter::Deserialize(Deserializer &deserializer) { +unique_ptr LegacyStructFilter::Deserialize(Deserializer &deserializer) { auto child_idx = deserializer.ReadPropertyWithDefault(200, "child_idx"); - auto child_name = deserializer.ReadPropertyWithDefault(201, "child_name"); + auto child_name = deserializer.ReadPropertyWithDefault(201, "child_name"); auto child_filter = deserializer.ReadPropertyWithDefault>(202, "child_filter"); - auto result = duckdb::unique_ptr(new StructFilter(child_idx, std::move(child_name), std::move(child_filter))); + auto result = duckdb::unique_ptr(new LegacyStructFilter(child_idx, std::move(child_name), std::move(child_filter))); return std::move(result); } diff --git a/src/duckdb/src/storage/serialization/serialize_tableref.cpp b/src/duckdb/src/storage/serialization/serialize_tableref.cpp index 2dc033314..9d21f6f10 100644 --- a/src/duckdb/src/storage/serialization/serialize_tableref.cpp +++ b/src/duckdb/src/storage/serialization/serialize_tableref.cpp @@ -12,14 +12,14 @@ namespace duckdb { void TableRef::Serialize(Serializer &serializer) const { serializer.WriteProperty(100, "type", type); - serializer.WritePropertyWithDefault(101, "alias", alias); + serializer.WritePropertyWithDefault(101, "alias", alias); serializer.WritePropertyWithDefault>(102, "sample", sample); serializer.WritePropertyWithDefault(103, "query_location", query_location, optional_idx()); } unique_ptr TableRef::Deserialize(Deserializer &deserializer) { auto type = deserializer.ReadProperty(100, "type"); - auto alias = deserializer.ReadPropertyWithDefault(101, "alias"); + auto alias = deserializer.ReadPropertyWithDefault(101, "alias"); auto sample = deserializer.ReadPropertyWithDefault>(102, "sample"); auto query_location = deserializer.ReadPropertyWithExplicitDefault(103, "query_location", optional_idx()); unique_ptr result; @@ -74,31 +74,31 @@ unique_ptr AtClause::Deserialize(Deserializer &deserializer) { void BaseTableRef::Serialize(Serializer &serializer) const { TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "schema_name", schema_name); - serializer.WritePropertyWithDefault(201, "table_name", table_name); - serializer.WritePropertyWithDefault>(202, "column_name_alias", column_name_alias); - serializer.WritePropertyWithDefault(203, "catalog_name", catalog_name); + serializer.WritePropertyWithDefault(200, "schema_name", schema_name); + serializer.WritePropertyWithDefault(201, "table_name", table_name); + serializer.WritePropertyWithDefault>(202, "column_name_alias", column_name_alias); + serializer.WritePropertyWithDefault(203, "catalog_name", catalog_name); serializer.WritePropertyWithDefault>(204, "at_clause", at_clause); } unique_ptr BaseTableRef::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new BaseTableRef()); - deserializer.ReadPropertyWithDefault(200, "schema_name", result->schema_name); - deserializer.ReadPropertyWithDefault(201, "table_name", result->table_name); - deserializer.ReadPropertyWithDefault>(202, "column_name_alias", result->column_name_alias); - deserializer.ReadPropertyWithDefault(203, "catalog_name", result->catalog_name); + deserializer.ReadPropertyWithDefault(200, "schema_name", result->schema_name); + deserializer.ReadPropertyWithDefault(201, "table_name", result->table_name); + deserializer.ReadPropertyWithDefault>(202, "column_name_alias", result->column_name_alias); + deserializer.ReadPropertyWithDefault(203, "catalog_name", result->catalog_name); deserializer.ReadPropertyWithDefault>(204, "at_clause", result->at_clause); return std::move(result); } void ColumnDataRef::Serialize(Serializer &serializer) const { TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "expected_names", expected_names); + serializer.WritePropertyWithDefault>(200, "expected_names", expected_names); serializer.WritePropertyWithDefault>(202, "collection", collection); } unique_ptr ColumnDataRef::Deserialize(Deserializer &deserializer) { - auto expected_names = deserializer.ReadPropertyWithDefault>(200, "expected_names"); + auto expected_names = deserializer.ReadPropertyWithDefault>(200, "expected_names"); auto collection = deserializer.ReadPropertyWithDefault>(202, "collection"); auto result = duckdb::unique_ptr(new ColumnDataRef(std::move(collection), std::move(expected_names))); return std::move(result); @@ -115,14 +115,14 @@ unique_ptr EmptyTableRef::Deserialize(Deserializer &deserializer) { void ExpressionListRef::Serialize(Serializer &serializer) const { TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault>(200, "expected_names", expected_names); + serializer.WritePropertyWithDefault>(200, "expected_names", expected_names); serializer.WritePropertyWithDefault>(201, "expected_types", expected_types); serializer.WritePropertyWithDefault>>>(202, "values", values); } unique_ptr ExpressionListRef::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new ExpressionListRef()); - deserializer.ReadPropertyWithDefault>(200, "expected_names", result->expected_names); + deserializer.ReadPropertyWithDefault>(200, "expected_names", result->expected_names); deserializer.ReadPropertyWithDefault>(201, "expected_types", result->expected_types); deserializer.ReadPropertyWithDefault>>>(202, "values", result->values); return std::move(result); @@ -135,10 +135,10 @@ void JoinRef::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(202, "condition", condition); serializer.WriteProperty(203, "join_type", type); serializer.WriteProperty(204, "ref_type", ref_type); - serializer.WritePropertyWithDefault>(205, "using_columns", using_columns); + serializer.WritePropertyWithDefault>(205, "using_columns", using_columns); serializer.WritePropertyWithDefault(206, "delim_flipped", delim_flipped); serializer.WritePropertyWithDefault>>(207, "duplicate_eliminated_columns", duplicate_eliminated_columns); - if (serializer.ShouldSerialize(6)) { + if (serializer.ShouldSerialize(StorageVersion::V1_4_0)) { serializer.WritePropertyWithDefault(208, "is_implicit", is_implicit, true); } } @@ -150,7 +150,7 @@ unique_ptr JoinRef::Deserialize(Deserializer &deserializer) { deserializer.ReadPropertyWithDefault>(202, "condition", result->condition); deserializer.ReadProperty(203, "join_type", result->type); deserializer.ReadProperty(204, "ref_type", result->ref_type); - deserializer.ReadPropertyWithDefault>(205, "using_columns", result->using_columns); + deserializer.ReadPropertyWithDefault>(205, "using_columns", result->using_columns); deserializer.ReadPropertyWithDefault(206, "delim_flipped", result->delim_flipped); deserializer.ReadPropertyWithDefault>>(207, "duplicate_eliminated_columns", result->duplicate_eliminated_columns); deserializer.ReadPropertyWithExplicitDefault(208, "is_implicit", result->is_implicit, true); @@ -161,10 +161,10 @@ void PivotRef::Serialize(Serializer &serializer) const { TableRef::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "source", source); serializer.WritePropertyWithDefault>>(201, "aggregates", aggregates); - serializer.WritePropertyWithDefault>(202, "unpivot_names", unpivot_names); + serializer.WritePropertyWithDefault>(202, "unpivot_names", unpivot_names); serializer.WritePropertyWithDefault>(203, "pivots", pivots); - serializer.WritePropertyWithDefault>(204, "groups", groups); - serializer.WritePropertyWithDefault>(205, "column_name_alias", column_name_alias); + serializer.WritePropertyWithDefault>(204, "groups", groups); + serializer.WritePropertyWithDefault>(205, "column_name_alias", column_name_alias); serializer.WritePropertyWithDefault(206, "include_nulls", include_nulls); } @@ -172,57 +172,57 @@ unique_ptr PivotRef::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new PivotRef()); deserializer.ReadPropertyWithDefault>(200, "source", result->source); deserializer.ReadPropertyWithDefault>>(201, "aggregates", result->aggregates); - deserializer.ReadPropertyWithDefault>(202, "unpivot_names", result->unpivot_names); + deserializer.ReadPropertyWithDefault>(202, "unpivot_names", result->unpivot_names); deserializer.ReadPropertyWithDefault>(203, "pivots", result->pivots); - deserializer.ReadPropertyWithDefault>(204, "groups", result->groups); - deserializer.ReadPropertyWithDefault>(205, "column_name_alias", result->column_name_alias); + deserializer.ReadPropertyWithDefault>(204, "groups", result->groups); + deserializer.ReadPropertyWithDefault>(205, "column_name_alias", result->column_name_alias); deserializer.ReadPropertyWithDefault(206, "include_nulls", result->include_nulls); return std::move(result); } void ShowRef::Serialize(Serializer &serializer) const { TableRef::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "table_name", table_name); + serializer.WritePropertyWithDefault(200, "table_name", table_name); serializer.WritePropertyWithDefault>(201, "query", query); serializer.WriteProperty(202, "show_type", show_type); - serializer.WritePropertyWithDefault(203, "catalog_name", catalog_name); - serializer.WritePropertyWithDefault(204, "schema_name", schema_name); + serializer.WritePropertyWithDefault(203, "catalog_name", catalog_name); + serializer.WritePropertyWithDefault(204, "schema_name", schema_name); } unique_ptr ShowRef::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new ShowRef()); - deserializer.ReadPropertyWithDefault(200, "table_name", result->table_name); + deserializer.ReadPropertyWithDefault(200, "table_name", result->table_name); deserializer.ReadPropertyWithDefault>(201, "query", result->query); deserializer.ReadProperty(202, "show_type", result->show_type); - deserializer.ReadPropertyWithDefault(203, "catalog_name", result->catalog_name); - deserializer.ReadPropertyWithDefault(204, "schema_name", result->schema_name); + deserializer.ReadPropertyWithDefault(203, "catalog_name", result->catalog_name); + deserializer.ReadPropertyWithDefault(204, "schema_name", result->schema_name); return std::move(result); } void SubqueryRef::Serialize(Serializer &serializer) const { TableRef::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "subquery", subquery); - serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); + serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); } unique_ptr SubqueryRef::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new SubqueryRef()); deserializer.ReadPropertyWithDefault>(200, "subquery", result->subquery); - deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); + deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); return std::move(result); } void TableFunctionRef::Serialize(Serializer &serializer) const { TableRef::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "function", function); - serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); + serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); serializer.WritePropertyWithDefault(202, "with_ordinality", with_ordinality, OrdinalityType::WITHOUT_ORDINALITY); } unique_ptr TableFunctionRef::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new TableFunctionRef()); deserializer.ReadPropertyWithDefault>(200, "function", result->function); - deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); + deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); deserializer.ReadPropertyWithExplicitDefault(202, "with_ordinality", result->with_ordinality, OrdinalityType::WITHOUT_ORDINALITY); return std::move(result); } diff --git a/src/duckdb/src/storage/serialization/serialize_types.cpp b/src/duckdb/src/storage/serialization/serialize_types.cpp index 21278c0cf..4147596b0 100644 --- a/src/duckdb/src/storage/serialization/serialize_types.cpp +++ b/src/duckdb/src/storage/serialization/serialize_types.cpp @@ -24,9 +24,6 @@ shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) auto extension_info = deserializer.ReadPropertyWithDefault>(103, "extension_info"); shared_ptr result; switch (type) { - case ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO: - result = AggregateStateTypeInfo::Deserialize(deserializer); - break; case ExtraTypeInfoType::ANY_TYPE_INFO: result = AnyTypeInfo::Deserialize(deserializer); break; @@ -75,27 +72,6 @@ shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) result->extension_info = std::move(extension_info); return result; } -// NOLINTBEGIN(bugprone-parent-virtual-call) -// reasons: Multi-level inheritance is not supported in the generation tool - -void AggregateStateTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(300, "function_name", state_type.function_name); - serializer.WriteProperty(301, "return_type", state_type.return_type); - serializer.WritePropertyWithDefault>(302, "bound_argument_types", state_type.bound_argument_types); - serializer.WritePropertyWithDefault>(303, "child_types", child_types); -} - -shared_ptr AggregateStateTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new AggregateStateTypeInfo()); - deserializer.ReadPropertyWithDefault(300, "function_name", result->state_type.function_name); - deserializer.ReadProperty(301, "return_type", result->state_type.return_type); - deserializer.ReadPropertyWithDefault>(302, "bound_argument_types", result->state_type.bound_argument_types); - deserializer.ReadPropertyWithDefault>(303, "child_types", result->child_types); - return std::move(result); -} - -// NOLINTEND(bugprone-parent-virtual-call) void AnyTypeInfo::Serialize(Serializer &serializer) const { ExtraTypeInfo::Serialize(serializer); @@ -172,17 +148,17 @@ shared_ptr IntegerLiteralTypeInfo::Deserialize(Deserializer &dese void LegacyAggregateStateTypeInfo::Serialize(Serializer &serializer) const { ExtraTypeInfo::Serialize(serializer); - serializer.WritePropertyWithDefault(200, "function_name", state_type.function_name); - serializer.WriteProperty(201, "return_type", state_type.return_type); - serializer.WritePropertyWithDefault>(202, "bound_argument_types", state_type.bound_argument_types); + /* [Deleted] (string) "function_name" */ + /* [Deleted] (LogicalType) "return_type" */ + /* [Deleted] (vector) "bound_argument_types" */ } shared_ptr LegacyAggregateStateTypeInfo::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new LegacyAggregateStateTypeInfo()); - deserializer.ReadPropertyWithDefault(200, "function_name", result->state_type.function_name); - deserializer.ReadProperty(201, "return_type", result->state_type.return_type); - deserializer.ReadPropertyWithDefault>(202, "bound_argument_types", result->state_type.bound_argument_types); - return std::move(result); + deserializer.ReadDeletedProperty(200, "function_name"); + deserializer.ReadDeletedProperty(201, "return_type"); + deserializer.ReadDeletedProperty>(202, "bound_argument_types"); + auto result = LegacyAggregateStateTypeInfo::LegacyDeserialize(); + return result; } void ListTypeInfo::Serialize(Serializer &serializer) const { diff --git a/src/duckdb/src/storage/single_file_block_manager.cpp b/src/duckdb/src/storage/single_file_block_manager.cpp index a9a672579..25c8c36b8 100644 --- a/src/duckdb/src/storage/single_file_block_manager.cpp +++ b/src/duckdb/src/storage/single_file_block_manager.cpp @@ -175,7 +175,11 @@ bool DecryptCanary(MainHeader &main_header, const shared_ptr &e void MainHeader::Write(WriteStream &ser) { ser.WriteData(const_data_ptr_cast(MAGIC_BYTES), MAGIC_BYTE_SIZE); - ser.Write(version_number); + if (static_cast(version_number) >= StorageVersion::V2_0_0) { + // from v2.0.0 we write 999, to indicate that the version number is deprecated + version_number = static_cast(StorageVersion::DEPRECATED); + } + ser.Write(version_number); for (idx_t i = 0; i < FLAG_COUNT; i++) { ser.Write(flags[i]); } @@ -203,6 +207,13 @@ void MainHeader::CheckMagicBytes(QueryContext context, FileHandle &handle) { } } +void MainHeader::CheckMagicBytes(MemoryMappedFile &handle) { + auto magic_bytes = handle.GetData(MainHeader::MAGIC_BYTE_OFFSET, MainHeader::MAGIC_BYTE_SIZE); + if (memcmp(magic_bytes, MainHeader::MAGIC_BYTES, MainHeader::MAGIC_BYTE_SIZE) != 0) { + throw IOException("The file \"%s\" exists, but it is not a valid DuckDB database file!", handle.GetPath()); + } +} + MainHeader MainHeader::Read(ReadStream &source) { data_t magic_bytes[MAGIC_BYTE_SIZE]; @@ -212,11 +223,14 @@ MainHeader MainHeader::Read(ReadStream &source) { throw IOException("The file is not a valid DuckDB database file!"); } - header.version_number = source.Read(); + header.version_number = source.Read(); - // Check the version number to determine if we can read this file. - if (header.version_number < VERSION_NUMBER_LOWER || header.version_number > VERSION_NUMBER_UPPER) { - auto version = GetDuckDBVersions(header.version_number); + if (static_cast(header.version_number) == DEPRECATED_VERSION_NUMBER) { + // if the version number in the main header is deprecated, then we just ignore the main header version number + // TODO: if we are confident, we can remove the check below + } else if (header.version_number < VERSION_NUMBER_LOWER || header.version_number > VERSION_NUMBER_UPPER) { + // Check the version number to determine if we can read this file. + auto version = GetDuckDBVersions(static_cast(header.version_number)); string version_text; if (!version.empty()) { // Known version. @@ -260,7 +274,67 @@ void DatabaseHeader::Write(WriteStream &ser) { ser.Write(block_count); ser.Write(block_alloc_size); ser.Write(vector_size); - ser.Write(serialization_compatibility); + + if (storage_compatibility < StorageVersion::V2_0_0) { + string storage_version_string = StorageVersionInfo::GetStorageVersionString(storage_compatibility); + auto ser_version = GetSerializationVersionDeprecated(storage_version_string.c_str()); + ser.Write(ser_version); + } else { + ser.Write(static_cast(storage_compatibility)); + } +} + +void DatabaseHeader::SetStorageVersionInDatabaseHeader(DatabaseHeader &header, StorageVersion main_version, + StorageVersion read_version) { + if ((main_version == MainHeader::DEPRECATED_VERSION_NUMBER) || (read_version >= StorageVersion::V2_0_0)) { + // From v2.0.0 onwards, we use and store only the storage version number + switch (read_version) { + case StorageVersion::V2_0_0: + header.storage_compatibility = StorageVersion::V2_0_0; + break; + // new versions should be added here + default: + throw InvalidInputException("Storage Version '%d' is not found!", static_cast(read_version)); + } + } else { + // Before V2.0.0 the Storage Version in the main header could be written in two different ways + // 1) When the DB is created from scratch -- with e.g. ATTACH (STORAGE_VERSION "v1.4.0") + // 2) if the db file got bumped to a higher version + // (e.g. "ATTACH 'bump.dp' (STORAGE_VERSION 'v.1.5.0'), when bump.db already exists") + // in case 1, the explicit storage version is serialized in the MAIN header (e.g. 1.4.0 = 67, in the example) + // in case 2, the version number in the MAIN header is bumped to at most v65 (v1.2.0) + // thus, in case 2, the main header storage version is often lower then the actual storage version + // that's also why we need the logic below for backwards compatibility + // if the main header version and db header version are < v2.0.0 + // then we fall back to the serialization version + switch (static_cast(read_version)) { + // In some old duckdb versions, storage version (64) + // is (by mistake) serialized instead of serialization version + case static_cast(StorageVersion::INVALID): + // If read version is 0 + case static_cast(SerializationVersionDeprecated::V0_10_2): + case static_cast(StorageVersion::V0_10_2): + case static_cast(SerializationVersionDeprecated::V1_0_0): + case static_cast(SerializationVersionDeprecated::V1_1_0): + header.storage_compatibility = StorageVersion::V0_10_2; + break; + case static_cast(SerializationVersionDeprecated::V1_2_0): + header.storage_compatibility = StorageVersion::V1_2_0; + break; + case static_cast(SerializationVersionDeprecated::V1_3_0): + header.storage_compatibility = StorageVersion::V1_3_0; + break; + case static_cast(SerializationVersionDeprecated::V1_4_0): + header.storage_compatibility = StorageVersion::V1_4_0; + break; + case static_cast(SerializationVersionDeprecated::V1_5_0): + header.storage_compatibility = StorageVersion::V1_5_0; + break; + default: + throw InvalidInputException("Deprecated Serialization Version '%d' is not found!", + static_cast(read_version)); + } + } } DatabaseHeader DatabaseHeader::Read(const MainHeader &main_header, ReadStream &source) { @@ -287,8 +361,11 @@ DatabaseHeader DatabaseHeader::Read(const MainHeader &main_header, ReadStream &s STANDARD_VECTOR_SIZE, header.vector_size); } - // Default to 1 for version 64, else read from file. - header.serialization_compatibility = main_header.version_number == 64 ? 1 : source.Read(); + // storage version from the database header + auto h_storage_version = source.Read(); + auto database_header_storage_version = static_cast(h_storage_version); + SetStorageVersionInDatabaseHeader(header, static_cast(main_header.version_number), + database_header_storage_version); return header; } @@ -323,43 +400,23 @@ SingleFileBlockManager::~SingleFileBlockManager() { this->in_destruction = true; } -FileOpenFlags SingleFileBlockManager::GetFileFlags(bool create_new) const { - FileOpenFlags result; - if (options.read_only) { - D_ASSERT(!create_new); - result = FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS | FileLockType::READ_LOCK; - } else { - result = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_READ | FileLockType::WRITE_LOCK; - if (create_new) { - result |= FileFlags::FILE_FLAGS_FILE_CREATE; - } - } - if (options.use_direct_io) { - result |= FileFlags::FILE_FLAGS_DIRECT_IO; - } - // database files can be read from in parallel - result |= FileFlags::FILE_FLAGS_PARALLEL_ACCESS; - result |= FileFlags::FILE_FLAGS_MULTI_CLIENT_ACCESS; - return result; -} - void SingleFileBlockManager::AddStorageVersionTag() { - db.tags["storage_version"] = GetStorageVersionName(options.storage_version.GetIndex(), true); + db.tags["storage_version"] = GetStorageVersionName(options.storage_version, true); } -uint64_t SingleFileBlockManager::GetVersionNumber() const { - auto storage_version = options.storage_version.GetIndex(); - if (storage_version < 4) { - return VERSION_NUMBER; +StorageVersion SingleFileBlockManager::GetVersionNumber() const { + auto storage_version = options.storage_version; + if (StorageManager::IsPriorToVersion(StorageVersion::V1_2_0, storage_version)) { + return StorageVersion::V0_10_2; } // Look up the matching version number. auto version_name = GetStorageVersionName(storage_version, false); - return GetStorageVersion(version_name.c_str()).GetIndex(); + return GetStorageVersion(version_name.c_str()); } -MainHeader ConstructMainHeader(idx_t version_number) { +MainHeader ConstructMainHeader(StorageVersion version_number) { MainHeader header; - header.version_number = version_number; + header.version_number = static_cast(version_number); memset(header.flags, 0, sizeof(uint64_t) * MainHeader::FLAG_COUNT); return header; } @@ -457,24 +514,25 @@ void SingleFileBlockManager::CheckAndAddEncryptionKey(MainHeader &main_header) { } void SingleFileBlockManager::CreateNewDatabase(QueryContext context) { - auto flags = GetFileFlags(true); - auto encryption_enabled = options.encryption_options.encryption_enabled; if (encryption_enabled) { // Check if we can read/write the encrypted database db.GetDatabase().GetEncryptionUtil(options.read_only); } - // open the RDBMS handle - auto &fs = FileSystem::Get(db); - handle = fs.OpenFile(path, flags); + // MAP mode opens only the mmap; other modes open the FileHandle. + handle = DatabaseHandle::Open(db, path, options, DatabaseOpenMode::CREATE_NEW_FILE); header_buffer.Clear(); + if (options.storage_version == StorageVersion::INVALID) { + options.storage_version = StorageCompatibility::Latest().storage_version; + } + options.version_number = GetVersionNumber(); - db.GetStorageManager().SetStorageVersion(options.storage_version.GetIndex()); + db.GetStorageManager().SetStorageVersion(options.storage_version); AddStorageVersionTag(); - MainHeader main_header = ConstructMainHeader(options.version_number.GetIndex()); + MainHeader main_header = ConstructMainHeader(options.version_number); // Derive the encryption key and add it to the cache. // Not used for plain databases. @@ -483,7 +541,7 @@ void SingleFileBlockManager::CreateNewDatabase(QueryContext context) { // We need the unique database identifier, if the storage version is new enough. // If encryption is enabled, we also use it as the salt. memset(options.db_identifier, 0, MainHeader::DB_IDENTIFIER_LEN); - if (encryption_enabled || options.version_number.GetIndex() >= 67) { + if (encryption_enabled || StorageManager::TargetAtLeastVersion(StorageVersion::V1_4_0, options.version_number)) { GenerateDBIdentifier(options.db_identifier); } @@ -539,7 +597,7 @@ void SingleFileBlockManager::CreateNewDatabase(QueryContext context) { // We create the SingleFileBlockManager with the desired block allocation size before calling CreateNewDatabase. h1.block_alloc_size = GetBlockAllocSize(); h1.vector_size = STANDARD_VECTOR_SIZE; - h1.serialization_compatibility = options.storage_version.GetIndex(); + h1.storage_compatibility = options.storage_version; SerializeHeaderStructure(h1, header_buffer.GetDataMutable()); ChecksumAndWrite(context, header_buffer, Storage::FILE_HEADER_SIZE); @@ -552,7 +610,7 @@ void SingleFileBlockManager::CreateNewDatabase(QueryContext context) { // We create the SingleFileBlockManager with the desired block allocation size before calling CreateNewDatabase. h2.block_alloc_size = GetBlockAllocSize(); h2.vector_size = STANDARD_VECTOR_SIZE; - h2.serialization_compatibility = options.storage_version.GetIndex(); + h2.storage_compatibility = options.storage_version; SerializeHeaderStructure(h2, header_buffer.GetDataMutable()); ChecksumAndWrite(context, header_buffer, Storage::FILE_HEADER_SIZE * 2ULL); @@ -565,17 +623,9 @@ void SingleFileBlockManager::CreateNewDatabase(QueryContext context) { } void SingleFileBlockManager::LoadExistingDatabase(QueryContext context) { - auto flags = GetFileFlags(false); + handle = DatabaseHandle::Open(db, path, options, DatabaseOpenMode::OPEN_EXISTING_FILE); + handle->CheckMagicBytes(context); - // open the RDBMS handle - auto &fs = FileSystem::Get(db); - handle = fs.OpenFile(path, flags); - if (!handle) { - // this can only happen in read-only mode - as that is when we set FILE_FLAGS_NULL_IF_NOT_EXISTS - throw IOException("Cannot open database \"%s\" in read-only mode: database does not exist", path); - } - - MainHeader::CheckMagicBytes(context, *handle); // otherwise, we check the metadata of the file ReadAndChecksum(context, header_buffer, 0, true); @@ -643,7 +693,7 @@ void SingleFileBlockManager::LoadExistingDatabase(QueryContext context) { static_cast(main_header.GetEncryptionVersion())); } - options.version_number = main_header.version_number; + options.version_number = static_cast(main_header.version_number); // read the database headers from disk DatabaseHeader h1; @@ -716,7 +766,7 @@ void SingleFileBlockManager::CheckChecksum(FileBuffer &block, uint64_t location, void SingleFileBlockManager::ReadAndChecksum(QueryContext context, FileBuffer &block, uint64_t location, bool skip_block_header) const { // read the buffer from disk - block.Read(context, *handle, location); + handle->Read(context, block, location); //! calculate delta header bytes (if any) uint64_t delta = GetBlockHeaderSize() - Storage::DEFAULT_BLOCK_HEADER_SIZE; @@ -755,9 +805,9 @@ void SingleFileBlockManager::ChecksumAndWrite(QueryContext context, FileBuffer & temp_buffer_manager = make_uniq(BlockAllocator::Get(db), block.GetBufferType(), block.Size(), GetBlockHeaderSize()); EncryptionEngine::EncryptBlock(db, key_id, block, *temp_buffer_manager, delta); - temp_buffer_manager->Write(context, *handle, location); + temp_buffer_manager->Write(context, handle->GetFileHandle(), location); } else { - block.Write(context, *handle, location); + handle->Write(context, block, location); } } @@ -766,27 +816,27 @@ void SingleFileBlockManager::Initialize(const DatabaseHeader &header, const opti meta_block = header.meta_block; iteration_count = header.iteration; max_block = NumericCast(header.block_count); - if (options.storage_version.IsValid()) { + if (options.storage_version != StorageVersion::INVALID) { // storage version specified explicitly - use requested storage version - auto requested_compat_version = options.storage_version.GetIndex(); - if (requested_compat_version < header.serialization_compatibility) { + auto requested_compat_version = options.storage_version; + if (requested_compat_version < header.storage_compatibility) { throw InvalidInputException( "Error opening \"%s\": cannot initialize database with storage version %d - which is lower than what " "the database itself uses (%d). The storage version of an existing database cannot be lowered.", - path, requested_compat_version, header.serialization_compatibility); + path, requested_compat_version, header.storage_compatibility); } } else { // load storage version from header - options.storage_version = header.serialization_compatibility; + options.storage_version = header.storage_compatibility; } - if (header.serialization_compatibility > SerializationCompatibility::Latest().serialization_version) { + if (header.storage_compatibility > StorageCompatibility::Latest().storage_version) { throw InvalidInputException( "Error opening \"%s\": file was written with a storage version greater than the latest version supported " "by this DuckDB instance. Try opening the file with a newer version of DuckDB.", path); } - db.GetStorageManager().SetStorageVersion(options.storage_version.GetIndex()); + db.GetStorageManager().SetStorageVersion(options.storage_version); if (block_alloc_size.IsValid() && block_alloc_size.GetIndex() != header.block_alloc_size) { throw InvalidInputException( @@ -798,6 +848,16 @@ void SingleFileBlockManager::Initialize(const DatabaseHeader &header, const opti SetBlockAllocSize(header.block_alloc_size); } +void SingleFileBlockManager::RewriteMainHeader(QueryContext &context, StorageVersion version_number) { + MainHeader main_header = ConstructMainHeader(version_number); + SerializeHeaderStructure(main_header, header_buffer.GetDataMutable()); + // now write the header to the file + ChecksumAndWrite(context, header_buffer, 0); + header_buffer.Clear(); + // avoid having the Sync at the end write two blocks + handle->Sync(); +} + void SingleFileBlockManager::LoadFreeList(QueryContext context) { MetaBlockPointer free_pointer(free_list_id, 0); if (!free_pointer.IsValid()) { @@ -1105,7 +1165,7 @@ void SingleFileBlockManager::ReadBlock(data_ptr_t internal_buffer, uint64_t bloc void SingleFileBlockManager::ReadBlock(Block &block, bool skip_block_header) const { // read the buffer from disk auto location = GetBlockLocation(block.id); - block.Read(QueryContext(), *handle, location); + handle->Read(QueryContext(), block, location); //! calculate delta header bytes (if any) uint64_t delta = GetBlockHeaderSize() - Storage::DEFAULT_BLOCK_HEADER_SIZE; @@ -1130,7 +1190,7 @@ void SingleFileBlockManager::ReadBlocks(FileBuffer &buffer, block_id_t start_blo // read the buffer from disk auto location = GetBlockLocation(start_block); - buffer.Read(QueryContext(), *handle, location); + handle->Read(QueryContext(), buffer, location); // for each of the blocks - verify the checksum auto ptr = buffer.InternalBuffer(); @@ -1169,7 +1229,8 @@ void SingleFileBlockManager::Truncate() { } // truncate the file free_list.erase(free_list.lower_bound(max_block), free_list.end()); - handle->Truncate(NumericCast(BLOCK_START + NumericCast(max_block) * GetBlockAllocSize())); + auto new_size = NumericCast(BLOCK_START + NumericCast(max_block) * GetBlockAllocSize()); + handle->Truncate(new_size); } vector SingleFileBlockManager::GetFreeListBlocks() { @@ -1306,7 +1367,7 @@ void SingleFileBlockManager::WriteHeader(QueryContext context, DatabaseHeader he header.block_count = NumericCast(max_block); lock.unlock(); - header.serialization_compatibility = options.storage_version.GetIndex(); + header.storage_compatibility = options.storage_version; auto debug_checkpoint_abort = Settings::Get(db.GetDatabase()); if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE) { @@ -1318,14 +1379,15 @@ void SingleFileBlockManager::WriteHeader(QueryContext context, DatabaseHeader he header_buffer.Clear(); // if we are upgrading the database from version 64 -> version 65, we need to re-write the main header - if (options.version_number.GetIndex() == 64 && options.storage_version.GetIndex() >= 4) { - // rewrite the main header - options.version_number = 65; - MainHeader main_header = ConstructMainHeader(options.version_number.GetIndex()); - SerializeHeaderStructure(main_header, header_buffer.GetDataMutable()); - // now write the header to the file - ChecksumAndWrite(context, header_buffer, 0); - header_buffer.Clear(); + if (options.version_number == StorageVersion::V0_10_2 && options.storage_version >= StorageVersion::V1_2_0) { + RewriteMainHeader(context, StorageVersion::V1_2_0); + } + + // if we are downgrading the database from v2.0.0 to an earlier version + // we also need to re-write the main header + if (options.version_number == MainHeader::DEPRECATED_VERSION_NUMBER && + options.storage_version < StorageVersion::V2_0_0) { + RewriteMainHeader(context, options.storage_version); } // set the header inside the buffer @@ -1363,7 +1425,9 @@ void SingleFileBlockManager::UnregisterBlock(block_id_t id) { void SingleFileBlockManager::TrimFreeBlockRange(block_id_t start, block_id_t end) { auto block_count = NumericCast(end + 1 - start); - handle->Trim(BLOCK_START + (NumericCast(start) * GetBlockAllocSize()), block_count * GetBlockAllocSize()); + auto offset = BLOCK_START + (NumericCast(start) * GetBlockAllocSize()); + auto length = block_count * GetBlockAllocSize(); + handle->Trim(offset, length); } void SingleFileBlockManager::TrimFreeBlocks(const set &blocks) { diff --git a/src/duckdb/src/storage/standard_buffer_manager.cpp b/src/duckdb/src/storage/standard_buffer_manager.cpp index 4c31fafa1..66245acd6 100644 --- a/src/duckdb/src/storage/standard_buffer_manager.cpp +++ b/src/duckdb/src/storage/standard_buffer_manager.cpp @@ -20,6 +20,11 @@ namespace duckdb { #ifdef DUCKDB_DEBUG_DESTROY_BLOCKS static void WriteGarbageIntoBuffer(BlockLock &lock, BlockHandle &block) { auto &buffer = block.GetMemory().GetBuffer(lock); + if (!buffer->OwnsInternalBuffer()) { + // don't write garbage into mmap buffers + // this would directly be written back into the file + return; + } memset(buffer->GetDataMutable(), 0xa5, buffer->Size()); // 0xa5 is default memory in debug mode } @@ -43,12 +48,12 @@ unique_ptr StandardBufferManager::ConstructManagedBuffer(idx_t size, if (type == FileBufferType::BLOCK) { throw InternalException("ConstructManagedBuffer cannot be used to construct blocks"); } - if (source) { + if (source && source->OwnsInternalBuffer()) { auto tmp = std::move(source); D_ASSERT(tmp->AllocSize() == BufferManager::GetAllocSize(size + block_header_size)); result = make_uniq(*tmp, type, block_header_size); } else { - // non re-usable buffer: allocate a new buffer + // non re-usable buffer (or mmap-backed, which we cannot rewrite): allocate a new buffer result = make_uniq(BlockAllocator::Get(db), type, size, block_header_size); } result->Initialize(DBConfig::GetConfig(db).options.debug_initialize); @@ -566,6 +571,8 @@ unique_ptr StandardBufferManager::ReadTemporaryBuffer(QueryContext c handle.reset(); // Delete the file and return the buffer. + // DeleteTemporaryFile already decrements evicted_data_per_tag for the .block path; do not + // decrement again here or the counter underflows on every read-back. DeleteTemporaryFile(block.GetMemory()); return buffer; diff --git a/src/duckdb/src/storage/statistics/array_stats.cpp b/src/duckdb/src/storage/statistics/array_stats.cpp index 9e821bab5..e128ed695 100644 --- a/src/duckdb/src/storage/statistics/array_stats.cpp +++ b/src/duckdb/src/storage/statistics/array_stats.cpp @@ -58,14 +58,14 @@ void ArrayStats::SetChildStats(BaseStatistics &stats, unique_ptr } } -void ArrayStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { +void ArrayStats::Merge(BaseStatistics &stats, const BaseStatistics &other, StatsMergeType merge_type) { if (other.GetType().id() == LogicalTypeId::VALIDITY) { return; } auto &child_stats = ArrayStats::GetChildStats(stats); auto &other_child_stats = ArrayStats::GetChildStats(other); - child_stats.Merge(other_child_stats); + child_stats.Merge(other_child_stats, merge_type); } void ArrayStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { @@ -91,9 +91,9 @@ child_list_t ArrayStats::ToStruct(const BaseStatistics &stats) { return result; } -void ArrayStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { +void ArrayStats::Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count) { auto &child_stats = ArrayStats::GetChildStats(stats); - auto &child_entry = ArrayVector::GetChildMutable(vector); + const auto &child_entry = ArrayVector::GetChild(vector); auto array_size = ArrayType::GetSize(vector.GetType()); UnifiedVectorFormat vdata; diff --git a/src/duckdb/src/storage/statistics/base_statistics.cpp b/src/duckdb/src/storage/statistics/base_statistics.cpp index 1e69e822a..2cdd92fda 100644 --- a/src/duckdb/src/storage/statistics/base_statistics.cpp +++ b/src/duckdb/src/storage/statistics/base_statistics.cpp @@ -49,6 +49,7 @@ BaseStatistics::BaseStatistics(BaseStatistics &&other) noexcept { has_no_null = other.has_no_null; distinct_count = other.distinct_count; stats_union = other.stats_union; + std::swap(extra_data, other.extra_data); std::swap(child_stats, other.child_stats); } @@ -58,6 +59,7 @@ BaseStatistics &BaseStatistics::operator=(BaseStatistics &&other) noexcept { has_no_null = other.has_no_null; distinct_count = other.distinct_count; stats_union = other.stats_union; + std::swap(extra_data, other.extra_data); std::swap(child_stats, other.child_stats); return *this; } @@ -144,7 +146,7 @@ bool BaseStatistics::IsConstant() const { return false; } -void BaseStatistics::Merge(const BaseStatistics &other) { +void BaseStatistics::Merge(const BaseStatistics &other, StatsMergeType merge_type) { has_null = has_null || other.has_null; has_no_null = has_no_null || other.has_no_null; switch (GetStatsType()) { @@ -152,16 +154,16 @@ void BaseStatistics::Merge(const BaseStatistics &other) { NumericStats::Merge(*this, other); break; case StatisticsType::STRING_STATS: - StringStats::Merge(*this, other); + StringStats::Merge(*this, other, merge_type); break; case StatisticsType::LIST_STATS: - ListStats::Merge(*this, other); + ListStats::Merge(*this, other, merge_type); break; case StatisticsType::STRUCT_STATS: - StructStats::Merge(*this, other); + StructStats::Merge(*this, other, merge_type); break; case StatisticsType::ARRAY_STATS: - ArrayStats::Merge(*this, other); + ArrayStats::Merge(*this, other, merge_type); break; case StatisticsType::GEOMETRY_STATS: GeometryStats::Merge(*this, other); @@ -259,6 +261,9 @@ void BaseStatistics::Copy(const BaseStatistics &other) { case StatisticsType::VARIANT_STATS: VariantStats::Copy(*this, other); break; + case StatisticsType::STRING_STATS: + StringStats::Copy(*this, other); + break; default: break; } @@ -475,7 +480,8 @@ string BaseStatistics::ToString() const { return ToStruct().ToString(); } -void BaseStatistics::Verify(Vector &vector, const SelectionVector &sel, idx_t count, const bool ignore_has_null) const { +void BaseStatistics::Verify(const Vector &vector, const SelectionVector &sel, idx_t count, + const bool ignore_has_null) const { D_ASSERT(vector.GetType() == this->type); switch (GetStatsType()) { case StatisticsType::NUMERIC_STATS: @@ -523,7 +529,7 @@ void BaseStatistics::Verify(Vector &vector, const SelectionVector &sel, idx_t co } } -void BaseStatistics::Verify(Vector &vector, idx_t count) const { +void BaseStatistics::Verify(const Vector &vector, idx_t count) const { auto sel = FlatVector::IncrementalSelectionVector(); Verify(vector, *sel, count, false); } @@ -540,7 +546,7 @@ BaseStatistics BaseStatistics::FromConstantType(const Value &input) { auto result = StringStats::CreateEmpty(input.type()); if (!input.IsNull()) { auto &string_value = StringValue::Get(input); - StringStats::Update(result, string_t(string_value)); + StringStats::FromConstant(result, string_t(string_value)); } return result; } diff --git a/src/duckdb/src/storage/statistics/column_statistics.cpp b/src/duckdb/src/storage/statistics/column_statistics.cpp index 19784bc91..5c1755d24 100644 --- a/src/duckdb/src/storage/statistics/column_statistics.cpp +++ b/src/duckdb/src/storage/statistics/column_statistics.cpp @@ -44,7 +44,7 @@ void ColumnStatistics::SetDistinct(unique_ptr distinct) { this->distinct_stats = std::move(distinct); } -void ColumnStatistics::UpdateDistinctStatistics(Vector &v, idx_t count, Vector &hashes) { +void ColumnStatistics::UpdateDistinctStatistics(const Vector &v, idx_t count, Vector &hashes) { if (!distinct_stats) { return; } diff --git a/src/duckdb/src/storage/statistics/distinct_statistics.cpp b/src/duckdb/src/storage/statistics/distinct_statistics.cpp index f952b4290..d1ed53b81 100644 --- a/src/duckdb/src/storage/statistics/distinct_statistics.cpp +++ b/src/duckdb/src/storage/statistics/distinct_statistics.cpp @@ -22,7 +22,7 @@ void DistinctStatistics::Merge(const DistinctStatistics &other) { total_count += other.total_count; } -void DistinctStatistics::UpdateSample(Vector &new_data, idx_t count, Vector &hashes) { +void DistinctStatistics::UpdateSample(const Vector &new_data, idx_t count, Vector &hashes) { total_count += count; const auto original_count = count; const auto sample_rate = new_data.GetType().IsIntegral() ? INTEGRAL_SAMPLE_RATE : BASE_SAMPLE_RATE; @@ -34,16 +34,16 @@ void DistinctStatistics::UpdateSample(Vector &new_data, idx_t count, Vector &has UpdateInternal(new_data, count, hashes); } -void DistinctStatistics::Update(Vector &new_data, idx_t count, Vector &hashes) { +void DistinctStatistics::Update(const Vector &new_data, idx_t count, Vector &hashes) { total_count += count; UpdateInternal(new_data, count, hashes); } -void DistinctStatistics::UpdateInternal(Vector &new_data, idx_t count, Vector &hashes) { +void DistinctStatistics::UpdateInternal(const Vector &new_data, idx_t count, Vector &hashes) { sample_count += count; VectorOperations::Hash(new_data, hashes, count); - log->Update(new_data, hashes, count); + log->Update(new_data, hashes); } string DistinctStatistics::ToString() const { diff --git a/src/duckdb/src/storage/statistics/geometry_stats.cpp b/src/duckdb/src/storage/statistics/geometry_stats.cpp index 0f266164d..f4143c2e8 100644 --- a/src/duckdb/src/storage/statistics/geometry_stats.cpp +++ b/src/duckdb/src/storage/statistics/geometry_stats.cpp @@ -79,7 +79,7 @@ BaseStatistics GeometryStats::CreateEmpty(LogicalType type) { void GeometryStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { // Should we serialize as old extension geometry type for backwards compatibility? // (in that case, write unknown string stats) - if (!serializer.ShouldSerialize(7)) { + if (!serializer.ShouldSerialize(StorageVersion::V1_5_0)) { auto string_stats = StringStats::CreateUnknown(LogicalType::VARCHAR); StringStats::Serialize(string_stats, serializer); return; @@ -195,7 +195,7 @@ void GeometryStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { target.Merge(source); } -void GeometryStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { +void GeometryStats::Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count) { // TODO: Verify stats } @@ -273,12 +273,12 @@ FilterPropagateResult GeometryStats::CheckZonemap(const BaseStatistics &stats, c return FilterPropagateResult::NO_PRUNING_POSSIBLE; } const auto &func = expr->Cast(); - if (func.children.size() != 2) { + if (func.GetChildren().size() != 2) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - if (func.children[0]->GetReturnType().id() != LogicalTypeId::GEOMETRY || - func.children[1]->GetReturnType().id() != LogicalTypeId::GEOMETRY) { + if (func.GetChildren()[0]->GetReturnType().id() != LogicalTypeId::GEOMETRY || + func.GetChildren()[1]->GetReturnType().id() != LogicalTypeId::GEOMETRY) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } @@ -287,7 +287,7 @@ FilterPropagateResult GeometryStats::CheckZonemap(const BaseStatistics &stats, c auto found = false; for (const auto &name : geometry_predicates) { - if (StringUtil::CIEquals(func.function.GetName().c_str(), name)) { + if (func.Function().GetName() == name) { found = true; break; } @@ -297,8 +297,8 @@ FilterPropagateResult GeometryStats::CheckZonemap(const BaseStatistics &stats, c return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - const auto lhs_kind = func.children[0]->GetExpressionType(); - const auto rhs_kind = func.children[1]->GetExpressionType(); + const auto lhs_kind = func.GetChildren()[0]->GetExpressionType(); + const auto rhs_kind = func.GetChildren()[1]->GetExpressionType(); const auto lhs_is_const = lhs_kind == ExpressionType::VALUE_CONSTANT && rhs_kind == ExpressionType::BOUND_REF; const auto rhs_is_const = rhs_kind == ExpressionType::VALUE_CONSTANT && lhs_kind == ExpressionType::BOUND_REF; @@ -315,10 +315,10 @@ FilterPropagateResult GeometryStats::CheckZonemap(const BaseStatistics &stats, c } if (lhs_is_const) { - return CheckIntersectionFilter(data, func.children[0]->Cast().value); + return CheckIntersectionFilter(data, func.GetChildren()[0]->Cast().GetValue()); } if (rhs_is_const) { - return CheckIntersectionFilter(data, func.children[1]->Cast().value); + return CheckIntersectionFilter(data, func.GetChildren()[1]->Cast().GetValue()); } // Else, no constant argument return FilterPropagateResult::NO_PRUNING_POSSIBLE; diff --git a/src/duckdb/src/storage/statistics/list_stats.cpp b/src/duckdb/src/storage/statistics/list_stats.cpp index 1e105deff..4fcde047d 100644 --- a/src/duckdb/src/storage/statistics/list_stats.cpp +++ b/src/duckdb/src/storage/statistics/list_stats.cpp @@ -59,14 +59,14 @@ void ListStats::SetChildStats(BaseStatistics &stats, unique_ptr } } -void ListStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { +void ListStats::Merge(BaseStatistics &stats, const BaseStatistics &other, StatsMergeType merge_type) { if (other.GetType().id() == LogicalTypeId::VALIDITY) { return; } auto &child_stats = ListStats::GetChildStats(stats); auto &other_child_stats = ListStats::GetChildStats(other); - child_stats.Merge(other_child_stats); + child_stats.Merge(other_child_stats, merge_type); } void ListStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { @@ -92,9 +92,9 @@ child_list_t ListStats::ToStruct(const BaseStatistics &stats) { return result; } -void ListStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { +void ListStats::Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count) { auto &child_stats = ListStats::GetChildStats(stats); - auto &child_entry = ListVector::GetChildMutable(vector); + const auto &child_entry = ListVector::GetChild(vector); auto entries = vector.Values(); idx_t total_list_count = 0; diff --git a/src/duckdb/src/storage/statistics/numeric_stats.cpp b/src/duckdb/src/storage/statistics/numeric_stats.cpp index e5a2b283b..28aa2cfa6 100644 --- a/src/duckdb/src/storage/statistics/numeric_stats.cpp +++ b/src/duckdb/src/storage/statistics/numeric_stats.cpp @@ -528,10 +528,10 @@ static void DeserializeNumericStatsValue(const LogicalType &type, NumericValueUn void NumericStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { auto &numeric_stats = NumericStats::GetDataUnsafe(stats); - serializer.WriteObject(200, "max", [&](Serializer &object) { + serializer.WriteObject(200, "min", [&](Serializer &object) { SerializeNumericStatsValue(stats.GetType(), numeric_stats.min, numeric_stats.has_min, object); }); - serializer.WriteObject(201, "min", [&](Serializer &object) { + serializer.WriteObject(201, "max", [&](Serializer &object) { SerializeNumericStatsValue(stats.GetType(), numeric_stats.max, numeric_stats.has_max, object); }); } @@ -539,10 +539,10 @@ void NumericStats::Serialize(const BaseStatistics &stats, Serializer &serializer void NumericStats::Deserialize(Deserializer &deserializer, BaseStatistics &result) { auto &numeric_stats = NumericStats::GetDataUnsafe(result); - deserializer.ReadObject(200, "max", [&](Deserializer &object) { + deserializer.ReadObject(200, "min", [&](Deserializer &object) { DeserializeNumericStatsValue(result.GetType(), numeric_stats.min, numeric_stats.has_min, object); }); - deserializer.ReadObject(201, "min", [&](Deserializer &object) { + deserializer.ReadObject(201, "max", [&](Deserializer &object) { DeserializeNumericStatsValue(result.GetType(), numeric_stats.max, numeric_stats.has_max, object); }); } @@ -557,7 +557,7 @@ child_list_t NumericStats::ToStruct(const BaseStatistics &stats) { } template -void NumericStats::TemplatedVerify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, +void NumericStats::TemplatedVerify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count) { auto entries = vector.Values(); auto min_value = NumericStats::MinOrNull(stats); @@ -580,7 +580,7 @@ void NumericStats::TemplatedVerify(const BaseStatistics &stats, Vector &vector, } } -void NumericStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { +void NumericStats::Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count) { auto &type = stats.GetType(); switch (type.InternalType()) { case PhysicalType::BOOL: diff --git a/src/duckdb/src/storage/statistics/string_stats.cpp b/src/duckdb/src/storage/statistics/string_stats.cpp index 575dc2960..59ad51a7d 100644 --- a/src/duckdb/src/storage/statistics/string_stats.cpp +++ b/src/duckdb/src/storage/statistics/string_stats.cpp @@ -1,4 +1,5 @@ #include "duckdb/storage/statistics/string_stats.hpp" +#include "duckdb/common/types/value.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" @@ -9,66 +10,46 @@ #include "duckdb/storage/statistics/base_statistics.hpp" #include "utf8proc_wrapper.hpp" #include "duckdb/common/types/blob.hpp" +#include "duckdb/storage/statistics/stats_writer.hpp" namespace duckdb { -namespace { - -bool AllCharsEqualTo(const data_t data[], data_t comp) { - for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - if (data[i] != comp) { - return false; - } - } - return true; -} - -int StringValueComparison(const_data_ptr_t data, idx_t len, const_data_ptr_t comparison) { - for (idx_t i = 0; i < len; i++) { - if (data[i] < comparison[i]) { - return -1; - } else if (data[i] > comparison[i]) { - return 1; - } - } - return 0; -} - -void ConstructValue(const_data_ptr_t data, idx_t size, data_t target[]) { - idx_t value_size = size > StringStatsData::MAX_STRING_MINMAX_SIZE ? StringStatsData::MAX_STRING_MINMAX_SIZE : size; - memcpy(target, data, value_size); - for (idx_t i = value_size; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - target[i] = '\0'; +bool StringStatsData::HasMinStringLength() const { + if (!has_min_string_length) { + return false; } + return min_string_length != NumericLimits::Maximum(); } -} // namespace - BaseStatistics StringStats::CreateUnknown(LogicalType type) { BaseStatistics result(std::move(type)); result.InitializeUnknown(); - auto &string_data = StringStats::GetDataUnsafe(result); - for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - string_data.min[i] = 0; - string_data.max[i] = 0xFF; - } + auto &string_data = GetDataUnsafe(result); string_data.max_string_length = 0; string_data.has_max_string_length = false; + string_data.min_string_length = 0; + string_data.has_min_string_length = false; + string_data.total_string_length = 0; + string_data.has_total_string_length = false; string_data.has_unicode = true; + string_data.min_type = StringStatsType::NO_STATS; + string_data.max_type = StringStatsType::NO_STATS; return result; } BaseStatistics StringStats::CreateEmpty(LogicalType type) { BaseStatistics result(std::move(type)); result.InitializeEmpty(); - auto &string_data = StringStats::GetDataUnsafe(result); - for (idx_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - string_data.min[i] = 0xFF; - string_data.max[i] = 0; - } + auto &string_data = GetDataUnsafe(result); string_data.max_string_length = 0; string_data.has_max_string_length = true; + string_data.min_string_length = NumericLimits::Maximum(); + string_data.has_min_string_length = true; + string_data.total_string_length = 0; + string_data.has_total_string_length = true; string_data.has_unicode = false; + string_data.min_type = StringStatsType::EMPTY_STATS; + string_data.max_type = StringStatsType::EMPTY_STATS; return result; } @@ -82,163 +63,653 @@ const StringStatsData &StringStats::GetDataUnsafe(const BaseStatistics &stats) { return stats.stats_union.string_data; } +bool StatsIsSet(StringStatsType type) { + return type == StringStatsType::TRUNCATED_STATS || type == StringStatsType::EXACT_STATS; +} + bool StringStats::HasMinMax(const BaseStatistics &stats) { if (stats.GetType().id() == LogicalTypeId::SQLNULL) { return false; } - auto &string_data = StringStats::GetDataUnsafe(stats); - if (StringValueComparison(string_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.max) > 0) { - return false; + auto &string_data = GetDataUnsafe(stats); + return StatsIsSet(string_data.min_type) && StatsIsSet(string_data.max_type); +} + +bool StringStats::HasMin(const BaseStatistics &stats) { + return StatsIsSet(GetMinType(stats)); +} + +bool StringStats::HasMax(const BaseStatistics &stats) { + return StatsIsSet(GetMaxType(stats)); +} + +StringStatsType StringStats::GetMinType(const BaseStatistics &stats) { + if (stats.GetType().id() == LogicalTypeId::SQLNULL) { + return StringStatsType::NO_STATS; } - if (stats.GetType().id() == LogicalTypeId::BLOB) { - // Empty stats are allowed and usable - return true; + auto &string_data = GetDataUnsafe(stats); + return string_data.min_type; +} + +StringStatsType StringStats::GetMaxType(const BaseStatistics &stats) { + if (stats.GetType().id() == LogicalTypeId::SQLNULL) { + return StringStatsType::NO_STATS; } - // The initial min value means either "empty string" or not set. Both effectively represent no lower bound. Thus, - // return true if at least max is set. - return !AllCharsEqualTo(string_data.max, 0xFF); + auto &string_data = GetDataUnsafe(stats); + return string_data.max_type; } bool StringStats::HasMaxStringLength(const BaseStatistics &stats) { if (stats.GetType().id() == LogicalTypeId::SQLNULL) { return false; } - return StringStats::GetDataUnsafe(stats).has_max_string_length; + return GetDataUnsafe(stats).has_max_string_length; } uint32_t StringStats::MaxStringLength(const BaseStatistics &stats) { if (!HasMaxStringLength(stats)) { throw InternalException("MaxStringLength called on statistics that does not have a max string length"); } - return StringStats::GetDataUnsafe(stats).max_string_length; + return GetDataUnsafe(stats).max_string_length; +} + +optional_idx StringStats::MinStringLength(const BaseStatistics &stats) { + auto &data = GetDataUnsafe(stats); + if (!data.HasMinStringLength()) { + return optional_idx(); + } + return data.min_string_length; +} + +optional_idx StringStats::TotalStringLength(const BaseStatistics &stats) { + auto &string_stats = GetDataUnsafe(stats); + return string_stats.has_total_string_length ? string_stats.total_string_length : optional_idx(); } bool StringStats::CanContainUnicode(const BaseStatistics &stats) { if (stats.GetType().id() == LogicalTypeId::SQLNULL) { return true; } - return StringStats::GetDataUnsafe(stats).has_unicode; + return GetDataUnsafe(stats).has_unicode; } -string GetStringMinMaxValue(const data_t data[]) { - idx_t len; - for (len = 0; len < StringStatsData::MAX_STRING_MINMAX_SIZE; len++) { - if (!data[len]) { - break; - } +string StringStats::Min(const BaseStatistics &stats) { + auto &string_data = GetDataUnsafe(stats); + if (!StatsIsSet(string_data.min_type)) { + throw InternalException("StringStats::Min called but no string stats were found - call StringStats::HasMin or " + "StringStats::HasMinMax first"); } - return string(const_char_ptr_cast(data), len); + return string_data.min.GetString(); } -string StringStats::Min(const BaseStatistics &stats) { - return GetStringMinMaxValue(StringStats::GetDataUnsafe(stats).min); +string StringStats::Max(const BaseStatistics &stats) { + auto &string_data = GetDataUnsafe(stats); + if (!StatsIsSet(string_data.max_type)) { + throw InternalException("StringStats::Max called but no string stats were found - call StringStats::HasMax or " + "StringStats::HasMinMax first"); + } + return string_data.max.GetString(); } -string StringStats::Max(const BaseStatistics &stats) { - return GetStringMinMaxValue(StringStats::GetDataUnsafe(stats).max); +Value StringStats::TryGetValidMin(const BaseStatistics &stats) { + auto min = Min(stats); + if (GetMinType(stats) == StringStatsType::EXACT_STATS) { + return Value(std::move(min)); + } + string result; + if (!Utf8Proc::ValidLowerBound(min, result)) { + return Value(); + } + return Value(std::move(result)); +} + +Value StringStats::TryGetValidMax(const BaseStatistics &stats) { + auto max = Max(stats); + if (GetMaxType(stats) == StringStatsType::EXACT_STATS) { + return Value(std::move(max)); + } + string result; + if (!Utf8Proc::ValidUpperBound(max, result)) { + return Value(); + } + return Value(std::move(result)); } void StringStats::ResetMaxStringLength(BaseStatistics &stats) { - StringStats::GetDataUnsafe(stats).has_max_string_length = false; + GetDataUnsafe(stats).has_max_string_length = false; } void StringStats::SetMaxStringLength(BaseStatistics &stats, uint32_t length) { - auto &data = StringStats::GetDataUnsafe(stats); + auto &data = GetDataUnsafe(stats); data.has_max_string_length = true; data.max_string_length = length; } void StringStats::SetContainsUnicode(BaseStatistics &stats) { - StringStats::GetDataUnsafe(stats).has_unicode = true; + GetDataUnsafe(stats).has_unicode = true; +} + +static void LegacyConstructValue(const_data_ptr_t data, idx_t size, data_t target[]) { + constexpr auto max_length = StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE; + idx_t value_size = size > max_length ? max_length : size; + memcpy(target, data, value_size); + for (idx_t i = value_size; i < max_length; i++) { + target[i] = '\0'; + } +} + +void LegacyConstructMinMax(string_t input, StringStatsType type, data_t result[], bool is_min) { + auto input_data = const_data_ptr_cast(input.GetData()); + if (type == StringStatsType::EMPTY_STATS) { + // for empty stats we serialize min as the maximum value, and max as the minimum value + data_t empty_byte = is_min ? 0xFF : 0x00; + memset(result, empty_byte, StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE); + } else if (type == StringStatsType::NO_STATS) { + // for no stats we serialize min as 0x00..., and max as 0xFF... + data_t no_stats_byte = is_min ? 0x00 : 0xFF; + memset(result, no_stats_byte, StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE); + } else { + LegacyConstructValue(input_data, input.GetSize(), result); + } +} + +string_t LegacyReadMinMax(const data_t result[]) { + // truncate any trailing 0-bytes - there's no reason to keep them around + uint32_t len = StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE; + for (; len > 0; len--) { + if (result[len - 1] != '\0') { + break; + } + } + return string_t(const_char_ptr_cast(result), len); +} + +static bool LegacyAllCharsEqualTo(const data_t data[], data_t comp) { + for (idx_t i = 0; i < StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE; i++) { + if (data[i] != comp) { + return false; + } + } + return true; +} + +StringStatsType LegacyGetMinMaxType(const data_t min[], const data_t max[], bool has_max_length, idx_t max_length) { + if (min[0] > max[0]) { + // if min > max then we have empty stats + return StringStatsType::EMPTY_STATS; + } + if (LegacyAllCharsEqualTo(min, 0x00) && LegacyAllCharsEqualTo(max, 0xFF)) { + // if min is 0x00... and max is 0xFF... then we have no min/max (nothing can ever be pruned) + return StringStatsType::NO_STATS; + } + if (has_max_length && max_length <= StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE) { + // when: + // (1) max length is known, and + // (2) max length is <= 8, and + // (3) min/max are equivalent to the max length + // We know the min/max are exact because they have not been truncated + // if min/max are not equal to the max length - we can't distinguish between e.g. `hello\0` and `hello` + if (max_length == 0 || (min[max_length - 1] != '\0' && max[max_length - 1] != '\0')) { + return StringStatsType::EXACT_STATS; + } + } + return StringStatsType::TRUNCATED_STATS; +} + +struct StringStatsField { + explicit StringStatsField(uint32_t data_p = 0U) : data(data_p) { + } + +public: + static constexpr uint32_t NO_UNICODE_FLAG = 1U << 0U; + static constexpr uint32_t MIN_EXACT_OR_EMPTY_FLAG = 1U << 1U; + static constexpr uint32_t MAX_EXACT_OR_EMPTY_FLAG = 1U << 2U; + static constexpr uint32_t LENGTH_MASK = 0xFU; + static constexpr uint32_t MIN_LENGTH_POSITION = 3U; + static constexpr uint32_t MAX_LENGTH_POSITION = 7U; + static constexpr uint32_t MIN_LENGTH_MASK = LENGTH_MASK << MIN_LENGTH_POSITION; + static constexpr uint32_t MAX_LENGTH_MASK = LENGTH_MASK << MAX_LENGTH_POSITION; + static constexpr uint32_t MIN_STRING_LENGTH_POSITION = 11U; + static constexpr uint32_t MIN_STRING_LENGTH_BIT_SIZE = 21U; + static constexpr uint32_t MIN_STRING_LENGTH_MASK = ((1U << MIN_STRING_LENGTH_BIT_SIZE) - 1) + << MIN_STRING_LENGTH_POSITION; + + void SetHasUnicode(bool has_unicode) { + // default is has_unicode = true + if (!has_unicode) { + data |= NO_UNICODE_FLAG; + } else { + data &= ~NO_UNICODE_FLAG; + } + } + bool GetHasUnicode() const { + return !(data & NO_UNICODE_FLAG); + } + void SetMinType(StringStatsType min_type) { + // we use one bit to encode these 4 values + // there are 4 possibilities + // min_max missing - bit set to 1 - EMPTY_STATS + // min_max missing - bit set to 0 - NO_STATS + // min_max present - bit set to 1 - EXACT_STATS + // min_max present - bit set to 0 - TRUNCATED_STATS + if (min_type == StringStatsType::EXACT_STATS || min_type == StringStatsType::EMPTY_STATS) { + data |= MIN_EXACT_OR_EMPTY_FLAG; + } else { + data &= ~MIN_EXACT_OR_EMPTY_FLAG; + } + } + StringStatsType GetMinType(bool has_min) const { + if (data & MIN_EXACT_OR_EMPTY_FLAG) { + return has_min ? StringStatsType::EXACT_STATS : StringStatsType::EMPTY_STATS; + } else { + return has_min ? StringStatsType::TRUNCATED_STATS : StringStatsType::NO_STATS; + } + } + void SetMaxType(StringStatsType max_type) { + if (max_type == StringStatsType::EXACT_STATS || max_type == StringStatsType::EMPTY_STATS) { + data |= MAX_EXACT_OR_EMPTY_FLAG; + } else { + data &= ~MAX_EXACT_OR_EMPTY_FLAG; + } + } + StringStatsType GetMaxType(bool has_max) const { + if (data & MAX_EXACT_OR_EMPTY_FLAG) { + return has_max ? StringStatsType::EXACT_STATS : StringStatsType::EMPTY_STATS; + } else { + return has_max ? StringStatsType::TRUNCATED_STATS : StringStatsType::NO_STATS; + } + } + void SetMinLength(idx_t min_length) { + if (min_length > 12) { + throw InternalException("StringStatsField::SetMinLength should be at most 12"); + } + data = + (data & ~MIN_LENGTH_MASK) | ((static_cast(min_length) << MIN_LENGTH_POSITION) & MIN_LENGTH_MASK); + } + uint32_t GetMinLength() const { + auto result = (data >> MIN_LENGTH_POSITION) & LENGTH_MASK; + if (result > 12) { + throw SerializationException("Invalid string stats - min length cannot be more than 12"); + } + return result; + } + void SetMaxLength(idx_t max_length) { + if (max_length > 12) { + throw InternalException("StringStatsField::SetMaxLength should be at most 12"); + } + data = + (data & ~MAX_LENGTH_MASK) | ((static_cast(max_length) << MAX_LENGTH_POSITION) & MAX_LENGTH_MASK); + } + uint32_t GetMaxLength() const { + auto result = (data >> MAX_LENGTH_POSITION) & LENGTH_MASK; + if (result > 12) { + throw SerializationException("Invalid string stats - min length cannot be more than 12"); + } + return result; + } + + uint32_t GetData() const { + return data; + } + + void SetMinStringLength(uint32_t min_str_length) { + if (min_str_length >= StringStatsData::MAXIMUM_MIN_STRING_LENGTH) { + // exceeds max length - set to max value to indicate "we know the min string length, it is >= the max) + min_str_length = StringStatsData::MAXIMUM_MIN_STRING_LENGTH; + } else { + // increment by 1 (0 is reserved for not set) + min_str_length++; + } + data = (data & ~MIN_STRING_LENGTH_MASK) | + ((min_str_length << MIN_STRING_LENGTH_POSITION) & MIN_STRING_LENGTH_MASK); + } + bool HasMinStringLength() const { + return GetMinStringLengthInternal() > 0; + } + uint32_t GetMinStringLength() const { + auto min_str_len = GetMinStringLengthInternal(); + if (min_str_len == StringStatsData::MAXIMUM_MIN_STRING_LENGTH) { + // we know the min string length is >= MAXIMUM_MIN_STRING_LENGTH + // return UINT_MAX to indicate this + return NumericLimits::Maximum(); + } + if (min_str_len == 0) { + throw InternalException("GetMinStringLength() called but min string length is not set"); + } + return min_str_len - 1; + } + +private: + uint32_t GetMinStringLengthInternal() const { + return (data & MIN_STRING_LENGTH_MASK) >> MIN_STRING_LENGTH_POSITION; + } + +private: + uint32_t data; +}; + +void TruncateStatsIfRequired(idx_t &len, StringStatsType &type) { + if (len > StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE) { + len = StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE; + type = StringStatsType::TRUNCATED_STATS; + } } void StringStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { - auto &string_data = StringStats::GetDataUnsafe(stats); - serializer.WriteProperty(200, "min", string_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); - serializer.WriteProperty(201, "max", string_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); - serializer.WriteProperty(202, "has_unicode", string_data.has_unicode); - serializer.WriteProperty(203, "has_max_string_length", string_data.has_max_string_length); - serializer.WriteProperty(204, "max_string_length", string_data.max_string_length); + auto &string_data = GetDataUnsafe(stats); + if (!serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + // targeting old storage: use legacy serialize + data_t min_data[StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE]; + data_t max_data[StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE]; + LegacyConstructMinMax(string_data.min, string_data.min_type, min_data, true); + LegacyConstructMinMax(string_data.max, string_data.max_type, max_data, false); + serializer.WriteProperty(200, "min", min_data, StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE); + serializer.WriteProperty(201, "max", max_data, StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE); + serializer.WriteProperty(202, "has_unicode", string_data.has_unicode); + serializer.WriteProperty(203, "has_max_string_length", string_data.has_max_string_length); + serializer.WriteProperty(204, "max_string_length", string_data.max_string_length); + } else { + // new serialize - we try to pack elements into a single integer to save on fields + StringStatsField field; + field.SetHasUnicode(string_data.has_unicode); + if (string_data.has_max_string_length) { + // use the same field as before for max string length + serializer.WriteProperty(204, "max_string_length", string_data.max_string_length); + } + auto min_type = string_data.min_type; + auto max_type = string_data.max_type; + if (StatsIsSet(min_type) && StatsIsSet(max_type)) { + // write the min/max together as a single blob field, if they are set + constexpr uint32_t min_max_size = StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE * 2; + data_t min_max_data[min_max_size] = {}; + auto min_size = string_data.min.GetSize(); + auto max_size = string_data.max.GetSize(); + TruncateStatsIfRequired(min_size, min_type); + TruncateStatsIfRequired(max_size, max_type); + field.SetMinLength(min_size); + field.SetMaxLength(max_size); + + memcpy(min_max_data, string_data.min.GetData(), min_size); + memcpy(min_max_data + StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE, string_data.max.GetData(), max_size); + + serializer.WriteProperty(205, "min_max", min_max_data, min_max_size); + } + field.SetMinType(min_type); + field.SetMaxType(max_type); + if (string_data.has_min_string_length) { + field.SetMinStringLength(string_data.min_string_length); + } + // write the stats as a single packed field + serializer.WriteProperty(206, "packed_field", field.GetData()); + serializer.WritePropertyWithDefault(207, "total_string_length", + string_data.has_total_string_length ? string_data.total_string_length + : optional_idx()); + } } void StringStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { - auto &string_data = StringStats::GetDataUnsafe(base); - deserializer.ReadProperty(200, "min", string_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); - deserializer.ReadProperty(201, "max", string_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); - deserializer.ReadProperty(202, "has_unicode", string_data.has_unicode); - deserializer.ReadProperty(203, "has_max_string_length", string_data.has_max_string_length); - deserializer.ReadProperty(204, "max_string_length", string_data.max_string_length); + auto &string_data = GetDataUnsafe(base); + // try to deserialize legacy stats first + data_t legacy_min_data[StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE]; + data_t legacy_max_data[StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE]; + auto has_min = + deserializer.ReadOptionalProperty(200, "min", legacy_min_data, StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE); + auto has_max = + deserializer.ReadOptionalProperty(201, "max", legacy_max_data, StringStatsData::LEGACY_MAX_STRING_MINMAX_SIZE); + if (has_min && has_max) { + // legacy stats + deserializer.ReadProperty(202, "has_unicode", string_data.has_unicode); + deserializer.ReadProperty(203, "has_max_string_length", string_data.has_max_string_length); + deserializer.ReadProperty(204, "max_string_length", string_data.max_string_length); + string_data.min = AssignString(base, LegacyReadMinMax(legacy_min_data), true); + string_data.max = AssignString(base, LegacyReadMinMax(legacy_max_data), false); + string_data.min_type = LegacyGetMinMaxType(legacy_min_data, legacy_max_data, string_data.has_max_string_length, + string_data.max_string_length); + string_data.max_type = string_data.min_type; + string_data.has_total_string_length = false; + string_data.has_min_string_length = false; + } else { + auto max_string_length = + deserializer.ReadPropertyWithExplicitDefault(204, "max_string_length", NumericLimits::Maximum()); + if (max_string_length == NumericLimits::Maximum()) { + string_data.has_max_string_length = false; + } else { + string_data.has_max_string_length = true; + string_data.max_string_length = max_string_length; + } + constexpr uint32_t min_max_size = StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE * 2; + data_t min_max_data[min_max_size]; + auto has_min_max = deserializer.ReadOptionalProperty(205, "min_max", min_max_data, min_max_size); + auto packed_field = deserializer.ReadProperty(206, "packed_field"); + StringStatsField field(packed_field); + string_data.min_type = field.GetMinType(has_min_max); + string_data.max_type = field.GetMaxType(has_min_max); + if (has_min_max) { + auto min = string_t(const_char_ptr_cast(min_max_data), field.GetMinLength()); + auto max = string_t(const_char_ptr_cast(min_max_data + StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE), + field.GetMaxLength()); + string_data.min = AssignString(base, min, true); + string_data.max = AssignString(base, max, false); + } + string_data.has_unicode = field.GetHasUnicode(); + if (field.HasMinStringLength()) { + string_data.has_min_string_length = true; + string_data.min_string_length = field.GetMinStringLength(); + } else { + string_data.has_min_string_length = false; + } + auto total_string_length = deserializer.ReadPropertyWithDefault(207, "total_string_length"); + if (total_string_length.IsValid()) { + string_data.has_total_string_length = true; + string_data.total_string_length = total_string_length.GetIndex(); + } else { + string_data.has_total_string_length = false; + } + } } -void StringStats::Update(BaseStatistics &stats, const string_t &value) { - auto data = const_data_ptr_cast(value.GetData()); - auto size = value.GetSize(); +void StringStats::FromConstant(BaseStatistics &stats, string_t value) { + static constexpr const idx_t CONSTANT_STATS_BOUND = 100000; - //! we can only fit 8 bytes, so we might need to trim our string - // construct the value - data_t target[StringStatsData::MAX_STRING_MINMAX_SIZE]; - ConstructValue(data, size, target); + // use the string stats writer for setting stats + StatsWriter writer(stats.GetType()); + writer.Update(value); + writer.Merge(stats); - // update the min and max - auto &string_data = StringStats::GetDataUnsafe(stats); - if (StringValueComparison(target, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.min) < 0) { - memcpy(string_data.min, target, StringStatsData::MAX_STRING_MINMAX_SIZE); + // just directly assign the min/max, since the stats writer truncates + // ... except if it is REALLY long + if (value.GetSize() <= CONSTANT_STATS_BOUND) { + SetMin(stats, value, StringStatsType::EXACT_STATS); + SetMax(stats, value, StringStatsType::EXACT_STATS); } - if (StringValueComparison(target, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.max) > 0) { - memcpy(string_data.max, target, StringStatsData::MAX_STRING_MINMAX_SIZE); + + // constants don't have a "total string length" + auto &string_stats = GetDataUnsafe(stats); + string_stats.has_total_string_length = false; +} + +void StringStats::MergeInConstant(BaseStatistics &stats, string_t input) { + auto constant_stats = CreateEmpty(stats.GetType()); + FromConstant(constant_stats, input); + Merge(stats, constant_stats); +} + +void StringStats::Update(BaseStatistics &stats, const string_t &value) { + MergeInConstant(stats, value); +} + +struct StringData { + unique_ptr data; + idx_t capacity = 0; + + string_t AssignString(const string_t &value) { + if (value.GetSize() > capacity) { + auto next_capacity = NextPowerOfTwo(value.GetSize()); + data = make_uniq_array(next_capacity); + capacity = next_capacity; + } + memcpy(data.get(), value.GetData(), value.GetSize()); + return string_t(const_char_ptr_cast(data.get()), static_cast(value.GetSize())); } - if (size > string_data.max_string_length) { - string_data.max_string_length = UnsafeNumericCast(size); +}; + +struct StringStatsExtraData : public ExtraStatsData { + StringData string_data[2]; +}; + +string_t StringStats::AssignString(BaseStatistics &stats, const string_t &input, bool is_min) { + if (input.IsInlined()) { + return input; } - if (stats.GetType().id() == LogicalTypeId::VARCHAR && !string_data.has_unicode) { - auto unicode = Utf8Proc::Analyze(const_char_ptr_cast(data), size); - if (unicode == UnicodeType::UTF8) { - string_data.has_unicode = true; - } else if (unicode == UnicodeType::INVALID) { - throw ErrorManager::InvalidUnicodeError(string(const_char_ptr_cast(data), size), - "segment statistics update"); - } + if (!stats.extra_data) { + stats.extra_data = make_uniq(); } + auto &extra_data = stats.extra_data->Cast(); + auto data_idx = is_min ? 0ULL : 1ULL; + return extra_data.string_data[data_idx].AssignString(input); +} + +void StringStats::SetMin(BaseStatistics &stats, const string_t &value, StringStatsType type) { + auto &stats_data = GetDataUnsafe(stats); + stats_data.min = AssignString(stats, value, true); + stats_data.min_type = type; +} + +void StringStats::SetMax(BaseStatistics &stats, const string_t &value, StringStatsType type) { + auto &stats_data = GetDataUnsafe(stats); + stats_data.max = AssignString(stats, value, false); + stats_data.max_type = type; } void StringStats::SetMin(BaseStatistics &stats, const string_t &value) { - ConstructValue(const_data_ptr_cast(value.GetData()), value.GetSize(), GetDataUnsafe(stats).min); + SetMin(stats, value, StringStatsType::TRUNCATED_STATS); } void StringStats::SetMax(BaseStatistics &stats, const string_t &value) { - ConstructValue(const_data_ptr_cast(value.GetData()), value.GetSize(), GetDataUnsafe(stats).max); + SetMax(stats, value, StringStatsType::TRUNCATED_STATS); } -void StringStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { - if (other.GetType().id() == LogicalTypeId::VALIDITY) { +void StringStats::Copy(BaseStatistics &stats, const BaseStatistics &other) { + auto &string_data = GetDataUnsafe(stats); + auto &other_data = GetDataUnsafe(other); + if (StatsIsSet(other_data.min_type)) { + string_data.min = AssignString(stats, other_data.min, true); + } + if (StatsIsSet(other_data.max_type)) { + string_data.max = AssignString(stats, other_data.max, false); + } +} + +void StringStats::MergeStats(BaseStatistics &stats, string_t &target, StringStatsType &target_type, + const string_t &source, StringStatsType source_type, bool is_min) { + if (target_type == StringStatsType::NO_STATS || source_type == StringStatsType::NO_STATS) { + // no min/max available - result is no min/max + target_type = StringStatsType::NO_STATS; return; } - if (other.GetType().id() == LogicalTypeId::SQLNULL) { + if (source_type == StringStatsType::EMPTY_STATS) { + // source is empty - nothing to update + return; + } + if (target_type == StringStatsType::EMPTY_STATS) { + // we don't have min/max - copy them from the target + target = AssignString(stats, source, is_min); + target_type = source_type; return; } - auto &string_data = StringStats::GetDataUnsafe(stats); - auto &other_data = StringStats::GetDataUnsafe(other); - if (StringValueComparison(other_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.min) < 0) { - memcpy(string_data.min, other_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE); + // both min/max stats are there - compare + bool new_is_more_extreme; + if (is_min) { + new_is_more_extreme = LessThan::Operation(source, target); + } else { + new_is_more_extreme = GreaterThan::Operation(source, target); } - if (StringValueComparison(other_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.max) > 0) { - memcpy(string_data.max, other_data.max, StringStatsData::MAX_STRING_MINMAX_SIZE); + if (!new_is_more_extreme) { + // old value is more extreme - bail + return; } + // assign the new value + target = AssignString(stats, source, is_min); + target_type = source_type; +} + +void StringStats::Merge(BaseStatistics &stats, const StringStatsData &other_data, StatsMergeType merge_type) { + auto &string_data = GetDataUnsafe(stats); + + // merge min/max + MergeStats(stats, string_data.min, string_data.min_type, other_data.min, other_data.min_type, true); + MergeStats(stats, string_data.max, string_data.max_type, other_data.max, other_data.max_type, false); string_data.has_unicode = string_data.has_unicode || other_data.has_unicode; string_data.has_max_string_length = string_data.has_max_string_length && other_data.has_max_string_length; string_data.max_string_length = MaxValue(string_data.max_string_length, other_data.max_string_length); + string_data.has_min_string_length = string_data.has_min_string_length && other_data.has_min_string_length; + string_data.min_string_length = MinValue(string_data.min_string_length, other_data.min_string_length); + if (merge_type == StatsMergeType::EXPAND_BOUNDS) { + // when expanding bounds we cannot sum total_string_length — the other side may overlap with our data + string_data.has_total_string_length = false; + } else { + string_data.has_total_string_length = string_data.has_total_string_length && other_data.has_total_string_length; + if (string_data.has_total_string_length) { + string_data.total_string_length += other_data.total_string_length; + } + } +} + +void StringStats::Merge(BaseStatistics &stats, const BaseStatistics &other, StatsMergeType merge_type) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + if (other.GetType().id() == LogicalTypeId::SQLNULL) { + return; + } + auto &other_data = GetDataUnsafe(other); + Merge(stats, other_data, merge_type); +} + +string_t ReadWriterStats(const data_t data[], idx_t size, StringStatsType &type) { + if (size <= StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE) { + // exact length + type = StringStatsType::EXACT_STATS; + return string_t(const_char_ptr_cast(data), static_cast(size)); + } + // truncated + type = StringStatsType::TRUNCATED_STATS; + return string_t(const_char_ptr_cast(data), StringStatsData::CURRENT_MAX_STRING_MINMAX_SIZE); +} + +void StringStats::Merge(BaseStatistics &stats, const StatsWriter &stats_writer) { + if (!stats_writer.AnyValid()) { + return; + } + // construct string stats data from the writer + StringStatsData other_data; + other_data.min = ReadWriterStats(stats_writer.min, stats_writer.min_size, other_data.min_type); + other_data.max = ReadWriterStats(stats_writer.max, stats_writer.max_size, other_data.max_type); + other_data.has_unicode = stats_writer.has_unicode; + other_data.has_max_string_length = true; + other_data.max_string_length = stats_writer.max_string_length; + other_data.has_min_string_length = true; + other_data.min_string_length = stats_writer.min_string_length; + other_data.has_total_string_length = true; + other_data.total_string_length = stats_writer.total_string_length; + Merge(stats, other_data); } FilterPropagateResult StringStats::CheckZonemap(const BaseStatistics &stats, ExpressionType comparison_type, array_ptr constants) { - auto &string_data = StringStats::GetDataUnsafe(stats); + auto &string_data = GetDataUnsafe(stats); D_ASSERT(stats.CanHaveNoNull()); for (auto &constant_value : constants) { D_ASSERT(constant_value.type() == stats.GetType()); D_ASSERT(!constant_value.IsNull()); auto &constant = StringValue::Get(constant_value); - auto prune_result = CheckZonemap(string_data.min, StringStatsData::MAX_STRING_MINMAX_SIZE, string_data.max, - StringStatsData::MAX_STRING_MINMAX_SIZE, comparison_type, constant); + FilterPropagateResult prune_result; + if (HasMinMax(stats)) { + prune_result = CheckZonemap(string_data.min, string_data.min_type, string_data.max, string_data.max_type, + comparison_type, constant); + } else { + prune_result = FilterPropagateResult::NO_PRUNING_POSSIBLE; + } if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } else if (prune_result == FilterPropagateResult::FILTER_ALWAYS_TRUE) { @@ -248,13 +719,19 @@ FilterPropagateResult StringStats::CheckZonemap(const BaseStatistics &stats, Exp return FilterPropagateResult::FILTER_ALWAYS_FALSE; } -FilterPropagateResult StringStats::CheckZonemap(const_data_ptr_t min_data, idx_t min_len, const_data_ptr_t max_data, - idx_t max_len, ExpressionType comparison_type, const string &constant) { - auto data = const_data_ptr_cast(constant.c_str()); - idx_t size = constant.size(); +int8_t CompareStringStats(string_t input, string_t stats, StringStatsType type) { + if (type == StringStatsType::TRUNCATED_STATS && input.GetSize() > stats.GetSize()) { + // if the stats are truncated we can only compare at most the bytes as are present in the stats + return Comparator::Operation(string_t(input.GetData(), static_cast(stats.GetSize())), stats); + } + return Comparator::Operation(input, stats); +} - int min_comp = StringValueComparison(data, MinValue(min_len, size), min_data); - int max_comp = StringValueComparison(data, MinValue(max_len, size), max_data); +FilterPropagateResult StringStats::CheckZonemap(string_t min, StringStatsType min_type, string_t max, + StringStatsType max_type, ExpressionType comparison_type, + string_t constant) { + auto min_comp = CompareStringStats(constant, min, min_type); + auto max_comp = CompareStringStats(constant, max, max_type); switch (comparison_type) { case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_NOT_DISTINCT_FROM: @@ -288,37 +765,31 @@ FilterPropagateResult StringStats::CheckZonemap(const_data_ptr_t min_data, idx_t } } -static uint32_t GetValidMinMaxSubstring(const_data_ptr_t data) { - for (uint32_t i = 0; i < StringStatsData::MAX_STRING_MINMAX_SIZE; i++) { - if (data[i] == '\0') { - return i; - } - } - return StringStatsData::MAX_STRING_MINMAX_SIZE; -} - child_list_t StringStats::ToStruct(const BaseStatistics &stats) { child_list_t result; - auto &string_data = StringStats::GetDataUnsafe(stats); - auto min_len = GetValidMinMaxSubstring(string_data.min); - auto max_len = GetValidMinMaxSubstring(string_data.max); - string_t min_str(const_char_ptr_cast(string_data.min), min_len); - string_t max_str(const_char_ptr_cast(string_data.max), max_len); - if (StringStats::HasMinMax(stats)) { - result.emplace_back("min", Blob::ToString(min_str)); - result.emplace_back("max", Blob::ToString(max_str)); + auto &string_data = GetDataUnsafe(stats); + if (HasMinMax(stats)) { + result.emplace_back("min", Blob::ToString(string_data.min)); + result.emplace_back("max", Blob::ToString(string_data.max)); } result.emplace_back("has_unicode", Value::BOOLEAN(string_data.has_unicode)); - if (StringStats::HasMaxStringLength(stats)) { + if (string_data.HasMinStringLength()) { + result.emplace_back("min_string_length", Value::UBIGINT(string_data.min_string_length)); + } + if (HasMaxStringLength(stats)) { result.emplace_back("max_string_length", Value::UBIGINT(string_data.max_string_length)); } + if (string_data.has_total_string_length) { + result.emplace_back("total_string_length", Value::UBIGINT(string_data.total_string_length)); + } return result; } -void StringStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { - auto &string_data = StringStats::GetDataUnsafe(stats); +void StringStats::Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count) { + auto &string_data = GetDataUnsafe(stats); auto entries = vector.Values(); + idx_t str_len_in_vector = 0; for (idx_t i = 0; i < count; i++) { auto idx = sel.get_index(i); auto entry = entries[idx]; @@ -334,6 +805,11 @@ void StringStats::Verify(const BaseStatistics &stats, Vector &vector, const Sele "Statistics mismatch: string value exceeds maximum string length.\nStatistics: %s\nVector: %s", stats.ToString(), vector.ToString()); } + if (string_data.HasMinStringLength() && len < string_data.min_string_length) { + throw InternalException( + "Statistics mismatch: string value exceeds minimum string length.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString()); + } if (stats.GetType().id() == LogicalTypeId::VARCHAR && !string_data.has_unicode) { auto unicode = Utf8Proc::Analyze(data, len); if (unicode == UnicodeType::UTF8) { @@ -344,18 +820,29 @@ void StringStats::Verify(const BaseStatistics &stats, Vector &vector, const Sele throw InternalException("Invalid unicode detected in vector: %s", vector.ToString()); } } - if (StringValueComparison(const_data_ptr_cast(data), - MinValue(len, StringStatsData::MAX_STRING_MINMAX_SIZE), string_data.min) < 0) { - throw InternalException("Statistics mismatch: value is smaller than min.\nStatistics: %s\nVector: %s", - stats.ToString(), vector.ToString()); + if (StatsIsSet(string_data.min_type)) { + auto min_cmp = CompareStringStats(value, string_data.min, string_data.min_type); + if (min_cmp < 0) { + throw InternalException("Statistics mismatch: value is smaller than min.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString()); + } } - if (StringValueComparison(const_data_ptr_cast(data), - MinValue(len, StringStatsData::MAX_STRING_MINMAX_SIZE), string_data.max) > 0) { - throw InternalException("Statistics mismatch: value is bigger than max.\nStatistics: %s\nVector: %s", - stats.ToString(), vector.ToString()); + if (StatsIsSet(string_data.max_type)) { + auto max_cmp = CompareStringStats(value, string_data.max, string_data.max_type); + if (max_cmp > 0) { + throw InternalException("Statistics mismatch: value is bigger than max.\nStatistics: %s\nVector: %s", + stats.ToString(), vector.ToString()); + } } + str_len_in_vector += len; // LCOV_EXCL_STOP } + // this is not exhaustive but it's better than nothing + if (string_data.has_total_string_length && str_len_in_vector > string_data.total_string_length) { + throw InternalException( + "Statistics mismatch: string length in vector exceeds total string length reported in stats", + stats.ToString(), vector.ToString()); + } } } // namespace duckdb diff --git a/src/duckdb/src/storage/statistics/struct_stats.cpp b/src/duckdb/src/storage/statistics/struct_stats.cpp index 59f468cec..b81b59a31 100644 --- a/src/duckdb/src/storage/statistics/struct_stats.cpp +++ b/src/duckdb/src/storage/statistics/struct_stats.cpp @@ -85,7 +85,7 @@ void StructStats::Copy(BaseStatistics &stats, const BaseStatistics &other) { } } -void StructStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { +void StructStats::Merge(BaseStatistics &stats, const BaseStatistics &other, StatsMergeType merge_type) { if (other.GetType().id() == LogicalTypeId::VALIDITY) { return; } @@ -93,7 +93,7 @@ void StructStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { D_ASSERT(StructType::GetChildCount(stats.GetType()) == StructType::GetChildCount(other.GetType())); auto child_count = StructType::GetChildCount(stats.GetType()); for (idx_t i = 0; i < child_count; i++) { - stats.child_stats[i].Merge(other.child_stats[i]); + stats.child_stats[i].Merge(other.child_stats[i], merge_type); } } @@ -130,8 +130,8 @@ child_list_t StructStats::ToStruct(const BaseStatistics &stats) { return result; } -void StructStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { - auto &child_entries = StructVector::GetEntries(vector); +void StructStats::Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count) { + const auto &child_entries = StructVector::GetEntries(vector); for (idx_t i = 0; i < child_entries.size(); i++) { stats.child_stats[i].Verify(child_entries[i], sel, count, true); } diff --git a/src/duckdb/src/storage/statistics/variant_stats.cpp b/src/duckdb/src/storage/statistics/variant_stats.cpp index 6f3ece448..71b5a2dd5 100644 --- a/src/duckdb/src/storage/statistics/variant_stats.cpp +++ b/src/duckdb/src/storage/statistics/variant_stats.cpp @@ -165,7 +165,7 @@ optional_ptr VariantShreddedStats::FindChildStats(const Ba auto &object_fields = StructType::GetChildTypes(typed_value_type); for (idx_t i = 0; i < object_fields.size(); i++) { auto &object_field = object_fields[i]; - if (StringUtil::CIEquals(object_field.first, component.key)) { + if (object_field.first == component.key) { return StructStats::GetChildStats(typed_value_stats, i); } } @@ -388,7 +388,8 @@ static Value GetShreddedStatsStruct(const BaseStatistics &stats, bool fully_shre std::sort(indices.begin(), indices.end(), [&](const idx_t &lhs, const idx_t &rhs) { auto &a = fields[lhs].first; auto &b = fields[rhs].first; - return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); + return std::lexicographical_compare(a.GetIdentifierName().begin(), a.GetIdentifierName().end(), + b.GetIdentifierName().begin(), b.GetIdentifierName().end()); }); for (idx_t i = 0; i < indices.size(); i++) { auto &child_stats = StructStats::GetChildStats(typed_value, indices[i]); @@ -520,7 +521,7 @@ bool VariantStats::MergeShredding(const BaseStatistics &stats, const BaseStatist for (idx_t i = 0; i < stats_object_children.size(); i++) { auto &stats_object_child = stats_object_children[i]; - auto other_it = key_to_index.find(stats_object_child.first); + auto other_it = key_to_index.find(stats_object_child.first.GetIdentifierName()); if (other_it == key_to_index.end()) { continue; } @@ -672,7 +673,7 @@ void VariantStats::Copy(BaseStatistics &stats, const BaseStatistics &other) { } } -void VariantStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { +void VariantStats::Verify(const BaseStatistics &stats, const Vector &vector, const SelectionVector &sel, idx_t count) { // TODO: Verify stats } diff --git a/src/duckdb/src/storage/storage_info.cpp b/src/duckdb/src/storage/storage_info.cpp index 382ad9502..8059b1f04 100644 --- a/src/duckdb/src/storage/storage_info.cpp +++ b/src/duckdb/src/storage/storage_info.cpp @@ -2,6 +2,7 @@ #include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/optional_idx.hpp" +#include "duckdb/storage/storage_manager.hpp" namespace duckdb { constexpr idx_t Storage::MAX_ROW_GROUP_SIZE; @@ -10,190 +11,227 @@ constexpr idx_t Storage::MIN_BLOCK_ALLOC_SIZE; constexpr idx_t Storage::DEFAULT_BLOCK_HEADER_SIZE; constexpr uint64_t MainHeader::DEFAULT_ENCRYPTION_KEY_LENGTH; -const uint64_t VERSION_NUMBER = 64; +const uint64_t VERSION_NUMBER = 69; const uint64_t VERSION_NUMBER_LOWER = 64; const uint64_t VERSION_NUMBER_UPPER = 69; static_assert(VERSION_NUMBER_LOWER <= VERSION_NUMBER, "Check on VERSION_NUMBER lower bound"); static_assert(VERSION_NUMBER <= VERSION_NUMBER_UPPER, "Check on VERSION_NUMBER upper bound"); -struct StorageVersionInfo { - const char *version_name; - idx_t storage_version; -}; - -struct SerializationVersionInfo { - const char *version_name; - idx_t serialization_version; -}; - // These sections are automatically generated by scripts/generate_storage_info.py // Do not edit them manually, your changes will be overwritten // clang-format off -// START OF STORAGE VERSION INFO -const uint64_t DEFAULT_STORAGE_VERSION_INFO = 64; +// START OF STORAGE_ARRAY VERSION INFO static const StorageVersionInfo storage_version_info[] = { - {"v0.0.4", 1}, - {"v0.1.0", 1}, - {"v0.1.1", 1}, - {"v0.1.2", 1}, - {"v0.1.3", 1}, - {"v0.1.4", 1}, - {"v0.1.5", 1}, - {"v0.1.6", 1}, - {"v0.1.7", 1}, - {"v0.1.8", 1}, - {"v0.1.9", 1}, - {"v0.2.0", 1}, - {"v0.2.1", 1}, - {"v0.2.2", 4}, - {"v0.2.3", 6}, - {"v0.2.4", 11}, - {"v0.2.5", 13}, - {"v0.2.6", 15}, - {"v0.2.7", 17}, - {"v0.2.8", 18}, - {"v0.2.9", 21}, - {"v0.3.0", 25}, - {"v0.3.1", 27}, - {"v0.3.2", 31}, - {"v0.3.3", 33}, - {"v0.3.4", 33}, - {"v0.3.5", 33}, - {"v0.4.0", 33}, - {"v0.5.0", 38}, - {"v0.5.1", 38}, - {"v0.6.0", 39}, - {"v0.6.1", 39}, - {"v0.7.0", 43}, - {"v0.7.1", 43}, - {"v0.8.0", 51}, - {"v0.8.1", 51}, - {"v0.9.0", 64}, - {"v0.9.1", 64}, - {"v0.9.2", 64}, - {"v0.10.0", 64}, - {"v0.10.1", 64}, - {"v0.10.2", 64}, - {"v0.10.3", 64}, - {"v1.0.0", 64}, - {"v1.1.0", 64}, - {"v1.1.1", 64}, - {"v1.1.2", 64}, - {"v1.1.3", 64}, - {"v1.2.0", 65}, - {"v1.2.1", 65}, - {"v1.2.2", 65}, - {"v1.3.0", 66}, - {"v1.3.1", 66}, - {"v1.3.2", 66}, - {"v1.4.0", 67}, - {"v1.4.1", 67}, - {"v1.4.2", 67}, - {"v1.4.3", 67}, - {"v1.4.4", 67}, - {"v1.5.0", 68}, - {"v1.5.1", 68}, - {"v1.5.2", 68}, - {"v2.0.0", 69}, - {nullptr, 0} + {"v0.0.4", StorageVersion::V0_0_4}, + {"v0.1.0", StorageVersion::V0_1_0}, + {"v0.1.1", StorageVersion::V0_1_1}, + {"v0.1.2", StorageVersion::V0_1_2}, + {"v0.1.3", StorageVersion::V0_1_3}, + {"v0.1.4", StorageVersion::V0_1_4}, + {"v0.1.5", StorageVersion::V0_1_5}, + {"v0.1.6", StorageVersion::V0_1_6}, + {"v0.1.7", StorageVersion::V0_1_7}, + {"v0.1.8", StorageVersion::V0_1_8}, + {"v0.1.9", StorageVersion::V0_1_9}, + {"v0.2.0", StorageVersion::V0_2_0}, + {"v0.2.1", StorageVersion::V0_2_1}, + {"v0.2.2", StorageVersion::V0_2_2}, + {"v0.2.3", StorageVersion::V0_2_3}, + {"v0.2.4", StorageVersion::V0_2_4}, + {"v0.2.5", StorageVersion::V0_2_5}, + {"v0.2.6", StorageVersion::V0_2_6}, + {"v0.2.7", StorageVersion::V0_2_7}, + {"v0.2.8", StorageVersion::V0_2_8}, + {"v0.2.9", StorageVersion::V0_2_9}, + {"v0.3.0", StorageVersion::V0_3_0}, + {"v0.3.1", StorageVersion::V0_3_1}, + {"v0.3.2", StorageVersion::V0_3_2}, + {"v0.3.3", StorageVersion::V0_3_3}, + {"v0.3.4", StorageVersion::V0_3_4}, + {"v0.3.5", StorageVersion::V0_3_5}, + {"v0.4.0", StorageVersion::V0_4_0}, + {"v0.5.0", StorageVersion::V0_5_0}, + {"v0.5.1", StorageVersion::V0_5_1}, + {"v0.6.0", StorageVersion::V0_6_0}, + {"v0.6.1", StorageVersion::V0_6_1}, + {"v0.7.0", StorageVersion::V0_7_0}, + {"v0.7.1", StorageVersion::V0_7_1}, + {"v0.8.0", StorageVersion::V0_8_0}, + {"v0.8.1", StorageVersion::V0_8_1}, + {"v0.9.0", StorageVersion::V0_9_0}, + {"v0.9.1", StorageVersion::V0_9_1}, + {"v0.9.2", StorageVersion::V0_9_2}, + {"v0.10.0", StorageVersion::V0_10_0}, + {"v0.10.1", StorageVersion::V0_10_1}, + {"v0.10.2", StorageVersion::V0_10_2}, + {"v0.10.3", StorageVersion::V0_10_3}, + {"v1.0.0", StorageVersion::V1_0_0}, + {"v1.1.0", StorageVersion::V1_1_0}, + {"v1.1.1", StorageVersion::V1_1_1}, + {"v1.1.2", StorageVersion::V1_1_2}, + {"v1.1.3", StorageVersion::V1_1_3}, + {"v1.2.0", StorageVersion::V1_2_0}, + {"v1.2.1", StorageVersion::V1_2_1}, + {"v1.2.2", StorageVersion::V1_2_2}, + {"v1.3.0", StorageVersion::V1_3_0}, + {"v1.3.1", StorageVersion::V1_3_1}, + {"v1.3.2", StorageVersion::V1_3_2}, + {"v1.4.0", StorageVersion::V1_4_0}, + {"v1.4.1", StorageVersion::V1_4_1}, + {"v1.4.2", StorageVersion::V1_4_2}, + {"v1.4.3", StorageVersion::V1_4_3}, + {"v1.4.4", StorageVersion::V1_4_4}, + {"v1.5.0", StorageVersion::V1_5_0}, + {"v1.5.1", StorageVersion::V1_5_1}, + {"v1.5.2", StorageVersion::V1_5_2}, + {"v1.5.3", StorageVersion::V1_5_3}, + {"v2.0.0", StorageVersion::V2_0_0}, + {"latest", StorageVersion::V2_0_0}, + {nullptr, StorageVersion::INVALID} }; -// END OF STORAGE VERSION INFO -static_assert(DEFAULT_STORAGE_VERSION_INFO == VERSION_NUMBER, "Check on VERSION_INFO"); +// END OF STORAGE_ARRAY VERSION INFO +// clang-format on -// START OF SERIALIZATION VERSION INFO -const uint64_t LATEST_SERIALIZATION_VERSION_INFO = 8; -const uint64_t DEFAULT_SERIALIZATION_VERSION_INFO = 1; +// These sections are automatically generated by scripts/generate_storage_info.py +// Do not edit them manually, your changes will be overwritten +// clang-format off +// START OF SER_ARRAY VERSION INFO static const SerializationVersionInfo serialization_version_info[] = { - {"v0.10.0", 1}, - {"v0.10.1", 1}, - {"v0.10.2", 1}, - {"v0.10.3", 2}, - {"v1.0.0", 2}, - {"v1.1.0", 3}, - {"v1.1.1", 3}, - {"v1.1.2", 3}, - {"v1.1.3", 3}, - {"v1.2.0", 4}, - {"v1.2.1", 4}, - {"v1.2.2", 4}, - {"v1.3.0", 5}, - {"v1.3.1", 5}, - {"v1.3.2", 5}, - {"v1.4.0", 6}, - {"v1.4.1", 6}, - {"v1.4.2", 6}, - {"v1.4.3", 6}, - {"v1.4.4", 6}, - {"v1.5.0", 7}, - {"v1.5.1", 7}, - {"v1.5.2", 7}, - {"v2.0.0", 8}, - {"latest", 8}, - {nullptr, 0} + {"v0.10.0", SerializationVersionDeprecated::V0_10_0}, + {"v0.10.1", SerializationVersionDeprecated::V0_10_1}, + {"v0.10.2", SerializationVersionDeprecated::V0_10_2}, + {"v0.10.3", SerializationVersionDeprecated::V0_10_3}, + {"v1.0.0", SerializationVersionDeprecated::V1_0_0}, + {"v1.1.0", SerializationVersionDeprecated::V1_1_0}, + {"v1.1.1", SerializationVersionDeprecated::V1_1_1}, + {"v1.1.2", SerializationVersionDeprecated::V1_1_2}, + {"v1.1.3", SerializationVersionDeprecated::V1_1_3}, + {"v1.2.0", SerializationVersionDeprecated::V1_2_0}, + {"v1.2.1", SerializationVersionDeprecated::V1_2_1}, + {"v1.2.2", SerializationVersionDeprecated::V1_2_2}, + {"v1.3.0", SerializationVersionDeprecated::V1_3_0}, + {"v1.3.1", SerializationVersionDeprecated::V1_3_1}, + {"v1.3.2", SerializationVersionDeprecated::V1_3_2}, + {"v1.4.0", SerializationVersionDeprecated::V1_4_0}, + {"v1.4.1", SerializationVersionDeprecated::V1_4_1}, + {"v1.4.2", SerializationVersionDeprecated::V1_4_2}, + {"v1.4.3", SerializationVersionDeprecated::V1_4_3}, + {"v1.4.4", SerializationVersionDeprecated::V1_4_4}, + {"v1.5.0", SerializationVersionDeprecated::V1_5_0}, + {"v1.5.1", SerializationVersionDeprecated::V1_5_1}, + {"v1.5.2", SerializationVersionDeprecated::V1_5_2}, + {"v1.5.3", SerializationVersionDeprecated::V1_5_3}, + {"v2.0.0", SerializationVersionDeprecated::V2_0_0}, + {"latest", SerializationVersionDeprecated::V2_0_0}, + {nullptr, SerializationVersionDeprecated::INVALID} }; -// END OF SERIALIZATION VERSION INFO +// END OF SER_ARRAY VERSION INFO // clang-format on -static_assert(DEFAULT_SERIALIZATION_VERSION_INFO <= LATEST_SERIALIZATION_VERSION_INFO, - "Check on SERIALIZATION_VERSION_INFO"); +static constexpr StorageVersion DEFAULT_STORAGE_VERSION_INFO = StorageVersion::V2_0_0; +static_assert(static_cast(DEFAULT_STORAGE_VERSION_INFO) == VERSION_NUMBER, "Check on VERSION_INFO"); -string GetStorageVersionName(const idx_t serialization_version, const bool add_suffix) { - if (serialization_version < 4) { - // special handling for lower serialization versions - return "v1.0.0+"; +const StorageVersionInfo *GetStorageVersionInfo() { + return storage_version_info; +} + +string GetStorageVersionNameInternal(const StorageVersion storage_version) { + if (StorageManager::IsPriorToVersion(StorageVersion::V1_2_0, storage_version)) { + return "v0.10.2"; } - optional_idx min_idx; - for (idx_t i = 0; serialization_version_info[i].version_name; i++) { - if (strcmp(serialization_version_info[i].version_name, "latest") == 0) { + + StorageVersion min_version = StorageVersion::INVALID; + for (idx_t i = 0; storage_version_info[i].version_name; i++) { + if (strcmp(storage_version_info[i].version_name, "latest") == 0) { continue; } - if (serialization_version_info[i].serialization_version != serialization_version) { + if (storage_version_info[i].storage_version != storage_version) { continue; } - if (!min_idx.IsValid()) { - min_idx = i; + if (min_version == StorageVersion::INVALID) { + min_version = storage_version_info[i].storage_version; } } - if (!min_idx.IsValid()) { + if (min_version == StorageVersion::INVALID) { D_ASSERT(0); + return ""; + } + + return StorageVersionInfo::GetStorageVersionString(storage_version); +} + +string GetStorageVersionName(const StorageVersion storage_version, const bool add_suffix) { + if (StorageManager::IsPriorToVersion(StorageVersion::V1_2_0, storage_version)) { + // special handling for lower storage versions + return "v1.0.0+"; + } + + auto name = GetStorageVersionNameInternal(storage_version); + if (name.empty()) { return "--UNKNOWN--"; } - auto min_name = string(serialization_version_info[min_idx.GetIndex()].version_name); if (add_suffix) { - min_name += "+"; + name += "+"; + } + + return name; +} + +idx_t GetSerializationVersionDeprecated(const char *version_string) { + for (idx_t i = 0; serialization_version_info[i].version_name; i++) { + if (!strcmp(serialization_version_info[i].version_name, version_string)) { + return SerializationVersionInfo::GetSerializationVersionValue( + serialization_version_info[i].storage_version); + } } - return min_name; + if (GetStorageVersion(version_string) != StorageVersion::INVALID) { + return SerializationVersionInfo::GetSerializationVersionValue(SerializationVersionDeprecated::V0_10_2); + } + // Unknown/future version + return SerializationVersionInfo::Invalid(); } -optional_idx GetStorageVersion(const char *version_string) { +StorageVersion GetStorageVersionValue(const char *version_string) { for (idx_t i = 0; storage_version_info[i].version_name; i++) { if (!strcmp(storage_version_info[i].version_name, version_string)) { return storage_version_info[i].storage_version; } } - return optional_idx(); + return StorageVersionInfo::Invalid(); } -optional_idx GetSerializationVersion(const char *version_string) { - for (idx_t i = 0; serialization_version_info[i].version_name; i++) { - if (!strcmp(serialization_version_info[i].version_name, version_string)) { - return serialization_version_info[i].serialization_version; +StorageVersion GetStorageVersion(const char *version_string) { + for (idx_t i = 0; storage_version_info[i].version_name; i++) { + if (!strcmp(storage_version_info[i].version_name, version_string)) { + auto version = GetStorageVersionValue(version_string); + if (version != StorageVersion::INVALID) { + return version; + } + } + } + return StorageVersion::INVALID; +} + +string StorageVersionInfo::GetStorageVersionString(const StorageVersion &version) { + for (idx_t i = 0; storage_version_info[i].version_name; i++) { + if (storage_version_info[i].storage_version == version) { + return storage_version_info[i].version_name; } } - return optional_idx(); + throw InvalidInputException("Invalid storage version %d", static_cast(version)); } -vector GetSerializationCandidates() { +vector GetStorageCandidates() { vector candidates; - for (idx_t i = 0; serialization_version_info[i].version_name; i++) { - candidates.push_back(serialization_version_info[i].version_name); + for (idx_t i = 0; storage_version_info[i].version_name; i++) { + candidates.push_back(storage_version_info[i].version_name); } return candidates; } -string GetDuckDBVersions(idx_t version_number) { +string GetDuckDBVersions(const StorageVersion version_number) { vector versions; for (idx_t i = 0; storage_version_info[i].version_name; i++) { if (version_number == storage_version_info[i].storage_version) { diff --git a/src/duckdb/src/storage/storage_manager.cpp b/src/duckdb/src/storage/storage_manager.cpp index 5d9d48ea3..9f8799b8b 100644 --- a/src/duckdb/src/storage/storage_manager.cpp +++ b/src/duckdb/src/storage/storage_manager.cpp @@ -2,9 +2,12 @@ #include "duckdb/catalog/catalog.hpp" #include "duckdb/common/file_system.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/logging/logger.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/client_data.hpp" +#include "duckdb/main/query_profiler.hpp" #include "duckdb/main/settings.hpp" #include "duckdb/main/database.hpp" #include "duckdb/storage/checkpoint_manager.hpp" @@ -28,8 +31,8 @@ using SHA256State = duckdb_mbedtls::MbedTlsWrapper::SHA256State; void StorageOptions::SetEncryptionVersion(string &storage_version_user_provided) { // storage version < v1.4.0 - if (!storage_version.IsValid() || - storage_version.GetIndex() < SerializationCompatibility::FromString("v1.4.0").serialization_version) { + if (storage_version == StorageVersion::INVALID || + StorageManager::IsPriorToVersion(StorageVersion::V1_4_0, storage_version)) { if (!storage_version_user_provided.empty()) { throw InvalidInputException("Explicit provided STORAGE_VERSION (\"%s\") and ENCRYPTION_KEY (storage >= " "v1.4.0) are not compatible", @@ -46,12 +49,12 @@ void StorageOptions::SetEncryptionVersion(string &storage_version_user_provided) switch (target_encryption_version) { case EncryptionTypes::V0_1: // storage version not explicitly set - if (!storage_version.IsValid() && storage_version_user_provided.empty()) { - storage_version = SerializationCompatibility::FromString("v1.5.0").serialization_version; + if (storage_version == StorageVersion::INVALID && storage_version_user_provided.empty()) { + storage_version = StorageVersion::V1_5_0; break; } // storage version set, but v1.4.0 =< storage < v1.5.0 - if (storage_version.GetIndex() < SerializationCompatibility::FromString("v1.5.0").serialization_version) { + if (StorageManager::IsPriorToVersion(StorageVersion::V1_5_0, storage_version)) { if (!storage_version_user_provided.empty()) { if (encryption_version == target_encryption_version) { // encryption version is explicitly given, but not compatible with < v1.5.0 @@ -71,8 +74,8 @@ void StorageOptions::SetEncryptionVersion(string &storage_version_user_provided) case EncryptionTypes::V0_0: // we set this to V0 to V1.5.0 if no explicit storage version provided - if (!storage_version.IsValid() && storage_version_user_provided.empty()) { - storage_version = SerializationCompatibility::FromString("v1.5.0").serialization_version; + if (storage_version == StorageVersion::INVALID && storage_version_user_provided.empty()) { + storage_version = StorageVersion::V1_5_0; break; } // if storage version is provided, we do nothing @@ -114,8 +117,7 @@ void StorageOptions::Initialize(unordered_map &options) { row_group_size = entry.second.GetValue(); } else if (entry.first == "storage_version") { storage_version_user_provided = entry.second.ToString(); - storage_version = - SerializationCompatibility::FromString(storage_version_user_provided).serialization_version; + storage_version = StorageCompatibility::FromString(storage_version_user_provided).storage_version; } else if (entry.first == "compress") { if (entry.second.DefaultCastAs(LogicalType::BOOLEAN).GetValue()) { compress_in_memory = CompressInMemory::COMPRESS; @@ -124,6 +126,21 @@ void StorageOptions::Initialize(unordered_map &options) { } } else if (entry.first == "debug_encryption_version") { encryption_version = EncryptionTypes::StringToVersion(entry.second.ToString()); + } else if (entry.first == "io_mode") { + auto io_mode_str = StringUtil::Upper(entry.second.ToString()); + if (io_mode_str == "BUFFERED_IO") { + io_mode = FileIOMode::BUFFERED_IO; + } else if (io_mode_str == "MMAP") { + io_mode = FileIOMode::MMAP; + } else if (io_mode_str == "DIRECT_IO") { + io_mode = FileIOMode::DIRECT_IO; + } else { + throw BinderException( + "Unrecognized IO_MODE \"%s\". Valid values are 'BUFFERED_IO', 'MMAP', or 'DIRECT_IO'.", + entry.second.ToString()); + } + } else if (entry.first == "mmap_reserve_size") { + mmap_reserve_size = DBConfig::ParseMemoryLimit(entry.second.ToString()); } else { throw BinderException("Unrecognized option for attach \"%s\"", entry.first); } @@ -197,6 +214,32 @@ optional_ptr StorageManager::GetWAL() { return wal.get(); } +// comparison used to see whether the storage version is compatible +bool StorageManager::TargetAtLeastVersion(StorageVersion target_version, idx_t storage_version) { + if (storage_version < static_cast(target_version)) { + // storage version is lower then required storage version + return false; + } + return true; +} + +// comparison used to see whether the storage version is compatible +bool StorageManager::IsPriorToVersion(StorageVersion target_version, StorageVersion storage_version) { + if (storage_version < target_version) { + // storage version is lower then required storage version + return true; + } + return false; +} + +bool StorageManager::TargetAtLeastVersion(StorageVersion target_version, StorageVersion storage_version) { + if (storage_version < target_version) { + // storage version is lower then required storage version + return false; + } + return true; +} + bool StorageManager::HasWAL() const { if (InMemory() || read_only || !load_complete) { return false; @@ -369,7 +412,7 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { DEFAULT_BLOCK_HEADER_STORAGE_SIZE); table_io_manager = make_uniq(*block_manager, DEFAULT_ROW_GROUP_SIZE); // in-memory databases can always use the latest storage version - storage_version = GetSerializationVersion("latest"); + storage_version = GetLatestStorageVersion(); load_complete = true; return; } @@ -382,7 +425,17 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { StorageManagerOptions options; options.read_only = read_only; - options.use_direct_io = config.options.use_direct_io; + // MMAP + encryption would corrupt the file (in-place decryption); demote to BUFFERED_IO. + auto resolved_io_mode = + storage_options.io_mode ? *storage_options.io_mode : Settings::Get(config); + if (storage_options.encryption && resolved_io_mode == FileIOMode::MMAP) { + DUCKDB_LOG_WARNING(db.GetDatabase(), + "MMAP IO_MODE is incompatible with encryption; falling back to BUFFERED_IO for \"%s\"", + path); + resolved_io_mode = FileIOMode::BUFFERED_IO; + } + options.io_mode = resolved_io_mode; + options.mmap_reserve_size = storage_options.mmap_reserve_size; options.debug_initialize = config.options.debug_initialize; options.storage_version = storage_options.storage_version; @@ -438,9 +491,10 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { // No encryption; use the default option. options.block_header_size = config.options.default_block_header_size; } - if (!options.storage_version.IsValid()) { + + if (options.storage_version == StorageVersion::INVALID) { // when creating a new database we default to the serialization version specified in the config - options.storage_version = config.options.serialization_compatibility.serialization_version; + options.storage_version = config.options.storage_compatibility.storage_version; } // Initialize the block manager before creating a new database. @@ -497,42 +551,39 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context) { } } - unique_ptr timer = nullptr; + unique_ptr timer = nullptr; // Start timing the storage load step. auto client_context = context.GetClientContext(); if (client_context) { auto profiler = client_context->client_data->profiler; - timer = make_uniq(profiler->StartTimer(MetricType::ATTACH_LOAD_STORAGE_LATENCY)); + timer = make_uniq(profiler->StartTimer()); } // Load the checkpoint from storage. auto checkpoint_reader = SingleFileCheckpointReader(*this); checkpoint_reader.LoadFromStorage(); - // End timing the storage load step. + // Reset the timer (also ends it). if (timer) { - timer->EndTimer(); timer = nullptr; } // Start timing the WAL replay step. if (client_context) { auto profiler = client_context->client_data->profiler; - timer = make_uniq(profiler->StartTimer(MetricType::ATTACH_REPLAY_WAL_LATENCY)); + timer = make_uniq(profiler->StartTimer()); } // Replay the WAL. wal_path = GetWALPath(); wal = WriteAheadLog::Replay(context, *this, wal_path); - // End timing the WAL replay step. - if (timer) { - timer->EndTimer(); - } + // Timer will go out of scope here, if set. } - if (row_group_size > 122880ULL && GetStorageVersion() < 4) { + // + if (row_group_size > 122880ULL && IsPriorToVersion(StorageVersion::V1_2_0, GetStorageVersion())) { throw InvalidInputException("Unsupported row group size %llu - row group sizes >= 122_880 are only supported " "with STORAGE_VERSION '1.2.0' or above.\nExplicitly specify a newer storage " "version when creating the database to enable larger row groups", @@ -705,14 +756,15 @@ void SingleFileStorageManager::CreateCheckpoint(QueryContext context, Checkpoint try { // Start timing the checkpoint. auto client_context = context.GetClientContext(); - ActiveTimer profiler; + MetricsTimer timer; if (client_context) { - profiler = client_context->client_data->profiler->StartTimer(MetricType::CHECKPOINT_LATENCY); + timer = client_context->client_data->profiler->StartTimer(); } // Write the checkpoint. auto checkpointer = CreateCheckpointWriter(context, options); checkpointer->CreateCheckpoint(); + timer.EndTimer(); } catch (std::exception &ex) { ErrorData error(ex); diff --git a/src/duckdb/src/storage/table/array_column_data.cpp b/src/duckdb/src/storage/table/array_column_data.cpp index 810c15ec6..c2f5c2ed9 100644 --- a/src/duckdb/src/storage/table/array_column_data.cpp +++ b/src/duckdb/src/storage/table/array_column_data.cpp @@ -190,24 +190,31 @@ void ArrayColumnData::InitializeAppend(ColumnAppendState &state) { state.child_appends.push_back(std::move(child_append)); } -void ArrayColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { +void ArrayColumnData::Append(ColumnAppendState &state, const Vector &vector, idx_t count) { if (vector.GetVectorType() != VectorType::FLAT_VECTOR) { Vector append_vector(Vector::Ref(vector)); append_vector.Flatten(); - Append(stats, state, append_vector, count); + Append(state, append_vector, count); return; } // Append validity - validity->Append(stats, state.child_appends[0], vector, count); + validity->Append(state.child_appends[0], vector, count); // Append child column auto array_size = ArrayType::GetSize(type); - auto &child_vec = ArrayVector::GetChildMutable(vector); - child_column->Append(ArrayStats::GetChildStats(stats), state.child_appends[1], child_vec, count * array_size); + const auto &child_vec = ArrayVector::GetChild(vector); + child_column->Append(state.child_appends[1], child_vec, count * array_size); this->count += count; } +void ArrayColumnData::FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) { + validity->FinalizeAppendLocked(finalize_state, state.child_appends[0]); + + ColumnDataFinalizeAppendState child_finalize_state(finalize_state, LogicalTypeId::ARRAY); + child_column->FinalizeAppendLocked(child_finalize_state, state.child_appends[1]); +} + void ArrayColumnData::RevertAppend(row_t new_count) { // Revert validity validity->RevertAppend(new_count); @@ -237,31 +244,36 @@ unique_ptr ArrayColumnData::GetUpdateStatistics() { return nullptr; } -void ArrayColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, - row_t row_id, Vector &result, idx_t result_idx) { +void ArrayColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t fetch_count, Vector &result, + idx_t result_offset) { // Create state for validity & child column if (state.child_states.empty()) { state.child_states.push_back(make_uniq()); } // Fetch validity - validity->FetchRow(transaction, *state.child_states[0], storage_index, row_id, result, result_idx); + validity->FetchRowsAtSegmentLevel(transaction, *state.child_states[0], offsets, sel, fetch_count, result, + result_offset); // Fetch child column auto &child_vec = ArrayVector::GetChildMutable(result); auto &child_type = ArrayType::GetChildType(type); auto array_size = ArrayType::GetSize(type); - // We need to fetch between [row_id * array_size, (row_id + 1) * array_size) - ColumnScanState child_state(nullptr); - child_state.Initialize(state.context, child_type, nullptr); + for (idx_t idx = 0; idx < fetch_count; idx++) { + // We need to fetch between [row_id * array_size, (row_id + 1) * array_size) + ColumnScanState child_state(nullptr); + child_state.Initialize(state.context, child_type, nullptr); - const auto child_offset = UnsafeNumericCast(row_id) * array_size; + const auto row_id = offsets[sel.get_index(idx)]; + const auto child_offset = row_id * array_size; - child_column->InitializeScanWithOffset(child_state, child_offset); - Vector child_scan(child_type, array_size); - child_column->ScanCount(child_state, child_scan, array_size); - VectorOperations::Copy(child_scan, child_vec, array_size, 0, result_idx * array_size); + child_column->InitializeScanWithOffset(child_state, child_offset); + Vector child_scan(child_type, array_size); + child_column->ScanCount(child_state, child_scan, array_size); + VectorOperations::Copy(child_scan, child_vec, array_size, 0, (result_offset + idx) * array_size); + } } void ArrayColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { @@ -376,11 +388,12 @@ void ArrayColumnData::InitializeColumn(PersistentColumnData &column_data, BaseSt } void ArrayColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, - vector &result) { + vector &result, + const ColumnSegmentInfoScanOptions &options) { col_path.push_back(0); - validity->GetColumnSegmentInfo(context, row_group_index, col_path, result); + validity->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); col_path.back() = 1; - child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); + child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); } void ArrayColumnData::Verify(RowGroup &parent) { diff --git a/src/duckdb/src/storage/table/column_checkpoint_state.cpp b/src/duckdb/src/storage/table/column_checkpoint_state.cpp index 31efad68f..e4e60bdd4 100644 --- a/src/duckdb/src/storage/table/column_checkpoint_state.cpp +++ b/src/duckdb/src/storage/table/column_checkpoint_state.cpp @@ -44,7 +44,7 @@ shared_ptr ColumnCheckpointState::GetFinalResult() { PartialBlockForCheckpoint::PartialBlockForCheckpoint(ColumnData &data, ColumnSegment &segment, PartialBlockState state, BlockManager &block_manager) - : PartialBlock(state, block_manager, segment.block) { + : PartialBlock(state, block_manager, segment.GetBlockHandle()) { PartialBlockForCheckpoint::AddSegmentToTail(data, segment, 0); } @@ -80,7 +80,7 @@ void PartialBlockForCheckpoint::Flush(QueryContext context, const idx_t free_spa D_ASSERT(segment.offset_in_block == 0); segment.segment.ConvertToPersistent(context, &block_manager, state.block_id); // update the block after it has been converted to a persistent segment - block_handle = segment.segment.block; + block_handle = segment.segment.GetBlockHandle(); } else { // subsequent segments are MARKED as persistent - they don't need to be rewritten segment.segment.MarkAsPersistent(block_handle, segment.offset_in_block); @@ -147,13 +147,13 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme } // LCOV_EXCL_STOP // Merge the segment statistics into the global statistics. - global_stats->Merge(segment->stats.statistics); + global_stats->Merge(segment->GetStats()); block_id_t block_id = INVALID_BLOCK; uint32_t offset_in_block = 0; unique_lock partial_block_lock; - if (segment->stats.statistics.IsConstant()) { + if (segment->GetStats().IsConstant()) { // Constant block. segment->ConvertToPersistent(partial_block_manager.GetClientContext(), nullptr, INVALID_BLOCK); } else if (segment_size != 0) { @@ -172,7 +172,7 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme D_ASSERT(offset_in_block > 0); auto &pstate = *allocation.partial_block; // pin the source block - auto old_handle = buffer_manager.Pin(segment->block); + auto old_handle = buffer_manager.Pin(segment->GetBlockHandle()); // pin the target block auto new_handle = buffer_manager.Pin(pstate.block_handle); // memcpy the contents of the old block to the new block @@ -196,12 +196,12 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme } else { // Empty segment, which does not have to go to disk. // We still need to change its type to persistent, because we need to write its metadata. - segment->segment_type = ColumnSegmentType::PERSISTENT; - segment->block.reset(); + segment->SetSegmentType(ColumnSegmentType::PERSISTENT); + segment->GetBlockHandle().reset(); } // construct the data pointer - DataPointer data_pointer(segment->stats.statistics.Copy()); + DataPointer data_pointer(segment->GetStats().Copy()); data_pointer.block_pointer.block_id = block_id; data_pointer.block_pointer.offset = offset_in_block; data_pointer.row_start = 0; diff --git a/src/duckdb/src/storage/table/column_data.cpp b/src/duckdb/src/storage/table/column_data.cpp index 279f69e15..85a9622f3 100644 --- a/src/duckdb/src/storage/table/column_data.cpp +++ b/src/duckdb/src/storage/table/column_data.cpp @@ -1,6 +1,8 @@ #include "duckdb/storage/table/column_data.hpp" #include "duckdb/common/exception/transaction_exception.hpp" +#include "duckdb/common/types/validity_mask.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/common/serializer/binary_deserializer.hpp" #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/common/serializer/serializer.hpp" @@ -8,6 +10,7 @@ #include "duckdb/function/variant/variant_shredding.hpp" #include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/filter/expression_filter.hpp" +#include "duckdb/planner/filter/table_filter_functions.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/data_pointer.hpp" #include "duckdb/storage/data_table.hpp" @@ -20,6 +23,9 @@ #include "duckdb/storage/table/variant_column_data.hpp" #include "duckdb/storage/table/update_segment.hpp" #include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/storage/statistics/array_stats.hpp" +#include "duckdb/storage/statistics/list_stats.hpp" +#include "duckdb/storage/statistics/struct_stats.hpp" #include "duckdb/storage/table/data_table_info.hpp" #include "duckdb/storage/table/append_state.hpp" #include "duckdb/storage/table/scan_state.hpp" @@ -27,23 +33,17 @@ namespace duckdb { static bool IsDirectNullCheckFilter(const TableFilter &filter) { - switch (filter.filter_type) { - case TableFilterType::EXPRESSION_FILTER: { - auto &expr = filter.Cast().expr; - if (expr->GetExpressionClass() != ExpressionClass::BOUND_OPERATOR) { - return false; - } - auto &op = expr->Cast(); - if ((op.GetExpressionType() != ExpressionType::OPERATOR_IS_NULL && - op.GetExpressionType() != ExpressionType::OPERATOR_IS_NOT_NULL) || - op.children.size() != 1) { - return false; - } - return op.children[0]->GetExpressionClass() == ExpressionClass::BOUND_REF; + auto &expr = ExpressionFilter::GetExpressionFilter(filter, "ColumnData::IsDirectNullCheckFilter").expr; + if (expr->GetExpressionClass() != ExpressionClass::BOUND_OPERATOR) { + return false; } - default: + auto &op = expr->Cast(); + if ((op.GetExpressionType() != ExpressionType::OPERATOR_IS_NULL && + op.GetExpressionType() != ExpressionType::OPERATOR_IS_NOT_NULL) || + op.GetChildren().size() != 1) { return false; } + return op.GetChildren()[0]->GetExpressionClass() == ExpressionClass::BOUND_REF; } ColumnData::ColumnData(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, LogicalType type_p, @@ -94,7 +94,7 @@ bool ColumnData::HasChanges(idx_t start_row, idx_t end_row) const { bool ColumnData::HasChanges() const { for (auto &segment_node : data.SegmentNodes()) { auto &segment = segment_node.GetNode(); - if (segment.segment_type == ColumnSegmentType::TRANSIENT) { + if (segment.GetSegmentType() == ColumnSegmentType::TRANSIENT) { // transient segment: always need to write to disk return true; } @@ -201,7 +201,7 @@ void ColumnData::BeginScanVectorInternal(ColumnScanState &state) { auto ¤t = state.current->GetNode(); current.Skip(state); } - D_ASSERT(state.current->GetNode().type == type); + D_ASSERT(state.current->GetNode().GetType() == type); } idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remaining, ScanVectorType scan_type, @@ -304,7 +304,8 @@ void ColumnData::FetchUpdateRow(TransactionData transaction, row_t row_id, Vecto if (!updates) { return; } - updates->FetchRow(transaction, NumericCast(row_id), result, result_idx); + const idx_t offset = NumericCast(row_id); + updates->FetchRows(transaction, &offset, *FlatVector::IncrementalSelectionVector(), 1, result, result_idx); } void ColumnData::UpdateInternal(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, @@ -392,18 +393,32 @@ void ColumnData::Skip(ColumnScanState &state, idx_t s_count) { state.Next(s_count); } -void ColumnData::Append(BaseStatistics &append_stats, ColumnAppendState &state, Vector &vector, idx_t append_count) { +void ColumnData::Append(ColumnAppendState &state, const Vector &vector, idx_t append_count) { UnifiedVectorFormat vdata; vector.ToUnifiedFormat(vdata); - AppendData(append_stats, state, vdata, append_count); + AppendData(state, vdata, append_count); +} + +void ColumnData::FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) { + // flush remaining stats + state.FinalFlush(finalize_state.global_stats); } -void ColumnData::Append(ColumnAppendState &state, Vector &vector, idx_t append_count) { +void ColumnData::FinalizeAppend(optional_ptr table_stats, ColumnAppendState &state) { if (!stats) { - throw InternalException("ColumnData::Append called on a column with a parent or without stats"); + throw InternalException("ColumnData::FinalizeAppend called on a column with a parent or without stats"); } lock_guard l(stats_lock); - Append(stats->statistics, state, vector, append_count); + ColumnDataFinalizeAppendState finalize_state(stats->statistics); + if (table_stats) { + finalize_state.global_stats.emplace_back(*table_stats); + } + FinalizeAppend(finalize_state, state); +} + +void ColumnData::FinalizeAppendLocked(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) { + lock_guard l(stats_lock); + FinalizeAppend(finalize_state, state); } FilterPropagateResult ColumnData::CheckZonemap(ColumnScanState &state, TableFilter &filter) { @@ -414,15 +429,19 @@ FilterPropagateResult ColumnData::CheckZonemap(ColumnScanState &state, TableFilt return FilterPropagateResult::NO_PRUNING_POSSIBLE; } // for dynamic filters we never consider the segment being "checked" as it can always change - state.segment_checked = filter.filter_type != TableFilterType::DYNAMIC_FILTER; + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "ColumnData::CheckZonemap"); + bool is_dynamic = ExpressionFilter::ContainsInternalFunction(*expr_filter.expr, DynamicFilterScalarFun::NAME); + state.segment_checked = !is_dynamic; FilterPropagateResult prune_result; { lock_guard l(stats_lock); auto &segment_stats = IsDirectNullCheckFilter(filter) && !state.child_states.empty() && state.child_states[0].current - ? state.child_states[0].current->GetNode().stats.statistics - : state.current->GetNode().stats.statistics; - prune_result = filter.CheckStatistics(segment_stats); + ? state.child_states[0].current->GetNode().GetStatsMutable() + : state.current->GetNode().GetStatsMutable(); + auto context = state.context.GetClientContext(); + prune_result = + context ? expr_filter.CheckStatistics(*context, segment_stats) : expr_filter.CheckStatistics(segment_stats); if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } @@ -434,14 +453,17 @@ FilterPropagateResult ColumnData::CheckZonemap(ColumnScanState &state, TableFilt } auto update_stats = updates->GetStatistics(); // combine the update and original prune result - FilterPropagateResult update_result = filter.CheckStatistics(*update_stats); + auto context = state.context.GetClientContext(); + FilterPropagateResult update_result = + context ? expr_filter.CheckStatistics(*context, *update_stats) : expr_filter.CheckStatistics(*update_stats); if (prune_result == update_result) { return prune_result; } return FilterPropagateResult::NO_PRUNING_POSSIBLE; } -FilterPropagateResult ColumnData::CheckZonemap(const StorageIndex &index, TableFilter &filter) { +FilterPropagateResult ColumnData::CheckZonemap(optional_ptr context, const StorageIndex &index, + TableFilter &filter) { if (!stats) { throw InternalException("ColumnData::CheckZonemap called on a column without stats"); } @@ -451,9 +473,13 @@ FilterPropagateResult ColumnData::CheckZonemap(const StorageIndex &index, TableF if (!child_stats) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - return filter.CheckStatistics(*child_stats); + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "ColumnData::CheckZonemap"); + return context ? expr_filter.CheckStatistics(*context, *child_stats) + : expr_filter.CheckStatistics(*child_stats); } - return filter.CheckStatistics(stats->statistics); + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "ColumnData::CheckZonemap"); + return context ? expr_filter.CheckStatistics(*context, stats->statistics) + : expr_filter.CheckStatistics(stats->statistics); } const BaseStatistics &ColumnData::GetStatisticsRef() const { @@ -492,47 +518,111 @@ void ColumnData::MergeIntoStatistics(BaseStatistics &other) { return other.Merge(stats->statistics); } +void ColumnAppendState::InitializeStats(const LogicalType &type) { + append_stats = BaseStatistics::CreateEmpty(type).ToUnique(); +} + +void ColumnAppendState::FlushSegmentStats() { + // we finished appending to this segment but we have another segment to append to + // first flush the stats into the column segment + // flush stats into the ColumnData and ColumnSegment stats + auto &append_segment = current->GetNode(); + append_segment.GetStatsMutable().Merge(*append_stats); + + // now merge the stats into the "full_append_stats" + if (!full_append_stats) { + full_append_stats = std::move(append_stats); + } else { + full_append_stats->Merge(*append_stats); + } + // clear the current append stats + append_stats.reset(); +} + +void ColumnAppendState::FinalFlush(vector> &global_stats) { + // flush stats to the final segment + FlushSegmentStats(); + // now merge the final append stats into the global stats + for (auto &stats_ref : global_stats) { + auto &target_stats = stats_ref.get(); + target_stats.Merge(*full_append_stats); + } +} + +ColumnDataFinalizeAppendState::ColumnDataFinalizeAppendState(ColumnDataFinalizeAppendState &parent, + LogicalTypeId type_transform, optional_idx child_offset) { + for (auto &parent_stats_ref : parent.global_stats) { + auto &parent_stats = parent_stats_ref.get(); + + optional_ptr transformed_stats; + switch (type_transform) { + case LogicalTypeId::ARRAY: + transformed_stats = ArrayStats::GetChildStats(parent_stats); + break; + case LogicalTypeId::LIST: + transformed_stats = ListStats::GetChildStats(parent_stats); + break; + case LogicalTypeId::STRUCT: + transformed_stats = StructStats::GetChildStats(parent_stats, child_offset.GetIndex()); + break; + case LogicalTypeId::VARIANT: + transformed_stats = VariantStats::GetUnshreddedStats(parent_stats); + break; + default: + throw InternalException("Unsupported type for ColumnDataFinalizeAppendState inheritance"); + } + global_stats.emplace_back(*transformed_stats); + } +} + void ColumnData::InitializeAppend(ColumnAppendState &state) { auto l = data.Lock(); if (data.IsEmpty(l)) { // no segments yet, append an empty segment - AppendTransientSegment(l, 0); + AppendTransientSegment(l, 0, nullptr); } auto segment = data.GetLastSegment(l); auto &last_segment = segment->GetNode(); - if (last_segment.segment_type == ColumnSegmentType::PERSISTENT || + if (last_segment.GetSegmentType() == ColumnSegmentType::PERSISTENT || !last_segment.GetCompressionFunction().init_append) { // we cannot append to this segment - append a new segment auto total_rows = segment->GetRowStart() + last_segment.count; - AppendTransientSegment(l, total_rows); + AppendTransientSegment(l, total_rows, last_segment); state.current = data.GetLastSegment(l); } else { state.current = segment; } + state.InitializeStats(GetType()); auto &append_segment = state.current->GetNode(); - D_ASSERT(append_segment.segment_type == ColumnSegmentType::TRANSIENT); + D_ASSERT(append_segment.GetSegmentType() == ColumnSegmentType::TRANSIENT); append_segment.InitializeAppend(state); D_ASSERT(append_segment.GetCompressionFunction().append); } -void ColumnData::AppendData(BaseStatistics &append_stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, - idx_t append_count) { +void ColumnData::AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t append_count) { idx_t offset = 0; while (true) { // append the data from the vector auto &append_segment = state.current->GetNode(); idx_t copied_elements = append_segment.Append(state, vdata, offset, append_count); this->count += copied_elements; - append_stats.Merge(append_segment.stats.statistics); if (copied_elements == append_count) { // finished copying everything break; } + // segment is full and we have more to copy + // first flush the stats into the segment and the column data + { + lock_guard guard(stats_lock); + state.FlushSegmentStats(); + } + // re-initialize the stats + state.InitializeStats(GetType()); // we couldn't fit everything we wanted in the current column segment, create a new one { auto l = data.Lock(); - AppendTransientSegment(l, state.current->GetRowStart() + append_segment.count); + AppendTransientSegment(l, state.current->GetRowStart() + append_segment.count, append_segment); state.current = data.GetLastSegment(l); state.current->GetNode().InitializeAppend(state); } @@ -572,7 +662,7 @@ void ColumnData::RevertAppend(row_t new_count_p) { data.EraseSegments(l, segment_index + 1); auto &transient = segment->GetNode(); - D_ASSERT(transient.segment_type == ColumnSegmentType::TRANSIENT); + D_ASSERT(transient.GetSegmentType() == ColumnSegmentType::TRANSIENT); segment->SetNext(nullptr); transient.RevertAppend(new_count - segment->GetRowStart()); } @@ -592,19 +682,42 @@ idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { return ScanVector(state, result, STANDARD_VECTOR_SIZE, ScanVectorType::SCAN_FLAT_VECTOR); } -void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, - row_t row_id, Vector &result, idx_t result_idx) { - if (UnsafeNumericCast(row_id) > count) { - throw InternalException("ColumnData::FetchRow - row_id out of range"); - } - auto segment = data.GetSegment(UnsafeNumericCast(row_id)); +void ColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t fetch_count, Vector &result, + idx_t result_offset) { + FetchRowsAtSegmentLevel(transaction, state, offsets, sel, fetch_count, result, result_offset); +} - // now perform the fetch within the segment - auto index_in_segment = row_id - UnsafeNumericCast(segment->GetRowStart()); - segment->GetNode().FetchRow(state, index_in_segment, result, result_idx); - // merge any updates made to this row +void ColumnData::FetchRowsAtSegmentLevel(TransactionData transaction, ColumnFetchState &state, const idx_t *offsets, + const SelectionVector &sel, idx_t fetch_count, Vector &result, + idx_t result_offset) { + if (fetch_count == 0) { + return; + } - FetchUpdateRow(transaction, row_id, result, result_idx); + optional_ptr> current_segment; + idx_t segment_start = 0; + idx_t segment_end = 0; + for (idx_t idx = 0; idx < fetch_count; idx++) { + const idx_t offset = offsets[sel.get_index(idx)]; + if (offset > count) { + throw InternalException("ColumnData::FetchRowsAtSegmentLevel - row_id %lld out of range for count %lld", + offset, count); + } + if (!current_segment || offset < segment_start || offset >= segment_end) { + current_segment = data.GetSegment(offset); + segment_start = current_segment->GetRowStart(); + segment_end = segment_start + current_segment->GetNode().count; + } + const idx_t index_in_segment = offset - segment_start; + current_segment->GetNode().FetchRow(state, NumericCast(index_in_segment), result, result_offset + idx); + } + { + const lock_guard update_guard(update_lock); + if (updates) { + updates->FetchRows(transaction, offsets, sel, fetch_count, result, result_offset); + } + } } idx_t ColumnData::FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector &base_vector, idx_t row_group_start) { @@ -634,27 +747,40 @@ void ColumnData::UpdateColumn(TransactionData transaction, DuckTableEntry &table ColumnData::Update(transaction, table_entry, column_path[0], update_vector, row_ids, update_count, row_group_start); } -void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row) { - const auto block_size = block_manager.GetBlockSize(); - const auto type_size = GetTypeIdSize(type.InternalType()); - auto vector_segment_size = block_size; - - if (data_type == ColumnDataType::INITIAL_TRANSACTION_LOCAL && start_row == 0) { -#if STANDARD_VECTOR_SIZE < 1024 - vector_segment_size = 1024 * type_size; -#else - vector_segment_size = STANDARD_VECTOR_SIZE * type_size; -#endif +void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row, optional_ptr prev_segment) { + auto &db = GetDatabase(); + auto &config = DBConfig::GetConfig(db); + + idx_t segment_size; + if (!prev_segment || prev_segment->GetSegmentType() == ColumnSegmentType::PERSISTENT) { + // We start with the `initial_bytes` setting, but we ensure that we have enough space for at least one row. + const auto initial_bytes = Settings::Get(config); + segment_size = MaxValue(GetTypeIdSize(type.InternalType()), initial_bytes); + } else { + segment_size = prev_segment->SegmentSize() * 2; + } + + // BIT (validity) segments can only hold rows in multiples of STANDARD_VECTOR_SIZE; + // any segment below STANDARD_MASK_SIZE triggers a dead-segment overflow chain + if (type.InternalType() == PhysicalType::BIT) { + segment_size = MaxValue(segment_size, ValidityMask::STANDARD_MASK_SIZE); + } + + // We set the segment size to the next power of two minus the block header size to + // ensure that we have fixed-size segments which we can group when offloading to temporary storage. + // FIXME: turn this into the min. temporary buffer size instead of a magical number, + // FIXME: once we allow offloading tinier buffers. + if (segment_size >= 1024) { + const auto block_header_size = block_manager.GetBlockHeaderSize(); + segment_size = NextPowerOfTwo(segment_size) - block_header_size; } - // The segment size is bound by the block size, but can be smaller. - idx_t segment_size = block_size < vector_segment_size ? block_size : vector_segment_size; + // The maximum segment size is always the block size of the corresponding block manager. + const auto block_size = block_manager.GetBlockSize(); + segment_size = MinValue(block_size, segment_size); allocation_size += segment_size; - auto &db = GetDatabase(); - auto &config = DBConfig::GetConfig(db); auto function = config.GetCompressionFunction(CompressionType::COMPRESSION_UNCOMPRESSED, type.InternalType()); - auto new_segment = ColumnSegment::CreateTransientSegment(db, function, type, segment_size, block_manager); AppendSegment(l, std::move(new_segment)); } @@ -757,7 +883,7 @@ void ColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatist // create a persistent segment auto segment = ColumnSegment::CreatePersistentSegment( - GetDatabase(), block_manager, data_pointer.block_pointer.block_id, data_pointer.block_pointer.offset, type, + GetDatabase(), block_manager, data_pointer.block_pointer.block_id, data_pointer.block_pointer.offset, data_pointer.tuple_count, data_pointer.compression_type, std::move(data_pointer.statistics), std::move(data_pointer.segment_state)); @@ -768,7 +894,7 @@ void ColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatist bool ColumnData::IsPersistent() { for (auto &segment : data.Segments()) { - if (segment.segment_type != ColumnSegmentType::PERSISTENT) { + if (segment.GetSegmentType() != ColumnSegmentType::PERSISTENT) { return false; } } @@ -1136,7 +1262,7 @@ struct ListBlockIds : public BlockIdVisitor { }; void ColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, - vector &result) { + vector &result, const ColumnSegmentInfoScanOptions &options) { D_ASSERT(!col_path.empty()); // convert the column path to a string @@ -1164,13 +1290,13 @@ void ColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_gro column_info.compression_type = CompressionTypeToString(segment.GetCompressionFunction().type); { lock_guard l(stats_lock); - column_info.segment_stats = segment.stats.statistics.ToStruct(); + column_info.segment_stats = segment.GetStats().ToStruct(); } column_info.has_updates = ColumnData::HasUpdates(); // persistent // block_id // block_offset - if (segment.segment_type == ColumnSegmentType::PERSISTENT) { + if (segment.GetSegmentType() == ColumnSegmentType::PERSISTENT) { column_info.persistent = true; column_info.block_id = segment.GetBlockId(); column_info.block_offset = segment.GetBlockOffset(); @@ -1188,7 +1314,9 @@ void ColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_gro compression_function.visit_block_ids(segment, list_block_ids); } } - if (compression_function.get_segment_info) { + // get_segment_info can be expensive (e.g. for bitpacked segments it reads each block to + // gather per-segment compression mode counts). Skip it unless EXTENDED info is requested. + if (options.include_segment_info && compression_function.get_segment_info) { auto segment_info = compression_function.get_segment_info(context, segment); vector sinfo; for (auto &item : segment_info) { diff --git a/src/duckdb/src/storage/table/column_data_checkpointer.cpp b/src/duckdb/src/storage/table/column_data_checkpointer.cpp index 6f534e086..8ce537261 100644 --- a/src/duckdb/src/storage/table/column_data_checkpointer.cpp +++ b/src/duckdb/src/storage/table/column_data_checkpointer.cpp @@ -1,4 +1,5 @@ #include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/storage/compression/standard_compression_state.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/database.hpp" @@ -88,7 +89,7 @@ ColumnDataCheckpointer::ColumnDataCheckpointer(vector &callback) { +void ColumnDataCheckpointer::ScanSegments(const std::function &callback) { auto &first_state = checkpoint_states[0]; auto &col_data = first_state.get().original_column; @@ -108,7 +109,7 @@ void ColumnDataCheckpointer::ScanSegments(const std::function ColumnDataCheckpointer::DetectBestCompressionMet InitAnalyze(); // scan over all the segments and run the analyze step - ScanSegments([&](Vector &scan_vector, idx_t count) { + ScanSegments([&](Vector &scan_vector) { for (idx_t i = 0; i < checkpoint_states.size(); i++) { auto &functions = compression_functions[i]; auto &states = analyze_states[i]; @@ -202,7 +203,7 @@ vector ColumnDataCheckpointer::DetectBestCompressionMet if (!state) { continue; } - if (!func->analyze(*state, scan_vector, count)) { + if (!func->analyze(*state, scan_vector)) { state = nullptr; func = nullptr; } @@ -337,11 +338,11 @@ void ColumnDataCheckpointer::WriteToDisk() { // Analyze the candidate functions } // Scan over the existing segment + changes and compress the data - ScanSegments([&](Vector &scan_vector, idx_t count) { + ScanSegments([&](Vector &scan_vector) { for (idx_t i = 0; i < checkpoint_states.size(); i++) { auto &function = analyze_result[i].function; auto &compression_state = compression_states[i]; - function->compress(*compression_state, scan_vector, count); + function->compress(*compression_state, scan_vector); } }); @@ -379,7 +380,7 @@ void ColumnDataCheckpointer::WritePersistentSegments(ColumnCheckpointState &stat current_row += segment.count; // merge the persistent stats into the global column stats - state.global_stats->Merge(segment.stats.statistics); + state.global_stats->Merge(segment.GetStats()); state.data_pointers.push_back(std::move(pointer)); } if (error_segment_start.IsValid()) { diff --git a/src/duckdb/src/storage/table/column_segment.cpp b/src/duckdb/src/storage/table/column_segment.cpp index 54ab8bd88..c3a0cf85b 100644 --- a/src/duckdb/src/storage/table/column_segment.cpp +++ b/src/duckdb/src/storage/table/column_segment.cpp @@ -11,18 +11,13 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/prefix_range_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/data_pointer.hpp" #include "duckdb/storage/table/append_state.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/planner/table_filter_state.hpp" #include "duckdb/planner/filter/bloom_filter.hpp" -#include "duckdb/planner/filter/perfect_hash_join_filter.hpp" -#include "duckdb/planner/filter/selectivity_optional_filter.hpp" #include @@ -33,21 +28,21 @@ namespace duckdb { //===--------------------------------------------------------------------===// unique_ptr ColumnSegment::CreatePersistentSegment(DatabaseInstance &db, BlockManager &block_manager, - block_id_t block_id, idx_t offset, - const LogicalType &type, idx_t count, + block_id_t block_id, idx_t offset, idx_t count, CompressionType compression_type, BaseStatistics statistics, unique_ptr segment_state) { auto &config = DBConfig::GetConfig(db); shared_ptr block; + auto &type = statistics.GetType(); auto function = config.GetCompressionFunction(compression_type, type.InternalType()); if (block_id != INVALID_BLOCK) { block = block_manager.RegisterBlock(block_id); } auto segment_size = block_manager.GetBlockSize(); - return make_uniq(db, std::move(block), type, ColumnSegmentType::PERSISTENT, count, function, + return make_uniq(db, std::move(block), ColumnSegmentType::PERSISTENT, count, function, std::move(statistics), block_id, offset, segment_size, std::move(segment_state)); } @@ -60,22 +55,22 @@ unique_ptr ColumnSegment::CreateTransientSegment(DatabaseInstance D_ASSERT(&buffer_manager == &block_manager.buffer_manager); auto block = buffer_manager.RegisterTransientMemory(segment_size, block_manager); - return make_uniq(db, std::move(block), type, ColumnSegmentType::TRANSIENT, 0U, function, + return make_uniq(db, std::move(block), ColumnSegmentType::TRANSIENT, 0U, function, BaseStatistics::CreateEmpty(type), INVALID_BLOCK, 0U, segment_size); } //===--------------------------------------------------------------------===// // Construct/Destruct //===--------------------------------------------------------------------===// -ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block_p, const LogicalType &type, +ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block_p, const ColumnSegmentType segment_type, const idx_t count, const CompressionFunction &function_p, BaseStatistics statistics, const block_id_t block_id_p, const idx_t offset, const idx_t segment_size_p, const unique_ptr segment_state_p) - : SegmentBase(count), db(db), type(type), type_size(GetTypeIdSize(type.InternalType())), - segment_type(segment_type), stats(std::move(statistics)), block(std::move(block_p)), function(function_p), - block_id(block_id_p), offset(offset), segment_size(segment_size_p) { + : SegmentBase(count), db(db), segment_type(segment_type), block(std::move(block_p)), + function(function_p), block_id(block_id_p), offset(offset), segment_size(segment_size_p), + stats(std::move(statistics)) { if (function.get().init_segment) { segment_state = function.get().init_segment(*this, block_id, segment_state_p.get()); } @@ -85,10 +80,9 @@ ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block } ColumnSegment::ColumnSegment(ColumnSegment &other) - : SegmentBase(other.count.load()), db(other.db), type(std::move(other.type)), - type_size(other.type_size), segment_type(other.segment_type), stats(std::move(other.stats)), + : SegmentBase(other.count.load()), db(other.db), segment_type(other.segment_type), block(std::move(other.block)), function(other.function), block_id(other.block_id), offset(other.offset), - segment_size(other.segment_size), segment_state(std::move(other.segment_state)) { + segment_size(other.segment_size), segment_state(std::move(other.segment_state)), stats(std::move(other.stats)) { // For constant segments (CompressionType::COMPRESSION_CONSTANT) the block is a nullptr. D_ASSERT(!block || segment_size <= GetBlockSize()); } @@ -202,7 +196,7 @@ idx_t ColumnSegment::Append(ColumnAppendState &state, UnifiedVectorFormat &appen if (!function.get().append) { throw InternalException("Attempting to append to a segment without append method"); } - return function.get().append(*state.append_state, *this, stats, append_data, offset, count); + return function.get().append(*state.append_state, *this, *state.append_stats, append_data, offset, count); } idx_t ColumnSegment::FinalizeAppend(ColumnAppendState &state) { @@ -210,7 +204,7 @@ idx_t ColumnSegment::FinalizeAppend(ColumnAppendState &state) { if (!function.get().finalize_append) { throw InternalException("Attempting to call FinalizeAppend on a segment without a finalize_append method"); } - auto result_count = function.get().finalize_append(*this, stats); + auto result_count = function.get().finalize_append(*this, *state.append_stats); state.append_state.reset(); return result_count; } @@ -247,7 +241,7 @@ void ColumnSegment::ConvertToPersistent(QueryContext context, optional_ptr(); - return opt_filter.FilterSelection(sel, vector, vdata, filter_state, scan_count, approved_tuple_count); - } - case TableFilterType::CONJUNCTION_OR: { - // similar to the CONJUNCTION_AND, but we need to take care of the SelectionVectors (OR all of them) - auto &state = filter_state.Cast(); - idx_t count_total = 0; - SelectionVector result_sel(approved_tuple_count); - auto &conjunction_or = filter.Cast(); - for (idx_t child_idx = 0; child_idx < conjunction_or.child_filters.size(); child_idx++) { - auto &child_filter = *conjunction_or.child_filters[child_idx]; - SelectionVector temp_sel; - temp_sel.Initialize(sel); - idx_t temp_tuple_count = approved_tuple_count; - idx_t temp_count = FilterSelection(temp_sel, vector, vdata, child_filter, *state.child_states[child_idx], - scan_count, temp_tuple_count); - // tuples passed, move them into the actual result vector - for (idx_t i = 0; i < temp_count; i++) { - auto new_idx = temp_sel.get_index(i); - bool is_new_idx = true; - for (idx_t res_idx = 0; res_idx < count_total; res_idx++) { - if (result_sel.get_index(res_idx) == new_idx) { - is_new_idx = false; - break; - } +static idx_t ExecuteExpressionFilterSelection(SelectionVector &sel, Vector &vector, ExpressionFilterState &state, + idx_t scan_count, idx_t &approved_tuple_count) { + if (approved_tuple_count == 0) { + return 0; + } + D_ASSERT(state.executor); + SelectionVector result_sel(approved_tuple_count); + if (scan_count > STANDARD_VECTOR_SIZE) { + // scan count is > vector size - split up the vector into multiple chunks + idx_t offset = 0; + idx_t result_offset = 0; + idx_t current_sel_offset = 0; + SelectionVector current_sel(approved_tuple_count); + while (offset < scan_count) { + idx_t chunk_count = MinValue(STANDARD_VECTOR_SIZE, scan_count - offset); + idx_t chunk_end = offset + chunk_count; + DataChunk chunk; + chunk.data.emplace_back(vector, offset, chunk_end); + + // construct the relevant selection vector for the current chunk (offset ... offset + chunk_count) + idx_t current_count = 0; + for (; current_sel_offset < approved_tuple_count; current_sel_offset++) { + auto sel_index = sel.get_index(current_sel_offset); + if (sel_index >= chunk_end) { + // exhausted the chunk + break; } - if (is_new_idx) { - result_sel.set_index(count_total++, new_idx); + if (sel_index < offset) { + throw InternalException("sel_index < offset in expression filter"); } + current_sel.set_index(current_count++, sel_index - offset); } - } - sel.Initialize(result_sel); - approved_tuple_count = count_total; - return approved_tuple_count; - } - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and = filter.Cast(); - auto &state = filter_state.Cast(); - for (idx_t child_idx = 0; child_idx < conjunction_and.child_filters.size(); child_idx++) { - auto &child_filter = *conjunction_and.child_filters[child_idx]; - FilterSelection(sel, vector, vdata, child_filter, *state.child_states[child_idx], scan_count, - approved_tuple_count); - } - return approved_tuple_count; - } - case TableFilterType::BLOOM_FILTER: { - auto &bloom_filter = filter.Cast(); - auto &state = filter_state.Cast(); - return bloom_filter.Filter(vector, sel, approved_tuple_count, state); - } - case TableFilterType::PERFECT_HASH_JOIN_FILTER: { - auto &perfect_hash_join_filter = filter.Cast(); - auto &state = filter_state.Cast(); - return perfect_hash_join_filter.Filter(vector, sel, approved_tuple_count, state); - } - case TableFilterType::PREFIX_RANGE_FILTER: { - auto &prefix_range_filter = filter.Cast(); - auto &state = filter_state.Cast(); - return prefix_range_filter.Filter(vector, sel, approved_tuple_count, state); - } - case TableFilterType::EXPRESSION_FILTER: { - auto &state = filter_state.Cast(); - SelectionVector result_sel(approved_tuple_count); - if (scan_count > STANDARD_VECTOR_SIZE) { - // scan count is > vector size - split up the vector into multiple chunks - idx_t offset = 0; - idx_t result_offset = 0; - idx_t current_sel_offset = 0; - SelectionVector current_sel(approved_tuple_count); - while (offset < scan_count) { - idx_t chunk_count = MinValue(STANDARD_VECTOR_SIZE, scan_count - offset); - idx_t chunk_end = offset + chunk_count; - DataChunk chunk; - chunk.data.emplace_back(vector, offset, chunk_end); - chunk.SetCardinality(chunk_count); - - // construct the relevant selection vector for the current chunk (offset ... offset + chunk_count) - idx_t current_count = 0; - for (; current_sel_offset < approved_tuple_count; current_sel_offset++) { - auto sel_index = sel.get_index(current_sel_offset); - if (sel_index >= chunk_end) { - // exhausted the chunk - break; - } - if (sel_index < offset) { - throw InternalException("sel_index < offset in expression filter"); - } - current_sel.set_index(current_count++, sel_index - offset); - } - if (current_count == 0) { - // no matching tuples in this chunk - offset += chunk_count; - continue; - } - auto current_result_data = result_sel.data() + result_offset; - SelectionVector current_result_sel(current_result_data, result_sel.Capacity() - result_offset); - idx_t new_matches = - state.executor->SelectExpression(chunk, current_result_sel, current_sel, current_count); - // increment all matches by the offset - for (idx_t i = 0; i < new_matches; i++) { - current_result_data[i] += offset; - } - result_offset += new_matches; + if (current_count == 0) { + // no matching tuples in this chunk offset += chunk_count; + continue; } - approved_tuple_count = result_offset; - } else { - // standard case: we can handle everything at once - run the expression once - DataChunk chunk; - chunk.data.emplace_back(Vector::Ref(vector)); - chunk.SetCardinality(scan_count); - approved_tuple_count = state.executor->SelectExpression(chunk, result_sel, sel, approved_tuple_count); + auto current_result_data = result_sel.data() + result_offset; + SelectionVector current_result_sel(current_result_data, result_sel.Capacity() - result_offset); + idx_t new_matches = state.executor->SelectExpression(chunk, current_result_sel, current_sel, current_count); + // increment all matches by the offset + for (idx_t i = 0; i < new_matches; i++) { + current_result_data[i] += offset; + } + result_offset += new_matches; + offset += chunk_count; } - sel.Initialize(result_sel); - return approved_tuple_count; - } - default: - throw InternalException("FIXME: unsupported type for filter selection"); + approved_tuple_count = result_offset; + } else { + // standard case: we can handle everything at once - run the expression once + DataChunk chunk; + chunk.data.emplace_back(Vector::Ref(vector)); + chunk.SetChildCardinality(scan_count); + SelectionVector identity_sel; + optional_ptr current_sel = &sel; + if (!sel.IsSet()) { + identity_sel = SelectionVector::Incremental(approved_tuple_count); + current_sel = &identity_sel; + } + approved_tuple_count = state.executor->SelectExpression(chunk, result_sel, current_sel, approved_tuple_count); } + sel.Initialize(result_sel); + return approved_tuple_count; +} + +idx_t ColumnSegment::FilterSelection(SelectionVector &sel, Vector &vector, UnifiedVectorFormat &vdata, + const TableFilter &filter, TableFilterState &filter_state, idx_t scan_count, + idx_t &approved_tuple_count) { + (void)vdata; + auto &state = filter_state.Cast(); + return ExecuteExpressionFilterSelection(sel, vector, state, scan_count, approved_tuple_count); } const CompressionFunction &ColumnSegment::GetCompressionFunction() { diff --git a/src/duckdb/src/storage/table/geo_column_data.cpp b/src/duckdb/src/storage/table/geo_column_data.cpp index 183f45822..2c67a9e7f 100644 --- a/src/duckdb/src/storage/table/geo_column_data.cpp +++ b/src/duckdb/src/storage/table/geo_column_data.cpp @@ -105,10 +105,13 @@ void GeoColumnData::Skip(ColumnScanState &state, idx_t count) { void GeoColumnData::InitializeAppend(ColumnAppendState &state) { base_column->InitializeAppend(state); } -void GeoColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t add_count) { - base_column->Append(stats, state, vector, add_count); +void GeoColumnData::Append(ColumnAppendState &state, const Vector &vector, idx_t add_count) { + base_column->Append(state, vector, add_count); count += add_count; } +void GeoColumnData::FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) { + base_column->FinalizeAppend(finalize_state, state); +} void GeoColumnData::RevertAppend(row_t new_count) { base_column->RevertAppend(new_count); count = UnsafeNumericCast(new_count); @@ -137,20 +140,22 @@ idx_t GeoColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) return fetch_count; } -void GeoColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, - row_t row_id, Vector &result, idx_t result_idx) { +void GeoColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t fetch_count, Vector &result, + idx_t result_offset) { // Not a shredded column, so just emit the binary format immediately if (storage_type == GeometryStorageType::WKB) { - return base_column->FetchRow(transaction, state, storage_index, row_id, result, result_idx); + return base_column->FetchRows(transaction, state, storage_index, offsets, sel, fetch_count, result, + result_offset); } // Otherwise, we need to fetch and reassemble DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), {base_column->GetType()}, 1); + chunk.Initialize(Allocator::DefaultAllocator(), {base_column->GetType()}, fetch_count); - base_column->FetchRow(transaction, state, storage_index, row_id, chunk.data[0], 0); + base_column->FetchRows(transaction, state, storage_index, offsets, sel, fetch_count, chunk.data[0], 0); - Reassemble(chunk.data[0], result, 1, storage_type, result_idx); + Reassemble(chunk.data[0], result, fetch_count, storage_type, result_offset); } //---------------------------------------------------------------------------------------------------------------------- @@ -350,11 +355,30 @@ unique_ptr GeoColumnData::Checkpoint(const RowGroup &row_ checkpoint_state->inner_column = base_column; checkpoint_state->inner_column_state = checkpoint_state->inner_column->Checkpoint(row_group, info, old_column_stats); + + if (base_column->GetType().id() == LogicalTypeId::GEOMETRY) { + // Get the stats from the base column. + checkpoint_state->global_stats = checkpoint_state->inner_column_state->GetStatistics(); + } else if (checkpoint_state->storage_type == GeometryStorageType::SPATIAL) { + // Legacy spatial storage - we cannot interpret the stats of the old format + auto new_stats = checkpoint_state->inner_column_state->GetStatistics(); + checkpoint_state->global_stats = GeometryStats::CreateUnknown(type).ToUnique(); + checkpoint_state->global_stats->CopyBase(*new_stats); + } else { + // Otherwise interpret stats from shredded column + const auto types = Geometry::GetSpecializedType(checkpoint_state->storage_type); + const auto gtype = types.first; + const auto vtype = types.second; + + auto new_stats = checkpoint_state->inner_column_state->GetStatistics(); + InterpretStats(*new_stats, *checkpoint_state->global_stats, gtype, vtype); + } + return std::move(checkpoint_state); } // Old storage version, write as old type - if (GetStorageManager().GetStorageVersion() < 7) { + if (StorageManager::IsPriorToVersion(StorageVersion::V1_5_0, GetStorageManager().GetStorageVersion())) { auto legacy_type = Geometry::GetSpatialGeometryType(); auto new_column = CreateColumn(block_manager, this->info, base_column->column_index, legacy_type, GetDataType(), this); @@ -382,20 +406,21 @@ unique_ptr GeoColumnData::Checkpoint(const RowGroup &row_ auto to_scan = MinValue(total_count - scanned, static_cast(STANDARD_VECTOR_SIZE)); Scan(TransactionData::Committed(), vector_index++, scan_state, scan_chunk.data[0], to_scan); - scan_chunk.SetCardinality(to_scan); // Verify the scan chunk scan_chunk.Verify(GetDatabase()); append_chunk.Reset(); - append_chunk.SetCardinality(to_scan); + append_chunk.SetChildCardinality(to_scan); // Make the split Specialize(scan_chunk.data[0], append_chunk.data[0], to_scan, GeometryStorageType::SPATIAL); // Append into the new specialized column - new_column->Append(empty_stats, append_state, append_chunk.data[0], to_scan); + new_column->Append(append_state, append_chunk.data[0], to_scan); } + ColumnDataFinalizeAppendState finalize_append_state(empty_stats); + new_column->FinalizeAppend(finalize_append_state, append_state); // Move then new column into our checkpoint state checkpoint_state->inner_column = new_column; @@ -477,20 +502,21 @@ unique_ptr GeoColumnData::Checkpoint(const RowGroup &row_ auto to_scan = MinValue(total_count - scanned, static_cast(STANDARD_VECTOR_SIZE)); Scan(TransactionData::Committed(), vector_index++, scan_state, scan_chunk.data[0], to_scan); - scan_chunk.SetCardinality(to_scan); // Verify the scan chunk scan_chunk.Verify(GetDatabase()); append_chunk.Reset(); - append_chunk.SetCardinality(to_scan); + append_chunk.SetChildCardinality(to_scan); // Make the split Specialize(scan_chunk.data[0], append_chunk.data[0], to_scan, new_storage_type); // Append into the new specialized column - new_column->Append(dummy_stats, append_state, append_chunk.data[0], to_scan); + new_column->Append(append_state, append_chunk.data[0], to_scan); } + ColumnDataFinalizeAppendState finalize_append_state(dummy_stats); + new_column->FinalizeAppend(finalize_append_state, append_state); // Merge the stats into the checkpoint state's global stats InterpretStats(dummy_stats, *checkpoint_state->global_stats, new_geom_type, new_vert_type); @@ -575,8 +601,9 @@ idx_t GeoColumnData::GetMaxEntry() { } void GeoColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, - vector &result) { - return base_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); + vector &result, + const ColumnSegmentInfoScanOptions &options) { + return base_column->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); } void GeoColumnData::Verify(RowGroup &parent) { @@ -590,16 +617,16 @@ void GeoColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { //---------------------------------------------------------------------------------------------------------------------- // Specialize //---------------------------------------------------------------------------------------------------------------------- -void GeoColumnData::Specialize(Vector &source, Vector &target, idx_t count, GeometryStorageType type) { +void GeoColumnData::Specialize(const Vector &source, Vector &target, idx_t count, GeometryStorageType type) { Geometry::ToVectorizedFormat(source, target, count, type); } -void GeoColumnData::Reassemble(Vector &source, Vector &target, idx_t count, GeometryStorageType type, +void GeoColumnData::Reassemble(const Vector &source, Vector &target, idx_t count, GeometryStorageType type, idx_t result_offset) { Geometry::FromVectorizedFormat(source, target, count, type, result_offset); } -static const BaseStatistics *GetVertexStats(BaseStatistics &stats, GeometryType geom_type) { +static const BaseStatistics *GetVertexStats(const BaseStatistics &stats, GeometryType geom_type) { switch (geom_type) { case GeometryType::POINT: { return StructStats::GetChildStats(stats); @@ -634,7 +661,7 @@ static const BaseStatistics *GetVertexStats(BaseStatistics &stats, GeometryType } } -void GeoColumnData::InterpretStats(BaseStatistics &source, BaseStatistics &target, GeometryType geom_type, +void GeoColumnData::InterpretStats(const BaseStatistics &source, BaseStatistics &target, GeometryType geom_type, VertexType vert_type) { // Copy base stats target.CopyBase(source); diff --git a/src/duckdb/src/storage/table/in_memory_checkpoint.cpp b/src/duckdb/src/storage/table/in_memory_checkpoint.cpp index c66d418bb..caeba0ae4 100644 --- a/src/duckdb/src/storage/table/in_memory_checkpoint.cpp +++ b/src/duckdb/src/storage/table/in_memory_checkpoint.cpp @@ -91,7 +91,8 @@ InMemoryTableDataWriter::InMemoryTableDataWriter(InMemoryCheckpointer &checkpoin } void InMemoryTableDataWriter::WriteUnchangedTable(MetaBlockPointer pointer, - const vector &metadata_pointers, idx_t total_rows) { + const vector &metadata_pointers, idx_t total_rows, + idx_t next_row_id) { } void InMemoryTableDataWriter::FinalizeTable(const TableStatistics &global_stats, DataTableInfo &info, @@ -117,7 +118,7 @@ MetadataManager &InMemoryTableDataWriter::GetMetadataManager() { InMemoryPartialBlock::InMemoryPartialBlock(ColumnData &data, ColumnSegment &segment, PartialBlockState state, BlockManager &block_manager) - : PartialBlock(state, block_manager, segment.block) { + : PartialBlock(state, block_manager, segment.GetBlockHandle()) { InMemoryPartialBlock::AddSegmentToTail(data, segment, 0); } diff --git a/src/duckdb/src/storage/table/list_column_data.cpp b/src/duckdb/src/storage/table/list_column_data.cpp index cf79f5279..f979cfee7 100644 --- a/src/duckdb/src/storage/table/list_column_data.cpp +++ b/src/duckdb/src/storage/table/list_column_data.cpp @@ -138,7 +138,6 @@ idx_t ListColumnData::ScanCount(ColumnScanState &state, Vector &result, idx_t co state.last_offset = last_entry; ListVector::SetListSize(result, child_scan_count); - FlatVector::SetSize(result, count_t(scan_count)); return scan_count; } @@ -182,7 +181,7 @@ void ListColumnData::InitializeAppend(ColumnAppendState &state) { state.child_appends.push_back(std::move(child_append_state)); } -void ListColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { +void ListColumnData::Append(ColumnAppendState &state, const Vector &vector, idx_t count) { D_ASSERT(count > 0); // construct the list_entry_t entries to append to the column data @@ -235,13 +234,21 @@ void ListColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vec // append the child vector if (child_count > 0) { - child_column->Append(ListStats::GetChildStats(stats), state.child_appends[1], child_vector, child_count); + child_column->Append(state.child_appends[1], child_vector, child_count); } // append the list offsets - ColumnData::AppendData(stats, state, vdata, count); + ColumnData::AppendData(state, vdata, count); // append the validity data vdata.validity = append_mask; - validity->AppendData(stats, state.child_appends[0], vdata, count); + validity->AppendData(state.child_appends[0], vdata, count); +} + +void ListColumnData::FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) { + ColumnData::FinalizeAppend(finalize_state, state); + validity->FinalizeAppendLocked(finalize_state, state.child_appends[0]); + + ColumnDataFinalizeAppendState child_finalize_state(finalize_state, LogicalTypeId::LIST); + child_column->FinalizeAppendLocked(child_finalize_state, state.child_appends[1]); } void ListColumnData::RevertAppend(row_t new_count) { @@ -273,49 +280,50 @@ void ListColumnData::UpdateColumn(TransactionData transaction, DuckTableEntry &t unique_ptr ListColumnData::GetUpdateStatistics() { return nullptr; } - -void ListColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, - row_t row_id, Vector &result, idx_t result_idx) { - // insert any child states that are required - // we need two (validity & list child) - // note that we need a scan state for the child vector - // this is because we will (potentially) fetch more than one tuple from the list child +void ListColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t fetch_count, Vector &result, + idx_t result_offset) { + // insert the validity child state if (state.child_states.empty()) { auto child_state = make_uniq(); state.child_states.push_back(std::move(child_state)); } - // now perform the fetch within the segment - auto start_offset = row_id == 0 ? 0 : FetchListOffset(UnsafeNumericCast(row_id - 1)); - auto end_offset = FetchListOffset(UnsafeNumericCast(row_id)); - validity->FetchRow(transaction, *state.child_states[0], storage_index, row_id, result, result_idx); + validity->FetchRowsAtSegmentLevel(transaction, *state.child_states[0], offsets, sel, fetch_count, result, + result_offset); auto &validity_mask = FlatVector::ValidityMutable(result); auto list_data = FlatVector::GetDataMutable(result); - auto &list_entry = list_data[result_idx]; - // set the list entry offset to the size of the current list - list_entry.offset = ListVector::GetListSize(result); - list_entry.length = end_offset - start_offset; - if (!validity_mask.RowIsValid(result_idx)) { - // the list is NULL! no need to fetch the child - D_ASSERT(list_entry.length == 0); - return; - } + for (idx_t idx = 0; idx < fetch_count; idx++) { + const auto row_id = offsets[sel.get_index(idx)]; + auto start_offset = row_id == 0 ? 0 : FetchListOffset(row_id - 1); + auto end_offset = FetchListOffset(row_id); + auto result_idx = result_offset + idx; + auto &list_entry = list_data[result_idx]; + // set the list entry offset to the size of the current list + list_entry.offset = ListVector::GetListSize(result); + list_entry.length = end_offset - start_offset; + if (!validity_mask.RowIsValid(result_idx)) { + // the list is NULL! no need to fetch the child + D_ASSERT(list_entry.length == 0); + continue; + } - // now we need to read from the child all the elements between [offset...length] - auto child_scan_count = list_entry.length; - if (child_scan_count > 0) { - ColumnScanState child_state(nullptr); - auto &child_type = ListType::GetChildType(result.GetType()); - Vector child_scan(child_type, child_scan_count); - // seek the scan towards the specified position and read [length] entries - child_state.Initialize(state.context, child_type, nullptr); - child_column->InitializeScanWithOffset(child_state, start_offset); - D_ASSERT(child_type.InternalType() == PhysicalType::STRUCT || - child_state.offset_in_column + child_scan_count <= child_column->GetMaxEntry()); - child_column->ScanCount(child_state, child_scan, child_scan_count); - - ListVector::Append(result, child_scan, child_scan_count); + // now we need to read from the child all the elements between [offset...length] + auto child_scan_count = list_entry.length; + if (child_scan_count > 0) { + ColumnScanState child_state(nullptr); + auto &child_type = ListType::GetChildType(result.GetType()); + Vector child_scan(child_type, child_scan_count); + // seek the scan towards the specified position and read [length] entries + child_state.Initialize(state.context, child_type, nullptr); + child_column->InitializeScanWithOffset(child_state, start_offset); + D_ASSERT(child_type.InternalType() == PhysicalType::STRUCT || + child_state.offset_in_column + child_scan_count <= child_column->GetMaxEntry()); + child_column->ScanCount(child_state, child_scan, child_scan_count); + + ListVector::Append(result, child_scan, child_scan_count); + } } } @@ -432,12 +440,13 @@ void ListColumnData::InitializeColumn(PersistentColumnData &column_data, BaseSta } void ListColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, - vector &result) { - ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result); + vector &result, + const ColumnSegmentInfoScanOptions &options) { + ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result, options); col_path.push_back(0); - validity->GetColumnSegmentInfo(context, row_group_index, col_path, result); + validity->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); col_path.back() = 1; - child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); + child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); } } // namespace duckdb diff --git a/src/duckdb/src/storage/table/per_column_metadata_blocks.cpp b/src/duckdb/src/storage/table/per_column_metadata_blocks.cpp new file mode 100644 index 000000000..afc7d6f34 --- /dev/null +++ b/src/duckdb/src/storage/table/per_column_metadata_blocks.cpp @@ -0,0 +1,125 @@ +#include "duckdb/storage/table/per_column_metadata_blocks.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +static constexpr idx_t COLUMN_INDEX_BIT = idx_t(1) << 63; + +idx_t PerColumnMetadataBlock::GetPacked() { + idx_t packed = index; + if (is_column_index) { + packed |= COLUMN_INDEX_BIT; + } + return packed; +} + +PerColumnMetadataBlock PerColumnMetadataBlock::Unpack(idx_t packed) { + PerColumnMetadataBlock result; + result.is_column_index = (packed & COLUMN_INDEX_BIT) != 0; + result.index = packed & ~COLUMN_INDEX_BIT; + return result; +} + +vector> PerColumnMetadataBlocks::GetBlocksForColumns(const vector &columns) const { + vector> result(columns.size()); + if (columns.empty()) { + return result; + } + idx_t col_pos = 0; + bool collecting = false; + for (auto &entry : data) { + if (entry.is_column_index) { + collecting = false; + // skip past requested columns that are before the current entry + while (col_pos < columns.size() && columns[col_pos] < entry.index) { + col_pos++; + } + if (col_pos >= columns.size()) { + break; + } + if (columns[col_pos] == entry.index) { + collecting = true; + } + } else if (collecting) { + result[col_pos].push_back(entry.index); + } + } + return result; +} + +void PerColumnMetadataBlocks::AddColumn(idx_t col_idx, const vector &blocks) { + if (blocks.empty()) { + return; + } +#ifdef D_ASSERT_IS_ENABLED + // assert sorted insertion: col_idx must be greater than the last column index + for (idx_t i = data.size(); i > 0; i--) { + if (data[i - 1].is_column_index) { + D_ASSERT(col_idx > data[i - 1].index); + break; + } + } +#endif + PerColumnMetadataBlock marker; + marker.is_column_index = true; + marker.index = col_idx; + data.push_back(marker); + for (auto &block_id : blocks) { + PerColumnMetadataBlock block; + block.is_column_index = false; + block.index = block_id; + data.push_back(block); + } +} + +PerColumnMetadataBlocks PerColumnMetadataBlocks::Merge(const PerColumnMetadataBlocks &a, + const PerColumnMetadataBlocks &b) { + PerColumnMetadataBlocks result; + result.data.reserve(a.data.size() + b.data.size()); + idx_t ai = 0; + idx_t bi = 0; + // each marker is followed by its blocks until the next marker; data is sorted by column index + while (ai < a.data.size() && bi < b.data.size()) { + D_ASSERT(a.data[ai].is_column_index && b.data[bi].is_column_index); + D_ASSERT(a.data[ai].index != b.data[bi].index); + bool take_a = a.data[ai].index < b.data[bi].index; + const auto &src = take_a ? a.data : b.data; + idx_t &pos = take_a ? ai : bi; + result.data.push_back(src[pos++]); // marker + while (pos < src.size() && !src[pos].is_column_index) { + result.data.push_back(src[pos++]); + } + } + while (ai < a.data.size()) { + result.data.push_back(a.data[ai++]); + } + while (bi < b.data.size()) { + result.data.push_back(b.data[bi++]); + } + return result; +} + +void PerColumnMetadataBlocks::RemoveColumn(idx_t col_idx) { + idx_t start = data.size(); + idx_t end = data.size(); + for (idx_t i = 0; i < data.size(); i++) { + if (data[i].is_column_index && data[i].index == col_idx) { + start = i; + // find the end: next column marker or end of data + end = data.size(); + for (idx_t j = i + 1; j < data.size(); j++) { + if (data[j].is_column_index) { + end = j; + break; + } + } + break; + } + } + if (start < data.size()) { + data.erase(data.begin() + NumericCast(start), data.begin() + NumericCast(end)); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/persistent_table_data.cpp b/src/duckdb/src/storage/table/persistent_table_data.cpp index 6da743f87..8630317b2 100644 --- a/src/duckdb/src/storage/table/persistent_table_data.cpp +++ b/src/duckdb/src/storage/table/persistent_table_data.cpp @@ -2,7 +2,7 @@ namespace duckdb { -PersistentTableData::PersistentTableData(idx_t column_count) : total_rows(0), row_group_count(0) { +PersistentTableData::PersistentTableData(idx_t column_count) : total_rows(0), next_row_id(0), row_group_count(0) { } PersistentTableData::~PersistentTableData() { diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index 33f3a513f..46b0524c9 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -1,4 +1,5 @@ #include "duckdb/storage/table/row_group.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/transaction/commit_state.hpp" #include "duckdb/common/exception.hpp" @@ -10,6 +11,7 @@ #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/execution/adaptive_filter.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/planner/filter/expression_filter.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/checkpoint/table_data_writer.hpp" #include "duckdb/storage/metadata/metadata_reader.hpp" @@ -28,6 +30,7 @@ #include "duckdb/storage/table/row_group_collection.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/common/storage_compatibility.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" @@ -56,6 +59,8 @@ RowGroup::RowGroup(RowGroupCollection &collection_p, RowGroupPointer pointer) this->deletes_pointers = std::move(pointer.deletes_pointers); this->has_metadata_blocks = pointer.has_metadata_blocks; this->extra_metadata_blocks = std::move(pointer.extra_metadata_blocks); + this->has_per_column_metadata_blocks = pointer.has_per_column_metadata_blocks; + this->per_column_metadata_blocks = std::move(pointer.per_column_metadata_blocks); Verify(); } @@ -164,6 +169,14 @@ ColumnData &RowGroup::GetColumn(storage_t c) const { return *columns[c]; } +ColumnData &RowGroup::GetRawColumnData(const StorageIndex &c) const { + return GetColumn(c); +} + +ColumnData &RowGroup::GetRawColumnData(storage_t c) const { + return GetColumn(c); +} + void RowGroup::LoadColumn(storage_t c) const { if (c == COLUMN_IDENTIFIER_ROW_ID) { LoadRowIdColumnData(); @@ -349,7 +362,8 @@ void ColumnScanState::Initialize(const QueryContext &context_p, const LogicalTyp Initialize(context_p, type, column_id, options); } -void CollectionScanState::Initialize(const QueryContext &context, const vector &types) { +void CollectionScanState::Initialize(const QueryContext &context_p, const vector &types) { + context = context_p; auto &column_ids = GetColumnIds(); D_ASSERT(column_scans.empty()); column_scans.reserve(column_scans.size()); @@ -369,7 +383,7 @@ void CollectionScanState::Initialize(const QueryContext &context, const vector &node, idx_t vector_offset) { auto &column_ids = state.GetColumnIds(); auto &filters = state.GetFilterInfo(); - if (!CheckZonemap(filters)) { + if (!CheckZonemap(state.context.GetClientContext(), filters)) { return false; } if (!RefersToSameObject(node.GetNode(), *this)) { @@ -398,7 +412,7 @@ bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, SegmentNode< bool RowGroup::InitializeScan(CollectionScanState &state, SegmentNode &node) { auto &column_ids = state.GetColumnIds(); auto &filters = state.GetFilterInfo(); - if (!CheckZonemap(filters)) { + if (!CheckZonemap(state.context.GetClientContext(), filters)) { return false; } if (!RefersToSameObject(node.GetNode(), *this)) { @@ -421,6 +435,24 @@ bool RowGroup::InitializeScan(CollectionScanState &state, SegmentNode return true; } +unique_ptr RowGroup::CreateNewRowGroupCopy(RowGroupCollection &new_collection, idx_t new_column_count) { + auto row_group = make_uniq(new_collection, this->count); + row_group->deletes_pointers = deletes_pointers; + row_group->deletes_is_loaded = deletes_is_loaded.load(); + row_group->owned_version_info = owned_version_info; + row_group->version_info = version_info.load(); + row_group->columns.resize(new_column_count); + if (is_loaded) { + row_group->is_loaded = unique_ptr[]>(new atomic[new_column_count]); + } + if (!column_pointers.empty()) { + row_group->column_pointers.resize(new_column_count); + } + row_group->has_per_column_metadata_blocks = has_per_column_metadata_blocks; + row_group->has_changes = true; + return row_group; +} + unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, const LogicalType &target_type, idx_t changed_idx, ExpressionExecutor &executor, CollectionScanState &scan_state, SegmentNode &node, @@ -453,27 +485,45 @@ unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, con executor.ExecuteExpression(scan_chunk, append_vector); column_data->Append(append_state, append_vector, scan_chunk.size()); } + column_data->FinalizeAppend(nullptr, append_state); - // set up the row_group based on this row_group - auto row_group = make_uniq(new_collection, this->count); - row_group->SetVersionInfo(GetOrCreateVersionInfoPtr()); - auto &cols = GetColumns(); - for (idx_t i = 0; i < cols.size(); i++) { + auto supports_per_column_writes = new_collection.SupportsPerColumnWrites(); + if (!supports_per_column_writes) { + // ensure all columns are loaded, as checkpointing will need to load them anyway, and it's better to fail-fast + // in case these don't fit in memory here + GetColumns(); + } + unique_lock lock(row_group_lock); + auto row_group = CreateNewRowGroupCopy(new_collection, columns.size()); + // copy existing columns, but swap out the one at changed_idx + for (idx_t i = 0; i < columns.size(); i++) { if (i == changed_idx) { - // this is the altered column: use the new column - row_group->columns.push_back(std::move(column_data)); + row_group->columns[i] = std::move(column_data); + if (row_group->is_loaded) { + row_group->is_loaded[i] = true; + } column_data.reset(); } else { - // this column was not altered: use the data directly - row_group->columns.push_back(cols[i]); + row_group->columns[i] = columns[i]; + if (row_group->is_loaded) { + row_group->is_loaded[i] = is_loaded[i].load(); + } + if (!row_group->column_pointers.empty()) { + row_group->column_pointers[i] = column_pointers[i]; + } } } + if (has_per_column_metadata_blocks) { + row_group->per_column_metadata_blocks = per_column_metadata_blocks; + row_group->per_column_metadata_blocks.RemoveColumn(changed_idx); + } + lock.unlock(); row_group->Verify(); return row_group; } unique_ptr RowGroup::AddColumn(RowGroupCollection &new_collection, ColumnDefinition &new_column, - ExpressionExecutor &executor, Vector &result) { + ExpressionExecutor &executor) { Verify(); // construct a new column data for the new column @@ -483,36 +533,49 @@ unique_ptr RowGroup::AddColumn(RowGroupCollection &new_collection, Col idx_t rows_to_write = this->count; if (rows_to_write > 0) { DataChunk dummy_chunk; + DataChunk result_chunk; + result_chunk.Initialize(Allocator::DefaultAllocator(), {new_column.GetType()}); + auto &result = result_chunk.data[0]; ColumnAppendState state; added_column->InitializeAppend(state); for (idx_t i = 0; i < rows_to_write; i += STANDARD_VECTOR_SIZE) { idx_t rows_in_this_vector = MinValue(rows_to_write - i, STANDARD_VECTOR_SIZE); - dummy_chunk.SetCardinality(rows_in_this_vector); + dummy_chunk.SetChildCardinality(rows_in_this_vector); + result_chunk.Reset(); executor.ExecuteExpression(dummy_chunk, result); added_column->Append(state, result, rows_in_this_vector); } + added_column->FinalizeAppend(nullptr, state); } - // set up the row_group based on this row_group - auto row_group = make_uniq(new_collection, this->count); - row_group->SetVersionInfo(GetOrCreateVersionInfoPtr()); - row_group->columns = columns; - if (is_loaded) { - // preserve lazy loading - row_group->is_loaded = unique_ptr[]>(new atomic[columns.size() + 1]); - for (idx_t i = 0; i < columns.size(); i++) { - row_group->is_loaded[i].store(is_loaded[i]); + if (!new_collection.SupportsPerColumnWrites()) { + // ensure all columns are loaded, as checkpointing will need to load them anyway, and it's better to fail-fast + // in case these don't fit in memory here + GetColumns(); + } + + unique_lock lock(row_group_lock); + auto row_group = CreateNewRowGroupCopy(new_collection, columns.size() + 1); + // copy existing columns + for (idx_t i = 0; i < columns.size(); i++) { + row_group->columns[i] = columns[i]; + if (row_group->is_loaded) { + row_group->is_loaded[i] = is_loaded[i].load(); + } + if (!row_group->column_pointers.empty()) { + row_group->column_pointers[i] = column_pointers[i]; } - // new column is always loaded - row_group->is_loaded[columns.size()] = true; - row_group->column_pointers = column_pointers; - row_group->column_pointers.emplace_back(); + } + if (has_per_column_metadata_blocks) { + row_group->per_column_metadata_blocks = per_column_metadata_blocks; } // add the new column - row_group->columns.push_back(std::move(added_column)); - row_group->has_changes = true; - + row_group->columns[columns.size()] = std::move(added_column); + if (row_group->is_loaded) { + row_group->is_loaded[columns.size()] = true; + } + lock.unlock(); row_group->Verify(); return row_group; } @@ -522,33 +585,33 @@ unique_ptr RowGroup::RemoveColumn(RowGroupCollection &new_collection, D_ASSERT(removed_column < columns.size()); - auto row_group = make_uniq(new_collection, this->count); - row_group->SetVersionInfo(GetOrCreateVersionInfoPtr()); - + if (!new_collection.SupportsPerColumnWrites()) { + // ensure all columns are loaded, as checkpointing will need to load them anyway, and it's better to fail-fast + // in case these don't fit in memory here + GetColumns(); + } + unique_lock lock(row_group_lock); + auto row_group = CreateNewRowGroupCopy(new_collection, columns.size() - 1); // copy over all columns except for the removed one - if (is_loaded) { - // preserve lazy loading - row_group->is_loaded = unique_ptr[]>(new atomic[columns.size() - 1]); - idx_t new_idx = 0; - for (idx_t i = 0; i < columns.size(); i++) { - if (i == removed_column) { - continue; - } - row_group->is_loaded[new_idx].store(is_loaded[i]); - row_group->column_pointers.emplace_back(column_pointers[i]); - row_group->columns.emplace_back(columns[i]); - new_idx++; + idx_t target_idx = 0; + for (idx_t i = 0; i < columns.size(); i++) { + if (i == removed_column) { + continue; } - } else { - for (idx_t i = 0; i < columns.size(); i++) { - if (i == removed_column) { - continue; - } - row_group->columns.emplace_back(columns[i]); + row_group->columns[target_idx] = columns[i]; + if (row_group->is_loaded) { + row_group->is_loaded[target_idx] = is_loaded[i].load(); + } + if (!row_group->column_pointers.empty()) { + row_group->column_pointers[target_idx] = column_pointers[i]; } + target_idx++; } - row_group->has_changes = true; - + if (has_per_column_metadata_blocks) { + row_group->per_column_metadata_blocks = per_column_metadata_blocks; + row_group->per_column_metadata_blocks.RemoveColumn(removed_column); + } + lock.unlock(); row_group->Verify(); return row_group; } @@ -598,10 +661,11 @@ FilterPropagateResult RowGroup::CheckRowIdFilter(const TableFilter &filter, idx_ NumericStats::SetMin(dummy_stats, UnsafeNumericCast(beg_row)); NumericStats::SetMax(dummy_stats, UnsafeNumericCast(end_row)); - return filter.CheckStatistics(dummy_stats); + auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "RowGroup::CheckRowIdFilter"); + return expr_filter.CheckStatistics(dummy_stats); } -bool RowGroup::CheckZonemap(ScanFilterInfo &filters) { +bool RowGroup::CheckZonemap(optional_ptr context, ScanFilterInfo &filters) { auto &filter_list = filters.GetFilterList(); // new row group - label all filters as up for grabs again filters.CheckAllFilters(); @@ -610,11 +674,11 @@ bool RowGroup::CheckZonemap(ScanFilterInfo &filters) { auto &filter = entry.filter; const auto &base_column_index = entry.table_column_index; - auto prune_result = GetColumn(base_column_index).CheckZonemap(base_column_index, filter); + auto prune_result = GetColumn(base_column_index).CheckZonemap(context, base_column_index, filter); if (prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { return false; } - if (filter.IsOnlyForZoneMapFiltering()) { + if (ExpressionFilter::IsRootOptionalFilter(filter)) { // these are only for row group checking, set as always true so we don't check it filters.SetFilterAlwaysTrue(i); } else if (prune_result == FilterPropagateResult::FILTER_ALWAYS_TRUE) { @@ -657,7 +721,9 @@ bool RowGroup::CheckZonemapSegments(CollectionScanState &state) { } D_ASSERT(target_row >= row_start); D_ASSERT(target_row <= row_start + this->count); - idx_t target_vector_index = (target_row - row_start) / STANDARD_VECTOR_SIZE; + // current_segment->GetRowStart() is already row-group-relative, and state.vector_index uses the same + // coordinate space. Subtracting row_start here incorrectly makes the target segment-local. + idx_t target_vector_index = target_row / STANDARD_VECTOR_SIZE; if (!target_vector_index_max.IsValid() || target_vector_index_max.GetIndex() < target_vector_index) { target_vector_index_max = target_vector_index; @@ -714,6 +780,8 @@ void RowGroup::Scan(ScanOptions options, CollectionScanState &state, DataChunk & NextVector(state); continue; } + state.rows_scanned += count; + auto &block_manager = GetBlockManager(); if (block_manager.Prefetch()) { PrefetchState prefetch_state; @@ -794,6 +862,7 @@ void RowGroup::Scan(ScanOptions options, CollectionScanState &state, DataChunk & auto &col_data = GetColumn(col_idx); col_data.Skip(state.column_scans[i]); } + filter_info.EndFilter(filter_state); state.vector_index++; continue; } @@ -864,7 +933,7 @@ optional_ptr RowGroup::GetVersionInfo() { if (!HasUnloadedDeletes()) { return version_info; } - // deletes are not loaded - reload + D_ASSERT(!deletes_pointers.empty()); auto root_delete = deletes_pointers[0]; auto loaded_info = RowVersionManager::Deserialize(root_delete, GetBlockManager().GetMetadataManager()); SetVersionInfo(std::move(loaded_info)); @@ -926,30 +995,33 @@ idx_t RowGroup::GetSelVector(ScanOptions options, idx_t vector_idx, SelectionVec return vinfo->GetSelVector(options, vector_idx, sel_vector, max_count); } -bool RowGroup::Fetch(TransactionData transaction, idx_t row) { - if (UnsafeNumericCast(row) > count) { - throw InternalException("RowGroup::Fetch - row_id out of range for row group"); +idx_t RowGroup::Fetch(TransactionData transaction, const idx_t *offsets, idx_t fetch_count, + SelectionVector &visible_sel) { + if (fetch_count == 0) { + return 0; } auto vinfo = GetVersionInfo(); if (!vinfo) { - return true; + // No version info at all, which means every row is visible. + return fetch_count; } - return vinfo->Fetch(transaction, row); + return vinfo->GetVisibleRows(transaction, offsets, fetch_count, visible_sel); } -void RowGroup::FetchRow(TransactionData transaction, ColumnFetchState &state, const vector &column_ids, - row_t row_id, DataChunk &result, idx_t result_idx) { - if (UnsafeNumericCast(row_id) > count) { - throw InternalException("RowGroup::FetchRow - row_id out of range for row group"); +void RowGroup::FetchRows(TransactionData transaction, ColumnFetchState &state, const vector &column_ids, + const idx_t *offsets, const SelectionVector &visible_sel, idx_t visible_count, + DataChunk &result, idx_t result_offset) { + if (visible_count == 0) { + return; } + for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { auto &column = column_ids[col_idx]; auto &result_vector = result.data[col_idx]; D_ASSERT(result_vector.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(!FlatVector::IsNull(result_vector, result_idx)); - // regular column: fetch data from the base column auto &col_data = GetColumn(column); - col_data.FetchRow(transaction, state, column, row_id, result_vector, result_idx); + col_data.FetchRows(transaction, state, column, offsets, visible_sel, visible_count, result_vector, + result_offset); } } @@ -995,8 +1067,21 @@ void RowGroup::RevertAppend(idx_t new_count) { Verify(); } -void RowGroup::InitializeAppend(RowGroupAppendState &append_state) { - append_state.row_group = this; +RowGroupAppendState::RowGroupAppendState(TableAppendState &parent_p) + : parent(parent_p), row_group(nullptr), offset_in_row_group(0) { +} + +RowGroupAppendState::~RowGroupAppendState() { +} + +void RowGroup::InitializeAppend(SegmentNode &row_group, RowGroupAppendState &append_state) { + append_state.row_group = row_group; + row_group.GetNode().InitializeAppendInternal(append_state); +} +void RowGroup::InitializeAppendInternal(RowGroupAppendState &append_state) { + if (!RefersToSameObject(append_state.row_group->GetNode(), *this)) { + throw InternalException("RowGroup::InitializeAppend mismatch - call RowGroupAppendState::InitializeAppend"); + } append_state.offset_in_row_group = this->count; // for each column, initialize the append state append_state.states = make_unsafe_uniq_array(GetColumnCount()); @@ -1018,6 +1103,22 @@ void RowGroup::Append(RowGroupAppendState &state, DataChunk &chunk, idx_t append state.offset_in_row_group += append_count; } +void RowGroup::FinalizeAppend(RowGroupAppendState &state) { + auto &parent_stats = state.parent.stats; + if (parent_stats.Empty()) { + for (idx_t c = 0; c < GetColumnCount(); c++) { + auto &col_data = GetColumn(c); + col_data.FinalizeAppend(nullptr, state.states[c]); + } + } else { + auto stats_lock = parent_stats.GetLock(); + for (idx_t c = 0; c < GetColumnCount(); c++) { + auto &col_data = GetColumn(c); + col_data.FinalizeAppend(parent_stats.GetStats(*stats_lock, c).Statistics(), state.states[c]); + } + } +} + void RowGroup::CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_t count) { auto &vinfo = GetOrCreateVersionInfo(); vinfo.CleanupAppend(lowest_transaction, start, count); @@ -1129,6 +1230,23 @@ CompressionType ColumnCheckpointInfo::GetCompressionType() { return info.compression_types[column_idx]; } +shared_ptr RowGroup::CheckpointColumn(const RowGroup &row_group, idx_t column_idx, RowGroupWriteInfo &info, + RowGroupWriteData &write_data) { + auto &column = row_group.GetColumn(column_idx); + ColumnCheckpointInfo checkpoint_info(info, column_idx); + auto checkpoint_state = column.Checkpoint(row_group, checkpoint_info); + + auto result_col = checkpoint_state->GetFinalResult(); + // FIXME: we should get rid of the checkpoint state statistics - and instead use the stats in the ColumnData + // directly + auto stats = checkpoint_state->GetStatistics(); + result_col->MergeStatistics(*stats); + + write_data.statistics.push_back(stats->Copy()); + write_data.states.push_back(std::move(checkpoint_state)); + return result_col; +} + vector RowGroup::WriteToDisk(RowGroupWriteInfo &info, const vector> &row_groups) { vector result; @@ -1162,24 +1280,8 @@ vector RowGroup::WriteToDisk(RowGroupWriteInfo &info, for (idx_t column_idx = 0; column_idx < column_count; column_idx++) { for (idx_t row_group_idx = 0; row_group_idx < row_groups.size(); row_group_idx++) { auto &row_group = row_groups[row_group_idx].get(); - auto &row_group_write_data = result[row_group_idx]; - // if we are loading a column just to checkpoint we unload it post-checkpoint - auto keep_column_loaded = row_group.ColumnIsLoaded(column_idx); - - auto &column = row_group.GetColumn(column_idx); - ColumnCheckpointInfo checkpoint_info(info, column_idx); - auto checkpoint_state = column.Checkpoint(row_group, checkpoint_info); - - auto result_col = checkpoint_state->GetFinalResult(); - // FIXME: we should get rid of the checkpoint state statistics - and instead use the stats in the ColumnData - // directly - auto stats = checkpoint_state->GetStatistics(); - result_col->MergeStatistics(*stats); - - result_columns[row_group_idx].emplace_back(std::move(result_col)); - row_group_write_data.statistics.push_back(stats->Copy()); - row_group_write_data.states.push_back(std::move(checkpoint_state)); - row_group_write_data.keep_column_loaded.emplace_back(keep_column_loaded); + result_columns[row_group_idx].emplace_back( + CheckpointColumn(row_group, column_idx, info, result[row_group_idx])); } } @@ -1233,46 +1335,50 @@ bool RowGroup::HasUnloadedDeletes() const { return !deletes_is_loaded; } -vector RowGroup::GetOrComputeExtraMetadataBlocks(bool force_compute) { - if (has_metadata_blocks && !force_compute) { - return extra_metadata_blocks; - } +PerColumnMetadataBlocks RowGroup::ComputePerColumnMetadataBlocks() const { + PerColumnMetadataBlocks result; if (column_pointers.empty()) { - // no pointers - return {}; - } - vector read_pointers; - // column_pointers stores the beginning of each column - // if columns are big - they may span multiple metadata blocks - // we need to figure out all blocks that this row group points to - // we need to follow the linked list in the metadata blocks to allow for this + return result; + } + auto &metadata_manager = GetCollection().GetMetadataManager(); - idx_t last_idx = column_pointers.size() - 1; - if (column_pointers.size() > 1) { - // for all but the last column pointer - we can just follow the linked list until we reach the last column - MetadataReader reader(metadata_manager, column_pointers[0]); - auto last_pointer = column_pointers[last_idx]; - read_pointers = reader.GetRemainingBlocks(last_pointer); - } - // for the last column we need to deserialize the column - because we don't know where it stops auto &types = GetCollection().GetTypes(); - MetadataReader reader(metadata_manager, column_pointers[last_idx], &read_pointers); - ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), last_idx, reader, types[last_idx]); - unordered_set result_as_set; - for (auto &ptr : read_pointers) { - result_as_set.emplace(ptr.block_pointer); - } - for (auto &ptr : column_pointers) { - result_as_set.erase(ptr.block_pointer); + for (idx_t i = 0; i < column_pointers.size(); i++) { + auto &start = column_pointers[i]; + vector col_read_pointers; + MetadataReader col_reader(metadata_manager, start, &col_read_pointers); + ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), i, col_reader, types[i]); + vector extra_blocks; + for (auto &ptr : col_read_pointers) { + if (ptr.block_pointer != start.block_pointer) { + extra_blocks.emplace_back(ptr.block_pointer); + } + } + result.AddColumn(i, extra_blocks); } - return {result_as_set.begin(), result_as_set.end()}; + return result; } const vector &RowGroup::GetColumnStartPointers() const { return column_pointers; } +vector RowGroup::GetExtraMetadataBlockPointers() const { + vector extra_metadata_block_pointers; + if (has_per_column_metadata_blocks) { + per_column_metadata_blocks.ForEachBlock( + [&](idx_t, idx_t block_id) { extra_metadata_block_pointers.emplace_back(block_id, 0); }); + } else { + D_ASSERT(has_metadata_blocks); + extra_metadata_block_pointers.reserve(extra_metadata_blocks.size()); + for (auto &block_pointer : extra_metadata_blocks) { + extra_metadata_block_pointers.emplace_back(block_pointer, 0); + } + } + return extra_metadata_block_pointers; +} + bool RowGroup::CanReuseMetadata(RowGroupWriter &writer) const { if (!Settings::Get(writer.GetDatabase())) { // disabled by configuration @@ -1282,10 +1388,6 @@ bool RowGroup::CanReuseMetadata(RowGroupWriter &writer) const { // no existing metadata on disk - cannot re-use return false; } - if (HasChanges()) { - // we have changes - need to rewrite - return false; - } auto &table_writer = writer.GetTableWriter(); if (table_writer.RequireLegacyStartRow() && table_writer.RowIdsChanged()) { // row-ids changed and we are targeting an old storage version that requires "start_row" - cannot re-use @@ -1294,33 +1396,137 @@ bool RowGroup::CanReuseMetadata(RowGroupWriter &writer) const { return true; } +bool RowGroup::HasUnchangedColumns() const { + for (idx_t c = 0; c < columns.size(); c++) { + if (!ColumnIsLoaded(c)) { + return true; + } + if (!columns[c]->HasAnyChanges()) { + return true; + } + } + return false; +} + RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { - if (CanReuseMetadata(writer)) { - // we have existing metadata and the row group has not been changed - // re-use previous metadata + bool can_reuse_metadata = CanReuseMetadata(writer); + if (can_reuse_metadata && !HasChanges()) { RowGroupWriteData result; - result.reuse_existing_metadata_blocks = true; - result.existing_extra_metadata_blocks = GetOrComputeExtraMetadataBlocks(); - return result; + result.write_action = RowGroupWriteAction::REUSE_EXISTING_ROW_GROUP_METADATA; + if (GetCollection().SupportsPerColumnWrites()) { + if (has_per_column_metadata_blocks) { + return result; + } + per_column_metadata_blocks = ComputePerColumnMetadataBlocks(); + has_per_column_metadata_blocks = true; + return result; + } + + if (has_metadata_blocks) { + return result; + } + + if (!has_per_column_metadata_blocks) { + auto meta_blocks = ComputePerColumnMetadataBlocks(); + + vector computed_extra_metadata_blocks; + meta_blocks.ForEachBlock( + [&](idx_t, idx_t block_id) { computed_extra_metadata_blocks.emplace_back(block_id); }); + extra_metadata_blocks = computed_extra_metadata_blocks; + has_metadata_blocks = true; + return result; + } + + D_ASSERT(has_per_column_metadata_blocks && !GetCollection().SupportsPerColumnWrites()); + // we loaded column-level metadata from disk, but don't support writing it anymore, so we need to fall back to a + // full checkpoint as we need to write out the metadata in a single go } + + // determine which columns can be reused + bool partial_reuse = + can_reuse_metadata && has_per_column_metadata_blocks && GetCollection().SupportsPerColumnWrites(); auto &compression_types = writer.GetCompressionTypes(); - if (columns.size() != compression_types.size()) { - throw InternalException("RowGroup::WriteToDisk - mismatch in column count vs compression types"); + RowGroupWriteData result; + if (partial_reuse) { + result.write_action = RowGroupWriteAction::PARTIALLY_REUSE_COLUMN_METADATA; + } else { + result.write_action = RowGroupWriteAction::FULLY_CHECKPOINT_ROW_GROUP; + } + + auto result_row_group = make_shared_ptr(GetCollection(), this->count); + result_row_group->columns.resize(GetColumnCount()); + result_row_group->column_pointers.resize(GetColumnCount()); + result_row_group->deletes_pointers = deletes_pointers; + result_row_group->deletes_is_loaded = deletes_is_loaded.load(); + result_row_group->owned_version_info = owned_version_info; + result_row_group->version_info = version_info.load(); + if (is_loaded) { + result_row_group->is_loaded = unique_ptr[]>(new atomic[GetColumnCount()]); + for (idx_t c = 0; c < GetColumnCount(); c++) { + result_row_group->is_loaded[c] = true; + } } + + RowGroupWriteInfo info(writer.GetPartialBlockManager(), compression_types, writer.GetCheckpointOptions()); + + vector reused_columns; for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { - if (!ColumnIsLoaded(column_idx)) { - // don't load a column just for verification - continue; + bool column_has_changes = true; + if (partial_reuse) { + if (!ColumnIsLoaded(column_idx)) { + column_has_changes = false; + } else if (!columns[column_idx]->HasAnyChanges()) { + column_has_changes = false; + } } - auto &column = GetColumn(column_idx); - if (column.count != this->count) { - throw InternalException("Corrupted in-memory column - column with index %llu has misaligned count (row " - "group has %llu rows, column has %llu)", - column_idx, this->count.load(), column.count.load()); + + if (!column_has_changes) { + // reuse this column's metadata + result.states.push_back(nullptr); + reused_columns.emplace_back(column_idx); + result_row_group->column_pointers[column_idx] = column_pointers[column_idx]; + // carry forward existing column data and statistics + if (!ColumnIsLoaded(column_idx)) { + result_row_group->columns[column_idx] = nullptr; + result_row_group->is_loaded[column_idx] = false; + // column not loaded - stats will be merged from previous table stats during Checkpoint + result.statistics.push_back(BaseStatistics::CreateEmpty(GetCollection().GetTypes()[column_idx])); + } else { + result_row_group->columns[column_idx] = columns[column_idx]; + auto col_stats = columns[column_idx]->GetStatistics(); + result.statistics.push_back(col_stats + ? std::move(*col_stats) + : BaseStatistics::CreateEmpty(GetCollection().GetTypes()[column_idx])); + } + } else { + // checkpoint this column + auto &column = GetColumn(column_idx); + if (column.count != this->count) { + throw InternalException("Corrupted in-memory column - column with index %llu has misaligned count " + "(row group has %llu rows, column has %llu)", + column_idx, this->count.load(), column.count.load()); + } + result_row_group->columns[column_idx] = CheckpointColumn(*this, column_idx, info, result); } } - RowGroupWriteInfo info(writer.GetPartialBlockManager(), compression_types, writer.GetCheckpointOptions()); - return WriteToDisk(info); + + if (!reused_columns.empty()) { + D_ASSERT(partial_reuse); + // carry forward the extras for reused columns onto the new row group, so RowGroup::Checkpoint + // can look them up via this->per_column_metadata_blocks + auto extras = per_column_metadata_blocks.GetBlocksForColumns(reused_columns); + for (idx_t i = 0; i < reused_columns.size(); i++) { + result_row_group->per_column_metadata_blocks.AddColumn(reused_columns[i], extras[i]); + } + result_row_group->has_per_column_metadata_blocks = true; + } else if (partial_reuse) { + // we planned to partially re-use column metadata, but every column ended up being rewritten - + // downgrade to a full checkpoint as there is no column metadata to carry forward + result.write_action = RowGroupWriteAction::FULLY_CHECKPOINT_ROW_GROUP; + } + + result.result_row_group = std::move(result_row_group); + return result; } void IncrementSegmentStart(PersistentColumnData &data, idx_t start_increment) { @@ -1340,26 +1546,28 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite // construct the row group pointer and write the column meta data to disk row_group_pointer.row_start = row_group_start; row_group_pointer.tuple_count = count; - if (write_data.reuse_existing_metadata_blocks) { + if (write_data.write_action == RowGroupWriteAction::REUSE_EXISTING_ROW_GROUP_METADATA) { // we are re-using the previous metadata row_group_pointer.data_pointers = column_pointers; - row_group_pointer.has_metadata_blocks = true; - row_group_pointer.extra_metadata_blocks = write_data.existing_extra_metadata_blocks; + row_group_pointer.has_metadata_blocks = has_metadata_blocks; + row_group_pointer.extra_metadata_blocks = extra_metadata_blocks; + row_group_pointer.has_per_column_metadata_blocks = has_per_column_metadata_blocks; + row_group_pointer.per_column_metadata_blocks = per_column_metadata_blocks; if (metadata_manager) { row_group_pointer.deletes_pointers = CheckpointDeletes(writer); - vector extra_metadata_block_pointers; - extra_metadata_block_pointers.reserve(write_data.existing_extra_metadata_blocks.size()); - for (auto &block_pointer : write_data.existing_extra_metadata_blocks) { - extra_metadata_block_pointers.emplace_back(block_pointer, 0); + vector metadata_block_pointers_to_be_cleared; + if (has_per_column_metadata_blocks) { + per_column_metadata_blocks.ForEachBlock( + [&](idx_t, idx_t block_id) { metadata_block_pointers_to_be_cleared.emplace_back(block_id, 0); }); + } else { + metadata_block_pointers_to_be_cleared.reserve(extra_metadata_blocks.size()); + for (auto &block_pointer : extra_metadata_blocks) { + metadata_block_pointers_to_be_cleared.emplace_back(block_pointer, 0); + } } metadata_manager->ClearModifiedBlocks(column_pointers); - metadata_manager->ClearModifiedBlocks(extra_metadata_block_pointers); - metadata_manager->ClearModifiedBlocks(deletes_pointers); - - // remember metadata_blocks to avoid loading them on future checkpoints - has_metadata_blocks = true; - extra_metadata_blocks = row_group_pointer.extra_metadata_blocks; + metadata_manager->ClearModifiedBlocks(metadata_block_pointers_to_be_cleared); } // merge row group stats into the global stats auto lock = global_stats.GetLock(); @@ -1373,27 +1581,57 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite } return row_group_pointer; } - D_ASSERT(write_data.states.size() == columns.size()); + // write path: write column metadata to disk (with optional per-column reuse) + D_ASSERT(write_data.states.size() == GetColumnCount()); + + // merge stats { auto lock = global_stats.GetLock(); for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { - global_stats.GetStats(*lock, column_idx).Statistics().Merge(write_data.statistics[column_idx]); + bool is_reused = !write_data.states[column_idx]; + if (is_reused) { + if (!ColumnIsLoaded(column_idx) && + collection.get().GetTypes()[column_idx].id() != LogicalTypeId::VARIANT) { + writer.SetHasUnloadedColumn(column_idx); + continue; + } + GetColumn(column_idx).MergeIntoStatistics(global_stats.GetStats(*lock, column_idx).Statistics()); + } else { + global_stats.GetStats(*lock, column_idx).Statistics().Merge(write_data.statistics[column_idx]); + } } } - vector column_metadata; - unordered_set metadata_blocks; - writer.StartWritingColumns(column_metadata); auto serialization_options = SerializationOptions(writer.GetAttachedDatabase()); - for (auto &state : write_data.states) { - // get the current position of the table data writer + // collect blocks that need to be preserved for reused columns + vector reused_column_blocks; + // per-column extras for newly written columns, collected separately and merged with the + // partial map of reused-column extras (carried over onto this row group by WriteToDisk). + PerColumnMetadataBlocks written_column_blocks; + + for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { + bool is_reused = !write_data.states[column_idx]; + if (is_reused) { + // reuse existing column pointer (extras are already carried on this->per_column_metadata_blocks) + auto col_ptr = column_pointers[column_idx]; + row_group_pointer.data_pointers.push_back(col_ptr); + reused_column_blocks.push_back(col_ptr); + continue; + } + // write new metadata for this column + auto &state = write_data.states[column_idx]; + D_ASSERT(state); + + // track blocks written for this column + vector col_written_blocks; + writer.StartWritingColumns(col_written_blocks); + auto &data_writer = writer.GetPayloadWriter(); auto pointer = writer.GetMetaBlockPointer(); // store the stats and the data pointers in the row group pointers row_group_pointer.data_pointers.push_back(pointer); - metadata_blocks.insert(pointer.block_pointer); // Write pointers to the column segments. // @@ -1408,32 +1646,59 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite serializer.Begin(); persistent_data.Serialize(serializer); serializer.End(); + + writer.FinishWritingColumns(); + + // collect per-column extra blocks (excluding the start block) + vector col_extra_blocks; + for (auto &written_ptr : col_written_blocks) { + if (written_ptr.block_pointer != pointer.block_pointer) { + col_extra_blocks.push_back(written_ptr.block_pointer); + } + } + written_column_blocks.AddColumn(column_idx, col_extra_blocks); } - writer.FinishWritingColumns(); - row_group_pointer.has_metadata_blocks = true; - for (auto &column_pointer : column_metadata) { - auto entry = metadata_blocks.find(column_pointer.block_pointer); - if (entry != metadata_blocks.end()) { - // this metadata block is already stored in "data_pointers" - no need to duplicate it - continue; + if (GetCollection().SupportsPerColumnWrites()) { + row_group_pointer.has_per_column_metadata_blocks = true; + if (write_data.write_action == RowGroupWriteAction::PARTIALLY_REUSE_COLUMN_METADATA) { + // merge reused-column extras (on this) with newly-written-column extras + D_ASSERT(has_per_column_metadata_blocks); + row_group_pointer.per_column_metadata_blocks = + PerColumnMetadataBlocks::Merge(per_column_metadata_blocks, written_column_blocks); + + // reused column blocks must be preserved by ClearModifiedBlocks + per_column_metadata_blocks.ForEachBlock( + [&](idx_t, idx_t block_id) { reused_column_blocks.emplace_back(block_id, 0); }); + } else { + row_group_pointer.per_column_metadata_blocks = written_column_blocks; } - // this metadata block is not stored - add it to the extra metadata blocks - row_group_pointer.extra_metadata_blocks.push_back(column_pointer.block_pointer); - metadata_blocks.insert(column_pointer.block_pointer); + row_group_pointer.has_metadata_blocks = false; + row_group_pointer.extra_metadata_blocks.clear(); + } else { + // Per-column reuse is not supported, so instead flatten the newly-written per-column extras into + // extra_metadata_blocks to still allow reusing the full row-group metadata on future checkpoints. + D_ASSERT(write_data.write_action != RowGroupWriteAction::PARTIALLY_REUSE_COLUMN_METADATA); + row_group_pointer.has_per_column_metadata_blocks = false; + row_group_pointer.per_column_metadata_blocks = {}; + row_group_pointer.has_metadata_blocks = true; + row_group_pointer.extra_metadata_blocks.clear(); + written_column_blocks.ForEachBlock( + [&](idx_t, idx_t block_id) { row_group_pointer.extra_metadata_blocks.push_back(block_id); }); } + if (metadata_manager) { row_group_pointer.deletes_pointers = CheckpointDeletes(writer); + metadata_manager->ClearModifiedBlocks(reused_column_blocks); } - // set up the pointers correctly within this row group for future operations + + // cache metadata pointers for future checkpoint reuse column_pointers = row_group_pointer.data_pointers; - has_metadata_blocks = true; + has_metadata_blocks = row_group_pointer.has_metadata_blocks; extra_metadata_blocks = row_group_pointer.extra_metadata_blocks; - for (idx_t c = 0; c < columns.size(); c++) { - if (!write_data.keep_column_loaded[c]) { - UnloadColumn(c); - } - } + has_per_column_metadata_blocks = row_group_pointer.has_per_column_metadata_blocks; + per_column_metadata_blocks = row_group_pointer.per_column_metadata_blocks; + Verify(); return row_group_pointer; } @@ -1497,15 +1762,26 @@ vector RowGroup::CheckpointDeletes(RowGroupWriter &writer) { return vinfo->Checkpoint(writer); } -void RowGroup::Serialize(RowGroupPointer &pointer, Serializer &serializer) { +void RowGroup::Serialize(RowGroupPointer &pointer, Serializer &serializer, bool supports_per_column_writes) { serializer.WriteProperty(100, "row_start", pointer.row_start); serializer.WriteProperty(101, "tuple_count", pointer.tuple_count); serializer.WriteProperty(102, "data_pointers", pointer.data_pointers); serializer.WriteProperty(103, "delete_pointers", pointer.deletes_pointers); - if (serializer.ShouldSerialize(6)) { + if (serializer.ShouldSerialize(StorageVersion::V1_4_0) && !supports_per_column_writes) { serializer.WriteProperty(104, "has_metadata_blocks", pointer.has_metadata_blocks); serializer.WritePropertyWithDefault(105, "extra_metadata_blocks", pointer.extra_metadata_blocks); } + if (supports_per_column_writes) { + D_ASSERT(serializer.ShouldSerialize(StorageVersion::V1_4_0)); + // also write legacy metadata blocks for v1.4 and v1.5 + serializer.WriteProperty(104, "has_metadata_blocks", pointer.has_per_column_metadata_blocks); + vector extra_metadata_block_ids; + pointer.per_column_metadata_blocks.ForEachBlock( + [&](idx_t, idx_t block_id) { extra_metadata_block_ids.push_back(block_id); }); + serializer.WritePropertyWithDefault(105, "extra_metadata_blocks", extra_metadata_block_ids); + serializer.WriteProperty(106, "has_per_column_metadata_blocks", pointer.has_per_column_metadata_blocks); + serializer.WritePropertyWithDefault(107, "per_column_metadata_blocks", pointer.per_column_metadata_blocks.data); + } } RowGroupPointer RowGroup::Deserialize(Deserializer &deserializer) { @@ -1516,6 +1792,15 @@ RowGroupPointer RowGroup::Deserialize(Deserializer &deserializer) { result.deletes_pointers = deserializer.ReadProperty>(103, "delete_pointers"); result.has_metadata_blocks = deserializer.ReadPropertyWithExplicitDefault(104, "has_metadata_blocks", false); result.extra_metadata_blocks = deserializer.ReadPropertyWithDefault>(105, "extra_metadata_blocks"); + result.has_per_column_metadata_blocks = + deserializer.ReadPropertyWithExplicitDefault(106, "has_per_column_metadata_blocks", false); + result.per_column_metadata_blocks = { + deserializer.ReadPropertyWithDefault>(107, "per_column_metadata_blocks")}; + if (result.has_per_column_metadata_blocks) { + // per-column metadata supersedes legacy extra_metadata_blocks + result.has_metadata_blocks = false; + result.extra_metadata_blocks.clear(); + } return result; } @@ -1538,14 +1823,7 @@ struct DuckDBPartitionRowGroup : public PartitionRowGroup { if (!is_exact || row_group->HasChanges()) { return false; } - if (stats.GetStatsType() == StatisticsType::STRING_STATS) { - if (!StringStats::HasMinMax(stats) || !StringStats::HasMaxStringLength(stats)) { - return false; - } - const idx_t max_length = StringStats::MaxStringLength(stats); - return max_length == StringStats::Max(stats).length() && max_length == StringStats::Min(stats).length(); - } - return stats.GetStatsType() == StatisticsType::NUMERIC_STATS; + return true; } }; @@ -1571,10 +1849,14 @@ PartitionStatistics RowGroup::GetPartitionStats(SegmentNode &row_group // GetColumnSegmentInfo //===--------------------------------------------------------------------===// void RowGroup::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, - vector &result) { + vector &result, const ColumnSegmentInfoScanOptions &options) { for (idx_t col_idx = 0; col_idx < GetColumnCount(); col_idx++) { + if (options.loaded_segments_only && !ColumnIsLoaded(col_idx)) { + // column is not loaded - skip it + continue; + } auto &col_data = GetColumn(col_idx); - col_data.GetColumnSegmentInfo(context, row_group_index, {col_idx}, result); + col_data.GetColumnSegmentInfo(context, row_group_index, {col_idx}, result, options); } } @@ -1619,8 +1901,13 @@ idx_t RowGroup::Delete(TransactionData transaction, DuckTableEntry &table_entry, void RowGroup::Verify() { #ifdef DEBUG - for (auto &column : GetColumns()) { - column->Verify(*this); + for (idx_t c = 0; c < columns.size(); c++) { + if (!ColumnIsLoaded(c)) { + continue; + } + if (columns[c]) { + columns[c]->Verify(*this); + } } lock_guard guard(row_group_lock); if (row_id_is_loaded) { diff --git a/src/duckdb/src/storage/table/row_group_collection.cpp b/src/duckdb/src/storage/table/row_group_collection.cpp index ba0def7c7..7b37aa6a5 100644 --- a/src/duckdb/src/storage/table/row_group_collection.cpp +++ b/src/duckdb/src/storage/table/row_group_collection.cpp @@ -7,6 +7,8 @@ #include "duckdb/execution/index/art/art.hpp" #include "duckdb/execution/index/bound_index.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/profiling_utils.hpp" +#include "duckdb/main/query_profiler.hpp" #include "duckdb/main/settings.hpp" #include "duckdb/parallel/task_executor.hpp" #include "duckdb/planner/constraints/bound_not_null_constraint.hpp" @@ -22,10 +24,24 @@ #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/common/storage_compatibility.hpp" #include "duckdb/common/type_visitor.hpp" namespace duckdb { +static bool CanRebuildExistingIndexesAfterVacuum(DataTableInfo &info, AttachedDatabase &attached, idx_t total_rows) { + auto &indexes = info.GetIndexes(); + if (indexes.Empty() || indexes.HasUnbound()) { + return false; + } + auto vacuum_rebuild_threshold = attached.GetVacuumRebuildIndexThreshold(); + if (vacuum_rebuild_threshold == 0 || total_rows > vacuum_rebuild_threshold) { + return false; + } + auto index_types = indexes.DistinctIndexTypes(); + return index_types.size() == 1 && index_types.count(ART::TYPE_NAME); +} + //===--------------------------------------------------------------------===// // Row Group Segment Tree //===--------------------------------------------------------------------===// @@ -44,18 +60,19 @@ void RowGroupSegmentTree::Initialize(PersistentTableData &data, optional_ptr RowGroupSegmentTree::LoadSegment() const { +optional> RowGroupSegmentTree::LoadSegment() const { if (current_row_group >= max_row_group) { reader.reset(); finished_loading = true; - return nullptr; + return nullopt; } BinaryDeserializer deserializer(*reader); deserializer.Begin(); auto row_group_pointer = RowGroup::Deserialize(deserializer); deserializer.End(); current_row_group++; - return make_shared_ptr(collection, std::move(row_group_pointer)); + auto row_start = row_group_pointer.row_start; + return LoadedSegment(make_shared_ptr(collection, std::move(row_group_pointer)), row_start); } //===--------------------------------------------------------------------===// @@ -70,9 +87,10 @@ RowGroupCollection::RowGroupCollection(shared_ptr info_p, TableIO RowGroupCollection::RowGroupCollection(shared_ptr info_p, BlockManager &block_manager, vector types_p, idx_t row_start, idx_t total_rows_p, idx_t row_group_size_p) - : block_manager(block_manager), row_group_size(row_group_size_p), total_rows(total_rows_p), info(std::move(info_p)), - types(std::move(types_p)), owned_row_groups(make_shared_ptr(*this, row_start)), - allocation_size(0), row_group_append_mode(RowGroupAppendMode::APPEND_TO_EXISTING) { + : block_manager(block_manager), row_group_size(row_group_size_p), total_rows(total_rows_p), + next_row_id(total_rows_p), info(std::move(info_p)), types(std::move(types_p)), + owned_row_groups(make_shared_ptr(*this, row_start)), allocation_size(0), + row_group_append_mode(RowGroupAppendMode::APPEND_TO_EXISTING) { // If the table contains shredded types (variant / geometry) then we can't append to an existing row group for (auto &type : types) { if (TypeVisitor::Contains(type, LogicalTypeId::VARIANT) || @@ -87,6 +105,10 @@ idx_t RowGroupCollection::GetTotalRows() const { return total_rows.load(); } +idx_t RowGroupCollection::GetNextRowId() const { + return next_row_id.load(); +} + idx_t RowGroupCollection::GetRowGroupCount() const { auto row_groups = GetRowGroups(); return row_groups->GetSegmentCount(); @@ -129,6 +151,8 @@ void RowGroupCollection::Initialize(PersistentTableData &data) { D_ASSERT(owned_row_groups->GetBaseRowId() == 0); auto l = owned_row_groups->Lock(); this->total_rows = data.total_rows; + this->next_row_id = data.next_row_id; + D_ASSERT(this->next_row_id >= this->total_rows); metadata_pointer = data.base_table_pointer; metadata_pointers = data.read_metadata_pointers; owned_row_groups->Initialize(data, metadata_pointers); @@ -144,12 +168,15 @@ void RowGroupCollection::FinalizeCheckpoint(MetaBlockPointer pointer, void RowGroupCollection::Initialize(PersistentCollectionData &data) { stats.InitializeEmpty(types); auto l = owned_row_groups->Lock(); + auto base_row_id = owned_row_groups->GetBaseRowId(); for (auto &row_group_data : data.row_group_data) { + D_ASSERT(row_group_data.start == base_row_id + total_rows.load()); auto row_group = make_uniq(*this, row_group_data); row_group->MergeIntoStatistics(stats); total_rows += row_group->count; owned_row_groups->AppendSegment(l, std::move(row_group), row_group_data.start); } + next_row_id = total_rows.load(); } void RowGroupCollection::SetRowGroupAppendMode(RowGroupAppendMode mode) { @@ -205,15 +232,18 @@ void RowGroupCollection::Verify() { #ifdef DEBUG idx_t current_total_rows = 0; auto row_groups = GetRowGroups(); - row_groups->Verify(); + row_groups->Verify(SegmentTreeVerifyMode::NON_OVERLAPPING); + idx_t current_rowid_end = row_groups->GetBaseRowId(); for (auto &entry : row_groups->SegmentNodes()) { auto &row_group = entry.GetNode(); row_group.Verify(); D_ASSERT(&row_group.GetCollection() == this); - D_ASSERT(entry.GetRowStart() == row_groups->GetBaseRowId() + current_total_rows); + D_ASSERT(entry.GetRowStart() >= current_rowid_end); current_total_rows += row_group.count; + current_rowid_end = entry.GetRowStart() + row_group.count; } D_ASSERT(current_total_rows == total_rows.load()); + D_ASSERT(row_groups->GetBaseRowId() + next_row_id.load() == current_rowid_end); #endif } @@ -226,7 +256,7 @@ void RowGroupCollection::InitializeScan(const QueryContext &context, CollectionS state.row_groups = GetRowGroups(); auto row_group = state.GetRootSegment(); D_ASSERT(row_group); - state.max_row = state.row_groups->GetBaseRowId() + total_rows; + state.max_row = state.row_groups->GetBaseRowId() + next_row_id.load(); state.Initialize(context, GetTypes()); while (row_group && !row_group->GetNode().InitializeScan(state, *row_group)) { row_group = state.GetNextRowGroup(*row_group); @@ -267,9 +297,9 @@ bool RowGroupCollection::InitializeScanInRowGroup(ClientContext &context, Collec void RowGroupCollection::InitializeParallelScan(ParallelCollectionScanState &state) { state.collection = this; state.row_groups = GetRowGroups(); - state.current_row_group = state.GetRootSegment(*state.row_groups); + state.AssignRowGroup(state.GetRootSegment(*state.row_groups)); state.vector_index = 0; - state.max_row = state.row_groups->GetBaseRowId() + total_rows; + state.max_row = state.row_groups->GetBaseRowId() + next_row_id.load(); state.batch_index = 0; state.processed_rows = 0; } @@ -303,14 +333,14 @@ bool RowGroupCollection::NextParallelScan(ClientContext &context, ParallelCollec D_ASSERT(vector_index * STANDARD_VECTOR_SIZE < current_row_group.count); state.vector_index++; if (state.vector_index * STANDARD_VECTOR_SIZE >= current_row_group.count) { - state.current_row_group = state.GetNextRowGroup(*state.row_groups, *row_group).get(); + state.AssignRowGroup(state.GetNextRowGroup(*state.row_groups, *row_group).get()); state.vector_index = 0; } } else { state.processed_rows += current_row_group.count; vector_index = 0; max_row = row_start + current_row_group.count; - state.current_row_group = state.GetNextRowGroup(*state.row_groups, *row_group).get(); + state.AssignRowGroup(state.GetNextRowGroup(*state.row_groups, *row_group).get()); } max_row = MinValue(max_row, state.max_row); scan_state.batch_index = ++state.batch_index; @@ -429,32 +459,78 @@ RowGroupIterationHelper RowGroupCollection::Chunks(DuckTransaction &transaction, //===--------------------------------------------------------------------===// void RowGroupCollection::Fetch(TransactionData transaction, DataChunk &result, const vector &column_ids, const Vector &row_identifiers, idx_t fetch_count, ColumnFetchState &state) { - // figure out which row_group to fetch from + if (fetch_count == 0) { + result.SetChildCardinality(0); + return; + } + ALWAYS_ASSERT(fetch_count <= STANDARD_VECTOR_SIZE); + auto row_ids = FlatVector::GetData(row_identifiers); - idx_t count = 0; auto row_groups = GetRowGroups(); - for (idx_t i = 0; i < fetch_count; i++) { - auto row_id = row_ids[i]; + idx_t count = 0; + + // Stack-allocated scratch buffers reused across runs/iterations within this call. + // + // Row positions in the row group. + idx_t offsets[STANDARD_VECTOR_SIZE]; + // Visible row positions in the row group. + sel_t visible_sel_buffer[STANDARD_VECTOR_SIZE]; + // Filter selection vector. + SelectionVector filter_sel(visible_sel_buffer, STANDARD_VECTOR_SIZE); + + idx_t pos = 0; + while (pos < fetch_count) { + // 1. resolve the row group containing row_ids[pos] optional_ptr> row_group; { idx_t segment_index; auto l = row_groups->Lock(); - if (!row_groups->TryGetSegmentIndex(l, UnsafeNumericCast(row_id), segment_index)) { - // in parallel append scenarios it is possible for the row_id + if (!row_groups->TryGetSegmentIndex(l, NumericCast(row_ids[pos]), segment_index)) { + // row not yet visible, skip + pos++; continue; } row_group = row_groups->GetSegmentByIndex(l, UnsafeNumericCast(segment_index)); } auto ¤t_row_group = row_group->GetNode(); - auto offset_in_row_group = UnsafeNumericCast(row_id) - row_group->GetRowStart(); - if (state.fetch_type == FetchType::TRANSACTIONAL_FETCH && - !current_row_group.Fetch(transaction, offset_in_row_group)) { + const idx_t row_start = row_group->GetRowStart(); + const idx_t row_end = row_start + current_row_group.count; + + // 2. extend the run while consecutive row-ids stay in [row_start, row_end) + const idx_t run_start = pos; + offsets[0] = NumericCast(row_ids[pos]) - row_start; + pos++; + while (pos < fetch_count) { + const idx_t rid = NumericCast(row_ids[pos]); + if (rid < row_start || rid >= row_end) { + break; + } + offsets[pos - run_start] = rid - row_start; + pos++; + } + const idx_t run_count = pos - run_start; + + // 3. bulk visibility check for the whole run. + idx_t visible_count = 0; + const_reference sel_for_fetch(*FlatVector::IncrementalSelectionVector()); + if (state.fetch_type == FetchType::FORCE_FETCH) { + visible_count = run_count; + } else { + visible_count = current_row_group.Fetch(transaction, offsets, run_count, filter_sel); + if (visible_count != run_count) { + sel_for_fetch = filter_sel; + } + } + + if (visible_count == 0) { continue; } + + // 4. bulk per-column fetch state.row_group = row_group; - current_row_group.FetchRow(transaction, state, column_ids, UnsafeNumericCast(offset_in_row_group), - result, count); - count++; + current_row_group.FetchRows(transaction, state, column_ids, offsets, sel_for_fetch.get(), visible_count, result, + count); + count += visible_count; } result.SetChildCardinality(count); } @@ -472,7 +548,8 @@ bool RowGroupCollection::CanFetch(TransactionData transaction, const row_t row_i } auto ¤t_row_group = row_group->GetNode(); auto offset_in_row_group = UnsafeNumericCast(row_id) - row_group->GetRowStart(); - return current_row_group.Fetch(transaction, offset_in_row_group); + SelectionVector visible_sel(1); + return current_row_group.Fetch(transaction, &offset_in_row_group, /*count=*/1, visible_sel) == 1; } //===--------------------------------------------------------------------===// @@ -493,7 +570,9 @@ bool RowGroupCollection::IsEmpty() const { } void RowGroupCollection::InitializeAppend(TransactionData transaction, TableAppendState &state) { - state.row_start = UnsafeNumericCast(total_rows.load()); + auto next_row_id = this->next_row_id.load(); + D_ASSERT(next_row_id >= total_rows.load()); + state.row_start = UnsafeNumericCast(next_row_id); state.current_row = state.row_start; state.total_append_count = 0; @@ -504,22 +583,25 @@ void RowGroupCollection::InitializeAppend(TransactionData transaction, TableAppe bool needs_new_row_group = state.row_groups->IsEmpty(l) || row_group_append_mode == RowGroupAppendMode::REQUIRE_NEW; // Otherwise we evaluate the row_group_append_mode if (!needs_new_row_group) { - if (info->GetIndexes().Empty()) { - // We honor SUGGEST_NEW unless the table has indexes because there is no vacuuming for indexed tables... + auto last_row_group = state.row_groups->GetLastSegment(l); + D_ASSERT(last_row_group->GetRowEnd() == state.row_groups->GetBaseRowId() + next_row_id); + if (info->GetIndexes().Empty() || CanRebuildExistingIndexesAfterVacuum(*info, GetAttached(), GetTotalRows())) { + // Honor SUGGEST_NEW if vacuum can compact the table later, either because there are no indexes or because + // the existing indexes can be rebuilt after vacuuming. needs_new_row_group = row_group_append_mode == RowGroupAppendMode::SUGGEST_NEW; } else { - // ... and if it has indexes we will ignore row_group_append_mode and try to append, unless the last row - // group is full already. - needs_new_row_group = row_group_size < state.row_groups->GetLastSegment(l)->GetNode().count; + // If the table has indexes that vacuum cannot rebuild, ignore row_group_append_mode and try to append, + // unless the last row group is full already. + needs_new_row_group = row_group_size < last_row_group->GetNode().count; } } if (needs_new_row_group) { - AppendRowGroup(l, state.row_groups->GetBaseRowId() + total_rows); + AppendRowGroup(l, state.row_groups->GetBaseRowId() + next_row_id); } state.start_row_group = state.row_groups->GetLastSegment(l); - D_ASSERT(state.row_groups->GetBaseRowId() + total_rows == + D_ASSERT(state.row_groups->GetBaseRowId() + next_row_id == state.start_row_group->GetRowStart() + state.start_row_group->GetNode().count); - state.start_row_group->GetNode().InitializeAppend(state.row_group_append_state); + RowGroup::InitializeAppend(*state.start_row_group, state.row_group_append_state); state.transaction = transaction; state.row_group_start = state.start_row_group->GetRowStart(); @@ -533,30 +615,31 @@ void RowGroupCollection::InitializeAppend(TableAppendState &state) { InitializeAppend(tdata, state); } -bool RowGroupCollection::Append(DataChunk &chunk, TableAppendState &state) { +optional_idx RowGroupCollection::Append(DataChunk &chunk, TableAppendState &state) { + const idx_t row_group_size = GetRowGroupSize(); D_ASSERT(chunk.ColumnCount() == types.size()); chunk.Verify(GetDatabase()); - bool new_row_group = false; + optional_idx flushed_row_group_idx; idx_t total_append_count = chunk.size(); idx_t remaining = chunk.size(); state.total_append_count += total_append_count; while (true) { - auto current_row_group = state.row_group_append_state.row_group; + auto ¤t_row_group = state.row_group_append_state.row_group->GetNode(); // check how much we can fit into the current row_group idx_t append_count = MinValue(remaining, row_group_size - state.row_group_append_state.offset_in_row_group); if (append_count > 0) { - auto previous_allocation_size = current_row_group->GetAllocationSize(); - current_row_group->Append(state.row_group_append_state, chunk, append_count); - allocation_size += current_row_group->GetAllocationSize() - previous_allocation_size; - // merge the stats - current_row_group->MergeIntoStatistics(stats); + auto previous_allocation_size = current_row_group.GetAllocationSize(); + current_row_group.Append(state.row_group_append_state, chunk, append_count); + allocation_size += current_row_group.GetAllocationSize() - previous_allocation_size; } remaining -= append_count; if (remaining == 0) { break; } + // finalize the append state for the current row group + current_row_group.FinalizeAppend(state.row_group_append_state); // we expect max 1 iteration of this loop (i.e. a single chunk should never overflow more than one // row_group) D_ASSERT(chunk.size() == remaining + append_count); @@ -565,14 +648,14 @@ bool RowGroupCollection::Append(DataChunk &chunk, TableAppendState &state) { chunk.Slice(append_count, remaining); } // append a new row_group - new_row_group = true; + flushed_row_group_idx = state.row_group_append_state.row_group->GetIndex(); auto next_start = state.row_group_start + state.row_group_append_state.offset_in_row_group; auto l = state.row_groups->Lock(); AppendRowGroup(l, next_start); // set up the append state for this row_group auto last_row_group = state.row_groups->GetLastSegment(l); - last_row_group->GetNode().InitializeAppend(state.row_group_append_state); + RowGroup::InitializeAppend(*last_row_group, state.row_group_append_state); state.row_group_start = next_start; } state.current_row += row_t(total_append_count); @@ -584,10 +667,15 @@ bool RowGroupCollection::Append(DataChunk &chunk, TableAppendState &state) { column_stats.UpdateDistinctStatistics(chunk.data[col_idx], chunk.size(), state.hashes); } - return new_row_group; + return flushed_row_group_idx; } void RowGroupCollection::FinalizeAppend(TransactionData transaction, TableAppendState &state) { + // first finalize the append of the final row group we appended to + auto &last_row_group = state.row_group_append_state.row_group->GetNode(); + last_row_group.FinalizeAppend(state.row_group_append_state); + + // now push version info into all row groups auto remaining = state.total_append_count; auto row_group = state.start_row_group; while (remaining > 0) { @@ -598,6 +686,8 @@ void RowGroupCollection::FinalizeAppend(TransactionData transaction, TableAppend row_group = state.row_groups->GetNextSegment(*row_group); } total_rows += state.total_append_count; + next_row_id += state.total_append_count; + D_ASSERT(next_row_id.load() >= total_rows.load()); state.total_append_count = 0; state.start_row_group = nullptr; @@ -606,14 +696,8 @@ void RowGroupCollection::FinalizeAppend(TransactionData transaction, TableAppend auto global_stats_lock = stats.GetLock(); for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { auto &global_stats = stats.GetStats(*global_stats_lock, col_idx); - if (!global_stats.HasDistinctStats()) { - continue; - } auto &local_stats = state.stats.GetStats(*local_stats_lock, col_idx); - if (!local_stats.HasDistinctStats()) { - continue; - } - global_stats.DistinctStats().Merge(local_stats.DistinctStats()); + global_stats.Merge(local_stats); } Verify(); @@ -654,8 +738,10 @@ void RowGroupCollection::RevertAppendInternal(idx_t new_end_idx) { if (last_segment->GetRowEnd() <= new_end_idx) { return; } + D_ASSERT(new_end_idx >= row_groups->GetBaseRowId()); auto reverted_row_groups = make_shared_ptr(*this, row_groups->GetBaseRowId()); auto rlock = reverted_row_groups->Lock(); + idx_t new_total_rows = 0; for (auto &entry : row_groups->SegmentNodes(l)) { idx_t row_start = entry.GetRowStart(); idx_t row_end = row_start + entry.GetCount(); @@ -668,10 +754,13 @@ void RowGroupCollection::RevertAppendInternal(idx_t new_end_idx) { // this is the last row group - have to revert WITHIN it entry.GetNode().RevertAppend(new_end_idx - row_start); } - reverted_row_groups->AppendSegment(rlock, entry.ReferenceNode()); + new_total_rows += entry.GetNode().count; + reverted_row_groups->AppendSegment(rlock, entry.ReferenceNode(), row_start); } SetRowGroups(std::move(reverted_row_groups)); - total_rows = new_end_idx; + total_rows = new_total_rows; + next_row_id = new_end_idx - row_groups->GetBaseRowId(); + D_ASSERT(next_row_id.load() >= total_rows.load()); } void RowGroupCollection::CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_t count) { @@ -709,10 +798,12 @@ bool RowGroupCollection::IsPersistent() const { void RowGroupCollection::MergeStorage(RowGroupCollection &data, optional_ptr table, optional_ptr commit_state) { D_ASSERT(data.types == types); - auto segments = data.GetRowGroups()->MoveSegments(); + auto source_row_groups = data.GetRowGroups(); + auto segments = source_row_groups->MoveSegments(); auto row_groups = GetRowGroups(); - auto start_index = row_groups->GetBaseRowId() + total_rows.load(); - auto index = start_index; + D_ASSERT(next_row_id.load() >= total_rows.load()); + auto target_base_row_id = row_groups->GetBaseRowId(); + auto start_index = target_base_row_id + next_row_id.load(); // check if the row groups we are merging are optimistically written // if all row groups are optimistically written we keep around the block pointers @@ -737,25 +828,35 @@ void RowGroupCollection::MergeStorage(RowGroupCollection &data, optional_ptrGetNode().IsPersistent(); + idx_t merged_count = 0; + idx_t source_offset = 0; + idx_t target_row_start = start_index; for (auto &entry : segments) { + D_ASSERT(entry->GetRowStart() == source_row_groups->GetBaseRowId() + source_offset); auto row_group = entry->MoveNode(); row_group->MoveToCollection(*this); + idx_t row_group_count = row_group->count; - if (commit_state && (index - start_index) < optimistically_written_count) { + if (commit_state && merged_count < optimistically_written_count) { // serialize the block pointers of this row group - auto persistent_data = row_group->SerializeRowGroupInfo(index); + auto persistent_data = row_group->SerializeRowGroupInfo(target_row_start); persistent_data.types = types; row_group_data->row_group_data.push_back(std::move(persistent_data)); } - index += row_group->count; - row_groups->AppendSegment(std::move(row_group)); + merged_count += row_group_count; + source_offset += row_group_count; + row_groups->AppendSegment(std::move(row_group), target_row_start); + target_row_start += row_group_count; } if (commit_state && optimistically_written_count > 0) { // if we have serialized the row groups - push the serialized block pointers into the commit state commit_state->AddRowGroupData(*table, start_index, optimistically_written_count, std::move(row_group_data)); } stats.MergeStats(data.stats); + D_ASSERT(source_offset == data.total_rows.load()); + D_ASSERT(data.next_row_id.load() == data.total_rows.load()); total_rows += data.total_rows.load(); + next_row_id = target_row_start - target_base_row_id; if (is_persistent) { SetRowGroupAppendMode(RowGroupAppendMode::SUGGEST_NEW); } @@ -841,7 +942,11 @@ void RowGroupCollection::Update(TransactionData transaction, DuckTableEntry &tab auto l = stats.GetLock(); for (idx_t i = 0; i < column_ids.size(); i++) { auto column_id = column_ids[i]; - stats.MergeStats(*l, column_id.index, *current_row_group.GetStatistics(column_id.index)); + // Use EXPAND_BOUNDS here: the row group stats include original data already counted in collection stats, + // so additive stats (like total_string_length) cannot be maintained correctly here. EXPAND_BOUNDS + // correctly expands min/max bounds while invalidating total_string_length until the next checkpoint. + stats.MergeStats(*l, column_id.index, *current_row_group.GetStatistics(column_id.index), + StatsMergeType::EXPAND_BOUNDS); } } while (pos < updates.size()); } @@ -974,13 +1079,11 @@ void RowGroupCollection::RemoveFromIndexes(const QueryContext &context, TableInd // If we are in WAL replay, delete data will be buffered, and so we sort the column_ids // since the sorted form will be the mapping used to get back physical IDs from the buffered index chunk. - vector column_ids; - for (auto &col : indexed_column_id_set) { - column_ids.emplace_back(col); - } + vector column_ids {indexed_column_id_set.begin(), indexed_column_id_set.end()}; sort(column_ids.begin(), column_ids.end()); vector column_types; + column_types.reserve(column_ids.size()); for (auto &col : column_ids) { column_types.push_back(types[col.GetPrimaryIndex()]); } @@ -1012,10 +1115,6 @@ void RowGroupCollection::RemoveFromIndexes(const QueryContext &context, TableInd } result_chunk.data[j].Reference(Value(types[j]), count_t(fetch_chunk.size())); } - result_chunk.SetCardinality(fetch_chunk); - - DataChunk remaining_result_chunk; - unique_ptr remaining_row_ids; for (auto &entry : indexes.IndexEntries()) { auto &index = *entry.index; @@ -1077,7 +1176,6 @@ void RowGroupCollection::RemoveFromIndexes(const QueryContext &context, TableInd auto col_id = column_ids[i].GetPrimaryIndex(); index_column_chunk.data[i].Reference(result_chunk.data[col_id]); } - index_column_chunk.SetCardinality(result_chunk.size()); auto &unbound_index = index.Cast(); unbound_index.BufferChunk(index_column_chunk, row_identifiers, column_ids, BufferedIndexReplay::DEL_ENTRY); } @@ -1197,8 +1295,6 @@ struct VacuumState { //! Whether we are allowed to rebuild indexes after a vacuum (only true when vacuum_rebuild_indexes //! threshold is set, the table's row count is within the threshold, and all indexes are bound ART's). bool can_rebuild_indexes = false; - //! Whether any operation (empty group drop or vacuum merge) actually remapped row IDs - bool row_ids_changed = false; idx_t row_start = 0; idx_t next_vacuum_idx = 0; vector row_group_counts; @@ -1213,20 +1309,26 @@ class VacuumTask : public BaseCheckpointTask { } void ExecuteTask() override { + MetricsTimer timer; + auto context = checkpoint_state.writer.TryGetClientContext(); + if (context) { + timer = QueryProfiler::Get(*context).StartTimer(); + } + auto &collection = checkpoint_state.collection; const idx_t row_group_size = collection.GetRowGroupSize(); auto &types = collection.GetTypes(); + // create the new set of target row groups (initially empty) - vector> new_row_groups; + vector>> new_row_groups; vector append_counts; idx_t row_group_rows = merge_rows; for (idx_t target_idx = 0; target_idx < target_count; target_idx++) { idx_t current_row_group_rows = MinValue(row_group_rows, row_group_size); - auto new_row_group = make_uniq(collection, current_row_group_rows); + auto new_row_group = make_shared_ptr(collection, current_row_group_rows); new_row_group->InitializeEmpty(types, ColumnDataType::MAIN_TABLE); - new_row_groups.push_back(std::move(new_row_group)); + new_row_groups.push_back(make_uniq>(0ULL, std::move(new_row_group), target_idx)); append_counts.push_back(0); - row_group_rows -= current_row_group_rows; } @@ -1242,7 +1344,8 @@ class VacuumTask : public BaseCheckpointTask { // fill the new row group with the merged rows TableAppendState append_state; - new_row_groups[current_append_idx]->InitializeAppend(append_state.row_group_append_state); + auto &initial_append_row_group = *new_row_groups[current_append_idx]; + RowGroup::InitializeAppend(initial_append_row_group, append_state.row_group_append_state); TableScanState scan_state; scan_state.Initialize(column_ids); @@ -1274,17 +1377,22 @@ class VacuumTask : public BaseCheckpointTask { scan_chunk.Flatten(); idx_t remaining = scan_chunk.size(); while (remaining > 0) { + auto ¤t_append_row_group = new_row_groups[current_append_idx]->GetNode(); idx_t append_count = MinValue(remaining, row_group_size - append_counts[current_append_idx]); - new_row_groups[current_append_idx]->Append(append_state.row_group_append_state, scan_chunk, - append_count); + current_append_row_group.Append(append_state.row_group_append_state, scan_chunk, append_count); append_counts[current_append_idx] += append_count; remaining -= append_count; const bool row_group_full = append_counts[current_append_idx] == row_group_size; const bool last_row_group = current_append_idx + 1 >= new_row_groups.size(); if (remaining > 0 || (row_group_full && !last_row_group)) { + // finalize the last append + new_row_groups[current_append_idx]->GetNode().FinalizeAppend( + append_state.row_group_append_state); + // move to the next row group current_append_idx++; - new_row_groups[current_append_idx]->InitializeAppend(append_state.row_group_append_state); + RowGroup::InitializeAppend(*new_row_groups[current_append_idx], + append_state.row_group_append_state); // slice chunk for the next append scan_chunk.Slice(append_count, remaining); } @@ -1294,9 +1402,12 @@ class VacuumTask : public BaseCheckpointTask { current_row_group.CommitDrop(); checkpoint_state.DropSegment(c_idx); } + // finalize the final append + new_row_groups[current_append_idx]->GetNode().FinalizeAppend(append_state.row_group_append_state); + idx_t total_append_count = 0; for (idx_t target_idx = 0; target_idx < target_count; target_idx++) { - auto &row_group = new_row_groups[target_idx]; + auto row_group = new_row_groups[target_idx]->MoveNode(); row_group->Verify(); // assign the new row group to the current segment @@ -1309,6 +1420,10 @@ class VacuumTask : public BaseCheckpointTask { "Mismatch in row group count %d vs verify count %d in RowGroupCollection::Checkpoint", merge_rows, total_append_count); } + + // Explicitly end the timer for the vacuum tasks here. + timer.EndTimer(); + // merging is complete - execute checkpoint tasks of the target row groups for (idx_t i = 0; i < target_count; i++) { auto checkpoint_task = collection.GetCheckpointTask(checkpoint_state, segment_idx + i); @@ -1340,73 +1455,91 @@ void RowGroupCollection::InitializeVacuumState(CollectionCheckpointState &checkp return; } - // if there are indexes - we cannot change row-ids - // this limits what kind of vacuuming we can do + // Vacuuming has two independent rowid constraints: + // - indexes may force surviving rowids to remain stable + // - the target storage version can force rowids to be written without gaps + // Indexes store rowids, so rowids can only be changed when there are no indexes, or when all indexes can be + // rebuilt after vacuuming. bool has_indexes = !info->GetIndexes().Empty(); - // *unless* vacuum_rebuild_indexes threshold is set, the table's row count - // is within the threshold, and all indexes are bound ART indexes, - // in which case we allow vacuuming and rebuild the indexes afterward. - auto vacuum_rebuild_threshold = Settings::Get(checkpoint_state.writer.GetDatabase()); - auto index_types = info->GetIndexes().DistinctIndexTypes(); - state.can_rebuild_indexes = has_indexes && !info->GetIndexes().HasUnbound() && index_types.size() == 1 && - index_types.count(ART::TYPE_NAME) && vacuum_rebuild_threshold > 0 && - GetTotalRows() <= vacuum_rebuild_threshold; - - // We can move around rowids if we either 1) don't have any indexes at all or 2) can_rebuild_indexes is true (in - // which case indexes are entirely rebuilt after vacuuming). + // vacuum_rebuild_indexes allows rowid-changing vacuum for indexed tables when the table's row count is within + // the threshold and all indexes are bound ART indexes. + state.can_rebuild_indexes = + CanRebuildExistingIndexesAfterVacuum(*info, checkpoint_state.writer.GetAttached(), GetTotalRows()); + + // can_change_row_ids only answers whether index state allows changing row_ids. state.can_change_row_ids = !has_indexes || state.can_rebuild_indexes; + // obtain the set of committed row counts for each row group - auto row_group_count = checkpoint_state.SegmentCount(); - vector committed_counts; - state.row_group_counts.reserve(checkpoint_state.SegmentCount()); - if (checkpoint_row_group_count.IsValid() && checkpoint_row_group_count.GetIndex() > row_group_count) { + auto num_row_groups = checkpoint_state.SegmentCount(); + state.row_group_counts.reserve(num_row_groups); + + if (checkpoint_row_group_count.IsValid() && checkpoint_row_group_count.GetIndex() > num_row_groups) { // we have row groups that were concurrently appended to this collection // don't vacuum - otherwise we can move row groups which could cause committed row-ids to be moved around // while transactions are still processing / depending on them being stable (during e.g. commit) state.can_vacuum_deletes = false; return; } - bool dropped_any_rowgroups = false; + bool legacy_vacuum_with_stable_row_ids = + !checkpoint_state.writer.CanPersistRowIdGaps() && !state.can_change_row_ids; + // RowIdsChanged condition: legacy storage cannot persist rowid gaps. If we are allowed to change rowids, + // we record a gap as seen; if we later see a row group with live rows, then we know rowids have to be shifted. + bool rowid_gap_seen = false; + vector committed_counts; for (auto &entry : checkpoint_state.row_groups.SegmentNodes()) { auto &row_group = entry.GetNode(); - auto row_group_count = row_group.GetCommittedRowCount(); + auto row_group_num_rows = row_group.GetCommittedRowCount(); + if (legacy_vacuum_with_stable_row_ids) { + // In this legacy path, empty row groups in the middle must be kept because they would create a rowid gap + // that cannot be persisted or densely rewritten. Keep the committed counts so the pass below can + // still remove trailing deleted row groups. + committed_counts.emplace_back(row_group_num_rows); + } + if (row_group_num_rows == 0) { + if (!checkpoint_state.writer.CanPersistRowIdGaps()) { + // Older storage versions cannot represent rowid gaps. Dropping a row group in the middle of the + // table therefore requires dense rowid rewriting. + if (!state.can_change_row_ids) { + // Indexes prevent rowid changes, so keep the row group for now. If this is part of a trailing + // deleted suffix, we can still drop it in the legacy_vacuum_with_stable_row_ids pass below. + state.row_group_counts.emplace_back(); + continue; + } + // track this gap until we know whether it is followed by live rows. + rowid_gap_seen = true; + } + // Drop the empty row group. Newer storage versions persist the resulting rowid gap; older storage reaches + // this path only when rowids may be densely rewritten. + row_group.CommitDrop(); + checkpoint_state.DropSegment(entry.GetIndex()); + state.row_group_counts.push_back(row_group_num_rows); + continue; + } if (!state.can_change_row_ids) { idx_t total_count = row_group.count; - committed_counts.emplace_back(row_group_count); - // we cannot change row ids, and this row group has deletes - // vacuuming here would alter row ids - so skip it - if (total_count != row_group_count) { + if (total_count != row_group_num_rows) { + // We have partial deletes and cannot change rowid's, so skip it. state.row_group_counts.emplace_back(); continue; } + // Otherwise, the row group is fully live. We can still consider fully live row groups for merging as + // that does not change rowids. } - if (row_group_count == 0) { - // empty row group - we can drop it entirely. - row_group.CommitDrop(); - checkpoint_state.DropSegment(entry.GetIndex()); - dropped_any_rowgroups = true; - state.row_group_counts.push_back(row_group_count); - continue; - } - if (dropped_any_rowgroups) { - // if there are any dropped row groups before a live row group, all the row ids of the row groups following - // the dropped row group will have their row ids shifted forward (to keep row ids contiguous). - state.row_ids_changed = true; + if (rowid_gap_seen) { + // checkpointing will later re-number rowids densely. + checkpoint_state.writer.SetRowIdsChanged(); } - state.row_group_counts.push_back(row_group_count); + state.row_group_counts.push_back(row_group_num_rows); } - if (!state.can_change_row_ids && options.type != CheckpointType::CONCURRENT_CHECKPOINT) { - // if we cannot change row ids we might still be able to vacuum trailing deletions - // since that would not change the row ids of any non-deleted rows + if (legacy_vacuum_with_stable_row_ids) { + // With older storage and stable rowids, gap-forming drops were skipped above. Removing a suffix of fully + // deleted row groups is still safe because no surviving rowid is shifted and no persisted rowid gap remains + // before live data. auto segment_count = state.row_group_counts.size(); for (idx_t i = segment_count; i > 0; i--) { auto segment_idx = i - 1; - if (!committed_counts[segment_idx].IsValid()) { - // cannot vacuum this row group - break; - } - if (committed_counts[segment_idx].GetIndex() != 0) { + if (committed_counts[segment_idx] != 0) { // multiple rows found here - skip break; } @@ -1450,11 +1583,20 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi // hence we target_count should be less than merge_count for a merge to be worth it // we greedily prefer to merge to the lowest target_count // i.e. we prefer to merge 2 row groups into 1, than 3 row groups into 2 - const idx_t row_group_size = GetRowGroupSize(); + // + // RowIdsChanged conditions for scheduled vacuum merges (over-approximation, not exact): + // 0) rowids are allowed to change in the first place. + // 1) the selected merge crosses an existing rowid gap. + // 2) the selected merge includes a partially deleted row group. + const idx_t target_row_group_size = GetRowGroupSize(); for (target_count = 1; target_count <= MAX_MERGE_COUNT; target_count++) { - auto total_target_size = target_count * row_group_size; + auto total_target_size = target_count * target_row_group_size; merge_count = 0; merge_rows = 0; + optional_idx expected_row_start; + // Conditions 1 and 2 are evaluated for each candidate merge window. The flag is only applied if this + // candidate is selected below. + bool candidate_changes_row_ids = false; for (next_idx = segment_idx; next_idx < checkpoint_state.SegmentCount(); next_idx++) { if (!state.row_group_counts[next_idx].IsValid()) { // cannot vacuum this row group - break @@ -1462,12 +1604,42 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi } auto next_row_count = state.row_group_counts[next_idx].GetIndex(); if (next_row_count == 0) { + if (!state.can_change_row_ids) { + // This row group was dropped under a storage format that can persist rowid gaps, but we are not + // allowed to change rowids. Do not schedule vacuuming across this row group. + break; + } continue; } + auto next_segment = checkpoint_state.GetSegment(next_idx); + if (!next_segment) { + break; + } + auto &next_row_group = next_segment->GetNode(); + auto next_total_count = next_row_group.count.load(); if (merge_rows + next_row_count > total_target_size) { // does not fit break; } + if (!expected_row_start.IsValid()) { + expected_row_start = next_segment->GetRowStart(); + } else if (next_segment->GetRowStart() != expected_row_start.GetIndex()) { + if (!state.can_change_row_ids) { + break; + } + // Condition 1 candidate: this candidate vacuum task will compact across a rowid gap, which will shift + // rowids within the selected row groups. + candidate_changes_row_ids = true; + } + if (next_row_count != next_total_count) { + if (!state.can_change_row_ids) { + break; + } + // Condition 2 candidate: if this merge is selected, it includes a partially deleted row group, which + // may mean shifting row-ids (over-approximation). + candidate_changes_row_ids = true; + } + expected_row_start = next_segment->GetRowStart() + next_total_count; // we can merge this row group together with the other row group merge_rows += next_row_count; merge_count++; @@ -1481,7 +1653,8 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi // merge it with a row group with 1 row, creating a row group with 100K+2 rows // etc. This leads to constant rewriting of the original 100K rows. idx_t minimum_target = - MinValue(state.row_group_counts[segment_idx].GetIndex() * 2, row_group_size) * target_count; + MinValue(state.row_group_counts[segment_idx].GetIndex() * 2, target_row_group_size) * + target_count; if (merge_rows >= STANDARD_VECTOR_SIZE && merge_rows < minimum_target) { // we haven't reached the minimum target - don't do this vacuum next_idx = segment_idx + 1; @@ -1492,6 +1665,10 @@ bool RowGroupCollection::ScheduleVacuumTasks(CollectionCheckpointState &checkpoi // we can reduce "merge_count" row groups to "target_count" // perform the merge at this level perform_merge = true; + if (candidate_changes_row_ids) { + // Apply Condition 2 or 3 for the selected merge window. + checkpoint_state.writer.SetRowIdsChanged(); + } break; } } @@ -1530,9 +1707,6 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl VacuumState vacuum_state; InitializeVacuumState(checkpoint_state, vacuum_state, writer.GetRowGroupCount()); - if (vacuum_state.row_ids_changed) { - writer.SetRowIdsChanged(); - } try { // schedule tasks @@ -1544,8 +1718,6 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl if (vacuum_tasks) { // vacuum tasks were scheduled - don't schedule a checkpoint task yet total_vacuum_tasks++; - vacuum_state.row_ids_changed = true; - writer.SetRowIdsChanged(); continue; } if (checkpoint_state.SegmentIsDropped(segment_idx)) { @@ -1587,7 +1759,7 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl break; } auto &write_state = checkpoint_state.write_data[segment_idx]; - if (!write_state.reuse_existing_metadata_blocks) { + if (write_state.write_action != RowGroupWriteAction::REUSE_EXISTING_ROW_GROUP_METADATA) { table_has_changes = true; break; } @@ -1600,19 +1772,15 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl for (idx_t segment_idx = 0; segment_idx < checkpoint_state.SegmentCount(); segment_idx++) { auto entry = checkpoint_state.GetSegment(segment_idx); auto &row_group = entry->GetNode(); - auto &write_state = checkpoint_state.write_data[segment_idx]; metadata_manager.ClearModifiedBlocks(row_group.GetColumnStartPointers()); - D_ASSERT(write_state.reuse_existing_metadata_blocks); - vector extra_metadata_block_pointers; - extra_metadata_block_pointers.reserve(write_state.existing_extra_metadata_blocks.size()); - for (auto &block_pointer : write_state.existing_extra_metadata_blocks) { - extra_metadata_block_pointers.emplace_back(block_pointer, 0); - } + D_ASSERT(checkpoint_state.write_data[segment_idx].write_action == + RowGroupWriteAction::REUSE_EXISTING_ROW_GROUP_METADATA); + vector extra_metadata_block_pointers = row_group.GetExtraMetadataBlockPointers(); metadata_manager.ClearModifiedBlocks(extra_metadata_block_pointers); auto row_group_writer = checkpoint_state.writer.GetRowGroupWriter(row_group); row_group.CheckpointDeletes(*row_group_writer); } - writer.WriteUnchangedTable(metadata_pointer, metadata_pointers, total_rows.load()); + writer.WriteUnchangedTable(metadata_pointer, metadata_pointers, total_rows.load(), next_row_id.load()); // copy over existing stats into the global stats CopyStats(global_stats); return; @@ -1627,6 +1795,9 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl global_stats.InitializeEmpty(stats); idx_t new_total_rows = 0; + idx_t new_next_row_id = 0; + auto base_row_id = row_groups->GetBaseRowId(); + auto can_persist_rowid_gaps = writer.CanPersistRowIdGaps(); unordered_set columns_with_incomplete_stats; for (idx_t segment_idx = 0; segment_idx < checkpoint_state.SegmentCount(); segment_idx++) { auto entry = checkpoint_state.GetSegment(segment_idx); @@ -1637,10 +1808,21 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl } auto &existing_row_group = entry->GetNode(); auto &row_group_writer = checkpoint_state.writers[segment_idx]; + auto source_row_start = entry->GetRowStart(); + auto row_start = source_row_start; + if (!can_persist_rowid_gaps) { + // Older storage versions do not serialize next_row_id, so the checkpointed row groups must have + // contiguous numbering for rowids. + row_start = base_row_id + new_total_rows; + // If rowids must remain stable, dense old-storage output must not move this surviving row group. + D_ASSERT(vacuum_state.can_change_row_ids || row_start == source_row_start); + } if (!row_group_writer) { // row group was not checkpointed - this can happen if compressing is disabled for in-memory tables - new_row_groups->AppendSegment(l, entry->ReferenceNode()); + D_ASSERT(row_start == source_row_start); + new_row_groups->AppendSegment(l, entry->ReferenceNode(), row_start); new_total_rows += existing_row_group.count; + new_next_row_id = MaxValue(new_next_row_id, row_start + existing_row_group.count - base_row_id); auto lock = global_stats.GetLock(); for (idx_t column_idx = 0; column_idx < existing_row_group.GetColumnCount(); column_idx++) { @@ -1651,8 +1833,21 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl continue; } auto &row_group_write_data = checkpoint_state.write_data[segment_idx]; - idx_t row_start = new_total_rows; - bool metadata_reuse = row_group_write_data.reuse_existing_metadata_blocks; + auto write_action = row_group_write_data.write_action; + auto debug_verify_blocks = Settings::Get(GetAttached().GetDatabase()) && + dynamic_cast(&checkpoint_state.writer) != nullptr; + vector reuse_column; + if (debug_verify_blocks) { + if (write_action == RowGroupWriteAction::REUSE_EXISTING_ROW_GROUP_METADATA) { + auto existing_column_count = entry->ReferenceNode()->GetColumnCount(); + reuse_column.resize(existing_column_count, true); + } else { + reuse_column.resize(row_group_write_data.states.size()); + for (idx_t column_idx = 0; column_idx < row_group_write_data.states.size(); column_idx++) { + reuse_column[column_idx] = !row_group_write_data.states[column_idx]; + } + } + } auto new_row_group = std::move(row_group_write_data.result_row_group); if (!new_row_group) { // row group was unchanged - emit previous row group @@ -1660,9 +1855,6 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl } auto &row_group = *new_row_group; RowGroupPointer pointer_copy; - auto debug_verify_blocks = Settings::Get(GetAttached().GetDatabase()) && - dynamic_cast(&checkpoint_state.writer) != nullptr; - // check if we should write this row group to the persistent storage // don't write it if it only has uncommitted transaction-local changes made AFTER this checkpoint was started auto pointer = @@ -1678,49 +1870,95 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl writer.AddRowGroup(std::move(pointer), std::move(row_group_writer)); new_total_rows += row_group.count; - new_row_groups->AppendSegment(l, std::move(new_row_group)); + new_next_row_id = MaxValue(new_next_row_id, row_start + row_group.count - base_row_id); + new_row_groups->AppendSegment(l, std::move(new_row_group), row_start); if (debug_verify_blocks) { - if (!pointer_copy.has_metadata_blocks) { + if (!pointer_copy.has_metadata_blocks && !pointer_copy.has_per_column_metadata_blocks) { throw InternalException("Checkpointing should always remember metadata blocks"); } - if (metadata_reuse && pointer_copy.data_pointers != row_group.GetColumnStartPointers()) { - throw InternalException("Column start pointers changed during metadata reuse"); + if (SupportsPerColumnWrites() != pointer_copy.has_per_column_metadata_blocks) { + throw InternalException( + "Checkpointing should always remember per-column metadata blocks when supporting it"); + } + if (write_action == RowGroupWriteAction::PARTIALLY_REUSE_COLUMN_METADATA && !SupportsPerColumnWrites()) { + throw InternalException("Partially reusing column metadata should only be done when supporting it"); + } + if (write_action == RowGroupWriteAction::REUSE_EXISTING_ROW_GROUP_METADATA && + pointer_copy.data_pointers != row_group.GetColumnStartPointers()) { + throw InternalException("Column start pointers changed during full metadata reuse"); } - // Capture blocks that have been written - vector all_written_blocks = pointer_copy.data_pointers; - vector all_metadata_blocks; - for (auto &block : pointer_copy.extra_metadata_blocks) { - all_written_blocks.emplace_back(block, 0); - all_metadata_blocks.emplace_back(block, 0); + // Verify per_column_metadata_blocks matches full deserialization + if (pointer_copy.has_per_column_metadata_blocks) { + const auto &column_start_ptrs = row_group.GetColumnStartPointers(); + auto &col_types = row_group.GetCollection().GetTypes(); + auto &mm = row_group.GetCollection().GetMetadataManager(); + vector columns; + vector> deserialized_extras; + deserialized_extras.reserve(column_start_ptrs.size()); + for (idx_t i = 0; i < column_start_ptrs.size(); i++) { + columns.emplace_back(i); + deserialized_extras.emplace_back(); + vector col_read_pointers; + MetadataReader reader(mm, column_start_ptrs[i], &col_read_pointers); + ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), i, reader, col_types[i]); + // collect extra blocks from deserialization (excluding start block) + for (auto &ptr : col_read_pointers) { + if (ptr.block_pointer != column_start_ptrs[i].block_pointer) { + deserialized_extras[i].emplace_back(ptr.block_pointer); + } + } + } + auto blocks_for_columns = pointer_copy.per_column_metadata_blocks.GetBlocksForColumns(columns); + if (deserialized_extras != blocks_for_columns) { + throw InternalException("per_column_metadata_blocks mismatch"); + } } - // Verify that we can load the metadata correctly again - vector all_quick_read_blocks; - for (auto &ptr : row_group.GetColumnStartPointers()) { - all_quick_read_blocks.emplace_back(ptr); - if (metadata_reuse && !block_manager.GetMetadataManager().BlockHasBeenCleared(ptr)) { - throw InternalException("Found column start block that was not cleared"); + // Verify blocks are cleared for partial column reuse + for (idx_t col_idx = 0; col_idx < pointer_copy.data_pointers.size(); col_idx++) { + if (!reuse_column[col_idx]) { + continue; + } + // reused column: its start block should be cleared + if (!block_manager.GetMetadataManager().BlockHasBeenCleared(pointer_copy.data_pointers[col_idx])) { + throw InternalException("Partial reuse: column %llu start block was not cleared", col_idx); } } - auto extra_metadata_blocks = row_group.GetOrComputeExtraMetadataBlocks(/* force_compute: */ true); - for (auto &ptr : extra_metadata_blocks) { - auto block_pointer = MetaBlockPointer(ptr, 0); - all_quick_read_blocks.emplace_back(block_pointer); - if (metadata_reuse && !block_manager.GetMetadataManager().BlockHasBeenCleared(block_pointer)) { - throw InternalException("Found extra metadata block that was not cleared"); + pointer_copy.per_column_metadata_blocks.ForEachBlock([&](idx_t col_idx, idx_t block_id) { + if (!reuse_column[col_idx]) { + return; } + // reused column: extra blocks should be cleared + auto block_ptr = MetaBlockPointer(block_id, 0); + if (!block_manager.GetMetadataManager().BlockHasBeenCleared(block_ptr)) { + throw InternalException("Partial reuse: column extra block %llu was not cleared", block_id); + } + }); + + // Capture blocks that have been written + vector all_written_blocks = pointer_copy.data_pointers; + if (pointer_copy.has_metadata_blocks) { + for (auto &block : pointer_copy.extra_metadata_blocks) { + all_written_blocks.emplace_back(block, 0); + } + } else { + if (!pointer_copy.has_per_column_metadata_blocks) { + throw InternalException("Checkpointing should always remember metadata blocks"); + } + pointer_copy.per_column_metadata_blocks.ForEachBlock( + [&](idx_t, idx_t block) { all_written_blocks.emplace_back(block, 0); }); } - // Deserialize all columns to check if the quick read via GetOrComputeExtraMetadataBlocks was correct + // Deserialize all columns to check if what's on disk matches our metadata vector all_full_read_blocks; - auto column_start_pointers = row_group.GetColumnStartPointers(); - auto &types = row_group.GetCollection().GetTypes(); + const auto &column_start_pointers = row_group.GetColumnStartPointers(); + auto &column_types = row_group.GetCollection().GetTypes(); auto &metadata_manager = row_group.GetCollection().GetMetadataManager(); for (idx_t i = 0; i < column_start_pointers.size(); i++) { MetadataReader reader(metadata_manager, column_start_pointers[i], &all_full_read_blocks); - ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), i, reader, types[i]); + ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), i, reader, column_types[i]); } // Derive sets of blocks to compare @@ -1728,25 +1966,16 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl for (auto &ptr : all_written_blocks) { all_written_block_ids.insert(ptr.block_pointer); } - set all_quick_read_block_ids; - for (auto &ptr : all_quick_read_blocks) { - all_quick_read_block_ids.insert(ptr.block_pointer); - } set all_full_read_block_ids; for (auto &ptr : all_full_read_blocks) { all_full_read_block_ids.insert(ptr.block_pointer); } - if (all_written_block_ids != all_quick_read_block_ids || - all_quick_read_block_ids != all_full_read_block_ids) { + if (all_written_block_ids != all_full_read_block_ids) { std::stringstream oss; oss << "\nWritten: "; for (auto &block : all_written_blocks) { oss << block << ", "; } - oss << "\nQuick read: "; - for (auto &block : all_quick_read_blocks) { - oss << block << ", "; - } oss << "\nFull read: "; for (auto &block : all_full_read_blocks) { oss << block << ", "; @@ -1765,6 +1994,11 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl set all_written_deletes_block_ids; for (auto &ptr : pointer_copy.deletes_pointers) { all_written_deletes_block_ids.insert(ptr.block_pointer); + // delete ptr should be cleared (unless it's been newly written) + if (block_manager.GetMetadataManager().BlockIsModified(ptr) && + !block_manager.GetMetadataManager().BlockHasBeenCleared(ptr)) { + throw InternalException("Delete ptr %llu was not cleared", ptr.block_pointer); + } } set all_read_deletes_block_ids; for (auto &ptr : read_deletes_pointers) { @@ -1791,12 +2025,14 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl } if (!columns_with_incomplete_stats.empty()) { // for any columns that have incomplete stats we need to merge in the previous global stats to ensure the stats - // are correct + // are correct — use EXPAND_BOUNDS so additive stats (e.g. total_string_length) are invalidated rather than + // double-counted (the collection stats include contributions from all row groups, including those already + // merged) auto lock = global_stats.GetLock(); for (auto &column_idx : columns_with_incomplete_stats) { auto stats_lock = stats.GetLock(); auto &column_stats = stats.GetStats(*stats_lock, column_idx); - global_stats.MergeStats(*lock, column_idx, column_stats.Statistics()); + global_stats.MergeStats(*lock, column_idx, column_stats.Statistics(), StatsMergeType::EXPAND_BOUNDS); } } l.Release(); @@ -1808,6 +2044,8 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl writer.FlushPartialBlocks(); // override the row group segment tree total_rows = new_total_rows; + next_row_id = new_next_row_id; + D_ASSERT(next_row_id.load() >= total_rows.load()); SetRowGroups(std::move(new_row_groups)); Verify(); // Rebuild indexes if: @@ -1816,7 +2054,7 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl // and all the indexes are bound ART's), // and // 2) we have changed rowids. - if (vacuum_state.can_rebuild_indexes && vacuum_state.row_ids_changed) { + if (vacuum_state.can_rebuild_indexes && writer.RowIdsChanged()) { writer.SetRebuildIndexes(); } } @@ -1895,17 +2133,50 @@ vector RowGroupCollection::GetPartitionStats() const { //===--------------------------------------------------------------------===// // GetColumnSegmentInfo //===--------------------------------------------------------------------===// -vector RowGroupCollection::GetColumnSegmentInfo(const QueryContext &context) { +vector RowGroupCollection::GetColumnSegmentInfo(const QueryContext &context, + const ColumnSegmentInfoScanOptions &options) const { vector result; - auto row_groups = GetRowGroups(); - auto lock = row_groups->Lock(); - for (auto &node : row_groups->SegmentNodes(lock)) { - auto &row_group = node.GetNode(); - row_group.GetColumnSegmentInfo(context, node.GetIndex(), result); + ColumnSegmentInfoScanState state; + state.options = options; + InitializeColumnSegmentInfoScan(state); + while (ScanColumnSegmentInfo(context, state, result)) { } return result; } +void RowGroupCollection::InitializeColumnSegmentInfoScan(ColumnSegmentInfoScanState &state) const { + // Pin a consistent snapshot of the row groups. Holding the shared_ptr keeps the + // segment tree alive even if a concurrent operation (alter, vacuum) installs a + // new collection. The segment-tree node lock is acquired only briefly here to + // fetch the first node (and may lazily load it from disk). + state.row_groups = GetRowGroups(); + state.current_row_group = state.row_groups->GetRootSegment(); +} + +bool RowGroupCollection::ScanColumnSegmentInfo(const QueryContext &context, ColumnSegmentInfoScanState &state, + vector &result) const { + if (!state.current_row_group) { + return false; + } + auto &node = *state.current_row_group; + node.GetNode().GetColumnSegmentInfo(context, node.GetIndex(), result, state.options); + // Advance to the next row group. For lazy-loading segment trees this acquires + // the node lock briefly only while more segments still need to be loaded. + state.current_row_group = state.row_groups->GetNextSegment(node); + return true; +} + +bool RowGroupCollection::SupportsPerColumnWrites() { + auto version = StorageCompatibility::FromDatabase(GetAttached()); + if (version.storage_version >= StorageCompatibility::FromString("v2.0.0").storage_version) { + return true; + } + if (version.storage_version >= StorageCompatibility::FromString("v1.4.0").storage_version) { + return Settings::Get(GetAttached().GetDatabase()); + } + return false; +} + //===--------------------------------------------------------------------===// // Alter //===--------------------------------------------------------------------===// @@ -1917,9 +2188,7 @@ shared_ptr RowGroupCollection::AddColumn(ClientContext &cont new_types.push_back(new_column.GetType()); auto result = make_shared_ptr(info, block_manager, std::move(new_types), row_groups->GetBaseRowId(), total_rows.load(), row_group_size); - - DataChunk dummy_chunk; - Vector default_vector(new_column.GetType()); + result->next_row_id = next_row_id.load(); result->stats.InitializeAddColumn(stats, new_column.GetType()); auto lock = result->stats.GetLock(); @@ -1928,12 +2197,13 @@ shared_ptr RowGroupCollection::AddColumn(ClientContext &cont // fill the column with its DEFAULT value, or NULL if none is specified auto new_stats = make_uniq(new_column.GetType()); auto result_row_groups = result->GetRowGroups(); - for (auto ¤t_row_group : row_groups->Segments()) { - auto new_row_group = current_row_group.AddColumn(*result, new_column, default_executor, default_vector); + for (auto &node : row_groups->SegmentNodes()) { + auto ¤t_row_group = node.GetNode(); + auto new_row_group = current_row_group.AddColumn(*result, new_column, default_executor); // merge in the statistics new_row_group->MergeIntoStatistics(new_column_idx, new_column_stats.Statistics()); - result_row_groups->AppendSegment(std::move(new_row_group)); + result_row_groups->AppendSegment(std::move(new_row_group), node.GetRowStart()); } return result; @@ -1947,15 +2217,17 @@ shared_ptr RowGroupCollection::RemoveColumn(idx_t col_idx) { auto result = make_shared_ptr(info, block_manager, std::move(new_types), row_groups->GetBaseRowId(), total_rows.load(), row_group_size); + result->next_row_id = next_row_id.load(); result->stats.InitializeRemoveColumn(stats, col_idx); auto result_lock = result->stats.GetLock(); result->stats.DestroyTableSample(*result_lock); auto result_row_groups = result->GetRowGroups(); - for (auto ¤t_row_group : row_groups->Segments()) { + for (auto &node : row_groups->SegmentNodes()) { + auto ¤t_row_group = node.GetNode(); auto new_row_group = current_row_group.RemoveColumn(*result, col_idx); - result_row_groups->AppendSegment(std::move(new_row_group)); + result_row_groups->AppendSegment(std::move(new_row_group), node.GetRowStart()); } return result; } @@ -1971,6 +2243,7 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont auto result = make_shared_ptr(info, block_manager, std::move(new_types), row_groups->GetBaseRowId(), total_rows.load(), row_group_size); + result->next_row_id = next_row_id.load(); result->stats.InitializeAlterType(stats, changed_idx, target_type); vector scan_types; @@ -1990,7 +2263,7 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont TableScanState scan_state; scan_state.Initialize(bound_columns); scan_state.table_state.Initialize(context, GetTypes()); - scan_state.table_state.max_row = row_groups->GetBaseRowId() + total_rows; + scan_state.table_state.max_row = row_groups->GetBaseRowId() + next_row_id.load(); // now alter the type of the column within all of the row_groups individually auto lock = result->stats.GetLock(); @@ -2002,7 +2275,7 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont auto new_row_group = current_row_group.AlterType(*result, target_type, changed_idx, executor, scan_state.table_state, node, scan_chunk); new_row_group->MergeIntoStatistics(changed_idx, changed_stats.Statistics()); - result_row_groups->AppendSegment(std::move(new_row_group)); + result_row_groups->AppendSegment(std::move(new_row_group), node.GetRowStart()); } return result; } @@ -2042,7 +2315,7 @@ void RowGroupCollection::VerifyNewConstraint(const QueryContext &context, DataTa } // Verify the NOT NULL constraint. - if (VectorOperations::HasNull(scan_chunk.data[0], scan_chunk.size())) { + if (VectorOperations::HasNull(scan_chunk.data[0])) { auto name = parent.Columns()[physical_index].GetName(); throw ConstraintException("NOT NULL constraint failed: %s.%s", info->GetTableName(), name); } diff --git a/src/duckdb/src/storage/table/row_group_reorderer.cpp b/src/duckdb/src/storage/table/row_group_reorderer.cpp index 482b9bac5..6ab37e900 100644 --- a/src/duckdb/src/storage/table/row_group_reorderer.cpp +++ b/src/duckdb/src/storage/table/row_group_reorderer.cpp @@ -22,9 +22,10 @@ bool CompareValues(const Value &v1, const Value &v2, const OrderByStatistics ord return (order == OrderByStatistics::MAX && v1 < v2) || (order == OrderByStatistics::MIN && v1 > v2); } -idx_t GetQualifyingTupleCount(RowGroup &row_group, BaseStatistics &stats, const OrderByColumnType type) { +idx_t GetMinimumQualifyingTupleCount(RowGroup &row_group, BaseStatistics &stats, const OrderByColumnType type, + TransactionData transaction) { if (!stats.CanHaveNull()) { - return row_group.count; + return row_group.GetVisibleRowCount(transaction); } if (type == OrderByColumnType::NUMERIC) { @@ -42,9 +43,9 @@ idx_t GetQualifyingTupleCount(RowGroup &row_group, BaseStatistics &stats, const } template -void AddRowGroups(multimap &row_group_map, It it, const End &end, +bool AddRowGroups(multimap &row_group_map, It it, const End &end, vector>> &ordered_row_groups, const idx_t row_limit, - const OrderByColumnType column_type, const OrderByStatistics stat_type) { + const OrderByColumnType column_type, const OrderByStatistics stat_type, TransactionData transaction) { const auto opposite_stat_type = stat_type == OrderByStatistics::MAX ? OrderByStatistics::MIN : OrderByStatistics::MAX; @@ -57,7 +58,7 @@ void AddRowGroups(multimap &row_group_map, It i idx_t qualify_later = 0; idx_t last_unresolved_row_group_sum = - GetQualifyingTupleCount(it->second.row_group.get().GetNode(), *last_stats, column_type); + GetMinimumQualifyingTupleCount(it->second.row_group.get().GetNode(), *last_stats, column_type, transaction); for (; it != end; ++it) { auto ¤t_key = it->first; auto &row_group = it->second.row_group; @@ -82,14 +83,16 @@ void AddRowGroups(multimap &row_group_map, It i auto &upcoming_row_group = last_unresolved_entry->second.row_group.get().GetNode(); auto &upcoming_stats = *last_unresolved_entry->second.stats; - last_unresolved_row_group_sum += GetQualifyingTupleCount(upcoming_row_group, upcoming_stats, column_type); + last_unresolved_row_group_sum += + GetMinimumQualifyingTupleCount(upcoming_row_group, upcoming_stats, column_type, transaction); last_unresolved_boundary = RowGroupReorderer::RetrieveStat(upcoming_stats, opposite_stat_type, column_type); } if (qualifying_tuples >= row_limit) { - return; + return true; } ordered_row_groups.emplace_back(row_group); } + return false; } template @@ -108,31 +111,35 @@ void InsertAllRowGroups(It it, const End &end, vector &row_group_map, const optional_idx row_limit, +bool SetRowGroupVector(multimap &row_group_map, const optional_idx row_limit, const idx_t row_group_offset, const OrderType order_type, const OrderByColumnType column_type, - vector>> &ordered_row_groups) { + vector>> &ordered_row_groups, TransactionData transaction) { const auto stat_type = order_type == OrderType::ASCENDING ? OrderByStatistics::MIN : OrderByStatistics::MAX; if (order_type == OrderType::ASCENDING) { auto end = row_group_map.end(); auto it = SkipOffsetPrunedRowGroups(row_group_map.begin(), end, row_group_offset); if (it == end) { - return; + return false; } if (row_limit.IsValid()) { - AddRowGroups(row_group_map, it, end, ordered_row_groups, row_limit.GetIndex(), column_type, stat_type); + return AddRowGroups(row_group_map, it, end, ordered_row_groups, row_limit.GetIndex(), column_type, + stat_type, transaction); } else { InsertAllRowGroups(it, end, ordered_row_groups); + return false; } } else { auto end = row_group_map.rend(); auto it = SkipOffsetPrunedRowGroups(row_group_map.rbegin(), end, row_group_offset); if (it == end) { - return; + return false; } if (row_limit.IsValid()) { - AddRowGroups(row_group_map, it, end, ordered_row_groups, row_limit.GetIndex(), column_type, stat_type); + return AddRowGroups(row_group_map, it, end, ordered_row_groups, row_limit.GetIndex(), column_type, + stat_type, transaction); } else { InsertAllRowGroups(it, end, ordered_row_groups); + return false; } } } @@ -199,8 +206,8 @@ OffsetPruningResult FindOffsetPrunableChunks(It it, const End &end, const OrderB } // namespace -RowGroupReorderer::RowGroupReorderer(const RowGroupOrderOptions &options_p) - : options(options_p), offset(0), initialized(false) { +RowGroupReorderer::RowGroupReorderer(const RowGroupOrderOptions &options_p, TransactionData transaction_p) + : options(options_p), transaction(transaction_p), offset(0), initialized(false) { } optional_ptr> RowGroupReorderer::GetNextRowGroup(SegmentNode &row_group) { @@ -347,12 +354,17 @@ optional_ptr> RowGroupReorderer::GetRootSegment(RowGroupSe AppendRowGroups(null_only_groups, options.leading_null_group_offset, ordered_row_groups); AppendRowGroups(ambiguous_groups, 0, ordered_row_groups); SetRowGroupVector(row_group_map, options.row_limit, options.row_group_offset, options.order_type, - options.column_type, ordered_row_groups); + options.column_type, ordered_row_groups, transaction); } else { - SetRowGroupVector(row_group_map, options.row_limit, options.row_group_offset, options.order_type, - options.column_type, ordered_row_groups); + auto pruned_by_limit = + SetRowGroupVector(row_group_map, options.row_limit, options.row_group_offset, options.order_type, + options.column_type, ordered_row_groups, transaction); AppendRowGroups(ambiguous_groups, 0, ordered_row_groups); - AppendRowGroups(null_only_groups, 0, ordered_row_groups); + // Once an ascending NULLS LAST limit is already satisfied by earlier non-null row groups, null-only + // groups cannot contribute to the ordered prefix anymore. + if (!pruned_by_limit || options.order_type != OrderType::ASCENDING) { + AppendRowGroups(null_only_groups, 0, ordered_row_groups); + } } if (ordered_row_groups.empty()) { diff --git a/src/duckdb/src/storage/table/row_id_column_data.cpp b/src/duckdb/src/storage/table/row_id_column_data.cpp index ff5975a4b..b3ca2393c 100644 --- a/src/duckdb/src/storage/table/row_id_column_data.cpp +++ b/src/duckdb/src/storage/table/row_id_column_data.cpp @@ -115,12 +115,15 @@ idx_t RowIdColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resul throw InternalException("Fetch is not supported for row id columns"); } -void RowIdColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, - row_t row_id, Vector &result, idx_t result_idx) { +void RowIdColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, + const idx_t *offsets, const SelectionVector &sel, idx_t fetch_count, Vector &result, + idx_t result_offset) { result.SetVectorType(VectorType::FLAT_VECTOR); auto data = FlatVector::GetDataMutable(result); auto row_start = state.row_group->GetRowStart(); - data[result_idx] = UnsafeNumericCast(row_start) + row_id; + for (idx_t idx = 0; idx < fetch_count; idx++) { + data[result_offset + idx] = NumericCast(row_start + offsets[sel.get_index(idx)]); + } } void RowIdColumnData::Skip(ColumnScanState &state, idx_t count) { @@ -132,12 +135,11 @@ void RowIdColumnData::InitializeAppend(ColumnAppendState &state) { throw InternalException("RowIdColumnData cannot be appended to"); } -void RowIdColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { +void RowIdColumnData::Append(ColumnAppendState &state, const Vector &vector, idx_t count) { throw InternalException("RowIdColumnData cannot be appended to"); } -void RowIdColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, - idx_t count) { +void RowIdColumnData::AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) { throw InternalException("RowIdColumnData cannot be appended to"); } diff --git a/src/duckdb/src/storage/table/row_number_column_data.cpp b/src/duckdb/src/storage/table/row_number_column_data.cpp index 89c6b5e54..013feab8a 100644 --- a/src/duckdb/src/storage/table/row_number_column_data.cpp +++ b/src/duckdb/src/storage/table/row_number_column_data.cpp @@ -72,9 +72,10 @@ idx_t RowNumberColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &r throw InternalException("Fetch is not supported for row number columns"); } -void RowNumberColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, - const StorageIndex &storage_index, row_t row_id, Vector &result, idx_t result_idx) { - throw InternalException("FetchRow is not supported for row number columns"); +void RowNumberColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, + const StorageIndex &storage_index, const idx_t *offsets, const SelectionVector &sel, + idx_t fetch_count, Vector &result, idx_t result_offset) { + throw InternalException("FetchRows is not supported for row number columns"); } void RowNumberColumnData::Skip(ColumnScanState &state, idx_t count) { @@ -86,12 +87,11 @@ void RowNumberColumnData::InitializeAppend(ColumnAppendState &state) { throw InternalException("RowNumberColumnData cannot be appended to"); } -void RowNumberColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { +void RowNumberColumnData::Append(ColumnAppendState &state, const Vector &vector, idx_t count) { throw InternalException("RowNumberColumnData cannot be appended to"); } -void RowNumberColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, - idx_t count) { +void RowNumberColumnData::AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) { throw InternalException("RowNumberColumnData cannot be appended to"); } diff --git a/src/duckdb/src/storage/table/row_version_manager.cpp b/src/duckdb/src/storage/table/row_version_manager.cpp index 1279f37a9..30175eac0 100644 --- a/src/duckdb/src/storage/table/row_version_manager.cpp +++ b/src/duckdb/src/storage/table/row_version_manager.cpp @@ -48,14 +48,40 @@ idx_t RowVersionManager::GetSelVector(ScanOptions options, idx_t vector_idx, Sel return chunk_info->GetSelVector(options, sel_vector, max_count); } -bool RowVersionManager::Fetch(TransactionData transaction, idx_t row) { - lock_guard lock(version_lock); - idx_t vector_index = row / STANDARD_VECTOR_SIZE; - auto info = GetChunkInfo(vector_index); - if (!info) { - return true; +idx_t RowVersionManager::GetVisibleRows(TransactionData transaction, const idx_t *offsets, idx_t count, + SelectionVector &visible_sel) { + if (count == 0) { + return 0; + } + const lock_guard lock(version_lock); + idx_t visible_count = 0; + idx_t idx = 0; + while (idx < count) { + const idx_t vector_idx = offsets[idx] / STANDARD_VECTOR_SIZE; + const idx_t base = vector_idx * STANDARD_VECTOR_SIZE; + const idx_t vector_end = base + STANDARD_VECTOR_SIZE; + // The first position after the current same-version-vector run. + idx_t next_idx = idx + 1; + // Extend the run while offsets remain within the same version-info vector. + while (next_idx < count && offsets[next_idx] >= base && offsets[next_idx] < vector_end) { + next_idx++; + } + auto info = GetChunkInfo(vector_idx); + if (!info) { + // no version info, which means entire chunk is visible + for (idx_t k = idx; k < next_idx; k++) { + visible_sel.set_index(visible_count++, k); + } + } else { + for (idx_t k = idx; k < next_idx; k++) { + if (info->Fetch(transaction, NumericCast(offsets[k] - base))) { + visible_sel.set_index(visible_count++, k); + } + } + } + idx = next_idx; } - return info->Fetch(transaction, UnsafeNumericCast(row - vector_index * STANDARD_VECTOR_SIZE)); + return visible_count; } void RowVersionManager::FillVectorInfo(idx_t vector_idx) { diff --git a/src/duckdb/src/storage/table/scan_state.cpp b/src/duckdb/src/storage/table/scan_state.cpp index 3d2f4b0ff..64bbcda2c 100644 --- a/src/duckdb/src/storage/table/scan_state.cpp +++ b/src/duckdb/src/storage/table/scan_state.cpp @@ -180,6 +180,13 @@ ParallelCollectionScanState::ParallelCollectionScanState() : collection(nullptr), current_row_group(nullptr), processed_rows(0) { } +void ParallelCollectionScanState::AssignRowGroup(optional_ptr> row_group) { + current_row_group = row_group; + while (current_row_group && !ShouldScanPartition(*current_row_group)) { + current_row_group = GetNextRowGroup(*row_groups, *current_row_group).get(); + } +} + optional_ptr> ParallelCollectionScanState::GetRootSegment(RowGroupSegmentTree &row_groups) const { if (reorderer) { return reorderer->GetRootSegment(row_groups); diff --git a/src/duckdb/src/storage/table/standard_column_data.cpp b/src/duckdb/src/storage/table/standard_column_data.cpp index f27b4cca7..515b61c4d 100644 --- a/src/duckdb/src/storage/table/standard_column_data.cpp +++ b/src/duckdb/src/storage/table/standard_column_data.cpp @@ -122,10 +122,14 @@ void StandardColumnData::InitializeAppend(ColumnAppendState &state) { state.child_appends.push_back(std::move(child_append)); } -void StandardColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, - idx_t count) { - ColumnData::AppendData(stats, state, vdata, count); - validity->AppendData(stats, state.child_appends[0], vdata, count); +void StandardColumnData::AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) { + ColumnData::AppendData(state, vdata, count); + validity->AppendData(state.child_appends[0], vdata, count); +} + +void StandardColumnData::FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) { + ColumnData::FinalizeAppend(finalize_state, state); + validity->FinalizeAppendLocked(finalize_state, state.child_appends[0]); } void StandardColumnData::RevertAppend(row_t new_count) { @@ -191,15 +195,16 @@ unique_ptr StandardColumnData::GetUpdateStatistics() { return stats; } -void StandardColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, - const StorageIndex &storage_index, row_t row_id, Vector &result, idx_t result_idx) { - // find the segment the row belongs to +void StandardColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, + const StorageIndex &storage_index, const idx_t *offsets, const SelectionVector &sel, + idx_t fetch_count, Vector &result, idx_t result_offset) { if (state.child_states.empty()) { - auto child_state = make_uniq(); - state.child_states.push_back(std::move(child_state)); + state.child_states.emplace_back(make_uniq()); } - ColumnData::FetchRow(transaction, state, storage_index, row_id, result, result_idx); - validity->FetchRow(transaction, *state.child_states[0], storage_index, row_id, result, result_idx); + // Bulk fetch the data and the validity in two passes. + FetchRowsAtSegmentLevel(transaction, state, offsets, sel, fetch_count, result, result_offset); + validity->FetchRowsAtSegmentLevel(transaction, *state.child_states[0], offsets, sel, fetch_count, result, + result_offset); } void StandardColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { @@ -215,6 +220,11 @@ void StandardColumnData::SetValidityData(shared_ptr validity this->validity = std::move(validity_p); } +ValidityColumnData &StandardColumnData::GetValidityData() { + D_ASSERT(validity); + return *validity; +} + struct StandardColumnCheckpointState : public ColumnCheckpointState { StandardColumnCheckpointState(const RowGroup &row_group, ColumnData &column_data, PartialBlockManager &partial_block_manager) @@ -322,11 +332,11 @@ void StandardColumnData::InitializeColumn(PersistentColumnData &column_data, Bas } void StandardColumnData::GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, - vector col_path, - vector &result) { - ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result); + vector col_path, vector &result, + const ColumnSegmentInfoScanOptions &options) { + ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result, options); col_path.push_back(0); - validity->GetColumnSegmentInfo(context, row_group_index, std::move(col_path), result); + validity->GetColumnSegmentInfo(context, row_group_index, std::move(col_path), result, options); } void StandardColumnData::Verify(RowGroup &parent) { diff --git a/src/duckdb/src/storage/table/struct_column_data.cpp b/src/duckdb/src/storage/table/struct_column_data.cpp index c5c5d9d60..1c1898aa2 100644 --- a/src/duckdb/src/storage/table/struct_column_data.cpp +++ b/src/duckdb/src/storage/table/struct_column_data.cpp @@ -136,7 +136,7 @@ static void ScanChild(ColumnScanState &state, Vector &result, const std::functio target.Reset(); input.Reset(); auto scan_count = callback(input.data[0]); - input.SetCardinality(scan_count); + FlatVector::SetSize(input.data[0], scan_count); executor.Execute(input, target); result.Reference(target.data[0]); } else { @@ -224,25 +224,33 @@ void StructColumnData::InitializeAppend(ColumnAppendState &state) { } } -void StructColumnData::Append(BaseStatistics &stats, ColumnAppendState &state, Vector &vector, idx_t count) { +void StructColumnData::Append(ColumnAppendState &state, const Vector &vector, idx_t count) { if (vector.GetVectorType() != VectorType::FLAT_VECTOR) { Vector append_vector(Vector::Ref(vector)); append_vector.Flatten(); - Append(stats, state, append_vector, count); + Append(state, append_vector, count); return; } // append the null values - validity->Append(stats, state.child_appends[0], vector, count); + validity->Append(state.child_appends[0], vector, count); auto &child_entries = StructVector::GetEntries(vector); for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->Append(StructStats::GetChildStats(stats, i), state.child_appends[i + 1], child_entries[i], - count); + sub_columns[i]->Append(state.child_appends[i + 1], child_entries[i], count); } this->count += count; } +void StructColumnData::FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) { + validity->FinalizeAppendLocked(finalize_state, state.child_appends[0]); + + for (idx_t i = 0; i < sub_columns.size(); i++) { + ColumnDataFinalizeAppendState child_finalize_state(finalize_state, LogicalTypeId::STRUCT, i); + sub_columns[i]->FinalizeAppendLocked(child_finalize_state, state.child_appends[i + 1]); + } +} + void StructColumnData::RevertAppend(row_t new_count) { validity->RevertAppend(new_count); for (auto &sub_column : sub_columns) { @@ -316,10 +324,11 @@ unique_ptr StructColumnData::GetUpdateStatistics() { return stats.ToUnique(); } -void StructColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, const StorageIndex &storage_index, - row_t row_id, Vector &result, idx_t result_idx) { +void StructColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, + const StorageIndex &storage_index, const idx_t *offsets, const SelectionVector &sel, + idx_t fetch_count, Vector &result, idx_t result_offset) { // fetch the validity state - validity->FetchRow(transaction, state, storage_index, row_id, result, result_idx); + validity->FetchRowsAtSegmentLevel(transaction, state, offsets, sel, fetch_count, result, result_offset); if (storage_index.IsPushdownExtract()) { auto &index_children = storage_index.GetChildIndexes(); D_ASSERT(index_children.size() == 1); @@ -329,14 +338,19 @@ void StructColumnData::FetchRow(TransactionData transaction, ColumnFetchState &s auto &child_type = StructType::GetChildTypes(type)[child_index].second; if (!child_storage_index.HasChildren() && child_storage_index.HasType() && child_storage_index.GetType() != child_type) { - Vector intermediate(child_type, 1); - sub_column.FetchRow(transaction, state, child_storage_index, row_id, intermediate, 0); auto context = transaction.transaction->context.lock(); - auto fetched_row = intermediate.GetValue(0).CastAs(*context, result.GetType()); - result.SetValue(result_idx, fetched_row); + for (idx_t idx = 0; idx < fetch_count; idx++) { + const idx_t offset = offsets[sel.get_index(idx)]; + Vector intermediate(child_type, 1); + sub_column.FetchRows(transaction, state, child_storage_index, &offset, + *FlatVector::IncrementalSelectionVector(), /*count=*/1, intermediate, 0); + auto fetched_row = intermediate.GetValue(0).CastAs(*context, result.GetType()); + result.SetValue(result_offset + idx, fetched_row); + } return; } else { - sub_column.FetchRow(transaction, state, child_storage_index, row_id, result, result_idx); + sub_column.FetchRows(transaction, state, child_storage_index, offsets, sel, fetch_count, result, + result_offset); return; } } @@ -344,7 +358,8 @@ void StructColumnData::FetchRow(TransactionData transaction, ColumnFetchState &s auto &child_entries = StructVector::GetEntries(result); // fetch the sub-column states for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->FetchRow(transaction, state, storage_index, row_id, child_entries[i], result_idx); + sub_columns[i]->FetchRows(transaction, state, storage_index, offsets, sel, fetch_count, child_entries[i], + result_offset); } } @@ -505,12 +520,13 @@ void StructColumnData::InitializeColumn(PersistentColumnData &column_data, BaseS } void StructColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, - vector &result) { + vector &result, + const ColumnSegmentInfoScanOptions &options) { col_path.push_back(0); - validity->GetColumnSegmentInfo(context, row_group_index, col_path, result); + validity->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); for (idx_t i = 0; i < sub_columns.size(); i++) { col_path.back() = i + 1; - sub_columns[i]->GetColumnSegmentInfo(context, row_group_index, col_path, result); + sub_columns[i]->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); } } diff --git a/src/duckdb/src/storage/table/table_statistics.cpp b/src/duckdb/src/storage/table/table_statistics.cpp index 95a09efa1..f5dbb80df 100644 --- a/src/duckdb/src/storage/table/table_statistics.cpp +++ b/src/duckdb/src/storage/table/table_statistics.cpp @@ -165,8 +165,8 @@ void TableStatistics::MergeStats(idx_t i, BaseStatistics &stats) { MergeStats(*l, i, stats); } -void TableStatistics::MergeStats(TableStatisticsLock &lock, idx_t i, BaseStatistics &stats) { - column_stats[i]->Statistics().Merge(stats); +void TableStatistics::MergeStats(TableStatisticsLock &lock, idx_t i, BaseStatistics &stats, StatsMergeType merge_type) { + column_stats[i]->Statistics().Merge(stats, merge_type); } ColumnStatistics &TableStatistics::GetStats(TableStatisticsLock &lock, idx_t i) { @@ -280,7 +280,7 @@ unique_ptr TableStatistics::GetLock() { return make_uniq(*stats_lock); } -bool TableStatistics::Empty() { +bool TableStatistics::Empty() const { D_ASSERT(column_stats.empty() == (stats_lock.get() == nullptr)); return column_stats.empty(); } diff --git a/src/duckdb/src/storage/table/update_segment.cpp b/src/duckdb/src/storage/table/update_segment.cpp index 6dff6ad38..9669c533f 100644 --- a/src/duckdb/src/storage/table/update_segment.cpp +++ b/src/duckdb/src/storage/table/update_segment.cpp @@ -5,6 +5,7 @@ #include "duckdb/storage/table/column_data.hpp" #include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/transaction/update_info.hpp" +#include "duckdb/storage/statistics/stats_writer.hpp" #include @@ -18,7 +19,7 @@ static UpdateSegment::fetch_committed_range_function_t GetFetchCommittedRangeFun static UpdateSegment::merge_update_function_t GetMergeUpdateFunction(PhysicalType type); static UpdateSegment::rollback_update_function_t GetRollbackUpdateFunction(PhysicalType type); static UpdateSegment::statistics_update_function_t GetStatisticsUpdateFunction(PhysicalType type); -static UpdateSegment::fetch_row_function_t GetFetchRowFunction(PhysicalType type); +static UpdateSegment::fetch_rows_function_t GetFetchRowsFunction(PhysicalType type); static UpdateSegment::get_effective_updates_t GetEffectiveUpdatesFunction(PhysicalType type); UpdateSegment::UpdateSegment(ColumnData &column_data) @@ -31,7 +32,7 @@ UpdateSegment::UpdateSegment(ColumnData &column_data) this->fetch_update_function = GetFetchUpdateFunction(physical_type); this->fetch_committed_function = GetFetchCommittedFunction(physical_type); this->fetch_committed_range = GetFetchCommittedRangeFunction(physical_type); - this->fetch_row_function = GetFetchRowFunction(physical_type); + this->fetch_rows_function = GetFetchRowsFunction(physical_type); this->merge_update_function = GetMergeUpdateFunction(physical_type); this->rollback_update_function = GetRollbackUpdateFunction(physical_type); this->statistics_update_function = GetStatisticsUpdateFunction(physical_type); @@ -102,6 +103,33 @@ idx_t UpdateInfo::GetAllocSize(idx_t type_size) { return AlignValue(sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); } +idx_t UpdateInfo::GetAllocSize(idx_t type_size, idx_t capacity) { + return AlignValue(sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * capacity); +} + +idx_t UpdateInfo::GetCompactCapacity(idx_t count) { + // Round up to next power of two, with a minimum of 8, capped at STANDARD_VECTOR_SIZE + if (count <= 8) { + return 8; + } + idx_t capacity = NextPowerOfTwo(count); + return MinValue(capacity, STANDARD_VECTOR_SIZE); +} + +void UpdateInfo::Initialize(UpdateInfo &info, DuckTableEntry &table_entry, transaction_t transaction_id, + idx_t row_group_start, idx_t capacity) { + info.max = UnsafeNumericCast(capacity); + info.row_group_start = row_group_start; + info.table = &table_entry; + info.column_index = COLUMN_IDENTIFIER_ROW_ID; + info.segment = nullptr; + info.vector_index = 0; + info.prev = UndoBufferPointer(); + info.next = UndoBufferPointer(); + info.N = 0; + info.version_number = transaction_id; +} + void UpdateInfo::Initialize(UpdateInfo &info, DuckTableEntry &table_entry, transaction_t transaction_id, idx_t row_group_start) { info.max = STANDARD_VECTOR_SIZE; @@ -410,104 +438,131 @@ void UpdateSegment::FetchCommittedRange(idx_t start_row, idx_t count, Vector &re } //===--------------------------------------------------------------------===// -// Fetch Row +// Fetch Rows //===--------------------------------------------------------------------===// -static void FetchRowValidity(transaction_t start_time, transaction_t transaction_id, UpdateInfo &info, idx_t row_idx, - Vector &result, idx_t result_idx) { +static bool FindUpdatedTuple(UpdateInfo ¤t, idx_t row_idx, idx_t &update_idx) { + auto tuples = current.GetTuples(); + idx_t left = 0; + idx_t right = current.N; + while (left < right) { + idx_t mid = left + (right - left) / 2; + if (tuples[mid] == row_idx) { + update_idx = mid; + return true; + } else if (tuples[mid] < row_idx) { + left = mid + 1; + } else { + right = mid; + } + } + return false; +} + +static void FetchRowsValidity(transaction_t start_time, transaction_t transaction_id, UpdateInfo &info, + const idx_t *offsets, const SelectionVector &sel, idx_t fetch_offset, idx_t count, + idx_t vector_offset, Vector &result, idx_t result_offset) { auto &result_mask = FlatVector::ValidityMutable(result); UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, [&](UpdateInfo ¤t) { auto info_data = current.GetData(); - auto tuples = current.GetTuples(); - if (current.N == 0) { - return; - } - idx_t left = 0; - idx_t right = current.N; - while (left < right) { - idx_t mid = left + (right - left) / 2; - if (tuples[mid] == row_idx) { - result_mask.Set(result_idx, info_data[mid]); - return; - } else if (tuples[mid] < row_idx) { - left = mid + 1; - } else { - right = mid; + for (idx_t idx = 0; idx < count; idx++) { + const idx_t row_idx = offsets[sel.get_index(fetch_offset + idx)] - vector_offset; + idx_t update_idx; + if (FindUpdatedTuple(current, row_idx, update_idx)) { + result_mask.Set(result_offset + fetch_offset + idx, info_data[update_idx]); } } }); } template -static void TemplatedFetchRow(transaction_t start_time, transaction_t transaction_id, UpdateInfo &info, idx_t row_idx, - Vector &result, idx_t result_idx) { +static void TemplatedFetchRows(transaction_t start_time, transaction_t transaction_id, UpdateInfo &info, + const idx_t *offsets, const SelectionVector &sel, idx_t fetch_offset, idx_t count, + idx_t vector_offset, Vector &result, idx_t result_offset) { auto result_data = FlatVector::GetDataMutable(result); UpdateInfo::UpdatesForTransaction(info, start_time, transaction_id, [&](UpdateInfo ¤t) { auto info_data = current.GetData(); - auto tuples = current.GetTuples(); - // FIXME: we could do a binary search in here - for (idx_t i = 0; i < current.N; i++) { - if (tuples[i] == row_idx) { - result_data[result_idx] = info_data[i]; - break; - } else if (tuples[i] > row_idx) { - break; + for (idx_t idx = 0; idx < count; idx++) { + const idx_t row_idx = offsets[sel.get_index(fetch_offset + idx)] - vector_offset; + idx_t update_idx; + if (FindUpdatedTuple(current, row_idx, update_idx)) { + result_data[result_offset + fetch_offset + idx] = info_data[update_idx]; } } }); } -static UpdateSegment::fetch_row_function_t GetFetchRowFunction(PhysicalType type) { +static UpdateSegment::fetch_rows_function_t GetFetchRowsFunction(PhysicalType type) { switch (type) { case PhysicalType::BIT: - return FetchRowValidity; + return FetchRowsValidity; case PhysicalType::BOOL: case PhysicalType::INT8: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::INT16: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::INT32: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::INT64: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::UINT8: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::UINT16: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::UINT32: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::UINT64: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::INT128: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::UINT128: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::FLOAT: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::DOUBLE: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::INTERVAL: - return TemplatedFetchRow; + return TemplatedFetchRows; case PhysicalType::VARCHAR: - return TemplatedFetchRow; + return TemplatedFetchRows; default: - throw NotImplementedException("Unimplemented type for update segment fetch row"); + throw NotImplementedException("Unimplemented type for update segment fetch rows"); } } -void UpdateSegment::FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx) { - if (row_id > column_data.count) { - throw InternalException("UpdateSegment::FetchRow out of range"); +void UpdateSegment::FetchRows(TransactionData transaction, const idx_t *offsets, const SelectionVector &sel, + idx_t fetch_count, Vector &result, idx_t result_offset) { + if (fetch_count == 0 || !root) { + return; } - idx_t vector_index = row_id / STANDARD_VECTOR_SIZE; + auto lock_handle = lock.GetSharedLock(); - auto entry = GetUpdateNode(*lock_handle, vector_index); - if (!entry.IsSet()) { - return; + for (idx_t idx = 0; idx < fetch_count;) { + const idx_t offset = offsets[sel.get_index(idx)]; + if (offset > column_data.count) { + throw InternalException("UpdateSegment::FetchRows out of range"); + } + const idx_t vector_index = offset / STANDARD_VECTOR_SIZE; + const idx_t vector_offset = vector_index * STANDARD_VECTOR_SIZE; + idx_t vector_count = 1; + while (idx + vector_count < fetch_count) { + const idx_t next_offset = offsets[sel.get_index(idx + vector_count)]; + if (next_offset > column_data.count) { + throw InternalException("UpdateSegment::FetchRows out of range"); + } + if (next_offset / STANDARD_VECTOR_SIZE != vector_index) { + break; + } + vector_count++; + } + + auto entry = GetUpdateNode(*lock_handle, vector_index); + if (entry.IsSet()) { + auto pin = entry.Pin(); + fetch_rows_function(transaction.start_time, transaction.transaction_id, UpdateInfo::Get(pin), offsets, sel, + idx, vector_count, vector_offset, result, result_offset); + } + idx += vector_count; } - idx_t row_in_vector = row_id - vector_index * STANDARD_VECTOR_SIZE; - auto pin = entry.Pin(); - fetch_row_function(transaction.start_time, transaction.transaction_id, UpdateInfo::Get(pin), row_in_vector, result, - result_idx); } //===--------------------------------------------------------------------===// @@ -709,6 +764,16 @@ string_t UpdateSelectElement::Operation(UpdateSegment &segment, string_t element return segment.GetStringHeap().AddBlob(element); } +template +static T GetUpdateNullPadding() { + return T(); +} + +template <> +string_t GetUpdateNullPadding() { + return string_t(uint32_t(0)); +} + template static void InitializeUpdateData(UpdateInfo &base_info, Vector &base_data, UpdateInfo &update_info, UnifiedVectorFormat &update, const SelectionVector &sel) { @@ -727,6 +792,7 @@ static void InitializeUpdateData(UpdateInfo &base_info, Vector &base_data, Updat for (idx_t i = 0; i < base_info.N; i++) { auto base_idx = base_tuples[i]; if (!base_validity.RowIsValid(base_idx)) { + base_tuple_data[i] = GetUpdateNullPadding(); continue; } base_tuple_data[i] = UpdateSelectElement::Operation(*base_info.segment, base_array_data[base_idx]); @@ -827,7 +893,8 @@ struct ExtractValidityEntry { template static void MergeUpdateLoopInternal(UpdateInfo &base_info, V *base_table_data, UpdateInfo &update_info, const SelectionVector &update_vector_sel, const V *update_vector_data, row_t *ids, - idx_t count, const SelectionVector &sel, idx_t row_group_start) { + idx_t count, const SelectionVector &sel, idx_t row_group_start, + const ValidityMask *base_table_validity = nullptr) { auto base_id = row_group_start + base_info.vector_index * STANDARD_VECTOR_SIZE; #ifdef DEBUG // all of these should be sorted, otherwise the below algorithm does not work @@ -888,6 +955,8 @@ static void MergeUpdateLoopInternal(UpdateInfo &base_info, V *base_table_data, U if (base_info_offset < base_info.N && base_tuples[base_info_offset] == update_id) { // it is! we have to move the tuple from base_info->ids[base_info_offset] to update_info result_values[result_offset] = base_info_data[base_info_offset]; + } else if (base_table_validity && !base_table_validity->RowIsValid(update_id)) { + result_values[result_offset] = GetUpdateNullPadding(); } else { // it is not! we have to move base_table_data[update_id] to update_info result_values[result_offset] = UpdateSelectElement::Operation( @@ -945,8 +1014,9 @@ static void MergeUpdateLoop(UpdateInfo &base_info, Vector &base_data, UpdateInfo idx_t row_group_start) { auto base_table_data = FlatVector::GetDataMutable(base_data); auto update_vector_data = update.GetData(update); + auto &base_validity = FlatVector::Validity(base_data); MergeUpdateLoopInternal(base_info, base_table_data, update_info, *update.sel, update_vector_data, ids, count, - sel, row_group_start); + sel, row_group_start, &base_validity); } static UpdateSegment::merge_update_function_t GetMergeUpdateFunction(PhysicalType type) { @@ -1050,17 +1120,19 @@ idx_t UpdateStringStatistics(UpdateSegment *segment, SegmentStatistics &stats, U SelectionVector &sel) { auto update_data = update.GetData(update); auto &mask = update.validity; + + StatsWriter stats_writer(stats.statistics.GetType()); + idx_t not_null_count = 0; if (mask.CannotHaveNull()) { stats.statistics.SetHasNoNullFast(); for (idx_t i = 0; i < count; i++) { auto idx = update.sel->get_index(i); auto &str = update_data[idx]; - StringStats::Update(stats.statistics, str); + stats_writer.Update(str); } sel.Initialize(nullptr); - return count; + not_null_count = count; } else { - idx_t not_null_count = 0; sel.Initialize(STANDARD_VECTOR_SIZE); for (idx_t i = 0; i < count; i++) { auto idx = update.sel->get_index(i); @@ -1068,7 +1140,7 @@ idx_t UpdateStringStatistics(UpdateSegment *segment, SegmentStatistics &stats, U stats.statistics.SetHasNoNullFast(); sel.set_index(not_null_count++, i); auto &str = update_data[idx]; - StringStats::Update(stats.statistics, str); + stats_writer.Update(str); } else { stats.statistics.SetHasNullFast(); } @@ -1076,8 +1148,9 @@ idx_t UpdateStringStatistics(UpdateSegment *segment, SegmentStatistics &stats, U if (not_null_count == count) { sel.Initialize(nullptr); } - return not_null_count; } + stats_writer.Merge(stats.statistics); + return not_null_count; } UpdateSegment::statistics_update_function_t GetStatisticsUpdateFunction(PhysicalType type) { @@ -1278,6 +1351,39 @@ void UpdateSegment::InitializeUpdateInfo(idx_t vector_idx) { } } +void UpdateSegment::ReallocateRootInfoIfNeeded(UpdateInfo ¤t_info, idx_t update_count, idx_t vector_index) { + idx_t required_capacity = MinValue(idx_t(current_info.N) + update_count, STANDARD_VECTOR_SIZE); + if (required_capacity <= idx_t(current_info.max)) { + return; + } + idx_t new_capacity = UpdateInfo::GetCompactCapacity(required_capacity); + idx_t new_alloc_size = UpdateInfo::GetAllocSize(type_size, new_capacity); + auto new_handle = root->allocator.Allocate(new_alloc_size); + auto &new_info = UpdateInfo::Get(new_handle); + + new_info.segment = current_info.segment; + new_info.table = current_info.table; + new_info.column_index = current_info.column_index; + new_info.row_group_start = current_info.row_group_start; + new_info.version_number.store(current_info.version_number.load()); + new_info.vector_index = current_info.vector_index; + new_info.N = current_info.N; + new_info.max = UnsafeNumericCast(new_capacity); + new_info.prev = current_info.prev; + new_info.next = current_info.next; + + memcpy(new_info.GetTuples(), current_info.GetTuples(), sizeof(sel_t) * current_info.N); + memcpy(new_info.GetValues(), current_info.GetValues(), type_size * current_info.N); + + if (new_info.next.IsSet()) { + auto next_pin = new_info.next.Pin(); + auto &next_info = UpdateInfo::Get(next_pin); + next_info.prev = new_handle.GetBufferPointer(); + } + + root->info[vector_index] = new_handle.GetBufferPointer(); +} + void UpdateSegment::Update(TransactionData transaction, DuckTableEntry &table_entry, idx_t column_index, Vector &update_p, row_t *ids, idx_t count, Vector &base_data, idx_t row_group_start) { // obtain an exclusive lock @@ -1299,11 +1405,9 @@ void UpdateSegment::Update(TransactionData transaction, DuckTableEntry &table_en // for strings - we need to push all strings we are going to place here into the string heap of the segment update_p.Flatten(); auto update_data = FlatVector::GetDataMutable(update_p); - auto &validity = FlatVector::ValidityMutable(update_p); for (idx_t i = 0; i < count; i++) { - if (validity.RowIsValid(i)) { - update_data[i] = GetStringHeap().AddBlob(update_data[i]); - } + auto idx = sel.get_index(i); + update_data[idx] = GetStringHeap().AddBlob(update_data[idx]); } update_p.ToUnifiedFormat(update_format); } @@ -1341,11 +1445,15 @@ void UpdateSegment::Update(TransactionData transaction, DuckTableEntry &table_en // this transaction in the version chain auto root_pointer = root->info[vector_index]; auto root_pin = root_pointer.Pin(); - auto &base_info = UpdateInfo::Get(root_pin); UndoBufferReference node_ref; - CheckForConflicts(base_info.next, transaction, ids, sel, count, UnsafeNumericCast(vector_offset), - node_ref); + CheckForConflicts(UpdateInfo::Get(root_pin).next, transaction, ids, sel, count, + UnsafeNumericCast(vector_offset), node_ref); + + ReallocateRootInfoIfNeeded(UpdateInfo::Get(root_pin), count, vector_index); + root_pointer = root->info[vector_index]; + root_pin = root_pointer.Pin(); + auto &base_info = UpdateInfo::Get(root_pin); // there are no conflicts - continue with the update unsafe_unique_array update_info_data; @@ -1388,11 +1496,12 @@ void UpdateSegment::Update(TransactionData transaction, DuckTableEntry &table_en node->Verify(); } else { // there is no version info yet: create the top level update info and fill it with the updates - // allocate space for the UpdateInfo in the allocator - idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); + // allocate space for the UpdateInfo in the allocator (compact: sized to actual count) + idx_t compact_capacity = UpdateInfo::GetCompactCapacity(count); + idx_t alloc_size = UpdateInfo::GetAllocSize(type_size, compact_capacity); auto handle = root->allocator.Allocate(alloc_size); auto &update_info = UpdateInfo::Get(handle); - UpdateInfo::Initialize(update_info, table_entry, TRANSACTION_ID_START - 1, row_group_start); + UpdateInfo::Initialize(update_info, table_entry, TRANSACTION_ID_START - 1, row_group_start, compact_capacity); update_info.column_index = column_index; InitializeUpdateInfo(update_info, ids, sel, count, vector_index, vector_offset); diff --git a/src/duckdb/src/storage/table/validity_column_data.cpp b/src/duckdb/src/storage/table/validity_column_data.cpp index 50d36dc93..9ce649f90 100644 --- a/src/duckdb/src/storage/table/validity_column_data.cpp +++ b/src/duckdb/src/storage/table/validity_column_data.cpp @@ -36,10 +36,8 @@ void ValidityColumnData::UpdateWithBase(TransactionData transaction, DuckTableEn row_group_start); } -void ValidityColumnData::AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, - idx_t count) { - lock_guard l(stats_lock); - ColumnData::AppendData(stats, state, vdata, count); +void ValidityColumnData::AppendData(ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) { + ColumnData::AppendData(state, vdata, count); } struct ValidityColumnCheckpointState : public ColumnCheckpointState { diff --git a/src/duckdb/src/storage/table/variant/variant_shredding.cpp b/src/duckdb/src/storage/table/variant/variant_shredding.cpp index 75d0f23a8..8cd55ae33 100644 --- a/src/duckdb/src/storage/table/variant/variant_shredding.cpp +++ b/src/duckdb/src/storage/table/variant/variant_shredding.cpp @@ -490,7 +490,7 @@ LogicalType VariantShreddingStats::GetShreddedType() const { return LogicalType::STRUCT(child_types); } -void VariantShreddingStats::Update(Vector &input, idx_t count) { +void VariantShreddingStats::Update(const Vector &input, idx_t count) { RecursiveUnifiedVectorFormat recursive_format; Vector::RecursiveToUnifiedFormat(input, recursive_format); UnifiedVariantVectorData variant(recursive_format); @@ -694,7 +694,7 @@ void DuckDBVariantShredding::WriteVariantValues(UnifiedVariantVectorData &varian } } -void VariantColumnData::ShredVariantData(Vector &input, Vector &output, idx_t count) { +void VariantColumnData::ShredVariantData(const Vector &input, Vector &output, idx_t count) { RecursiveUnifiedVectorFormat recursive_format; Vector::RecursiveToUnifiedFormat(input, recursive_format); UnifiedVariantVectorData variant(recursive_format); diff --git a/src/duckdb/src/storage/table/variant/variant_unshredding.cpp b/src/duckdb/src/storage/table/variant/variant_unshredding.cpp index 8cc8b7d6a..054c77222 100644 --- a/src/duckdb/src/storage/table/variant/variant_unshredding.cpp +++ b/src/duckdb/src/storage/table/variant/variant_unshredding.cpp @@ -114,7 +114,7 @@ static vector UnshredTypedObject(UnifiedVariantVectorData &variant continue; } auto &obj_value = res[i]; - obj_value.AddChild(child_name, std::move(values[i])); + obj_value.AddChild(child_name.GetIdentifierName(), std::move(values[i])); } } return res; diff --git a/src/duckdb/src/storage/table/variant_column_data.cpp b/src/duckdb/src/storage/table/variant_column_data.cpp index b0b00a74f..acb39da51 100644 --- a/src/duckdb/src/storage/table/variant_column_data.cpp +++ b/src/duckdb/src/storage/table/variant_column_data.cpp @@ -81,7 +81,7 @@ bool FindShreddedColumnInternal(const ColumnData &shredded, referenceAppend(stats, state.child_appends[0], vector, count); + validity->Append(state.child_appends[0], vector, count); if (IsShredded()) { throw InternalException("Can't append to a shredded VariantColumnData"); } else { for (idx_t i = 0; i < sub_columns.size(); i++) { - sub_columns[i]->Append(VariantStats::GetUnshreddedStats(stats), state.child_appends[i + 1], vector, count); + sub_columns[i]->Append(state.child_appends[i + 1], vector, count); } - VariantStats::MarkAsNotShredded(stats); } this->count += count; } +void VariantColumnData::FinalizeAppend(ColumnDataFinalizeAppendState &finalize_state, ColumnAppendState &state) { + for (idx_t i = 0; i < sub_columns.size(); i++) { + ColumnDataFinalizeAppendState child_finalize_state(finalize_state, LogicalTypeId::VARIANT); + sub_columns[i]->FinalizeAppend(child_finalize_state, state.child_appends[i + 1]); + } + if (!IsShredded()) { + for (auto &stats_ref : finalize_state.global_stats) { + VariantStats::MarkAsNotShredded(stats_ref.get()); + } + } +} + void VariantColumnData::RevertAppend(row_t new_count) { validity->RevertAppend(new_count); for (idx_t i = 0; i < sub_columns.size(); i++) { @@ -443,75 +430,79 @@ unique_ptr VariantColumnData::GetUpdateStatistics() { return nullptr; } -void VariantColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, - const StorageIndex &storage_index, row_t row_id, Vector &result, idx_t result_idx) { +void VariantColumnData::FetchRows(TransactionData transaction, ColumnFetchState &state, + const StorageIndex &storage_index, const idx_t *offsets, const SelectionVector &sel, + idx_t fetch_count, Vector &result, idx_t result_offset) { if (storage_index.IsPushdownExtract() && IsShredded()) { StorageIndex struct_extract; if (PushdownShreddedFieldExtract(storage_index.GetChildIndex(0), struct_extract)) { //! Shredded field exists and is fully shredded, //! add the storage index to create a pushed-down 'struct_extract' to get the leaf - sub_columns[1]->FetchRow(transaction, state, struct_extract, row_id, result, result_idx); + sub_columns[1]->FetchRows(transaction, state, struct_extract, offsets, sel, fetch_count, result, + result_offset); return; } } - Vector variant_vec(LogicalType::VARIANT(), result_idx + 1); - validity->FetchRow(transaction, state, storage_index, row_id, variant_vec, result_idx); - if (IsShredded()) { - auto intermediate = CreateUnshreddingIntermediate(result_idx + 1); - auto &child_vectors = StructVector::GetEntries(intermediate); - // fetch the validity state - // fetch the sub-column states - StorageIndex empty(0); - for (idx_t i = 0; i < sub_columns.size(); i++) { - sub_columns[i]->FetchRow(transaction, state, empty, row_id, child_vectors[i], result_idx); - } - if (result_idx) { - intermediate.SetValue(0, intermediate.GetValue(result_idx)); + + for (idx_t idx = 0; idx < fetch_count; idx++) { + const idx_t offset = offsets[sel.get_index(idx)]; + const idx_t result_idx = result_offset + idx; + Vector variant_vec(LogicalType::VARIANT(), 1); + validity->FetchRowsAtSegmentLevel(transaction, state, &offset, *FlatVector::IncrementalSelectionVector(), + /*count=*/1, variant_vec, 0); + if (IsShredded()) { + auto intermediate = CreateUnshreddingIntermediate(1); + auto &child_vectors = StructVector::GetEntries(intermediate); + // fetch the validity state + // fetch the sub-column states + StorageIndex empty(0); + for (idx_t i = 0; i < sub_columns.size(); i++) { + sub_columns[i]->FetchRows(transaction, state, empty, &offset, *FlatVector::IncrementalSelectionVector(), + /*count=*/1, child_vectors[i], 0); + } + + //! FIXME: adjust UnshredVariantData so we can write the value in place directly. + Vector unshredded(variant_vec.GetType(), 1); + VariantUtils::UnshredVariantData(intermediate, unshredded, 1); + variant_vec.SetValue(0, unshredded.GetValue(0)); + } else { + sub_columns[0]->FetchRows(transaction, state, storage_index, &offset, + *FlatVector::IncrementalSelectionVector(), /*count=*/1, variant_vec, 0); } - //! FIXME: adjust UnshredVariantData so we can write the value in place directly. - Vector unshredded(variant_vec.GetType(), 1); - VariantUtils::UnshredVariantData(intermediate, unshredded, 1); - variant_vec.SetValue(0, unshredded.GetValue(0)); - } else { - sub_columns[0]->FetchRow(transaction, state, storage_index, row_id, variant_vec, result_idx); - if (result_idx) { - variant_vec.SetValue(0, variant_vec.GetValue(result_idx)); + if (!storage_index.IsPushdownExtract()) { + //! No extract required + D_ASSERT(result.GetType().id() == LogicalTypeId::VARIANT); + result.SetValue(result_idx, variant_vec.GetValue(0)); + continue; } - } - if (!storage_index.IsPushdownExtract()) { - //! No extract required - D_ASSERT(result.GetType().id() == LogicalTypeId::VARIANT); - result.SetValue(result_idx, variant_vec.GetValue(0)); - return; - } + Vector extracted_variant(variant_vec.GetType(), 1); + vector components; + reference path_iter(storage_index.GetChildIndex(0)); - Vector extracted_variant(variant_vec.GetType(), 1); - vector components; - reference path_iter(storage_index.GetChildIndex(0)); + while (true) { + auto ¤t = path_iter.get(); + auto &field_name = current.GetFieldName(); + components.emplace_back(field_name); + if (!current.HasChildren()) { + break; + } + path_iter = current.GetChildIndex(0); + } + VariantUtils::VariantExtract(variant_vec, components, extracted_variant, 1); - while (true) { - auto ¤t = path_iter.get(); - auto &field_name = current.GetFieldName(); - components.emplace_back(field_name); - if (!current.HasChildren()) { - break; + if (result.GetType().id() == LogicalTypeId::VARIANT) { + //! No cast required + result.SetValue(result_idx, extracted_variant.GetValue(0)); + continue; } - path_iter = current.GetChildIndex(0); - } - VariantUtils::VariantExtract(variant_vec, components, extracted_variant, 1); - if (result.GetType().id() == LogicalTypeId::VARIANT) { - //! No cast required - result.SetValue(result_idx, extracted_variant.GetValue(0)); - return; + //! Need to perform the cast here as well + auto context = transaction.transaction->context.lock(); + auto fetched_row = extracted_variant.GetValue(0).CastAs(*context, result.GetType()); + result.SetValue(result_idx, fetched_row); } - - //! Need to perform the cast here as well - auto context = transaction.transaction->context.lock(); - auto fetched_row = extracted_variant.GetValue(0).CastAs(*context, result.GetType()); - result.SetValue(result_idx, fetched_row); } void VariantColumnData::VisitBlockIds(BlockIdVisitor &visitor) const { @@ -606,6 +597,42 @@ unique_ptr VariantColumnData::CreateCheckpointState(const return make_uniq(row_group, *this, partial_block_manager); } +static void AppendShredded(Vector &input, Vector &append_vector, idx_t count, VariantShreddedAppendInput &append_data) { + D_ASSERT(append_vector.GetType().id() == LogicalTypeId::STRUCT); + auto &child_vectors = StructVector::GetEntries(append_vector); + D_ASSERT(child_vectors.size() == 2); + + //! Create the new column data for the shredded data + VariantColumnData::ShredVariantData(input, append_vector, count); + auto &unshredded_vector = child_vectors[0]; + auto &shredded_vector = child_vectors[1]; + + auto &unshredded = append_data.unshredded; + auto &shredded = append_data.shredded; + + auto &unshredded_append_state = append_data.unshredded_append_state; + auto &shredded_append_state = append_data.shredded_append_state; + + unshredded.Append(unshredded_append_state, unshredded_vector, count); + shredded.Append(shredded_append_state, shredded_vector, count); +} + +static void FinalizeShreddedAppend(VariantShreddedAppendInput &append_data) { + auto &unshredded_stats = append_data.unshredded_stats; + auto &shredded_stats = append_data.shredded_stats; + + auto &unshredded = append_data.unshredded; + auto &shredded = append_data.shredded; + + auto &unshredded_append_state = append_data.unshredded_append_state; + auto &shredded_append_state = append_data.shredded_append_state; + + ColumnDataFinalizeAppendState unshredded_finalize_state(unshredded_stats); + unshredded.FinalizeAppend(unshredded_finalize_state, unshredded_append_state); + ColumnDataFinalizeAppendState shredded_finalize_state(shredded_stats); + shredded.FinalizeAppend(shredded_finalize_state, shredded_append_state); +} + vector> VariantColumnData::WriteShreddedData(const RowGroup &row_group, const LogicalType &shredded_type, BaseStatistics &stats) { @@ -659,6 +686,8 @@ vector> VariantColumnData::WriteShreddedData(const RowGro AppendShredded(scan_vector, append_vector, to_scan, append_data); } + FinalizeShreddedAppend(append_data); + stats = std::move(transformed_stats); return ret; } @@ -819,12 +848,13 @@ void VariantColumnData::InitializeColumn(PersistentColumnData &column_data, Base } void VariantColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, - vector &result) { + vector &result, + const ColumnSegmentInfoScanOptions &options) { col_path.push_back(0); - validity->GetColumnSegmentInfo(context, row_group_index, col_path, result); + validity->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); for (idx_t i = 0; i < sub_columns.size(); i++) { col_path.back() = i + 1; - sub_columns[i]->GetColumnSegmentInfo(context, row_group_index, col_path, result); + sub_columns[i]->GetColumnSegmentInfo(context, row_group_index, col_path, result, options); } } diff --git a/src/duckdb/src/storage/table_index_list.cpp b/src/duckdb/src/storage/table_index_list.cpp index 0d1a0a3d9..d6e3d8119 100644 --- a/src/duckdb/src/storage/table_index_list.cpp +++ b/src/duckdb/src/storage/table_index_list.cpp @@ -95,7 +95,7 @@ void TableIndexList::AddIndex(unique_ptr index) { } } -void TableIndexList::RemoveIndex(const string &name) { +void TableIndexList::RemoveIndex(const Identifier &name) { lock_guard lock(index_entries_lock); for (idx_t i = 0; i < index_entries.size(); i++) { auto &index = *index_entries[i]->index; @@ -133,7 +133,7 @@ bool TableIndexList::NameIsUnique(const string &name) { return true; } -optional_ptr TableIndexList::Find(const string &name) { +optional_ptr TableIndexList::Find(const Identifier &name) { for (auto &entry : index_entries) { auto &index = *entry->index; if (index.GetIndexName() == name) { @@ -166,7 +166,7 @@ void TableIndexList::Bind(ClientContext &context, DataTableInfo &table_info, con vector column_names; for (auto &col : table.GetColumns().Logical()) { column_types.push_back(col.Type()); - column_names.push_back(col.Name()); + column_names.emplace_back(col.Name()); } unique_lock lock(index_entries_lock); @@ -205,7 +205,8 @@ void TableIndexList::Bind(ClientContext &context, DataTableInfo &table_info, con // Add the table to the binder. vector dummy_column_ids; - binder->bind_context.AddBaseTable(TableIndex(0), string(), column_names, column_types, dummy_column_ids, table); + binder->bind_context.AddBaseTable(TableIndex(0), Identifier(), StringsToIdentifiers(column_names), column_types, + dummy_column_ids, table); // Create an IndexBinder to bind the index IndexBinder idx_binder(*binder, context); @@ -414,7 +415,6 @@ void TableIndexList::ReferenceIndexChunk(DataChunk &table_chunk, DataChunk &inde auto col_id = mapped_column_ids[i].GetPrimaryIndex(); index_chunk.data[i].Reference(table_chunk.data[col_id]); } - index_chunk.SetCardinality(table_chunk); } } // namespace duckdb diff --git a/src/duckdb/src/storage/temporary_file_manager.cpp b/src/duckdb/src/storage/temporary_file_manager.cpp index 6a8445da6..3e002088d 100644 --- a/src/duckdb/src/storage/temporary_file_manager.cpp +++ b/src/duckdb/src/storage/temporary_file_manager.cpp @@ -6,6 +6,7 @@ #include "duckdb/main/database.hpp" #include "duckdb/main/settings.hpp" #include "duckdb/common/encryption_functions.hpp" +#include "duckdb/storage/storage_options.hpp" #include "zstd.h" namespace duckdb { @@ -142,6 +143,10 @@ idx_t BlockIndexManager::GetMaxIndex() const { return max_index; } +idx_t BlockIndexManager::GetUsedBlockCount() const { + return indexes_in_use.size(); +} + bool BlockIndexManager::HasFreeBlocks() const { return !free_indexes.empty(); } @@ -197,7 +202,7 @@ TemporaryFileHandle::TemporaryFileLock::TemporaryFileLock(mutex &mutex) : lock(m TemporaryFileIndex TemporaryFileHandle::TryGetBlockIndex(idx_t block_header_size) { TemporaryFileLock lock(file_lock); - if (index_manager.GetMaxIndex() >= max_allowed_index && index_manager.HasFreeBlocks()) { + if (index_manager.GetMaxIndex() >= max_allowed_index && !index_manager.HasFreeBlocks()) { // file is at capacity return TemporaryFileIndex(); } @@ -318,7 +323,7 @@ TemporaryFileInformation TemporaryFileHandle::GetTemporaryFile() { TemporaryFileLock lock(file_lock); TemporaryFileInformation info; info.path = path; - info.size = GetPositionInFile(index_manager.GetMaxIndex()); + info.size = GetPositionInFile(index_manager.GetUsedBlockCount()); return info; } @@ -328,7 +333,8 @@ void TemporaryFileHandle::CreateFileIfNotExists(TemporaryFileLock &) { } auto &fs = FileSystem::GetFileSystem(db); auto open_flags = FileFlags::FILE_FLAGS_READ | FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE; - if (db.config.options.use_direct_io) { + // Temp files have no per-file IO_MODE; they follow the global default_io_mode setting. + if (Settings::Get(db) == FileIOMode::DIRECT_IO) { open_flags |= FileFlags::FILE_FLAGS_DIRECT_IO; } handle = fs.OpenFile(path, open_flags); diff --git a/src/duckdb/src/storage/wal_replay.cpp b/src/duckdb/src/storage/wal_replay.cpp index 833885c45..c443b0571 100644 --- a/src/duckdb/src/storage/wal_replay.cpp +++ b/src/duckdb/src/storage/wal_replay.cpp @@ -57,14 +57,14 @@ class ReplayState { WALReplayState replay_state; struct ReplayIndexInfo { - ReplayIndexInfo(TableIndexList &index_list, unique_ptr index, const string &table_schema, - const string &table_name) + ReplayIndexInfo(TableIndexList &index_list, unique_ptr index, const Identifier &table_schema, + const Identifier &table_name) : index_list(index_list), index(std::move(index)), table_schema(table_schema), table_name(table_name) {}; reference index_list; unique_ptr index; - string table_schema; - string table_name; + Identifier table_schema; + Identifier table_name; }; vector replay_index_infos; }; @@ -436,7 +436,7 @@ unique_ptr WriteAheadLogReplayer::ReplayLog(unique_ptrclient_data->profiler; - profiler.AddToCounter(MetricType::WAL_REPLAY_ENTRY_COUNT, replay_entry_count); + profiler.AddToMetricCounter(MetricStorageWALReplayEntryCount::Name, replay_entry_count); } } catch (std::exception &ex) { // LCOV_EXCL_START ErrorData error(ex); @@ -733,8 +733,8 @@ void WriteAheadLogDeserializer::ReplayDropTable() { DropInfo info; info.type = CatalogType::TABLE_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); + info.name = Identifier(deserializer.ReadProperty(102, "name")); if (DeserializeOnly()) { return; } @@ -815,12 +815,13 @@ void WriteAheadLogDeserializer::ReplayAlter() { vector column_names; for (auto &col : column_list.Logical()) { column_types.push_back(col.Type()); - column_names.push_back(col.Name()); + column_names.emplace_back(col.Name()); } // Create a binder to bind the parsed expressions. vector column_indexes; - binder->bind_context.AddBaseTable(TableIndex(0), string(), column_names, column_types, column_indexes, table); + binder->bind_context.AddBaseTable(TableIndex(0), Identifier(), StringsToIdentifiers(column_names), column_types, + column_indexes, table); IndexBinder idx_binder(*binder, context); // Bind the parsed expressions to create unbound expressions. @@ -866,8 +867,8 @@ void WriteAheadLogDeserializer::ReplayCreateView() { void WriteAheadLogDeserializer::ReplayDropView() { DropInfo info; info.type = CatalogType::VIEW_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); + info.name = Identifier(deserializer.ReadProperty(102, "name")); if (DeserializeOnly()) { return; } @@ -879,7 +880,7 @@ void WriteAheadLogDeserializer::ReplayDropView() { //===--------------------------------------------------------------------===// void WriteAheadLogDeserializer::ReplayCreateSchema() { CreateSchemaInfo info; - info.schema = deserializer.ReadProperty(101, "schema"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); if (DeserializeOnly()) { return; } @@ -891,7 +892,7 @@ void WriteAheadLogDeserializer::ReplayDropSchema() { DropInfo info; info.type = CatalogType::SCHEMA_ENTRY; - info.name = deserializer.ReadProperty(101, "schema"); + info.name = Identifier(deserializer.ReadProperty(101, "schema")); if (DeserializeOnly()) { return; } @@ -912,8 +913,8 @@ void WriteAheadLogDeserializer::ReplayDropType() { DropInfo info; info.type = CatalogType::TYPE_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); + info.name = Identifier(deserializer.ReadProperty(102, "name")); if (DeserializeOnly()) { return; } @@ -941,9 +942,9 @@ void WriteAheadLogDeserializer::ReplayCreateTrigger() { void WriteAheadLogDeserializer::ReplayDropTrigger() { DropInfo info; info.type = CatalogType::TRIGGER_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); - auto table_name = deserializer.ReadPropertyWithDefault(103, "table"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); + info.name = Identifier(deserializer.ReadProperty(102, "name")); + auto table_name = deserializer.ReadPropertyWithDefault(103, "table"); if (DeserializeOnly()) { return; } @@ -971,8 +972,8 @@ void WriteAheadLogDeserializer::ReplayCreateSequence() { void WriteAheadLogDeserializer::ReplayDropSequence() { DropInfo info; info.type = CatalogType::SEQUENCE_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); + info.name = Identifier(deserializer.ReadProperty(102, "name")); if (DeserializeOnly()) { return; } @@ -992,7 +993,7 @@ void WriteAheadLogDeserializer::ReplaySequenceValue() { } // fetch the sequence from the catalog - auto &seq = catalog.GetEntry(context, schema, name); + auto &seq = catalog.GetEntry(context, Identifier(schema), Identifier(name)); seq.ReplayValue(usage_count, counter, last_value); } @@ -1011,8 +1012,8 @@ void WriteAheadLogDeserializer::ReplayCreateMacro() { void WriteAheadLogDeserializer::ReplayDropMacro() { DropInfo info; info.type = CatalogType::MACRO_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); + info.name = Identifier(deserializer.ReadProperty(102, "name")); if (DeserializeOnly()) { return; } @@ -1034,8 +1035,8 @@ void WriteAheadLogDeserializer::ReplayCreateTableMacro() { void WriteAheadLogDeserializer::ReplayDropTableMacro() { DropInfo info; info.type = CatalogType::TABLE_MACRO_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); + info.name = Identifier(deserializer.ReadProperty(102, "name")); if (DeserializeOnly()) { return; } @@ -1082,8 +1083,8 @@ void WriteAheadLogDeserializer::ReplayCreateIndex() { void WriteAheadLogDeserializer::ReplayDropIndex() { DropInfo info; info.type = CatalogType::INDEX_ENTRY; - info.schema = deserializer.ReadProperty(101, "schema"); - info.name = deserializer.ReadProperty(102, "name"); + info.schema = Identifier(deserializer.ReadProperty(101, "schema")); + info.name = Identifier(deserializer.ReadProperty(102, "name")); if (DeserializeOnly()) { return; } @@ -1103,8 +1104,8 @@ void WriteAheadLogDeserializer::ReplayDropIndex() { // Replay Data //===--------------------------------------------------------------------===// void WriteAheadLogDeserializer::ReplayUseTable() { - auto schema_name = deserializer.ReadProperty(101, "schema"); - auto table_name = deserializer.ReadProperty(102, "table"); + auto schema_name = deserializer.ReadProperty(101, "schema"); + auto table_name = deserializer.ReadProperty(102, "table"); if (DeserializeOnly()) { return; } @@ -1150,7 +1151,7 @@ void WriteAheadLogDeserializer::ReplayRowGroupData() { } auto &storage = state.current_table->GetStorage(); auto &table_info = storage.GetDataTableInfo(); - auto base_row = storage.GetTotalRows(); + auto base_row = storage.GetNextRowId(); RowGroupCollection new_row_groups(table_info, table_info->GetIOManager(), storage.GetTypes(), base_row); new_row_groups.Initialize(data); @@ -1164,11 +1165,11 @@ void WriteAheadLogDeserializer::ReplayRowGroupData() { column_ids.emplace_back(col.StorageOid()); } Vector row_id_vector(LogicalType::ROW_TYPE, STANDARD_VECTOR_SIZE); - auto row_ids = FlatVector::GetDataMutable(row_id_vector); - auto current_row_id = storage.GetTotalRows(); + auto current_row_id = storage.GetNextRowId(); for (auto &chunk : new_row_groups.Chunks(transaction, column_ids)) { + auto row_id_writer = FlatVector::Writer(row_id_vector, chunk.size()); for (idx_t r = 0; r < chunk.size(); r++) { - row_ids[r] = NumericCast(current_row_id + r); + row_id_writer.WriteValue(NumericCast(current_row_id + r)); } current_row_id += chunk.size(); for (auto &index : indexes.Indexes()) { @@ -1202,9 +1203,9 @@ void WriteAheadLogDeserializer::ReplayDelete() { // Delete the row IDs from the current table. auto &storage = state.current_table->GetStorage(); - auto total_rows = storage.GetTotalRows(); + auto next_row_id = storage.GetNextRowId(); for (idx_t i = 0; i < chunk.size(); i++) { - if (source_ids[i] >= UnsafeNumericCast(total_rows)) { + if (source_ids[i] >= UnsafeNumericCast(next_row_id)) { throw SerializationException("invalid row ID delete in WAL"); } } diff --git a/src/duckdb/src/storage/write_ahead_log.cpp b/src/duckdb/src/storage/write_ahead_log.cpp index 49c4e232f..ee3aa483e 100644 --- a/src/duckdb/src/storage/write_ahead_log.cpp +++ b/src/duckdb/src/storage/write_ahead_log.cpp @@ -262,7 +262,8 @@ void WriteAheadLog::WriteHeader() { auto &single_file_block_manager = database.GetStorageManager().GetBlockManager().Cast(); auto file_version_number = single_file_block_manager.GetVersionNumber(); - if (file_version_number > 66) { + // double check + if (StorageManager::TargetAtLeastVersion(StorageVersion::V1_3_0, file_version_number)) { auto db_identifier = single_file_block_manager.GetDBIdentifier(); serializer.WriteList(102, "db_identifier", MainHeader::DB_IDENTIFIER_LEN, [&](Serializer::List &list, idx_t i) { list.WriteElement(db_identifier[i]); }); @@ -336,7 +337,7 @@ void WriteAheadLog::WriteSequenceValue(SequenceValue val) { serializer.WriteProperty(103, "usage_count", val.usage_count); serializer.WriteProperty(104, "counter", val.counter); // we only support writing last_value from version 2.0.0 onwards - if (storage_manager.GetStorageVersion() >= 8) { + if (StorageManager::TargetAtLeastVersion(StorageVersion::V2_0_0, storage_manager.GetStorageVersion())) { serializer.WriteProperty(105, "last_value", val.entry->GetData().last_value); } serializer.End(); @@ -376,10 +377,11 @@ void WriteAheadLog::WriteDropTableMacro(const TableMacroCatalogEntry &entry) { //===--------------------------------------------------------------------===// void SerializeIndex(AttachedDatabase &db, WriteAheadLogSerializer &serializer, TableIndexList &list, - const string &name) { + const Identifier &name) { case_insensitive_map_t options; auto storage_version = db.GetStorageManager().GetStorageVersion(); - auto v1_0_0_storage = storage_version < 3; + // Before: serialization version 3 + auto v1_0_0_storage = StorageManager::IsPriorToVersion(StorageVersion::V1_2_0, storage_version); if (!v1_0_0_storage) { options["v1_0_0_storage"] = v1_0_0_storage; } @@ -481,7 +483,7 @@ void WriteAheadLog::WriteDropSchema(const SchemaCatalogEntry &entry) { //===--------------------------------------------------------------------===// // DATA //===--------------------------------------------------------------------===// -void WriteAheadLog::WriteSetTable(const string &schema, const string &table) { +void WriteAheadLog::WriteSetTable(const Identifier &schema, const Identifier &table) { WriteAheadLogSerializer serializer(*this, WALType::USE_TABLE); serializer.WriteProperty(101, "schema", schema); serializer.WriteProperty(102, "table", table); diff --git a/src/duckdb/src/transaction/commit_state.cpp b/src/duckdb/src/transaction/commit_state.cpp index 9584393b8..16655ab5e 100644 --- a/src/duckdb/src/transaction/commit_state.cpp +++ b/src/duckdb/src/transaction/commit_state.cpp @@ -35,7 +35,7 @@ void CommitDropState::DropBlock(block_id_t block_id) { dropped_block_ids.push_back(block_id); } -void CommitDropState::RemoveIndex(TableIndexList &indexes, string name) { +void CommitDropState::RemoveIndex(TableIndexList &indexes, Identifier name) { pending_index_removals.push_back(PendingIndexRemoval {indexes, std::move(name)}); } diff --git a/src/duckdb/src/transaction/duck_transaction.cpp b/src/duckdb/src/transaction/duck_transaction.cpp index a7a0c0a88..cad22b173 100644 --- a/src/duckdb/src/transaction/duck_transaction.cpp +++ b/src/duckdb/src/transaction/duck_transaction.cpp @@ -17,6 +17,7 @@ #include "duckdb/main/config.hpp" #include "duckdb/storage/table/column_data.hpp" #include "duckdb/main/client_data.hpp" +#include "duckdb/main/query_profiler.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/storage/storage_lock.hpp" #include "duckdb/storage/table/data_table_info.hpp" @@ -213,11 +214,11 @@ ErrorData DuckTransaction::WriteToWAL(ClientContext &context, AttachedDatabase & commit_state = storage_manager.GenStorageCommitState(*wal); auto &profiler = *context.client_data->profiler; - - auto commit_timer = profiler.StartTimer(MetricType::COMMIT_LOCAL_STORAGE_LATENCY); + auto commit_timer = profiler.StartTimer(); storage->Commit(commit_state.get()); + commit_timer.EndTimer(); - auto wal_timer = profiler.StartTimer(MetricType::WRITE_TO_WAL_LATENCY); + auto wal_timer = profiler.StartTimer(); undo_buffer.WriteToWAL(*wal, commit_state.get()); if (commit_state->HasRowGroupData()) { // if we have optimistically written any data AND we are writing to the WAL, we have written references to @@ -225,6 +226,8 @@ ErrorData DuckTransaction::WriteToWAL(ClientContext &context, AttachedDatabase & // hence we need to ensure those optimistically written blocks are persisted storage_manager.GetBlockManager().FileSync(); } + wal_timer.EndTimer(); + } catch (std::exception &ex) { // Call RevertCommit() outside this try-catch as it itself may throw error_data = ErrorData(ex); diff --git a/src/duckdb/src/transaction/meta_transaction.cpp b/src/duckdb/src/transaction/meta_transaction.cpp index cad81b349..cbf448e96 100644 --- a/src/duckdb/src/transaction/meta_transaction.cpp +++ b/src/duckdb/src/transaction/meta_transaction.cpp @@ -194,7 +194,7 @@ void MetaTransaction::SetActiveQuery(transaction_t query_number) { } } -optional_ptr MetaTransaction::GetReferencedDatabase(const string &name) { +optional_ptr MetaTransaction::GetReferencedDatabase(const Identifier &name) { lock_guard guard(referenced_database_lock); auto entry = used_databases.find(name); if (entry != used_databases.end()) { @@ -203,10 +203,10 @@ optional_ptr MetaTransaction::GetReferencedDatabase(const stri return nullptr; } -shared_ptr MetaTransaction::GetReferencedDatabaseOwning(const string &name) { +shared_ptr MetaTransaction::GetReferencedDatabaseOwning(const Identifier &name) { lock_guard guard(referenced_database_lock); for (auto &entry : referenced_databases) { - if (StringUtil::CIEquals(entry.first.get().name, name)) { + if (entry.first.get().name == name) { return entry.second; } } diff --git a/src/duckdb/src/transaction/transaction_context.cpp b/src/duckdb/src/transaction/transaction_context.cpp index 9c5c3c01e..33884d6f0 100644 --- a/src/duckdb/src/transaction/transaction_context.cpp +++ b/src/duckdb/src/transaction/transaction_context.cpp @@ -7,6 +7,7 @@ #include "duckdb/main/database.hpp" #include "duckdb/transaction/meta_transaction.hpp" #include "duckdb/main/attached_database.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -43,6 +44,14 @@ void TransactionContext::BeginTransaction() { } } +void TransactionContext::SetInvalidationPolicy(TransactionInvalidationPolicy new_invalidation_policy) { + if (new_invalidation_policy == TransactionInvalidationPolicy::STANDARD_POLICY) { + // if no policy is specified explicitly use the default one from the settings + new_invalidation_policy = Settings::Get(context); + } + invalidation_policy = new_invalidation_policy; +} + void TransactionContext::Commit() { if (!current_transaction) { throw TransactionException("failed to commit: no transaction active"); diff --git a/src/duckdb/third_party/concurrentqueue/lightweightsemaphore.h b/src/duckdb/third_party/concurrentqueue/lightweightsemaphore.h index 904275dfa..97d827867 100644 --- a/src/duckdb/third_party/concurrentqueue/lightweightsemaphore.h +++ b/src/duckdb/third_party/concurrentqueue/lightweightsemaphore.h @@ -9,6 +9,16 @@ #include #include // For std::make_signed +// VS2012 doesn't support deleted functions. +// In this case, we declare the function normally but don't define it. +#ifndef MOODYCAMEL_DELETE_FUNCTION +#if defined(_MSC_VER) && _MSC_VER < 1800 +#define MOODYCAMEL_DELETE_FUNCTION +#else +#define MOODYCAMEL_DELETE_FUNCTION = delete +#endif +#endif + #if defined(_WIN32) // Avoid including windows.h in a header; we only need a handful of // items, so we'll redeclare them here (this is relatively safe since diff --git a/src/duckdb/third_party/fmt/include/fmt/format.h b/src/duckdb/third_party/fmt/include/fmt/format.h index 4c5163010..b5ab10c1e 100644 --- a/src/duckdb/third_party/fmt/include/fmt/format.h +++ b/src/duckdb/third_party/fmt/include/fmt/format.h @@ -321,24 +321,13 @@ inline typename Container::value_type* get_data(Container& c) { return c.data(); } -#ifdef _SECURE_SCL -// Make a checked iterator to avoid MSVC warnings. -template using checked_ptr = stdext::checked_array_iterator; -template checked_ptr make_checked(T* p, std::size_t size) { - return {p, size}; -} -#else -template using checked_ptr = T*; -template inline T* make_checked(T* p, std::size_t) { return p; } -#endif - template ::value)> -inline checked_ptr reserve( +inline typename Container::value_type* reserve( std::back_insert_iterator& it, std::size_t n) { Container& c = get_container(it); std::size_t size = c.size(); c.resize(size + n); - return make_checked(get_data(c) + size, n); + return get_data(c) + size; } template @@ -540,7 +529,7 @@ template void buffer::append(const U* begin, const U* end) { std::size_t new_size = size_ + to_unsigned(end - begin); reserve(new_size); - std::uninitialized_copy(begin, end, make_checked(ptr_, capacity_) + size_); + std::uninitialized_copy(begin, end, ptr_ + size_); size_ = new_size; } } // namespace internal @@ -642,7 +631,7 @@ class basic_memory_buffer : private Allocator, public internal::buffer { if (data == other.store_) { this->set(store_, capacity); std::uninitialized_copy(other.store_, other.store_ + size, - internal::make_checked(store_, capacity)); + store_); } else { this->set(data, capacity); // Set pointer to the inline array so that delete is not called @@ -689,7 +678,7 @@ void basic_memory_buffer::grow(std::size_t size) { T* new_data = std::allocator_traits::allocate(*this, new_capacity); // The following code doesn't throw, so the raw pointer above doesn't leak. std::uninitialized_copy(old_data, old_data + this->size(), - internal::make_checked(new_data, new_capacity)); + new_data); this->set(new_data, new_capacity); // deallocate must not throw according to the standard, but even if it does, // the buffer already uses the new storage and will deallocate it in @@ -1565,7 +1554,7 @@ template class basic_writer { } buffer -= s.size(); std::uninitialized_copy(s.data(), s.data() + s.size(), - make_checked(buffer, s.size())); + buffer); }); } }; diff --git a/src/duckdb/third_party/httplib/httplib.hpp b/src/duckdb/third_party/httplib/httplib.hpp index 06e2d9406..7e2b36591 100644 --- a/src/duckdb/third_party/httplib/httplib.hpp +++ b/src/duckdb/third_party/httplib/httplib.hpp @@ -904,9 +904,26 @@ class ThreadPool final : public TaskQueue { public: explicit ThreadPool(size_t n, size_t mqr = 0) : shutdown_(false), max_queued_requests_(mqr) { - while (n) { - threads_.emplace_back(worker(*this)); - n--; + // Exception-safe construction: if std::thread creation fails partway + // (e.g. EAGAIN on pthread_create under thread-resource pressure), the + // partially-built threads_ vector would otherwise destruct with joinable + // std::thread objects, calling std::terminate(). Instead, signal shutdown + // to the workers we already spawned, join them, and rethrow. + try { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } catch (...) { + { + duckdb::unique_lock lock(mutex_); + shutdown_ = true; + } + cond_.notify_all(); + for (auto &t : threads_) { + if (t.joinable()) { t.join(); } + } + throw; } } diff --git a/src/duckdb/extension/jemalloc/include/malloc_ncpus.h b/src/duckdb/third_party/jemalloc/include/duckdb/malloc_ncpus.h similarity index 93% rename from src/duckdb/extension/jemalloc/include/malloc_ncpus.h rename to src/duckdb/third_party/jemalloc/include/duckdb/malloc_ncpus.h index 18b044a44..b7df59819 100644 --- a/src/duckdb/extension/jemalloc/include/malloc_ncpus.h +++ b/src/duckdb/third_party/jemalloc/include/duckdb/malloc_ncpus.h @@ -19,4 +19,4 @@ unsigned duckdb_malloc_ncpus(); } #endif -#endif // MALLOC_NCPUS_H +#endif // MALLOC_NCPUS_H \ No newline at end of file diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/activity_callback.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/activity_callback.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/activity_callback.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/activity_callback.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_externs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_externs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_externs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_externs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_inlines_a.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_inlines_a.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_inlines_a.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_inlines_a.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_inlines_b.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_inlines_b.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_inlines_b.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_inlines_b.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_stats.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_stats.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_stats.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_stats.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_structs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_structs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_structs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_structs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_types.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_types.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/arena_types.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/arena_types.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/assert.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/assert.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/assert.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/assert.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic_c11.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic_c11.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic_c11.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic_c11.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic_gcc_atomic.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic_gcc_atomic.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic_gcc_atomic.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic_gcc_atomic.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic_gcc_sync.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic_gcc_sync.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic_gcc_sync.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic_gcc_sync.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic_msvc.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic_msvc.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/atomic_msvc.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/atomic_msvc.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/background_thread_externs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/background_thread_externs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/background_thread_externs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/background_thread_externs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/background_thread_inlines.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/background_thread_inlines.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/background_thread_inlines.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/background_thread_inlines.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/background_thread_structs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/background_thread_structs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/background_thread_structs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/background_thread_structs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/base.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/base.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/base.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/base.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/batcher.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/batcher.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/batcher.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/batcher.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bin.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/bin.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bin.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/bin.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bin_info.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/bin_info.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bin_info.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/bin_info.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bin_stats.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/bin_stats.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bin_stats.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/bin_stats.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bin_types.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/bin_types.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bin_types.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/bin_types.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bit_util.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/bit_util.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bit_util.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/bit_util.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bitmap.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/bitmap.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/bitmap.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/bitmap.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/buf_writer.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/buf_writer.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/buf_writer.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/buf_writer.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/cache_bin.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/cache_bin.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/cache_bin.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/cache_bin.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ckh.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/ckh.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ckh.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/ckh.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/counter.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/counter.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/counter.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/counter.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ctl.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/ctl.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ctl.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/ctl.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/decay.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/decay.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/decay.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/decay.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/div.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/div.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/div.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/div.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ecache.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/ecache.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ecache.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/ecache.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/edata.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/edata.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/edata.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/edata.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/edata_cache.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/edata_cache.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/edata_cache.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/edata_cache.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ehooks.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/ehooks.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ehooks.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/ehooks.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/emap.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/emap.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/emap.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/emap.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/emitter.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/emitter.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/emitter.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/emitter.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/eset.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/eset.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/eset.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/eset.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/exp_grow.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/exp_grow.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/exp_grow.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/exp_grow.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/extent.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/extent.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/extent.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/extent.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/extent_dss.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/extent_dss.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/extent_dss.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/extent_dss.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/extent_mmap.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/extent_mmap.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/extent_mmap.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/extent_mmap.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/fb.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/fb.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/fb.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/fb.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/fxp.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/fxp.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/fxp.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/fxp.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hash.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/hash.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hash.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/hash.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hook.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/hook.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hook.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/hook.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hpa.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/hpa.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hpa.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/hpa.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hpa_hooks.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/hpa_hooks.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hpa_hooks.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/hpa_hooks.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hpa_opts.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/hpa_opts.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hpa_opts.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/hpa_opts.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hpdata.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/hpdata.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/hpdata.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/hpdata.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/inspect.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/inspect.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/inspect.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/inspect.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_decls.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_decls.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_decls.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_decls.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_defs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_defs.h similarity index 99% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_defs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_defs.h index 5e3998031..dd04c5516 100644 --- a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_defs.h +++ b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_defs.h @@ -205,7 +205,7 @@ /* #undef JEMALLOC_EXPERIMENTAL_SMALLOCX_API */ /* JEMALLOC_PROF enables allocation profiling. */ -#define JEMALLOC_PROF +/* #undef JEMALLOC_PROF */ /* Use libunwind for profile backtracing if defined. */ /* #undef JEMALLOC_PROF_LIBUNWIND */ @@ -214,7 +214,7 @@ /* #undef JEMALLOC_PROF_LIBGCC */ /* Use gcc intrinsics for profile backtracing if defined. */ -#define JEMALLOC_PROF_GCC +/* #undef JEMALLOC_PROF_GCC */ /* JEMALLOC_PAGEID enabled page id */ /* #undef JEMALLOC_PAGEID */ diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_externs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_externs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_externs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_externs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_includes.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_includes.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_includes.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_includes.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_a.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_a.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_a.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_a.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_b.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_b.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_b.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_b.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_c.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_c.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_c.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_inlines_c.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_macros.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_macros.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_macros.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_macros.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_overrides.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_overrides.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_overrides.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_overrides.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_types.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_types.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_internal_types.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_internal_types.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_preamble.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_preamble.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/jemalloc_preamble.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/jemalloc_preamble.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/large_externs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/large_externs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/large_externs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/large_externs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/lockedint.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/lockedint.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/lockedint.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/lockedint.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/log.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/log.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/log.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/log.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/malloc_io.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/malloc_io.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/malloc_io.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/malloc_io.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/mpsc_queue.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/mpsc_queue.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/mpsc_queue.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/mpsc_queue.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/mutex.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/mutex.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/mutex.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/mutex.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/mutex_prof.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/mutex_prof.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/mutex_prof.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/mutex_prof.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/nstime.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/nstime.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/nstime.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/nstime.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/pa.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/pa.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/pa.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/pa.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/pac.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/pac.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/pac.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/pac.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/pages.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/pages.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/pages.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/pages.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/pai.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/pai.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/pai.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/pai.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/peak.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/peak.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/peak.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/peak.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/peak_event.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/peak_event.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/peak_event.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/peak_event.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ph.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/ph.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ph.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/ph.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/private_namespace.gen.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/private_namespace.gen.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/private_namespace.gen.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/private_namespace.gen.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/private_namespace.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/private_namespace.h similarity index 99% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/private_namespace.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/private_namespace.h index efe904660..0d7243e06 100644 --- a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/private_namespace.h +++ b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/private_namespace.h @@ -748,6 +748,7 @@ #define opt_malloc_conf_symlink JEMALLOC_N(opt_malloc_conf_symlink) #define opt_prof_bt_max JEMALLOC_N(opt_prof_bt_max) #define opt_prof_pid_namespace JEMALLOC_N(opt_prof_pid_namespace) +#define prof_get_pid_namespace JEMALLOC_N(prof_get_pid_namespace) #define os_page JEMALLOC_N(os_page) #define pa_shard_nactive JEMALLOC_N(pa_shard_nactive) #define pa_shard_ndirty JEMALLOC_N(pa_shard_ndirty) diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prng.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prng.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prng.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prng.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_data.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_data.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_data.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_data.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_externs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_externs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_externs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_externs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_hook.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_hook.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_hook.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_hook.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_inlines.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_inlines.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_inlines.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_inlines.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_log.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_log.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_log.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_log.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_recent.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_recent.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_recent.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_recent.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_stats.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_stats.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_stats.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_stats.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_structs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_structs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_structs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_structs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_sys.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_sys.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_sys.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_sys.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_types.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_types.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/prof_types.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/prof_types.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/psset.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/psset.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/psset.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/psset.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/public_namespace.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/public_namespace.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/public_namespace.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/public_namespace.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/public_unnamespace.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/public_unnamespace.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/public_unnamespace.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/public_unnamespace.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ql.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/ql.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ql.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/ql.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/qr.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/qr.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/qr.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/qr.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/quantum.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/quantum.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/quantum.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/quantum.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/rb.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/rb.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/rb.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/rb.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/rtree.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/rtree.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/rtree.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/rtree.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/rtree_tsd.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/rtree_tsd.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/rtree_tsd.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/rtree_tsd.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/safety_check.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/safety_check.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/safety_check.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/safety_check.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/san.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/san.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/san.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/san.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/san_bump.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/san_bump.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/san_bump.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/san_bump.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/sc.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/sc.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/sc.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/sc.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/sec.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/sec.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/sec.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/sec.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/sec_opts.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/sec_opts.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/sec_opts.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/sec_opts.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/seq.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/seq.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/seq.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/seq.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/slab_data.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/slab_data.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/slab_data.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/slab_data.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/smoothstep.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/smoothstep.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/smoothstep.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/smoothstep.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/spin.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/spin.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/spin.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/spin.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/stats.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/stats.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/stats.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/stats.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/sz.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/sz.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/sz.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/sz.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tcache_externs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tcache_externs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tcache_externs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tcache_externs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tcache_inlines.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tcache_inlines.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tcache_inlines.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tcache_inlines.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tcache_structs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tcache_structs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tcache_structs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tcache_structs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tcache_types.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tcache_types.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tcache_types.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tcache_types.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/test_hooks.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/test_hooks.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/test_hooks.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/test_hooks.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/thread_event.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/thread_event.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/thread_event.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/thread_event.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ticker.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/ticker.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/ticker.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/ticker.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_generic.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_generic.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_generic.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_generic.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_internals.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_internals.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_internals.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_internals.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_malloc_thread_cleanup.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_malloc_thread_cleanup.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_malloc_thread_cleanup.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_malloc_thread_cleanup.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_tls.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_tls.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_tls.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_tls.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_types.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_types.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_types.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_types.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_win.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_win.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/tsd_win.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/tsd_win.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/typed_list.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/typed_list.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/typed_list.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/typed_list.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/util.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/util.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/util.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/util.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/witness.h b/src/duckdb/third_party/jemalloc/include/jemalloc/internal/witness.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/internal/witness.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/internal/witness.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_defs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_defs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_defs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_defs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_macros.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_macros.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_macros.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_macros.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_mangle.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_mangle.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_mangle.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_mangle.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_mangle_jet.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_mangle_jet.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_mangle_jet.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_mangle_jet.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_protos.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_protos.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_protos.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_protos.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_protos_jet.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_protos_jet.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_protos_jet.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_protos_jet.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_rename.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_rename.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_rename.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_rename.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_typedefs.h b/src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_typedefs.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/jemalloc/jemalloc_typedefs.h rename to src/duckdb/third_party/jemalloc/include/jemalloc/jemalloc_typedefs.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/msvc_compat/C99/stdbool.h b/src/duckdb/third_party/jemalloc/include/msvc_compat/C99/stdbool.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/msvc_compat/C99/stdbool.h rename to src/duckdb/third_party/jemalloc/include/msvc_compat/C99/stdbool.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/msvc_compat/C99/stdint.h b/src/duckdb/third_party/jemalloc/include/msvc_compat/C99/stdint.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/msvc_compat/C99/stdint.h rename to src/duckdb/third_party/jemalloc/include/msvc_compat/C99/stdint.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/msvc_compat/strings.h b/src/duckdb/third_party/jemalloc/include/msvc_compat/strings.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/msvc_compat/strings.h rename to src/duckdb/third_party/jemalloc/include/msvc_compat/strings.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/include/msvc_compat/windows_extra.h b/src/duckdb/third_party/jemalloc/include/msvc_compat/windows_extra.h similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/include/msvc_compat/windows_extra.h rename to src/duckdb/third_party/jemalloc/include/msvc_compat/windows_extra.h diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/arena.c b/src/duckdb/third_party/jemalloc/src/arena.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/arena.c rename to src/duckdb/third_party/jemalloc/src/arena.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/background_thread.c b/src/duckdb/third_party/jemalloc/src/background_thread.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/background_thread.c rename to src/duckdb/third_party/jemalloc/src/background_thread.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/base.c b/src/duckdb/third_party/jemalloc/src/base.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/base.c rename to src/duckdb/third_party/jemalloc/src/base.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/batcher.c b/src/duckdb/third_party/jemalloc/src/batcher.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/batcher.c rename to src/duckdb/third_party/jemalloc/src/batcher.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/bin.c b/src/duckdb/third_party/jemalloc/src/bin.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/bin.c rename to src/duckdb/third_party/jemalloc/src/bin.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/bin_info.c b/src/duckdb/third_party/jemalloc/src/bin_info.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/bin_info.c rename to src/duckdb/third_party/jemalloc/src/bin_info.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/bitmap.c b/src/duckdb/third_party/jemalloc/src/bitmap.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/bitmap.c rename to src/duckdb/third_party/jemalloc/src/bitmap.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/buf_writer.c b/src/duckdb/third_party/jemalloc/src/buf_writer.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/buf_writer.c rename to src/duckdb/third_party/jemalloc/src/buf_writer.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/cache_bin.c b/src/duckdb/third_party/jemalloc/src/cache_bin.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/cache_bin.c rename to src/duckdb/third_party/jemalloc/src/cache_bin.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/ckh.c b/src/duckdb/third_party/jemalloc/src/ckh.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/ckh.c rename to src/duckdb/third_party/jemalloc/src/ckh.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/counter.c b/src/duckdb/third_party/jemalloc/src/counter.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/counter.c rename to src/duckdb/third_party/jemalloc/src/counter.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/ctl.c b/src/duckdb/third_party/jemalloc/src/ctl.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/ctl.c rename to src/duckdb/third_party/jemalloc/src/ctl.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/decay.c b/src/duckdb/third_party/jemalloc/src/decay.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/decay.c rename to src/duckdb/third_party/jemalloc/src/decay.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/div.c b/src/duckdb/third_party/jemalloc/src/div.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/div.c rename to src/duckdb/third_party/jemalloc/src/div.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/ecache.c b/src/duckdb/third_party/jemalloc/src/ecache.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/ecache.c rename to src/duckdb/third_party/jemalloc/src/ecache.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/edata.c b/src/duckdb/third_party/jemalloc/src/edata.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/edata.c rename to src/duckdb/third_party/jemalloc/src/edata.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/edata_cache.c b/src/duckdb/third_party/jemalloc/src/edata_cache.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/edata_cache.c rename to src/duckdb/third_party/jemalloc/src/edata_cache.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/ehooks.c b/src/duckdb/third_party/jemalloc/src/ehooks.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/ehooks.c rename to src/duckdb/third_party/jemalloc/src/ehooks.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/emap.c b/src/duckdb/third_party/jemalloc/src/emap.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/emap.c rename to src/duckdb/third_party/jemalloc/src/emap.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/eset.c b/src/duckdb/third_party/jemalloc/src/eset.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/eset.c rename to src/duckdb/third_party/jemalloc/src/eset.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/exp_grow.c b/src/duckdb/third_party/jemalloc/src/exp_grow.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/exp_grow.c rename to src/duckdb/third_party/jemalloc/src/exp_grow.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/extent.c b/src/duckdb/third_party/jemalloc/src/extent.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/extent.c rename to src/duckdb/third_party/jemalloc/src/extent.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/extent_dss.c b/src/duckdb/third_party/jemalloc/src/extent_dss.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/extent_dss.c rename to src/duckdb/third_party/jemalloc/src/extent_dss.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/extent_mmap.c b/src/duckdb/third_party/jemalloc/src/extent_mmap.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/extent_mmap.c rename to src/duckdb/third_party/jemalloc/src/extent_mmap.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/fxp.c b/src/duckdb/third_party/jemalloc/src/fxp.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/fxp.c rename to src/duckdb/third_party/jemalloc/src/fxp.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/hook.c b/src/duckdb/third_party/jemalloc/src/hook.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/hook.c rename to src/duckdb/third_party/jemalloc/src/hook.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/hpa.c b/src/duckdb/third_party/jemalloc/src/hpa.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/hpa.c rename to src/duckdb/third_party/jemalloc/src/hpa.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/hpa_hooks.c b/src/duckdb/third_party/jemalloc/src/hpa_hooks.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/hpa_hooks.c rename to src/duckdb/third_party/jemalloc/src/hpa_hooks.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/hpdata.c b/src/duckdb/third_party/jemalloc/src/hpdata.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/hpdata.c rename to src/duckdb/third_party/jemalloc/src/hpdata.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/inspect.c b/src/duckdb/third_party/jemalloc/src/inspect.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/inspect.c rename to src/duckdb/third_party/jemalloc/src/inspect.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/jemalloc.c b/src/duckdb/third_party/jemalloc/src/jemalloc.c similarity index 99% rename from src/duckdb/extension/jemalloc/jemalloc/src/jemalloc.c rename to src/duckdb/third_party/jemalloc/src/jemalloc.c index f9591211d..dfb0c160c 100644 --- a/src/duckdb/extension/jemalloc/jemalloc/src/jemalloc.c +++ b/src/duckdb/third_party/jemalloc/src/jemalloc.c @@ -25,7 +25,7 @@ #include "jemalloc/internal/thread_event.h" #include "jemalloc/internal/util.h" -#include "malloc_ncpus.h" +#include "duckdb/malloc_ncpus.h" /******************************************************************************/ /* Data. */ diff --git a/src/duckdb/third_party/jemalloc/src/jemalloc_cpp.cpp b/src/duckdb/third_party/jemalloc/src/jemalloc_cpp.cpp new file mode 100644 index 000000000..aedccfa55 --- /dev/null +++ b/src/duckdb/third_party/jemalloc/src/jemalloc_cpp.cpp @@ -0,0 +1,312 @@ +#include +#include +// NOLINTBEGIN(misc-use-anonymous-namespace) + +#ifndef DUCKDB_OVERRIDE_NEW_DELETE +#error "OVERRIDE_NEW_DELETE not properly defined" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#include "jemalloc/internal/jemalloc_preamble.h" +#include "jemalloc/internal/jemalloc_internal_includes.h" + +#ifdef __cplusplus +} +#endif + +// All operators in this file are exported. + +// Possibly alias hidden versions of malloc and sdallocx to avoid an extra plt +// thunk? +// +// extern __typeof (sdallocx) sdallocx_int +// __attribute ((alias ("sdallocx"), +// visibility ("hidden"))); +// +// ... but it needs to work with jemalloc namespaces. + +void *operator new(std::size_t size); +void *operator new[](std::size_t size); +void *operator new(std::size_t size, const std::nothrow_t &) noexcept; +void *operator new[](std::size_t size, const std::nothrow_t &) noexcept; +void operator delete(void *ptr) noexcept; +void operator delete[](void *ptr) noexcept; +void operator delete(void *ptr, const std::nothrow_t &) noexcept; +void operator delete[](void *ptr, const std::nothrow_t &) noexcept; + +#if __cpp_sized_deallocation >= 201309 +/* C++14's sized-delete operators. */ +void operator delete(void *ptr, std::size_t size) noexcept; +void operator delete[](void *ptr, std::size_t size) noexcept; +#endif + +#if __cpp_aligned_new >= 201606 +/* C++17's over-aligned operators. */ +void *operator new(std::size_t size, std::align_val_t); +void *operator new(std::size_t size, std::align_val_t, const std::nothrow_t &) noexcept; +void *operator new[](std::size_t size, std::align_val_t); +void *operator new[](std::size_t size, std::align_val_t, const std::nothrow_t &) noexcept; +void operator delete(void* ptr, std::align_val_t) noexcept; +void operator delete(void* ptr, std::align_val_t, const std::nothrow_t &) noexcept; +void operator delete(void* ptr, std::size_t size, std::align_val_t al) noexcept; +void operator delete[](void* ptr, std::align_val_t) noexcept; +void operator delete[](void* ptr, std::align_val_t, const std::nothrow_t &) noexcept; +void operator delete[](void* ptr, std::size_t size, std::align_val_t al) noexcept; +#endif + +JEMALLOC_NOINLINE +static void * +handleOOM(std::size_t size, bool nothrow) { + if (opt_experimental_infallible_new) { + const char *huge_warning = (size >= ((std::size_t)1 << 30)) ? + "This may be caused by heap corruption, if the large size " + "is unexpected (suggest building with sanitizers for " + "debugging)." : ""; + + safety_check_fail(": Allocation of size %zu failed. " + "%s opt.experimental_infallible_new is true. Aborting.\n", + size, huge_warning); + return nullptr; + } + + void *ptr = nullptr; + + while (ptr == nullptr) { + std::new_handler handler; + // GCC-4.8 and clang 4.0 do not have std::get_new_handler. + { + static std::mutex mtx; + std::lock_guard lock(mtx); + + handler = std::set_new_handler(nullptr); + std::set_new_handler(handler); + } + if (handler == nullptr) + break; + + try { + handler(); + } catch (const std::bad_alloc &) { + break; + } + + ptr = je_malloc(size); + } + + if (ptr == nullptr && !nothrow) + std::__throw_bad_alloc(); + return ptr; +} + +template +JEMALLOC_NOINLINE +static void * +fallbackNewImpl(std::size_t size) noexcept(IsNoExcept) { + void *ptr = malloc_default(size); + if (likely(ptr != nullptr)) { + return ptr; + } + return handleOOM(size, IsNoExcept); +} + +template +JEMALLOC_ALWAYS_INLINE +void * +newImpl(std::size_t size) noexcept(IsNoExcept) { + LOG("core.operator_new.entry", "size: %zu", size); + + void * ret = imalloc_fastpath(size, &fallbackNewImpl); + + LOG("core.operator_new.exit", "result: %p", ret); + return ret; +} + +void * +operator new(std::size_t size) { + return newImpl(size); +} + +void * +operator new[](std::size_t size) { + return newImpl(size); +} + +void * +operator new(std::size_t size, const std::nothrow_t &) noexcept { + return newImpl(size); +} + +void * +operator new[](std::size_t size, const std::nothrow_t &) noexcept { + return newImpl(size); +} + +#if __cpp_aligned_new >= 201606 + +template +JEMALLOC_ALWAYS_INLINE +void * +alignedNewImpl(std::size_t size, std::align_val_t alignment) noexcept(IsNoExcept) { + void *ptr = je_aligned_alloc(static_cast(alignment), size); + if (likely(ptr != nullptr)) { + return ptr; + } + + return handleOOM(size, IsNoExcept); +} + +void * +operator new(std::size_t size, std::align_val_t alignment) { + return alignedNewImpl(size, alignment); +} + +void * +operator new[](std::size_t size, std::align_val_t alignment) { + return alignedNewImpl(size, alignment); +} + +void * +operator new(std::size_t size, std::align_val_t alignment, const std::nothrow_t &) noexcept { + return alignedNewImpl(size, alignment); +} + +void * +operator new[](std::size_t size, std::align_val_t alignment, const std::nothrow_t &) noexcept { + return alignedNewImpl(size, alignment); +} + +#endif // __cpp_aligned_new + +void +operator delete(void *ptr) noexcept { + LOG("core.operator_delete.entry", "ptr: %p", ptr); + + je_free_impl(ptr); + + LOG("core.operator_delete.exit", ""); +} + +void +operator delete[](void *ptr) noexcept { + LOG("core.operator_delete.entry", "ptr: %p", ptr); + + je_free_impl(ptr); + + LOG("core.operator_delete.exit", ""); +} + +void +operator delete(void *ptr, const std::nothrow_t &) noexcept { + LOG("core.operator_delete.entry", "ptr: %p", ptr); + + je_free_impl(ptr); + + LOG("core.operator_delete.exit", ""); +} + +void operator delete[](void *ptr, const std::nothrow_t &) noexcept { + LOG("core.operator_delete.entry", "ptr: %p", ptr); + + je_free_impl(ptr); + + LOG("core.operator_delete.exit", ""); +} + +#if __cpp_sized_deallocation >= 201309 + +JEMALLOC_ALWAYS_INLINE +void +sizedDeleteImpl(void* ptr, std::size_t size) noexcept { + if (unlikely(ptr == nullptr)) { + return; + } + LOG("core.operator_delete.entry", "ptr: %p, size: %zu", ptr, size); + + je_sdallocx_noflags(ptr, size); + + LOG("core.operator_delete.exit", ""); +} + +void +operator delete(void *ptr, std::size_t size) noexcept { + sizedDeleteImpl(ptr, size); +} + +void +operator delete[](void *ptr, std::size_t size) noexcept { + sizedDeleteImpl(ptr, size); +} + +#endif // __cpp_sized_deallocation + +#if __cpp_aligned_new >= 201606 + +JEMALLOC_ALWAYS_INLINE +void +alignedSizedDeleteImpl(void* ptr, std::size_t size, std::align_val_t alignment) + noexcept { + if (config_debug) { + assert(((size_t)alignment & ((size_t)alignment - 1)) == 0); + } + if (unlikely(ptr == nullptr)) { + return; + } + LOG("core.operator_delete.entry", "ptr: %p, size: %zu, alignment: %zu", + ptr, size, alignment); + + je_sdallocx_impl(ptr, size, MALLOCX_ALIGN(alignment)); + + LOG("core.operator_delete.exit", ""); +} + +void +operator delete(void* ptr, std::align_val_t) noexcept { + LOG("core.operator_delete.entry", "ptr: %p", ptr); + + je_free_impl(ptr); + + LOG("core.operator_delete.exit", ""); +} + +void +operator delete[](void* ptr, std::align_val_t) noexcept { + LOG("core.operator_delete.entry", "ptr: %p", ptr); + + je_free_impl(ptr); + + LOG("core.operator_delete.exit", ""); +} + +void +operator delete(void* ptr, std::align_val_t, const std::nothrow_t&) noexcept { + LOG("core.operator_delete.entry", "ptr: %p", ptr); + + je_free_impl(ptr); + + LOG("core.operator_delete.exit", ""); +} + +void +operator delete[](void* ptr, std::align_val_t, const std::nothrow_t&) noexcept { + LOG("core.operator_delete.entry", "ptr: %p", ptr); + + je_free_impl(ptr); + + LOG("core.operator_delete.exit", ""); +} + +void +operator delete(void* ptr, std::size_t size, std::align_val_t alignment) noexcept { + alignedSizedDeleteImpl(ptr, size, alignment); +} + +void +operator delete[](void* ptr, std::size_t size, std::align_val_t alignment) noexcept { + alignedSizedDeleteImpl(ptr, size, alignment); +} + +#endif // __cpp_aligned_new +// NOLINTEND(misc-use-anonymous-namespace) diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/large.c b/src/duckdb/third_party/jemalloc/src/large.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/large.c rename to src/duckdb/third_party/jemalloc/src/large.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/log.c b/src/duckdb/third_party/jemalloc/src/log.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/log.c rename to src/duckdb/third_party/jemalloc/src/log.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/malloc_io.c b/src/duckdb/third_party/jemalloc/src/malloc_io.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/malloc_io.c rename to src/duckdb/third_party/jemalloc/src/malloc_io.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/mutex.c b/src/duckdb/third_party/jemalloc/src/mutex.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/mutex.c rename to src/duckdb/third_party/jemalloc/src/mutex.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/nstime.c b/src/duckdb/third_party/jemalloc/src/nstime.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/nstime.c rename to src/duckdb/third_party/jemalloc/src/nstime.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/pa.c b/src/duckdb/third_party/jemalloc/src/pa.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/pa.c rename to src/duckdb/third_party/jemalloc/src/pa.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/pa_extra.c b/src/duckdb/third_party/jemalloc/src/pa_extra.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/pa_extra.c rename to src/duckdb/third_party/jemalloc/src/pa_extra.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/pac.c b/src/duckdb/third_party/jemalloc/src/pac.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/pac.c rename to src/duckdb/third_party/jemalloc/src/pac.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/pages.c b/src/duckdb/third_party/jemalloc/src/pages.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/pages.c rename to src/duckdb/third_party/jemalloc/src/pages.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/pai.c b/src/duckdb/third_party/jemalloc/src/pai.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/pai.c rename to src/duckdb/third_party/jemalloc/src/pai.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/peak_event.c b/src/duckdb/third_party/jemalloc/src/peak_event.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/peak_event.c rename to src/duckdb/third_party/jemalloc/src/peak_event.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/prof.c b/src/duckdb/third_party/jemalloc/src/prof.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/prof.c rename to src/duckdb/third_party/jemalloc/src/prof.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/prof_data.c b/src/duckdb/third_party/jemalloc/src/prof_data.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/prof_data.c rename to src/duckdb/third_party/jemalloc/src/prof_data.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/prof_log.c b/src/duckdb/third_party/jemalloc/src/prof_log.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/prof_log.c rename to src/duckdb/third_party/jemalloc/src/prof_log.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/prof_recent.c b/src/duckdb/third_party/jemalloc/src/prof_recent.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/prof_recent.c rename to src/duckdb/third_party/jemalloc/src/prof_recent.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/prof_stats.c b/src/duckdb/third_party/jemalloc/src/prof_stats.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/prof_stats.c rename to src/duckdb/third_party/jemalloc/src/prof_stats.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/prof_sys.c b/src/duckdb/third_party/jemalloc/src/prof_sys.c similarity index 96% rename from src/duckdb/extension/jemalloc/jemalloc/src/prof_sys.c rename to src/duckdb/third_party/jemalloc/src/prof_sys.c index 4ae97ccce..91e184fee 100644 --- a/src/duckdb/extension/jemalloc/jemalloc/src/prof_sys.c +++ b/src/duckdb/third_party/jemalloc/src/prof_sys.c @@ -18,7 +18,26 @@ * already hooked _Unwind_Backtrace. We'll temporarily disable hooking. */ #undef _Unwind_Backtrace +/* + * libunwind's (pulled in by libunwind-dev on the musl extension + * build images) shadows gcc's own on the include path and only + * declares _Unwind_Backtrace under _GNU_SOURCE. Without it the declaration is + * missing, and gcc 14 treats implicit function declarations as errors in + * C99-and-later modes. + * + * Define _GNU_SOURCE just for this include and restore the prior state right + * after, so the change stays confined to and does not alter the + * feature-test posture seen by the rest of this translation unit. + */ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#define JE_DUCKDB_TMP_GNU_SOURCE +#endif #include +#ifdef JE_DUCKDB_TMP_GNU_SOURCE +#undef _GNU_SOURCE +#undef JE_DUCKDB_TMP_GNU_SOURCE +#endif #define _Unwind_Backtrace JEMALLOC_TEST_HOOK(_Unwind_Backtrace, test_hooks_libc_hook) #endif diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/psset.c b/src/duckdb/third_party/jemalloc/src/psset.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/psset.c rename to src/duckdb/third_party/jemalloc/src/psset.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/rtree.c b/src/duckdb/third_party/jemalloc/src/rtree.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/rtree.c rename to src/duckdb/third_party/jemalloc/src/rtree.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/safety_check.c b/src/duckdb/third_party/jemalloc/src/safety_check.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/safety_check.c rename to src/duckdb/third_party/jemalloc/src/safety_check.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/san.c b/src/duckdb/third_party/jemalloc/src/san.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/san.c rename to src/duckdb/third_party/jemalloc/src/san.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/san_bump.c b/src/duckdb/third_party/jemalloc/src/san_bump.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/san_bump.c rename to src/duckdb/third_party/jemalloc/src/san_bump.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/sc.c b/src/duckdb/third_party/jemalloc/src/sc.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/sc.c rename to src/duckdb/third_party/jemalloc/src/sc.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/sec.c b/src/duckdb/third_party/jemalloc/src/sec.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/sec.c rename to src/duckdb/third_party/jemalloc/src/sec.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/stats.c b/src/duckdb/third_party/jemalloc/src/stats.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/stats.c rename to src/duckdb/third_party/jemalloc/src/stats.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/sz.c b/src/duckdb/third_party/jemalloc/src/sz.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/sz.c rename to src/duckdb/third_party/jemalloc/src/sz.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/tcache.c b/src/duckdb/third_party/jemalloc/src/tcache.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/tcache.c rename to src/duckdb/third_party/jemalloc/src/tcache.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/test_hooks.c b/src/duckdb/third_party/jemalloc/src/test_hooks.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/test_hooks.c rename to src/duckdb/third_party/jemalloc/src/test_hooks.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/thread_event.c b/src/duckdb/third_party/jemalloc/src/thread_event.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/thread_event.c rename to src/duckdb/third_party/jemalloc/src/thread_event.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/ticker.c b/src/duckdb/third_party/jemalloc/src/ticker.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/ticker.c rename to src/duckdb/third_party/jemalloc/src/ticker.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/tsd.c b/src/duckdb/third_party/jemalloc/src/tsd.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/tsd.c rename to src/duckdb/third_party/jemalloc/src/tsd.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/util.c b/src/duckdb/third_party/jemalloc/src/util.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/util.c rename to src/duckdb/third_party/jemalloc/src/util.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/witness.c b/src/duckdb/third_party/jemalloc/src/witness.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/witness.c rename to src/duckdb/third_party/jemalloc/src/witness.c diff --git a/src/duckdb/extension/jemalloc/jemalloc/src/zone.c b/src/duckdb/third_party/jemalloc/src/zone.c similarity index 100% rename from src/duckdb/extension/jemalloc/jemalloc/src/zone.c rename to src/duckdb/third_party/jemalloc/src/zone.c diff --git a/src/duckdb/third_party/utf8proc/include/utf8proc_wrapper.hpp b/src/duckdb/third_party/utf8proc/include/utf8proc_wrapper.hpp index 8edab2a1a..be8530b40 100644 --- a/src/duckdb/third_party/utf8proc/include/utf8proc_wrapper.hpp +++ b/src/duckdb/third_party/utf8proc/include/utf8proc_wrapper.hpp @@ -36,6 +36,20 @@ class Utf8Proc { //! Returns the position (in bytes) of the previous grapheme cluster static size_t PreviousGraphemeCluster(const char *s, size_t len, size_t pos); + //! Find the next legal-UTF8 string + static bool FindNextLegalUTF8(std::string &str); + + //! Given a string that may end with a truncated (incomplete) UTF-8 sequence, + //! compute the smallest valid UTF-8 string that is strictly greater than every + //! string that shares the same byte prefix. Returns false when no finite upper + //! bound exists (e.g. the string is empty or consists entirely of U+10FFFF). + static bool ValidUpperBound(const std::string &str, std::string &result); + + //! Given a string that may contain an invalid or truncated UTF-8 sequence, + //! compute a valid UTF-8 lower bound by replacing the first invalid byte (and + //! everything after it) with a NUL byte. + static bool ValidLowerBound(const std::string &str, std::string &result); + //! Transform a codepoint to utf8 and writes it to "c", sets "sz" to the size of the codepoint static bool CodepointToUtf8(int cp, int &sz, char *c); //! Returns the codepoint length in bytes when encoded in UTF8 diff --git a/src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp b/src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp index c8ae60a61..9c406949d 100644 --- a/src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp +++ b/src/duckdb/third_party/utf8proc/utf8proc_wrapper.cpp @@ -3,7 +3,7 @@ #include "duckdb/common/assert.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/helper.hpp" - +#include "duckdb/function/scalar/string_common.hpp" namespace duckdb { @@ -406,4 +406,98 @@ size_t Utf8Proc::RenderWidth(const std::string &str) { return render_width; } +bool Utf8Proc::ValidUpperBound(const string &str, string &result) { + if (str.empty()) { + return false; + } + auto *data = reinterpret_cast(str.data()); + idx_t pos = str.size(); + + // Find the start of the rightmost (possibly incomplete) sequence. + idx_t start = pos - 1; + while (start > 0 && (data[start] & 0xC0) == 0x80) { + start--; + } + + uint8_t lead = data[start]; + int total_len; + if ((lead & 0x80) == 0) { + total_len = 1; + } else if ((lead & 0xE0) == 0xC0) { + total_len = 2; + } else if ((lead & 0xF0) == 0xE0) { + total_len = 3; + } else if ((lead & 0xF8) == 0xF0) { + total_len = 4; + } else { + // Invalid leading byte; strip it and let FindNextLegalUTF8 backtrack. + result = string(str.data(), start); + return FindNextLegalUTF8(result); + } + + // Fill any missing continuation bytes with 0xBF so the sequence is complete, + // then delegate the increment (and any backtracking) to FindNextLegalUTF8. + int present = static_cast(pos - start); + result = str.substr(0, pos); + for (int i = present; i < total_len; i++) { + result += '\xBF'; + } + return FindNextLegalUTF8(result); +} + +bool Utf8Proc::ValidLowerBound(const string &str, string &result) { + size_t invalid_pos; + if (Analyze(str.c_str(), str.size(), nullptr, &invalid_pos) != UnicodeType::INVALID) { + result = str; + return true; + } + // Everything before invalid_pos is valid UTF-8. A NUL byte is always less + // than any byte a valid continuation or leading byte could produce (>= 0x80). + result = str.substr(0, invalid_pos) + '\0'; + return true; +} + +bool Utf8Proc::FindNextLegalUTF8(string &str) { + while (!str.empty()) { + // Find the start of the last codepoint. + idx_t last_codepoint_start; + for (last_codepoint_start = str.size(); last_codepoint_start > 0; last_codepoint_start--) { + if (IsCharacter(str[last_codepoint_start - 1])) { + break; + } + } + if (last_codepoint_start == 0) { + throw InvalidInputException("Invalid UTF8 found in string \"%s\"", str); + } + last_codepoint_start--; + // Surrogates (U+D800–U+DFFF, encoded as ED A0 80–ED BF BF) cannot be decoded + // with UTF8ToCodepoint, which throws for them. This can happen when ValidUpperBound + // fills a truncated ED-prefixed sequence with 0xBF bytes. Jump straight to U+E000. + auto *seq = reinterpret_cast(str.c_str() + last_codepoint_start); + if (seq[0] == 0xED && (seq[1] & 0xE0) == 0xA0) { + char cp_bytes[4]; + int cp_size; + Utf8Proc::CodepointToUtf8(0xE000, cp_size, cp_bytes); + str = str.substr(0, last_codepoint_start) + string(cp_bytes, static_cast(cp_size)); + return true; + } + int codepoint_size; + auto codepoint = Utf8Proc::UTF8ToCodepoint(str.c_str() + last_codepoint_start, codepoint_size) + 1; + if (codepoint >= 0xD800 && codepoint <= 0xDFFF) { + // incremented codepoint falls within surrogate range; skip to next valid character + codepoint = 0xE000; + } + char next_codepoint_text[4]; + int next_codepoint_size; + if (Utf8Proc::CodepointToUtf8(codepoint, next_codepoint_size, next_codepoint_text)) { + auto s = static_cast(next_codepoint_size); + str = str.substr(0, last_codepoint_start) + string(next_codepoint_text, s); + return true; + } + // Last codepoint was U+10FFFF; pop it and try the preceding one. + str = str.substr(0, last_codepoint_start); + } + return false; +} + } // namespace duckdb diff --git a/src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp b/src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp index 4373c3dbb..1fdfd5cdb 100644 --- a/src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp +++ b/src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp @@ -1,8 +1,8 @@ -#include "extension/core_functions/aggregate/algebraic/covar.cpp" +#include "extension/core_functions/aggregate/algebraic/avg.cpp" -#include "extension/core_functions/aggregate/algebraic/corr.cpp" +#include "extension/core_functions/aggregate/algebraic/covar.cpp" #include "extension/core_functions/aggregate/algebraic/stddev.cpp" -#include "extension/core_functions/aggregate/algebraic/avg.cpp" +#include "extension/core_functions/aggregate/algebraic/corr.cpp" diff --git a/src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp b/src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp index e9775f68a..ffd7c8dad 100644 --- a/src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp +++ b/src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp @@ -1,20 +1,20 @@ -#include "extension/core_functions/aggregate/distributive/string_agg.cpp" - -#include "extension/core_functions/aggregate/distributive/product.cpp" - #include "extension/core_functions/aggregate/distributive/approx_count.cpp" -#include "extension/core_functions/aggregate/distributive/bool.cpp" - #include "extension/core_functions/aggregate/distributive/arg_min_max.cpp" -#include "extension/core_functions/aggregate/distributive/sum.cpp" +#include "extension/core_functions/aggregate/distributive/product.cpp" #include "extension/core_functions/aggregate/distributive/skew.cpp" +#include "extension/core_functions/aggregate/distributive/kurtosis.cpp" + +#include "extension/core_functions/aggregate/distributive/sum.cpp" + #include "extension/core_functions/aggregate/distributive/bitagg.cpp" #include "extension/core_functions/aggregate/distributive/bitstring_agg.cpp" -#include "extension/core_functions/aggregate/distributive/kurtosis.cpp" +#include "extension/core_functions/aggregate/distributive/string_agg.cpp" + +#include "extension/core_functions/aggregate/distributive/bool.cpp" diff --git a/src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp b/src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp index e1a562545..4f892afa3 100644 --- a/src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp +++ b/src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp @@ -1,12 +1,12 @@ -#include "extension/core_functions/aggregate/holistic/mode.cpp" - -#include "extension/core_functions/aggregate/holistic/mad.cpp" +#include "extension/core_functions/aggregate/holistic/quantile.cpp" -#include "extension/core_functions/aggregate/holistic/approximate_quantile.cpp" +#include "extension/core_functions/aggregate/holistic/mode.cpp" -#include "extension/core_functions/aggregate/holistic/quantile.cpp" +#include "extension/core_functions/aggregate/holistic/reservoir_quantile.cpp" #include "extension/core_functions/aggregate/holistic/approx_top_k.cpp" -#include "extension/core_functions/aggregate/holistic/reservoir_quantile.cpp" +#include "extension/core_functions/aggregate/holistic/mad.cpp" + +#include "extension/core_functions/aggregate/holistic/approximate_quantile.cpp" diff --git a/src/duckdb/ub_extension_core_functions_aggregate_regression.cpp b/src/duckdb/ub_extension_core_functions_aggregate_regression.cpp index ebe309de1..071bc16d4 100644 --- a/src/duckdb/ub_extension_core_functions_aggregate_regression.cpp +++ b/src/duckdb/ub_extension_core_functions_aggregate_regression.cpp @@ -1,7 +1,3 @@ -#include "extension/core_functions/aggregate/regression/regr_slope.cpp" - -#include "extension/core_functions/aggregate/regression/regr_count.cpp" - #include "extension/core_functions/aggregate/regression/regr_r2.cpp" #include "extension/core_functions/aggregate/regression/regr_avg.cpp" @@ -10,5 +6,9 @@ #include "extension/core_functions/aggregate/regression/regr_sxx_syy.cpp" +#include "extension/core_functions/aggregate/regression/regr_slope.cpp" + +#include "extension/core_functions/aggregate/regression/regr_count.cpp" + #include "extension/core_functions/aggregate/regression/regr_intercept.cpp" diff --git a/src/duckdb/ub_extension_core_functions_scalar_array.cpp b/src/duckdb/ub_extension_core_functions_scalar_array.cpp index e4f63a369..9b9a475b4 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_array.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_array.cpp @@ -1,4 +1,4 @@ -#include "extension/core_functions/scalar/array/array_functions.cpp" - #include "extension/core_functions/scalar/array/array_value.cpp" +#include "extension/core_functions/scalar/array/array_functions.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_date.cpp b/src/duckdb/ub_extension_core_functions_scalar_date.cpp index 735def33b..869059d9b 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_date.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_date.cpp @@ -1,20 +1,20 @@ -#include "extension/core_functions/scalar/date/age.cpp" - #include "extension/core_functions/scalar/date/time_bucket.cpp" -#include "extension/core_functions/scalar/date/make_date.cpp" +#include "extension/core_functions/scalar/date/current.cpp" -#include "extension/core_functions/scalar/date/epoch.cpp" +#include "extension/core_functions/scalar/date/date_sub.cpp" -#include "extension/core_functions/scalar/date/date_part.cpp" +#include "extension/core_functions/scalar/date/make_date.cpp" -#include "extension/core_functions/scalar/date/date_trunc.cpp" +#include "extension/core_functions/scalar/date/age.cpp" -#include "extension/core_functions/scalar/date/date_sub.cpp" +#include "extension/core_functions/scalar/date/date_part.cpp" #include "extension/core_functions/scalar/date/to_interval.cpp" -#include "extension/core_functions/scalar/date/current.cpp" +#include "extension/core_functions/scalar/date/epoch.cpp" + +#include "extension/core_functions/scalar/date/date_trunc.cpp" #include "extension/core_functions/scalar/date/date_diff.cpp" diff --git a/src/duckdb/ub_extension_core_functions_scalar_debug.cpp b/src/duckdb/ub_extension_core_functions_scalar_debug.cpp index 578281472..4ba1ce617 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_debug.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_debug.cpp @@ -1,6 +1,6 @@ #include "extension/core_functions/scalar/debug/sleep.cpp" -#include "extension/core_functions/scalar/debug/index_key.cpp" - #include "extension/core_functions/scalar/debug/vector_type.cpp" +#include "extension/core_functions/scalar/debug/index_key.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_generic.cpp b/src/duckdb/ub_extension_core_functions_scalar_generic.cpp index 29c4ccb1d..0242be3c2 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_generic.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_generic.cpp @@ -1,22 +1,22 @@ -#include "extension/core_functions/scalar/generic/current_setting.cpp" +#include "extension/core_functions/scalar/generic/binning.cpp" -#include "extension/core_functions/scalar/generic/least.cpp" +#include "extension/core_functions/scalar/generic/system_functions.cpp" #include "extension/core_functions/scalar/generic/stats.cpp" -#include "extension/core_functions/scalar/generic/binning.cpp" +#include "extension/core_functions/scalar/generic/current_setting.cpp" -#include "extension/core_functions/scalar/generic/system_functions.cpp" +#include "extension/core_functions/scalar/generic/least.cpp" #include "extension/core_functions/scalar/generic/type_functions.cpp" +#include "extension/core_functions/scalar/generic/replace_type.cpp" + +#include "extension/core_functions/scalar/generic/can_implicitly_cast.cpp" + #include "extension/core_functions/scalar/generic/cast_to_type.cpp" #include "extension/core_functions/scalar/generic/alias.cpp" -#include "extension/core_functions/scalar/generic/replace_type.cpp" - #include "extension/core_functions/scalar/generic/hash.cpp" -#include "extension/core_functions/scalar/generic/can_implicitly_cast.cpp" - diff --git a/src/duckdb/ub_extension_core_functions_scalar_list.cpp b/src/duckdb/ub_extension_core_functions_scalar_list.cpp index a120190d0..a0909eac1 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_list.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_list.cpp @@ -1,22 +1,22 @@ -#include "extension/core_functions/scalar/list/list_distance.cpp" - -#include "extension/core_functions/scalar/list/array_slice.cpp" - -#include "extension/core_functions/scalar/list/list_transform.cpp" +#include "extension/core_functions/scalar/list/list_reduce.cpp" -#include "extension/core_functions/scalar/list/list_sort.cpp" +#include "extension/core_functions/scalar/list/flatten.cpp" -#include "extension/core_functions/scalar/list/list_has_any_or_all.cpp" +#include "extension/core_functions/scalar/list/list_filter.cpp" #include "extension/core_functions/scalar/list/list_value.cpp" -#include "extension/core_functions/scalar/list/flatten.cpp" +#include "extension/core_functions/scalar/list/list_sort.cpp" -#include "extension/core_functions/scalar/list/list_reduce.cpp" +#include "extension/core_functions/scalar/list/range.cpp" -#include "extension/core_functions/scalar/list/list_filter.cpp" +#include "extension/core_functions/scalar/list/list_transform.cpp" -#include "extension/core_functions/scalar/list/range.cpp" +#include "extension/core_functions/scalar/list/list_distance.cpp" + +#include "extension/core_functions/scalar/list/list_has_any_or_all.cpp" #include "extension/core_functions/scalar/list/list_aggregates.cpp" +#include "extension/core_functions/scalar/list/array_slice.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_map.cpp b/src/duckdb/ub_extension_core_functions_scalar_map.cpp index a180cb10f..78538a430 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_map.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_map.cpp @@ -1,16 +1,16 @@ #include "extension/core_functions/scalar/map/switch.cpp" -#include "extension/core_functions/scalar/map/map_extract.cpp" - -#include "extension/core_functions/scalar/map/cardinality.cpp" - -#include "extension/core_functions/scalar/map/map_from_entries.cpp" - #include "extension/core_functions/scalar/map/map_entries.cpp" #include "extension/core_functions/scalar/map/map.cpp" #include "extension/core_functions/scalar/map/map_keys_values.cpp" +#include "extension/core_functions/scalar/map/cardinality.cpp" + #include "extension/core_functions/scalar/map/map_concat.cpp" +#include "extension/core_functions/scalar/map/map_extract.cpp" + +#include "extension/core_functions/scalar/map/map_from_entries.cpp" + diff --git a/src/duckdb/ub_extension_core_functions_scalar_string.cpp b/src/duckdb/ub_extension_core_functions_scalar_string.cpp index 9bba7dbc6..9da959101 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_string.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_string.cpp @@ -1,50 +1,50 @@ -#include "extension/core_functions/scalar/string/parse_formatted_bytes.cpp" - -#include "extension/core_functions/scalar/string/parse_path.cpp" +#include "extension/core_functions/scalar/string/translate.cpp" -#include "extension/core_functions/scalar/string/format_bytes.cpp" +#include "extension/core_functions/scalar/string/jaro_winkler.cpp" -#include "extension/core_functions/scalar/string/chr.cpp" +#include "extension/core_functions/scalar/string/parse_formatted_bytes.cpp" -#include "extension/core_functions/scalar/string/left_right.cpp" +#include "extension/core_functions/scalar/string/hamming.cpp" -#include "extension/core_functions/scalar/string/jaro_winkler.cpp" +#include "extension/core_functions/scalar/string/unicode.cpp" -#include "extension/core_functions/scalar/string/starts_with.cpp" +#include "extension/core_functions/scalar/string/chr.cpp" -#include "extension/core_functions/scalar/string/replace.cpp" +#include "extension/core_functions/scalar/string/format_bytes.cpp" #include "extension/core_functions/scalar/string/reverse.cpp" -#include "extension/core_functions/scalar/string/levenshtein.cpp" +#include "extension/core_functions/scalar/string/hex.cpp" #include "extension/core_functions/scalar/string/instr.cpp" -#include "extension/core_functions/scalar/string/pad.cpp" +#include "extension/core_functions/scalar/string/parse_path.cpp" + +#include "extension/core_functions/scalar/string/starts_with.cpp" + +#include "extension/core_functions/scalar/string/to_base.cpp" #include "extension/core_functions/scalar/string/ascii.cpp" -#include "extension/core_functions/scalar/string/printf.cpp" +#include "extension/core_functions/scalar/string/repeat.cpp" -#include "extension/core_functions/scalar/string/bar.cpp" +#include "extension/core_functions/scalar/string/url_encode.cpp" -#include "extension/core_functions/scalar/string/hamming.cpp" +#include "extension/core_functions/scalar/string/jaccard.cpp" -#include "extension/core_functions/scalar/string/translate.cpp" +#include "extension/core_functions/scalar/string/levenshtein.cpp" #include "extension/core_functions/scalar/string/damerau_levenshtein.cpp" -#include "extension/core_functions/scalar/string/hex.cpp" - -#include "extension/core_functions/scalar/string/url_encode.cpp" +#include "extension/core_functions/scalar/string/replace.cpp" -#include "extension/core_functions/scalar/string/trim.cpp" +#include "extension/core_functions/scalar/string/printf.cpp" -#include "extension/core_functions/scalar/string/to_base.cpp" +#include "extension/core_functions/scalar/string/left_right.cpp" -#include "extension/core_functions/scalar/string/repeat.cpp" +#include "extension/core_functions/scalar/string/pad.cpp" -#include "extension/core_functions/scalar/string/jaccard.cpp" +#include "extension/core_functions/scalar/string/trim.cpp" -#include "extension/core_functions/scalar/string/unicode.cpp" +#include "extension/core_functions/scalar/string/bar.cpp" diff --git a/src/duckdb/ub_extension_core_functions_scalar_struct.cpp b/src/duckdb/ub_extension_core_functions_scalar_struct.cpp index 15e35c7a4..7cc473ffb 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_struct.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_struct.cpp @@ -1,8 +1,8 @@ +#include "extension/core_functions/scalar/struct/struct_insert.cpp" + #include "extension/core_functions/scalar/struct/struct_values.cpp" #include "extension/core_functions/scalar/struct/struct_update.cpp" -#include "extension/core_functions/scalar/struct/struct_insert.cpp" - #include "extension/core_functions/scalar/struct/struct_keys.cpp" diff --git a/src/duckdb/ub_extension_core_functions_scalar_union.cpp b/src/duckdb/ub_extension_core_functions_scalar_union.cpp index 5b68509a9..fad24f297 100644 --- a/src/duckdb/ub_extension_core_functions_scalar_union.cpp +++ b/src/duckdb/ub_extension_core_functions_scalar_union.cpp @@ -1,6 +1,6 @@ -#include "extension/core_functions/scalar/union/union_tag.cpp" - #include "extension/core_functions/scalar/union/union_extract.cpp" #include "extension/core_functions/scalar/union/union_value.cpp" +#include "extension/core_functions/scalar/union/union_tag.cpp" + diff --git a/src/duckdb/ub_extension_icu_third_party_icu_common.cpp b/src/duckdb/ub_extension_icu_third_party_icu_common.cpp index eb8a9dcbf..4a230511f 100644 --- a/src/duckdb/ub_extension_icu_third_party_icu_common.cpp +++ b/src/duckdb/ub_extension_icu_third_party_icu_common.cpp @@ -1,222 +1,222 @@ -#include "extension/icu/third_party/icu/common/resbund.cpp" +#include "extension/icu/third_party/icu/common/uniset_props.cpp" -#include "extension/icu/third_party/icu/common/util.cpp" +#include "extension/icu/third_party/icu/common/propname.cpp" -#include "extension/icu/third_party/icu/common/emojiprops.cpp" +#include "extension/icu/third_party/icu/common/unorm.cpp" #include "extension/icu/third_party/icu/common/normalizer2.cpp" -#include "extension/icu/third_party/icu/common/ucol_swp.cpp" +#include "extension/icu/third_party/icu/common/dtintrv.cpp" -#include "extension/icu/third_party/icu/common/icudataver.cpp" +#include "extension/icu/third_party/icu/common/loclikely.cpp" -#include "extension/icu/third_party/icu/common/charstr.cpp" +#include "extension/icu/third_party/icu/common/cstr.cpp" -#include "extension/icu/third_party/icu/common/ustrcase_locale.cpp" +#include "extension/icu/third_party/icu/common/stringpiece.cpp" -#include "extension/icu/third_party/icu/common/utrace.cpp" +#include "extension/icu/third_party/icu/common/ustring.cpp" -#include "extension/icu/third_party/icu/common/ucmndata.cpp" +#include "extension/icu/third_party/icu/common/cmemory.cpp" -#include "extension/icu/third_party/icu/common/uiter.cpp" +#include "extension/icu/third_party/icu/common/uvectr32.cpp" -#include "extension/icu/third_party/icu/common/ucat.cpp" +#include "extension/icu/third_party/icu/common/ucurr.cpp" -#include "extension/icu/third_party/icu/common/uniset_props.cpp" +#include "extension/icu/third_party/icu/common/bytestream.cpp" -#include "extension/icu/third_party/icu/common/utext.cpp" +#include "extension/icu/third_party/icu/common/appendable.cpp" -#include "extension/icu/third_party/icu/common/schriter.cpp" +#include "extension/icu/third_party/icu/common/locavailable.cpp" -#include "extension/icu/third_party/icu/common/caniter.cpp" +#include "extension/icu/third_party/icu/common/utypes.cpp" -#include "extension/icu/third_party/icu/common/udata.cpp" +#include "extension/icu/third_party/icu/common/utrie_swap.cpp" -#include "extension/icu/third_party/icu/common/unistr_case_locale.cpp" +#include "extension/icu/third_party/icu/common/schriter.cpp" -#include "extension/icu/third_party/icu/common/errorcode.cpp" +#include "extension/icu/third_party/icu/common/utext.cpp" -#include "extension/icu/third_party/icu/common/unifunct.cpp" +#include "extension/icu/third_party/icu/common/unistr.cpp" -#include "extension/icu/third_party/icu/common/bytestream.cpp" +#include "extension/icu/third_party/icu/common/uloc_keytype.cpp" -#include "extension/icu/third_party/icu/common/uset_props.cpp" +#include "extension/icu/third_party/icu/common/ustrtrns.cpp" -#include "extension/icu/third_party/icu/common/locresdata.cpp" +#include "extension/icu/third_party/icu/common/uniset.cpp" -#include "extension/icu/third_party/icu/common/utrie_swap.cpp" +#include "extension/icu/third_party/icu/common/uobject.cpp" #include "extension/icu/third_party/icu/common/umutex.cpp" -#include "extension/icu/third_party/icu/common/dtintrv.cpp" - -#include "extension/icu/third_party/icu/common/bytestrieiterator.cpp" - -#include "extension/icu/third_party/icu/common/characterproperties.cpp" +#include "extension/icu/third_party/icu/common/unistr_case_locale.cpp" -#include "extension/icu/third_party/icu/common/messagepattern.cpp" +#include "extension/icu/third_party/icu/common/ulist.cpp" -#include "extension/icu/third_party/icu/common/stringpiece.cpp" +#include "extension/icu/third_party/icu/common/uarrsort.cpp" -#include "extension/icu/third_party/icu/common/unistr_case.cpp" +#include "extension/icu/third_party/icu/common/ustack.cpp" -#include "extension/icu/third_party/icu/common/util_props.cpp" +#include "extension/icu/third_party/icu/common/characterproperties.cpp" -#include "extension/icu/third_party/icu/common/locavailable.cpp" +#include "extension/icu/third_party/icu/common/ubidiln.cpp" -#include "extension/icu/third_party/icu/common/ustring.cpp" +#include "extension/icu/third_party/icu/common/udataswp.cpp" -#include "extension/icu/third_party/icu/common/ucharstrieiterator.cpp" +#include "extension/icu/third_party/icu/common/locdispnames.cpp" -#include "extension/icu/third_party/icu/common/uvectr64.cpp" +#include "extension/icu/third_party/icu/common/unisetspan.cpp" -#include "extension/icu/third_party/icu/common/usetiter.cpp" +#include "extension/icu/third_party/icu/common/uinvchar.cpp" -#include "extension/icu/third_party/icu/common/ucharstrie.cpp" +#include "extension/icu/third_party/icu/common/ucln_cmn.cpp" -#include "extension/icu/third_party/icu/common/loclikely.cpp" +#include "extension/icu/third_party/icu/common/uinit.cpp" -#include "extension/icu/third_party/icu/common/ulist.cpp" +#include "extension/icu/third_party/icu/common/localeprioritylist.cpp" -#include "extension/icu/third_party/icu/common/utf_impl.cpp" +#include "extension/icu/third_party/icu/common/unifilt.cpp" -#include "extension/icu/third_party/icu/common/propname.cpp" +#include "extension/icu/third_party/icu/common/ubidiwrt.cpp" #include "extension/icu/third_party/icu/common/fixedstring.cpp" -#include "extension/icu/third_party/icu/common/sharedobject.cpp" +#include "extension/icu/third_party/icu/common/umath.cpp" -#include "extension/icu/third_party/icu/common/ustack.cpp" +#include "extension/icu/third_party/icu/common/uenum.cpp" -#include "extension/icu/third_party/icu/common/appendable.cpp" +#include "extension/icu/third_party/icu/common/uset_props.cpp" -#include "extension/icu/third_party/icu/common/normlzr.cpp" +#include "extension/icu/third_party/icu/common/unifunct.cpp" -#include "extension/icu/third_party/icu/common/bytestrie.cpp" +#include "extension/icu/third_party/icu/common/errorcode.cpp" -#include "extension/icu/third_party/icu/common/umath.cpp" +#include "extension/icu/third_party/icu/common/uloc_tag.cpp" -#include "extension/icu/third_party/icu/common/cmemory.cpp" +#include "extension/icu/third_party/icu/common/ustrenum.cpp" + +#include "extension/icu/third_party/icu/common/parsepos.cpp" #include "extension/icu/third_party/icu/common/uset.cpp" -#include "extension/icu/third_party/icu/common/filterednormalizer2.cpp" +#include "extension/icu/third_party/icu/common/cstring.cpp" -#include "extension/icu/third_party/icu/common/uloc_keytype.cpp" +#include "extension/icu/third_party/icu/common/ucasemap.cpp" -#include "extension/icu/third_party/icu/common/bmpset.cpp" +#include "extension/icu/third_party/icu/common/uiter.cpp" -#include "extension/icu/third_party/icu/common/uinit.cpp" +#include "extension/icu/third_party/icu/common/lsr.cpp" -#include "extension/icu/third_party/icu/common/uniset_closure.cpp" +#include "extension/icu/third_party/icu/common/edits.cpp" -#include "extension/icu/third_party/icu/common/simpleformatter.cpp" +#include "extension/icu/third_party/icu/common/emojiprops.cpp" -#include "extension/icu/third_party/icu/common/edits.cpp" +#include "extension/icu/third_party/icu/common/stringtriebuilder.cpp" -#include "extension/icu/third_party/icu/common/locdspnm.cpp" +#include "extension/icu/third_party/icu/common/uvectr64.cpp" -#include "extension/icu/third_party/icu/common/ushape.cpp" +#include "extension/icu/third_party/icu/common/resbund.cpp" -#include "extension/icu/third_party/icu/common/stringtriebuilder.cpp" +#include "extension/icu/third_party/icu/common/ucharstrie.cpp" -#include "extension/icu/third_party/icu/common/ustrtrns.cpp" +#include "extension/icu/third_party/icu/common/filterednormalizer2.cpp" -#include "extension/icu/third_party/icu/common/usc_impl.cpp" +#include "extension/icu/third_party/icu/common/uchar.cpp" -#include "extension/icu/third_party/icu/common/bytesinkutil.cpp" +#include "extension/icu/third_party/icu/common/chariter.cpp" -#include "extension/icu/third_party/icu/common/ubidiln.cpp" +#include "extension/icu/third_party/icu/common/locid.cpp" -#include "extension/icu/third_party/icu/common/uarrsort.cpp" +#include "extension/icu/third_party/icu/common/ustrcase_locale.cpp" -#include "extension/icu/third_party/icu/common/localematcher.cpp" +#include "extension/icu/third_party/icu/common/unistr_case.cpp" -#include "extension/icu/third_party/icu/common/locdispnames.cpp" +#include "extension/icu/third_party/icu/common/sharedobject.cpp" -#include "extension/icu/third_party/icu/common/propsvec.cpp" +#include "extension/icu/third_party/icu/common/simpleformatter.cpp" -#include "extension/icu/third_party/icu/common/lsr.cpp" +#include "extension/icu/third_party/icu/common/util.cpp" #include "extension/icu/third_party/icu/common/normalizer2impl.cpp" -#include "extension/icu/third_party/icu/common/unifilt.cpp" +#include "extension/icu/third_party/icu/common/uprops.cpp" -#include "extension/icu/third_party/icu/common/unisetspan.cpp" +#include "extension/icu/third_party/icu/common/patternprops.cpp" -#include "extension/icu/third_party/icu/common/uniset.cpp" +#include "extension/icu/third_party/icu/common/locdspnm.cpp" -#include "extension/icu/third_party/icu/common/uchar.cpp" +#include "extension/icu/third_party/icu/common/ucharstrieiterator.cpp" -#include "extension/icu/third_party/icu/common/ucln_cmn.cpp" +#include "extension/icu/third_party/icu/common/messagepattern.cpp" -#include "extension/icu/third_party/icu/common/chariter.cpp" +#include "extension/icu/third_party/icu/common/bytestrie.cpp" -#include "extension/icu/third_party/icu/common/uobject.cpp" +#include "extension/icu/third_party/icu/common/uscript.cpp" -#include "extension/icu/third_party/icu/common/ruleiter.cpp" +#include "extension/icu/third_party/icu/common/utf_impl.cpp" -#include "extension/icu/third_party/icu/common/pluralmap.cpp" +#include "extension/icu/third_party/icu/common/uniset_closure.cpp" -#include "extension/icu/third_party/icu/common/ucurr.cpp" +#include "extension/icu/third_party/icu/common/bmpset.cpp" -#include "extension/icu/third_party/icu/common/ubidiwrt.cpp" +#include "extension/icu/third_party/icu/common/pluralmap.cpp" -#include "extension/icu/third_party/icu/common/udataswp.cpp" +#include "extension/icu/third_party/icu/common/ucat.cpp" -#include "extension/icu/third_party/icu/common/unistr.cpp" +#include "extension/icu/third_party/icu/common/udata.cpp" -#include "extension/icu/third_party/icu/common/patternprops.cpp" +#include "extension/icu/third_party/icu/common/normlzr.cpp" -#include "extension/icu/third_party/icu/common/uloc_tag.cpp" +#include "extension/icu/third_party/icu/common/ustr_wcs.cpp" -#include "extension/icu/third_party/icu/common/cstr.cpp" +#include "extension/icu/third_party/icu/common/locbased.cpp" -#include "extension/icu/third_party/icu/common/uscript.cpp" +#include "extension/icu/third_party/icu/common/usetiter.cpp" -#include "extension/icu/third_party/icu/common/udatamem.cpp" +#include "extension/icu/third_party/icu/common/localematcher.cpp" -#include "extension/icu/third_party/icu/common/uinvchar.cpp" +#include "extension/icu/third_party/icu/common/charstr.cpp" -#include "extension/icu/third_party/icu/common/unorm.cpp" +#include "extension/icu/third_party/icu/common/bytesinkutil.cpp" #include "extension/icu/third_party/icu/common/resbund_cnv.cpp" -#include "extension/icu/third_party/icu/common/unistr_props.cpp" +#include "extension/icu/third_party/icu/common/bytestrieiterator.cpp" -#include "extension/icu/third_party/icu/common/uvectr32.cpp" +#include "extension/icu/third_party/icu/common/icudataver.cpp" -#include "extension/icu/third_party/icu/common/ucasemap.cpp" +#include "extension/icu/third_party/icu/common/resource.cpp" -#include "extension/icu/third_party/icu/common/uenum.cpp" +#include "extension/icu/third_party/icu/common/ucol_swp.cpp" -#include "extension/icu/third_party/icu/common/ustrenum.cpp" +#include "extension/icu/third_party/icu/common/ruleiter.cpp" -#include "extension/icu/third_party/icu/common/locid.cpp" +#include "extension/icu/third_party/icu/common/udatamem.cpp" -#include "extension/icu/third_party/icu/common/ustr_wcs.cpp" +#include "extension/icu/third_party/icu/common/ures_cnv.cpp" -#include "extension/icu/third_party/icu/common/parsepos.cpp" +#include "extension/icu/third_party/icu/common/ushape.cpp" -#include "extension/icu/third_party/icu/common/resource.cpp" +#include "extension/icu/third_party/icu/common/propsvec.cpp" -#include "extension/icu/third_party/icu/common/ures_cnv.cpp" +#include "extension/icu/third_party/icu/common/usc_impl.cpp" -#include "extension/icu/third_party/icu/common/localeprioritylist.cpp" +#include "extension/icu/third_party/icu/common/util_props.cpp" -#include "extension/icu/third_party/icu/common/uprops.cpp" +#include "extension/icu/third_party/icu/common/uhash_us.cpp" -#include "extension/icu/third_party/icu/common/uchriter.cpp" +#include "extension/icu/third_party/icu/common/unistr_props.cpp" -#include "extension/icu/third_party/icu/common/ubidi.cpp" +#include "extension/icu/third_party/icu/common/ucmndata.cpp" #include "extension/icu/third_party/icu/common/ustrfmt.cpp" -#include "extension/icu/third_party/icu/common/locbased.cpp" +#include "extension/icu/third_party/icu/common/uchriter.cpp" -#include "extension/icu/third_party/icu/common/cstring.cpp" +#include "extension/icu/third_party/icu/common/caniter.cpp" -#include "extension/icu/third_party/icu/common/uhash_us.cpp" +#include "extension/icu/third_party/icu/common/locresdata.cpp" -#include "extension/icu/third_party/icu/common/utypes.cpp" +#include "extension/icu/third_party/icu/common/utrace.cpp" + +#include "extension/icu/third_party/icu/common/ubidi.cpp" diff --git a/src/duckdb/ub_extension_icu_third_party_icu_i18n.cpp b/src/duckdb/ub_extension_icu_third_party_icu_i18n.cpp index de95582d9..8171e1d4a 100644 --- a/src/duckdb/ub_extension_icu_third_party_icu_i18n.cpp +++ b/src/duckdb/ub_extension_icu_third_party_icu_i18n.cpp @@ -1,182 +1,182 @@ -#include "extension/icu/third_party/icu/i18n/selfmt.cpp" +#include "extension/icu/third_party/icu/i18n/tztrans.cpp" -#include "extension/icu/third_party/icu/i18n/coll.cpp" +#include "extension/icu/third_party/icu/i18n/collationruleparser.cpp" -#include "extension/icu/third_party/icu/i18n/umsg.cpp" +#include "extension/icu/third_party/icu/i18n/basictz.cpp" -#include "extension/icu/third_party/icu/i18n/region.cpp" +#include "extension/icu/third_party/icu/i18n/utmscale.cpp" -#include "extension/icu/third_party/icu/i18n/ufieldpositer.cpp" +#include "extension/icu/third_party/icu/i18n/tzrule.cpp" -#include "extension/icu/third_party/icu/i18n/collation.cpp" +#include "extension/icu/third_party/icu/i18n/listformatter.cpp" -#include "extension/icu/third_party/icu/i18n/dangical.cpp" +#include "extension/icu/third_party/icu/i18n/coleitr.cpp" -#include "extension/icu/third_party/icu/i18n/collationcompare.cpp" +#include "extension/icu/third_party/icu/i18n/collationdata.cpp" -#include "extension/icu/third_party/icu/i18n/udat.cpp" +#include "extension/icu/third_party/icu/i18n/ucoleitr.cpp" -#include "extension/icu/third_party/icu/i18n/uitercollationiterator.cpp" +#include "extension/icu/third_party/icu/i18n/sortkey.cpp" -#include "extension/icu/third_party/icu/i18n/number_notation.cpp" +#include "extension/icu/third_party/icu/i18n/dangical.cpp" #include "extension/icu/third_party/icu/i18n/uregion.cpp" -#include "extension/icu/third_party/icu/i18n/format.cpp" +#include "extension/icu/third_party/icu/i18n/udateintervalformat.cpp" -#include "extension/icu/third_party/icu/i18n/tmutamt.cpp" +#include "extension/icu/third_party/icu/i18n/ucln_in.cpp" -#include "extension/icu/third_party/icu/i18n/collationiterator.cpp" +#include "extension/icu/third_party/icu/i18n/formattedval_iterimpl.cpp" -#include "extension/icu/third_party/icu/i18n/unumsys.cpp" +#include "extension/icu/third_party/icu/i18n/collationsettings.cpp" -#include "extension/icu/third_party/icu/i18n/dcfmtsym.cpp" +#include "extension/icu/third_party/icu/i18n/rbtz.cpp" -#include "extension/icu/third_party/icu/i18n/ucal.cpp" +#include "extension/icu/third_party/icu/i18n/collationtailoring.cpp" -#include "extension/icu/third_party/icu/i18n/astro.cpp" +#include "extension/icu/third_party/icu/i18n/tznames_impl.cpp" -#include "extension/icu/third_party/icu/i18n/udateintervalformat.cpp" +#include "extension/icu/third_party/icu/i18n/currunit.cpp" -#include "extension/icu/third_party/icu/i18n/unum.cpp" +#include "extension/icu/third_party/icu/i18n/region.cpp" -#include "extension/icu/third_party/icu/i18n/collationfastlatin.cpp" +#include "extension/icu/third_party/icu/i18n/ulocdata.cpp" -#include "extension/icu/third_party/icu/i18n/erarules.cpp" +#include "extension/icu/third_party/icu/i18n/format.cpp" -#include "extension/icu/third_party/icu/i18n/smpdtfst.cpp" +#include "extension/icu/third_party/icu/i18n/collationdatawriter.cpp" -#include "extension/icu/third_party/icu/i18n/collationweights.cpp" +#include "extension/icu/third_party/icu/i18n/collationkeys.cpp" -#include "extension/icu/third_party/icu/i18n/number_padding.cpp" +#include "extension/icu/third_party/icu/i18n/formattedvalue.cpp" + +#include "extension/icu/third_party/icu/i18n/decContext.cpp" #include "extension/icu/third_party/icu/i18n/reldtfmt.cpp" -#include "extension/icu/third_party/icu/i18n/collationsettings.cpp" +#include "extension/icu/third_party/icu/i18n/standardplural.cpp" -#include "extension/icu/third_party/icu/i18n/formattedvalue.cpp" +#include "extension/icu/third_party/icu/i18n/alphaindex.cpp" -#include "extension/icu/third_party/icu/i18n/ulocdata.cpp" +#include "extension/icu/third_party/icu/i18n/datefmt.cpp" -#include "extension/icu/third_party/icu/i18n/rbtz.cpp" +#include "extension/icu/third_party/icu/i18n/uitercollationiterator.cpp" -#include "extension/icu/third_party/icu/i18n/coleitr.cpp" +#include "extension/icu/third_party/icu/i18n/collationdatareader.cpp" -#include "extension/icu/third_party/icu/i18n/collationsets.cpp" +#include "extension/icu/third_party/icu/i18n/number_notation.cpp" -#include "extension/icu/third_party/icu/i18n/listformatter.cpp" +#include "extension/icu/third_party/icu/i18n/dtrule.cpp" + +#include "extension/icu/third_party/icu/i18n/unum.cpp" + +#include "extension/icu/third_party/icu/i18n/ulistformatter.cpp" + +#include "extension/icu/third_party/icu/i18n/tmutamt.cpp" #include "extension/icu/third_party/icu/i18n/bocsu.cpp" -#include "extension/icu/third_party/icu/i18n/displayoptions.cpp" +#include "extension/icu/third_party/icu/i18n/utf16collationiterator.cpp" -#include "extension/icu/third_party/icu/i18n/zrule.cpp" +#include "extension/icu/third_party/icu/i18n/tmunit.cpp" -#include "extension/icu/third_party/icu/i18n/dtitvinf.cpp" +#include "extension/icu/third_party/icu/i18n/fpositer.cpp" -#include "extension/icu/third_party/icu/i18n/measunit.cpp" +#include "extension/icu/third_party/icu/i18n/collationfastlatinbuilder.cpp" -#include "extension/icu/third_party/icu/i18n/rulebasedcollator.cpp" +#include "extension/icu/third_party/icu/i18n/number_modifiers.cpp" -#include "extension/icu/third_party/icu/i18n/basictz.cpp" +#include "extension/icu/third_party/icu/i18n/collationcompare.cpp" -#include "extension/icu/third_party/icu/i18n/collationrootelements.cpp" +#include "extension/icu/third_party/icu/i18n/measunit.cpp" -#include "extension/icu/third_party/icu/i18n/collationdata.cpp" +#include "extension/icu/third_party/icu/i18n/astro.cpp" -#include "extension/icu/third_party/icu/i18n/collationdatawriter.cpp" +#include "extension/icu/third_party/icu/i18n/udat.cpp" -#include "extension/icu/third_party/icu/i18n/collationruleparser.cpp" +#include "extension/icu/third_party/icu/i18n/utf8collationiterator.cpp" -#include "extension/icu/third_party/icu/i18n/utmscale.cpp" +#include "extension/icu/third_party/icu/i18n/selfmt.cpp" -#include "extension/icu/third_party/icu/i18n/measure.cpp" +#include "extension/icu/third_party/icu/i18n/collationiterator.cpp" -#include "extension/icu/third_party/icu/i18n/simpletz.cpp" +#include "extension/icu/third_party/icu/i18n/tzfmt.cpp" -#include "extension/icu/third_party/icu/i18n/currunit.cpp" +#include "extension/icu/third_party/icu/i18n/units_converter.cpp" -#include "extension/icu/third_party/icu/i18n/datefmt.cpp" +#include "extension/icu/third_party/icu/i18n/rulebasedcollator.cpp" -#include "extension/icu/third_party/icu/i18n/scientificnumberformatter.cpp" +#include "extension/icu/third_party/icu/i18n/dcfmtsym.cpp" -#include "extension/icu/third_party/icu/i18n/tzrule.cpp" +#include "extension/icu/third_party/icu/i18n/dtitvinf.cpp" -#include "extension/icu/third_party/icu/i18n/collationfastlatinbuilder.cpp" +#include "extension/icu/third_party/icu/i18n/collationfcd.cpp" -#include "extension/icu/third_party/icu/i18n/collationbuilder.cpp" +#include "extension/icu/third_party/icu/i18n/udatpg.cpp" -#include "extension/icu/third_party/icu/i18n/olsontz.cpp" +#include "extension/icu/third_party/icu/i18n/scientificnumberformatter.cpp" -#include "extension/icu/third_party/icu/i18n/currfmt.cpp" +#include "extension/icu/third_party/icu/i18n/coll.cpp" -#include "extension/icu/third_party/icu/i18n/number_modifiers.cpp" +#include "extension/icu/third_party/icu/i18n/unumsys.cpp" -#include "extension/icu/third_party/icu/i18n/number_decimfmtprops.cpp" +#include "extension/icu/third_party/icu/i18n/displayoptions.cpp" -#include "extension/icu/third_party/icu/i18n/tztrans.cpp" +#include "extension/icu/third_party/icu/i18n/smpdtfst.cpp" -#include "extension/icu/third_party/icu/i18n/tmunit.cpp" +#include "extension/icu/third_party/icu/i18n/zrule.cpp" -#include "extension/icu/third_party/icu/i18n/utf8collationiterator.cpp" +#include "extension/icu/third_party/icu/i18n/collationbuilder.cpp" #include "extension/icu/third_party/icu/i18n/vzone.cpp" -#include "extension/icu/third_party/icu/i18n/zonemeta.cpp" - -#include "extension/icu/third_party/icu/i18n/units_converter.cpp" +#include "extension/icu/third_party/icu/i18n/ufieldpositer.cpp" -#include "extension/icu/third_party/icu/i18n/tzfmt.cpp" +#include "extension/icu/third_party/icu/i18n/ucol_sit.cpp" -#include "extension/icu/third_party/icu/i18n/udatpg.cpp" +#include "extension/icu/third_party/icu/i18n/collationdatabuilder.cpp" -#include "extension/icu/third_party/icu/i18n/tznames.cpp" +#include "extension/icu/third_party/icu/i18n/collationfastlatin.cpp" -#include "extension/icu/third_party/icu/i18n/fpositer.cpp" +#include "extension/icu/third_party/icu/i18n/number_padding.cpp" -#include "extension/icu/third_party/icu/i18n/decContext.cpp" +#include "extension/icu/third_party/icu/i18n/erarules.cpp" -#include "extension/icu/third_party/icu/i18n/collationkeys.cpp" +#include "extension/icu/third_party/icu/i18n/collationweights.cpp" -#include "extension/icu/third_party/icu/i18n/number_affixutils.cpp" +#include "extension/icu/third_party/icu/i18n/collationrootelements.cpp" -#include "extension/icu/third_party/icu/i18n/sortkey.cpp" +#include "extension/icu/third_party/icu/i18n/measure.cpp" -#include "extension/icu/third_party/icu/i18n/utf16collationiterator.cpp" +#include "extension/icu/third_party/icu/i18n/collationsets.cpp" -#include "extension/icu/third_party/icu/i18n/formattedval_iterimpl.cpp" +#include "extension/icu/third_party/icu/i18n/ucal.cpp" -#include "extension/icu/third_party/icu/i18n/ucoleitr.cpp" +#include "extension/icu/third_party/icu/i18n/number_decimfmtprops.cpp" -#include "extension/icu/third_party/icu/i18n/gender.cpp" +#include "extension/icu/third_party/icu/i18n/umsg.cpp" -#include "extension/icu/third_party/icu/i18n/dtrule.cpp" +#include "extension/icu/third_party/icu/i18n/tznames.cpp" #include "extension/icu/third_party/icu/i18n/curramt.cpp" -#include "extension/icu/third_party/icu/i18n/collationtailoring.cpp" +#include "extension/icu/third_party/icu/i18n/simpletz.cpp" -#include "extension/icu/third_party/icu/i18n/ucol_sit.cpp" +#include "extension/icu/third_party/icu/i18n/number_affixutils.cpp" -#include "extension/icu/third_party/icu/i18n/formatted_string_builder.cpp" +#include "extension/icu/third_party/icu/i18n/zonemeta.cpp" -#include "extension/icu/third_party/icu/i18n/collationdatareader.cpp" +#include "extension/icu/third_party/icu/i18n/olsontz.cpp" #include "extension/icu/third_party/icu/i18n/ztrans.cpp" -#include "extension/icu/third_party/icu/i18n/tznames_impl.cpp" - -#include "extension/icu/third_party/icu/i18n/standardplural.cpp" - #include "extension/icu/third_party/icu/i18n/fphdlimp.cpp" -#include "extension/icu/third_party/icu/i18n/ucln_in.cpp" - -#include "extension/icu/third_party/icu/i18n/ulistformatter.cpp" +#include "extension/icu/third_party/icu/i18n/gender.cpp" -#include "extension/icu/third_party/icu/i18n/collationdatabuilder.cpp" +#include "extension/icu/third_party/icu/i18n/currfmt.cpp" -#include "extension/icu/third_party/icu/i18n/alphaindex.cpp" +#include "extension/icu/third_party/icu/i18n/formatted_string_builder.cpp" -#include "extension/icu/third_party/icu/i18n/collationfcd.cpp" +#include "extension/icu/third_party/icu/i18n/collation.cpp" diff --git a/src/duckdb/ub_extension_json_json_functions.cpp b/src/duckdb/ub_extension_json_json_functions.cpp index c12e03236..1523a5983 100644 --- a/src/duckdb/ub_extension_json_json_functions.cpp +++ b/src/duckdb/ub_extension_json_json_functions.cpp @@ -1,46 +1,46 @@ -#include "extension/json/json_functions/json_extract.cpp" - -#include "extension/json/json_functions/json_valid.cpp" +#include "extension/json/json_functions/json_contains.cpp" #include "extension/json/json_functions/read_json.cpp" -#include "extension/json/json_functions/json_structure.cpp" +#include "extension/json/json_functions/json_extract.cpp" -#include "extension/json/json_functions/json_merge_patch_diff.cpp" +#include "extension/json/json_functions/json_type.cpp" -#include "extension/json/json_functions/json_array_length.cpp" +#include "extension/json/json_functions/json_valid.cpp" -#include "extension/json/json_functions/json_serialize_sql.cpp" +#include "extension/json/json_functions/json_transform.cpp" -#include "extension/json/json_functions/json_exists.cpp" +#include "extension/json/json_functions/json_keys.cpp" -#include "extension/json/json_functions/json_deep_merge.cpp" +#include "extension/json/json_functions/copy_json.cpp" -#include "extension/json/json_functions/json_keys.cpp" +#include "extension/json/json_functions/json_table_in_out.cpp" -#include "extension/json/json_functions/read_json_objects.cpp" +#include "extension/json/json_functions/json_serialize_plan.cpp" -#include "extension/json/json_functions/json_value.cpp" +#include "extension/json/json_functions/json_strip_nulls.cpp" -#include "extension/json/json_functions/json_create.cpp" +#include "extension/json/json_functions/json_merge_patch_diff.cpp" -#include "extension/json/json_functions/json_table_in_out.cpp" +#include "extension/json/json_functions/json_exists.cpp" -#include "extension/json/json_functions/json_serialize_plan.cpp" +#include "extension/json/json_functions/json_structure.cpp" -#include "extension/json/json_functions/json_normalize.cpp" +#include "extension/json/json_functions/json_pretty.cpp" -#include "extension/json/json_functions/json_transform.cpp" +#include "extension/json/json_functions/read_json_objects.cpp" -#include "extension/json/json_functions/json_contains.cpp" +#include "extension/json/json_functions/json_create.cpp" -#include "extension/json/json_functions/json_type.cpp" +#include "extension/json/json_functions/json_deep_merge.cpp" -#include "extension/json/json_functions/copy_json.cpp" +#include "extension/json/json_functions/json_serialize_sql.cpp" -#include "extension/json/json_functions/json_pretty.cpp" +#include "extension/json/json_functions/json_value.cpp" -#include "extension/json/json_functions/json_strip_nulls.cpp" +#include "extension/json/json_functions/json_array_length.cpp" + +#include "extension/json/json_functions/json_normalize.cpp" #include "extension/json/json_functions/json_merge_patch.cpp" diff --git a/src/duckdb/ub_extension_parquet_decoder.cpp b/src/duckdb/ub_extension_parquet_decoder.cpp index d436738aa..5d625aaf8 100644 --- a/src/duckdb/ub_extension_parquet_decoder.cpp +++ b/src/duckdb/ub_extension_parquet_decoder.cpp @@ -1,12 +1,12 @@ -#include "extension/parquet/decoder/byte_stream_split_decoder.cpp" - -#include "extension/parquet/decoder/delta_byte_array_decoder.cpp" - #include "extension/parquet/decoder/rle_decoder.cpp" #include "extension/parquet/decoder/dictionary_decoder.cpp" +#include "extension/parquet/decoder/delta_length_byte_array_decoder.cpp" + +#include "extension/parquet/decoder/byte_stream_split_decoder.cpp" + #include "extension/parquet/decoder/delta_binary_packed_decoder.cpp" -#include "extension/parquet/decoder/delta_length_byte_array_decoder.cpp" +#include "extension/parquet/decoder/delta_byte_array_decoder.cpp" diff --git a/src/duckdb/ub_extension_parquet_reader.cpp b/src/duckdb/ub_extension_parquet_reader.cpp index f3dacfafe..28b599256 100644 --- a/src/duckdb/ub_extension_parquet_reader.cpp +++ b/src/duckdb/ub_extension_parquet_reader.cpp @@ -1,14 +1,14 @@ -#include "extension/parquet/reader/expression_column_reader.cpp" - -#include "extension/parquet/reader/string_column_reader.cpp" +#include "extension/parquet/reader/struct_column_reader.cpp" #include "extension/parquet/reader/list_column_reader.cpp" +#include "extension/parquet/reader/decimal_column_reader.cpp" + #include "extension/parquet/reader/variant_column_reader.cpp" -#include "extension/parquet/reader/row_number_column_reader.cpp" +#include "extension/parquet/reader/string_column_reader.cpp" -#include "extension/parquet/reader/decimal_column_reader.cpp" +#include "extension/parquet/reader/expression_column_reader.cpp" -#include "extension/parquet/reader/struct_column_reader.cpp" +#include "extension/parquet/reader/row_number_column_reader.cpp" diff --git a/src/duckdb/ub_extension_parquet_reader_variant.cpp b/src/duckdb/ub_extension_parquet_reader_variant.cpp index c5f86f0b4..4e55f0b16 100644 --- a/src/duckdb/ub_extension_parquet_reader_variant.cpp +++ b/src/duckdb/ub_extension_parquet_reader_variant.cpp @@ -1,4 +1,4 @@ -#include "extension/parquet/reader/variant/variant_shredded_conversion.cpp" - #include "extension/parquet/reader/variant/variant_binary_decoder.cpp" +#include "extension/parquet/reader/variant/variant_shredded_conversion.cpp" + diff --git a/src/duckdb/ub_extension_parquet_writer.cpp b/src/duckdb/ub_extension_parquet_writer.cpp index 2e12faccf..e1d3beb0f 100644 --- a/src/duckdb/ub_extension_parquet_writer.cpp +++ b/src/duckdb/ub_extension_parquet_writer.cpp @@ -1,14 +1,16 @@ -#include "extension/parquet/writer/enum_column_writer.cpp" - -#include "extension/parquet/writer/decimal_column_writer.cpp" +#include "extension/parquet/writer/boolean_column_writer.cpp" #include "extension/parquet/writer/struct_column_writer.cpp" -#include "extension/parquet/writer/boolean_column_writer.cpp" +#include "extension/parquet/writer/array_column_writer.cpp" + +#include "extension/parquet/writer/enum_column_writer.cpp" #include "extension/parquet/writer/list_column_writer.cpp" -#include "extension/parquet/writer/array_column_writer.cpp" +#include "extension/parquet/writer/decimal_column_writer.cpp" #include "extension/parquet/writer/primitive_column_writer.cpp" +#include "extension/parquet/writer/variant_column_writer.cpp" + diff --git a/src/duckdb/ub_src_common.cpp b/src/duckdb/ub_src_common.cpp index a1c0b5c89..c09dc26a7 100644 --- a/src/duckdb/ub_src_common.cpp +++ b/src/duckdb/ub_src_common.cpp @@ -1,5 +1,3 @@ -#include "src/common/allocator.cpp" - #include "src/common/assert.cpp" #include "src/common/bignum.cpp" @@ -16,6 +14,8 @@ #include "src/common/client_box_renderer_context.cpp" +#include "src/common/clustered_aggregate.cpp" + #include "src/common/column_data_collection_render_interface.cpp" #include "src/common/column_index.cpp" @@ -58,8 +58,12 @@ #include "src/common/hive_partitioning.cpp" +#include "src/common/identifier.cpp" + #include "src/common/local_file_system.cpp" +#include "src/common/memory_mapped_file.cpp" + #include "src/common/opener_file_system.cpp" #include "src/common/path.cpp" @@ -76,10 +80,10 @@ #include "src/common/render_tree.cpp" -#include "src/common/serialization_compatibility.cpp" - #include "src/common/stacktrace.cpp" +#include "src/common/storage_compatibility.cpp" + #include "src/common/string_util.cpp" #include "src/common/thread_util.cpp" diff --git a/src/duckdb/ub_src_common_allocator.cpp b/src/duckdb/ub_src_common_allocator.cpp new file mode 100644 index 000000000..d1a04b9db --- /dev/null +++ b/src/duckdb/ub_src_common_allocator.cpp @@ -0,0 +1,6 @@ +#include "src/common/allocator/allocator.cpp" + +#include "src/common/allocator/allocator_jemalloc.cpp" + +#include "src/common/allocator/allocator_standard.cpp" + diff --git a/src/duckdb/ub_src_common_enums.cpp b/src/duckdb/ub_src_common_enums.cpp index e979f2e50..c90e668fb 100644 --- a/src/duckdb/ub_src_common_enums.cpp +++ b/src/duckdb/ub_src_common_enums.cpp @@ -1,5 +1,7 @@ #include "src/common/enums/catalog_type.cpp" +#include "src/common/enums/column_segment_info_scan_type.cpp" + #include "src/common/enums/compression_type.cpp" #include "src/common/enums/date_part_specifier.cpp" @@ -14,8 +16,6 @@ #include "src/common/enums/logical_operator_type.cpp" -#include "src/common/enums/metric_type.cpp" - #include "src/common/enums/optimizer_type.cpp" #include "src/common/enums/physical_operator_type.cpp" diff --git a/src/duckdb/ub_src_execution_operator_helper.cpp b/src/duckdb/ub_src_execution_operator_helper.cpp index 15ca26744..200dae991 100644 --- a/src/duckdb/ub_src_execution_operator_helper.cpp +++ b/src/duckdb/ub_src_execution_operator_helper.cpp @@ -4,8 +4,12 @@ #include "src/execution/operator/helper/physical_buffered_collector.cpp" +#include "src/execution/operator/helper/physical_connect.cpp" + #include "src/execution/operator/helper/physical_create_secret.cpp" +#include "src/execution/operator/helper/physical_disconnect.cpp" + #include "src/execution/operator/helper/physical_execute.cpp" #include "src/execution/operator/helper/physical_explain_analyze.cpp" diff --git a/src/duckdb/ub_src_execution_operator_set.cpp b/src/duckdb/ub_src_execution_operator_set.cpp index d85e9dd39..823e75ff9 100644 --- a/src/duckdb/ub_src_execution_operator_set.cpp +++ b/src/duckdb/ub_src_execution_operator_set.cpp @@ -2,5 +2,7 @@ #include "src/execution/operator/set/physical_recursive_cte.cpp" +#include "src/execution/operator/set/physical_recursive_cte_runtime.cpp" + #include "src/execution/operator/set/physical_union.cpp" diff --git a/src/duckdb/ub_src_function_scalar_operator.cpp b/src/duckdb/ub_src_function_scalar_operator.cpp index f31f65916..7f8f9828d 100644 --- a/src/duckdb/ub_src_function_scalar_operator.cpp +++ b/src/duckdb/ub_src_function_scalar_operator.cpp @@ -2,6 +2,8 @@ #include "src/function/scalar/operator/arithmetic.cpp" +#include "src/function/scalar/operator/decimal_division.cpp" + #include "src/function/scalar/operator/multiply.cpp" #include "src/function/scalar/operator/subtract.cpp" diff --git a/src/duckdb/ub_src_function_scalar_string.cpp b/src/duckdb/ub_src_function_scalar_string.cpp index fd559a5f1..8b3677d53 100644 --- a/src/duckdb/ub_src_function_scalar_string.cpp +++ b/src/duckdb/ub_src_function_scalar_string.cpp @@ -14,6 +14,8 @@ #include "src/function/scalar/string/nfc_normalize.cpp" +#include "src/function/scalar/string/overlay.cpp" + #include "src/function/scalar/string/path_join.cpp" #include "src/function/scalar/string/prefix.cpp" diff --git a/src/duckdb/ub_src_function_scalar_variant.cpp b/src/duckdb/ub_src_function_scalar_variant.cpp index 2fb7bc824..fb6938aed 100644 --- a/src/duckdb/ub_src_function_scalar_variant.cpp +++ b/src/duckdb/ub_src_function_scalar_variant.cpp @@ -1,5 +1,11 @@ +#include "src/function/scalar/variant/variant_bind_utils.cpp" + +#include "src/function/scalar/variant/variant_exists.cpp" + #include "src/function/scalar/variant/variant_extract.cpp" +#include "src/function/scalar/variant/variant_keys.cpp" + #include "src/function/scalar/variant/variant_normalize.cpp" #include "src/function/scalar/variant/variant_typeof.cpp" diff --git a/src/duckdb/ub_src_function_table_system.cpp b/src/duckdb/ub_src_function_table_system.cpp index 5897012fb..7ded15b59 100644 --- a/src/duckdb/ub_src_function_table_system.cpp +++ b/src/duckdb/ub_src_function_table_system.cpp @@ -30,6 +30,8 @@ #include "src/function/table/system/duckdb_memory.cpp" +#include "src/function/table/system/duckdb_metrics.cpp" + #include "src/function/table/system/duckdb_optimizers.cpp" #include "src/function/table/system/duckdb_prepared_statements.cpp" diff --git a/src/duckdb/ub_src_main.cpp b/src/duckdb/ub_src_main.cpp index 9c6fa3a79..e61f67a80 100644 --- a/src/duckdb/ub_src_main.cpp +++ b/src/duckdb/ub_src_main.cpp @@ -42,16 +42,18 @@ #include "src/main/extension_manager.cpp" +#include "src/main/gathered_metrics.cpp" + #include "src/main/materialized_query_result.cpp" +#include "src/main/metrics_manager.cpp" + #include "src/main/pending_query_result.cpp" #include "src/main/prepared_statement.cpp" #include "src/main/prepared_statement_data.cpp" -#include "src/main/profiling_info.cpp" - #include "src/main/profiling_utils.cpp" #include "src/main/query_profiler.cpp" diff --git a/src/duckdb/ub_src_optimizer.cpp b/src/duckdb/ub_src_optimizer.cpp index 94608c4b4..5f936aea0 100644 --- a/src/duckdb/ub_src_optimizer.cpp +++ b/src/duckdb/ub_src_optimizer.cpp @@ -48,12 +48,16 @@ #include "src/optimizer/outer_join_simplification.cpp" +#include "src/optimizer/partial_aggregate_pushdown.cpp" + #include "src/optimizer/partitioned_execution.cpp" #include "src/optimizer/projection_pullup.cpp" #include "src/optimizer/regex_range_filter.cpp" +#include "src/optimizer/remote_pushdown_optimizer.cpp" + #include "src/optimizer/remove_duplicate_groups.cpp" #include "src/optimizer/remove_unused_columns.cpp" diff --git a/src/duckdb/ub_src_optimizer_join_order.cpp b/src/duckdb/ub_src_optimizer_join_order.cpp index b8ba49ccd..1fd90dc79 100644 --- a/src/duckdb/ub_src_optimizer_join_order.cpp +++ b/src/duckdb/ub_src_optimizer_join_order.cpp @@ -2,10 +2,14 @@ #include "src/optimizer/join_order/cost_model.cpp" +#include "src/optimizer/join_order/filter_info.cpp" + #include "src/optimizer/join_order/join_node.cpp" #include "src/optimizer/join_order/join_order_optimizer.cpp" +#include "src/optimizer/join_order/join_predicate.cpp" + #include "src/optimizer/join_order/join_relation_set.cpp" #include "src/optimizer/join_order/plan_enumerator.cpp" diff --git a/src/duckdb/ub_src_optimizer_rule.cpp b/src/duckdb/ub_src_optimizer_rule.cpp index 3b772d09b..bb336c4f5 100644 --- a/src/duckdb/ub_src_optimizer_rule.cpp +++ b/src/duckdb/ub_src_optimizer_rule.cpp @@ -10,6 +10,8 @@ #include "src/optimizer/rule/constant_order_normalization.cpp" +#include "src/optimizer/rule/contains_to_in_clause.cpp" + #include "src/optimizer/rule/date_part_simplification.cpp" #include "src/optimizer/rule/date_trunc_simplification.cpp" diff --git a/src/duckdb/ub_src_parallel.cpp b/src/duckdb/ub_src_parallel.cpp index cf5b9bcf0..bf2a536b3 100644 --- a/src/duckdb/ub_src_parallel.cpp +++ b/src/duckdb/ub_src_parallel.cpp @@ -32,5 +32,9 @@ #include "src/parallel/task_scheduler.cpp" +#include "src/parallel/task_scheduler_pool.cpp" + +#include "src/parallel/task_scheduler_queue.cpp" + #include "src/parallel/thread_context.cpp" diff --git a/src/duckdb/ub_src_parser_parsed_data.cpp b/src/duckdb/ub_src_parser_parsed_data.cpp index 9696f6d6b..b07703b51 100644 --- a/src/duckdb/ub_src_parser_parsed_data.cpp +++ b/src/duckdb/ub_src_parser_parsed_data.cpp @@ -12,6 +12,8 @@ #include "src/parser/parsed_data/comment_on_column_info.cpp" +#include "src/parser/parsed_data/connect_info.cpp" + #include "src/parser/parsed_data/copy_info.cpp" #include "src/parser/parsed_data/create_aggregate_function_info.cpp" @@ -54,6 +56,8 @@ #include "src/parser/parsed_data/detach_info.cpp" +#include "src/parser/parsed_data/disconnect_info.cpp" + #include "src/parser/parsed_data/drop_info.cpp" #include "src/parser/parsed_data/exported_table_data.cpp" diff --git a/src/duckdb/ub_src_parser_peg_transformer.cpp b/src/duckdb/ub_src_parser_peg_transformer.cpp index 38e9101bb..5fea064d8 100644 --- a/src/duckdb/ub_src_parser_peg_transformer.cpp +++ b/src/duckdb/ub_src_parser_peg_transformer.cpp @@ -16,6 +16,8 @@ #include "src/parser/peg/transformer/transform_common.cpp" +#include "src/parser/peg/transformer/transform_connect.cpp" + #include "src/parser/peg/transformer/transform_copy.cpp" #include "src/parser/peg/transformer/transform_create_index.cpp" @@ -56,7 +58,7 @@ #include "src/parser/peg/transformer/transform_generated.cpp" -#include "src/parser/peg/transformer/transform_import.cpp" +#include "src/parser/peg/transformer/transform_generic_copy_option.cpp" #include "src/parser/peg/transformer/transform_insert.cpp" diff --git a/src/duckdb/ub_src_parser_query_node.cpp b/src/duckdb/ub_src_parser_query_node.cpp index c97a578d5..1dcb8a845 100644 --- a/src/duckdb/ub_src_parser_query_node.cpp +++ b/src/duckdb/ub_src_parser_query_node.cpp @@ -4,6 +4,8 @@ #include "src/parser/query_node/insert_query_node.cpp" +#include "src/parser/query_node/merge_query_node.cpp" + #include "src/parser/query_node/recursive_cte_node.cpp" #include "src/parser/query_node/select_node.cpp" diff --git a/src/duckdb/ub_src_parser_statement.cpp b/src/duckdb/ub_src_parser_statement.cpp index 1f570d100..c892e913f 100644 --- a/src/duckdb/ub_src_parser_statement.cpp +++ b/src/duckdb/ub_src_parser_statement.cpp @@ -4,6 +4,8 @@ #include "src/parser/statement/call_statement.cpp" +#include "src/parser/statement/connect_statement.cpp" + #include "src/parser/statement/copy_database_statement.cpp" #include "src/parser/statement/copy_statement.cpp" @@ -14,6 +16,8 @@ #include "src/parser/statement/detach_statement.cpp" +#include "src/parser/statement/disconnect_statement.cpp" + #include "src/parser/statement/drop_statement.cpp" #include "src/parser/statement/execute_statement.cpp" diff --git a/src/duckdb/ub_src_planner.cpp b/src/duckdb/ub_src_planner.cpp index acfd17384..d37d33047 100644 --- a/src/duckdb/ub_src_planner.cpp +++ b/src/duckdb/ub_src_planner.cpp @@ -12,6 +12,8 @@ #include "src/planner/column_binding.cpp" +#include "src/planner/column_qualifier.cpp" + #include "src/planner/expression.cpp" #include "src/planner/expression_binder.cpp" @@ -32,8 +34,6 @@ #include "src/planner/table_binding.cpp" -#include "src/planner/table_filter.cpp" - #include "src/planner/table_filter_set.cpp" #include "src/planner/table_filter_state.cpp" diff --git a/src/duckdb/ub_src_planner_binder_statement.cpp b/src/duckdb/ub_src_planner_binder_statement.cpp index a0d0b3305..423cc5776 100644 --- a/src/duckdb/ub_src_planner_binder_statement.cpp +++ b/src/duckdb/ub_src_planner_binder_statement.cpp @@ -2,6 +2,8 @@ #include "src/planner/binder/statement/bind_call.cpp" +#include "src/planner/binder/statement/bind_connect.cpp" + #include "src/planner/binder/statement/bind_copy.cpp" #include "src/planner/binder/statement/bind_copy_database.cpp" diff --git a/src/duckdb/ub_src_planner_expression_binder.cpp b/src/duckdb/ub_src_planner_expression_binder.cpp index ad437b96a..a7fa87f3c 100644 --- a/src/duckdb/ub_src_planner_expression_binder.cpp +++ b/src/duckdb/ub_src_planner_expression_binder.cpp @@ -1,5 +1,3 @@ -#include "src/planner/expression_binder/aggregate_binder.cpp" - #include "src/planner/expression_binder/alter_binder.cpp" #include "src/planner/expression_binder/base_select_binder.cpp" @@ -36,8 +34,6 @@ #include "src/planner/expression_binder/table_function_binder.cpp" -#include "src/planner/expression_binder/try_operator_binder.cpp" - #include "src/planner/expression_binder/update_binder.cpp" #include "src/planner/expression_binder/where_binder.cpp" diff --git a/src/duckdb/ub_src_planner_filter.cpp b/src/duckdb/ub_src_planner_filter.cpp index c5b7de870..6a6f4169d 100644 --- a/src/duckdb/ub_src_planner_filter.cpp +++ b/src/duckdb/ub_src_planner_filter.cpp @@ -22,3 +22,17 @@ #include "src/planner/filter/struct_filter.cpp" +#include "src/planner/filter/table_filter_bloom_function.cpp" + +#include "src/planner/filter/table_filter_dynamic_function.cpp" + +#include "src/planner/filter/table_filter_functions.cpp" + +#include "src/planner/filter/table_filter_optional_function.cpp" + +#include "src/planner/filter/table_filter_perfect_hash_join_function.cpp" + +#include "src/planner/filter/table_filter_prefix_range_function.cpp" + +#include "src/planner/filter/table_filter_selectivity_optional_function.cpp" + diff --git a/src/duckdb/ub_src_planner_subquery.cpp b/src/duckdb/ub_src_planner_subquery.cpp index 98630c1ee..058d47ccc 100644 --- a/src/duckdb/ub_src_planner_subquery.cpp +++ b/src/duckdb/ub_src_planner_subquery.cpp @@ -1,3 +1,5 @@ +#include "src/planner/subquery/delim_join_cte_rewriter.cpp" + #include "src/planner/subquery/flatten_dependent_join.cpp" #include "src/planner/subquery/rewrite_correlated_expressions.cpp" diff --git a/src/duckdb/ub_src_storage.cpp b/src/duckdb/ub_src_storage.cpp index b9179c387..5b4f7029d 100644 --- a/src/duckdb/ub_src_storage.cpp +++ b/src/duckdb/ub_src_storage.cpp @@ -12,6 +12,8 @@ #include "src/storage/data_table.cpp" +#include "src/storage/database_handle.cpp" + #include "src/storage/index.cpp" #include "src/storage/local_storage.cpp" diff --git a/src/duckdb/ub_src_storage_compression.cpp b/src/duckdb/ub_src_storage_compression.cpp index b7759ce8e..50caa586f 100644 --- a/src/duckdb/ub_src_storage_compression.cpp +++ b/src/duckdb/ub_src_storage_compression.cpp @@ -20,6 +20,8 @@ #include "src/storage/compression/rle.cpp" +#include "src/storage/compression/standard_compression_state.cpp" + #include "src/storage/compression/string_uncompressed.cpp" #include "src/storage/compression/uncompressed.cpp" diff --git a/src/duckdb/ub_src_storage_table.cpp b/src/duckdb/ub_src_storage_table.cpp index bbd2de5eb..544a7f2da 100644 --- a/src/duckdb/ub_src_storage_table.cpp +++ b/src/duckdb/ub_src_storage_table.cpp @@ -16,6 +16,8 @@ #include "src/storage/table/list_column_data.cpp" +#include "src/storage/table/per_column_metadata_blocks.cpp" + #include "src/storage/table/persistent_table_data.cpp" #include "src/storage/table/row_group.cpp" diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index 76c6791d9..7a4a6d144 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -746,7 +746,7 @@ jobject ProcessVector(JNIEnv *env, Connection *conn_ref, Vector &vec, idx_t row_ continue; } Vector variant_vec(variant_val.type()); - variant_vec.SetValue(0, variant_val); + variant_vec.SetValue(0, variant_val); jobject variant_j_vec = ProcessVector(env, conn_ref, variant_vec, 1); env->CallVoidMethod(variant_j_vec, J_DuckVector_retainConstlenData); env->SetObjectArrayElement(varlen_data, row_idx, variant_j_vec);