diff --git a/src/ops/group.c b/src/ops/group.c index d2402231..aa9509c7 100644 --- a/src/ops/group.c +++ b/src/ops/group.c @@ -34,6 +34,11 @@ typedef struct { double sum_f, min_f, max_f, prod_f, first_f, last_f, sum_sq_f; int64_t sum_i, min_i, max_i, prod_i, first_i, last_i, sum_sq_i; + /* Parallel f64 sum of the integer stream — used by AVG so the + * mean of an i64 column whose sum exceeds 2^63 (e.g. ClickBench + * UserID, signed values around ±9e18 × 10^7 rows) stays accurate + * instead of being whatever (uint64) wrap left in sum_i. */ + double sum_d; int64_t cnt; int64_t null_count; bool has_first; @@ -44,6 +49,7 @@ static void reduce_acc_init(reduce_acc_t* acc) { acc->prod_f = 1.0; acc->first_f = 0; acc->last_f = 0; acc->sum_sq_f = 0; acc->sum_i = 0; acc->min_i = INT64_MAX; acc->max_i = INT64_MIN; acc->prod_i = 1; acc->first_i = 0; acc->last_i = 0; acc->sum_sq_i = 0; + acc->sum_d = 0; acc->cnt = 0; acc->null_count = 0; acc->has_first = false; } @@ -92,6 +98,7 @@ static inline bool sym_lex_gt(int64_t a, int64_t b) { return sym_lex_lt(b, a); } (acc)->sum_i = (int64_t)((uint64_t)(acc)->sum_i + (uint64_t)v); \ (acc)->sum_sq_i = (int64_t)((uint64_t)(acc)->sum_sq_i + (uint64_t)v * (uint64_t)v); \ (acc)->prod_i = (int64_t)((uint64_t)(acc)->prod_i * (uint64_t)v); \ + (acc)->sum_d += (double)v; \ if (v < (acc)->min_i) (acc)->min_i = v; \ if (v > (acc)->max_i) (acc)->max_i = v; \ if (!(acc)->has_first) { (acc)->first_i = v; (acc)->has_first = true; } \ @@ -159,6 +166,7 @@ static void reduce_range(ray_t* input, int64_t start, int64_t end, acc->sum_i = (int64_t)((uint64_t)acc->sum_i + (uint64_t)v); acc->sum_sq_i = (int64_t)((uint64_t)acc->sum_sq_i + (uint64_t)v * (uint64_t)v); acc->prod_i = (int64_t)((uint64_t)acc->prod_i * (uint64_t)v); + acc->sum_d += (double)v; if (v < acc->min_i) acc->min_i = v; if (v > acc->max_i) acc->max_i = v; if (!acc->has_first) { acc->first_i = v; acc->has_first = true; } @@ -262,6 +270,7 @@ static void reduce_merge(reduce_acc_t* dst, const reduce_acc_t* src, int8_t in_t dst->sum_i = (int64_t)((uint64_t)dst->sum_i + (uint64_t)src->sum_i); dst->sum_sq_i = (int64_t)((uint64_t)dst->sum_sq_i + (uint64_t)src->sum_sq_i); dst->prod_i = (int64_t)((uint64_t)dst->prod_i * (uint64_t)src->prod_i); + dst->sum_d += src->sum_d; if (in_type == RAY_SYM) { /* Lex compare for SYM min/max (see sym_lex_lt). */ if (src->cnt > 0) { @@ -2051,7 +2060,7 @@ ray_t* exec_reduction(ray_graph_t* g, ray_op_t* op, ray_t* input) { /* COUNT returns total length including nulls — matches ray_count_fn's * "count all elements" semantics, not SQL's COUNT(col) non-null count. */ case OP_COUNT: result = ray_i64(scan_n); break; - case OP_AVG: result = merged.cnt > 0 ? ray_f64(in_type == RAY_F64 ? merged.sum_f / merged.cnt : (double)merged.sum_i / merged.cnt) : ray_typed_null(-RAY_F64); break; + case OP_AVG: result = merged.cnt > 0 ? ray_f64(in_type == RAY_F64 ? merged.sum_f / merged.cnt : merged.sum_d / merged.cnt) : ray_typed_null(-RAY_F64); break; case OP_FIRST: result = merged.has_first ? (in_type == RAY_F64 ? ray_f64(merged.first_f) : reduction_i64_result(merged.first_i, in_type)) : ray_typed_null(-in_type); break; case OP_LAST: result = merged.has_first ? (in_type == RAY_F64 ? ray_f64(merged.last_f) : reduction_i64_result(merged.last_i, in_type)) : ray_typed_null(-in_type); break; case OP_VAR: case OP_VAR_POP: @@ -2060,7 +2069,7 @@ ray_t* exec_reduction(ray_graph_t* g, ray_op_t* op, ray_t* input) { if (insufficient) { result = ray_typed_null(-RAY_F64); break; } double mean, var_pop; if (in_type == RAY_F64) { mean = merged.sum_f / merged.cnt; var_pop = merged.sum_sq_f / merged.cnt - mean * mean; } - else { mean = (double)merged.sum_i / merged.cnt; var_pop = (double)merged.sum_sq_i / merged.cnt - mean * mean; } + else { mean = merged.sum_d / merged.cnt; var_pop = (double)merged.sum_sq_i / merged.cnt - mean * mean; } if (var_pop < 0) var_pop = 0; double val; if (op->opcode == OP_VAR_POP) val = var_pop; @@ -2090,7 +2099,7 @@ ray_t* exec_reduction(ray_graph_t* g, ray_op_t* op, ray_t* input) { /* COUNT returns total length including nulls — matches ray_count_fn's * "count all elements" semantics, not SQL's COUNT(col) non-null count. */ case OP_COUNT: return ray_i64(scan_n); - case OP_AVG: return acc.cnt > 0 ? ray_f64(in_type == RAY_F64 ? acc.sum_f / acc.cnt : (double)acc.sum_i / acc.cnt) : ray_typed_null(-RAY_F64); + case OP_AVG: return acc.cnt > 0 ? ray_f64(in_type == RAY_F64 ? acc.sum_f / acc.cnt : acc.sum_d / acc.cnt) : ray_typed_null(-RAY_F64); case OP_FIRST: return acc.has_first ? (in_type == RAY_F64 ? ray_f64(acc.first_f) : reduction_i64_result(acc.first_i, in_type)) : ray_typed_null(-in_type); case OP_LAST: return acc.has_first ? (in_type == RAY_F64 ? ray_f64(acc.last_f) : reduction_i64_result(acc.last_i, in_type)) : ray_typed_null(-in_type); case OP_VAR: case OP_VAR_POP: @@ -2099,7 +2108,7 @@ ray_t* exec_reduction(ray_graph_t* g, ray_op_t* op, ray_t* input) { if (insufficient) return ray_typed_null(-RAY_F64); double mean, var_pop; if (in_type == RAY_F64) { mean = acc.sum_f / acc.cnt; var_pop = acc.sum_sq_f / acc.cnt - mean * mean; } - else { mean = (double)acc.sum_i / acc.cnt; var_pop = (double)acc.sum_sq_i / acc.cnt - mean * mean; } + else { mean = acc.sum_d / acc.cnt; var_pop = (double)acc.sum_sq_i / acc.cnt - mean * mean; } if (var_pop < 0) var_pop = 0; double val; if (op->opcode == OP_VAR_POP) val = var_pop;