diff --git a/.cspell.json b/.cspell.json index 0697ccc4edd..b72ea55deb2 100644 --- a/.cspell.json +++ b/.cspell.json @@ -261,6 +261,7 @@ "ossl", "ccrng", "KEYWRAP", + "HKDF", "NVME", // EC2 "IMDS", diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h index 0d062be1e00..e18e5deb8d1 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -372,6 +373,7 @@ namespace Aws * Return empty string on success, string with error message on error. */ Aws::String WritePartToDownloadStream(Aws::IOStream* partStream, uint64_t writeOffset); + void AddChecksumForPart(Aws:: IOStream* partStream, const PartPointer& shared); void ApplyDownloadConfiguration(const DownloadConfiguration& downloadConfig); @@ -389,6 +391,9 @@ namespace Aws Aws::String GetChecksum() const { return m_checksum; } void SetChecksum(const Aws::String& checksum) { this->m_checksum = checksum; } + Aws::S3::Model::ChecksumAlgorithm GetChecksumAlgorithm() const { std::lock_guard locker(m_getterSetterLock); return m_checksumAlgorithm; } + void SetChecksumAlgorithm (const Aws::S3::Model::ChecksumAlgorithm& checksumAlgorithm) { std::lock_guard locker(m_getterSetterLock); m_checksumAlgorithm = checksumAlgorithm; } + private: void CleanupDownloadStream(); @@ -430,6 +435,7 @@ namespace Aws mutable std::condition_variable m_waitUntilFinishedSignal; mutable std::mutex m_getterSetterLock; Aws::String m_checksum; + Aws::S3::Model::ChecksumAlgorithm m_checksumAlgorithm; }; AWS_TRANSFER_API Aws::OStream& operator << (Aws::OStream& s, TransferStatus status); diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h index a4b5580fd6e..725f14c1219 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h @@ -144,6 +144,13 @@ namespace Aws * upload. Defaults to CRC64-NVME. */ Aws::S3::Model::ChecksumAlgorithm checksumAlgorithm = S3::Model::ChecksumAlgorithm::CRC64NVME; + + /** + * Enable checksum validation for downloads. When enabled, checksums will be + * calculated during download and validated against S3 response headers. + * Defaults to true. + */ + bool validateChecksums = true; }; /** diff --git a/src/aws-cpp-sdk-transfer/source/transfer/TransferHandle.cpp b/src/aws-cpp-sdk-transfer/source/transfer/TransferHandle.cpp index 61fa1fee83e..fec20950d4b 100644 --- a/src/aws-cpp-sdk-transfer/source/transfer/TransferHandle.cpp +++ b/src/aws-cpp-sdk-transfer/source/transfer/TransferHandle.cpp @@ -6,6 +6,12 @@ #include #include #include +#include +#include +#include +#include +#include +#include "aws/core/utils/HashingUtils.h" #include @@ -370,6 +376,11 @@ namespace Aws AWS_LOGSTREAM_TRACE(CLASS_TAG, "Transfer handle ID [" << GetId() << "] Restarting transfer."); m_cancel.store(false); m_lastPart.store(false); + + // Clear checksum state for retry + std::lock_guard locker(m_getterSetterLock); + m_checksum.clear(); + m_checksumAlgorithm = Aws::S3::Model::ChecksumAlgorithm::NOT_SET; } bool TransferHandle::ShouldContinue() const @@ -423,6 +434,25 @@ namespace Aws return ""; } + void TransferHandle::AddChecksumForPart(Aws::IOStream *partStream, const PartPointer& partState) { + partStream->seekg(0); + Aws::String checksum = ""; + if (GetChecksumAlgorithm()==S3::Model::ChecksumAlgorithm::CRC32) { + checksum = Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::Crypto::CRC32().Calculate(*partStream).GetResult()); + } else if (GetChecksumAlgorithm()==S3::Model::ChecksumAlgorithm::CRC32C) { + checksum = Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::Crypto::CRC32C().Calculate(*partStream).GetResult()); + } else if (GetChecksumAlgorithm()==S3::Model::ChecksumAlgorithm::CRC64NVME) { + checksum = Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::Crypto::CRC64().Calculate(*partStream).GetResult()); + } else if (GetChecksumAlgorithm()==S3::Model::ChecksumAlgorithm::SHA1) { + checksum = Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::Crypto::Sha1().Calculate(*partStream).GetResult()); + } else if (GetChecksumAlgorithm()==S3::Model::ChecksumAlgorithm::SHA256) { + checksum = Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::Crypto::Sha256().Calculate(*partStream).GetResult()); + } + partState->SetChecksum(checksum); + partStream->clear(); + partStream->seekg(0); + } + void TransferHandle::ApplyDownloadConfiguration(const DownloadConfiguration& downloadConfig) { SetVersionId(downloadConfig.versionId); diff --git a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp index 996e427e114..4738bb9f487 100644 --- a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp +++ b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp @@ -15,6 +15,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -406,12 +409,9 @@ namespace Aws const auto fullObjectHashCalculator = [](const std::shared_ptr& handle, bool isRetry, S3::Model::ChecksumAlgorithm algorithm) -> std::shared_ptr { if (handle->GetChecksum().empty() && !isRetry) { - if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) { + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32 || algorithm == S3::Model::ChecksumAlgorithm::CRC32C) { return Aws::MakeShared("TransferManager"); } - if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) { - return Aws::MakeShared("TransferManager"); - } if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) { return Aws::MakeShared("TransferManager"); } @@ -673,6 +673,10 @@ namespace Aws { return outcome.GetResult().GetChecksumCRC32C(); } + else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC64NVME) + { + return outcome.GetResult().GetChecksumCRC64NVME(); + } else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1) { return outcome.GetResult().GetChecksumSHA1(); @@ -965,6 +969,7 @@ namespace Aws headObjectRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag); headObjectRequest.WithBucket(handle->GetBucketName()) .WithKey(handle->GetKey()); + headObjectRequest.SetChecksumMode(Aws::S3::Model::ChecksumMode::ENABLED); if(!handle->GetVersionId().empty()) { @@ -1000,6 +1005,18 @@ namespace Aws handle->SetContentType(headObjectOutcome.GetResult().GetContentType()); handle->SetMetadata(headObjectOutcome.GetResult().GetMetadata()); handle->SetEtag(headObjectOutcome.GetResult().GetETag()); + if (headObjectOutcome.GetResult().GetChecksumType() == Aws::S3::Model::ChecksumType::FULL_OBJECT) { + if (!headObjectOutcome.GetResult().GetChecksumCRC32C().empty()) { + handle->SetChecksum(headObjectOutcome.GetResult().GetChecksumCRC32C()); + handle->SetChecksumAlgorithm(S3::Model::ChecksumAlgorithm::CRC32C); + } else if (!headObjectOutcome.GetResult().GetChecksumCRC32().empty()) { + handle->SetChecksum(headObjectOutcome.GetResult().GetChecksumCRC32()); + handle->SetChecksumAlgorithm(S3::Model::ChecksumAlgorithm::CRC32); + } else if (!headObjectOutcome.GetResult().GetChecksumCRC64NVME().empty()) { + handle->SetChecksum(headObjectOutcome.GetResult().GetChecksumCRC64NVME()); + handle->SetChecksumAlgorithm(S3::Model::ChecksumAlgorithm::CRC64NVME); + } + } /* When bucket versioning is suspended, head object will return "null" for unversioned object. * Send following GetObject with "null" as versionId will result in 403 access denied error if your IAM role or policy * doesn't have GetObjectVersion permission. @@ -1082,6 +1099,7 @@ namespace Aws getObjectRangeRequest.SetRange(FormatRangeSpecifier(rangeStart, rangeEnd)); getObjectRangeRequest.SetResponseStreamFactory(responseStreamFunction); getObjectRangeRequest.SetIfMatch(handle->GetEtag()); + getObjectRangeRequest.SetChecksumMode(Aws::S3::Model::ChecksumMode::ENABLED); if(handle->GetVersionId().size() > 0) { getObjectRangeRequest.SetVersionId(handle->GetVersionId()); @@ -1204,6 +1222,19 @@ namespace Aws Aws::String errMsg{handle->WritePartToDownloadStream(bufferStream, partState->GetRangeBegin())}; if (errMsg.empty()) { + if (!outcome.GetResult().GetChecksumCRC32().empty()) { + partState->SetChecksum(outcome.GetResult().GetChecksumCRC32()); + } else if (!outcome.GetResult().GetChecksumCRC32C().empty()) { + partState->SetChecksum(outcome.GetResult().GetChecksumCRC32C()); + } else if (!outcome.GetResult().GetChecksumCRC64NVME().empty()) { + partState->SetChecksum(outcome.GetResult().GetChecksumCRC64NVME()); + } else if (!outcome.GetResult().GetChecksumSHA1().empty()) { + partState->SetChecksum(outcome.GetResult().GetChecksumSHA1()); + } else if (!outcome.GetResult().GetChecksumSHA256().empty()) { + partState->SetChecksum(outcome.GetResult().GetChecksumSHA256()); + } else { + if (m_transferConfig.validateChecksums) { handle->AddChecksumForPart(bufferStream, partState); } + } handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag()); } else { Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, @@ -1239,6 +1270,73 @@ namespace Aws { if (failedParts.size() == 0 && handle->GetBytesTransferred() == handle->GetBytesTotalSize()) { + if (m_transferConfig.validateChecksums && !handle->GetChecksum().empty() && + (handle->GetChecksumAlgorithm() == S3::Model::ChecksumAlgorithm::CRC32 || + handle->GetChecksumAlgorithm() == S3::Model::ChecksumAlgorithm::CRC32C || + handle->GetChecksumAlgorithm() == S3::Model::ChecksumAlgorithm::CRC64NVME)) { + uint64_t combinedChecksum = 0; + bool first = true; + for (const auto& part: handle->GetCompletedParts()) { + Aws::String checksumStr = part.second->GetChecksum(); + uint64_t partSize = part.second->GetSizeInBytes(); + if (checksumStr.empty()) { continue; } + auto decoded = Aws::Utils::HashingUtils::Base64Decode(checksumStr); + const auto* raw = decoded.GetUnderlyingData(); + if (first) { + if (handle->GetChecksumAlgorithm() == S3::Model::ChecksumAlgorithm::CRC64NVME) { + uint64_t partCrcBE = 0; + std::memcpy(&partCrcBE, raw, sizeof(uint64_t)); + const uint64_t partCrc = aws_ntoh64(partCrcBE); + combinedChecksum = partCrc; + } else { + uint32_t partCrcBE = 0; + std::memcpy(&partCrcBE, raw, sizeof(uint32_t)); + const uint32_t partCrc = aws_ntoh32(partCrcBE); + combinedChecksum = partCrc; + } + first = false; + } else { + if (handle->GetChecksumAlgorithm() == S3::Model::ChecksumAlgorithm::CRC64NVME) { + uint64_t partCrcBE = 0; + std::memcpy(&partCrcBE, raw, sizeof(uint64_t)); + const uint64_t partCrc = aws_ntoh64(partCrcBE); + combinedChecksum = Aws::Crt::Checksum::CombineCRC64NVME(combinedChecksum, partCrc, partSize); + } + else { + uint32_t partCrcBE = 0; + std::memcpy(&partCrcBE, raw, sizeof(uint32_t)); + const uint32_t partCrc = aws_ntoh32(partCrcBE); + if (handle->GetChecksumAlgorithm() == S3::Model::ChecksumAlgorithm::CRC32) { + combinedChecksum = Aws::Crt::Checksum::CombineCRC32(static_cast(combinedChecksum), partCrc, partSize); + } else { + combinedChecksum = Aws::Crt::Checksum::CombineCRC32C(static_cast(combinedChecksum), partCrc, partSize); + } + } + } + } + Aws::Utils::ByteBuffer checksumBuffer(handle->GetChecksumAlgorithm()== S3::Model::ChecksumAlgorithm::CRC64NVME ? 8 : 4); + if (handle->GetChecksumAlgorithm() == S3::Model::ChecksumAlgorithm::CRC64NVME) { + const uint64_t be = aws_hton64(combinedChecksum); + std::memcpy(checksumBuffer.GetUnderlyingData(), &be, sizeof(uint64_t)); + } else { + const uint32_t be = aws_hton32(static_cast(combinedChecksum)); + std::memcpy(checksumBuffer.GetUnderlyingData(), &be, sizeof(uint32_t)); + } + Aws::String combinedChecksumStr = Aws::Utils::HashingUtils::Base64Encode(checksumBuffer); + + if (combinedChecksumStr != handle->GetChecksum()) { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() + << "] Full-object checksum mismatch. Expected: " << handle->GetChecksum() + << ", Calculated: " << combinedChecksumStr); + Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, + "ChecksumMismatch", + "Full-object checksum validation failed", + false); + handle->SetError(error); + handle->UpdateStatus(TransferStatus::FAILED); + TriggerErrorCallback(handle, error); + } + } outcome.GetResult().GetBody().flush(); handle->UpdateStatus(TransferStatus::COMPLETED); } diff --git a/tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp b/tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp index 6d97c81b0c2..e5a844691e1 100644 --- a/tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp +++ b/tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp @@ -4,10 +4,13 @@ #include #include #include +#include +#include #include #include #include #include +#include using namespace Aws; using namespace Aws::S3; @@ -36,11 +39,87 @@ class MockS3Client : public S3Client { } }; +class MockMultipartS3Client : public S3Client { +public: + Aws::String FULL_OBJECT_CHECKSUM; //"SBi/K+1ooBg=" + MockMultipartS3Client(Aws::String expected_checksum) : S3Client() { + FULL_OBJECT_CHECKSUM = expected_checksum; + }; + + HeadObjectOutcome HeadObject(const HeadObjectRequest&) const override { + HeadObjectResult result; + result.SetContentLength(78643200); + result.SetChecksumCRC64NVME(FULL_OBJECT_CHECKSUM); + result.SetChecksumType(Aws::S3::Model::ChecksumType::FULL_OBJECT); // This is key! + result.SetETag("\"test-etag-12345\""); // Add ETag + return HeadObjectOutcome(std::move(result)); + } + + GetObjectOutcome GetObject(const GetObjectRequest& request) const override { + GetObjectResult result; + + const uint64_t totalSize = 78643200; + const uint64_t partSize = 5242880; + const std::vector checksums = { + "wAQOkgd/LJk=", "zfmsUj6AZfs=", "oyENjcGDHcY=", "wAQOkgd/LJk=", "zfmsUj6AZfs=", + "oyENjcGDHcY=", "wAQOkgd/LJk=", "zfmsUj6AZfs=", "oyENjcGDHcY=", "wAQOkgd/LJk=", + "zfmsUj6AZfs=", "oyENjcGDHcY=", "wAQOkgd/LJk=", "zfmsUj6AZfs=", "oyENjcGDHcY=" + }; + + if (request.RangeHasBeenSet()) { + auto range = request.GetRange(); + size_t dashPos = range.find('-'); + uint64_t start = std::stoull(range.substr(6, dashPos - 6)); + uint64_t end = std::stoull(range.substr(dashPos + 1)); + uint64_t size = end - start + 1; + + int partNum = static_cast(start) / partSize; + if (partNum < 15) { + result.SetContentRange(Aws::String("bytes ") + std::to_string(start) + "-" + std::to_string(end) + "/" + std::to_string(totalSize)); + result.SetChecksumCRC64NVME(checksums[partNum]); + result.SetContentLength(size); + result.SetETag(Aws::String("\"part-etag-") + std::to_string(partNum) + "\""); + + // Call the response stream factory if provided + if (request.GetResponseStreamFactory()) { + auto responseStream = request.GetResponseStreamFactory()(); + + // Write part-specific data to the response stream + char partChar = 'A' + (partNum % 3); + for (uint64_t i = 0; i < size; ++i) { + responseStream->put(partChar); + } + responseStream->flush(); + + // Simulate data received callback to track bytes transferred + if (request.GetDataReceivedEventHandler()) { + request.GetDataReceivedEventHandler()(nullptr, nullptr, size); + } + + result.ReplaceBody(responseStream); + } else { + // Fallback for non-factory requests + auto stream = Aws::New(ALLOCATION_TAG); + char partChar = 'A' + (partNum % 3); + for (uint64_t i = 0; i < size; ++i) { + stream->put(partChar); + } + stream->seekg(0, std::ios::beg); + result.ReplaceBody(stream); + } + } + } + + return GetObjectOutcome(std::move(result)); + } +}; + class TransferUnitTest : public testing::Test { protected: void SetUp() override { executor = Aws::MakeShared(ALLOCATION_TAG, 1); mockS3Client = Aws::MakeShared(ALLOCATION_TAG); + mockMultipartS3Client = Aws::MakeShared(ALLOCATION_TAG, "SBi/K+1ooBg="); } static void SetUpTestSuite() { @@ -53,6 +132,7 @@ class TransferUnitTest : public testing::Test { std::shared_ptr executor; std::shared_ptr mockS3Client; + std::shared_ptr mockMultipartS3Client; static SDKOptions _options; }; @@ -73,3 +153,79 @@ TEST_F(TransferUnitTest, ContentValidationShouldFail) { EXPECT_EQ(TransferStatus::FAILED, handle->GetStatus()); } + +TEST_F(TransferUnitTest, MultipartDownloadTest) { + TransferManagerConfiguration config(executor.get()); + config.s3Client = mockMultipartS3Client; + config.bufferSize = 5242880; // 5MB to ensure multipart + auto transferManager = TransferManager::Create(config); + + // Create a temporary file for download since multipart needs seekable stream + std::string tempFile; +#ifdef _WIN32 + char tempPath[MAX_PATH]; + GetTempPathA(MAX_PATH, tempPath); + tempFile = std::string(tempPath) + "test_download_" + std::to_string(rand()); +#else + tempFile = "/tmp/test_download_" + std::to_string(rand()); +#endif + auto createStreamFn = [tempFile]() -> Aws::IOStream* { + return Aws::New(ALLOCATION_TAG, tempFile.c_str(), + std::ios_base::out | std::ios_base::in | + std::ios_base::binary | std::ios_base::trunc); + }; + + // Download the full 78MB file + auto handle = transferManager->DownloadFile("test-bucket", "test-key", createStreamFn); + handle->WaitUntilFinished(); + + // Test multipart download functionality - should PASS with correct checksum + EXPECT_TRUE(handle->IsMultipart()); + EXPECT_EQ(78643200u, handle->GetBytesTotalSize()); + EXPECT_EQ(15u, handle->GetCompletedParts().size()); + EXPECT_EQ(0u, handle->GetFailedParts().size()); + EXPECT_EQ(0u, handle->GetPendingParts().size()); + EXPECT_EQ(TransferStatus::COMPLETED, handle->GetStatus()); // Should PASS + + // Clean up + std::remove(tempFile.c_str()); +} + +TEST_F(TransferUnitTest, MultipartDownloadTest_Fail) { + TransferManagerConfiguration config(executor.get()); + auto mockFailClient = Aws::MakeShared(ALLOCATION_TAG, "WRONG_CHECKSUM="); + config.s3Client = mockFailClient; + config.bufferSize = 5242880; // 5MB to ensure multipart + auto transferManager = TransferManager::Create(config); + + // Create a temporary file for download since multipart needs seekable stream + std::string tempFile; +#ifdef _WIN32 + char tempPath[MAX_PATH]; + GetTempPathA(MAX_PATH, tempPath); + tempFile = std::string(tempPath) + "test_download_" + std::to_string(rand()); +#else + tempFile = "/tmp/test_download_" + std::to_string(rand()); +#endif + auto createStreamFn = [tempFile]() -> Aws::IOStream* { + return Aws::New(ALLOCATION_TAG, tempFile.c_str(), + std::ios_base::out | std::ios_base::in | + std::ios_base::binary | std::ios_base::trunc); + }; + + // Download the full 78MB file + auto handle = transferManager->DownloadFile("test-bucket", "test-key", createStreamFn); + handle->WaitUntilFinished(); + + // Test multipart download functionality - should FAIL with wrong checksum + EXPECT_TRUE(handle->IsMultipart()); + EXPECT_EQ(78643200u, handle->GetBytesTotalSize()); + EXPECT_EQ(15u, handle->GetCompletedParts().size()); + EXPECT_EQ(0u, handle->GetFailedParts().size()); + EXPECT_EQ(0u, handle->GetPendingParts().size()); + EXPECT_EQ(TransferStatus::FAILED, handle->GetStatus()); // Should FAIL due to wrong checksum + EXPECT_EQ("Full-object checksum validation failed", handle->GetLastError().GetMessage()); + + // Clean up + std::remove(tempFile.c_str()); +} \ No newline at end of file