Skip to content
  •  
  •  
  •  
321 changes: 149 additions & 172 deletions CMakeLists.txt

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,14 @@ struct TimeTZAverageOperation : public BaseSumOperation<AverageSetOperation, Add
}
};

LogicalType GetAvgStateType(const AggregateFunction &function) {
LogicalType GetAvgStateType(const BoundAggregateFunction &function) {
child_list_t<LogicalType> children;
children.emplace_back("count", LogicalType::UBIGINT);
children.emplace_back("value", function.GetArguments()[0]);
return LogicalType::STRUCT(std::move(children));
}

LogicalType GetKahanAvgStateType(const AggregateFunction &function) {
LogicalType GetKahanAvgStateType(const BoundAggregateFunction &function) {
child_list_t<LogicalType> children;
children.emplace_back("count", LogicalType::UBIGINT);
children.emplace_back("value", LogicalType::DOUBLE);
Expand Down Expand Up @@ -289,9 +289,9 @@ AggregateFunction GetAverageAggregate(PhysicalType type) {
unique_ptr<FunctionData> BindDecimalAvg(BindAggregateFunctionInput &input) {
auto &function = input.GetBoundFunction();
auto &arguments = input.GetArguments();
auto decimal_type = arguments[0]->return_type;
function = GetAverageAggregate(decimal_type.InternalType());
function.name = "avg";
auto decimal_type = arguments[0]->GetReturnType();
function.ReplaceImplementation(GetAverageAggregate(decimal_type.InternalType()));
function.SetName("avg");
function.GetArguments()[0] = decimal_type;
function.SetReturnType(LogicalType::DOUBLE);
return make_uniq<AverageDecimalBindData>(
Expand Down
21 changes: 17 additions & 4 deletions src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,27 @@
namespace duckdb {

LogicalType GetCorrStateType() {
child_list_t<LogicalType> covar_children;
covar_children.emplace_back("count", LogicalType::UBIGINT);
covar_children.emplace_back("meanx", LogicalType::DOUBLE);
covar_children.emplace_back("meany", LogicalType::DOUBLE);
covar_children.emplace_back("co_moment", LogicalType::DOUBLE);
auto cov_pop_type = LogicalType::STRUCT(std::move(covar_children));

child_list_t<LogicalType> stddev_types;
stddev_types.emplace_back("count", LogicalType::UBIGINT);
stddev_types.emplace_back("mean", LogicalType::DOUBLE);
stddev_types.emplace_back("dsquared", LogicalType::DOUBLE);
auto stddev_type = LogicalType::STRUCT(std::move(stddev_types));

child_list_t<LogicalType> state_children;
state_children.emplace_back("cov_pop", CovarPopFun::GetFunction().GetStateType());
state_children.emplace_back("dev_pop_x", VarPopFun::GetFunction().GetStateType());
state_children.emplace_back("dev_pop_y", VarPopFun::GetFunction().GetStateType());
state_children.emplace_back("cov_pop", std::move(cov_pop_type));
state_children.emplace_back("dev_pop_x", stddev_type);
state_children.emplace_back("dev_pop_y", stddev_type);
return LogicalType::STRUCT(std::move(state_children));
}

LogicalType GetCorrExportStateType(const AggregateFunction &) {
LogicalType GetCorrExportStateType(const BoundAggregateFunction &) {
return GetCorrStateType();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace duckdb {

namespace {

LogicalType GetCovarStateType(const AggregateFunction &) {
LogicalType GetCovarStateType(const BoundAggregateFunction &) {
child_list_t<LogicalType> child_types;
child_types.emplace_back("count", LogicalType::UBIGINT);
child_types.emplace_back("meanx", LogicalType::DOUBLE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace duckdb {

namespace {

LogicalType GetStddevStateType(const AggregateFunction &) {
LogicalType GetStddevStateType(const BoundAggregateFunction &) {
child_list_t<LogicalType> child_types;
child_types.emplace_back("count", LogicalType::UBIGINT);
child_types.emplace_back("mean", LogicalType::DOUBLE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, id
D_ASSERT(input_count == 1);
auto &input = inputs[0];

auto input_validity = input.Validity(count);
auto input_validity = input.Validity();

if (count > STANDARD_VECTOR_SIZE) {
throw InternalException("ApproxCountDistinct - count must be at most vector size");
}
Vector hash_vec(LogicalType::HASH, count);
VectorOperations::Hash(input, hash_vec, count);

auto states = state_vector.Values<ApproxDistinctCountState *>(count);
auto hashes = hash_vec.Values<hash_t>(count);
auto states = state_vector.Values<ApproxDistinctCountState *>();
auto hashes = hash_vec.Values<hash_t>();
for (idx_t i = 0; i < count; i++) {
if (!input_validity.IsValid(i)) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ struct ArgMinMaxBase {
auto &context = input.GetClientContext();
auto &function = input.GetBoundFunction();
auto &arguments = input.GetArguments();
if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) {
ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type);
if (arguments[1]->GetReturnType().InternalType() == PhysicalType::VARCHAR) {
ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->GetReturnType());
}
function.GetArguments()[0] = arguments[0]->return_type;
function.SetReturnType(arguments[0]->return_type);
function.GetArguments()[0] = arguments[0]->GetReturnType();
function.SetReturnType(arguments[0]->GetReturnType());

auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
return unique_ptr<FunctionData>(std::move(function_data));
Expand All @@ -207,7 +207,7 @@ struct SpecializedGenericArgMinMaxState {
}

static void PrepareData(Vector &by, idx_t count, bool &, UnifiedVectorFormat &result) {
by.ToUnifiedFormat(count, result);
by.ToUnifiedFormat(result);
}
};

Expand All @@ -220,7 +220,7 @@ struct GenericArgMinMaxState {
static void PrepareData(Vector &by, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) {
OrderModifiers modifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST);
CreateSortKeyHelpers::CreateSortKeyWithValidity(by, extra_state, modifiers, count);
extra_state.ToUnifiedFormat(count, result);
extra_state.ToUnifiedFormat(result);
}
};

Expand All @@ -234,7 +234,7 @@ struct VectorArgMinMaxBase : ArgMinMaxBase<COMPARATOR> {

auto &arg = inputs[0];
UnifiedVectorFormat adata;
arg.ToUnifiedFormat(count, adata);
arg.ToUnifiedFormat(adata);

using ARG_TYPE = typename STATE::ARG_TYPE;
using BY_TYPE = typename STATE::BY_TYPE;
Expand All @@ -245,7 +245,7 @@ struct VectorArgMinMaxBase : ArgMinMaxBase<COMPARATOR> {
const auto bys = UnifiedVectorFormat::GetData<BY_TYPE>(bdata);

UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
state_vector.ToUnifiedFormat(sdata);

STATE *last_state = nullptr;
sel_t assign_sel[STANDARD_VECTOR_SIZE];
Expand Down Expand Up @@ -354,11 +354,11 @@ struct VectorArgMinMaxBase : ArgMinMaxBase<COMPARATOR> {
auto &context = input.GetClientContext();
auto &function = input.GetBoundFunction();
auto &arguments = input.GetArguments();
if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) {
ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type);
if (arguments[1]->GetReturnType().InternalType() == PhysicalType::VARCHAR) {
ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->GetReturnType());
}
function.GetArguments()[0] = arguments[0]->return_type;
function.SetReturnType(arguments[0]->return_type);
function.GetArguments()[0] = arguments[0]->GetReturnType();
function.SetReturnType(arguments[0]->GetReturnType());

auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
return unique_ptr<FunctionData>(std::move(function_data));
Expand Down Expand Up @@ -401,8 +401,9 @@ AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type,
AggregateFunction::StateDestroy<STATE, OP>);
#else
auto function = GetGenericArgMinMaxFunction<OP>(null_handling);
function.GetArguments() = {type, by_type};
function.return_type = type;
function.GetSignature().GetParameter(0).SetType(type);
function.GetSignature().GetParameter(1).SetType(by_type);
function.SetReturnType(type);
return function;
#endif
}
Expand Down Expand Up @@ -462,8 +463,9 @@ AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const
function.SetBindCallback(GetBindFunction<OP>(null_handling));
#else
auto function = GetGenericArgMinMaxFunction<OP>(null_handling);
function.GetArguments() = {type, by_type};
function.return_type = type;
function.GetSignature().GetParameter(0).SetType(type);
function.GetSignature().GetParameter(1).SetType(by_type);
function.SetReturnType(type);
#endif
return function;
}
Expand Down Expand Up @@ -526,8 +528,8 @@ unique_ptr<FunctionData> BindDecimalArgMinMax(BindAggregateFunctionInput &input)
auto &context = input.GetClientContext();
auto &function = input.GetBoundFunction();
auto &arguments = input.GetArguments();
auto decimal_type = arguments[0]->return_type;
auto by_type = arguments[1]->return_type;
auto decimal_type = arguments[0]->GetReturnType();
auto by_type = arguments[1]->GetReturnType();

// To avoid a combinatorial explosion, cast the ordering argument to one from the list
auto by_types = ArgMaxByTypes();
Expand All @@ -554,9 +556,9 @@ unique_ptr<FunctionData> BindDecimalArgMinMax(BindAggregateFunctionInput &input)
by_type = by_types[best_target];
}

auto name = std::move(function.name);
function = GetDecimalArgMinMaxFunction<OP>(by_type, decimal_type, NULL_HANDLING);
function.name = std::move(name);
auto name = function.GetName();
function.ReplaceImplementation(GetDecimalArgMinMaxFunction<OP>(by_type, decimal_type, NULL_HANDLING));
function.SetName(std::move(name));
function.SetReturnType(decimal_type);

auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING);
Expand Down Expand Up @@ -669,8 +671,8 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp
STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format, bind_data.nulls_last);
STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format, bind_data.nulls_last);

n_vector.ToUnifiedFormat(count, n_format);
state_vector.ToUnifiedFormat(count, state_format);
n_vector.ToUnifiedFormat(n_format);
state_vector.ToUnifiedFormat(state_format);

auto states = UnifiedVectorFormat::GetData<STATE *>(state_format);

Expand Down Expand Up @@ -719,7 +721,7 @@ void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp
// Bind
//------------------------------------------------------------------------------
template <class VAL_TYPE, class ARG_TYPE, class COMPARATOR>
void SpecializeArgMinMaxNFunction(AggregateFunction &function) {
void SpecializeArgMinMaxNFunction(BoundAggregateFunction &function) {
using STATE = ArgMinMaxNState<VAL_TYPE, ARG_TYPE, COMPARATOR>;
using OP = MinMaxNOperation;

Expand All @@ -733,7 +735,7 @@ void SpecializeArgMinMaxNFunction(AggregateFunction &function) {
}

template <class VAL_TYPE, class COMPARATOR>
void SpecializeArgMinMaxNFunction(PhysicalType arg_type, AggregateFunction &function) {
void SpecializeArgMinMaxNFunction(PhysicalType arg_type, BoundAggregateFunction &function) {
switch (arg_type) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::VARCHAR:
Expand All @@ -759,7 +761,7 @@ void SpecializeArgMinMaxNFunction(PhysicalType arg_type, AggregateFunction &func
}

template <class COMPARATOR>
void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) {
void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, BoundAggregateFunction &function) {
switch (val_type) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::VARCHAR:
Expand All @@ -785,7 +787,7 @@ void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type,
}

template <class VAL_TYPE, class ARG_TYPE, class COMPARATOR>
void SpecializeArgMinMaxNullNFunction(AggregateFunction &function) {
void SpecializeArgMinMaxNullNFunction(BoundAggregateFunction &function) {
using STATE = ArgMinMaxNState<VAL_TYPE, ARG_TYPE, COMPARATOR>;
using OP = MinMaxNOperation;

Expand All @@ -798,7 +800,7 @@ void SpecializeArgMinMaxNullNFunction(AggregateFunction &function) {
}

template <class VAL_TYPE, bool NULLS_LAST, class COMPARATOR>
void SpecializeArgMinMaxNullNFunction(PhysicalType arg_type, AggregateFunction &function) {
void SpecializeArgMinMaxNullNFunction(PhysicalType arg_type, BoundAggregateFunction &function) {
switch (arg_type) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::VARCHAR:
Expand All @@ -824,7 +826,7 @@ void SpecializeArgMinMaxNullNFunction(PhysicalType arg_type, AggregateFunction &
}

template <bool NULLS_LAST, class COMPARATOR>
void SpecializeArgMinMaxNullNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) {
void SpecializeArgMinMaxNullNFunction(PhysicalType val_type, PhysicalType arg_type, BoundAggregateFunction &function) {
switch (val_type) {
#ifndef DUCKDB_SMALLER_BINARY
case PhysicalType::VARCHAR:
Expand Down Expand Up @@ -858,14 +860,14 @@ unique_ptr<FunctionData> ArgMinMaxNBind(BindAggregateFunctionInput &input) {
auto &function = input.GetBoundFunction();
auto &arguments = input.GetArguments();
for (auto &arg : arguments) {
if (arg->return_type.id() == LogicalTypeId::UNKNOWN) {
if (arg->GetReturnType().id() == LogicalTypeId::UNKNOWN) {
throw ParameterNotResolvedException();
}
}

const auto val_type = arguments[0]->return_type.InternalType();
const auto arg_type = arguments[1]->return_type.InternalType();
function.SetReturnType(LogicalType::LIST(arguments[0]->return_type));
const auto val_type = arguments[0]->GetReturnType().InternalType();
const auto arg_type = arguments[1]->GetReturnType().InternalType();
function.SetReturnType(LogicalType::LIST(arguments[0]->GetReturnType()));

// Specialize the function based on the input types
auto function_data = make_uniq<ArgMinMaxFunctionData>(NULL_HANDLING, NULLS_LAST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct BitState {
};

template <class T>
LogicalType GetBitStateType(const AggregateFunction &function) {
LogicalType GetBitStateType(const BoundAggregateFunction &function) {
child_list_t<LogicalType> child_types;
child_types.emplace_back("is_set", LogicalType::BOOLEAN);

Expand All @@ -28,7 +28,7 @@ LogicalType GetBitStateType(const AggregateFunction &function) {
return LogicalType::STRUCT(std::move(child_types));
}

LogicalType GetBitStringStateType(const AggregateFunction &function) {
LogicalType GetBitStringStateType(const BoundAggregateFunction &function) {
child_list_t<LogicalType> child_types;
child_types.emplace_back("is_set", LogicalType::BOOLEAN);
child_types.emplace_back("value", function.GetReturnType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ struct BitstringAggBindData : public FunctionData {
}

static void Serialize(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &) {
const BoundAggregateFunction &) {
auto &bind_data = bind_data_p->Cast<BitstringAggBindData>();
serializer.WriteProperty(100, "min", bind_data.min);
serializer.WriteProperty(101, "max", bind_data.max);
}

static unique_ptr<FunctionData> Deserialize(Deserializer &deserializer, AggregateFunction &) {
static unique_ptr<FunctionData> Deserialize(Deserializer &deserializer, BoundAggregateFunction &) {
Value min;
Value max;
deserializer.ReadProperty(100, "min", min);
Expand Down Expand Up @@ -271,7 +271,7 @@ void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &typ
function.SetStatisticsCallback(
BitstringPropagateStats); // stores min and max from column stats in BitstringAggBindData
bitstring_agg.AddFunction(function); // uses the BitstringAggBindData to access statistics for creating bitstring
function.GetArguments() = {type, type, type};
function.GetSignature() = FunctionSignature({type, type, type}, LogicalType::BIT);
function.SetStatisticsCallback(nullptr); // min and max are provided as arguments
bitstring_agg.AddFunction(function);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct BoolOrFunFunction {
}
};

LogicalType GetBoolAndStateType(const AggregateFunction &function) {
LogicalType GetBoolAndStateType(const BoundAggregateFunction &function) {
child_list_t<LogicalType> child_types;
child_types.emplace_back("empty", LogicalType::BOOLEAN);
child_types.emplace_back("val", LogicalType::BOOLEAN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ struct KurtosisOperation {
}
};

LogicalType GetKurtosisStateType(const AggregateFunction &function) {
LogicalType GetKurtosisStateType(const BoundAggregateFunction &function) {
child_list_t<LogicalType> children;
children.emplace_back("n", LogicalType::UBIGINT);
children.emplace_back("sum", LogicalType::DOUBLE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct ProductFunction {
}
};

LogicalType GetProductStateType(const AggregateFunction &function) {
LogicalType GetProductStateType(const BoundAggregateFunction &function) {
child_list_t<LogicalType> children;
children.emplace_back("empty", LogicalType::BOOLEAN);
children.emplace_back("val", LogicalType::DOUBLE);
Expand Down
Loading