Skip to content

Commit 86139c5

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Add replica groups in dstack-service
add_replica_groups_model Replica Groups AutoScaling Rolling deployment and UI Replica Groups implementation clean up
1 parent 168c631 commit 86139c5

File tree

10 files changed

+728
-55
lines changed

10 files changed

+728
-55
lines changed

src/dstack/_internal/cli/utils/run.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,16 +281,38 @@ def _format_job_name(
281281
show_deployment_num: bool,
282282
show_replica: bool,
283283
show_job: bool,
284+
group_index: Optional[int] = None,
285+
last_shown_group_index: Optional[int] = None,
284286
) -> str:
285287
name_parts = []
288+
prefix = ""
286289
if show_replica:
287-
name_parts.append(f"replica={job.job_spec.replica_num}")
290+
# Show group information if replica groups are used
291+
if group_index is not None:
292+
# Show group=X replica=Y when group changes, or just replica=Y when same group
293+
if group_index != last_shown_group_index:
294+
# First job in group: use 3 spaces indent
295+
prefix = " "
296+
name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}")
297+
else:
298+
# Subsequent job in same group: align "replica=" with first job's "replica="
299+
# Calculate padding: width of " group={last_shown_group_index} "
300+
padding_width = 3 + len(f"group={last_shown_group_index}") + 1
301+
prefix = " " * padding_width
302+
name_parts.append(f"replica={job.job_spec.replica_num}")
303+
else:
304+
# Legacy behavior: no replica groups
305+
prefix = " "
306+
name_parts.append(f"replica={job.job_spec.replica_num}")
307+
else:
308+
prefix = " "
309+
288310
if show_job:
289311
name_parts.append(f"job={job.job_spec.job_num}")
290312
name_suffix = (
291313
f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else ""
292314
)
293-
name_value = " " + (" ".join(name_parts) if name_parts else "")
315+
name_value = prefix + (" ".join(name_parts) if name_parts else "")
294316
name_value += name_suffix
295317
return name_value
296318

@@ -359,6 +381,14 @@ def get_runs_table(
359381
)
360382
merge_job_rows = len(run.jobs) == 1 and not show_deployment_num
361383

384+
# Replica Group Changes: Build mapping from replica group names to indices
385+
group_name_to_index: Dict[str, int] = {}
386+
# Replica Group Changes: Check if replica_groups attribute exists (only available for ServiceConfiguration)
387+
replica_groups = getattr(run.run_spec.configuration, "replica_groups", None)
388+
if replica_groups:
389+
for idx, group in enumerate(replica_groups):
390+
group_name_to_index[group.name] = idx
391+
362392
run_row: Dict[Union[str, int], Any] = {
363393
"NAME": _format_run_name(run, show_deployment_num),
364394
"SUBMITTED": format_date(run.submitted_at),
@@ -372,13 +402,35 @@ def get_runs_table(
372402
if not merge_job_rows:
373403
add_row_from_dict(table, run_row)
374404

375-
for job in run.jobs:
405+
# Sort jobs by group index first, then by replica_num within each group
406+
def get_job_sort_key(job: Job) -> tuple:
407+
group_index = None
408+
if group_name_to_index and job.job_spec.replica_group:
409+
group_index = group_name_to_index.get(job.job_spec.replica_group)
410+
# Use a large number for jobs without groups to put them at the end
411+
return (group_index if group_index is not None else 999999, job.job_spec.replica_num)
412+
413+
sorted_jobs = sorted(run.jobs, key=get_job_sort_key)
414+
415+
last_shown_group_index: Optional[int] = None
416+
for job in sorted_jobs:
376417
latest_job_submission = job.job_submissions[-1]
377418
status_formatted = _format_job_submission_status(latest_job_submission, verbose)
378419

420+
# Get group index for this job
421+
group_index: Optional[int] = None
422+
if group_name_to_index and job.job_spec.replica_group:
423+
group_index = group_name_to_index.get(job.job_spec.replica_group)
424+
379425
job_row: Dict[Union[str, int], Any] = {
380426
"NAME": _format_job_name(
381-
job, latest_job_submission, show_deployment_num, show_replica, show_job
427+
job,
428+
latest_job_submission,
429+
show_deployment_num,
430+
show_replica,
431+
show_job,
432+
group_index=group_index,
433+
last_shown_group_index=last_shown_group_index,
382434
),
383435
"STATUS": status_formatted,
384436
"PROBES": _format_job_probes(
@@ -390,6 +442,9 @@ def get_runs_table(
390442
"GPU": "-",
391443
"PRICE": "-",
392444
}
445+
# Update last shown group index for next iteration
446+
if group_index is not None:
447+
last_shown_group_index = group_index
393448
jpd = latest_job_submission.job_provisioning_data
394449
if jpd is not None:
395450
shared_offer: Optional[InstanceOfferWithAvailability] = None

src/dstack/_internal/core/models/configurations.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel):
612612

613613
@root_validator
614614
def check_image_or_commands_present(cls, values):
615+
# If replica_groups is present, skip validation - commands come from replica groups
616+
replica_groups = values.get("replica_groups")
617+
if replica_groups:
618+
return values
619+
615620
if not values.get("commands") and not values.get("image"):
616621
raise ValueError("Either `commands` or `image` must be set")
617622
return values
@@ -714,6 +719,85 @@ def schema_extra(schema: Dict[str, Any]):
714719
)
715720

716721

722+
class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel):
723+
name: Annotated[
724+
str,
725+
Field(description="The name of the replica group"),
726+
]
727+
replicas: Annotated[
728+
Range[int],
729+
Field(
730+
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
731+
"If it's a range, the `scaling` property is required"
732+
),
733+
]
734+
scaling: Annotated[
735+
Optional[ScalingSpec],
736+
Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
737+
] = None
738+
probes: Annotated[
739+
list[ProbeConfig],
740+
Field(description="List of probes used to determine job health for this replica group"),
741+
] = []
742+
rate_limits: Annotated[
743+
list[RateLimit],
744+
Field(description="Rate limiting rules for this replica group"),
745+
] = []
746+
# TODO: Extract to ConfigurationWithResourcesParams mixin
747+
resources: Annotated[
748+
ResourcesSpec,
749+
Field(description="The resources requirements for replicas in this group"),
750+
] = ResourcesSpec()
751+
752+
@validator("replicas")
753+
def convert_replicas(cls, v: Range[int]) -> Range[int]:
754+
if v.max is None:
755+
raise ValueError("The maximum number of replicas is required")
756+
if v.min is None:
757+
v.min = 0
758+
if v.min < 0:
759+
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
760+
return v
761+
762+
@root_validator()
763+
def override_commands_validation(cls, values):
764+
"""
765+
Override parent validator from ConfigurationWithCommandsParams.
766+
ReplicaGroup always requires commands (no image option).
767+
"""
768+
commands = values.get("commands", [])
769+
if not commands:
770+
raise ValueError("`commands` must be set for replica groups")
771+
return values
772+
773+
@root_validator()
774+
def validate_scaling(cls, values):
775+
scaling = values.get("scaling")
776+
replicas = values.get("replicas")
777+
if replicas and replicas.min != replicas.max and not scaling:
778+
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
779+
if replicas and replicas.min == replicas.max and scaling:
780+
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
781+
return values
782+
783+
@validator("rate_limits")
784+
def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
785+
counts = Counter(limit.prefix for limit in v)
786+
duplicates = [prefix for prefix, count in counts.items() if count > 1]
787+
if duplicates:
788+
raise ValueError(
789+
f"Prefixes {duplicates} are used more than once."
790+
" Each rate limit should have a unique path prefix"
791+
)
792+
return v
793+
794+
@validator("probes")
795+
def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
796+
if has_duplicates(v):
797+
raise ValueError("Probes must be unique")
798+
return v
799+
800+
717801
class ServiceConfigurationParams(CoreModel):
718802
port: Annotated[
719803
# NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used.
@@ -771,6 +855,19 @@ class ServiceConfigurationParams(CoreModel):
771855
Field(description="List of probes used to determine job health"),
772856
] = []
773857

858+
replica_groups: Annotated[
859+
Optional[List[ReplicaGroup]],
860+
Field(
861+
description=(
862+
"List of replica groups. Each group defines replicas with shared configuration "
863+
"(commands, port, resources, scaling, probes, rate_limits). "
864+
"When specified, the top-level `replicas`, `commands`, `port`, `resources`, "
865+
"`scaling`, `probes`, and `rate_limits` are ignored. "
866+
"Each replica group must have a unique name."
867+
)
868+
),
869+
] = None
870+
774871
@validator("port")
775872
def convert_port(cls, v) -> PortMapping:
776873
if isinstance(v, int):
@@ -807,6 +904,12 @@ def validate_gateway(
807904

808905
@root_validator()
809906
def validate_scaling(cls, values):
907+
replica_groups = values.get("replica_groups")
908+
# If replica_groups are set, we don't need to validate scaling.
909+
# Each replica group has its own scaling.
910+
if replica_groups:
911+
return values
912+
810913
scaling = values.get("scaling")
811914
replicas = values.get("replicas")
812915
if replicas and replicas.min != replicas.max and not scaling:
@@ -815,6 +918,42 @@ def validate_scaling(cls, values):
815918
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
816919
return values
817920

921+
@root_validator()
922+
def normalize_to_replica_groups(cls, values):
923+
replica_groups = values.get("replica_groups")
924+
if replica_groups:
925+
return values
926+
927+
# TEMP: prove we’re here and see the inputs
928+
print(
929+
"[normalize_to_replica_groups]",
930+
"commands:",
931+
values.get("commands"),
932+
"replicas:",
933+
values.get("replicas"),
934+
"resources:",
935+
values.get("resources"),
936+
"scaling:",
937+
values.get("scaling"),
938+
"probes:",
939+
values.get("probes"),
940+
"rate_limits:",
941+
values.get("rate_limits"),
942+
)
943+
# If replica_groups is not set, we need to normalize the configuration to replica groups.
944+
values["replica_groups"] = [
945+
ReplicaGroup(
946+
name="default",
947+
replicas=values.get("replicas"),
948+
commands=values.get("commands"),
949+
resources=values.get("resources"),
950+
scaling=values.get("scaling"),
951+
probes=values.get("probes"),
952+
rate_limits=values.get("rate_limits"),
953+
)
954+
]
955+
return values
956+
818957
@validator("rate_limits")
819958
def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
820959
counts = Counter(limit.prefix for limit in v)
@@ -836,6 +975,24 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
836975
raise ValueError("Probes must be unique")
837976
return v
838977

978+
@validator("replica_groups")
979+
def validate_replica_groups(
980+
cls, v: Optional[List[ReplicaGroup]]
981+
) -> Optional[List[ReplicaGroup]]:
982+
if v is None:
983+
return v
984+
if not v:
985+
raise ValueError("`replica_groups` cannot be an empty list")
986+
# Check for duplicate names
987+
names = [group.name for group in v]
988+
if len(names) != len(set(names)):
989+
duplicates = [name for name in set(names) if names.count(name) > 1]
990+
raise ValueError(
991+
f"Duplicate replica group names found: {duplicates}. "
992+
"Each replica group must have a unique name."
993+
)
994+
return v
995+
839996

840997
class ServiceConfigurationConfig(
841998
ProfileParamsConfig,

src/dstack/_internal/core/models/runs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class JobSpec(CoreModel):
253253
job_num: int
254254
job_name: str
255255
jobs_per_replica: int = 1 # default value for backward compatibility
256+
replica_group: Optional[str] = "default"
256257
app_specs: Optional[List[AppSpec]]
257258
user: Optional[UnixUser] = None # default value for backward compatibility
258259
commands: List[str]

0 commit comments

Comments
 (0)