diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index b8d9ca6866..a8e6d9f59d 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -429,12 +429,21 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): if ".." in prefix: raise ValueError("Traversal components are not allowed in S3 path!") + # Spot check: enforce ownership only when downloading from the session's default + # bucket. Cross-account buckets are left untouched. + expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket_name) + extra_args = None + if expected_owner: + extra_args = {"ExpectedBucketOwner": expected_owner} + # Try to download the prefix as an object first, in case it is a file and not a 'directory'. # Do this first, in case the object has broader permissions than the bucket. if not prefix.endswith("/"): try: file_destination = os.path.join(target, os.path.basename(prefix)) - s3.Object(bucket_name, prefix).download_file(file_destination) + s3.Object(bucket_name, prefix).download_file( + file_destination, ExtraArgs=extra_args + ) return except botocore.exceptions.ClientError as e: err_info = e.response["Error"] @@ -445,10 +454,10 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): else: raise - _download_files_under_prefix(bucket_name, prefix, target, s3) + _download_files_under_prefix(bucket_name, prefix, target, s3, extra_args=extra_args) -def _download_files_under_prefix(bucket_name, prefix, target, s3): +def _download_files_under_prefix(bucket_name, prefix, target, s3, extra_args=None): """Download all S3 files which match the given prefix Args: @@ -456,6 +465,8 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3): prefix (str): S3 prefix within the bucket that will be downloaded target (str): destination path where the downloaded items will be placed s3 (boto3.resources.base.ServiceResource): S3 resource + extra_args (dict): Optional extra arguments passed to each download_file call. + Used to carry ExpectedBucketOwner when the bucket is the session's default. """ bucket = s3.Bucket(bucket_name) for obj_sum in bucket.objects.filter(Prefix=prefix): @@ -473,7 +484,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3): # anything else will be raised. if exc.errno != errno.EEXIST: raise - obj.download_file(file_path) + obj.download_file(file_path, ExtraArgs=extra_args) def create_tar_file(source_files, target=None): @@ -620,6 +631,16 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): extra_args = {"ServerSideEncryption": "aws:kms"} else: extra_args = None + + # Spot check: when the model is being uploaded to the session's default bucket, + # assert ownership to defend against bucket-squatting on the predictable default + # name. Other caller-supplied buckets are left untouched. + if sagemaker_session is not None: + expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket) + if expected_owner: + extra_args = dict(extra_args) if extra_args else {} + extra_args["ExpectedBucketOwner"] = expected_owner + sagemaker_session.boto_session.resource( "s3", region_name=sagemaker_session.boto_region_name ).Object(bucket, new_key).upload_file(tmp_model_path, ExtraArgs=extra_args) @@ -767,7 +788,17 @@ def download_file(bucket_name, path, target, sagemaker_session): s3 = boto_session.resource("s3", region_name=sagemaker_session.boto_region_name) bucket = s3.Bucket(bucket_name) - bucket.download_file(path, target) + + # Spot check: assert ownership only when downloading from the session's default + # bucket. Non-default buckets (e.g. caller-supplied model URIs pointing at shared + # or cross-account data) are downloaded without ExpectedBucketOwner to preserve + # legitimate cross-account flows. + expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket_name) + extra_args = None + if expected_owner: + extra_args = {"ExpectedBucketOwner": expected_owner} + + bucket.download_file(path, target, ExtraArgs=extra_args) def sts_regional_endpoint(region): diff --git a/sagemaker-core/src/sagemaker/core/experiments/_helper.py b/sagemaker-core/src/sagemaker/core/experiments/_helper.py index d94dd31fca..6ce424da05 100644 --- a/sagemaker-core/src/sagemaker/core/experiments/_helper.py +++ b/sagemaker-core/src/sagemaker/core/experiments/_helper.py @@ -95,6 +95,16 @@ def upload_artifact(self, file_path, extra_args=None): artifact_s3_key = "{}/{}/{}".format( self.artifact_prefix, self.trial_component_name, artifact_name ) + + # Spot check: enforce ownership only when uploading to the session's default + # bucket. Cross-account destinations are left untouched. + expected_owner = self.sagemaker_session._get_account_id_if_default_bucket( + self.artifact_bucket + ) + if expected_owner: + extra_args = dict(extra_args) if extra_args else {} + extra_args["ExpectedBucketOwner"] = expected_owner + self._s3_client.upload_file( file_path, self.artifact_bucket, @@ -133,9 +143,21 @@ def upload_object_artifact(self, artifact_name, artifact_object, file_extension= artifact_s3_key = "{}/{}/{}".format( self.artifact_prefix, self.trial_component_name, artifact_name ) - self._s3_client.put_object( - Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key + + # Spot check: enforce ownership only when uploading to the session's default + # bucket. Cross-account destinations are left untouched. + put_kwargs = { + "Body": json.dumps(artifact_object), + "Bucket": self.artifact_bucket, + "Key": artifact_s3_key, + } + expected_owner = self.sagemaker_session._get_account_id_if_default_bucket( + self.artifact_bucket ) + if expected_owner: + put_kwargs["ExpectedBucketOwner"] = expected_owner + + self._s3_client.put_object(**put_kwargs) etag = self._try_get_etag(artifact_s3_key) return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag @@ -149,7 +171,14 @@ def _try_get_etag(self, key): str: The S3 object ETag if it allows, otherwise return None. """ try: - response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key) + head_kwargs = {"Bucket": self.artifact_bucket, "Key": key} + expected_owner = self.sagemaker_session._get_account_id_if_default_bucket( + self.artifact_bucket + ) + if expected_owner: + head_kwargs["ExpectedBucketOwner"] = expected_owner + + response = self._s3_client.head_object(**head_kwargs) return response["ETag"] except botocore.exceptions.ClientError as error: # requires read permissions diff --git a/sagemaker-core/src/sagemaker/core/fw_utils.py b/sagemaker-core/src/sagemaker/core/fw_utils.py index f658ae9840..a520286141 100644 --- a/sagemaker-core/src/sagemaker/core/fw_utils.py +++ b/sagemaker-core/src/sagemaker/core/fw_utils.py @@ -400,6 +400,7 @@ def tar_and_upload_dir( kms_key=None, s3_resource=None, settings: Optional[SessionSettings] = None, + expected_bucket_owner: Optional[str] = None, ) -> UploadedCode: """Package source files and upload a compress tar file to S3. @@ -430,6 +431,12 @@ def tar_and_upload_dir( settings (sagemaker.session_settings.SessionSettings): Optional. The settings of the SageMaker ``Session``, can be used to override the default encryption behavior (default: None). + expected_bucket_owner (str): Optional. AWS account id passed as + ``ExpectedBucketOwner`` on the upload. Callers should supply this when + ``bucket`` is the session's default bucket (via + ``Session._get_account_id_if_default_bucket``) to defend against + bucket-squatting on the predictable default name. Leave as ``None`` for + cross-account destination buckets. Returns: sagemaker.fw_utils.UploadedCode: An object with the S3 bucket and key (S3 prefix) and script name. @@ -471,6 +478,10 @@ def tar_and_upload_dir( else: extra_args = None + if expected_bucket_owner: + extra_args = dict(extra_args) if extra_args else {} + extra_args["ExpectedBucketOwner"] = expected_bucket_owner + if s3_resource is None: s3_resource = session.resource("s3", region_name=session.region_name) else: diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 41957e30a2..ffaa7e2f71 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -428,6 +428,14 @@ def upload_data(self, path, bucket=None, key_prefix="data", callback=None, extra bucket=bucket, key_prefix=key_prefix, sagemaker_session=self ) + # Spot check: if the resolved bucket is the session's default bucket, enforce + # ownership on the upload to defend against bucket-squatting on the predictable + # default name. Other buckets are left untouched to preserve cross-account flows. + expected_owner = self._get_account_id_if_default_bucket(bucket) + if expected_owner: + extra_args = dict(extra_args) if extra_args else {} + extra_args["ExpectedBucketOwner"] = expected_owner + # Generate a tuple for each file that we want to upload of the form (local_path, s3_key). files = [] key_suffix = None @@ -484,10 +492,17 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None): s3_object = s3.Object(bucket_name=bucket, key=key) + # Spot check: enforce ownership only when writing to the session's default + # bucket. Cross-account destinations are left untouched. + put_kwargs = {"Body": body} if kms_key is not None: - s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms") - else: - s3_object.put(Body=body) + put_kwargs["SSEKMSKeyId"] = kms_key + put_kwargs["ServerSideEncryption"] = "aws:kms" + expected_owner = self._get_account_id_if_default_bucket(bucket) + if expected_owner: + put_kwargs["ExpectedBucketOwner"] = expected_owner + + s3_object.put(**put_kwargs) s3_uri = "s3://{}/{}".format(bucket, key) return s3_uri @@ -511,11 +526,19 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): else: s3 = self.s3_client + # Spot check: if the caller-supplied bucket is the session's default bucket, + # assert ownership on the list/download calls to defend against squatting on + # the predictable default bucket name. Non-default buckets are left untouched + # to preserve legitimate cross-account flows. + expected_owner = self._get_account_id_if_default_bucket(bucket) + # Initialize the variables used to loop through the contents of the S3 bucket. keys = [] directories = [] next_token = "" base_parameters = {"Bucket": bucket, "Prefix": key_prefix} + if expected_owner: + base_parameters["ExpectedBucketOwner"] = expected_owner # Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into # a "keys" list. @@ -542,6 +565,9 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): # For each object key, create the directory on the local machine if needed, and then # download the file. + download_extra_args = dict(extra_args) if extra_args else {} + if expected_owner: + download_extra_args["ExpectedBucketOwner"] = expected_owner downloaded_paths = [] for dir_path in directories: os.makedirs(os.path.dirname(dir_path), exist_ok=True) @@ -553,7 +579,10 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): if not os.path.exists(os.path.dirname(destination_path)): os.makedirs(os.path.dirname(destination_path), exist_ok=True) s3.download_file( - Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args + Bucket=bucket, + Key=key, + Filename=destination_path, + ExtraArgs=download_extra_args or None, ) downloaded_paths.append(destination_path) return downloaded_paths @@ -573,8 +602,16 @@ def read_s3_file(self, bucket, key_prefix): else: s3 = self.s3_client + # Spot check: assert ownership only when the caller is reading from the + # session's default bucket. Other buckets (e.g. JumpStart content buckets) + # are read without ExpectedBucketOwner to preserve cross-account flows. + get_kwargs = {"Bucket": bucket, "Key": key_prefix} + expected_owner = self._get_account_id_if_default_bucket(bucket) + if expected_owner: + get_kwargs["ExpectedBucketOwner"] = expected_owner + # Explicitly passing a None kms_key to boto3 throws a validation error. - s3_object = s3.get_object(Bucket=bucket, Key=key_prefix) + s3_object = s3.get_object(**get_kwargs) return s3_object["Body"].read().decode("utf-8") @@ -699,32 +736,50 @@ def general_bucket_check_if_user_has_permission( If there is any other error that comes up with calling head bucket, it is raised up here If there is no bucket , it will create one + When the SDK selected the bucket name itself (``_default_bucket_set_by_sdk`` is True), + the probe is issued with ``ExpectedBucketOwner`` set to the caller's account id so that + S3 returns ``403`` if a bucket with the SDK-chosen name happens to exist in another + account (bucket-squatting defense). For user-overridden bucket names the probe is + issued without ``ExpectedBucketOwner`` to preserve legitimate cross-account usage. + Args: bucket_name (str): Name of the S3 bucket s3 (str): S3 object from boto session region (str): The region in which to create the bucket. bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not """ + extra_args = {} + if self._default_bucket_set_by_sdk: + extra_args["ExpectedBucketOwner"] = self.account_id() + try: if self.default_bucket_prefix: s3.meta.client.list_objects_v2( - Bucket=bucket_name, Prefix=self.default_bucket_prefix + Bucket=bucket_name, Prefix=self.default_bucket_prefix, **extra_args ) else: - s3.meta.client.head_bucket(Bucket=bucket_name) + s3.meta.client.head_bucket(Bucket=bucket_name, **extra_args) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] - # bucket does not exist or forbidden to access + # bucket does not exist, is owned by another account, or is forbidden to access if bucket_creation_date_none: if error_code == "404" and message == "Not Found": self.create_bucket_for_not_exist_error(bucket_name, region, s3) elif error_code == "403" and message == "Forbidden": - LOGGER.error( - "Bucket %s exists, but access is forbidden. Please try again after " - "adding appropriate access.", - bucket.name, - ) + if self._default_bucket_set_by_sdk: + LOGGER.error( + "Bucket %s is not accessible as the default bucket. It may be " + "owned by another account or access is forbidden. To unblock, " + "pass a custom default_bucket parameter to sagemaker.Session.", + bucket.name, + ) + else: + LOGGER.error( + "Bucket %s exists, but access is forbidden. Please try again " + "after adding appropriate access.", + bucket.name, + ) raise else: raise @@ -773,6 +828,51 @@ def generate_default_sagemaker_bucket_name(self, boto_session): ).get_caller_identity()["Account"] return "sagemaker-{}-{}".format(region, account) + def _get_account_id_if_default_bucket(self, bucket): + """Return the caller's account id if ``bucket`` is the SDK-generated default bucket. + + Used by S3 operations that receive a caller-supplied bucket name to apply the + "spot check": when the bucket matches the SDK-generated default bucket name, + the call should assert ownership via ``ExpectedBucketOwner`` to defend against + bucket-squatting on the predictable default name. For any other bucket — including + user-overridden default buckets (which may legitimately be cross-account), + JumpStart, marketplace artifacts, or shared-team buckets — this returns ``None`` + so the call proceeds unchanged. + + This check is passive: it does not trigger default-bucket resolution or creation. + If ``default_bucket()`` has not been called yet, this returns ``None``. + + Args: + bucket (str): The bucket name the caller is about to use. + + Returns: + Optional[str]: The expected account id, or ``None`` if the spot check does + not apply. + """ + if not bucket: + return None + + # Only apply the spot check when the SDK generated the bucket name itself. + # User-overridden default buckets (_default_bucket_name_override) may legitimately + # be in another account, so we must not assert caller-account ownership on them. + if not self._default_bucket_set_by_sdk: + return None + + # Use the already-resolved default bucket (fast, no side effects). + # _default_bucket is only set after default_bucket() has been called at least once. + if self._default_bucket and bucket == self._default_bucket: + try: + return self.account_id() + except Exception: # pylint: disable=broad-except + # account_id() issues an STS call; if it fails we skip the spot check + # rather than block the S3 operation. + LOGGER.warning( + "Could not resolve caller account id for ExpectedBucketOwner check " + "on bucket %s; proceeding without the check.", + bucket, + ) + return None + def determine_bucket_and_prefix( self, bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None ): diff --git a/sagemaker-core/src/sagemaker/core/lambda_helper.py b/sagemaker-core/src/sagemaker/core/lambda_helper.py index 7c1a4c26e7..af942404ab 100644 --- a/sagemaker-core/src/sagemaker/core/lambda_helper.py +++ b/sagemaker-core/src/sagemaker/core/lambda_helper.py @@ -123,12 +123,17 @@ def create(self): bucket, key_prefix = s3.determine_bucket_and_prefix( bucket=self.s3_bucket, key_prefix=None, sagemaker_session=self.session ) + # Spot check: if the resolved bucket is the session's default bucket, + # enforce ownership on the upload so an attacker cannot squat on the + # predictable default name. + expected_owner = self.session._get_account_id_if_default_bucket(bucket) key = _upload_to_s3( s3_client=_get_s3_client(self.session), function_name=self.function_name, zipped_code_dir=self.zipped_code_dir, s3_bucket=bucket, s3_key_prefix=key_prefix, + expected_bucket_owner=expected_owner, ) code = {"S3Bucket": bucket, "S3Key": key} @@ -179,6 +184,13 @@ def update(self): else: function_name_for_s3 = self.function_name + # Spot check: enforce ownership only when the resolved bucket is + # the session's default bucket (defends against squatting on the + # predictable default name). Other buckets are left untouched. + expected_owner = self.session._get_account_id_if_default_bucket( + bucket + ) + response = lambda_client.update_function_code( FunctionName=(self.function_name or self.function_arn), S3Bucket=bucket, @@ -188,6 +200,7 @@ def update(self): zipped_code_dir=self.zipped_code_dir, s3_bucket=bucket, s3_key_prefix=key_prefix, + expected_bucket_owner=expected_owner, ), ) return response @@ -276,13 +289,31 @@ def _get_lambda_client(session): return lambda_client -def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket, s3_key_prefix=None): +def _upload_to_s3( + s3_client, + function_name, + zipped_code_dir, + s3_bucket, + s3_key_prefix=None, + expected_bucket_owner=None, +): """Upload the zipped code to S3 bucket provided in the Lambda instance. Lambda instance must have a path to the zipped code folder and a S3 bucket to upload the code. The key will lambda/function_name/code and the S3 URI where the code is uploaded is in this format: s3://bucket_name/lambda/function_name/code. + Args: + s3_client: boto3 S3 client used for the upload. + function_name (str): Lambda function name used to build the S3 key. + zipped_code_dir (str): Local path to the zipped Lambda code. + s3_bucket (str): Destination S3 bucket. + s3_key_prefix (str): Optional S3 key prefix. + expected_bucket_owner (str): Optional account id passed as ``ExpectedBucketOwner`` + on the upload when the destination bucket should belong to that account + (typically the caller's account, when ``s3_bucket`` is the session's default + bucket). ``None`` leaves the upload untouched for cross-account flows. + Returns: the S3 key where the code is uploaded. """ @@ -292,7 +323,10 @@ def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket, s3_key_p function_name, "code", ) - s3_client.upload_file(zipped_code_dir, s3_bucket, key) + extra_args = None + if expected_bucket_owner: + extra_args = {"ExpectedBucketOwner": expected_bucket_owner} + s3_client.upload_file(zipped_code_dir, s3_bucket, key, ExtraArgs=extra_args) return key diff --git a/sagemaker-core/src/sagemaker/core/s3/client.py b/sagemaker-core/src/sagemaker/core/s3/client.py index f16350dda4..427d24a200 100644 --- a/sagemaker-core/src/sagemaker/core/s3/client.py +++ b/sagemaker-core/src/sagemaker/core/s3/client.py @@ -117,6 +117,13 @@ def upload_bytes(b: Union[bytes, io.BytesIO], s3_uri, kms_key=None, sagemaker_se else: extra_args = None + # Spot check: enforce ownership only when uploading to the session's default + # bucket. Cross-account destinations are left untouched. + expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket) + if expected_owner: + extra_args = dict(extra_args) if extra_args else {} + extra_args["ExpectedBucketOwner"] = expected_owner + b = b if isinstance(b, io.BytesIO) else io.BytesIO(b) sagemaker_session.s3_resource.Bucket(bucket).upload_fileobj( b, object_key, ExtraArgs=extra_args @@ -193,8 +200,17 @@ def read_bytes(s3_uri, sagemaker_session=None) -> bytes: bucket, object_key = parse_s3_url(s3_uri) + # Spot check: enforce ownership only when reading from the session's default + # bucket. Cross-account reads (e.g. JumpStart) are left untouched. + extra_args = None + expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket) + if expected_owner: + extra_args = {"ExpectedBucketOwner": expected_owner} + bytes_io = io.BytesIO() - sagemaker_session.s3_resource.Bucket(bucket).download_fileobj(object_key, bytes_io) + sagemaker_session.s3_resource.Bucket(bucket).download_fileobj( + object_key, bytes_io, ExtraArgs=extra_args + ) bytes_io.seek(0) return bytes_io.read() diff --git a/sagemaker-core/src/sagemaker/core/tools/api_coverage.json b/sagemaker-core/src/sagemaker/core/tools/api_coverage.json index 860edc8e79..65b0080b22 100644 --- a/sagemaker-core/src/sagemaker/core/tools/api_coverage.json +++ b/sagemaker-core/src/sagemaker/core/tools/api_coverage.json @@ -1 +1 @@ -{"SupportedAPIs": 361, "UnsupportedAPIs": 6} \ No newline at end of file +{"SupportedAPIs": 461, "UnsupportedAPIs": 48} \ No newline at end of file diff --git a/sagemaker-core/tests/unit/helper/test_session_helper.py b/sagemaker-core/tests/unit/helper/test_session_helper.py index 55589ee3ef..1c3efd7ddf 100644 --- a/sagemaker-core/tests/unit/helper/test_session_helper.py +++ b/sagemaker-core/tests/unit/helper/test_session_helper.py @@ -1412,3 +1412,296 @@ def test_general_bucket_check_without_prefix(self, mock_boto_session, mock_sagem "test-bucket", mock_s3_resource, mock_bucket, "us-west-2", True ) mock_s3_resource.meta.client.head_bucket.assert_called_once_with(Bucket="test-bucket") + + +class TestUploadDataSpotCheck: + """Spot-check behavior in Session.upload_data. + + ExpectedBucketOwner must be added to ExtraArgs when the destination bucket + is the session's default bucket, and NOT added for any other bucket. + """ + + def test_upload_to_default_bucket_includes_expected_owner( + self, mock_boto_session, mock_sagemaker_client, tmp_path + ): + test_file = tmp_path / "test.txt" + test_file.write_text("x") + + mock_s3_resource = Mock() + mock_s3_object = Mock() + mock_s3_resource.Object.return_value = mock_s3_object + + session = Session( + boto_session=mock_boto_session, + sagemaker_client=mock_sagemaker_client, + default_bucket="sagemaker-us-west-2-111111111111", + ) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_resource = mock_s3_resource + + with patch.object(session, "account_id", return_value="111111111111"): + session.upload_data( + path=str(test_file), + bucket="sagemaker-us-west-2-111111111111", + key_prefix="data", + ) + + call_args = mock_s3_object.upload_file.call_args + assert call_args[1]["ExtraArgs"] == {"ExpectedBucketOwner": "111111111111"} + + def test_upload_to_non_default_bucket_omits_expected_owner( + self, mock_boto_session, mock_sagemaker_client, tmp_path + ): + """Cross-account uploads (e.g. to a partner or shared bucket) must not + carry ExpectedBucketOwner, or they would break. + """ + test_file = tmp_path / "test.txt" + test_file.write_text("x") + + mock_s3_resource = Mock() + mock_s3_object = Mock() + mock_s3_resource.Object.return_value = mock_s3_object + + session = Session( + boto_session=mock_boto_session, + sagemaker_client=mock_sagemaker_client, + default_bucket="sagemaker-us-west-2-111111111111", + ) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_resource = mock_s3_resource + + with patch.object(session, "account_id", return_value="111111111111"): + session.upload_data( + path=str(test_file), + bucket="partner-cross-account-bucket", + key_prefix="data", + ) + + call_args = mock_s3_object.upload_file.call_args + # ExtraArgs should remain None since caller didn't pass any and bucket is not default. + assert call_args[1]["ExtraArgs"] is None + + def test_upload_default_bucket_merges_with_existing_extra_args( + self, mock_boto_session, mock_sagemaker_client, tmp_path + ): + """Existing ExtraArgs (e.g. KMS config) must be preserved alongside ExpectedBucketOwner.""" + test_file = tmp_path / "test.txt" + test_file.write_text("x") + + mock_s3_resource = Mock() + mock_s3_object = Mock() + mock_s3_resource.Object.return_value = mock_s3_object + + session = Session( + boto_session=mock_boto_session, + sagemaker_client=mock_sagemaker_client, + default_bucket="sagemaker-us-west-2-111111111111", + ) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_resource = mock_s3_resource + + with patch.object(session, "account_id", return_value="111111111111"): + session.upload_data( + path=str(test_file), + bucket="sagemaker-us-west-2-111111111111", + key_prefix="data", + extra_args={"ServerSideEncryption": "AES256"}, + ) + + merged = mock_s3_object.upload_file.call_args[1]["ExtraArgs"] + assert merged["ServerSideEncryption"] == "AES256" + assert merged["ExpectedBucketOwner"] == "111111111111" + + +class TestUploadStringAsFileBodySpotCheck: + """Spot check in Session.upload_string_as_file_body.""" + + def test_to_default_bucket_includes_expected_owner( + self, mock_boto_session, mock_sagemaker_client + ): + mock_s3_resource = Mock() + mock_s3_object = Mock() + mock_s3_resource.Object.return_value = mock_s3_object + + session = Session( + boto_session=mock_boto_session, + sagemaker_client=mock_sagemaker_client, + default_bucket="sagemaker-us-west-2-111111111111", + ) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_resource = mock_s3_resource + + with patch.object(session, "account_id", return_value="111111111111"): + session.upload_string_as_file_body( + body="data", + bucket="sagemaker-us-west-2-111111111111", + key="some/key", + ) + + mock_s3_object.put.assert_called_once_with( + Body="data", ExpectedBucketOwner="111111111111" + ) + + def test_to_non_default_bucket_omits_expected_owner( + self, mock_boto_session, mock_sagemaker_client + ): + mock_s3_resource = Mock() + mock_s3_object = Mock() + mock_s3_resource.Object.return_value = mock_s3_object + + session = Session( + boto_session=mock_boto_session, + sagemaker_client=mock_sagemaker_client, + default_bucket="sagemaker-us-west-2-111111111111", + ) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_resource = mock_s3_resource + + with patch.object(session, "account_id", return_value="111111111111"): + session.upload_string_as_file_body( + body="data", + bucket="shared-partner-bucket", + key="some/key", + ) + + mock_s3_object.put.assert_called_once_with(Body="data") + + def test_to_default_bucket_preserves_kms(self, mock_boto_session, mock_sagemaker_client): + mock_s3_resource = Mock() + mock_s3_object = Mock() + mock_s3_resource.Object.return_value = mock_s3_object + + session = Session( + boto_session=mock_boto_session, + sagemaker_client=mock_sagemaker_client, + default_bucket="sagemaker-us-west-2-111111111111", + ) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_resource = mock_s3_resource + + with patch.object(session, "account_id", return_value="111111111111"): + session.upload_string_as_file_body( + body="data", + bucket="sagemaker-us-west-2-111111111111", + key="some/key", + kms_key="kms-key-id", + ) + + mock_s3_object.put.assert_called_once_with( + Body="data", + SSEKMSKeyId="kms-key-id", + ServerSideEncryption="aws:kms", + ExpectedBucketOwner="111111111111", + ) + + +class TestReadS3FileSpotCheck: + """Spot check in Session.read_s3_file.""" + + def test_read_from_default_bucket_includes_expected_owner( + self, mock_boto_session, mock_sagemaker_client + ): + mock_s3_client = Mock() + mock_body = Mock() + mock_body.read.return_value = b"content" + mock_s3_client.get_object.return_value = {"Body": mock_body} + + session = Session(boto_session=mock_boto_session, sagemaker_client=mock_sagemaker_client) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_client = mock_s3_client + + with patch.object(session, "account_id", return_value="111111111111"): + session.read_s3_file("sagemaker-us-west-2-111111111111", "k") + + mock_s3_client.get_object.assert_called_once_with( + Bucket="sagemaker-us-west-2-111111111111", + Key="k", + ExpectedBucketOwner="111111111111", + ) + + def test_read_from_non_default_bucket_omits_expected_owner( + self, mock_boto_session, mock_sagemaker_client + ): + """JumpStart / cross-account reads must not break.""" + mock_s3_client = Mock() + mock_body = Mock() + mock_body.read.return_value = b"content" + mock_s3_client.get_object.return_value = {"Body": mock_body} + + session = Session(boto_session=mock_boto_session, sagemaker_client=mock_sagemaker_client) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_client = mock_s3_client + + with patch.object(session, "account_id", return_value="111111111111"): + session.read_s3_file("jumpstart-cache-prod-us-west-2", "k") + + mock_s3_client.get_object.assert_called_once_with( + Bucket="jumpstart-cache-prod-us-west-2", Key="k" + ) + + +class TestDownloadDataSpotCheck: + """Spot check in Session.download_data.""" + + def test_download_from_default_bucket_includes_expected_owner( + self, mock_boto_session, mock_sagemaker_client, tmp_path + ): + mock_s3_client = Mock() + mock_s3_client.list_objects_v2.return_value = { + "Contents": [{"Key": "p/f.txt", "Size": 1}] + } + + session = Session(boto_session=mock_boto_session, sagemaker_client=mock_sagemaker_client) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_client = mock_s3_client + + with patch.object(session, "account_id", return_value="111111111111"): + session.download_data( + path=str(tmp_path), + bucket="sagemaker-us-west-2-111111111111", + key_prefix="p/f.txt", + ) + + mock_s3_client.list_objects_v2.assert_called_once_with( + Bucket="sagemaker-us-west-2-111111111111", + Prefix="p/f.txt", + ExpectedBucketOwner="111111111111", + ) + assert ( + mock_s3_client.download_file.call_args[1]["ExtraArgs"] + == {"ExpectedBucketOwner": "111111111111"} + ) + + def test_download_from_non_default_bucket_omits_expected_owner( + self, mock_boto_session, mock_sagemaker_client, tmp_path + ): + mock_s3_client = Mock() + mock_s3_client.list_objects_v2.return_value = { + "Contents": [{"Key": "p/f.txt", "Size": 1}] + } + + session = Session(boto_session=mock_boto_session, sagemaker_client=mock_sagemaker_client) + session._default_bucket = "sagemaker-us-west-2-111111111111" + session._default_bucket_set_by_sdk = True + session.s3_client = mock_s3_client + + with patch.object(session, "account_id", return_value="111111111111"): + session.download_data( + path=str(tmp_path), + bucket="jumpstart-cache-prod-us-west-2", + key_prefix="p/f.txt", + ) + + mock_s3_client.list_objects_v2.assert_called_once_with( + Bucket="jumpstart-cache-prod-us-west-2", Prefix="p/f.txt" + ) + assert mock_s3_client.download_file.call_args[1]["ExtraArgs"] is None diff --git a/sagemaker-core/tests/unit/session/test_session_bucket_operations.py b/sagemaker-core/tests/unit/session/test_session_bucket_operations.py index 9b5648b5e5..93af6c585d 100644 --- a/sagemaker-core/tests/unit/session/test_session_bucket_operations.py +++ b/sagemaker-core/tests/unit/session/test_session_bucket_operations.py @@ -222,3 +222,216 @@ def test_default_bucket_sdk_generated_with_owner_check(self, mock_boto_session): assert result == "sagemaker-us-west-2-123456789012" # Should check bucket ownership mock_s3_resource.meta.client.head_bucket.assert_called() + + +class TestGeneralBucketCheckExpectedBucketOwner: + """Test general_bucket_check_if_user_has_permission passes ExpectedBucketOwner + when the SDK generated the default bucket name (bucket-squatting defense). + """ + + @pytest.fixture + def mock_boto_session(self): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_session.resource.return_value = Mock() + return mock_session + + def test_probe_includes_expected_owner_when_sdk_selected_name(self, mock_boto_session): + """head_bucket should carry ExpectedBucketOwner when SDK picked the name.""" + mock_s3_resource = Mock() + mock_s3_resource.meta.client.head_bucket.return_value = None # 200 OK + + session = Session(boto_session=mock_boto_session) + session._default_bucket_set_by_sdk = True + with patch.object(session, "account_id", return_value="123456789012"): + session.general_bucket_check_if_user_has_permission( + "sagemaker-us-west-2-123456789012", + mock_s3_resource, + Mock(), + "us-west-2", + False, + ) + + mock_s3_resource.meta.client.head_bucket.assert_called_once_with( + Bucket="sagemaker-us-west-2-123456789012", + ExpectedBucketOwner="123456789012", + ) + + def test_probe_omits_expected_owner_when_user_selected_name(self, mock_boto_session): + """Probe must NOT pass ExpectedBucketOwner when the user picked the bucket name. + + This keeps legitimate cross-account bucket overrides working. + """ + mock_s3_resource = Mock() + mock_s3_resource.meta.client.head_bucket.return_value = None # 200 OK + + session = Session(boto_session=mock_boto_session) + session._default_bucket_set_by_sdk = False + session.general_bucket_check_if_user_has_permission( + "customer-cross-account-bucket", mock_s3_resource, Mock(), "us-west-2", False + ) + + mock_s3_resource.meta.client.head_bucket.assert_called_once_with( + Bucket="customer-cross-account-bucket" + ) + + def test_list_objects_probe_includes_expected_owner_when_sdk_selected( + self, mock_boto_session + ): + """list_objects_v2 branch (default_bucket_prefix set) passes ExpectedBucketOwner + only when SDK picked the name. + """ + mock_s3_resource = Mock() + mock_s3_resource.meta.client.list_objects_v2.return_value = {} # 200 OK + + session = Session(boto_session=mock_boto_session) + session.default_bucket_prefix = "team-prefix" + session._default_bucket_set_by_sdk = True + with patch.object(session, "account_id", return_value="123456789012"): + session.general_bucket_check_if_user_has_permission( + "sagemaker-us-west-2-123456789012", + mock_s3_resource, + Mock(), + "us-west-2", + False, + ) + + mock_s3_resource.meta.client.list_objects_v2.assert_called_once_with( + Bucket="sagemaker-us-west-2-123456789012", + Prefix="team-prefix", + ExpectedBucketOwner="123456789012", + ) + + def test_squatted_bucket_raises_on_403(self, mock_boto_session): + """If S3 returns 403 Forbidden (bucket owned by another account), + the probe must re-raise instead of silently accepting the bucket. + """ + mock_s3_resource = Mock() + mock_bucket = Mock() + mock_bucket.name = "sagemaker-us-west-2-123456789012" + mock_s3_resource.meta.client.head_bucket.side_effect = ClientError( + {"Error": {"Code": "403", "Message": "Forbidden"}}, "HeadBucket" + ) + + session = Session(boto_session=mock_boto_session) + session._default_bucket_set_by_sdk = True + with patch.object(session, "account_id", return_value="123456789012"): + with pytest.raises(ClientError): + session.general_bucket_check_if_user_has_permission( + "sagemaker-us-west-2-123456789012", + mock_s3_resource, + mock_bucket, + "us-west-2", + True, # bucket_creation_date_none -> enters 403 handling branch + ) + + def test_missing_bucket_still_triggers_creation(self, mock_boto_session): + """404 response path (bucket truly doesn't exist) must still create the bucket.""" + mock_s3_resource = Mock() + mock_bucket = Mock() + mock_bucket.name = "sagemaker-us-west-2-123456789012" + mock_s3_resource.meta.client.head_bucket.side_effect = ClientError( + {"Error": {"Code": "404", "Message": "Not Found"}}, "HeadBucket" + ) + + session = Session(boto_session=mock_boto_session) + session._default_bucket_set_by_sdk = True + + with patch.object(session, "account_id", return_value="123456789012"), patch.object( + session, "create_bucket_for_not_exist_error" + ) as mock_create: + session.general_bucket_check_if_user_has_permission( + "sagemaker-us-west-2-123456789012", + mock_s3_resource, + mock_bucket, + "us-west-2", + True, + ) + + mock_create.assert_called_once_with( + "sagemaker-us-west-2-123456789012", "us-west-2", mock_s3_resource + ) + + +class TestExpectedBucketOwnerIdIfDefaultBucket: + """Test the spot-check helper used by Group B S3 operations.""" + + @pytest.fixture + def mock_boto_session(self): + mock_session = Mock() + mock_session.region_name = "us-west-2" + mock_session.client.return_value = Mock() + mock_session.resource.return_value = Mock() + return mock_session + + def test_returns_none_for_empty_bucket(self, mock_boto_session): + session = Session(boto_session=mock_boto_session) + assert session._get_account_id_if_default_bucket(None) is None + assert session._get_account_id_if_default_bucket("") is None + + def test_returns_account_id_when_bucket_matches_resolved_default( + self, mock_boto_session + ): + session = Session(boto_session=mock_boto_session) + session._default_bucket = "sagemaker-us-west-2-123456789012" + session._default_bucket_set_by_sdk = True + + with patch.object(session, "account_id", return_value="123456789012"): + assert ( + session._get_account_id_if_default_bucket( + "sagemaker-us-west-2-123456789012" + ) + == "123456789012" + ) + + def test_returns_none_when_bucket_matches_user_override(self, mock_boto_session): + """User-overridden default buckets may be cross-account, so the spot check + must NOT fire — only SDK-generated names are vulnerable to squatting. + """ + session = Session(boto_session=mock_boto_session, default_bucket="my-override") + session._default_bucket_name_override = "my-override" + # _default_bucket_set_by_sdk remains False (user chose the name) + + assert session._get_account_id_if_default_bucket("my-override") is None + + def test_returns_none_for_non_default_bucket(self, mock_boto_session): + """Cross-account flows must not trigger the owner check.""" + session = Session(boto_session=mock_boto_session) + session._default_bucket = "sagemaker-us-west-2-123456789012" + session._default_bucket_set_by_sdk = True + + with patch.object(session, "account_id", return_value="123456789012"): + assert ( + session._get_account_id_if_default_bucket( + "jumpstart-cache-prod-us-west-2" + ) + is None + ) + + def test_returns_none_when_default_not_yet_resolved(self, mock_boto_session): + """Helper must be passive - not trigger default_bucket() resolution.""" + session = Session(boto_session=mock_boto_session) + session._default_bucket = None + session._default_bucket_name_override = None + + assert ( + session._get_account_id_if_default_bucket( + "sagemaker-us-west-2-123456789012" + ) + is None + ) + + def test_returns_none_when_account_id_fails(self, mock_boto_session): + """If STS call fails, fall back gracefully rather than block the S3 op.""" + session = Session(boto_session=mock_boto_session) + session._default_bucket = "sagemaker-us-west-2-123456789012" + session._default_bucket_set_by_sdk = True + + with patch.object(session, "account_id", side_effect=Exception("sts failure")): + assert ( + session._get_account_id_if_default_bucket( + "sagemaker-us-west-2-123456789012" + ) + is None + ) diff --git a/sagemaker-core/tests/unit/test_common_utils.py b/sagemaker-core/tests/unit/test_common_utils.py index 8aeb496922..19198e6860 100644 --- a/sagemaker-core/tests/unit/test_common_utils.py +++ b/sagemaker-core/tests/unit/test_common_utils.py @@ -2514,3 +2514,145 @@ def test_create_or_update_code_dir_validates_dependencies(self): None, tmpdir, ) + + +class TestDownloadFileSpotCheck: + """Spot-check behavior in common_utils.download_file.""" + + def test_download_from_default_bucket_includes_expected_owner(self, tmp_path): + from sagemaker.core.common_utils import download_file + + mock_session = Mock() + mock_session.boto_region_name = "us-west-2" + mock_boto_session = Mock() + mock_session.boto_session = mock_boto_session + mock_s3 = Mock() + mock_bucket = Mock() + mock_boto_session.resource.return_value = mock_s3 + mock_s3.Bucket.return_value = mock_bucket + + mock_session._get_account_id_if_default_bucket.return_value = "111111111111" + + download_file( + "sagemaker-us-west-2-111111111111", "k", str(tmp_path / "f"), mock_session + ) + + mock_session._get_account_id_if_default_bucket.assert_called_once_with( + "sagemaker-us-west-2-111111111111" + ) + mock_bucket.download_file.assert_called_once_with( + "k", str(tmp_path / "f"), ExtraArgs={"ExpectedBucketOwner": "111111111111"} + ) + + def test_download_from_non_default_bucket_omits_expected_owner(self, tmp_path): + from sagemaker.core.common_utils import download_file + + mock_session = Mock() + mock_session.boto_region_name = "us-west-2" + mock_boto_session = Mock() + mock_session.boto_session = mock_boto_session + mock_s3 = Mock() + mock_bucket = Mock() + mock_boto_session.resource.return_value = mock_s3 + mock_s3.Bucket.return_value = mock_bucket + + mock_session._get_account_id_if_default_bucket.return_value = None + + download_file("cross-account-bucket", "k", str(tmp_path / "f"), mock_session) + + mock_bucket.download_file.assert_called_once_with( + "k", str(tmp_path / "f"), ExtraArgs=None + ) + + +class TestSaveModelSpotCheck: + """Spot-check behavior in common_utils._save_model.""" + + def test_save_to_default_bucket_includes_expected_owner(self, tmp_path): + from sagemaker.core.common_utils import _save_model + from sagemaker.core.session_settings import SessionSettings + + model_file = tmp_path / "m.tar.gz" + model_file.write_text("x") + + mock_session = Mock() + mock_boto_session = Mock() + mock_session.boto_session = mock_boto_session + mock_session.boto_region_name = "us-west-2" + mock_session.settings = SessionSettings(encrypt_repacked_artifacts=False) + mock_s3 = Mock() + mock_obj = Mock() + mock_boto_session.resource.return_value = mock_s3 + mock_s3.Object.return_value = mock_obj + + mock_session._get_account_id_if_default_bucket.return_value = "111111111111" + + _save_model( + "s3://sagemaker-us-west-2-111111111111/m.tar.gz", + str(model_file), + mock_session, + kms_key=None, + ) + + call_args = mock_obj.upload_file.call_args + assert call_args[1]["ExtraArgs"] == {"ExpectedBucketOwner": "111111111111"} + + def test_save_to_non_default_bucket_omits_expected_owner(self, tmp_path): + from sagemaker.core.common_utils import _save_model + from sagemaker.core.session_settings import SessionSettings + + model_file = tmp_path / "m.tar.gz" + model_file.write_text("x") + + mock_session = Mock() + mock_boto_session = Mock() + mock_session.boto_session = mock_boto_session + mock_session.boto_region_name = "us-west-2" + mock_session.settings = SessionSettings(encrypt_repacked_artifacts=False) + mock_s3 = Mock() + mock_obj = Mock() + mock_boto_session.resource.return_value = mock_s3 + mock_s3.Object.return_value = mock_obj + + mock_session._get_account_id_if_default_bucket.return_value = None + + _save_model( + "s3://marketplace-vendor-bucket/m.tar.gz", + str(model_file), + mock_session, + kms_key=None, + ) + + call_args = mock_obj.upload_file.call_args + assert call_args[1]["ExtraArgs"] is None + + def test_save_to_default_bucket_preserves_kms(self, tmp_path): + from sagemaker.core.common_utils import _save_model + from sagemaker.core.session_settings import SessionSettings + + model_file = tmp_path / "m.tar.gz" + model_file.write_text("x") + + mock_session = Mock() + mock_boto_session = Mock() + mock_session.boto_session = mock_boto_session + mock_session.boto_region_name = "us-west-2" + mock_session.settings = SessionSettings() + mock_s3 = Mock() + mock_obj = Mock() + mock_boto_session.resource.return_value = mock_s3 + mock_s3.Object.return_value = mock_obj + + mock_session._get_account_id_if_default_bucket.return_value = "111111111111" + + _save_model( + "s3://sagemaker-us-west-2-111111111111/m.tar.gz", + str(model_file), + mock_session, + kms_key="kms-key-id", + ) + + merged = mock_obj.upload_file.call_args[1]["ExtraArgs"] + assert merged["ServerSideEncryption"] == "aws:kms" + assert merged["SSEKMSKeyId"] == "kms-key-id" + assert merged["ExpectedBucketOwner"] == "111111111111" diff --git a/sagemaker-core/tests/unit/test_lambda_helper.py b/sagemaker-core/tests/unit/test_lambda_helper.py index cc0a52be1e..885cce7405 100644 --- a/sagemaker-core/tests/unit/test_lambda_helper.py +++ b/sagemaker-core/tests/unit/test_lambda_helper.py @@ -478,7 +478,33 @@ def test_upload_to_s3(self): assert result == "prefix/lambda/my-function/code" mock_s3_client.upload_file.assert_called_once_with( - "/path/to/code.zip", "my-bucket", "prefix/lambda/my-function/code" + "/path/to/code.zip", + "my-bucket", + "prefix/lambda/my-function/code", + ExtraArgs=None, + ) + + def test_upload_to_s3_with_expected_bucket_owner(self): + """When expected_bucket_owner is provided (caller resolved default bucket), + ExtraArgs must carry ExpectedBucketOwner. + """ + mock_s3_client = Mock() + + result = _upload_to_s3( + mock_s3_client, + "my-function", + "/path/to/code.zip", + "sagemaker-us-west-2-111111111111", + "prefix", + expected_bucket_owner="111111111111", + ) + + assert result == "prefix/lambda/my-function/code" + mock_s3_client.upload_file.assert_called_once_with( + "/path/to/code.zip", + "sagemaker-us-west-2-111111111111", + "prefix/lambda/my-function/code", + ExtraArgs={"ExpectedBucketOwner": "111111111111"}, ) def test_zip_lambda_code(self, tmp_path):