diff --git a/extensions/aws/s3/MinifiToAwsInputStream.cpp b/extensions/aws/s3/MinifiToAwsInputStream.cpp new file mode 100644 index 0000000000..367c000ad3 --- /dev/null +++ b/extensions/aws/s3/MinifiToAwsInputStream.cpp @@ -0,0 +1,75 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MinifiToAwsInputStream.h" + +#include +#include + +namespace org::apache::nifi::minifi::aws::s3 { + +MinifiInputStreamBuf::int_type MinifiInputStreamBuf::underflow() { + if (gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + const uint64_t stream_pos = stream_->tell(); + if (stream_pos >= start_pos_ + content_length_) { + return traits_type::eof(); + } + const auto remaining = (start_pos_ + content_length_) - stream_pos; + const auto to_read = std::min(utils::configuration::DEFAULT_BUFFER_SIZE, remaining); + const auto bytes_read = stream_->read(std::span(reinterpret_cast(buffer_.data()), gsl::narrow(to_read))); + if (io::isError(bytes_read)) { + owner_->setstate(std::ios_base::badbit); + return traits_type::eof(); + } + if (bytes_read == 0) { + return traits_type::eof(); + } + setg(buffer_.data(), buffer_.data(), buffer_.data() + bytes_read); + return traits_type::to_int_type(*gptr()); +} + +MinifiInputStreamBuf::pos_type MinifiInputStreamBuf::seekoff(off_type off, std::ios_base::seekdir way, std::ios_base::openmode which) { + if (!(which & std::ios_base::in)) { + return {off_type(-1)}; + } + pos_type new_virtual_pos; + if (way == std::ios_base::beg) { + new_virtual_pos = pos_type(off); + } else if (way == std::ios_base::cur) { + const auto phys_pos = static_cast(stream_->tell()) - static_cast(egptr() - gptr()); + new_virtual_pos = pos_type(phys_pos - static_cast(start_pos_) + off); + } else { + new_virtual_pos = pos_type(static_cast(content_length_) + off); + } + return seekpos(new_virtual_pos, which); +} + +MinifiInputStreamBuf::pos_type MinifiInputStreamBuf::seekpos(pos_type pos, std::ios_base::openmode which) { + if (!(which & std::ios_base::in)) { + return {off_type(-1)}; + } + if (off_type(pos) < 0 || off_type(pos) > gsl::narrow(content_length_)) { + return {off_type(-1)}; + } + stream_->seek(start_pos_ + static_cast(off_type(pos))); + setg(buffer_.data(), buffer_.data(), buffer_.data()); // invalidate read buffer + return pos; +} + +} // namespace org::apache::nifi::minifi::aws::s3 diff --git a/extensions/aws/s3/MinifiToAwsInputStream.h b/extensions/aws/s3/MinifiToAwsInputStream.h new file mode 100644 index 0000000000..067ec0fb94 --- /dev/null +++ b/extensions/aws/s3/MinifiToAwsInputStream.h @@ -0,0 +1,60 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "utils/ConfigurationUtils.h" +#include "minifi-cpp/io/InputStream.h" +#include "minifi-cpp/utils/gsl.h" + +namespace org::apache::nifi::minifi::aws::s3 { + +class MinifiInputStreamBuf : public std::streambuf { + public: + MinifiInputStreamBuf(std::shared_ptr stream, uint64_t content_length, gsl::not_null*> owner) + : stream_(std::move(stream)), + start_pos_(stream_->tell()), + content_length_(content_length), + buffer_(utils::configuration::DEFAULT_BUFFER_SIZE), + owner_(owner) {} + + protected: + int_type underflow() override; + pos_type seekoff(off_type off, std::ios_base::seekdir way, std::ios_base::openmode which) override; + pos_type seekpos(pos_type pos, std::ios_base::openmode which) override; + + private: + std::shared_ptr stream_; + uint64_t start_pos_; + uint64_t content_length_; + std::vector buffer_; + gsl::not_null*> owner_; +}; + +class MinifiToAwsInputStream : private MinifiInputStreamBuf, public std::basic_iostream { + public: + MinifiToAwsInputStream(std::shared_ptr stream, uint64_t content_length) + : MinifiInputStreamBuf(std::move(stream), content_length, gsl::not_null*>(static_cast*>(this))), + std::basic_iostream(static_cast(this)) {} +}; + +} // namespace org::apache::nifi::minifi::aws::s3 diff --git a/extensions/aws/s3/S3Wrapper.cpp b/extensions/aws/s3/S3Wrapper.cpp index 1bb3be6ac7..c568c5f762 100644 --- a/extensions/aws/s3/S3Wrapper.cpp +++ b/extensions/aws/s3/S3Wrapper.cpp @@ -22,6 +22,7 @@ #include #include +#include "MinifiToAwsInputStream.h" #include "S3ClientRequestSender.h" #include "utils/ArrayUtils.h" #include "utils/StringUtils.h" @@ -68,32 +69,11 @@ std::string S3Wrapper::getEncryptionString(Aws::S3Crt::Model::ServerSideEncrypti return ""; } -std::shared_ptr S3Wrapper::readFlowFileStream(const std::shared_ptr& stream, uint64_t read_limit, uint64_t& read_size_out) { - std::array buffer{}; - auto data_stream = std::make_shared(); - uint64_t read_size = 0; - while (read_size < read_limit) { - const auto next_read_size = (std::min)(read_limit - read_size, uint64_t{BUFFER_SIZE}); - const auto read_ret = stream->read(std::span(buffer).subspan(0, next_read_size)); - if (io::isError(read_ret)) { - throw StreamReadException("Reading flow file inputstream failed!"); - } - if (read_ret > 0) { - data_stream->write(reinterpret_cast(buffer.data()), gsl::narrow(read_ret)); - read_size += read_ret; - } else { - break; - } - } - read_size_out = read_size; - return data_stream; -} - std::optional S3Wrapper::putObject(const PutObjectRequestParameters& put_object_params, const std::shared_ptr& stream, uint64_t flow_size) { - uint64_t read_size{}; - auto data_stream = readFlowFileStream(stream, flow_size, read_size); auto request = createPutObjectRequest(put_object_params); - request.SetBody(data_stream); + auto aws_stream = std::make_shared(stream, flow_size); + request.SetBody(aws_stream); + request.SetContentLength(static_cast(flow_size)); // NOLINT(runtime/int,google-runtime-int) AWS SDK expects long long for content length auto aws_result = request_sender_->sendPutObjectRequest(request); if (!aws_result) { @@ -120,11 +100,9 @@ std::optional S3Wrapper::uploadParts(const PutObje const size_t start_part = upload_state.uploaded_parts + 1; const size_t last_part = start_part + part_count - 1; for (size_t part_number = start_part; part_number <= last_part; ++part_number) { - uint64_t read_size{}; const auto remaining = flow_size - total_read; const auto next_read_size = std::min(remaining, upload_state.part_size); - auto stream_ptr = readFlowFileStream(stream, next_read_size, read_size); - total_read += read_size; + auto aws_stream = std::make_shared(stream, next_read_size); auto upload_part_request = Aws::S3Crt::Model::UploadPartRequest{} .WithBucket(put_object_params.bucket) @@ -132,20 +110,24 @@ std::optional S3Wrapper::uploadParts(const PutObje .WithPartNumber(gsl::narrow(part_number)) .WithUploadId(upload_state.upload_id) .WithChecksumAlgorithm(put_object_params.checksum_algorithm); - upload_part_request.SetBody(stream_ptr); + upload_part_request.SetBody(aws_stream); + upload_part_request.SetContentLength(static_cast(next_read_size)); // NOLINT(runtime/int,google-runtime-int) AWS SDK expects long long for content length - Aws::Utils::ByteBuffer part_md5(Aws::Utils::HashingUtils::CalculateMD5(*stream_ptr)); + Aws::Utils::ByteBuffer part_md5(Aws::Utils::HashingUtils::CalculateMD5(*aws_stream)); upload_part_request.SetContentMD5(Aws::Utils::HashingUtils::Base64Encode(part_md5)); + // Reset to part start so the SDK reads the full part during the upload request. + aws_stream->seekg(0, std::ios::beg); auto upload_part_result = request_sender_->sendUploadPartRequest(upload_part_request); if (!upload_part_result) { logger_->log_error("Failed to upload part {} of {} of S3 object with key '{}'", part_number, last_part, put_object_params.object_key); return std::nullopt; } + total_read += next_read_size; result.part_etags.push_back(upload_part_result->GetETag()); upload_state.uploaded_etags.push_back(upload_part_result->GetETag()); upload_state.uploaded_parts += 1; - upload_state.uploaded_size += read_size; + upload_state.uploaded_size += next_read_size; multipart_upload_storage_->storeState(put_object_params.bucket, put_object_params.object_key, upload_state); logger_->log_info("Uploaded part {} of {} S3 object with key '{}'", part_number, last_part, put_object_params.object_key); } diff --git a/extensions/aws/s3/S3Wrapper.h b/extensions/aws/s3/S3Wrapper.h index 0f558c0d22..537eabbdbe 100644 --- a/extensions/aws/s3/S3Wrapper.h +++ b/extensions/aws/s3/S3Wrapper.h @@ -21,7 +21,6 @@ #include #include #include -#include #include #include #include @@ -301,7 +300,6 @@ class S3Wrapper { static int64_t writeFetchedBody(Aws::IOStream& source, int64_t data_size, io::OutputStream& output); static std::string getEncryptionString(Aws::S3Crt::Model::ServerSideEncryption encryption); - static std::shared_ptr readFlowFileStream(const std::shared_ptr& stream, uint64_t read_limit, uint64_t& read_size_out); std::optional> listVersions(const ListRequestParameters& params); std::optional> listObjects(const ListRequestParameters& params); diff --git a/extensions/aws/tests/MockS3RequestSender.h b/extensions/aws/tests/MockS3RequestSender.h index d209057eb1..2ee77cae03 100644 --- a/extensions/aws/tests/MockS3RequestSender.h +++ b/extensions/aws/tests/MockS3RequestSender.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -232,6 +233,10 @@ class MockS3RequestSender : public minifi::aws::s3::S3RequestSender { fail_on_part_ = 0; return std::nullopt; } + // Consume the body like the real SDK, allowing the next part to start at the correct position + if (auto body = request.GetBody()) { + body->ignore(std::numeric_limits::max()); + } upload_part_requests.push_back(request); Aws::S3Crt::Model::UploadPartResult result; result.SetETag("etag" + std::to_string(etag_counter_)); @@ -294,6 +299,10 @@ class MockS3RequestSender : public minifi::aws::s3::S3RequestSender { } static std::string getUploadPartRequestBody(const Aws::S3Crt::Model::UploadPartRequest& upload_part_request) { + // Seek to the beginning of this part's window before reading, because the + // underlying io::InputStream is shared across all parts and may be positioned + // elsewhere by the time this helper is called. + upload_part_request.GetBody()->seekg(0); std::istreambuf_iterator buf_it; return std::string(std::istreambuf_iterator(*upload_part_request.GetBody()), buf_it); }