Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 15 additions & 43 deletions sagemaker-core/resource_plan.csv

Large diffs are not rendered by default.

13,351 changes: 3,227 additions & 10,124 deletions sagemaker-core/sample/sagemaker/2017-07-24/service-2.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion sagemaker-core/src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Namespace package for SageMaker."""
__path__ = __import__('pkgutil').extend_path(__path__, __name__)

__path__ = __import__("pkgutil").extend_path(__path__, __name__)
1 change: 1 addition & 0 deletions sagemaker-core/src/sagemaker/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SageMaker Core package for low-level resource management and SDK foundations."""

from sagemaker.core.utils.utils import enable_textual_rich_console_and_traceback


Expand Down
2 changes: 1 addition & 1 deletion sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def _create_or_update_code_dir(
"""Placeholder docstring"""
code_dir = os.path.join(model_dir, "code")
resolved_code_dir = _get_resolved_path(code_dir)

# Validate that code_dir does not resolve to a sensitive system path
for sensitive_path in _SENSITIVE_SYSTEM_PATHS:
if resolved_code_dir != "/" and resolved_code_dir.startswith(sensitive_path):
Expand Down
436 changes: 33 additions & 403 deletions sagemaker-core/src/sagemaker/core/config_schema.py

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions sagemaker-core/src/sagemaker/core/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,7 @@ def tar_and_upload_dir(

try:
source_files = _list_files_to_compress(script, directory) + dependencies
tar_file = utils.create_tar_file(
source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
)
tar_file = utils.create_tar_file(source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME))

if kms_key:
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
Expand Down Expand Up @@ -1208,7 +1206,7 @@ def create_image_uri(
the image uri
"""
from sagemaker.core import image_uris

renamed_warning("The method create_image_uri")
return image_uris.retrieve(
framework=framework,
Expand Down
2 changes: 2 additions & 0 deletions sagemaker-core/src/sagemaker/core/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pathlib import Path
from urllib.parse import urlparse


def _sanitize_git_url(repo_url):
"""Sanitize Git repository URL to prevent URL injection attacks.

Expand Down Expand Up @@ -84,6 +85,7 @@ def _sanitize_git_url(repo_url):

return repo_url


def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
"""Git clone repo containing the training code and serving code.

Expand Down
9 changes: 6 additions & 3 deletions sagemaker-core/src/sagemaker/core/helper/session_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def _initialize(
self.sagemaker_client = sagemaker_client
else:
from sagemaker.core.user_agent import get_user_agent_extra_suffix

config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix())
self.sagemaker_client = self.boto_session.client("sagemaker", config=config)

Expand Down Expand Up @@ -1877,7 +1878,7 @@ def expand_role(self, role):
if "/" in role:
return role
return self.boto_session.resource("iam").Role(role).arn


def _expand_container_def(c_def):
"""Placeholder docstring"""
Expand Down Expand Up @@ -2714,7 +2715,9 @@ def _live_logging_deploy_done(sagemaker_client, endpoint_name, paginator, pagina
if endpoint_status != "Creating":
stop = True
if endpoint_status == "InService":
LOGGER.info("Created endpoint with name %s. Waiting for it to be InService", endpoint_name)
LOGGER.info(
"Created endpoint with name %s. Waiting for it to be InService", endpoint_name
)
else:
time.sleep(poll)

Expand Down Expand Up @@ -2974,4 +2977,4 @@ def container_def(
c_def["Mode"] = container_mode
if image_config:
c_def["ImageConfig"] = image_config
return c_def
return c_def
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty

def config_for_framework(framework):
"""Loads the JSON config for the given framework."""
fname = os.path.join(os.path.dirname(__file__), "..", "image_uri_config", "{}.json".format(framework))
fname = os.path.join(
os.path.dirname(__file__), "..", "image_uri_config", "{}.json".format(framework)
)
with open(fname) as f:
return json.load(f)

Expand Down
4 changes: 2 additions & 2 deletions sagemaker-core/src/sagemaker/core/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __next__(self):
# print and move on to next response byte
print("Unknown event type:" + chunk)
continue

# Check buffer size before writing to prevent unbounded memory consumption
chunk_size = len(chunk["PayloadPart"]["Bytes"])
current_size = self.buffer.getbuffer().nbytes
Expand All @@ -192,6 +192,6 @@ def __next__(self):
f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. "
f"No newline found in stream."
)

self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
4 changes: 2 additions & 2 deletions sagemaker-core/src/sagemaker/core/local/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ def __init__(self, root_path):
super(LocalFileDataSource, self).__init__()

self.root_path = os.path.abspath(root_path)

# Validate that the path is not in restricted locations
for restricted_path in _SENSITIVE_SYSTEM_PATHS:
if self.root_path != "/" and self.root_path.startswith(restricted_path):
raise ValueError(
f"Local Mode does not support mounting from restricted system paths. "
f"Got: {root_path}"
)

if not os.path.exists(self.root_path):
raise RuntimeError("Invalid data source: %s does not exist." % self.root_path)

Expand Down
4 changes: 2 additions & 2 deletions sagemaker-core/src/sagemaker/core/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def get_child_process_ids(pid):
"""
if not str(pid).isdigit():
raise ValueError("Invalid PID")

cmd = ["pgrep", "-P", str(pid)]

process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if err:
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-core/src/sagemaker/core/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def get_model_package_args(
if model_card is not None:
original_req = {}
if isinstance(model_card, ModelPackageModelCard):
original_req["ModelCardContent"] = model_card.model_card_content
original_req["ModelCardContent"] = model_card.model_card_content
else:
original_req["ModelCardContent"] = model_card.content
original_req["ModelCardContent"] = model_card.content
original_req["ModelCardStatus"] = model_card.model_card_status
model_package_args["model_card"] = original_req
return model_package_args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def _rmtree(path, image=None, is_studio=False):
logger.warning(
"Failed to clean up root-owned files in %s. "
"You may need to remove them manually with: sudo rm -rf %s",
path, path,
path,
path,
)
raise
try:
Expand All @@ -82,7 +83,8 @@ def _rmtree(path, image=None, is_studio=False):
logger.warning(
"Failed to clean up root-owned files in %s. "
"You may need to remove them manually with: sudo rm -rf %s",
path, path,
path,
path,
)
raise

Expand Down
5 changes: 4 additions & 1 deletion sagemaker-core/src/sagemaker/core/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,10 @@ def _normalize_outputs(self, outputs=None):
# If the output's s3_uri is not an s3_uri, create one.
parse_result = urlparse(output.s3_output.s3_uri)
if parse_result.scheme != "s3":
if getattr(self.sagemaker_session, "local_mode", False) and parse_result.scheme == "file":
if (
getattr(self.sagemaker_session, "local_mode", False)
and parse_result.scheme == "file"
):
normalized_outputs.append(output)
continue
if _pipeline_config:
Expand Down
15 changes: 8 additions & 7 deletions sagemaker-core/src/sagemaker/core/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def remote(
"""

def _remote(func):

if job_conda_env:
RemoteExecutor._validate_env_name(job_conda_env)

Expand Down Expand Up @@ -775,7 +775,7 @@ def __init__(
+ "without spark_config or use_torchrun or use_mpirun. "
+ "Please provide instance_count = 1"
)

if job_conda_env:
self._validate_env_name(job_conda_env)

Expand Down Expand Up @@ -955,21 +955,22 @@ def _validate_submit_args(func, *args, **kwargs):
+ f"{'arguments' if len(missing_kwargs) > 1 else 'argument'}: "
+ f"{missing_kwargs_string}"
)

@staticmethod
def _validate_env_name(env_name: str) -> None:
"""Validate conda environment name to prevent command injection.

Args:
env_name (str): The environment name to validate

Raises:
ValueError: If the environment name contains invalid characters
"""

# Allow only alphanumeric, underscore, and hyphen
import re
if not re.match(r'^[a-zA-Z0-9_-]+$', env_name):

if not re.match(r"^[a-zA-Z0-9_-]+$", env_name):
raise ValueError(
f"Invalid environment name '{env_name}'. "
"Only alphanumeric characters, underscores, and hyphens are allowed."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,7 @@ def json_serialize_obj_to_s3(
)


def deserialize_obj_from_s3(
sagemaker_session: Session, s3_uri: str, verification_key=None
) -> Any:
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, verification_key=None) -> Any:
"""Downloads from S3 and then deserializes data objects.

Called from both job (verifying client-uploaded args) and client (verifying
Expand Down Expand Up @@ -394,6 +392,7 @@ def _upload_payload_and_metadata_to_s3(
sagemaker_session,
)


def _upload_payload_and_metadata_to_s3_signed(
bytes_to_upload: Union[bytes, io.BytesIO],
private_key: ec.EllipticCurvePrivateKey,
Expand Down Expand Up @@ -457,7 +456,6 @@ def _upload_payload_and_metadata_to_s3_hashed(
)



def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
"""Downloads from S3 and then deserializes exception with plain SHA-256 verification.

Expand Down
8 changes: 6 additions & 2 deletions sagemaker-core/src/sagemaker/core/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,9 @@ def _get_default_spark_image(session):
class _Job:
"""Helper class that interacts with the SageMaker training service."""

def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, verification_key: str):
def __init__(
self, job_name: str, s3_uri: str, sagemaker_session: Session, verification_key: str
):
"""Initialize a _Job object.

Args:
Expand Down Expand Up @@ -870,7 +872,9 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
"""
job_name = describe_training_job_response["TrainingJobName"]
s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]
verification_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"]
verification_key = describe_training_job_response["Environment"][
"REMOTE_FUNCTION_SECRET_KEY"
]

job = _Job(job_name, s3_uri, sagemaker_session, verification_key)
job._last_describe_response = describe_training_job_response
Expand Down
Loading
Loading