Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/training/SDK/training_sdk_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"source": [
"from sagemaker.hyperpod.training import (\n",
" HyperPodPytorchJob,\n",
" Container,\n",
" Containers,\n",
" ReplicaSpec,\n",
" Resources,\n",
" RunPolicy,\n",
Expand All @@ -57,7 +57,7 @@
" template=Template(\n",
" spec=Spec(\n",
" containers=[\n",
" Container(\n",
" Containers(\n",
" name=\"container-name\",\n",
" image=\"448049793756.dkr.ecr.us-west-2.amazonaws.com/ptjob:mnist\",\n",
" image_pull_policy=\"Always\",\n",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel, ConfigDict, Field
from typing import Optional, List, Dict, Union
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_config import (
Container,
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import (
Containers,
ReplicaSpec,
Resources,
RunPolicy,
Expand Down Expand Up @@ -103,7 +103,7 @@ def to_domain(self) -> Dict:
]

# Create container object
container = Container(**container_kwargs)
container = Containers(**container_kwargs)

# Create pod spec kwargs
spec_kwargs = {"containers": list([container])}
Expand Down
5 changes: 1 addition & 4 deletions src/sagemaker/hyperpod/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_config import *
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_status import (
HyperPodPytorchJobStatus,
)
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import *
from sagemaker.hyperpod.training.hyperpod_pytorch_job import (
HyperPodPytorchJob,
_load_hp_job,
Expand Down
2,977 changes: 0 additions & 2,977 deletions src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_config.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,23 @@ class LabelSelector(BaseModel):
)


class NamespaceSelector(BaseModel):
"""A label query over the set of namespaces that the term applies to. The term is applied to the union of the namespaces selected by this field and the ones listed in the namespaces field. null selector and null or empty namespaces list means "this pod's namespace". An empty selector ({}) matches all namespaces."""

model_config = ConfigDict(extra="forbid")

matchExpressions: Optional[List[MatchExpressions]] = Field(
default=None,
alias="match_expressions",
description="matchExpressions is a list of label selector requirements. The requirements are ANDed.",
)
matchLabels: Optional[Dict[str, str]] = Field(
default=None,
alias="match_labels",
description='matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels map is equivalent to an element of matchExpressions, whose key field is "key", the operator is "In", and the values array contains only "value". The requirements are ANDed.',
)


class TopologySpreadConstraints(BaseModel):
"""TopologySpreadConstraint specifies how to spread matching pods among the given topology."""

Expand Down Expand Up @@ -2955,6 +2972,134 @@ class Template(BaseModel):
)


class ReplicaSpec(BaseModel):
"""ReplicaSpec is a description of the replica"""

model_config = ConfigDict(extra="forbid")

name: str = Field(description="The name for the replica set")
replicas: Optional[int] = Field(
default=1,
description="Replicas is the desired number of replicas of the given template.",
)
spares: Optional[int] = Field(
default=0,
description="Spares requests spare resources from Kueue. E.g. If a job is configured with 4 replicas and 2 spares, job requests resources required to run 6 pods such as cpu, gpu",
)
template: Optional[Template] = Field(
default=None,
description="Template is the object that describes the pod that will be created for this replica.",
)


class LogMonitoringConfiguration(BaseModel):
"""LogMonitoringRule defines the criteria used to detect a SLOW or HANGING job"""

model_config = ConfigDict(extra="forbid")

expectedRecurringFrequencyInSeconds: Optional[int] = Field(
default=None,
alias="expected_recurring_frequency_in_seconds",
description="Time interval between two subsequent matches for LogPattern beyond which, the rule evaluates to HANGING. When not specified, there is no constraint on duration between two subsequent matches for LogPattern.",
)
expectedStartCutOffInSeconds: Optional[int] = Field(
default=None,
alias="expected_start_cut_off_in_seconds",
description="Time to first match for LogPattern beyond which, the rule evaluates to HANGING. When not specified, there is no constraint on time to first match for LogPattern.",
)
logPattern: str = Field(
alias="log_pattern",
description="Regex to identify log lines to apply the rule to when the rule is active. This regex can optionally include one capturing group to extract a numeric metric value.",
)
metricEvaluationDataPoints: Optional[int] = Field(
default=None,
alias="metric_evaluation_data_points",
description="The number of consecutive times that a rule must evaluate to SLOW in order to mark a job as SLOW. When not specified, the default value is 1.",
)
metricThreshold: Optional[int] = Field(
default=None,
alias="metric_threshold",
description="Threshold for value extracted by LogPattern if it has a capturing group. When not specified, Metric evaluation will not be performed.",
)
name: str = Field(description="Name of the rule")
operator: Optional[str] = Field(
default=None,
description="Operator to compare the value extracted by LogPattern to MetricThreshold. Rule evaluates to SLOW if value extracted by LogPattern compared to MetricThreshold using Operator evaluates to true. When not specified, Metric evaluation will not be performed. Following operator values are supported: gt (greater than) lt (lesser than) eq (equal to) gteq (greater than or equal to) lteq (less than or equal to)",
)
stopPattern: Optional[str] = Field(
default=None,
alias="stop_pattern",
description="Regex to identify the log line at which to deactivate the rule. When not specified, the rule will always be active.",
)


class RestartPolicy(BaseModel):
"""Additional restart limiting configurations"""

model_config = ConfigDict(extra="forbid")

evalPeriodSeconds: int = Field(
alias="eval_period_seconds",
description="The period of evaluating the restart limit in seconds",
)
maxFullJobRestarts: Optional[int] = Field(
default=None,
alias="max_full_job_restarts",
description="The max allowed number of full job restarts before failing the job",
)
numRestartBeforeFullJobRestart: Optional[int] = Field(
default=None,
alias="num_restart_before_full_job_restart",
description="The number of standard restarts before a full job restart",
)


class RunPolicy(BaseModel):
"""RunPolicy"""

model_config = ConfigDict(extra="forbid")

activeDeadlineSeconds: Optional[int] = Field(
default=None,
alias="active_deadline_seconds",
description="Specifies the duration in seconds relative to the startTime that the job may be active before the system tries to terminate it; value must be positive integer.",
)
cleanPodPolicy: Optional[str] = Field(
default="All",
alias="clean_pod_policy",
description="CleanPodPolicy defines the policy to kill pods after the job completes.",
)
faultDeadlineSeconds: Optional[int] = Field(
default=None,
alias="fault_deadline_seconds",
description="The limit on the fault time for the job (Status of Fault) before failing",
)
jobMaxRetryCount: Optional[int] = Field(default=None, alias="job_max_retry_count")
logMonitoringConfiguration: Optional[List[LogMonitoringConfiguration]] = Field(
default=None,
alias="log_monitoring_configuration",
description="LogMonitoringConfiguration defines the log monitoring rules for SLOW and HANGING job detection",
)
restartPolicy: Optional[RestartPolicy] = Field(
default=None,
alias="restart_policy",
description="Additional restart limiting configurations",
)
startupDeadlineSeconds: Optional[int] = Field(
default=None,
alias="startup_deadline_seconds",
description="The limit on the startup time for the job (Status of Staring) before failing",
)
suspend: Optional[bool] = Field(
default=None, description="Suspend suspends HyperPodPytorchJob when set to true"
)
ttlSecondsAfterFinished: Optional[int] = Field(
default=0,
alias="ttl_seconds_after_finished",
description="TTLSecondsAfterFinished is the TTL to clean up jobs. Set to -1 for infinite",
)


class PodSets(BaseModel):
model_config = ConfigDict(extra="forbid")

Expand Down Expand Up @@ -3081,3 +3226,23 @@ class HyperPodPytorchJobStatus(BaseModel):
alias="start_time",
description="The time when job is first acknowledged by the controller. When using kueue, the job is also admitted It is represented in RFC3339 form and is in UTC.",
)


class _HyperPodPytorchJob(BaseModel):
"""Config defines the desired state of HyperPodPytorchJob"""

model_config = ConfigDict(extra="ignore")

nprocPerNode: str = Field(
default="auto",
alias="nproc_per_node",
description="Number of workers per node; supported values: [auto, cpu, gpu, int].",
)
replicaSpecs: Optional[List[ReplicaSpec]] = Field(
default=None,
alias="replica_specs",
description="The replicas to include as part of the job",
)
runPolicy: Optional[RunPolicy] = Field(
default=None, alias="run_policy", description="RunPolicy"
)
7 changes: 2 additions & 5 deletions src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from pydantic import ConfigDict, Field
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_config import (
_HyperPodPytorchJob,
)
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_status import (
HyperPodPytorchJobStatus,
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import (
_HyperPodPytorchJob, HyperPodPytorchJobStatus
)
from sagemaker.hyperpod.common.config.metadata import Metadata
from kubernetes import client, config
Expand Down
6 changes: 3 additions & 3 deletions test/unit_tests/training/test_hyperpod_pytorch_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sagemaker.hyperpod.training import (
HyperPodPytorchJob,
HyperPodPytorchJobStatus,
Container,
Containers,
ReplicaSpec,
Resources,
RunPolicy,
Expand All @@ -27,7 +27,7 @@ def setUp(self):
template=Template(
spec=Spec(
containers=[
Container(
Containers(
name="test-container",
image="test-image",
resources=Resources(
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_get_success(self, mock_load_job, mock_custom_api, mock_verify_config):
template=Template(
spec=Spec(
containers=[
Container(
Containers(
name="test-container",
image="test-image",
resources=Resources(
Expand Down
Loading