diff --git a/cmake/external_dependencies.cmake b/cmake/external_dependencies.cmake index e2a3f4e3..aea1225e 100644 --- a/cmake/external_dependencies.cmake +++ b/cmake/external_dependencies.cmake @@ -74,7 +74,7 @@ endif() find_package_or_fetch( PACKAGE_NAME sparrow GIT_REPOSITORY https://github.com/man-group/sparrow.git - TAG 1.3.0 + TAG 2.0.0 ) unset(CREATE_JSON_READER_TARGET) diff --git a/conanfile.py b/conanfile.py index a6840562..17d303b6 100644 --- a/conanfile.py +++ b/conanfile.py @@ -43,7 +43,7 @@ def configure(self): self.options.rm_safe("fPIC") def requirements(self): - self.requires("sparrow/1.0.0") + self.requires("sparrow/1.2.0", options={"json_reader": True}) self.requires(f"flatbuffers/{self._flatbuffers_version}") self.requires("lz4/1.9.4") self.requires("zstd/1.5.7") diff --git a/include/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp b/include/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp index 6ca89a41..1fc561dd 100644 --- a/include/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp +++ b/include/sparrow_ipc/arrow_interface/arrow_array/private_data.hpp @@ -5,13 +5,14 @@ #include #include +#include "sparrow/buffer/buffer.hpp" + #include "sparrow_ipc/config/config.hpp" namespace sparrow_ipc { template - concept ArrowPrivateData = requires(T& t) - { + concept ArrowPrivateData = requires(T& t) { { t.buffers_ptrs() } -> std::same_as; { t.n_buffers() } -> std::convertible_to; }; @@ -19,7 +20,8 @@ namespace sparrow_ipc class arrow_array_private_data { public: - using optionally_owned_buffer = std::variant, std::span>; + + using optionally_owned_buffer = std::variant, std::span>; explicit arrow_array_private_data(std::vector&& buffers); [[nodiscard]] SPARROW_IPC_API const void** buffers_ptrs() noexcept; diff --git a/include/sparrow_ipc/compression.hpp b/include/sparrow_ipc/compression.hpp index c1d080e9..a1436e88 100644 --- a/include/sparrow_ipc/compression.hpp +++ b/include/sparrow_ipc/compression.hpp @@ -6,6 +6,8 @@ #include #include +#include + #include "sparrow_ipc/config/config.hpp" namespace sparrow_ipc @@ -20,39 +22,43 @@ namespace sparrow_ipc class SPARROW_IPC_API CompressionCache { - public: - CompressionCache(); - ~CompressionCache(); + public: + + CompressionCache(); + ~CompressionCache(); - CompressionCache(CompressionCache&&) noexcept; - CompressionCache& operator=(CompressionCache&&) noexcept; + CompressionCache(CompressionCache&&) noexcept; + CompressionCache& operator=(CompressionCache&&) noexcept; - CompressionCache(const CompressionCache&) = delete; - CompressionCache& operator=(const CompressionCache&) = delete; + CompressionCache(const CompressionCache&) = delete; + CompressionCache& operator=(const CompressionCache&) = delete; - std::optional> find(const void* data_ptr, const size_t data_size); - std::span store(const void* data_ptr, const size_t data_size, std::vector&& data); + std::optional> find(const void* data_ptr, const size_t data_size); + std::span + store(const void* data_ptr, const size_t data_size, std::vector&& data); - size_t size() const; - size_t count(const void* data_ptr, const size_t data_size) const; - bool empty() const; - void clear(); + size_t size() const; + size_t count(const void* data_ptr, const size_t data_size) const; + bool empty() const; + void clear(); - private: - std::unique_ptr m_pimpl; + private: + + std::unique_ptr m_pimpl; }; [[nodiscard]] SPARROW_IPC_API std::span compress( const CompressionType compression_type, const std::span& data, - CompressionCache& cache); + CompressionCache& cache + ); [[nodiscard]] SPARROW_IPC_API size_t get_compressed_size( const CompressionType compression_type, const std::span& data, - CompressionCache& cache); + CompressionCache& cache + ); - [[nodiscard]] SPARROW_IPC_API std::variant, std::span> decompress( - const CompressionType compression_type, - std::span data); + [[nodiscard]] SPARROW_IPC_API std::variant, std::span> + decompress(const CompressionType compression_type, std::span data); } diff --git a/include/sparrow_ipc/deserialize_decimal_array.hpp b/include/sparrow_ipc/deserialize_decimal_array.hpp new file mode 100644 index 00000000..cf30e536 --- /dev/null +++ b/include/sparrow_ipc/deserialize_decimal_array.hpp @@ -0,0 +1,88 @@ +#pragma once + +#include + +#include +#include + +#include "Message_generated.h" +#include "sparrow_ipc/arrow_interface/arrow_array.hpp" +#include "sparrow_ipc/arrow_interface/arrow_schema.hpp" +#include "sparrow_ipc/deserialize_utils.hpp" + +namespace sparrow_ipc +{ + template + [[nodiscard]] sparrow::decimal_array deserialize_non_owning_decimal( + const org::apache::arrow::flatbuf::RecordBatch& record_batch, + std::span body, + std::string_view name, + const std::optional>& metadata, + bool nullable, + size_t& buffer_index, + int32_t scale, + int32_t precision + ) + { + constexpr std::size_t sizeof_decimal = sizeof(typename T::integer_type); + std::string format_str = "d:" + std::to_string(precision) + "," + std::to_string(scale); + if constexpr (sizeof_decimal != 16) // We don't need to specify the size for 128-bit + // decimals + { + format_str += "," + std::to_string(sizeof_decimal * 8); + } + + // Set up flags based on nullable + std::optional> flags; + if (nullable) + { + flags = std::unordered_set{sparrow::ArrowFlag::NULLABLE}; + } + + ArrowSchema schema = make_non_owning_arrow_schema( + format_str, + name.data(), + metadata, + flags, + 0, + nullptr, + nullptr + ); + + const auto compression = record_batch.compression(); + std::vector buffers; + + auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index); + auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index); + + if (compression) + { + buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); + buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression)); + } + else + { + buffers.emplace_back(validity_buffer_span); + sparrow::buffer data_buffer_copy(data_buffer_span.size(), sparrow::buffer::default_allocator()); + std::memcpy(data_buffer_copy.data(), data_buffer_span.data(), data_buffer_span.size()); + buffers.emplace_back(std::move(data_buffer_copy)); + } + + const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count( + validity_buffer_span, + record_batch.length() + ); + + ArrowArray array = make_arrow_array( + record_batch.length(), + null_count, + 0, + 0, + nullptr, + nullptr, + std::move(buffers) + ); + sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; + return sparrow::decimal_array(std::move(ap)); + } +} \ No newline at end of file diff --git a/include/sparrow_ipc/deserialize_utils.hpp b/include/sparrow_ipc/deserialize_utils.hpp index 97f6be9d..c3f77d0c 100644 --- a/include/sparrow_ipc/deserialize_utils.hpp +++ b/include/sparrow_ipc/deserialize_utils.hpp @@ -60,10 +60,10 @@ namespace sparrow_ipc::utils * @param compression The compression algorithm to use. If nullptr, no decompression is performed. * * @return A `std::variant` containing either: - * - A `std::vector` with the decompressed data, or + * - A `sparrow::buffer` with the decompressed data, or * - A `std::span` providing a view of the original `buffer_span` if no decompression occurred. */ - [[nodiscard]] std::variant, std::span> get_decompressed_buffer( + [[nodiscard]] std::variant, std::span> get_decompressed_buffer( std::span buffer_span, const org::apache::arrow::flatbuf::BodyCompression* compression ); diff --git a/include/sparrow_ipc/utils.hpp b/include/sparrow_ipc/utils.hpp index 63f1fb89..8e567132 100644 --- a/include/sparrow_ipc/utils.hpp +++ b/include/sparrow_ipc/utils.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -13,6 +14,39 @@ namespace sparrow_ipc::utils // Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies SPARROW_IPC_API size_t align_to_8(const size_t n); + /** + * @brief Extracts words after ':' separated by ',' from a string. + * + * This function finds the position of ':' in the input string and then + * splits the remaining part by ',' to extract individual words. + * + * @param str Input string to parse (e.g., "prefix:word1,word2,word3") + * @return std::vector Vector of string views containing the extracted words + * Returns an empty vector if ':' is not found or if there are no words after it + * + * @example + * extract_words_after_colon("d:128,10") returns {"128", "10"} + * extract_words_after_colon("w:256") returns {"256"} + * extract_words_after_colon("no_colon") returns {} + */ + SPARROW_IPC_API std::vector extract_words_after_colon(std::string_view str); + + /** + * @brief Parse a string_view to int32_t using std::from_chars. + * + * This function converts a string view to a 32-bit integer using std::from_chars + * for efficient parsing. + * + * @param str The string view to parse + * @return std::optional The parsed integer value, or std::nullopt if parsing fails + * + * @example + * parse_to_int32("123") returns std::optional(123) + * parse_to_int32("abc") returns std::nullopt + * parse_to_int32("") returns std::nullopt + */ + SPARROW_IPC_API std::optional parse_to_int32(std::string_view str); + /** * @brief Checks if all record batches in a collection have consistent structure. * @@ -63,5 +97,24 @@ namespace sparrow_ipc::utils // Parse the format string // The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc std::optional parse_format(std::string_view format_str, std::string_view sep); + + /** + * @brief Parse decimal format strings. + * + * This function parses decimal format strings which can be in two formats: + * - "d:precision,scale" (e.g., "d:19,10") + * - "d:precision,scale,bitWidth" (e.g., "d:19,10,128") + * + * @param format_str The format string to parse + * @return std::optional>> + * A tuple containing (precision, scale, optional bitWidth), or std::nullopt if parsing fails + * + * @example + * parse_decimal_format("d:19,10") returns std::optional{std::tuple{19, 10, std::nullopt}} + * parse_decimal_format("d:19,10,128") returns std::optional{std::tuple{19, 10, std::optional{128}}} + * parse_decimal_format("invalid") returns std::nullopt + */ + SPARROW_IPC_API std::optional>> parse_decimal_format(std::string_view format_str); + // size_t calculate_output_serialized_size(const sparrow::record_batch& record_batch); } diff --git a/src/compression.cpp b/src/compression.cpp index 7579f6e4..d6cab632 100644 --- a/src/compression.cpp +++ b/src/compression.cpp @@ -15,7 +15,7 @@ namespace sparrow_ipc struct TupleHasher { template - static inline void hash_combine(std::size_t& seed, const T& v) + inline static void hash_combine(std::size_t& seed, const T& v) { std::hash hasher; seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); @@ -32,74 +32,87 @@ namespace sparrow_ipc class CompressionCacheImpl { - public: - CompressionCacheImpl() = default; - ~CompressionCacheImpl() = default; + public: - CompressionCacheImpl(CompressionCacheImpl&&) noexcept = default; - CompressionCacheImpl& operator=(CompressionCacheImpl&&) noexcept = default; + CompressionCacheImpl() = default; + ~CompressionCacheImpl() = default; - CompressionCacheImpl(const CompressionCacheImpl&) = delete; - CompressionCacheImpl& operator=(const CompressionCacheImpl&) = delete; + CompressionCacheImpl(CompressionCacheImpl&&) noexcept = default; + CompressionCacheImpl& operator=(CompressionCacheImpl&&) noexcept = default; - std::optional> find(const void* data_ptr, const size_t data_size) - { - auto it = m_cache.find({data_ptr, data_size}); - if (it != m_cache.end()) - { - return it->second; - } - return std::nullopt; - } + CompressionCacheImpl(const CompressionCacheImpl&) = delete; + CompressionCacheImpl& operator=(const CompressionCacheImpl&) = delete; - std::span store(const void* data_ptr, const size_t data_size, std::vector&& data) + std::optional> find(const void* data_ptr, const size_t data_size) + { + auto it = m_cache.find({data_ptr, data_size}); + if (it != m_cache.end()) { - auto [it, inserted] = m_cache.emplace(std::piecewise_construct, std::forward_as_tuple(data_ptr, data_size), std::forward_as_tuple(std::move(data))); - if (!inserted) - { - throw std::runtime_error("Key already exists in compression cache"); - } return it->second; } + return std::nullopt; + } - size_t size() const + std::span + store(const void* data_ptr, const size_t data_size, std::vector&& data) + { + auto [it, inserted] = m_cache.emplace( + std::piecewise_construct, + std::forward_as_tuple(data_ptr, data_size), + std::forward_as_tuple(std::move(data)) + ); + if (!inserted) { - return m_cache.size(); + throw std::runtime_error("Key already exists in compression cache"); } + return it->second; + } - size_t count(const void* data_ptr, const size_t data_size) const - { - return m_cache.count({data_ptr, data_size}); - } + size_t size() const + { + return m_cache.size(); + } - bool empty() const - { - return m_cache.empty(); - } + size_t count(const void* data_ptr, const size_t data_size) const + { + return m_cache.count({data_ptr, data_size}); + } - void clear() - { - m_cache.clear(); - } + bool empty() const + { + return m_cache.empty(); + } - private: - using cache_key_t = std::tuple; - using compression_cache_t = std::unordered_map, TupleHasher>; - compression_cache_t m_cache; + void clear() + { + m_cache.clear(); + } + + private: + + using cache_key_t = std::tuple; + using compression_cache_t = std::unordered_map, TupleHasher>; + compression_cache_t m_cache; }; - CompressionCache::CompressionCache() : m_pimpl(std::make_unique()) {} + CompressionCache::CompressionCache() + : m_pimpl(std::make_unique()) + { + } + CompressionCache::~CompressionCache() = default; CompressionCache::CompressionCache(CompressionCache&&) noexcept = default; CompressionCache& CompressionCache::operator=(CompressionCache&&) noexcept = default; - std::optional> CompressionCache::find(const void* data_ptr, const size_t data_size) + std::optional> + CompressionCache::find(const void* data_ptr, const size_t data_size) { return m_pimpl->find(data_ptr, data_size); } - std::span CompressionCache::store(const void* data_ptr, const size_t data_size, std::vector&& data) + std::span + CompressionCache::store(const void* data_ptr, const size_t data_size, std::vector&& data) { return m_pimpl->store(data_ptr, data_size, std::move(data)); } @@ -151,19 +164,25 @@ namespace sparrow_ipc throw std::invalid_argument("Unsupported compression type."); } } - } // namespace details + } // namespace details namespace { using compress_func = std::function(std::span)>; - using decompress_func = std::function(std::span, int64_t)>; + using decompress_func = std::function(std::span, int64_t)>; std::vector lz4_compress_with_header(std::span data) { const std::int64_t uncompressed_size = data.size(); const size_t max_compressed_size = LZ4F_compressFrameBound(uncompressed_size, nullptr); std::vector result(details::CompressionHeaderSize + max_compressed_size); - const size_t compressed_size = LZ4F_compressFrame(result.data() + details::CompressionHeaderSize, max_compressed_size, data.data(), uncompressed_size, nullptr); + const size_t compressed_size = LZ4F_compressFrame( + result.data() + details::CompressionHeaderSize, + max_compressed_size, + data.data(), + uncompressed_size, + nullptr + ); if (LZ4F_isError(compressed_size)) { throw std::runtime_error("Failed to compress data with LZ4 frame format"); @@ -173,15 +192,23 @@ namespace sparrow_ipc return result; } - std::vector lz4_decompress(std::span data, const std::int64_t decompressed_size) + sparrow::buffer + lz4_decompress(std::span data, const std::int64_t decompressed_size) { - std::vector decompressed_data(decompressed_size); + sparrow::bufferdecompressed_data(decompressed_size, sparrow::buffer::default_allocator()); LZ4F_dctx* dctx = nullptr; LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION); size_t compressed_size_in_out = data.size(); size_t decompressed_size_in_out = decompressed_size; - const size_t result = LZ4F_decompress(dctx, decompressed_data.data(), &decompressed_size_in_out, data.data(), &compressed_size_in_out, nullptr); - if (LZ4F_isError(result) || (decompressed_size_in_out != (size_t)decompressed_size)) + const size_t result = LZ4F_decompress( + dctx, + decompressed_data.data(), + &decompressed_size_in_out, + data.data(), + &compressed_size_in_out, + nullptr + ); + if (LZ4F_isError(result) || (decompressed_size_in_out != (size_t) decompressed_size)) { throw std::runtime_error("Failed to decompress data with LZ4 frame format"); } @@ -194,7 +221,13 @@ namespace sparrow_ipc const std::int64_t uncompressed_size = data.size(); const size_t max_compressed_size = ZSTD_compressBound(uncompressed_size); std::vector result(details::CompressionHeaderSize + max_compressed_size); - const size_t compressed_size = ZSTD_compress(result.data() + details::CompressionHeaderSize, max_compressed_size, data.data(), uncompressed_size, 1); + const size_t compressed_size = ZSTD_compress( + result.data() + details::CompressionHeaderSize, + max_compressed_size, + data.data(), + uncompressed_size, + 1 + ); if (ZSTD_isError(compressed_size)) { throw std::runtime_error("Failed to compress data with ZSTD"); @@ -204,11 +237,17 @@ namespace sparrow_ipc return result; } - std::vector zstd_decompress(std::span data, const std::int64_t decompressed_size) + sparrow::buffer + zstd_decompress(std::span data, const std::int64_t decompressed_size) { - std::vector decompressed_data(decompressed_size); - const size_t result = ZSTD_decompress(decompressed_data.data(), decompressed_size, data.data(), data.size()); - if (ZSTD_isError(result) || (result != (size_t)decompressed_size)) + sparrow::buffer decompressed_data(decompressed_size, sparrow::buffer::default_allocator()); + const size_t result = ZSTD_decompress( + decompressed_data.data(), + decompressed_size, + data.data(), + data.size() + ); + if (ZSTD_isError(result) || (result != (size_t) decompressed_size)) { throw std::runtime_error("Failed to decompress data with ZSTD"); } @@ -219,14 +258,19 @@ namespace sparrow_ipc { const std::int64_t header = -1; result.reserve(details::CompressionHeaderSize + data.size()); - result.insert(result.end(), reinterpret_cast(&header), reinterpret_cast(&header) + sizeof(header)); + result.insert( + result.end(), + reinterpret_cast(&header), + reinterpret_cast(&header) + sizeof(header) + ); result.insert(result.end(), data.begin(), data.end()); } std::span compress_with_header( const std::span& data, compress_func comp_func, - CompressionCache& cache) + CompressionCache& cache + ) { const void* buffer_ptr = data.data(); const size_t buffer_size = data.size(); @@ -253,7 +297,8 @@ namespace sparrow_ipc return cache.store(buffer_ptr, buffer_size, std::move(result_vec)); } - std::variant, std::span> decompress_with_header(std::span data, decompress_func decomp_func) + std::variant, std::span> + decompress_with_header(std::span data, decompress_func decomp_func) { if (data.size() < details::CompressionHeaderSize) { @@ -278,12 +323,10 @@ namespace sparrow_ipc } return data.subspan(details::CompressionHeaderSize); } - } // namespace + } // namespace - std::span compress( - const CompressionType compression_type, - const std::span& data, - CompressionCache& cache) + std::span + compress(const CompressionType compression_type, const std::span& data, CompressionCache& cache) { switch (compression_type) { @@ -303,12 +346,14 @@ namespace sparrow_ipc size_t get_compressed_size( const CompressionType compression_type, const std::span& data, - CompressionCache& cache) + CompressionCache& cache + ) { return compress(compression_type, data, cache).size(); } - std::variant, std::span> decompress(const CompressionType compression_type, std::span data) + std::variant, std::span> + decompress(const CompressionType compression_type, std::span data) { if (data.empty()) { diff --git a/src/deserialize.cpp b/src/deserialize.cpp index 92063de1..f5b29fc5 100644 --- a/src/deserialize.cpp +++ b/src/deserialize.cpp @@ -2,6 +2,7 @@ #include +#include "sparrow_ipc/deserialize_decimal_array.hpp" #include "sparrow_ipc/deserialize_fixedsizebinary_array.hpp" #include "sparrow_ipc/deserialize_primitive_array.hpp" #include "sparrow_ipc/deserialize_variable_size_binary_array.hpp" @@ -205,6 +206,73 @@ namespace sparrow_ipc ) ); break; + case org::apache::arrow::flatbuf::Type::Decimal: + { + const auto decimal_field = field->type_as_Decimal(); + const auto scale = decimal_field->scale(); + const auto precision = decimal_field->precision(); + if (decimal_field->bitWidth() == 32) + { + arrays.emplace_back( + deserialize_non_owning_decimal>( + record_batch, + encapsulated_message.body(), + name, + metadata, + nullable, + buffer_index, + scale, + precision + ) + ); + } + else if (decimal_field->bitWidth() == 64) + { + arrays.emplace_back( + deserialize_non_owning_decimal>( + record_batch, + encapsulated_message.body(), + name, + metadata, + nullable, + buffer_index, + scale, + precision + ) + ); + } + else if (decimal_field->bitWidth() == 128) + { + arrays.emplace_back( + deserialize_non_owning_decimal>( + record_batch, + encapsulated_message.body(), + name, + metadata, + nullable, + buffer_index, + scale, + precision + ) + ); + } + else if (decimal_field->bitWidth() == 256) + { + arrays.emplace_back( + deserialize_non_owning_decimal>( + record_batch, + encapsulated_message.body(), + name, + metadata, + nullable, + buffer_index, + scale, + precision + ) + ); + } + break; + } default: throw std::runtime_error("Unsupported type."); } diff --git a/src/deserialize_utils.cpp b/src/deserialize_utils.cpp index f5e93e24..b1fb1af7 100644 --- a/src/deserialize_utils.cpp +++ b/src/deserialize_utils.cpp @@ -35,7 +35,7 @@ namespace sparrow_ipc::utils return body.subspan(buffer_metadata->offset(), buffer_metadata->length()); } - std::variant, std::span> get_decompressed_buffer( + std::variant, std::span> get_decompressed_buffer( std::span buffer_span, const org::apache::arrow::flatbuf::BodyCompression* compression ) diff --git a/src/flatbuffer_utils.cpp b/src/flatbuffer_utils.cpp index 4c4e79a7..8b715d4d 100644 --- a/src/flatbuffer_utils.cpp +++ b/src/flatbuffer_utils.cpp @@ -370,7 +370,7 @@ namespace sparrow_ipc } // Creates a Flatbuffers Decimal type from a format string - // The format string is expected to be in the format "d:precision,scale" + // The format string is expected to be in the format "d:precision,scale" or "d:precision,scale,bitWidth" std::pair> get_flatbuffer_decimal_type( flatbuffers::FlatBufferBuilder& builder, std::string_view format_str, @@ -378,29 +378,26 @@ namespace sparrow_ipc ) { // Decimal requires precision and scale. We need to parse the format_str. - // Format: "d:precision,scale" - const auto scale = utils::parse_format(format_str, ","); - if (!scale.has_value()) + // Format: "d:precision,scale" or "d:precision,scale,bitWidth" + const auto parsed = utils::parse_decimal_format(format_str); + if (!parsed.has_value()) { throw std::runtime_error( "Failed to parse Decimal " + std::to_string(bitWidth) - + " scale from format string: " + std::string(format_str) - ); - } - const size_t comma_pos = format_str.find(','); - const auto precision = utils::parse_format(format_str.substr(0, comma_pos), ":"); - if (!precision.has_value()) - { - throw std::runtime_error( - "Failed to parse Decimal " + std::to_string(bitWidth) - + " precision from format string: " + std::string(format_str) + + " format string: " + std::string(format_str) ); } + + const auto& [precision, scale, parsed_bitwidth] = parsed.value(); + + // Use the bitWidth from the format string if provided, otherwise use the parameter + const int32_t actual_bitwidth = parsed_bitwidth.value_or(bitWidth); + const auto decimal_type = org::apache::arrow::flatbuf::CreateDecimal( builder, - precision.value(), - scale.value(), - bitWidth + precision, + scale, + actual_bitwidth ); return {org::apache::arrow::flatbuf::Type::Decimal, decimal_type.Union()}; } diff --git a/src/utils.cpp b/src/utils.cpp index 73db1369..8de95941 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,9 +1,14 @@ #include "sparrow_ipc/utils.hpp" #include +#include +#include namespace sparrow_ipc::utils { + + // Parse the format string + // The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc std::optional parse_format(std::string_view format_str, std::string_view sep) { // Find the position of the delimiter @@ -29,8 +34,138 @@ namespace sparrow_ipc::utils return substr_size; } + std::optional>> parse_decimal_format(std::string_view format_str) + { + // Format can be "d:precision,scale" or "d:precision,scale,bitWidth" + // First, find the colon + const auto colon_pos = format_str.find(':'); + if (colon_pos == std::string_view::npos) + { + return std::nullopt; + } + + // Extract the part after the colon + std::string_view params_str(format_str.data() + colon_pos + 1, format_str.size() - colon_pos - 1); + + // Find the first comma (between precision and scale) + const auto first_comma_pos = params_str.find(','); + if (first_comma_pos == std::string_view::npos) + { + return std::nullopt; + } + + // Parse precision + std::string_view precision_str(params_str.data(), first_comma_pos); + int32_t precision = 0; + auto [ptr1, ec1] = std::from_chars( + precision_str.data(), + precision_str.data() + precision_str.size(), + precision + ); + if (ec1 != std::errc() || ptr1 != precision_str.data() + precision_str.size()) + { + return std::nullopt; + } + + // Find the second comma (between scale and bitWidth, if present) + const auto remaining_str = params_str.substr(first_comma_pos + 1); + const auto second_comma_pos = remaining_str.find(','); + + std::string_view scale_str; + std::optional bit_width; + + if (second_comma_pos == std::string_view::npos) + { + // Format is "d:precision,scale" + scale_str = remaining_str; + } + else + { + // Format is "d:precision,scale,bitWidth" + scale_str = std::string_view(remaining_str.data(), second_comma_pos); + std::string_view bitwidth_str(remaining_str.data() + second_comma_pos + 1, + remaining_str.size() - second_comma_pos - 1); + + // Parse bitWidth + int32_t bw = 0; + auto [ptr3, ec3] = std::from_chars( + bitwidth_str.data(), + bitwidth_str.data() + bitwidth_str.size(), + bw + ); + if (ec3 != std::errc() || ptr3 != bitwidth_str.data() + bitwidth_str.size()) + { + return std::nullopt; + } + bit_width = bw; + } + + // Parse scale + int32_t scale = 0; + auto [ptr2, ec2] = std::from_chars( + scale_str.data(), + scale_str.data() + scale_str.size(), + scale + ); + if (ec2 != std::errc() || ptr2 != scale_str.data() + scale_str.size()) + { + return std::nullopt; + } + + return std::make_tuple(precision, scale, bit_width); + } + size_t align_to_8(const size_t n) { return (n + 7) & -8; } + + std::vector extract_words_after_colon(std::string_view str) + { + std::vector result; + + // Find the position of ':' + const auto colon_pos = str.find(':'); + if (colon_pos == std::string_view::npos) + { + return result; // Return empty vector if ':' not found + } + + // Get the substring after ':' + std::string_view remaining = str.substr(colon_pos + 1); + + // If nothing after ':', return empty vector + if (remaining.empty()) + { + return result; + } + + // Split by ',' + size_t start = 0; + size_t comma_pos = remaining.find(','); + + while (comma_pos != std::string_view::npos) + { + result.push_back(remaining.substr(start, comma_pos - start)); + start = comma_pos + 1; + comma_pos = remaining.find(',', start); + } + + // Add the last word (or the only word if no comma was found) + result.push_back(remaining.substr(start)); + + return result; + } + + std::optional parse_to_int32(std::string_view str) + { + int32_t value = 0; + const auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), value); + + if (ec != std::errc() || ptr != str.data() + str.size()) + { + return std::nullopt; + } + return value; + } } diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index 8cb74e8f..96a1f308 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -22,8 +22,9 @@ const std::filesystem::path arrow_testing_data_dir = ARROW_TESTING_DATA_DIR; const std::filesystem::path tests_resources_files_path = arrow_testing_data_dir / "data" / "arrow-ipc-stream" / "integration" / "cpp-21.0.0"; -const std::filesystem::path tests_resources_files_path_with_compression = arrow_testing_data_dir / "data" / "arrow-ipc-stream" - / "integration" / "2.0.0-compression"; +const std::filesystem::path tests_resources_files_path_with_compression = arrow_testing_data_dir / "data" + / "arrow-ipc-stream" / "integration" + / "2.0.0-compression"; const std::vector files_paths_to_test = { tests_resources_files_path / "generated_primitive", @@ -33,16 +34,20 @@ const std::vector files_paths_to_test = { tests_resources_files_path / "generated_large_binary", tests_resources_files_path / "generated_binary_zerolength", tests_resources_files_path / "generated_binary_no_batches", + tests_resources_files_path / "generated_decimal32", + tests_resources_files_path / "generated_decimal64", + tests_resources_files_path / "generated_decimal", + tests_resources_files_path / "generated_decimal256", }; const std::vector files_paths_to_test_with_lz4_compression = { tests_resources_files_path_with_compression / "generated_lz4", - tests_resources_files_path_with_compression/ "generated_uncompressible_lz4", + tests_resources_files_path_with_compression / "generated_uncompressible_lz4", }; const std::vector files_paths_to_test_with_zstd_compression = { tests_resources_files_path_with_compression / "generated_zstd", - tests_resources_files_path_with_compression/ "generated_uncompressible_zstd", + tests_resources_files_path_with_compression / "generated_uncompressible_zstd", }; size_t get_number_of_batches(const std::filesystem::path& json_path) @@ -93,15 +98,27 @@ void compare_record_batches( } } -struct Lz4CompressionParams { +struct Lz4CompressionParams +{ static constexpr sparrow_ipc::CompressionType compression_type = sparrow_ipc::CompressionType::LZ4_FRAME; - static const std::vector& files() { return files_paths_to_test_with_lz4_compression; } + + static const std::vector& files() + { + return files_paths_to_test_with_lz4_compression; + } + static constexpr const char* name = "LZ4"; }; -struct ZstdCompressionParams { +struct ZstdCompressionParams +{ static constexpr sparrow_ipc::CompressionType compression_type = sparrow_ipc::CompressionType::ZSTD; - static const std::vector& files() { return files_paths_to_test_with_zstd_compression; } + + static const std::vector& files() + { + return files_paths_to_test_with_zstd_compression; + } + static constexpr const char* name = "ZSTD"; }; @@ -203,13 +220,19 @@ TEST_SUITE("Integration tests") } } - TEST_CASE_TEMPLATE("Compare record_batch serialization with stream file using compression", T, Lz4CompressionParams, ZstdCompressionParams) + TEST_CASE_TEMPLATE( + "Compare record_batch serialization with stream file using compression", + T, + Lz4CompressionParams, + ZstdCompressionParams + ) { for (const auto& file_path : T::files()) { std::filesystem::path json_path = file_path; json_path.replace_extension(".json"); - const std::string test_name = "Testing " + std::string(T::name) + " compression with " + file_path.filename().string(); + const std::string test_name = "Testing " + std::string(T::name) + " compression with " + + file_path.filename().string(); SUBCASE(test_name.c_str()) { // Load the JSON file @@ -254,7 +277,12 @@ TEST_SUITE("Integration tests") } } - TEST_CASE_TEMPLATE("Round trip of classic test files serialization/deserialization using compression", T, Lz4CompressionParams, ZstdCompressionParams) + TEST_CASE_TEMPLATE( + "Round trip of classic test files serialization/deserialization using compression", + T, + Lz4CompressionParams, + ZstdCompressionParams + ) { for (const auto& file_path : files_paths_to_test) { diff --git a/tests/test_utils.cpp b/tests/test_utils.cpp index 0619d68a..bd54c850 100644 --- a/tests/test_utils.cpp +++ b/tests/test_utils.cpp @@ -15,4 +15,244 @@ namespace sparrow_ipc CHECK_EQ(utils::align_to_8(15), 16); CHECK_EQ(utils::align_to_8(16), 16); } + + TEST_CASE("extract_words_after_colon") + { + SUBCASE("Basic case with multiple words") + { + auto result = utils::extract_words_after_colon("d:128,10"); + REQUIRE_EQ(result.size(), 2); + CHECK_EQ(result[0], "128"); + CHECK_EQ(result[1], "10"); + } + + SUBCASE("Single word after colon") + { + auto result = utils::extract_words_after_colon("w:256"); + REQUIRE_EQ(result.size(), 1); + CHECK_EQ(result[0], "256"); + } + + SUBCASE("Three words") + { + auto result = utils::extract_words_after_colon("d:10,5,128"); + REQUIRE_EQ(result.size(), 3); + CHECK_EQ(result[0], "10"); + CHECK_EQ(result[1], "5"); + CHECK_EQ(result[2], "128"); + } + + SUBCASE("No colon in string") + { + auto result = utils::extract_words_after_colon("no_colon"); + CHECK_EQ(result.size(), 0); + } + + SUBCASE("Colon at end") + { + auto result = utils::extract_words_after_colon("prefix:"); + CHECK_EQ(result.size(), 0); + } + + SUBCASE("Empty string") + { + auto result = utils::extract_words_after_colon(""); + CHECK_EQ(result.size(), 0); + } + + SUBCASE("Only colon and comma") + { + auto result = utils::extract_words_after_colon(":,"); + REQUIRE_EQ(result.size(), 2); + CHECK_EQ(result[0], ""); + CHECK_EQ(result[1], ""); + } + + SUBCASE("Complex prefix") + { + auto result = utils::extract_words_after_colon("prefix:word1,word2,word3"); + REQUIRE_EQ(result.size(), 3); + CHECK_EQ(result[0], "word1"); + CHECK_EQ(result[1], "word2"); + CHECK_EQ(result[2], "word3"); + } + } + + TEST_CASE("parse_to_int32") + { + SUBCASE("Valid positive integer") + { + auto result = utils::parse_to_int32("123"); + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), 123); + } + + SUBCASE("Valid negative integer") + { + auto result = utils::parse_to_int32("-456"); + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), -456); + } + + SUBCASE("Zero") + { + auto result = utils::parse_to_int32("0"); + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), 0); + } + + SUBCASE("Large valid number") + { + auto result = utils::parse_to_int32("2147483647"); // INT32_MAX + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), 2147483647); + } + + SUBCASE("Invalid - not a number") + { + auto result = utils::parse_to_int32("abc"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - empty string") + { + auto result = utils::parse_to_int32(""); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - partial number with text") + { + auto result = utils::parse_to_int32("123abc"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - text with number") + { + auto result = utils::parse_to_int32("abc123"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - just a sign") + { + auto result = utils::parse_to_int32("-"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Valid with leading zeros") + { + auto result = utils::parse_to_int32("00123"); + REQUIRE(result.has_value()); + CHECK_EQ(result.value(), 123); + } + } + + TEST_CASE("parse_decimal_format") + { + SUBCASE("Basic format: d:19,10") + { + auto result = utils::parse_decimal_format("d:19,10"); + REQUIRE(result.has_value()); + const auto& [precision, scale, bitwidth] = result.value(); + CHECK_EQ(precision, 19); + CHECK_EQ(scale, 10); + CHECK_FALSE(bitwidth.has_value()); + } + + SUBCASE("Extended format: d:19,10,128") + { + auto result = utils::parse_decimal_format("d:19,10,128"); + REQUIRE(result.has_value()); + const auto& [precision, scale, bitwidth] = result.value(); + CHECK_EQ(precision, 19); + CHECK_EQ(scale, 10); + REQUIRE(bitwidth.has_value()); + CHECK_EQ(bitwidth.value(), 128); + } + + SUBCASE("Extended format: d:38,6,256") + { + auto result = utils::parse_decimal_format("d:38,6,256"); + REQUIRE(result.has_value()); + const auto& [precision, scale, bitwidth] = result.value(); + CHECK_EQ(precision, 38); + CHECK_EQ(scale, 6); + REQUIRE(bitwidth.has_value()); + CHECK_EQ(bitwidth.value(), 256); + } + + SUBCASE("Basic format with zero scale: d:10,0") + { + auto result = utils::parse_decimal_format("d:10,0"); + REQUIRE(result.has_value()); + const auto& [precision, scale, bitwidth] = result.value(); + CHECK_EQ(precision, 10); + CHECK_EQ(scale, 0); + CHECK_FALSE(bitwidth.has_value()); + } + + SUBCASE("Extended format with zero scale: d:10,0,64") + { + auto result = utils::parse_decimal_format("d:10,0,64"); + REQUIRE(result.has_value()); + const auto& [precision, scale, bitwidth] = result.value(); + CHECK_EQ(precision, 10); + CHECK_EQ(scale, 0); + REQUIRE(bitwidth.has_value()); + CHECK_EQ(bitwidth.value(), 64); + } + + SUBCASE("Invalid - no colon") + { + auto result = utils::parse_decimal_format("d19,10"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - no comma") + { + auto result = utils::parse_decimal_format("d:19"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - missing precision") + { + auto result = utils::parse_decimal_format("d:,10"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - missing scale") + { + auto result = utils::parse_decimal_format("d:19,"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - missing bitwidth when provided") + { + auto result = utils::parse_decimal_format("d:19,10,"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - non-numeric precision") + { + auto result = utils::parse_decimal_format("d:abc,10"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - non-numeric scale") + { + auto result = utils::parse_decimal_format("d:19,abc"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - non-numeric bitwidth") + { + auto result = utils::parse_decimal_format("d:19,10,abc"); + CHECK_FALSE(result.has_value()); + } + + SUBCASE("Invalid - empty string") + { + auto result = utils::parse_decimal_format(""); + CHECK_FALSE(result.has_value()); + } + } }