Skip to content

Commit 19b8d5a

Browse files
committed
Initial changes
1 parent b28d55d commit 19b8d5a

13 files changed

Lines changed: 1211 additions & 5 deletions

cpp/core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
128128
memory/MemoryManager.cc
129129
memory/ArrowMemoryPool.cc
130130
memory/ColumnarBatch.cc
131+
shuffle/BlockStatistics.cc
131132
shuffle/Dictionary.cc
132133
shuffle/FallbackRangePartitioner.cc
133134
shuffle/HashPartitioner.cc
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "shuffle/BlockStatistics.h"
19+
20+
#include <arrow/buffer.h>
21+
#include <arrow/type.h>
22+
#include <arrow/util/bit_util.h>
23+
24+
namespace gluten {
25+
namespace {
26+
27+
// Returns true if the row at the given index is valid (non-null).
28+
inline bool isRowValid(const std::shared_ptr<arrow::Buffer>& validityBuffer, uint32_t row) {
29+
if (!validityBuffer) {
30+
return true; // No validity buffer means all rows are valid.
31+
}
32+
return arrow::bit_util::GetBit(validityBuffer->data(), row);
33+
}
34+
35+
// Returns true if the column has any null rows.
36+
bool hasAnyNull(const std::shared_ptr<arrow::Buffer>& validityBuffer, uint32_t numRows) {
37+
if (!validityBuffer || numRows == 0) {
38+
return false;
39+
}
40+
// Check each bit — return early on first null found.
41+
for (uint32_t i = 0; i < numRows; ++i) {
42+
if (!arrow::bit_util::GetBit(validityBuffer->data(), i)) {
43+
return true;
44+
}
45+
}
46+
return false;
47+
}
48+
49+
template <typename T>
50+
void writeBytes(uint8_t*& dst, T value) {
51+
memcpy(dst, &value, sizeof(T));
52+
dst += sizeof(T);
53+
}
54+
55+
template <typename T>
56+
T readBytes(const uint8_t*& src) {
57+
T value;
58+
memcpy(&value, src, sizeof(T));
59+
src += sizeof(T);
60+
return value;
61+
}
62+
63+
template <typename T>
64+
void scanColumnMinMax(
65+
const std::shared_ptr<arrow::Buffer>& validityBuffer,
66+
const std::shared_ptr<arrow::Buffer>& valueBuffer,
67+
uint32_t numRows,
68+
ColumnStatistics& stats) {
69+
if (!valueBuffer || valueBuffer->size() == 0 || numRows == 0) {
70+
return;
71+
}
72+
73+
const auto* values = reinterpret_cast<const T*>(valueBuffer->data());
74+
bool foundAny = false;
75+
T minVal{};
76+
T maxVal{};
77+
78+
for (uint32_t i = 0; i < numRows; ++i) {
79+
if (!isRowValid(validityBuffer, i)) {
80+
continue;
81+
}
82+
T val = values[i];
83+
if (!foundAny) {
84+
minVal = val;
85+
maxVal = val;
86+
foundAny = true;
87+
} else {
88+
if (val < minVal) {
89+
minVal = val;
90+
}
91+
if (val > maxVal) {
92+
maxVal = val;
93+
}
94+
}
95+
}
96+
97+
if (foundAny) {
98+
stats.hasStats = true;
99+
stats.setMin(minVal);
100+
stats.setMax(maxVal);
101+
}
102+
}
103+
104+
} // namespace
105+
106+
void ColumnStatistics::merge(const ColumnStatistics& other) {
107+
hasNull = hasNull || other.hasNull;
108+
if (!other.hasStats) {
109+
return;
110+
}
111+
if (!hasStats) {
112+
hasStats = true;
113+
memcpy(minBytes, other.minBytes, 8);
114+
memcpy(maxBytes, other.maxBytes, 8);
115+
return;
116+
}
117+
// Both have stats — merge based on type.
118+
switch (static_cast<arrow::Type::type>(typeId)) {
119+
case arrow::Type::INT8:
120+
mergeTyped<int8_t>(other);
121+
break;
122+
case arrow::Type::INT16:
123+
mergeTyped<int16_t>(other);
124+
break;
125+
case arrow::Type::INT32:
126+
case arrow::Type::DATE32:
127+
mergeTyped<int32_t>(other);
128+
break;
129+
case arrow::Type::INT64:
130+
case arrow::Type::DATE64:
131+
case arrow::Type::TIMESTAMP:
132+
mergeTyped<int64_t>(other);
133+
break;
134+
case arrow::Type::FLOAT:
135+
mergeTyped<float>(other);
136+
break;
137+
case arrow::Type::DOUBLE:
138+
mergeTyped<double>(other);
139+
break;
140+
default:
141+
break;
142+
}
143+
}
144+
145+
arrow::Status BlockStatistics::serialize(arrow::io::OutputStream* out, int64_t payloadSize) const {
146+
uint32_t size = serializedSize();
147+
std::vector<uint8_t> buffer(size);
148+
uint8_t* ptr = buffer.data();
149+
150+
writeBytes(ptr, kVersion);
151+
writeBytes(ptr, static_cast<uint16_t>(columnStats.size()));
152+
writeBytes(ptr, payloadSize);
153+
154+
for (const auto& col : columnStats) {
155+
col.serialize(ptr);
156+
}
157+
158+
return out->Write(buffer.data(), size);
159+
}
160+
161+
arrow::Result<std::pair<BlockStatistics, int64_t>> BlockStatistics::deserialize(arrow::io::InputStream* in) {
162+
// Read version.
163+
uint8_t version;
164+
ARROW_ASSIGN_OR_RAISE(auto bytesRead, in->Read(sizeof(version), &version));
165+
if (bytesRead != sizeof(version) || version != kVersion) {
166+
return arrow::Status::Invalid("Unsupported BlockStatistics version: ", static_cast<int>(version));
167+
}
168+
169+
// Read numColumns.
170+
uint16_t numColumns;
171+
ARROW_ASSIGN_OR_RAISE(bytesRead, in->Read(sizeof(numColumns), &numColumns));
172+
if (bytesRead != sizeof(numColumns)) {
173+
return arrow::Status::IOError("Unexpected end of stream reading BlockStatistics numColumns");
174+
}
175+
176+
// Read payloadSize.
177+
int64_t payloadSize;
178+
ARROW_ASSIGN_OR_RAISE(bytesRead, in->Read(sizeof(payloadSize), &payloadSize));
179+
if (bytesRead != sizeof(payloadSize)) {
180+
return arrow::Status::IOError("Unexpected end of stream reading BlockStatistics payloadSize");
181+
}
182+
183+
BlockStatistics stats;
184+
stats.columnStats.reserve(numColumns);
185+
186+
for (uint16_t i = 0; i < numColumns; ++i) {
187+
uint8_t buf[ColumnStatistics::kSerializedSize];
188+
ARROW_ASSIGN_OR_RAISE(bytesRead, in->Read(sizeof(buf), buf));
189+
if (bytesRead != sizeof(buf)) {
190+
return arrow::Status::IOError("Unexpected end of stream reading BlockStatistics column ", i);
191+
}
192+
const uint8_t* ptr = buf;
193+
stats.columnStats.push_back(ColumnStatistics::deserialize(ptr));
194+
}
195+
196+
return std::make_pair(std::move(stats), payloadSize);
197+
}
198+
199+
void BlockStatistics::merge(const BlockStatistics& other) {
200+
for (size_t i = 0; i < columnStats.size() && i < other.columnStats.size(); ++i) {
201+
columnStats[i].merge(other.columnStats[i]);
202+
}
203+
}
204+
205+
BlockStatistics computeBlockStatistics(
206+
const std::shared_ptr<arrow::Schema>& schema,
207+
const std::vector<std::shared_ptr<arrow::Buffer>>& buffers,
208+
uint32_t numRows,
209+
bool hasComplexType) {
210+
BlockStatistics result;
211+
if (numRows == 0 || buffers.empty()) {
212+
return result;
213+
}
214+
215+
uint32_t bufIdx = 0;
216+
auto numFields = schema->num_fields();
217+
218+
for (int fieldIdx = 0; fieldIdx < numFields; ++fieldIdx) {
219+
auto typeId = schema->field(fieldIdx)->type()->id();
220+
221+
switch (typeId) {
222+
case arrow::Type::BINARY:
223+
case arrow::Type::STRING:
224+
case arrow::Type::LARGE_BINARY:
225+
case arrow::Type::LARGE_STRING: {
226+
if (bufIdx + 3 > buffers.size()) {
227+
break;
228+
}
229+
auto validityBuf = buffers[bufIdx++]; // validity
230+
bufIdx++; // length (skip)
231+
bufIdx++; // value (skip)
232+
233+
ColumnStatistics col{};
234+
col.columnIndex = static_cast<uint16_t>(fieldIdx);
235+
col.typeId = static_cast<uint8_t>(typeId);
236+
col.hasNull = hasAnyNull(validityBuf, numRows);
237+
col.hasStats = false; // String stats not supported yet.
238+
result.columnStats.push_back(col);
239+
break;
240+
}
241+
case arrow::Type::STRUCT:
242+
case arrow::Type::MAP:
243+
case arrow::Type::LIST:
244+
case arrow::Type::LARGE_LIST:
245+
// Complex types are skipped in assembleBuffers() per-field loop.
246+
// Their buffer is appended at the end. No stats for them.
247+
break;
248+
case arrow::Type::NA:
249+
// Null type has no buffers.
250+
break;
251+
case arrow::Type::BOOL: {
252+
if (bufIdx + 2 > buffers.size()) {
253+
break;
254+
}
255+
auto validityBuf = buffers[bufIdx++]; // validity
256+
bufIdx++; // value (bit-packed, skip for stats)
257+
258+
ColumnStatistics col{};
259+
col.columnIndex = static_cast<uint16_t>(fieldIdx);
260+
col.typeId = static_cast<uint8_t>(typeId);
261+
col.hasNull = hasAnyNull(validityBuf, numRows);
262+
col.hasStats = false; // Bool stats not useful.
263+
result.columnStats.push_back(col);
264+
break;
265+
}
266+
default: {
267+
// Fixed-width numeric types.
268+
if (bufIdx + 2 > buffers.size()) {
269+
break;
270+
}
271+
auto validityBuf = buffers[bufIdx++]; // validity
272+
auto valueBuf = buffers[bufIdx++]; // value
273+
274+
ColumnStatistics col{};
275+
col.columnIndex = static_cast<uint16_t>(fieldIdx);
276+
col.typeId = static_cast<uint8_t>(typeId);
277+
col.hasNull = hasAnyNull(validityBuf, numRows);
278+
col.hasStats = false;
279+
280+
switch (typeId) {
281+
case arrow::Type::INT8:
282+
scanColumnMinMax<int8_t>(validityBuf, valueBuf, numRows, col);
283+
break;
284+
case arrow::Type::INT16:
285+
scanColumnMinMax<int16_t>(validityBuf, valueBuf, numRows, col);
286+
break;
287+
case arrow::Type::INT32:
288+
case arrow::Type::DATE32:
289+
scanColumnMinMax<int32_t>(validityBuf, valueBuf, numRows, col);
290+
break;
291+
case arrow::Type::INT64:
292+
case arrow::Type::DATE64:
293+
case arrow::Type::TIMESTAMP:
294+
scanColumnMinMax<int64_t>(validityBuf, valueBuf, numRows, col);
295+
break;
296+
case arrow::Type::FLOAT:
297+
scanColumnMinMax<float>(validityBuf, valueBuf, numRows, col);
298+
break;
299+
case arrow::Type::DOUBLE:
300+
scanColumnMinMax<double>(validityBuf, valueBuf, numRows, col);
301+
break;
302+
default:
303+
// Unsupported type for min/max stats.
304+
break;
305+
}
306+
307+
result.columnStats.push_back(col);
308+
break;
309+
}
310+
}
311+
}
312+
313+
return result;
314+
}
315+
316+
} // namespace gluten

0 commit comments

Comments
 (0)