Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sagemaker-core/src/sagemaker/core/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,10 @@ def _get_default_spark_image(session):
except ImportError:
pass

# Spark 3.3 and below do not support py312; use 3.5 which supports both py39 and py312
if py_version == "312" and spark_version in ("2.4", "3.0", "3.1", "3.2", "3.3"):
spark_version = "3.5"

image_uri = image_uris.retrieve(
framework=SPARK_NAME,
region=region,
Expand Down
46 changes: 46 additions & 0 deletions sagemaker-core/tests/integ/remote_function/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,52 @@ def spark_test_container(sagemaker_session, sagemaker_sdk_tar_path, tmp_path_fac
)


@pytest.fixture(scope="session")
def spark_pre_execution_commands(sagemaker_session):
"""Build sagemaker-core wheel, upload to S3, and return pre-execution install commands.

This mirrors the pattern used in sagemaker-mlops feature_processor integ tests.
The Spark processing image does not have sagemaker-core pre-installed, so we must
build the local dev wheel and install it in the container via pre_execution_commands.
"""
import subprocess
import glob
import tempfile
from sagemaker.core.s3 import S3Uploader

repo_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
)
core_dir = os.path.join(repo_root, "sagemaker-core")

with tempfile.TemporaryDirectory() as dist_dir:
subprocess.run(
f"python -m build --wheel --outdir {dist_dir}",
shell=True,
cwd=core_dir,
check=True,
)
wheels = glob.glob(os.path.join(dist_dir, "sagemaker_core-*.whl"))
if not wheels:
raise FileNotFoundError(f"No sagemaker-core wheel found in {dist_dir}")
wheel_path = wheels[0]
wheel_name = os.path.basename(wheel_path)

s3_prefix = "s3://{}/spark-integ-test/wheels".format(
sagemaker_session.default_bucket()
)
S3Uploader.upload(wheel_path, s3_prefix, sagemaker_session=sagemaker_session)

PIP = "python3 -m pip install --root-user-action=ignore"
AWS = "python3 -m awscli"
cmds = [
f"{PIP} awscli",
f"{AWS} s3 cp {s3_prefix}/{wheel_name} /tmp/{wheel_name}",
f"{PIP} /tmp/{wheel_name}",
]
return cmds


@pytest.fixture(scope="session")
def conda_env_yml():
"""Write conda yml file needed for tests."""
Expand Down
21 changes: 15 additions & 6 deletions sagemaker-core/tests/integ/remote_function/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,16 +574,18 @@ def my_func():
assert client_error_message in str(error)


@pytest.mark.skipif(
sys.version_info[:2] not in [(3, 9), (3, 12)],
reason="SageMaker Spark image only available for Python 3.9 and 3.12",
)
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
# @pytest.mark.skipif(
# sys.version_info[:2] not in [(3, 9), (3, 12)],
# reason="SageMaker Spark image only available for Python 3.9 and 3.12",
# )
@pytest.mark.spark_py312
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type, spark_pre_execution_commands):
@remote(
role=ROLE,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
keep_alive_period_in_seconds=60,
pre_execution_commands=spark_pre_execution_commands,
spark_config=SparkConfig(
configuration=[
{
Expand All @@ -598,7 +600,14 @@ def test_spark_transform():

spark = SparkSession.builder.getOrCreate()

assert spark.conf.get("spark.app.name") == "remote-spark-test"
# Avoid bare assert here: pytest's assertion rewriting injects _pytest
# module references into the function bytecode, which causes
# deserialization to fail in the Spark container (no pytest installed).
app_name = spark.conf.get("spark.app.name")
if app_name != "remote-spark-test":
raise RuntimeError(
f"Expected spark.app.name='remote-spark-test', got '{app_name}'"
)

test_spark_transform()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,11 +798,11 @@ def transform(raw_s3_data_as_df):
# sys.version_info[:2] not in [(3, 9), (3, 12)],
# reason=f"SageMaker Spark image only supports Python 3.9 and 3.12, got {sys.version_info[:2]}",
# )
@pytest.mark.skip(
reason="Lake Formation credential vending (GetTemporaryGlueTableCredentials) requires "
"full LF environment setup (resource registration, trust policy, data location grants) "
"that is not configured in CI. See quip-amazon.com/S3FEAMMMuKm0 for details."
)
# @pytest.mark.skip(
# reason="Lake Formation credential vending (GetTemporaryGlueTableCredentials) requires "
# "full LF environment setup (resource registration, trust policy, data location grants) "
# "that is not configured in CI. See quip-amazon.com/S3FEAMMMuKm0 for details."
# )
@pytest.mark.spark_py312
@pytest.mark.slow_test
def test_to_pipeline_and_execute_with_lake_formation(
Expand Down
Loading