Skip to content

Commit cb7ae26

Browse files
authored
feat[cuda]: device exports for varbinview (#6338)
Need to copy to VarBin before sending to cudf --------- Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent c782cee commit cb7ae26

8 files changed

Lines changed: 371 additions & 37 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#pragma once
5+
6+
#include <stdint.h>
7+
8+
// Maximum length of inlined string.
9+
constexpr int32_t MAX_INLINED_SIZE = 12;
10+
11+
// a byte buffer holding string data
12+
typedef uint8_t* Buffer;
13+
14+
// an i32 offsets buffer
15+
typedef int32_t* Offsets;
16+
17+
struct InlinedBinaryView {
18+
int32_t size;
19+
uint8_t data[12];
20+
};
21+
22+
struct RefBinaryView {
23+
int32_t size;
24+
uint8_t prefix[4];
25+
int32_t index;
26+
int32_t offset;
27+
};
28+
29+
// The BinaryView type is how we access values.
30+
union alignas(int64_t) BinaryView {
31+
InlinedBinaryView inlined;
32+
RefBinaryView ref;
33+
};
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#include "config.cuh"
5+
#include "varbinview.cuh"
6+
7+
// single-threaded, compute offsets
8+
extern "C" __global__ void varbinview_compute_offsets(
9+
const BinaryView *views,
10+
int64_t num_strings,
11+
Offsets out_offsets,
12+
int32_t *last_offset
13+
) {
14+
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
15+
16+
// force execution to be single-threaded to compute the prefix
17+
// sum.
18+
// TODO(aduffy): we could do this with a CUB kernel instead.
19+
// Check the profiles later to see where this shows up.
20+
if (tid != 0) {
21+
return;
22+
}
23+
24+
int32_t offset = 0;
25+
out_offsets[0] = 0;
26+
for (int i = 0; i < num_strings; i++) {
27+
offset += views[i].inlined.size;
28+
out_offsets[i + 1] = offset;
29+
}
30+
31+
*last_offset = offset;
32+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#include "config.cuh"
5+
#include "varbinview.cuh"
6+
7+
// Lookup a string from a binary view, copying it into
8+
// a destination buffer.
9+
__device__ void copy_string_to_dst(
10+
BinaryView& view,
11+
Buffer *buffers,
12+
uint8_t *dst
13+
) {
14+
int32_t size = view.inlined.size;
15+
uint8_t *src;
16+
if (size <= MAX_INLINED_SIZE) {
17+
// TODO(aduffy): use uint64_t loads instead?
18+
src = view.inlined.data;
19+
} else {
20+
auto ref = view.ref;
21+
src = buffers[ref.index] + ref.offset;
22+
}
23+
memcpy(dst, src, size);
24+
}
25+
26+
extern "C" __global__ void varbinview_copy_strings(
27+
int64_t len,
28+
BinaryView* views,
29+
Buffer* buffers,
30+
Buffer dst_buffer,
31+
Offsets dst_offsets
32+
) {
33+
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
34+
35+
// Each thread is responsible for copying a single string.
36+
// Any excess threads do no work.
37+
if (tid >= len) {
38+
return;
39+
}
40+
41+
auto view = views[tid];
42+
int32_t offset = dst_offsets[tid];
43+
uint8_t *dst = &dst_buffer[offset];
44+
45+
copy_string_to_dst(view, buffers, dst);
46+
}

vortex-cuda/src/arrow/canonical.rs

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use vortex_array::arrays::PrimitiveArrayParts;
1212
use vortex_array::arrays::StructArray;
1313
use vortex_array::arrays::StructArrayParts;
1414
use vortex_array::buffer::BufferHandle;
15-
use vortex_array::validity::Validity;
15+
use vortex_array::vtable::ValidityHelper;
1616
use vortex_dtype::DecimalType;
1717
use vortex_dtype::datetime::AnyTemporal;
1818
use vortex_error::VortexResult;
@@ -26,6 +26,10 @@ use crate::arrow::DeviceType;
2626
use crate::arrow::ExportDeviceArray;
2727
use crate::arrow::PrivateData;
2828
use crate::arrow::SyncEvent;
29+
use crate::arrow::check_validity_empty;
30+
use crate::arrow::ensure_device_resident;
31+
use crate::arrow::varbinview::BinaryParts;
32+
use crate::arrow::varbinview::copy_varbinview_to_varbin;
2933
use crate::executor::CudaArrayExt;
3034

3135
/// An implementation of `ExportDeviceArray` that exports Vortex arrays to `ArrowDeviceArray` by
@@ -43,11 +47,11 @@ impl ExportDeviceArray for CanonicalDeviceArrayExport {
4347
) -> VortexResult<ArrowDeviceArray> {
4448
let cuda_array = array.execute_cuda(ctx).await?;
4549

46-
let (arrow_array, sync_event) = export_canonical(cuda_array, ctx).await?;
50+
let (arrow_array, _) = export_canonical(cuda_array, ctx).await?;
4751

4852
Ok(ArrowDeviceArray {
4953
array: arrow_array,
50-
sync_event,
54+
sync_event: None,
5155
device_id: ctx.stream().context().ordinal() as i64,
5256
device_type: DeviceType::Cuda,
5357
_reserved: Default::default(),
@@ -68,7 +72,7 @@ fn export_canonical(
6872
buffer, validity, ..
6973
} = primitive.into_parts();
7074

71-
check_validity_empty(validity)?;
75+
check_validity_empty(&validity)?;
7276

7377
let buffer = ensure_device_resident(buffer, ctx).await?;
7478

@@ -96,7 +100,7 @@ fn export_canonical(
96100
} = decimal.into_parts();
97101

98102
// verify that there is no null buffer
99-
check_validity_empty(validity)?;
103+
check_validity_empty(&validity)?;
100104

101105
// TODO(aduffy): GPU kernel for upcasting.
102106
vortex_ensure!(
@@ -120,12 +124,11 @@ fn export_canonical(
120124
buffer, validity, ..
121125
} = values.into_parts();
122126

123-
check_validity_empty(validity)?;
127+
check_validity_empty(&validity)?;
124128

125129
let buffer = ensure_device_resident(buffer, ctx).await?;
126130
export_fixed_size(buffer, len, 0, ctx)
127131
}
128-
129132
Canonical::Bool(bool_array) => {
130133
let BoolArrayParts {
131134
bits,
@@ -135,10 +138,40 @@ fn export_canonical(
135138
..
136139
} = bool_array.into_parts();
137140

138-
check_validity_empty(validity)?;
141+
check_validity_empty(&validity)?;
139142

140143
export_fixed_size(bits, len, offset, ctx)
141144
}
145+
Canonical::VarBinView(varbinview) => {
146+
let len = varbinview.len();
147+
check_validity_empty(varbinview.validity())?;
148+
149+
let BinaryParts { offsets, bytes } =
150+
copy_varbinview_to_varbin(varbinview, ctx).await?;
151+
152+
let offsets = ensure_device_resident(offsets, ctx).await?;
153+
let bytes = ensure_device_resident(bytes, ctx).await?;
154+
155+
let buffers = vec![None, Some(offsets), Some(bytes)];
156+
let mut private_data = PrivateData::new(buffers, vec![], ctx)?;
157+
let sync_event = private_data.sync_event();
158+
//
159+
let arrow_array = ArrowArray {
160+
length: len as i64,
161+
null_count: 0,
162+
offset: 0,
163+
// 1 (optional) buffer for nulls, one buffer for the data
164+
n_buffers: 2,
165+
buffers: private_data.buffer_ptrs.as_mut_ptr(),
166+
n_children: 0,
167+
children: std::ptr::null_mut(),
168+
release: Some(release_array),
169+
dictionary: std::ptr::null_mut(),
170+
private_data: Box::into_raw(private_data).cast(),
171+
};
172+
173+
Ok((arrow_array, sync_event))
174+
}
142175
// TODO(aduffy): implement VarBinView. cudf doesn't support it, so we need to
143176
// execute a kernel to translate from VarBinView -> VarBin.
144177
c => todo!("support for exporting {} arrays", c.dtype()),
@@ -155,7 +188,7 @@ async fn export_struct(
155188
validity, fields, ..
156189
} = array.into_parts();
157190

158-
check_validity_empty(validity)?;
191+
check_validity_empty(&validity)?;
159192

160193
// We need the children to be held across await points.
161194
let mut children = Vec::with_capacity(fields.len());
@@ -220,26 +253,6 @@ fn export_fixed_size(
220253
Ok((arrow_array, sync_event))
221254
}
222255

223-
/// Check that the validity buffer is empty and does not need to be copied over the device boundary.
224-
fn check_validity_empty(validity: Validity) -> VortexResult<()> {
225-
if let Validity::AllInvalid | Validity::Array(_) = validity {
226-
vortex_bail!("Exporting array with non-trivial validity not supported yet")
227-
}
228-
229-
Ok(())
230-
}
231-
232-
async fn ensure_device_resident(
233-
buffer_handle: BufferHandle,
234-
ctx: &mut CudaExecutionCtx,
235-
) -> VortexResult<BufferHandle> {
236-
if buffer_handle.is_on_device() {
237-
Ok(buffer_handle)
238-
} else {
239-
ctx.move_to_device(buffer_handle)?.await
240-
}
241-
}
242-
243256
// export some nested data
244257

245258
unsafe extern "C" fn release_array(array: *mut ArrowArray) {

vortex-cuda/src/arrow/mod.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//! More documentation at <https://arrow.apache.org/docs/format/CDeviceDataInterface.html>
1010
1111
mod canonical;
12+
mod varbinview;
1213

1314
use std::ffi::c_void;
1415
use std::fmt::Debug;
@@ -24,7 +25,9 @@ use cudarc::runtime::sys::cudaEvent_t;
2425
use vortex_array::Array;
2526
use vortex_array::ArrayRef;
2627
use vortex_array::buffer::BufferHandle;
28+
use vortex_array::validity::Validity;
2729
use vortex_error::VortexResult;
30+
use vortex_error::vortex_bail;
2831
use vortex_error::vortex_err;
2932

3033
use crate::CudaBufferExt;
@@ -213,3 +216,23 @@ pub trait ExportDeviceArray: Debug + Send + Sync + 'static {
213216
ctx: &mut CudaExecutionCtx,
214217
) -> VortexResult<ArrowDeviceArray>;
215218
}
219+
220+
/// Check that the validity buffer is empty and does not need to be copied over the device boundary.
221+
pub(crate) fn check_validity_empty(validity: &Validity) -> VortexResult<()> {
222+
if let Validity::AllInvalid | Validity::Array(_) = validity {
223+
vortex_bail!("Exporting array with non-trivial validity not supported yet")
224+
}
225+
226+
Ok(())
227+
}
228+
229+
pub(crate) async fn ensure_device_resident(
230+
buffer_handle: BufferHandle,
231+
ctx: &mut CudaExecutionCtx,
232+
) -> VortexResult<BufferHandle> {
233+
if buffer_handle.is_on_device() {
234+
Ok(buffer_handle)
235+
} else {
236+
ctx.move_to_device(buffer_handle)?.await
237+
}
238+
}

0 commit comments

Comments
 (0)