Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/ops/group.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
typedef struct {
double sum_f, min_f, max_f, prod_f, first_f, last_f, sum_sq_f;
int64_t sum_i, min_i, max_i, prod_i, first_i, last_i, sum_sq_i;
/* Parallel f64 sum of the integer stream — used by AVG so the
* mean of an i64 column whose sum exceeds 2^63 (e.g. ClickBench
* UserID, signed values around ±9e18 × 10^7 rows) stays accurate
* instead of being whatever (uint64) wrap left in sum_i. */
double sum_d;
int64_t cnt;
int64_t null_count;
bool has_first;
Expand All @@ -44,6 +49,7 @@ static void reduce_acc_init(reduce_acc_t* acc) {
acc->prod_f = 1.0; acc->first_f = 0; acc->last_f = 0; acc->sum_sq_f = 0;
acc->sum_i = 0; acc->min_i = INT64_MAX; acc->max_i = INT64_MIN;
acc->prod_i = 1; acc->first_i = 0; acc->last_i = 0; acc->sum_sq_i = 0;
acc->sum_d = 0;
acc->cnt = 0; acc->null_count = 0; acc->has_first = false;
}

Expand Down Expand Up @@ -92,6 +98,7 @@ static inline bool sym_lex_gt(int64_t a, int64_t b) { return sym_lex_lt(b, a); }
(acc)->sum_i = (int64_t)((uint64_t)(acc)->sum_i + (uint64_t)v); \
(acc)->sum_sq_i = (int64_t)((uint64_t)(acc)->sum_sq_i + (uint64_t)v * (uint64_t)v); \
(acc)->prod_i = (int64_t)((uint64_t)(acc)->prod_i * (uint64_t)v); \
(acc)->sum_d += (double)v; \
if (v < (acc)->min_i) (acc)->min_i = v; \
if (v > (acc)->max_i) (acc)->max_i = v; \
if (!(acc)->has_first) { (acc)->first_i = v; (acc)->has_first = true; } \
Expand Down Expand Up @@ -159,6 +166,7 @@ static void reduce_range(ray_t* input, int64_t start, int64_t end,
acc->sum_i = (int64_t)((uint64_t)acc->sum_i + (uint64_t)v);
acc->sum_sq_i = (int64_t)((uint64_t)acc->sum_sq_i + (uint64_t)v * (uint64_t)v);
acc->prod_i = (int64_t)((uint64_t)acc->prod_i * (uint64_t)v);
acc->sum_d += (double)v;
if (v < acc->min_i) acc->min_i = v;
if (v > acc->max_i) acc->max_i = v;
if (!acc->has_first) { acc->first_i = v; acc->has_first = true; }
Expand Down Expand Up @@ -262,6 +270,7 @@ static void reduce_merge(reduce_acc_t* dst, const reduce_acc_t* src, int8_t in_t
dst->sum_i = (int64_t)((uint64_t)dst->sum_i + (uint64_t)src->sum_i);
dst->sum_sq_i = (int64_t)((uint64_t)dst->sum_sq_i + (uint64_t)src->sum_sq_i);
dst->prod_i = (int64_t)((uint64_t)dst->prod_i * (uint64_t)src->prod_i);
dst->sum_d += src->sum_d;
if (in_type == RAY_SYM) {
/* Lex compare for SYM min/max (see sym_lex_lt). */
if (src->cnt > 0) {
Expand Down Expand Up @@ -2051,7 +2060,7 @@ ray_t* exec_reduction(ray_graph_t* g, ray_op_t* op, ray_t* input) {
/* COUNT returns total length including nulls — matches ray_count_fn's
* "count all elements" semantics, not SQL's COUNT(col) non-null count. */
case OP_COUNT: result = ray_i64(scan_n); break;
case OP_AVG: result = merged.cnt > 0 ? ray_f64(in_type == RAY_F64 ? merged.sum_f / merged.cnt : (double)merged.sum_i / merged.cnt) : ray_typed_null(-RAY_F64); break;
case OP_AVG: result = merged.cnt > 0 ? ray_f64(in_type == RAY_F64 ? merged.sum_f / merged.cnt : merged.sum_d / merged.cnt) : ray_typed_null(-RAY_F64); break;
case OP_FIRST: result = merged.has_first ? (in_type == RAY_F64 ? ray_f64(merged.first_f) : reduction_i64_result(merged.first_i, in_type)) : ray_typed_null(-in_type); break;
case OP_LAST: result = merged.has_first ? (in_type == RAY_F64 ? ray_f64(merged.last_f) : reduction_i64_result(merged.last_i, in_type)) : ray_typed_null(-in_type); break;
case OP_VAR: case OP_VAR_POP:
Expand All @@ -2060,7 +2069,7 @@ ray_t* exec_reduction(ray_graph_t* g, ray_op_t* op, ray_t* input) {
if (insufficient) { result = ray_typed_null(-RAY_F64); break; }
double mean, var_pop;
if (in_type == RAY_F64) { mean = merged.sum_f / merged.cnt; var_pop = merged.sum_sq_f / merged.cnt - mean * mean; }
else { mean = (double)merged.sum_i / merged.cnt; var_pop = (double)merged.sum_sq_i / merged.cnt - mean * mean; }
else { mean = merged.sum_d / merged.cnt; var_pop = (double)merged.sum_sq_i / merged.cnt - mean * mean; }
if (var_pop < 0) var_pop = 0;
double val;
if (op->opcode == OP_VAR_POP) val = var_pop;
Expand Down Expand Up @@ -2090,7 +2099,7 @@ ray_t* exec_reduction(ray_graph_t* g, ray_op_t* op, ray_t* input) {
/* COUNT returns total length including nulls — matches ray_count_fn's
* "count all elements" semantics, not SQL's COUNT(col) non-null count. */
case OP_COUNT: return ray_i64(scan_n);
case OP_AVG: return acc.cnt > 0 ? ray_f64(in_type == RAY_F64 ? acc.sum_f / acc.cnt : (double)acc.sum_i / acc.cnt) : ray_typed_null(-RAY_F64);
case OP_AVG: return acc.cnt > 0 ? ray_f64(in_type == RAY_F64 ? acc.sum_f / acc.cnt : acc.sum_d / acc.cnt) : ray_typed_null(-RAY_F64);
case OP_FIRST: return acc.has_first ? (in_type == RAY_F64 ? ray_f64(acc.first_f) : reduction_i64_result(acc.first_i, in_type)) : ray_typed_null(-in_type);
case OP_LAST: return acc.has_first ? (in_type == RAY_F64 ? ray_f64(acc.last_f) : reduction_i64_result(acc.last_i, in_type)) : ray_typed_null(-in_type);
case OP_VAR: case OP_VAR_POP:
Expand All @@ -2099,7 +2108,7 @@ ray_t* exec_reduction(ray_graph_t* g, ray_op_t* op, ray_t* input) {
if (insufficient) return ray_typed_null(-RAY_F64);
double mean, var_pop;
if (in_type == RAY_F64) { mean = acc.sum_f / acc.cnt; var_pop = acc.sum_sq_f / acc.cnt - mean * mean; }
else { mean = (double)acc.sum_i / acc.cnt; var_pop = (double)acc.sum_sq_i / acc.cnt - mean * mean; }
else { mean = acc.sum_d / acc.cnt; var_pop = (double)acc.sum_sq_i / acc.cnt - mean * mean; }
if (var_pop < 0) var_pop = 0;
double val;
if (op->opcode == OP_VAR_POP) val = var_pop;
Expand Down
Loading