Skip to content

Commit e838fb1

Browse files
committed
use AggregateFunctionPtr if possible
Signed-off-by: Murphy <mofei@starrocks.com>
1 parent 8d07edf commit e838fb1

4 files changed

Lines changed: 23 additions & 18 deletions

File tree

be/src/exprs/agg/aggregate.cpp

Whitespace-only changes.

be/src/exprs/agg/factory/aggregate_factory.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ template <typename NestedState, IsAggNullPred<NestedState> AggNullPredType>
353353
AggregateFunctionPtr AggregateFactory::MakeNullableAggregateFunctionVariadic(AggregateFunctionPtr nested_function,
354354
AggNullPredType null_pred) {
355355
using AggregateDataType = NullableAggregateFunctionState<NestedState, false>;
356-
return new NullableAggregateFunctionVariadic<AggregateFunctionPtr, AggregateDataType, AggNullPredType>(
357-
nested_function, std::move(null_pred));
356+
return new NullableAggregateFunctionVariadic<AggregateDataType, AggNullPredType>(nested_function,
357+
std::move(null_pred));
358358
}
359359

360360
template <LogicalType LT>

be/src/exprs/agg/factory/aggregate_resolver.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,16 @@ class AggregateFuncResolver {
214214
}
215215

216216
template <LogicalType ArgLT, LogicalType ResultLT, bool IsNull>
217-
const AggregateFunction* create_array_function(std::string& name) {
217+
AggregateFunctionPtr create_array_function(std::string& name) {
218218
if constexpr (IsNull) {
219219
if (name == "dict_merge") {
220220
auto dict_merge = track_function(AggregateFactory::MakeDictMergeAggregateFunction());
221221
return track_function(
222222
AggregateFactory::MakeNullableAggregateFunctionUnary<DictMergeState, false>(dict_merge));
223223
} else if (name == "retention") {
224-
auto retentoin = track_function(AggregateFactory::MakeRetentionAggregateFunction());
224+
auto retention = track_function(AggregateFactory::MakeRetentionAggregateFunction());
225225
return track_function(
226-
AggregateFactory::MakeNullableAggregateFunctionUnary<RetentionState, false>(retentoin));
226+
AggregateFactory::MakeNullableAggregateFunctionUnary<RetentionState, false>(retention));
227227
} else if (name == "window_funnel") {
228228
if constexpr (ArgLT == TYPE_INT || ArgLT == TYPE_BIGINT || ArgLT == TYPE_DATE ||
229229
ArgLT == TYPE_DATETIME) {
@@ -250,7 +250,7 @@ class AggregateFuncResolver {
250250
}
251251

252252
template <LogicalType ArgLT, LogicalType ResultLT, bool IsWindowFunc, bool IsNull>
253-
std::enable_if_t<isArithmeticLT<ArgLT>, const AggregateFunction*> create_decimal_function(std::string& name) {
253+
std::enable_if_t<isArithmeticLT<ArgLT>, AggregateFunctionPtr> create_decimal_function(std::string& name) {
254254
static_assert(lt_is_decimal128<ResultLT> || lt_is_decimal256<ResultLT>);
255255
if constexpr (IsNull) {
256256
using ResultType = RunTimeCppType<ResultLT>;

be/src/exprs/agg/nullable_aggregate.h

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "column/column.h"
18+
#include "exprs/agg/aggregate.h"
1819
#ifdef __x86_64__
1920
#include <immintrin.h>
2021
#elif defined(__ARM_NEON) && defined(__aarch64__)
@@ -108,7 +109,7 @@ struct AggNonNullPred {
108109
// For this case, the serialized output type is non-nullable, because only the state of input needs to be serialized.
109110
// If all the rows are NULL or `AggNullPred` returns true, we will return NULL.
110111
// The State must be NullableAggregateFunctionState
111-
template <typename NestedFuncPtr, typename State, bool IsWindowFunc, bool IgnoreNull = true,
112+
template <typename NestedAggregateFunctionPtr, typename State, bool IsWindowFunc, bool IgnoreNull = true,
112113
IsAggNullPred<typename State::NestedState> AggNullPred = AggNonNullPred<typename State::NestedState>>
113114
class NullableAggregateFunctionBase : public AggregateFunctionStateHelper<State> {
114115
using NestedState = typename State::NestedState;
@@ -117,7 +118,8 @@ class NullableAggregateFunctionBase : public AggregateFunctionStateHelper<State>
117118
public:
118119
bool is_exception_safe() const override { return nested_function->is_exception_safe(); }
119120

120-
explicit NullableAggregateFunctionBase(NestedFuncPtr nested_function_, AggNullPred null_pred = AggNullPred())
121+
explicit NullableAggregateFunctionBase(NestedAggregateFunctionPtr nested_function_,
122+
AggNullPred null_pred = AggNullPred())
121123
: nested_function(std::move(nested_function_)), null_pred(std::move(null_pred)) {}
122124
// as array_agg is not nullable, so it needn't create() here.
123125

@@ -311,17 +313,19 @@ class NullableAggregateFunctionBase : public AggregateFunctionStateHelper<State>
311313
}
312314

313315
protected:
314-
NestedFuncPtr nested_function;
316+
NestedAggregateFunctionPtr nested_function;
315317
AggNullPred null_pred;
316318
};
317319

318-
template <typename NestedFuncPtr, typename State, bool IsWindowFunc, bool IgnoreNull = true,
320+
template <typename NestedAggregateFunctionPtr, typename State, bool IsWindowFunc, bool IgnoreNull = true,
319321
IsAggNullPred<typename State::NestedState> AggNullPred = AggNonNullPred<typename State::NestedState>>
320322
class NullableAggregateFunctionUnary final
321-
: public NullableAggregateFunctionBase<NestedFuncPtr, State, IsWindowFunc, IgnoreNull, AggNullPred> {
323+
: public NullableAggregateFunctionBase<NestedAggregateFunctionPtr, State, IsWindowFunc, IgnoreNull,
324+
AggNullPred> {
322325
public:
323-
explicit NullableAggregateFunctionUnary(const NestedFuncPtr& nested_function, AggNullPred null_pred = AggNullPred())
324-
: NullableAggregateFunctionBase<NestedFuncPtr, State, IsWindowFunc, IgnoreNull, AggNullPred>(
326+
explicit NullableAggregateFunctionUnary(const NestedAggregateFunctionPtr& nested_function,
327+
AggNullPred null_pred = AggNullPred())
328+
: NullableAggregateFunctionBase<NestedAggregateFunctionPtr, State, IsWindowFunc, IgnoreNull, AggNullPred>(
325329
nested_function, std::move(null_pred)) {}
326330

327331
// NOTE: In stream MV, need handle input row by row, so need support single update.
@@ -948,14 +952,15 @@ class NullableAggregateFunctionUnary final
948952
}
949953
};
950954

951-
template <typename NestedFuncPtr, typename State,
955+
template <typename State,
952956
IsAggNullPred<typename State::NestedState> AggNullPredType = AggNonNullPred<typename State::NestedState>>
953957
class NullableAggregateFunctionVariadic final
954-
: public NullableAggregateFunctionBase<NestedFuncPtr, State, false, true, AggNullPredType> {
958+
: public NullableAggregateFunctionBase<AggregateFunctionPtr, State, false, true, AggNullPredType> {
955959
public:
956-
NullableAggregateFunctionVariadic(NestedFuncPtr nested_function, AggNullPredType null_pred = AggNullPredType())
957-
: NullableAggregateFunctionBase<NestedFuncPtr, State, false, true, AggNullPredType>(
958-
std::move(nested_function), std::move(null_pred)) {}
960+
NullableAggregateFunctionVariadic(AggregateFunctionPtr nested_function,
961+
AggNullPredType null_pred = AggNullPredType())
962+
: NullableAggregateFunctionBase<AggregateFunctionPtr, State, false, true, AggNullPredType>(
963+
nested_function, std::move(null_pred)) {}
959964

960965
void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
961966
size_t row_num) const override {

0 commit comments

Comments
 (0)