diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 68dc828f7..84764406f 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -281,16 +281,38 @@ def _format_job_name( show_deployment_num: bool, show_replica: bool, show_job: bool, + group_index: Optional[int] = None, + last_shown_group_index: Optional[int] = None, ) -> str: name_parts = [] + prefix = "" if show_replica: - name_parts.append(f"replica={job.job_spec.replica_num}") + # Show group information if replica groups are used + if group_index is not None: + # Show group=X replica=Y when group changes, or just replica=Y when same group + if group_index != last_shown_group_index: + # First job in group: use 3 spaces indent + prefix = " " + name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}") + else: + # Subsequent job in same group: align "replica=" with first job's "replica=" + # Calculate padding: width of " group={last_shown_group_index} " + padding_width = 3 + len(f"group={last_shown_group_index}") + 1 + prefix = " " * padding_width + name_parts.append(f"replica={job.job_spec.replica_num}") + else: + # Legacy behavior: no replica groups + prefix = " " + name_parts.append(f"replica={job.job_spec.replica_num}") + else: + prefix = " " + if show_job: name_parts.append(f"job={job.job_spec.job_num}") name_suffix = ( f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else "" ) - name_value = " " + (" ".join(name_parts) if name_parts else "") + name_value = prefix + (" ".join(name_parts) if name_parts else "") name_value += name_suffix return name_value @@ -359,6 +381,14 @@ def get_runs_table( ) merge_job_rows = len(run.jobs) == 1 and not show_deployment_num + # Replica Group Changes: Build mapping from replica group names to indices + group_name_to_index: Dict[str, int] = {} + # Replica Group Changes: Check if replicas attribute exists (only available for ServiceConfiguration) + replicas = getattr(run.run_spec.configuration, "replicas", None) + if replicas: + for idx, group in enumerate(replicas): + group_name_to_index[group.name] = idx + run_row: Dict[Union[str, int], Any] = { "NAME": _format_run_name(run, show_deployment_num), "SUBMITTED": format_date(run.submitted_at), @@ -372,13 +402,35 @@ def get_runs_table( if not merge_job_rows: add_row_from_dict(table, run_row) - for job in run.jobs: + # Sort jobs by group index first, then by replica_num within each group + def get_job_sort_key(job: Job) -> tuple: + group_index = None + if group_name_to_index and job.job_spec.replica_group: + group_index = group_name_to_index.get(job.job_spec.replica_group) + # Use a large number for jobs without groups to put them at the end + return (group_index if group_index is not None else 999999, job.job_spec.replica_num) + + sorted_jobs = sorted(run.jobs, key=get_job_sort_key) + + last_shown_group_index: Optional[int] = None + for job in sorted_jobs: latest_job_submission = job.job_submissions[-1] status_formatted = _format_job_submission_status(latest_job_submission, verbose) + # Get group index for this job + group_index: Optional[int] = None + if group_name_to_index and job.job_spec.replica_group: + group_index = group_name_to_index.get(job.job_spec.replica_group) + job_row: Dict[Union[str, int], Any] = { "NAME": _format_job_name( - job, latest_job_submission, show_deployment_num, show_replica, show_job + job, + latest_job_submission, + show_deployment_num, + show_replica, + show_job, + group_index=group_index, + last_shown_group_index=last_shown_group_index, ), "STATUS": status_formatted, "PROBES": _format_job_probes( @@ -390,6 +442,9 @@ def get_runs_table( "GPU": "-", "PRICE": "-", } + # Update last shown group index for next iteration + if group_index is not None: + last_shown_group_index = group_index jpd = latest_job_submission.job_provisioning_data if jpd is not None: shared_offer: Optional[InstanceOfferWithAvailability] = None diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 158c59b34..6abdda1a4 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel): @root_validator def check_image_or_commands_present(cls, values): + # If replicas is present, skip validation - commands come from replica groups + replica_groups = values.get("replicas") + if replica_groups: + return values + if not values.get("commands") and not values.get("image"): raise ValueError("Either `commands` or `image` must be set") return values @@ -714,6 +719,85 @@ def schema_extra(schema: Dict[str, Any]): ) +class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel): + name: Annotated[ + str, + Field(description="The name of the replica group"), + ] + replicas: Annotated[ + Range[int], + Field( + description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " + "If it's a range, the `scaling` property is required" + ), + ] + scaling: Annotated[ + Optional[ScalingSpec], + Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), + ] = None + probes: Annotated[ + list[ProbeConfig], + Field(description="List of probes used to determine job health for this replica group"), + ] = [] + rate_limits: Annotated[ + list[RateLimit], + Field(description="Rate limiting rules for this replica group"), + ] = [] + # TODO: Extract to ConfigurationWithResourcesParams mixin + resources: Annotated[ + ResourcesSpec, + Field(description="The resources requirements for replicas in this group"), + ] = ResourcesSpec() + + @validator("replicas") + def convert_replicas(cls, v: Range[int]) -> Range[int]: + if v.max is None: + raise ValueError("The maximum number of replicas is required") + if v.min is None: + v.min = 0 + if v.min < 0: + raise ValueError("The minimum number of replicas must be greater than or equal to 0") + return v + + @root_validator() + def override_commands_validation(cls, values): + """ + Override parent validator from ConfigurationWithCommandsParams. + ReplicaGroup always requires commands (no image option). + """ + commands = values.get("commands", []) + if not commands: + raise ValueError("`commands` must be set for replica groups") + return values + + @root_validator() + def validate_scaling(cls, values): + scaling = values.get("scaling") + replicas = values.get("replicas") + if replicas and replicas.min != replicas.max and not scaling: + raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") + if replicas and replicas.min == replicas.max and scaling: + raise ValueError("To use `scaling`, `replicas` must be set to a range.") + return values + + @validator("rate_limits") + def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: + counts = Counter(limit.prefix for limit in v) + duplicates = [prefix for prefix, count in counts.items() if count > 1] + if duplicates: + raise ValueError( + f"Prefixes {duplicates} are used more than once." + " Each rate limit should have a unique path prefix" + ) + return v + + @validator("probes") + def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: + if has_duplicates(v): + raise ValueError("Probes must be unique") + return v + + class ServiceConfigurationParams(CoreModel): port: Annotated[ # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used. @@ -754,13 +838,7 @@ class ServiceConfigurationParams(CoreModel): SERVICE_HTTPS_DEFAULT ) auth: Annotated[bool, Field(description="Enable the authorization")] = True - replicas: Annotated[ - Range[int], - Field( - description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " - "If it's a range, the `scaling` property is required" - ), - ] = Range[int](min=1, max=1) + scaling: Annotated[ Optional[ScalingSpec], Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), @@ -771,6 +849,19 @@ class ServiceConfigurationParams(CoreModel): Field(description="List of probes used to determine job health"), ] = [] + replicas: Annotated[ + Optional[Union[Range[int], List[ReplicaGroup]]], + Field( + description=( + "List of replica groups. Each group defines replicas with shared configuration " + "(commands, port, resources, scaling, probes, rate_limits). " + "When specified, the top-level `replicas`, `commands`, `port`, `resources`, " + "`scaling`, `probes`, and `rate_limits` are ignored. " + "Each replica group must have a unique name." + ) + ), + ] = None + @validator("port") def convert_port(cls, v) -> PortMapping: if isinstance(v, int): @@ -785,16 +876,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]: return OpenAIChatModel(type="chat", name=v, format="openai") return v - @validator("replicas") - def convert_replicas(cls, v: Range[int]) -> Range[int]: - if v.max is None: - raise ValueError("The maximum number of replicas is required") - if v.min is None: - v.min = 0 - if v.min < 0: - raise ValueError("The minimum number of replicas must be greater than or equal to 0") - return v - @validator("gateway") def validate_gateway( cls, v: Optional[Union[bool, str]] @@ -806,13 +887,29 @@ def validate_gateway( return v @root_validator() - def validate_scaling(cls, values): - scaling = values.get("scaling") + def normalize_replicas(cls, values): replicas = values.get("replicas") - if replicas and replicas.min != replicas.max and not scaling: - raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") - if replicas and replicas.min == replicas.max and scaling: - raise ValueError("To use `scaling`, `replicas` must be set to a range.") + if isinstance(replicas, list) and len(replicas) > 0: + if all(isinstance(item, ReplicaGroup) for item in replicas): + return values + + # Handle backward compatibility: convert old-style replica config to groups + old_replicas = values.get("replicas") + if isinstance(old_replicas, Range): + replica_count = old_replicas + else: + replica_count = Range[int](min=1, max=1) + values["replicas"] = [ + ReplicaGroup( + name="default", + replicas=replica_count, + commands=values.get("commands", []), + resources=values.get("resources"), + scaling=values.get("scaling"), + probes=values.get("probes", []), + rate_limits=values.get("rate_limits"), + ) + ] return values @validator("rate_limits") @@ -836,6 +933,28 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: raise ValueError("Probes must be unique") return v + @validator("replicas") + def validate_replicas( + cls, v: Optional[Union[Range[int], List[ReplicaGroup]]] + ) -> Optional[Union[Range[int], List[ReplicaGroup]]]: + if v is None: + return v + if isinstance(v, Range): + return v + + if isinstance(v, list): + if not v: + raise ValueError("`replicas` cannot be an empty list") + # Check for duplicate names + names = [group.name for group in v] + if len(names) != len(set(names)): + duplicates = [name for name in set(names) if names.count(name) > 1] + raise ValueError( + f"Duplicate replica group names found: {duplicates}. " + "Each replica group must have a unique name." + ) + return v + class ServiceConfigurationConfig( ProfileParamsConfig, @@ -858,6 +977,18 @@ class ServiceConfiguration( ): type: Literal["service"] = "service" + @property + def replica_groups(self) -> Optional[List[ReplicaGroup]]: + """ + Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None. + Use this property for type-safe access in code. + """ + if self.replicas is None: + return None + if isinstance(self.replicas, list): + return self.replicas + return None + AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration] diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index a966bc34a..59596eabb 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -253,6 +253,7 @@ class JobSpec(CoreModel): job_num: int job_name: str jobs_per_replica: int = 1 # default value for backward compatibility + replica_group: Optional[str] = "default" app_specs: Optional[List[AppSpec]] user: Optional[UnixUser] = None # default value for backward compatibility commands: List[str] diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index af2dcee8d..8885c6584 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -1,5 +1,6 @@ import asyncio import datetime +import json from typing import List, Optional, Set, Tuple from sqlalchemy import and_, or_, select @@ -8,6 +9,7 @@ import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError +from dstack._internal.core.models.configurations import ReplicaGroup from dstack._internal.core.models.profiles import RetryEvent, StopCriteria from dstack._internal.core.models.runs import ( Job, @@ -38,6 +40,7 @@ from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.prometheus.client_metrics import run_metrics from dstack._internal.server.services.runs import ( + create_group_run_spec, fmt, process_terminating_run, run_model_to_run, @@ -47,6 +50,7 @@ is_replica_registered, retry_run_replica_jobs, scale_run_replicas, + scale_run_replicas_per_group, ) from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.services import update_service_desired_replica_count @@ -192,9 +196,10 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): logger.debug("%s: retrying run is not yet ready for resubmission", fmt(run_model)) return - run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": - run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0 + run_model.desired_replica_count = sum( + group.replicas.min or 0 for group in (run.run_spec.configuration.replica_groups or []) + ) await update_service_desired_replica_count( session, run_model, @@ -203,12 +208,23 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): last_scaled_at=None, ) - if run_model.desired_replica_count == 0: - # stay zero scaled - return + if run_model.desired_replica_count == 0: + # stay zero scaled + return - await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) - switch_run_status(session, run_model, RunStatus.SUBMITTED) + # Per group scaling because single replica is also normalized to replica groups. + replicas: List[ReplicaGroup] = run.run_spec.configuration.replica_groups or [] + counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + await scale_run_replicas_per_group(session, run_model, replicas, counts) + else: + run_model.desired_replica_count = 1 + await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) + + switch_run_status(session=session, run_model=run_model, new_status=RunStatus.SUBMITTED) def _retrying_run_ready_for_resubmission(run_model: RunModel, run: Run) -> bool: @@ -444,6 +460,32 @@ async def _handle_run_replicas( # FIXME: should only include scaling events, not retries and deployments last_scaled_at=max((r.timestamp for r in replicas_info), default=None), ) + replicas: List[ReplicaGroup] = run_spec.configuration.replica_groups or [] + if replicas: + counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + await scale_run_replicas_per_group(session, run_model, replicas, counts) + + # Handle per-group rolling deployment + await _update_jobs_to_new_deployment_in_place( + session=session, + run_model=run_model, + run_spec=run_spec, + replicas=replicas, + ) + # Process per-group rolling deployment + for group in replicas: + await _handle_rolling_deployment_for_group( + session=session, + run_model=run_model, + group=group, + base_run_spec=run_spec, + desired_replica_counts=counts, + ) + return max_replica_count = run_model.desired_replica_count if _has_out_of_date_replicas(run_model): @@ -509,7 +551,10 @@ async def _handle_run_replicas( async def _update_jobs_to_new_deployment_in_place( - session: AsyncSession, run_model: RunModel, run_spec: RunSpec + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + replicas: Optional[List] = None, ) -> None: """ Bump deployment_num for jobs that do not require redeployment. @@ -518,14 +563,30 @@ async def _update_jobs_to_new_deployment_in_place( session=session, project=run_model.project, ) + base_run_spec = run_spec + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): if all(j.status.is_finished() for j in job_models): continue if all(j.deployment_num == run_model.deployment_num for j in job_models): continue + + # Determine which group this replica belongs to + replica_group_name = None + group_run_spec = base_run_spec + + if replicas: + job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data) + replica_group_name = job_spec.replica_group or "default" + + for group in replicas: + if group.name == replica_group_name: + group_run_spec = create_group_run_spec(base_run_spec, group) + break + # FIXME: Handle getting image configuration errors or skip it. new_job_specs = await get_job_specs_from_run_spec( - run_spec=run_spec, + run_spec=group_run_spec, secrets=secrets, replica_num=replica_num, ) @@ -543,8 +604,15 @@ async def _update_jobs_to_new_deployment_in_place( job_model.deployment_num = run_model.deployment_num -def _has_out_of_date_replicas(run: RunModel) -> bool: +def _has_out_of_date_replicas(run: RunModel, group_filter: Optional[str] = None) -> bool: for job in run.jobs: + # Filter jobs by group if specified + if group_filter is not None: + job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + # Handle None case: treat None as "default" for backward compatibility + job_replica_group = job_spec.replica_group or "default" + if job_replica_group != group_filter: + continue if job.deployment_num < run.deployment_num and not ( job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN ): @@ -607,3 +675,109 @@ def _should_stop_on_master_done(run: Run) -> bool: if is_master_job(job) and job.job_submissions[-1].status == JobStatus.DONE: return True return False + + +async def _handle_rolling_deployment_for_group( + session: AsyncSession, + run_model: RunModel, + group: ReplicaGroup, + base_run_spec: RunSpec, + desired_replica_counts: dict, +) -> None: + """ + Handle rolling deployment for a single replica group. + """ + from dstack._internal.server.services.runs.replicas import ( + _build_replica_lists, + scale_run_replicas_for_group, + ) + + group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) + + # Check if group has out-of-date replicas + if not _has_out_of_date_replicas(run_model, group_filter=group.name): + return # Group is up-to-date + + # Calculate max replicas (allow surge during deployment) + group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE + + # Count non-terminated replicas for this group only + + non_terminated_replica_count = len( + { + j.replica_num + for j in run_model.jobs + if not j.status.is_finished() and _job_belongs_to_group(job=j, group_name=group.name) + } + ) + + # Start new up-to-date replicas if needed + if non_terminated_replica_count < group_max_replica_count: + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, + jobs=run_model.jobs, + group_filter=group.name, + ) + + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=group_max_replica_count - non_terminated_replica_count, + base_run_spec=base_run_spec, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + # Stop out-of-date replicas that are not registered + replicas_to_stop_count = 0 + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + job_spec = JobSpec.__response__.parse_raw(jobs[0].job_spec_data) + if job_spec.replica_group != group.name: + continue + # Check if replica is out-of-date and not registered + if ( + any(j.deployment_num < run_model.deployment_num for j in jobs) + and any( + j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() + for j in jobs + ) + and not is_replica_registered(jobs) + ): + replicas_to_stop_count += 1 + + # Stop excessive registered out-of-date replicas + non_terminating_registered_replicas_count = 0 + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + # Filter by group + job_spec = JobSpec.__response__.parse_raw(jobs[0].job_spec_data) + if job_spec.replica_group != group.name: + continue + + if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs): + non_terminating_registered_replicas_count += 1 + + replicas_to_stop_count += max(0, non_terminating_registered_replicas_count - group_desired) + + if replicas_to_stop_count > 0: + # Build lists again to get current state + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, + jobs=run_model.jobs, + group_filter=group.name, + ) + + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=-replicas_to_stop_count, + base_run_spec=base_run_spec, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + +def _job_belongs_to_group(job: JobModel, group_name: str) -> bool: + job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + return job_spec.replica_group == group_name diff --git a/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py new file mode 100644 index 000000000..af4611b3c --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py @@ -0,0 +1,26 @@ +"""add runmodel desired_replica_counts + +Revision ID: 706e0acc3a7d +Revises: 22d74df9897e +Create Date: 2025-12-18 10:54:13.508297 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "706e0acc3a7d" +down_revision = "903c91e24634" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.add_column(sa.Column("desired_replica_counts", sa.Text(), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.drop_column("desired_replica_counts") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 5274d9ebf..72170f3c9 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -405,7 +405,7 @@ class RunModel(BaseModel): priority: Mapped[int] = mapped_column(Integer, default=0) deployment_num: Mapped[int] = mapped_column(Integer) desired_replica_count: Mapped[int] = mapped_column(Integer) - + desired_replica_counts: Mapped[Optional[str]] = mapped_column(Text, nullable=True) jobs: Mapped[List["JobModel"]] = relationship( back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]" ) diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 5773403cf..8f5ab7509 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -18,6 +18,7 @@ ServerClientError, ) from dstack._internal.core.models.common import ApplyAction +from dstack._internal.core.models.configurations import ReplicaGroup from dstack._internal.core.models.profiles import ( RetryEvent, ) @@ -486,12 +487,8 @@ async def submit_run( submitted_at = common_utils.get_current_datetime() initial_status = RunStatus.SUBMITTED - initial_replicas = 1 if run_spec.merged_profile.schedule is not None: initial_status = RunStatus.PENDING - initial_replicas = 0 - elif run_spec.configuration.type == "service": - initial_replicas = run_spec.configuration.replicas.min or 0 run_model = RunModel( id=uuid.uuid4(), @@ -519,12 +516,50 @@ async def submit_run( if run_spec.configuration.type == "service": await services.register_service(session, run_model, run_spec) + service_config = run_spec.configuration - for replica_num in range(initial_replicas): + global_replica_num = 0 # Global counter across all groups for unique replica_num + + for replica_group in service_config.replica_groups or []: + if run_spec.merged_profile.schedule is not None: + group_initial_replicas = 0 + else: + group_initial_replicas = replica_group.replicas.min or 0 + + # Each replica in this group gets the same group-specific configuration + for group_replica_num in range(group_initial_replicas): + group_run_spec = create_group_run_spec( + base_run_spec=run_spec, + replica_group=replica_group, + ) + jobs = await get_jobs_from_run_spec( + run_spec=group_run_spec, + secrets=secrets, + replica_num=global_replica_num, + ) + + for job in jobs: + job.job_spec.replica_group = replica_group.name + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=job, + status=JobStatus.SUBMITTED, + ) + session.add(job_model) + events.emit( + session, + f"Job created on run submission. Status: {job_model.status.upper()}", + actor=events.SystemActor(), + targets=[ + events.Target.from_model(job_model), + ], + ) + global_replica_num += 1 + else: jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, - replica_num=replica_num, + replica_num=0, ) for job in jobs: job_model = create_job_model_for_new_submission( @@ -552,6 +587,31 @@ async def submit_run( return common_utils.get_or_error(run) +def create_group_run_spec( + base_run_spec: RunSpec, + replica_group: ReplicaGroup, +) -> RunSpec: + # Create a copy of the configuration as a dict + config_dict = base_run_spec.configuration.dict() + + # Override with group-specific values (only if provided) + if replica_group.commands: + config_dict["commands"] = replica_group.commands + + if replica_group.resources: + config_dict["resources"] = replica_group.resources + + # Create new configuration object with merged values + # Use the same class as base (ServiceConfiguration) + new_config = base_run_spec.configuration.__class__.parse_obj(config_dict) + + # Create new RunSpec with modified configuration + # Preserve all other RunSpec properties (repo_data, file_archives, etc.) + run_spec_dict = base_run_spec.dict() + run_spec_dict["configuration"] = new_config + return RunSpec.parse_obj(run_spec_dict) + + def create_job_model_for_new_submission( run_model: RunModel, job: Job, diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index 43065d96d..5884a096a 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -1,8 +1,9 @@ -from typing import List +from typing import Dict, List, Optional, Tuple from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.runs import JobStatus, JobTerminationReason, RunSpec +from dstack._internal.core.models.configurations import ReplicaGroup +from dstack._internal.core.models.runs import JobSpec, JobStatus, JobTerminationReason, RunSpec from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.services import events from dstack._internal.server.services.jobs import ( @@ -11,7 +12,11 @@ switch_job_status, ) from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.runs import create_job_model_for_new_submission, logger +from dstack._internal.server.services.runs import ( + create_group_run_spec, + create_job_model_for_new_submission, + logger, +) from dstack._internal.server.services.secrets import get_project_secrets_mapping @@ -23,8 +28,28 @@ async def retry_run_replica_jobs( session=session, project=run_model.project, ) + + # Determine replica group from existing job + base_run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + job_spec = JobSpec.parse_raw(latest_jobs[0].job_spec_data) + replica_group_name = job_spec.replica_group + replica_group = None + + # Find matching replica group + if replica_group_name and base_run_spec.configuration.replicas: + for group in base_run_spec.configuration.replicas: + if group.name == replica_group_name: + replica_group = group + break + + run_spec = ( + base_run_spec + if replica_group is None + else create_group_run_spec(base_run_spec, replica_group) + ) + new_jobs = await get_jobs_from_run_spec( - run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + run_spec=run_spec, secrets=secrets, replica_num=latest_jobs[0].replica_num, ) @@ -41,6 +66,10 @@ async def retry_run_replica_jobs( job_model.termination_reason_message = "Replica is to be retried" switch_job_status(session, job_model, JobStatus.TERMINATING) + # Set replica_group on retried jobs to maintain group identity + if replica_group_name: + new_job.job_spec.replica_group = replica_group_name + new_job_model = create_job_model_for_new_submission( run_model=run_model, job=new_job, @@ -64,7 +93,6 @@ def is_replica_registered(jobs: list[JobModel]) -> bool: async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replicas_diff: int): if replicas_diff == 0: - # nothing to do return logger.info( @@ -74,14 +102,48 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica abs(replicas_diff), ) + active_replicas, inactive_replicas = _build_replica_lists(run_model, run_model.jobs) + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + + if replicas_diff < 0: + _scale_down_replicas(active_replicas, abs(replicas_diff)) + else: + await _scale_up_replicas( + session, + run_model, + active_replicas, + inactive_replicas, + replicas_diff, + run_spec, + group_name=None, + ) + + +def _build_replica_lists( + run_model: RunModel, + jobs: List[JobModel], + group_filter: Optional[str] = None, +) -> Tuple[ + List[Tuple[int, bool, int, List[JobModel]]], List[Tuple[int, bool, int, List[JobModel]]] +]: # lists of (importance, is_out_of_date, replica_num, jobs) active_replicas = [] inactive_replicas = [] - for replica_num, replica_jobs in group_jobs_by_replica_latest(run_model.jobs): + for replica_num, replica_jobs in group_jobs_by_replica_latest(jobs): + # Filter by group if specified + if group_filter is not None: + try: + job_spec = JobSpec.parse_raw(replica_jobs[0].job_spec_data) + if job_spec.replica_group != group_filter: + continue + except Exception: + continue + statuses = set(job.status for job in replica_jobs) deployment_num = replica_jobs[0].deployment_num # same for all jobs is_out_of_date = deployment_num < run_model.deployment_num + if {JobStatus.TERMINATING, *JobStatus.finished_statuses()} & statuses: # if there are any terminating or finished jobs, the replica is inactive inactive_replicas.append((0, is_out_of_date, replica_num, replica_jobs)) @@ -98,44 +160,71 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica # all jobs are running and ready, the replica is active and has the importance of 3 active_replicas.append((3, is_out_of_date, replica_num, replica_jobs)) - # sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) + # Sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) active_replicas.sort(key=lambda r: (r[1], -r[0], r[2])) - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) - if replicas_diff < 0: - for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]): - # scale down the less important replicas first - for job in replica_jobs: - if job.status.is_finished() or job.status == JobStatus.TERMINATING: - continue - job.status = JobStatus.TERMINATING - job.termination_reason = JobTerminationReason.SCALED_DOWN - # background task will process the job later - else: - scheduled_replicas = 0 + return active_replicas, inactive_replicas - # rerun inactive replicas - for _, _, _, replica_jobs in inactive_replicas: - if scheduled_replicas == replicas_diff: - break - await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) - scheduled_replicas += 1 +def _scale_down_replicas( + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + count: int, +) -> None: + """Scale down by terminating the least important replicas""" + if count <= 0: + return + + for _, _, _, replica_jobs in reversed(active_replicas[-count:]): + for job in replica_jobs: + if job.status.is_finished() or job.status == JobStatus.TERMINATING: + continue + job.status = JobStatus.TERMINATING + job.termination_reason = JobTerminationReason.SCALED_DOWN + + +async def _scale_up_replicas( + session: AsyncSession, + run_model: RunModel, + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]], + replicas_diff: int, + run_spec: RunSpec, + group_name: Optional[str] = None, +) -> None: + """Scale up by retrying inactive replicas and creating new ones""" + if replicas_diff <= 0: + return + + scheduled_replicas = 0 + + # Retry inactive replicas first + for _, _, _, replica_jobs in inactive_replicas: + if scheduled_replicas == replicas_diff: + break + await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) + scheduled_replicas += 1 + + # Create new replicas + if scheduled_replicas < replicas_diff: secrets = await get_project_secrets_mapping( session=session, project=run_model.project, ) - for replica_num in range( - len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff - ): - # FIXME: Handle getting image configuration errors or skip it. + max_replica_num = max((job.replica_num for job in run_model.jobs), default=-1) + + new_replicas_needed = replicas_diff - scheduled_replicas + for i in range(new_replicas_needed): + new_replica_num = max_replica_num + 1 + i jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, - replica_num=replica_num, + replica_num=new_replica_num, ) for job in jobs: + # Set replica_group if specified + if group_name is not None: + job.job_spec.replica_group = group_name job_model = create_job_model_for_new_submission( run_model=run_model, job=job, @@ -148,3 +237,90 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica actor=events.SystemActor(), targets=[events.Target.from_model(job_model)], ) + # Append to run_model.jobs so that when processing later replica groups in the same + # transaction, run_model.jobs includes jobs from previously processed groups. + run_model.jobs.append(job_model) + + +async def scale_run_replicas_per_group( + session: AsyncSession, + run_model: RunModel, + replicas: List[ReplicaGroup], + desired_replica_counts: Dict[str, int], +) -> None: + """Scale each replica group independently""" + if not replicas: + return + + for group in replicas: + group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) + + # Build replica lists filtered by this group + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, jobs=run_model.jobs, group_filter=group.name + ) + + # Count active replicas + active_group_count = len(active_replicas) + group_diff = group_desired - active_group_count + + if group_diff != 0: + # Check if rolling deployment is in progress for THIS GROUP + from dstack._internal.server.background.tasks.process_runs import ( + _has_out_of_date_replicas, + ) + + group_has_out_of_date = _has_out_of_date_replicas(run_model, group_filter=group.name) + + # During rolling deployment, don't scale down old replicas + # Let rolling deployment handle stopping old replicas + if group_diff < 0 and group_has_out_of_date: + # Skip scaling down during rolling deployment + continue + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=group_diff, + base_run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + +async def scale_run_replicas_for_group( + session: AsyncSession, + run_model: RunModel, + group: ReplicaGroup, + replicas_diff: int, + base_run_spec: RunSpec, + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]], +) -> None: + """Scale a specific replica group up or down""" + if replicas_diff == 0: + return + + logger.info( + "%s: scaling %s %s replica(s) for group '%s'", + fmt(run_model), + "UP" if replicas_diff > 0 else "DOWN", + abs(replicas_diff), + group.name, + ) + + # Get group-specific run_spec + group_run_spec = create_group_run_spec(base_run_spec, group) + + if replicas_diff < 0: + _scale_down_replicas(active_replicas, abs(replicas_diff)) + else: + await _scale_up_replicas( + session=session, + run_model=run_model, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + replicas_diff=replicas_diff, + run_spec=group_run_spec, + group_name=group.name, + ) diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index 73b6d9fc7..66b22bc79 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -88,7 +88,10 @@ def validate_run_spec_and_set_defaults( f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_RUNNING_TTL_SECONDS}s" ) if isinstance(run_spec.configuration, ServiceConfiguration): - if run_spec.merged_profile.schedule and run_spec.configuration.replicas.min == 0: + # Check if any group has min=0 + if run_spec.merged_profile.schedule and any( + group.replicas.min == 0 for group in (run_spec.configuration.replica_groups or []) + ): raise ServerClientError( "Scheduled services with autoscaling to zero are not supported" ) @@ -149,11 +152,10 @@ def get_nodes_required_num(run_spec: RunSpec) -> int: nodes_required_num = 1 if run_spec.configuration.type == "task": nodes_required_num = run_spec.configuration.nodes - elif ( - run_spec.configuration.type == "service" - and run_spec.configuration.replicas.min is not None - ): - nodes_required_num = run_spec.configuration.replicas.min + elif run_spec.configuration.type == "service": + nodes_required_num = sum( + group.replicas.min or 0 for group in (run_spec.configuration.replica_groups or []) + ) return nodes_required_num diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 05c1fa909..222a954cb 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -2,6 +2,7 @@ Application logic related to `type: service` runs. """ +import json import uuid from datetime import datetime from typing import Optional @@ -141,7 +142,12 @@ def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> Servi "The `https` configuration property is not applicable when running services without a gateway." " Please configure a gateway or remove the `https` property from the service configuration" ) - if run_spec.configuration.replicas.min != run_spec.configuration.replicas.max: + # Check if any group has autoscaling (min != max) + has_autoscaling = any( + group.replicas.min != group.replicas.max + for group in (run_spec.configuration.replica_groups or []) + ) + if has_autoscaling: raise ServerClientError( "Auto-scaling is not supported when running services without a gateway." " Please configure a gateway or set `replicas` to a fixed value in the service configuration" @@ -299,13 +305,37 @@ async def update_service_desired_replica_count( configuration: ServiceConfiguration, last_scaled_at: Optional[datetime], ) -> None: - scaler = get_service_scaler(configuration) stats = None if run_model.gateway_id is not None: conn = await get_or_add_gateway_connection(session, run_model.gateway_id) stats = await conn.get_stats(run_model.project.name, run_model.run_name) - run_model.desired_replica_count = scaler.get_desired_count( - current_desired_count=run_model.desired_replica_count, - stats=stats, - last_scaled_at=last_scaled_at, - ) + replica_groups = configuration.replica_groups or [] + if replica_groups: + desired_replica_counts = {} + total = 0 + prev_counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + for group in replica_groups: + scaler = get_service_scaler(group.replicas, group.scaling) + group_desired = scaler.get_desired_count( + current_desired_count=prev_counts.get(group.name, group.replicas.min or 0), + stats=stats, + last_scaled_at=last_scaled_at, + ) + desired_replica_counts[group.name] = group_desired + total += group_desired + run_model.desired_replica_counts = json.dumps(desired_replica_counts) + run_model.desired_replica_count = total + else: + # Todo Not required as single replica is normalized to replicas. + if configuration.replica_groups: + first_group = configuration.replica_groups[0] + scaler = get_service_scaler(count=first_group.replicas, scaling=first_group.scaling) + run_model.desired_replica_count = scaler.get_desired_count( + current_desired_count=run_model.desired_replica_count, + stats=stats, + last_scaled_at=last_scaled_at, + ) diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index cd6d06e58..641d2cee4 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -6,7 +6,8 @@ from pydantic import BaseModel import dstack._internal.utils.common as common_utils -from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.configurations import ScalingSpec +from dstack._internal.core.models.resources import Range from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats @@ -119,21 +120,21 @@ def get_desired_count( return new_desired_count -def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: - assert conf.replicas.min is not None - assert conf.replicas.max is not None - if conf.scaling is None: +def get_service_scaler(count: Range[int], scaling: Optional[ScalingSpec]) -> BaseServiceScaler: + assert count.min is not None + assert count.max is not None + if scaling is None: return ManualScaler( - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, + min_replicas=count.min, + max_replicas=count.max, ) - if conf.scaling.metric == "rps": + if scaling.metric == "rps": return RPSAutoscaler( # replicas count validated by configuration model - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, - target=conf.scaling.target, - scale_up_delay=conf.scaling.scale_up_delay, - scale_down_delay=conf.scaling.scale_down_delay, + min_replicas=count.min, + max_replicas=count.max, + target=scaling.target, + scale_up_delay=scaling.scale_up_delay, + scale_down_delay=scaling.scale_down_delay, ) - raise ValueError(f"No scaler found for scaling parameters {conf.scaling}") + raise ValueError(f"No scaler found for scaling parameters {scaling}")