From 22c14107ad1727352f0edf1d3e0af2f56e6ae931 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Fri, 12 Dec 2025 15:20:16 +0545 Subject: [PATCH 1/5] Add replica groups in dstack-service add_replica_groups_model Replica Groups AutoScaling Rolling deployment and UI Replica Groups implementation clean up --- src/dstack/_internal/cli/utils/run.py | 63 ++++- .../_internal/core/models/configurations.py | 157 ++++++++++++ src/dstack/_internal/core/models/runs.py | 1 + .../server/background/tasks/process_runs.py | 192 +++++++++++++- ...a7d_add_runmodel_desired_replica_counts.py | 26 ++ src/dstack/_internal/server/models.py | 2 +- .../server/services/runs/__init__.py | 72 +++++- .../server/services/runs/replicas.py | 236 +++++++++++++++--- .../_internal/server/services/runs/spec.py | 1 + .../server/services/services/__init__.py | 39 ++- 10 files changed, 733 insertions(+), 56 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 68dc828f7..48d121bb9 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -281,16 +281,38 @@ def _format_job_name( show_deployment_num: bool, show_replica: bool, show_job: bool, + group_index: Optional[int] = None, + last_shown_group_index: Optional[int] = None, ) -> str: name_parts = [] + prefix = "" if show_replica: - name_parts.append(f"replica={job.job_spec.replica_num}") + # Show group information if replica groups are used + if group_index is not None: + # Show group=X replica=Y when group changes, or just replica=Y when same group + if group_index != last_shown_group_index: + # First job in group: use 3 spaces indent + prefix = " " + name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}") + else: + # Subsequent job in same group: align "replica=" with first job's "replica=" + # Calculate padding: width of " group={last_shown_group_index} " + padding_width = 3 + len(f"group={last_shown_group_index}") + 1 + prefix = " " * padding_width + name_parts.append(f"replica={job.job_spec.replica_num}") + else: + # Legacy behavior: no replica groups + prefix = " " + name_parts.append(f"replica={job.job_spec.replica_num}") + else: + prefix = " " + if show_job: name_parts.append(f"job={job.job_spec.job_num}") name_suffix = ( f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else "" ) - name_value = " " + (" ".join(name_parts) if name_parts else "") + name_value = prefix + (" ".join(name_parts) if name_parts else "") name_value += name_suffix return name_value @@ -359,6 +381,14 @@ def get_runs_table( ) merge_job_rows = len(run.jobs) == 1 and not show_deployment_num + # Replica Group Changes: Build mapping from replica group names to indices + group_name_to_index: Dict[str, int] = {} + # Replica Group Changes: Check if replica_groups attribute exists (only available for ServiceConfiguration) + replica_groups = getattr(run.run_spec.configuration, "replica_groups", None) + if replica_groups: + for idx, group in enumerate(replica_groups): + group_name_to_index[group.name] = idx + run_row: Dict[Union[str, int], Any] = { "NAME": _format_run_name(run, show_deployment_num), "SUBMITTED": format_date(run.submitted_at), @@ -372,13 +402,35 @@ def get_runs_table( if not merge_job_rows: add_row_from_dict(table, run_row) - for job in run.jobs: + # Sort jobs by group index first, then by replica_num within each group + def get_job_sort_key(job: Job) -> tuple: + group_index = None + if group_name_to_index and job.job_spec.replica_group: + group_index = group_name_to_index.get(job.job_spec.replica_group) + # Use a large number for jobs without groups to put them at the end + return (group_index if group_index is not None else 999999, job.job_spec.replica_num) + + sorted_jobs = sorted(run.jobs, key=get_job_sort_key) + + last_shown_group_index: Optional[int] = None + for job in sorted_jobs: latest_job_submission = job.job_submissions[-1] status_formatted = _format_job_submission_status(latest_job_submission, verbose) + # Get group index for this job + group_index: Optional[int] = None + if group_name_to_index and job.job_spec.replica_group: + group_index = group_name_to_index.get(job.job_spec.replica_group) + job_row: Dict[Union[str, int], Any] = { "NAME": _format_job_name( - job, latest_job_submission, show_deployment_num, show_replica, show_job + job, + latest_job_submission, + show_deployment_num, + show_replica, + show_job, + group_index=group_index, + last_shown_group_index=last_shown_group_index, ), "STATUS": status_formatted, "PROBES": _format_job_probes( @@ -390,6 +442,9 @@ def get_runs_table( "GPU": "-", "PRICE": "-", } + # Update last shown group index for next iteration + if group_index is not None: + last_shown_group_index = group_index jpd = latest_job_submission.job_provisioning_data if jpd is not None: shared_offer: Optional[InstanceOfferWithAvailability] = None diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 158c59b34..93dd9909d 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel): @root_validator def check_image_or_commands_present(cls, values): + # If replica_groups is present, skip validation - commands come from replica groups + replica_groups = values.get("replica_groups") + if replica_groups: + return values + if not values.get("commands") and not values.get("image"): raise ValueError("Either `commands` or `image` must be set") return values @@ -714,6 +719,85 @@ def schema_extra(schema: Dict[str, Any]): ) +class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel): + name: Annotated[ + str, + Field(description="The name of the replica group"), + ] + replicas: Annotated[ + Range[int], + Field( + description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " + "If it's a range, the `scaling` property is required" + ), + ] + scaling: Annotated[ + Optional[ScalingSpec], + Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), + ] = None + probes: Annotated[ + list[ProbeConfig], + Field(description="List of probes used to determine job health for this replica group"), + ] = [] + rate_limits: Annotated[ + list[RateLimit], + Field(description="Rate limiting rules for this replica group"), + ] = [] + # TODO: Extract to ConfigurationWithResourcesParams mixin + resources: Annotated[ + ResourcesSpec, + Field(description="The resources requirements for replicas in this group"), + ] = ResourcesSpec() + + @validator("replicas") + def convert_replicas(cls, v: Range[int]) -> Range[int]: + if v.max is None: + raise ValueError("The maximum number of replicas is required") + if v.min is None: + v.min = 0 + if v.min < 0: + raise ValueError("The minimum number of replicas must be greater than or equal to 0") + return v + + @root_validator() + def override_commands_validation(cls, values): + """ + Override parent validator from ConfigurationWithCommandsParams. + ReplicaGroup always requires commands (no image option). + """ + commands = values.get("commands", []) + if not commands: + raise ValueError("`commands` must be set for replica groups") + return values + + @root_validator() + def validate_scaling(cls, values): + scaling = values.get("scaling") + replicas = values.get("replicas") + if replicas and replicas.min != replicas.max and not scaling: + raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") + if replicas and replicas.min == replicas.max and scaling: + raise ValueError("To use `scaling`, `replicas` must be set to a range.") + return values + + @validator("rate_limits") + def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: + counts = Counter(limit.prefix for limit in v) + duplicates = [prefix for prefix, count in counts.items() if count > 1] + if duplicates: + raise ValueError( + f"Prefixes {duplicates} are used more than once." + " Each rate limit should have a unique path prefix" + ) + return v + + @validator("probes") + def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: + if has_duplicates(v): + raise ValueError("Probes must be unique") + return v + + class ServiceConfigurationParams(CoreModel): port: Annotated[ # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used. @@ -771,6 +855,19 @@ class ServiceConfigurationParams(CoreModel): Field(description="List of probes used to determine job health"), ] = [] + replica_groups: Annotated[ + Optional[List[ReplicaGroup]], + Field( + description=( + "List of replica groups. Each group defines replicas with shared configuration " + "(commands, port, resources, scaling, probes, rate_limits). " + "When specified, the top-level `replicas`, `commands`, `port`, `resources`, " + "`scaling`, `probes`, and `rate_limits` are ignored. " + "Each replica group must have a unique name." + ) + ), + ] = None + @validator("port") def convert_port(cls, v) -> PortMapping: if isinstance(v, int): @@ -807,6 +904,12 @@ def validate_gateway( @root_validator() def validate_scaling(cls, values): + replica_groups = values.get("replica_groups") + # If replica_groups are set, we don't need to validate scaling. + # Each replica group has its own scaling. + if replica_groups: + return values + scaling = values.get("scaling") replicas = values.get("replicas") if replicas and replicas.min != replicas.max and not scaling: @@ -815,6 +918,42 @@ def validate_scaling(cls, values): raise ValueError("To use `scaling`, `replicas` must be set to a range.") return values + @root_validator() + def normalize_to_replica_groups(cls, values): + replica_groups = values.get("replica_groups") + if replica_groups: + return values + + # TEMP: prove we’re here and see the inputs + print( + "[normalize_to_replica_groups]", + "commands:", + values.get("commands"), + "replicas:", + values.get("replicas"), + "resources:", + values.get("resources"), + "scaling:", + values.get("scaling"), + "probes:", + values.get("probes"), + "rate_limits:", + values.get("rate_limits"), + ) + # If replica_groups is not set, we need to normalize the configuration to replica groups. + values["replica_groups"] = [ + ReplicaGroup( + name="default", + replicas=values.get("replicas"), + commands=values.get("commands"), + resources=values.get("resources"), + scaling=values.get("scaling"), + probes=values.get("probes"), + rate_limits=values.get("rate_limits"), + ) + ] + return values + @validator("rate_limits") def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: counts = Counter(limit.prefix for limit in v) @@ -836,6 +975,24 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: raise ValueError("Probes must be unique") return v + @validator("replica_groups") + def validate_replica_groups( + cls, v: Optional[List[ReplicaGroup]] + ) -> Optional[List[ReplicaGroup]]: + if v is None: + return v + if not v: + raise ValueError("`replica_groups` cannot be an empty list") + # Check for duplicate names + names = [group.name for group in v] + if len(names) != len(set(names)): + duplicates = [name for name in set(names) if names.count(name) > 1] + raise ValueError( + f"Duplicate replica group names found: {duplicates}. " + "Each replica group must have a unique name." + ) + return v + class ServiceConfigurationConfig( ProfileParamsConfig, diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index a966bc34a..59596eabb 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -253,6 +253,7 @@ class JobSpec(CoreModel): job_num: int job_name: str jobs_per_replica: int = 1 # default value for backward compatibility + replica_group: Optional[str] = "default" app_specs: Optional[List[AppSpec]] user: Optional[UnixUser] = None # default value for backward compatibility commands: List[str] diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index af2dcee8d..1a1a33fa6 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -1,5 +1,6 @@ import asyncio import datetime +import json from typing import List, Optional, Set, Tuple from sqlalchemy import and_, or_, select @@ -8,6 +9,7 @@ import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError +from dstack._internal.core.models.configurations import ReplicaGroup from dstack._internal.core.models.profiles import RetryEvent, StopCriteria from dstack._internal.core.models.runs import ( Job, @@ -38,6 +40,7 @@ from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.prometheus.client_metrics import run_metrics from dstack._internal.server.services.runs import ( + create_group_run_spec, fmt, process_terminating_run, run_model_to_run, @@ -47,6 +50,7 @@ is_replica_registered, retry_run_replica_jobs, scale_run_replicas, + scale_run_replicas_per_group, ) from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.services import update_service_desired_replica_count @@ -192,7 +196,7 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): logger.debug("%s: retrying run is not yet ready for resubmission", fmt(run_model)) return - run_model.desired_replica_count = 1 + # run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0 await update_service_desired_replica_count( @@ -203,12 +207,24 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): last_scaled_at=None, ) - if run_model.desired_replica_count == 0: - # stay zero scaled - return + if run_model.desired_replica_count == 0: + # stay zero scaled + return + - await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) - switch_run_status(session, run_model, RunStatus.SUBMITTED) + # Per group scaling because single replica is also normalized to replica groups. + replica_groups = run.run_spec.configuration.replica_groups or [] + counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + await scale_run_replicas_per_group(session, run_model, replica_groups, counts) + else: + run_model.desired_replica_count = 1 + await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) + + switch_run_status(session=session, run_model=run_model, new_status=RunStatus.SUBMITTED) def _retrying_run_ready_for_resubmission(run_model: RunModel, run: Run) -> bool: @@ -444,6 +460,32 @@ async def _handle_run_replicas( # FIXME: should only include scaling events, not retries and deployments last_scaled_at=max((r.timestamp for r in replicas_info), default=None), ) + replica_groups = run_spec.configuration.replica_groups or [] + if replica_groups: + counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + await scale_run_replicas_per_group(session, run_model, replica_groups, counts) + + # Handle per-group rolling deployment + await _update_jobs_to_new_deployment_in_place( + session=session, + run_model=run_model, + run_spec=run_spec, + replica_groups=replica_groups, + ) + # Process per-group rolling deployment + for group in replica_groups: + await _handle_rolling_deployment_for_group( + session=session, + run_model=run_model, + group=group, + base_run_spec=run_spec, + desired_replica_counts=counts, + ) + return max_replica_count = run_model.desired_replica_count if _has_out_of_date_replicas(run_model): @@ -509,7 +551,10 @@ async def _handle_run_replicas( async def _update_jobs_to_new_deployment_in_place( - session: AsyncSession, run_model: RunModel, run_spec: RunSpec + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + replica_groups: Optional[List] = None, ) -> None: """ Bump deployment_num for jobs that do not require redeployment. @@ -518,14 +563,30 @@ async def _update_jobs_to_new_deployment_in_place( session=session, project=run_model.project, ) + base_run_spec = run_spec + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): if all(j.status.is_finished() for j in job_models): continue if all(j.deployment_num == run_model.deployment_num for j in job_models): continue + + # Determine which group this replica belongs to + replica_group_name = None + group_run_spec = base_run_spec + + if replica_groups: + job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data) + replica_group_name = job_spec.replica_group or "default" + + for group in replica_groups: + if group.name == replica_group_name: + group_run_spec = create_group_run_spec(base_run_spec, group) + break + # FIXME: Handle getting image configuration errors or skip it. new_job_specs = await get_job_specs_from_run_spec( - run_spec=run_spec, + run_spec=group_run_spec, secrets=secrets, replica_num=replica_num, ) @@ -543,8 +604,15 @@ async def _update_jobs_to_new_deployment_in_place( job_model.deployment_num = run_model.deployment_num -def _has_out_of_date_replicas(run: RunModel) -> bool: +def _has_out_of_date_replicas(run: RunModel, group_filter: Optional[str] = None) -> bool: for job in run.jobs: + # Filter jobs by group if specified + if group_filter is not None: + job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + # Handle None case: treat None as "default" for backward compatibility + job_replica_group = job_spec.replica_group or "default" + if job_replica_group != group_filter: + continue if job.deployment_num < run.deployment_num and not ( job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN ): @@ -607,3 +675,109 @@ def _should_stop_on_master_done(run: Run) -> bool: if is_master_job(job) and job.job_submissions[-1].status == JobStatus.DONE: return True return False + + +async def _handle_rolling_deployment_for_group( + session: AsyncSession, + run_model: RunModel, + group: ReplicaGroup, + base_run_spec: RunSpec, + desired_replica_counts: dict, +) -> None: + """ + Handle rolling deployment for a single replica group. + """ + from dstack._internal.server.services.runs.replicas import ( + _build_replica_lists, + scale_run_replicas_for_group, + ) + + group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) + + # Check if group has out-of-date replicas + if not _has_out_of_date_replicas(run_model, group_filter=group.name): + return # Group is up-to-date + + # Calculate max replicas (allow surge during deployment) + group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE + + # Count non-terminated replicas for this group only + + non_terminated_replica_count = len( + { + j.replica_num + for j in run_model.jobs + if not j.status.is_finished() and _job_belongs_to_group(job=j, group_name=group.name) + } + ) + + # Start new up-to-date replicas if needed + if non_terminated_replica_count < group_max_replica_count: + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, + jobs=run_model.jobs, + group_filter=group.name, + ) + + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=group_max_replica_count - non_terminated_replica_count, + base_run_spec=base_run_spec, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + # Stop out-of-date replicas that are not registered + replicas_to_stop_count = 0 + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + job_spec = JobSpec.__response__.parse_raw(jobs[0].job_spec_data) + if job_spec.replica_group != group.name: + continue + # Check if replica is out-of-date and not registered + if ( + any(j.deployment_num < run_model.deployment_num for j in jobs) + and any( + j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() + for j in jobs + ) + and not is_replica_registered(jobs) + ): + replicas_to_stop_count += 1 + + # Stop excessive registered out-of-date replicas + non_terminating_registered_replicas_count = 0 + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + # Filter by group + job_spec = JobSpec.__response__.parse_raw(jobs[0].job_spec_data) + if job_spec.replica_group != group.name: + continue + + if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs): + non_terminating_registered_replicas_count += 1 + + replicas_to_stop_count += max(0, non_terminating_registered_replicas_count - group_desired) + + if replicas_to_stop_count > 0: + # Build lists again to get current state + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, + jobs=run_model.jobs, + group_filter=group.name, + ) + + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=-replicas_to_stop_count, + base_run_spec=base_run_spec, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + +def _job_belongs_to_group(job: JobModel, group_name: str) -> bool: + job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + return job_spec.replica_group == group_name diff --git a/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py new file mode 100644 index 000000000..f615560cb --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py @@ -0,0 +1,26 @@ +"""add runmodel desired_replica_counts + +Revision ID: 706e0acc3a7d +Revises: 22d74df9897e +Create Date: 2025-12-18 10:54:13.508297 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "706e0acc3a7d" +down_revision = "22d74df9897e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.add_column(sa.Column("desired_replica_counts", sa.Text(), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.drop_column("desired_replica_counts") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 5274d9ebf..72170f3c9 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -405,7 +405,7 @@ class RunModel(BaseModel): priority: Mapped[int] = mapped_column(Integer, default=0) deployment_num: Mapped[int] = mapped_column(Integer) desired_replica_count: Mapped[int] = mapped_column(Integer) - + desired_replica_counts: Mapped[Optional[str]] = mapped_column(Text, nullable=True) jobs: Mapped[List["JobModel"]] = relationship( back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]" ) diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 5773403cf..bfa9a90d8 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -18,6 +18,7 @@ ServerClientError, ) from dstack._internal.core.models.common import ApplyAction +from dstack._internal.core.models.configurations import ReplicaGroup from dstack._internal.core.models.profiles import ( RetryEvent, ) @@ -486,12 +487,8 @@ async def submit_run( submitted_at = common_utils.get_current_datetime() initial_status = RunStatus.SUBMITTED - initial_replicas = 1 if run_spec.merged_profile.schedule is not None: initial_status = RunStatus.PENDING - initial_replicas = 0 - elif run_spec.configuration.type == "service": - initial_replicas = run_spec.configuration.replicas.min or 0 run_model = RunModel( id=uuid.uuid4(), @@ -519,12 +516,50 @@ async def submit_run( if run_spec.configuration.type == "service": await services.register_service(session, run_model, run_spec) + service_config = run_spec.configuration - for replica_num in range(initial_replicas): + global_replica_num = 0 # Global counter across all groups for unique replica_num + + for replica_group in service_config.replica_groups: + if run_spec.merged_profile.schedule is not None: + group_initial_replicas = 0 + else: + group_initial_replicas = replica_group.replicas.min or 0 + + # Each replica in this group gets the same group-specific configuration + for group_replica_num in range(group_initial_replicas): + group_run_spec = create_group_run_spec( + base_run_spec=run_spec, + replica_group=replica_group, + ) + jobs = await get_jobs_from_run_spec( + run_spec=group_run_spec, + secrets=secrets, + replica_num=global_replica_num, + ) + + for job in jobs: + job.job_spec.replica_group = replica_group.name + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=job, + status=JobStatus.SUBMITTED, + ) + session.add(job_model) + events.emit( + session, + f"Job created on run submission. Status: {job_model.status.upper()}", + actor=events.SystemActor(), + targets=[ + events.Target.from_model(job_model), + ], + ) + global_replica_num += 1 + else: jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, - replica_num=replica_num, + replica_num=0, ) for job in jobs: job_model = create_job_model_for_new_submission( @@ -552,6 +587,31 @@ async def submit_run( return common_utils.get_or_error(run) +def create_group_run_spec( + base_run_spec: RunSpec, + replica_group: ReplicaGroup, +) -> RunSpec: + # Create a copy of the configuration as a dict + config_dict = base_run_spec.configuration.dict() + + # Override with group-specific values (only if provided) + if replica_group.commands: + config_dict["commands"] = replica_group.commands + + if replica_group.resources: + config_dict["resources"] = replica_group.resources + + # Create new configuration object with merged values + # Use the same class as base (ServiceConfiguration) + new_config = base_run_spec.configuration.__class__.parse_obj(config_dict) + + # Create new RunSpec with modified configuration + # Preserve all other RunSpec properties (repo_data, file_archives, etc.) + run_spec_dict = base_run_spec.dict() + run_spec_dict["configuration"] = new_config + return RunSpec.parse_obj(run_spec_dict) + + def create_job_model_for_new_submission( run_model: RunModel, job: Job, diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index 43065d96d..c8cb3ddc6 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -1,8 +1,9 @@ -from typing import List +from typing import Dict, List, Optional, Tuple from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.runs import JobStatus, JobTerminationReason, RunSpec +from dstack._internal.core.models.configurations import ReplicaGroup +from dstack._internal.core.models.runs import JobSpec, JobStatus, JobTerminationReason, RunSpec from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.services import events from dstack._internal.server.services.jobs import ( @@ -11,7 +12,11 @@ switch_job_status, ) from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.runs import create_job_model_for_new_submission, logger +from dstack._internal.server.services.runs import ( + create_group_run_spec, + create_job_model_for_new_submission, + logger, +) from dstack._internal.server.services.secrets import get_project_secrets_mapping @@ -23,8 +28,28 @@ async def retry_run_replica_jobs( session=session, project=run_model.project, ) + + # Determine replica group from existing job + base_run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + job_spec = JobSpec.parse_raw(latest_jobs[0].job_spec_data) + replica_group_name = job_spec.replica_group + replica_group = None + + # Find matching replica group + if replica_group_name and base_run_spec.configuration.replica_groups: + for group in base_run_spec.configuration.replica_groups: + if group.name == replica_group_name: + replica_group = group + break + + run_spec = ( + base_run_spec + if replica_group is None + else create_group_run_spec(base_run_spec, replica_group) + ) + new_jobs = await get_jobs_from_run_spec( - run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + run_spec=run_spec, secrets=secrets, replica_num=latest_jobs[0].replica_num, ) @@ -41,6 +66,10 @@ async def retry_run_replica_jobs( job_model.termination_reason_message = "Replica is to be retried" switch_job_status(session, job_model, JobStatus.TERMINATING) + # Set replica_group on retried jobs to maintain group identity + if replica_group_name: + new_job.job_spec.replica_group = replica_group_name + new_job_model = create_job_model_for_new_submission( run_model=run_model, job=new_job, @@ -64,7 +93,6 @@ def is_replica_registered(jobs: list[JobModel]) -> bool: async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replicas_diff: int): if replicas_diff == 0: - # nothing to do return logger.info( @@ -74,14 +102,48 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica abs(replicas_diff), ) + active_replicas, inactive_replicas = _build_replica_lists(run_model, run_model.jobs) + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + + if replicas_diff < 0: + _scale_down_replicas(active_replicas, abs(replicas_diff)) + else: + await _scale_up_replicas( + session, + run_model, + active_replicas, + inactive_replicas, + replicas_diff, + run_spec, + group_name=None, + ) + + +def _build_replica_lists( + run_model: RunModel, + jobs: List[JobModel], + group_filter: Optional[str] = None, +) -> Tuple[ + List[Tuple[int, bool, int, List[JobModel]]], List[Tuple[int, bool, int, List[JobModel]]] +]: # lists of (importance, is_out_of_date, replica_num, jobs) active_replicas = [] inactive_replicas = [] - for replica_num, replica_jobs in group_jobs_by_replica_latest(run_model.jobs): + for replica_num, replica_jobs in group_jobs_by_replica_latest(jobs): + # Filter by group if specified + if group_filter is not None: + try: + job_spec = JobSpec.parse_raw(replica_jobs[0].job_spec_data) + if job_spec.replica_group != group_filter: + continue + except Exception: + continue + statuses = set(job.status for job in replica_jobs) deployment_num = replica_jobs[0].deployment_num # same for all jobs is_out_of_date = deployment_num < run_model.deployment_num + if {JobStatus.TERMINATING, *JobStatus.finished_statuses()} & statuses: # if there are any terminating or finished jobs, the replica is inactive inactive_replicas.append((0, is_out_of_date, replica_num, replica_jobs)) @@ -98,44 +160,71 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica # all jobs are running and ready, the replica is active and has the importance of 3 active_replicas.append((3, is_out_of_date, replica_num, replica_jobs)) - # sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) + # Sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) active_replicas.sort(key=lambda r: (r[1], -r[0], r[2])) - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) - if replicas_diff < 0: - for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]): - # scale down the less important replicas first - for job in replica_jobs: - if job.status.is_finished() or job.status == JobStatus.TERMINATING: - continue - job.status = JobStatus.TERMINATING - job.termination_reason = JobTerminationReason.SCALED_DOWN - # background task will process the job later - else: - scheduled_replicas = 0 + return active_replicas, inactive_replicas - # rerun inactive replicas - for _, _, _, replica_jobs in inactive_replicas: - if scheduled_replicas == replicas_diff: - break - await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) - scheduled_replicas += 1 +def _scale_down_replicas( + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + count: int, +) -> None: + """Scale down by terminating the least important replicas""" + if count <= 0: + return + + for _, _, _, replica_jobs in reversed(active_replicas[-count:]): + for job in replica_jobs: + if job.status.is_finished() or job.status == JobStatus.TERMINATING: + continue + job.status = JobStatus.TERMINATING + job.termination_reason = JobTerminationReason.SCALED_DOWN + + +async def _scale_up_replicas( + session: AsyncSession, + run_model: RunModel, + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]], + replicas_diff: int, + run_spec: RunSpec, + group_name: Optional[str] = None, +) -> None: + """Scale up by retrying inactive replicas and creating new ones""" + if replicas_diff <= 0: + return + + scheduled_replicas = 0 + + # Retry inactive replicas first + for _, _, _, replica_jobs in inactive_replicas: + if scheduled_replicas == replicas_diff: + break + await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) + scheduled_replicas += 1 + + # Create new replicas + if scheduled_replicas < replicas_diff: secrets = await get_project_secrets_mapping( session=session, project=run_model.project, ) - for replica_num in range( - len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff - ): - # FIXME: Handle getting image configuration errors or skip it. + max_replica_num = max((job.replica_num for job in run_model.jobs), default=-1) + + new_replicas_needed = replicas_diff - scheduled_replicas + for i in range(new_replicas_needed): + new_replica_num = max_replica_num + 1 + i jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, - replica_num=replica_num, + replica_num=new_replica_num, ) for job in jobs: + # Set replica_group if specified + if group_name is not None: + job.job_spec.replica_group = group_name job_model = create_job_model_for_new_submission( run_model=run_model, job=job, @@ -148,3 +237,90 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica actor=events.SystemActor(), targets=[events.Target.from_model(job_model)], ) + # Append to run_model.jobs so that when processing later replica groups in the same + # transaction, run_model.jobs includes jobs from previously processed groups. + run_model.jobs.append(job_model) + + +async def scale_run_replicas_per_group( + session: AsyncSession, + run_model: RunModel, + replica_groups: List[ReplicaGroup], + desired_replica_counts: Dict[str, int], +) -> None: + """Scale each replica group independently""" + if not replica_groups: + return + + for group in replica_groups: + group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) + + # Build replica lists filtered by this group + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, jobs=run_model.jobs, group_filter=group.name + ) + + # Count active replicas + active_group_count = len(active_replicas) + group_diff = group_desired - active_group_count + + if group_diff != 0: + # Check if rolling deployment is in progress for THIS GROUP + from dstack._internal.server.background.tasks.process_runs import ( + _has_out_of_date_replicas, + ) + + group_has_out_of_date = _has_out_of_date_replicas(run_model, group_filter=group.name) + + # During rolling deployment, don't scale down old replicas + # Let rolling deployment handle stopping old replicas + if group_diff < 0 and group_has_out_of_date: + # Skip scaling down during rolling deployment + continue + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=group_diff, + base_run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + +async def scale_run_replicas_for_group( + session: AsyncSession, + run_model: RunModel, + group: ReplicaGroup, + replicas_diff: int, + base_run_spec: RunSpec, + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]], +) -> None: + """Scale a specific replica group up or down""" + if replicas_diff == 0: + return + + logger.info( + "%s: scaling %s %s replica(s) for group '%s'", + fmt(run_model), + "UP" if replicas_diff > 0 else "DOWN", + abs(replicas_diff), + group.name, + ) + + # Get group-specific run_spec + group_run_spec = create_group_run_spec(base_run_spec, group) + + if replicas_diff < 0: + _scale_down_replicas(active_replicas, abs(replicas_diff)) + else: + await _scale_up_replicas( + session=session, + run_model=run_model, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + replicas_diff=replicas_diff, + run_spec=group_run_spec, + group_name=group.name, + ) diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index 73b6d9fc7..53d0c2192 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -50,6 +50,7 @@ "env", "shell", "commands", + "replica_groups", ], } diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 05c1fa909..d029c145e 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -2,6 +2,7 @@ Application logic related to `type: service` runs. """ +import json import uuid from datetime import datetime from typing import Optional @@ -299,13 +300,39 @@ async def update_service_desired_replica_count( configuration: ServiceConfiguration, last_scaled_at: Optional[datetime], ) -> None: - scaler = get_service_scaler(configuration) stats = None if run_model.gateway_id is not None: conn = await get_or_add_gateway_connection(session, run_model.gateway_id) stats = await conn.get_stats(run_model.project.name, run_model.run_name) - run_model.desired_replica_count = scaler.get_desired_count( - current_desired_count=run_model.desired_replica_count, - stats=stats, - last_scaled_at=last_scaled_at, - ) + if configuration.replica_groups: + desired_replica_counts = {} + total = 0 + prev_counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + for group in configuration.replica_groups: + # temp group_wise config to get the group_wise desired replica count. + group_config = configuration.copy( + exclude={"replica_groups"}, + update={"replicas": group.replicas, "scaling": group.scaling}, + ) + scaler = get_service_scaler(group_config) + group_desired = scaler.get_desired_count( + current_desired_count=prev_counts.get(group.name, group.replicas.min or 0), + stats=stats, + last_scaled_at=last_scaled_at, + ) + desired_replica_counts[group.name] = group_desired + total += group_desired + run_model.desired_replica_counts = json.dumps(desired_replica_counts) + run_model.desired_replica_count = total + else: + # Todo Not required as single replica is normalized to replica_groups. + scaler = get_service_scaler(configuration) + run_model.desired_replica_count = scaler.get_desired_count( + current_desired_count=run_model.desired_replica_count, + stats=stats, + last_scaled_at=last_scaled_at, + ) From 5abbcad6ae917a209cb352cb469c1d42d7a070cc Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Tue, 23 Dec 2025 15:12:39 +0545 Subject: [PATCH 2/5] Resolve Merge Conflict & Rename replica_groups to replicas --- src/dstack/_internal/cli/utils/run.py | 8 +- .../_internal/core/models/configurations.py | 150 +++++++++--------- .../server/background/tasks/process_runs.py | 25 +-- ...a7d_add_runmodel_desired_replica_counts.py | 2 +- .../server/services/runs/__init__.py | 2 +- .../server/services/runs/replicas.py | 10 +- .../_internal/server/services/runs/spec.py | 15 +- .../server/services/services/__init__.py | 14 +- 8 files changed, 112 insertions(+), 114 deletions(-) diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 48d121bb9..84764406f 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -383,10 +383,10 @@ def get_runs_table( # Replica Group Changes: Build mapping from replica group names to indices group_name_to_index: Dict[str, int] = {} - # Replica Group Changes: Check if replica_groups attribute exists (only available for ServiceConfiguration) - replica_groups = getattr(run.run_spec.configuration, "replica_groups", None) - if replica_groups: - for idx, group in enumerate(replica_groups): + # Replica Group Changes: Check if replicas attribute exists (only available for ServiceConfiguration) + replicas = getattr(run.run_spec.configuration, "replicas", None) + if replicas: + for idx, group in enumerate(replicas): group_name_to_index[group.name] = idx run_row: Dict[Union[str, int], Any] = { diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 93dd9909d..f84d88ac0 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -612,8 +612,8 @@ class ConfigurationWithCommandsParams(CoreModel): @root_validator def check_image_or_commands_present(cls, values): - # If replica_groups is present, skip validation - commands come from replica groups - replica_groups = values.get("replica_groups") + # If replicas is present, skip validation - commands come from replica groups + replica_groups = values.get("replicas") if replica_groups: return values @@ -838,25 +838,25 @@ class ServiceConfigurationParams(CoreModel): SERVICE_HTTPS_DEFAULT ) auth: Annotated[bool, Field(description="Enable the authorization")] = True - replicas: Annotated[ - Range[int], - Field( - description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " - "If it's a range, the `scaling` property is required" - ), - ] = Range[int](min=1, max=1) - scaling: Annotated[ - Optional[ScalingSpec], - Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), - ] = None + # replicas: Annotated[ + # Range[int], + # Field( + # description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " + # "If it's a range, the `scaling` property is required" + # ), + # ] = Range[int](min=1, max=1) + # scaling: Annotated[ + # Optional[ScalingSpec], + # Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), + # ] = None rate_limits: Annotated[list[RateLimit], Field(description="Rate limiting rules")] = [] probes: Annotated[ list[ProbeConfig], Field(description="List of probes used to determine job health"), ] = [] - replica_groups: Annotated[ - Optional[List[ReplicaGroup]], + replicas: Annotated[ + Optional[Union[Range[int], List[ReplicaGroup], int, str]], Field( description=( "List of replica groups. Each group defines replicas with shared configuration " @@ -882,15 +882,15 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]: return OpenAIChatModel(type="chat", name=v, format="openai") return v - @validator("replicas") - def convert_replicas(cls, v: Range[int]) -> Range[int]: - if v.max is None: - raise ValueError("The maximum number of replicas is required") - if v.min is None: - v.min = 0 - if v.min < 0: - raise ValueError("The minimum number of replicas must be greater than or equal to 0") - return v + # @validator("replicas") + # def convert_replicas(cls, v: Range[int]) -> Range[int]: + # if v.max is None: + # raise ValueError("The maximum number of replicas is required") + # if v.min is None: + # v.min = 0 + # if v.min < 0: + # raise ValueError("The minimum number of replicas must be greater than or equal to 0") + # return v @validator("gateway") def validate_gateway( @@ -902,53 +902,43 @@ def validate_gateway( ) return v - @root_validator() - def validate_scaling(cls, values): - replica_groups = values.get("replica_groups") - # If replica_groups are set, we don't need to validate scaling. - # Each replica group has its own scaling. - if replica_groups: - return values - - scaling = values.get("scaling") - replicas = values.get("replicas") - if replicas and replicas.min != replicas.max and not scaling: - raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") - if replicas and replicas.min == replicas.max and scaling: - raise ValueError("To use `scaling`, `replicas` must be set to a range.") - return values + # @root_validator() + # def validate_scaling(cls, values): + # replica_groups = values.get("replica_groups") + # # If replica_groups are set, we don't need to validate scaling. + # # Each replica group has its own scaling. + # if replica_groups: + # return values + + # scaling = values.get("scaling") + # replicas = values.get("replicas") + # if replicas and replicas.min != replicas.max and not scaling: + # raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") + # if replicas and replicas.min == replicas.max and scaling: + # raise ValueError("To use `scaling`, `replicas` must be set to a range.") + # return values @root_validator() - def normalize_to_replica_groups(cls, values): - replica_groups = values.get("replica_groups") - if replica_groups: - return values - - # TEMP: prove we’re here and see the inputs - print( - "[normalize_to_replica_groups]", - "commands:", - values.get("commands"), - "replicas:", - values.get("replicas"), - "resources:", - values.get("resources"), - "scaling:", - values.get("scaling"), - "probes:", - values.get("probes"), - "rate_limits:", - values.get("rate_limits"), - ) - # If replica_groups is not set, we need to normalize the configuration to replica groups. - values["replica_groups"] = [ + def normalize_replicas(cls, values): + replicas = values.get("replicas") + if isinstance(replicas, list) and len(replicas) > 0: + if all(isinstance(item, ReplicaGroup) for item in replicas): + return values + + # Handle backward compatibility: convert old-style replica config to groups + old_replicas = values.get("replicas") + if isinstance(old_replicas, Range): + replica_count = old_replicas + else: + replica_count = Range[int](min=1, max=1) + values["replicas"] = [ ReplicaGroup( name="default", - replicas=values.get("replicas"), - commands=values.get("commands"), + replicas=replica_count, + commands=values.get("commands", []), resources=values.get("resources"), scaling=values.get("scaling"), - probes=values.get("probes"), + probes=values.get("probes", []), rate_limits=values.get("rate_limits"), ) ] @@ -975,22 +965,24 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: raise ValueError("Probes must be unique") return v - @validator("replica_groups") - def validate_replica_groups( - cls, v: Optional[List[ReplicaGroup]] - ) -> Optional[List[ReplicaGroup]]: + @validator("replicas") + def validate_replicas(cls, v: Optional[List[ReplicaGroup]]) -> Optional[List[ReplicaGroup]]: if v is None: return v - if not v: - raise ValueError("`replica_groups` cannot be an empty list") - # Check for duplicate names - names = [group.name for group in v] - if len(names) != len(set(names)): - duplicates = [name for name in set(names) if names.count(name) > 1] - raise ValueError( - f"Duplicate replica group names found: {duplicates}. " - "Each replica group must have a unique name." - ) + if isinstance(v, (Range, int, str)): + return v + + if isinstance(v, list): + if not v: + raise ValueError("`replicas` cannot be an empty list") + # Check for duplicate names + names = [group.name for group in v] + if len(names) != len(set(names)): + duplicates = [name for name in set(names) if names.count(name) > 1] + raise ValueError( + f"Duplicate replica group names found: {duplicates}. " + "Each replica group must have a unique name." + ) return v diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 1a1a33fa6..413756c63 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -198,7 +198,9 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): # run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": - run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0 + run_model.desired_replica_count = sum( + group.replicas.min or 0 for group in run.run_spec.configuration.replicas + ) await update_service_desired_replica_count( session, run_model, @@ -211,15 +213,14 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): # stay zero scaled return - # Per group scaling because single replica is also normalized to replica groups. - replica_groups = run.run_spec.configuration.replica_groups or [] + replicas = run.run_spec.configuration.replicas or [] counts = ( json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} ) - await scale_run_replicas_per_group(session, run_model, replica_groups, counts) + await scale_run_replicas_per_group(session, run_model, replicas, counts) else: run_model.desired_replica_count = 1 await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) @@ -460,24 +461,24 @@ async def _handle_run_replicas( # FIXME: should only include scaling events, not retries and deployments last_scaled_at=max((r.timestamp for r in replicas_info), default=None), ) - replica_groups = run_spec.configuration.replica_groups or [] - if replica_groups: + replicas = run_spec.configuration.replicas or [] + if replicas: counts = ( json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} ) - await scale_run_replicas_per_group(session, run_model, replica_groups, counts) + await scale_run_replicas_per_group(session, run_model, replicas, counts) # Handle per-group rolling deployment await _update_jobs_to_new_deployment_in_place( session=session, run_model=run_model, run_spec=run_spec, - replica_groups=replica_groups, + replicas=replicas, ) # Process per-group rolling deployment - for group in replica_groups: + for group in replicas: await _handle_rolling_deployment_for_group( session=session, run_model=run_model, @@ -554,7 +555,7 @@ async def _update_jobs_to_new_deployment_in_place( session: AsyncSession, run_model: RunModel, run_spec: RunSpec, - replica_groups: Optional[List] = None, + replicas: Optional[List] = None, ) -> None: """ Bump deployment_num for jobs that do not require redeployment. @@ -575,11 +576,11 @@ async def _update_jobs_to_new_deployment_in_place( replica_group_name = None group_run_spec = base_run_spec - if replica_groups: + if replicas: job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data) replica_group_name = job_spec.replica_group or "default" - for group in replica_groups: + for group in replicas: if group.name == replica_group_name: group_run_spec = create_group_run_spec(base_run_spec, group) break diff --git a/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py index f615560cb..af4611b3c 100644 --- a/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py +++ b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "706e0acc3a7d" -down_revision = "22d74df9897e" +down_revision = "903c91e24634" branch_labels = None depends_on = None diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index bfa9a90d8..b78c30bb7 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -520,7 +520,7 @@ async def submit_run( global_replica_num = 0 # Global counter across all groups for unique replica_num - for replica_group in service_config.replica_groups: + for replica_group in service_config.replicas: if run_spec.merged_profile.schedule is not None: group_initial_replicas = 0 else: diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index c8cb3ddc6..5884a096a 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -36,8 +36,8 @@ async def retry_run_replica_jobs( replica_group = None # Find matching replica group - if replica_group_name and base_run_spec.configuration.replica_groups: - for group in base_run_spec.configuration.replica_groups: + if replica_group_name and base_run_spec.configuration.replicas: + for group in base_run_spec.configuration.replicas: if group.name == replica_group_name: replica_group = group break @@ -245,14 +245,14 @@ async def _scale_up_replicas( async def scale_run_replicas_per_group( session: AsyncSession, run_model: RunModel, - replica_groups: List[ReplicaGroup], + replicas: List[ReplicaGroup], desired_replica_counts: Dict[str, int], ) -> None: """Scale each replica group independently""" - if not replica_groups: + if not replicas: return - for group in replica_groups: + for group in replicas: group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) # Build replica lists filtered by this group diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index 53d0c2192..8d184d8dc 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -50,7 +50,6 @@ "env", "shell", "commands", - "replica_groups", ], } @@ -89,7 +88,10 @@ def validate_run_spec_and_set_defaults( f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_RUNNING_TTL_SECONDS}s" ) if isinstance(run_spec.configuration, ServiceConfiguration): - if run_spec.merged_profile.schedule and run_spec.configuration.replicas.min == 0: + # Check if any group has min=0 + if run_spec.merged_profile.schedule and any( + group.replicas.min == 0 for group in run_spec.configuration.replicas + ): raise ServerClientError( "Scheduled services with autoscaling to zero are not supported" ) @@ -150,11 +152,10 @@ def get_nodes_required_num(run_spec: RunSpec) -> int: nodes_required_num = 1 if run_spec.configuration.type == "task": nodes_required_num = run_spec.configuration.nodes - elif ( - run_spec.configuration.type == "service" - and run_spec.configuration.replicas.min is not None - ): - nodes_required_num = run_spec.configuration.replicas.min + elif run_spec.configuration.type == "service": + nodes_required_num = sum( + group.replicas.min or 0 for group in run_spec.configuration.replicas + ) return nodes_required_num diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index d029c145e..0df2a4838 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -142,7 +142,11 @@ def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> Servi "The `https` configuration property is not applicable when running services without a gateway." " Please configure a gateway or remove the `https` property from the service configuration" ) - if run_spec.configuration.replicas.min != run_spec.configuration.replicas.max: + # Check if any group has autoscaling (min != max) + has_autoscaling = any( + group.replicas.min != group.replicas.max for group in run_spec.configuration.replicas + ) + if has_autoscaling: raise ServerClientError( "Auto-scaling is not supported when running services without a gateway." " Please configure a gateway or set `replicas` to a fixed value in the service configuration" @@ -304,7 +308,7 @@ async def update_service_desired_replica_count( if run_model.gateway_id is not None: conn = await get_or_add_gateway_connection(session, run_model.gateway_id) stats = await conn.get_stats(run_model.project.name, run_model.run_name) - if configuration.replica_groups: + if configuration.replicas: desired_replica_counts = {} total = 0 prev_counts = ( @@ -312,10 +316,10 @@ async def update_service_desired_replica_count( if run_model.desired_replica_counts else {} ) - for group in configuration.replica_groups: + for group in configuration.replicas: # temp group_wise config to get the group_wise desired replica count. group_config = configuration.copy( - exclude={"replica_groups"}, + exclude={"replicas"}, update={"replicas": group.replicas, "scaling": group.scaling}, ) scaler = get_service_scaler(group_config) @@ -329,7 +333,7 @@ async def update_service_desired_replica_count( run_model.desired_replica_counts = json.dumps(desired_replica_counts) run_model.desired_replica_count = total else: - # Todo Not required as single replica is normalized to replica_groups. + # Todo Not required as single replica is normalized to replicas. scaler = get_service_scaler(configuration) run_model.desired_replica_count = scaler.get_desired_count( current_desired_count=run_model.desired_replica_count, From abba7da9eb8a0712df548c7d1ccc800d5c2cbf1b Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Tue, 23 Dec 2025 21:57:00 +0545 Subject: [PATCH 3/5] Resolve pyright type check --- .../_internal/core/models/configurations.py | 62 +++++++------------ .../server/background/tasks/process_runs.py | 7 +-- .../server/services/runs/__init__.py | 2 +- .../_internal/server/services/runs/spec.py | 4 +- .../server/services/services/__init__.py | 29 +++++---- .../server/services/services/autoscalers.py | 29 ++++----- 6 files changed, 57 insertions(+), 76 deletions(-) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index f84d88ac0..6abdda1a4 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -838,17 +838,11 @@ class ServiceConfigurationParams(CoreModel): SERVICE_HTTPS_DEFAULT ) auth: Annotated[bool, Field(description="Enable the authorization")] = True - # replicas: Annotated[ - # Range[int], - # Field( - # description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " - # "If it's a range, the `scaling` property is required" - # ), - # ] = Range[int](min=1, max=1) - # scaling: Annotated[ - # Optional[ScalingSpec], - # Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), - # ] = None + + scaling: Annotated[ + Optional[ScalingSpec], + Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), + ] = None rate_limits: Annotated[list[RateLimit], Field(description="Rate limiting rules")] = [] probes: Annotated[ list[ProbeConfig], @@ -856,7 +850,7 @@ class ServiceConfigurationParams(CoreModel): ] = [] replicas: Annotated[ - Optional[Union[Range[int], List[ReplicaGroup], int, str]], + Optional[Union[Range[int], List[ReplicaGroup]]], Field( description=( "List of replica groups. Each group defines replicas with shared configuration " @@ -882,16 +876,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]: return OpenAIChatModel(type="chat", name=v, format="openai") return v - # @validator("replicas") - # def convert_replicas(cls, v: Range[int]) -> Range[int]: - # if v.max is None: - # raise ValueError("The maximum number of replicas is required") - # if v.min is None: - # v.min = 0 - # if v.min < 0: - # raise ValueError("The minimum number of replicas must be greater than or equal to 0") - # return v - @validator("gateway") def validate_gateway( cls, v: Optional[Union[bool, str]] @@ -902,22 +886,6 @@ def validate_gateway( ) return v - # @root_validator() - # def validate_scaling(cls, values): - # replica_groups = values.get("replica_groups") - # # If replica_groups are set, we don't need to validate scaling. - # # Each replica group has its own scaling. - # if replica_groups: - # return values - - # scaling = values.get("scaling") - # replicas = values.get("replicas") - # if replicas and replicas.min != replicas.max and not scaling: - # raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") - # if replicas and replicas.min == replicas.max and scaling: - # raise ValueError("To use `scaling`, `replicas` must be set to a range.") - # return values - @root_validator() def normalize_replicas(cls, values): replicas = values.get("replicas") @@ -966,10 +934,12 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: return v @validator("replicas") - def validate_replicas(cls, v: Optional[List[ReplicaGroup]]) -> Optional[List[ReplicaGroup]]: + def validate_replicas( + cls, v: Optional[Union[Range[int], List[ReplicaGroup]]] + ) -> Optional[Union[Range[int], List[ReplicaGroup]]]: if v is None: return v - if isinstance(v, (Range, int, str)): + if isinstance(v, Range): return v if isinstance(v, list): @@ -1007,6 +977,18 @@ class ServiceConfiguration( ): type: Literal["service"] = "service" + @property + def replica_groups(self) -> Optional[List[ReplicaGroup]]: + """ + Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None. + Use this property for type-safe access in code. + """ + if self.replicas is None: + return None + if isinstance(self.replicas, list): + return self.replicas + return None + AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration] diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 413756c63..8885c6584 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -196,10 +196,9 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): logger.debug("%s: retrying run is not yet ready for resubmission", fmt(run_model)) return - # run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": run_model.desired_replica_count = sum( - group.replicas.min or 0 for group in run.run_spec.configuration.replicas + group.replicas.min or 0 for group in (run.run_spec.configuration.replica_groups or []) ) await update_service_desired_replica_count( session, @@ -214,7 +213,7 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): return # Per group scaling because single replica is also normalized to replica groups. - replicas = run.run_spec.configuration.replicas or [] + replicas: List[ReplicaGroup] = run.run_spec.configuration.replica_groups or [] counts = ( json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts @@ -461,7 +460,7 @@ async def _handle_run_replicas( # FIXME: should only include scaling events, not retries and deployments last_scaled_at=max((r.timestamp for r in replicas_info), default=None), ) - replicas = run_spec.configuration.replicas or [] + replicas: List[ReplicaGroup] = run_spec.configuration.replica_groups or [] if replicas: counts = ( json.loads(run_model.desired_replica_counts) diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index b78c30bb7..8f5ab7509 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -520,7 +520,7 @@ async def submit_run( global_replica_num = 0 # Global counter across all groups for unique replica_num - for replica_group in service_config.replicas: + for replica_group in service_config.replica_groups or []: if run_spec.merged_profile.schedule is not None: group_initial_replicas = 0 else: diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index 8d184d8dc..66b22bc79 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -90,7 +90,7 @@ def validate_run_spec_and_set_defaults( if isinstance(run_spec.configuration, ServiceConfiguration): # Check if any group has min=0 if run_spec.merged_profile.schedule and any( - group.replicas.min == 0 for group in run_spec.configuration.replicas + group.replicas.min == 0 for group in (run_spec.configuration.replica_groups or []) ): raise ServerClientError( "Scheduled services with autoscaling to zero are not supported" @@ -154,7 +154,7 @@ def get_nodes_required_num(run_spec: RunSpec) -> int: nodes_required_num = run_spec.configuration.nodes elif run_spec.configuration.type == "service": nodes_required_num = sum( - group.replicas.min or 0 for group in run_spec.configuration.replicas + group.replicas.min or 0 for group in (run_spec.configuration.replica_groups or []) ) return nodes_required_num diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 0df2a4838..222a954cb 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -144,7 +144,8 @@ def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> Servi ) # Check if any group has autoscaling (min != max) has_autoscaling = any( - group.replicas.min != group.replicas.max for group in run_spec.configuration.replicas + group.replicas.min != group.replicas.max + for group in (run_spec.configuration.replica_groups or []) ) if has_autoscaling: raise ServerClientError( @@ -308,7 +309,8 @@ async def update_service_desired_replica_count( if run_model.gateway_id is not None: conn = await get_or_add_gateway_connection(session, run_model.gateway_id) stats = await conn.get_stats(run_model.project.name, run_model.run_name) - if configuration.replicas: + replica_groups = configuration.replica_groups or [] + if replica_groups: desired_replica_counts = {} total = 0 prev_counts = ( @@ -316,13 +318,8 @@ async def update_service_desired_replica_count( if run_model.desired_replica_counts else {} ) - for group in configuration.replicas: - # temp group_wise config to get the group_wise desired replica count. - group_config = configuration.copy( - exclude={"replicas"}, - update={"replicas": group.replicas, "scaling": group.scaling}, - ) - scaler = get_service_scaler(group_config) + for group in replica_groups: + scaler = get_service_scaler(group.replicas, group.scaling) group_desired = scaler.get_desired_count( current_desired_count=prev_counts.get(group.name, group.replicas.min or 0), stats=stats, @@ -334,9 +331,11 @@ async def update_service_desired_replica_count( run_model.desired_replica_count = total else: # Todo Not required as single replica is normalized to replicas. - scaler = get_service_scaler(configuration) - run_model.desired_replica_count = scaler.get_desired_count( - current_desired_count=run_model.desired_replica_count, - stats=stats, - last_scaled_at=last_scaled_at, - ) + if configuration.replica_groups: + first_group = configuration.replica_groups[0] + scaler = get_service_scaler(count=first_group.replicas, scaling=first_group.scaling) + run_model.desired_replica_count = scaler.get_desired_count( + current_desired_count=run_model.desired_replica_count, + stats=stats, + last_scaled_at=last_scaled_at, + ) diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index cd6d06e58..641d2cee4 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -6,7 +6,8 @@ from pydantic import BaseModel import dstack._internal.utils.common as common_utils -from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.configurations import ScalingSpec +from dstack._internal.core.models.resources import Range from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats @@ -119,21 +120,21 @@ def get_desired_count( return new_desired_count -def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: - assert conf.replicas.min is not None - assert conf.replicas.max is not None - if conf.scaling is None: +def get_service_scaler(count: Range[int], scaling: Optional[ScalingSpec]) -> BaseServiceScaler: + assert count.min is not None + assert count.max is not None + if scaling is None: return ManualScaler( - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, + min_replicas=count.min, + max_replicas=count.max, ) - if conf.scaling.metric == "rps": + if scaling.metric == "rps": return RPSAutoscaler( # replicas count validated by configuration model - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, - target=conf.scaling.target, - scale_up_delay=conf.scaling.scale_up_delay, - scale_down_delay=conf.scaling.scale_down_delay, + min_replicas=count.min, + max_replicas=count.max, + target=scaling.target, + scale_up_delay=scaling.scale_up_delay, + scale_down_delay=scaling.scale_down_delay, ) - raise ValueError(f"No scaler found for scaling parameters {conf.scaling}") + raise ValueError(f"No scaler found for scaling parameters {scaling}") From d97429269ee4efff7427ea9b84f2ec796a148bdc Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 24 Dec 2025 15:04:59 +0545 Subject: [PATCH 4/5] Rename replicas to count and make replica names optional --- .../_internal/core/models/configurations.py | 32 ++++++++++------ .../server/background/tasks/process_runs.py | 4 +- .../server/services/runs/__init__.py | 2 +- .../server/services/runs/replicas.py | 2 +- .../_internal/server/services/runs/spec.py | 4 +- .../server/services/services/__init__.py | 8 ++-- .../core/models/test_configurations.py | 37 +++++++++++++++---- 7 files changed, 60 insertions(+), 29 deletions(-) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 6abdda1a4..3583ec860 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -721,10 +721,12 @@ def schema_extra(schema: Dict[str, Any]): class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel): name: Annotated[ - str, - Field(description="The name of the replica group"), + Optional[str], + Field( + description="The name of the replica group. If not provided, defaults to 'replica0', 'replica1', etc. based on position." + ), ] - replicas: Annotated[ + count: Annotated[ Range[int], Field( description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " @@ -733,7 +735,7 @@ class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel): ] scaling: Annotated[ Optional[ScalingSpec], - Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), + Field(description="The auto-scaling rules. Required if `count` is set to a range"), ] = None probes: Annotated[ list[ProbeConfig], @@ -749,8 +751,8 @@ class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel): Field(description="The resources requirements for replicas in this group"), ] = ResourcesSpec() - @validator("replicas") - def convert_replicas(cls, v: Range[int]) -> Range[int]: + @validator("count") + def convert_count(cls, v: Range[int]) -> Range[int]: if v.max is None: raise ValueError("The maximum number of replicas is required") if v.min is None: @@ -773,11 +775,11 @@ def override_commands_validation(cls, values): @root_validator() def validate_scaling(cls, values): scaling = values.get("scaling") - replicas = values.get("replicas") - if replicas and replicas.min != replicas.max and not scaling: - raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") - if replicas and replicas.min == replicas.max and scaling: - raise ValueError("To use `scaling`, `replicas` must be set to a range.") + count = values.get("count") + if count and count.min != count.max and not scaling: + raise ValueError("When you set `count` to a range, ensure to specify `scaling`.") + if count and count.min == count.max and scaling: + raise ValueError("To use `scaling`, `count` must be set to a range.") return values @validator("rate_limits") @@ -902,7 +904,7 @@ def normalize_replicas(cls, values): values["replicas"] = [ ReplicaGroup( name="default", - replicas=replica_count, + count=replica_count, commands=values.get("commands", []), resources=values.get("resources"), scaling=values.get("scaling"), @@ -945,6 +947,12 @@ def validate_replicas( if isinstance(v, list): if not v: raise ValueError("`replicas` cannot be an empty list") + + # Assign default names to groups without names + for index, group in enumerate(v): + if group.name is None: + group.name = f"replica{index}" + # Check for duplicate names names = [group.name for group in v] if len(names) != len(set(names)): diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 8885c6584..5b84e4b86 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -198,7 +198,7 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): if run.run_spec.configuration.type == "service": run_model.desired_replica_count = sum( - group.replicas.min or 0 for group in (run.run_spec.configuration.replica_groups or []) + group.count.min or 0 for group in (run.run_spec.configuration.replica_groups or []) ) await update_service_desired_replica_count( session, @@ -692,7 +692,7 @@ async def _handle_rolling_deployment_for_group( scale_run_replicas_for_group, ) - group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) + group_desired = desired_replica_counts.get(group.name, group.count.min or 0) # Check if group has out-of-date replicas if not _has_out_of_date_replicas(run_model, group_filter=group.name): diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 8f5ab7509..fdfd59a62 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -524,7 +524,7 @@ async def submit_run( if run_spec.merged_profile.schedule is not None: group_initial_replicas = 0 else: - group_initial_replicas = replica_group.replicas.min or 0 + group_initial_replicas = replica_group.count.min or 0 # Each replica in this group gets the same group-specific configuration for group_replica_num in range(group_initial_replicas): diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index 5884a096a..3fe6de513 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -253,7 +253,7 @@ async def scale_run_replicas_per_group( return for group in replicas: - group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) + group_desired = desired_replica_counts.get(group.name, group.count.min or 0) # Build replica lists filtered by this group active_replicas, inactive_replicas = _build_replica_lists( diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index 66b22bc79..f6fba1f59 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -90,7 +90,7 @@ def validate_run_spec_and_set_defaults( if isinstance(run_spec.configuration, ServiceConfiguration): # Check if any group has min=0 if run_spec.merged_profile.schedule and any( - group.replicas.min == 0 for group in (run_spec.configuration.replica_groups or []) + group.count.min == 0 for group in (run_spec.configuration.replica_groups or []) ): raise ServerClientError( "Scheduled services with autoscaling to zero are not supported" @@ -154,7 +154,7 @@ def get_nodes_required_num(run_spec: RunSpec) -> int: nodes_required_num = run_spec.configuration.nodes elif run_spec.configuration.type == "service": nodes_required_num = sum( - group.replicas.min or 0 for group in (run_spec.configuration.replica_groups or []) + group.count.min or 0 for group in (run_spec.configuration.replica_groups or []) ) return nodes_required_num diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 222a954cb..f8a321083 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -144,7 +144,7 @@ def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> Servi ) # Check if any group has autoscaling (min != max) has_autoscaling = any( - group.replicas.min != group.replicas.max + group.count.min != group.count.max for group in (run_spec.configuration.replica_groups or []) ) if has_autoscaling: @@ -319,9 +319,9 @@ async def update_service_desired_replica_count( else {} ) for group in replica_groups: - scaler = get_service_scaler(group.replicas, group.scaling) + scaler = get_service_scaler(group.count, group.scaling) group_desired = scaler.get_desired_count( - current_desired_count=prev_counts.get(group.name, group.replicas.min or 0), + current_desired_count=prev_counts.get(group.name, group.count.min or 0), stats=stats, last_scaled_at=last_scaled_at, ) @@ -333,7 +333,7 @@ async def update_service_desired_replica_count( # Todo Not required as single replica is normalized to replicas. if configuration.replica_groups: first_group = configuration.replica_groups[0] - scaler = get_service_scaler(count=first_group.replicas, scaling=first_group.scaling) + scaler = get_service_scaler(count=first_group.count, scaling=first_group.scaling) run_model.desired_replica_count = scaler.get_desired_count( current_desired_count=run_model.desired_replica_count, stats=stats, diff --git a/src/tests/_internal/core/models/test_configurations.py b/src/tests/_internal/core/models/test_configurations.py index 79007fe19..91d54fc78 100644 --- a/src/tests/_internal/core/models/test_configurations.py +++ b/src/tests/_internal/core/models/test_configurations.py @@ -21,15 +21,32 @@ def test_conf(replicas: Any, scaling: Optional[Any] = None): conf["scaling"] = scaling return conf - assert parse_run_configuration(test_conf(1)).replicas == Range(min=1, max=1) - assert parse_run_configuration(test_conf("2")).replicas == Range(min=2, max=2) - assert parse_run_configuration(test_conf("3..3")).replicas == Range(min=3, max=3) + # assert parse_run_configuration(test_conf(1)).replicas == Range(min=1, max=1) + # assert parse_run_configuration(test_conf("2")).replicas == Range(min=2, max=2) + # assert parse_run_configuration(test_conf("3..3")).replicas == Range(min=3, max=3) + + config = parse_run_configuration(test_conf(1)) + assert len(config.replicas) == 1 + assert config.replicas[0].name == "default" + assert config.replicas[0].count == Range(min=1, max=1) + + config = parse_run_configuration(test_conf("2")) + assert len(config.replicas) == 1 + assert config.replicas[0].name == "default" + assert config.replicas[0].count == Range(min=2, max=2) + + config = parse_run_configuration(test_conf("3..3")) + assert len(config.replicas) == 1 + assert config.replicas[0].name == "default" + assert config.replicas[0].count == Range(min=3, max=3) + with pytest.raises( ConfigurationError, - match="When you set `replicas` to a range, ensure to specify `scaling`", + match="When you set `count` to a range, ensure to specify `scaling`", ): parse_run_configuration(test_conf("0..10")) - assert parse_run_configuration( + + config = parse_run_configuration( test_conf( "0..10", { @@ -37,10 +54,16 @@ def test_conf(replicas: Any, scaling: Optional[Any] = None): "target": 10, }, ) - ).replicas == Range(min=0, max=10) + ) + assert len(config.replicas) == 1 + assert config.replicas[0].name == "default" + assert config.replicas[0].count == Range(min=0, max=10) + assert config.replicas[0].scaling is not None + assert config.replicas[0].scaling.metric == "rps" + with pytest.raises( ConfigurationError, - match="When you set `replicas` to a range, ensure to specify `scaling`", + match="When you set `count` to a range, ensure to specify `scaling`", ): parse_run_configuration( test_conf( From 1ec1d6d13abedf26af291ab6eb764cba9be66ab1 Mon Sep 17 00:00:00 2001 From: Bihan Rana Date: Wed, 24 Dec 2025 18:54:05 +0545 Subject: [PATCH 5/5] Resolve review comments on probes and rate limits --- .../_internal/core/models/configurations.py | 44 +++++-------------- 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index a252acd89..a3e0124b5 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -737,14 +737,6 @@ class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel): Optional[ScalingSpec], Field(description="The auto-scaling rules. Required if `count` is set to a range"), ] = None - probes: Annotated[ - list[ProbeConfig], - Field(description="List of probes used to determine job health for this replica group"), - ] = [] - rate_limits: Annotated[ - list[RateLimit], - Field(description="Rate limiting rules for this replica group"), - ] = [] # TODO: Extract to ConfigurationWithResourcesParams mixin resources: Annotated[ ResourcesSpec, @@ -782,23 +774,6 @@ def validate_scaling(cls, values): raise ValueError("To use `scaling`, `count` must be set to a range.") return values - @validator("rate_limits") - def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: - counts = Counter(limit.prefix for limit in v) - duplicates = [prefix for prefix, count in counts.items() if count > 1] - if duplicates: - raise ValueError( - f"Prefixes {duplicates} are used more than once." - " Each rate limit should have a unique path prefix" - ) - return v - - @validator("probes") - def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: - if has_duplicates(v): - raise ValueError("Probes must be unique") - return v - class ServiceConfigurationParams(CoreModel): port: Annotated[ @@ -887,14 +862,18 @@ def validate_gateway( raise ValueError( "The `gateway` property must be a string or boolean `false`, not boolean `true`" ) + @validator("replicas") def convert_replicas(cls, v: Range[int]) -> Range[int]: - if v.max is None: - raise ValueError("The maximum number of replicas is required") - if v.min is None: - v.min = 0 - if v.min < 0: - raise ValueError("The minimum number of replicas must be greater than or equal to 0") + if isinstance(v, Range): + if v.max is None: + raise ValueError("The maximum number of replicas is required") + if v.min is None: + v.min = 0 + if v.min < 0: + raise ValueError( + "The minimum number of replicas must be greater than or equal to 0" + ) return v @root_validator() @@ -904,7 +883,6 @@ def normalize_replicas(cls, values): if all(isinstance(item, ReplicaGroup) for item in replicas): return values - # Handle backward compatibility: convert old-style replica config to groups old_replicas = values.get("replicas") if isinstance(old_replicas, Range): replica_count = old_replicas @@ -917,8 +895,6 @@ def normalize_replicas(cls, values): commands=values.get("commands", []), resources=values.get("resources"), scaling=values.get("scaling"), - probes=values.get("probes", []), - rate_limits=values.get("rate_limits"), ) ] return values