Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,21 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
if ".." in prefix:
raise ValueError("Traversal components are not allowed in S3 path!")

# Spot check: enforce ownership only when downloading from the session's default
# bucket. Cross-account buckets are left untouched.
expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket_name)
extra_args = None
if expected_owner:
extra_args = {"ExpectedBucketOwner": expected_owner}

# Try to download the prefix as an object first, in case it is a file and not a 'directory'.
# Do this first, in case the object has broader permissions than the bucket.
if not prefix.endswith("/"):
try:
file_destination = os.path.join(target, os.path.basename(prefix))
s3.Object(bucket_name, prefix).download_file(file_destination)
s3.Object(bucket_name, prefix).download_file(
file_destination, ExtraArgs=extra_args
)
return
except botocore.exceptions.ClientError as e:
err_info = e.response["Error"]
Expand All @@ -445,17 +454,19 @@ def download_folder(bucket_name, prefix, target, sagemaker_session):
else:
raise

_download_files_under_prefix(bucket_name, prefix, target, s3)
_download_files_under_prefix(bucket_name, prefix, target, s3, extra_args=extra_args)


def _download_files_under_prefix(bucket_name, prefix, target, s3):
def _download_files_under_prefix(bucket_name, prefix, target, s3, extra_args=None):
"""Download all S3 files which match the given prefix

Args:
bucket_name (str): S3 bucket name
prefix (str): S3 prefix within the bucket that will be downloaded
target (str): destination path where the downloaded items will be placed
s3 (boto3.resources.base.ServiceResource): S3 resource
extra_args (dict): Optional extra arguments passed to each download_file call.
Used to carry ExpectedBucketOwner when the bucket is the session's default.
"""
bucket = s3.Bucket(bucket_name)
for obj_sum in bucket.objects.filter(Prefix=prefix):
Expand All @@ -473,7 +484,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3):
# anything else will be raised.
if exc.errno != errno.EEXIST:
raise
obj.download_file(file_path)
obj.download_file(file_path, ExtraArgs=extra_args)


def create_tar_file(source_files, target=None):
Expand Down Expand Up @@ -620,6 +631,16 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
extra_args = {"ServerSideEncryption": "aws:kms"}
else:
extra_args = None

# Spot check: when the model is being uploaded to the session's default bucket,
# assert ownership to defend against bucket-squatting on the predictable default
# name. Other caller-supplied buckets are left untouched.
if sagemaker_session is not None:
expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket)
if expected_owner:
extra_args = dict(extra_args) if extra_args else {}
extra_args["ExpectedBucketOwner"] = expected_owner

sagemaker_session.boto_session.resource(
"s3", region_name=sagemaker_session.boto_region_name
).Object(bucket, new_key).upload_file(tmp_model_path, ExtraArgs=extra_args)
Expand Down Expand Up @@ -767,7 +788,17 @@ def download_file(bucket_name, path, target, sagemaker_session):

s3 = boto_session.resource("s3", region_name=sagemaker_session.boto_region_name)
bucket = s3.Bucket(bucket_name)
bucket.download_file(path, target)

# Spot check: assert ownership only when downloading from the session's default
# bucket. Non-default buckets (e.g. caller-supplied model URIs pointing at shared
# or cross-account data) are downloaded without ExpectedBucketOwner to preserve
# legitimate cross-account flows.
expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket_name)
extra_args = None
if expected_owner:
extra_args = {"ExpectedBucketOwner": expected_owner}

bucket.download_file(path, target, ExtraArgs=extra_args)


def sts_regional_endpoint(region):
Expand Down
35 changes: 32 additions & 3 deletions sagemaker-core/src/sagemaker/core/experiments/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ def upload_artifact(self, file_path, extra_args=None):
artifact_s3_key = "{}/{}/{}".format(
self.artifact_prefix, self.trial_component_name, artifact_name
)

# Spot check: enforce ownership only when uploading to the session's default
# bucket. Cross-account destinations are left untouched.
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
self.artifact_bucket
)
if expected_owner:
extra_args = dict(extra_args) if extra_args else {}
extra_args["ExpectedBucketOwner"] = expected_owner

self._s3_client.upload_file(
file_path,
self.artifact_bucket,
Expand Down Expand Up @@ -133,9 +143,21 @@ def upload_object_artifact(self, artifact_name, artifact_object, file_extension=
artifact_s3_key = "{}/{}/{}".format(
self.artifact_prefix, self.trial_component_name, artifact_name
)
self._s3_client.put_object(
Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key

# Spot check: enforce ownership only when uploading to the session's default
# bucket. Cross-account destinations are left untouched.
put_kwargs = {
"Body": json.dumps(artifact_object),
"Bucket": self.artifact_bucket,
"Key": artifact_s3_key,
}
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
self.artifact_bucket
)
if expected_owner:
put_kwargs["ExpectedBucketOwner"] = expected_owner

self._s3_client.put_object(**put_kwargs)
etag = self._try_get_etag(artifact_s3_key)
return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag

Expand All @@ -149,7 +171,14 @@ def _try_get_etag(self, key):
str: The S3 object ETag if it allows, otherwise return None.
"""
try:
response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key)
head_kwargs = {"Bucket": self.artifact_bucket, "Key": key}
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
self.artifact_bucket
)
if expected_owner:
head_kwargs["ExpectedBucketOwner"] = expected_owner

response = self._s3_client.head_object(**head_kwargs)
return response["ETag"]
except botocore.exceptions.ClientError as error:
# requires read permissions
Expand Down
11 changes: 11 additions & 0 deletions sagemaker-core/src/sagemaker/core/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def tar_and_upload_dir(
kms_key=None,
s3_resource=None,
settings: Optional[SessionSettings] = None,
expected_bucket_owner: Optional[str] = None,
) -> UploadedCode:
"""Package source files and upload a compress tar file to S3.

Expand Down Expand Up @@ -430,6 +431,12 @@ def tar_and_upload_dir(
settings (sagemaker.session_settings.SessionSettings): Optional. The settings
of the SageMaker ``Session``, can be used to override the default encryption
behavior (default: None).
expected_bucket_owner (str): Optional. AWS account id passed as
``ExpectedBucketOwner`` on the upload. Callers should supply this when
``bucket`` is the session's default bucket (via
``Session._get_account_id_if_default_bucket``) to defend against
bucket-squatting on the predictable default name. Leave as ``None`` for
cross-account destination buckets.
Returns:
sagemaker.fw_utils.UploadedCode: An object with the S3 bucket and key (S3 prefix) and
script name.
Expand Down Expand Up @@ -471,6 +478,10 @@ def tar_and_upload_dir(
else:
extra_args = None

if expected_bucket_owner:
extra_args = dict(extra_args) if extra_args else {}
extra_args["ExpectedBucketOwner"] = expected_bucket_owner

if s3_resource is None:
s3_resource = session.resource("s3", region_name=session.region_name)
else:
Expand Down
126 changes: 113 additions & 13 deletions sagemaker-core/src/sagemaker/core/helper/session_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,14 @@ def upload_data(self, path, bucket=None, key_prefix="data", callback=None, extra
bucket=bucket, key_prefix=key_prefix, sagemaker_session=self
)

# Spot check: if the resolved bucket is the session's default bucket, enforce
# ownership on the upload to defend against bucket-squatting on the predictable
# default name. Other buckets are left untouched to preserve cross-account flows.
expected_owner = self._get_account_id_if_default_bucket(bucket)
if expected_owner:
extra_args = dict(extra_args) if extra_args else {}
extra_args["ExpectedBucketOwner"] = expected_owner

# Generate a tuple for each file that we want to upload of the form (local_path, s3_key).
files = []
key_suffix = None
Expand Down Expand Up @@ -484,10 +492,17 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):

s3_object = s3.Object(bucket_name=bucket, key=key)

# Spot check: enforce ownership only when writing to the session's default
# bucket. Cross-account destinations are left untouched.
put_kwargs = {"Body": body}
if kms_key is not None:
s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms")
else:
s3_object.put(Body=body)
put_kwargs["SSEKMSKeyId"] = kms_key
put_kwargs["ServerSideEncryption"] = "aws:kms"
expected_owner = self._get_account_id_if_default_bucket(bucket)
if expected_owner:
put_kwargs["ExpectedBucketOwner"] = expected_owner

s3_object.put(**put_kwargs)

s3_uri = "s3://{}/{}".format(bucket, key)
return s3_uri
Expand All @@ -511,11 +526,19 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
else:
s3 = self.s3_client

# Spot check: if the caller-supplied bucket is the session's default bucket,
# assert ownership on the list/download calls to defend against squatting on
# the predictable default bucket name. Non-default buckets are left untouched
# to preserve legitimate cross-account flows.
expected_owner = self._get_account_id_if_default_bucket(bucket)

# Initialize the variables used to loop through the contents of the S3 bucket.
keys = []
directories = []
next_token = ""
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
if expected_owner:
base_parameters["ExpectedBucketOwner"] = expected_owner

# Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
# a "keys" list.
Expand All @@ -542,6 +565,9 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):

# For each object key, create the directory on the local machine if needed, and then
# download the file.
download_extra_args = dict(extra_args) if extra_args else {}
if expected_owner:
download_extra_args["ExpectedBucketOwner"] = expected_owner
downloaded_paths = []
for dir_path in directories:
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
Expand All @@ -553,7 +579,10 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
if not os.path.exists(os.path.dirname(destination_path)):
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
s3.download_file(
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
Bucket=bucket,
Key=key,
Filename=destination_path,
ExtraArgs=download_extra_args or None,
)
downloaded_paths.append(destination_path)
return downloaded_paths
Expand All @@ -573,8 +602,16 @@ def read_s3_file(self, bucket, key_prefix):
else:
s3 = self.s3_client

# Spot check: assert ownership only when the caller is reading from the
# session's default bucket. Other buckets (e.g. JumpStart content buckets)
# are read without ExpectedBucketOwner to preserve cross-account flows.
get_kwargs = {"Bucket": bucket, "Key": key_prefix}
expected_owner = self._get_account_id_if_default_bucket(bucket)
if expected_owner:
get_kwargs["ExpectedBucketOwner"] = expected_owner

# Explicitly passing a None kms_key to boto3 throws a validation error.
s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
s3_object = s3.get_object(**get_kwargs)

return s3_object["Body"].read().decode("utf-8")

Expand Down Expand Up @@ -699,32 +736,50 @@ def general_bucket_check_if_user_has_permission(
If there is any other error that comes up with calling head bucket, it is raised up here
If there is no bucket , it will create one

When the SDK selected the bucket name itself (``_default_bucket_set_by_sdk`` is True),
the probe is issued with ``ExpectedBucketOwner`` set to the caller's account id so that
S3 returns ``403`` if a bucket with the SDK-chosen name happens to exist in another
account (bucket-squatting defense). For user-overridden bucket names the probe is
issued without ``ExpectedBucketOwner`` to preserve legitimate cross-account usage.

Args:
bucket_name (str): Name of the S3 bucket
s3 (str): S3 object from boto session
region (str): The region in which to create the bucket.
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
"""
extra_args = {}
if self._default_bucket_set_by_sdk:
extra_args["ExpectedBucketOwner"] = self.account_id()

try:
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name, Prefix=self.default_bucket_prefix
Bucket=bucket_name, Prefix=self.default_bucket_prefix, **extra_args
)
else:
s3.meta.client.head_bucket(Bucket=bucket_name)
s3.meta.client.head_bucket(Bucket=bucket_name, **extra_args)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
# bucket does not exist or forbidden to access
# bucket does not exist, is owned by another account, or is forbidden to access
if bucket_creation_date_none:
if error_code == "404" and message == "Not Found":
self.create_bucket_for_not_exist_error(bucket_name, region, s3)
elif error_code == "403" and message == "Forbidden":
LOGGER.error(
"Bucket %s exists, but access is forbidden. Please try again after "
"adding appropriate access.",
bucket.name,
)
if self._default_bucket_set_by_sdk:
LOGGER.error(
"Bucket %s is not accessible as the default bucket. It may be "
"owned by another account or access is forbidden. To unblock, "
"pass a custom default_bucket parameter to sagemaker.Session.",
bucket.name,
)
else:
LOGGER.error(
"Bucket %s exists, but access is forbidden. Please try again "
"after adding appropriate access.",
bucket.name,
)
raise
else:
raise
Expand Down Expand Up @@ -773,6 +828,51 @@ def generate_default_sagemaker_bucket_name(self, boto_session):
).get_caller_identity()["Account"]
return "sagemaker-{}-{}".format(region, account)

def _get_account_id_if_default_bucket(self, bucket):
"""Return the caller's account id if ``bucket`` is the SDK-generated default bucket.

Used by S3 operations that receive a caller-supplied bucket name to apply the
"spot check": when the bucket matches the SDK-generated default bucket name,
the call should assert ownership via ``ExpectedBucketOwner`` to defend against
bucket-squatting on the predictable default name. For any other bucket — including
user-overridden default buckets (which may legitimately be cross-account),
JumpStart, marketplace artifacts, or shared-team buckets — this returns ``None``
so the call proceeds unchanged.

This check is passive: it does not trigger default-bucket resolution or creation.
If ``default_bucket()`` has not been called yet, this returns ``None``.

Args:
bucket (str): The bucket name the caller is about to use.

Returns:
Optional[str]: The expected account id, or ``None`` if the spot check does
not apply.
"""
if not bucket:
return None

# Only apply the spot check when the SDK generated the bucket name itself.
# User-overridden default buckets (_default_bucket_name_override) may legitimately
# be in another account, so we must not assert caller-account ownership on them.
if not self._default_bucket_set_by_sdk:
return None

# Use the already-resolved default bucket (fast, no side effects).
# _default_bucket is only set after default_bucket() has been called at least once.
if self._default_bucket and bucket == self._default_bucket:
try:
return self.account_id()
except Exception: # pylint: disable=broad-except
# account_id() issues an STS call; if it fails we skip the spot check
# rather than block the S3 operation.
LOGGER.warning(
"Could not resolve caller account id for ExpectedBucketOwner check "
"on bucket %s; proceeding without the check.",
bucket,
)
return None

def determine_bucket_and_prefix(
self, bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None
):
Expand Down
Loading
Loading