Skip to content

Commit 5abbcad

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Resolve Merge Conflict & Rename replica_groups to replicas
1 parent 22c1410 commit 5abbcad

File tree

8 files changed

+112
-114
lines changed

8 files changed

+112
-114
lines changed

src/dstack/_internal/cli/utils/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,10 @@ def get_runs_table(
383383

384384
# Replica Group Changes: Build mapping from replica group names to indices
385385
group_name_to_index: Dict[str, int] = {}
386-
# Replica Group Changes: Check if replica_groups attribute exists (only available for ServiceConfiguration)
387-
replica_groups = getattr(run.run_spec.configuration, "replica_groups", None)
388-
if replica_groups:
389-
for idx, group in enumerate(replica_groups):
386+
# Replica Group Changes: Check if replicas attribute exists (only available for ServiceConfiguration)
387+
replicas = getattr(run.run_spec.configuration, "replicas", None)
388+
if replicas:
389+
for idx, group in enumerate(replicas):
390390
group_name_to_index[group.name] = idx
391391

392392
run_row: Dict[Union[str, int], Any] = {

src/dstack/_internal/core/models/configurations.py

Lines changed: 71 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,8 @@ class ConfigurationWithCommandsParams(CoreModel):
612612

613613
@root_validator
614614
def check_image_or_commands_present(cls, values):
615-
# If replica_groups is present, skip validation - commands come from replica groups
616-
replica_groups = values.get("replica_groups")
615+
# If replicas is present, skip validation - commands come from replica groups
616+
replica_groups = values.get("replicas")
617617
if replica_groups:
618618
return values
619619

@@ -838,25 +838,25 @@ class ServiceConfigurationParams(CoreModel):
838838
SERVICE_HTTPS_DEFAULT
839839
)
840840
auth: Annotated[bool, Field(description="Enable the authorization")] = True
841-
replicas: Annotated[
842-
Range[int],
843-
Field(
844-
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
845-
"If it's a range, the `scaling` property is required"
846-
),
847-
] = Range[int](min=1, max=1)
848-
scaling: Annotated[
849-
Optional[ScalingSpec],
850-
Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
851-
] = None
841+
# replicas: Annotated[
842+
# Range[int],
843+
# Field(
844+
# description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
845+
# "If it's a range, the `scaling` property is required"
846+
# ),
847+
# ] = Range[int](min=1, max=1)
848+
# scaling: Annotated[
849+
# Optional[ScalingSpec],
850+
# Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
851+
# ] = None
852852
rate_limits: Annotated[list[RateLimit], Field(description="Rate limiting rules")] = []
853853
probes: Annotated[
854854
list[ProbeConfig],
855855
Field(description="List of probes used to determine job health"),
856856
] = []
857857

858-
replica_groups: Annotated[
859-
Optional[List[ReplicaGroup]],
858+
replicas: Annotated[
859+
Optional[Union[Range[int], List[ReplicaGroup], int, str]],
860860
Field(
861861
description=(
862862
"List of replica groups. Each group defines replicas with shared configuration "
@@ -882,15 +882,15 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]:
882882
return OpenAIChatModel(type="chat", name=v, format="openai")
883883
return v
884884

885-
@validator("replicas")
886-
def convert_replicas(cls, v: Range[int]) -> Range[int]:
887-
if v.max is None:
888-
raise ValueError("The maximum number of replicas is required")
889-
if v.min is None:
890-
v.min = 0
891-
if v.min < 0:
892-
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
893-
return v
885+
# @validator("replicas")
886+
# def convert_replicas(cls, v: Range[int]) -> Range[int]:
887+
# if v.max is None:
888+
# raise ValueError("The maximum number of replicas is required")
889+
# if v.min is None:
890+
# v.min = 0
891+
# if v.min < 0:
892+
# raise ValueError("The minimum number of replicas must be greater than or equal to 0")
893+
# return v
894894

895895
@validator("gateway")
896896
def validate_gateway(
@@ -902,53 +902,43 @@ def validate_gateway(
902902
)
903903
return v
904904

905-
@root_validator()
906-
def validate_scaling(cls, values):
907-
replica_groups = values.get("replica_groups")
908-
# If replica_groups are set, we don't need to validate scaling.
909-
# Each replica group has its own scaling.
910-
if replica_groups:
911-
return values
912-
913-
scaling = values.get("scaling")
914-
replicas = values.get("replicas")
915-
if replicas and replicas.min != replicas.max and not scaling:
916-
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
917-
if replicas and replicas.min == replicas.max and scaling:
918-
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
919-
return values
905+
# @root_validator()
906+
# def validate_scaling(cls, values):
907+
# replica_groups = values.get("replica_groups")
908+
# # If replica_groups are set, we don't need to validate scaling.
909+
# # Each replica group has its own scaling.
910+
# if replica_groups:
911+
# return values
912+
913+
# scaling = values.get("scaling")
914+
# replicas = values.get("replicas")
915+
# if replicas and replicas.min != replicas.max and not scaling:
916+
# raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
917+
# if replicas and replicas.min == replicas.max and scaling:
918+
# raise ValueError("To use `scaling`, `replicas` must be set to a range.")
919+
# return values
920920

921921
@root_validator()
922-
def normalize_to_replica_groups(cls, values):
923-
replica_groups = values.get("replica_groups")
924-
if replica_groups:
925-
return values
926-
927-
# TEMP: prove we’re here and see the inputs
928-
print(
929-
"[normalize_to_replica_groups]",
930-
"commands:",
931-
values.get("commands"),
932-
"replicas:",
933-
values.get("replicas"),
934-
"resources:",
935-
values.get("resources"),
936-
"scaling:",
937-
values.get("scaling"),
938-
"probes:",
939-
values.get("probes"),
940-
"rate_limits:",
941-
values.get("rate_limits"),
942-
)
943-
# If replica_groups is not set, we need to normalize the configuration to replica groups.
944-
values["replica_groups"] = [
922+
def normalize_replicas(cls, values):
923+
replicas = values.get("replicas")
924+
if isinstance(replicas, list) and len(replicas) > 0:
925+
if all(isinstance(item, ReplicaGroup) for item in replicas):
926+
return values
927+
928+
# Handle backward compatibility: convert old-style replica config to groups
929+
old_replicas = values.get("replicas")
930+
if isinstance(old_replicas, Range):
931+
replica_count = old_replicas
932+
else:
933+
replica_count = Range[int](min=1, max=1)
934+
values["replicas"] = [
945935
ReplicaGroup(
946936
name="default",
947-
replicas=values.get("replicas"),
948-
commands=values.get("commands"),
937+
replicas=replica_count,
938+
commands=values.get("commands", []),
949939
resources=values.get("resources"),
950940
scaling=values.get("scaling"),
951-
probes=values.get("probes"),
941+
probes=values.get("probes", []),
952942
rate_limits=values.get("rate_limits"),
953943
)
954944
]
@@ -975,22 +965,24 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
975965
raise ValueError("Probes must be unique")
976966
return v
977967

978-
@validator("replica_groups")
979-
def validate_replica_groups(
980-
cls, v: Optional[List[ReplicaGroup]]
981-
) -> Optional[List[ReplicaGroup]]:
968+
@validator("replicas")
969+
def validate_replicas(cls, v: Optional[List[ReplicaGroup]]) -> Optional[List[ReplicaGroup]]:
982970
if v is None:
983971
return v
984-
if not v:
985-
raise ValueError("`replica_groups` cannot be an empty list")
986-
# Check for duplicate names
987-
names = [group.name for group in v]
988-
if len(names) != len(set(names)):
989-
duplicates = [name for name in set(names) if names.count(name) > 1]
990-
raise ValueError(
991-
f"Duplicate replica group names found: {duplicates}. "
992-
"Each replica group must have a unique name."
993-
)
972+
if isinstance(v, (Range, int, str)):
973+
return v
974+
975+
if isinstance(v, list):
976+
if not v:
977+
raise ValueError("`replicas` cannot be an empty list")
978+
# Check for duplicate names
979+
names = [group.name for group in v]
980+
if len(names) != len(set(names)):
981+
duplicates = [name for name in set(names) if names.count(name) > 1]
982+
raise ValueError(
983+
f"Duplicate replica group names found: {duplicates}. "
984+
"Each replica group must have a unique name."
985+
)
994986
return v
995987

996988

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
198198

199199
# run_model.desired_replica_count = 1
200200
if run.run_spec.configuration.type == "service":
201-
run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0
201+
run_model.desired_replica_count = sum(
202+
group.replicas.min or 0 for group in run.run_spec.configuration.replicas
203+
)
202204
await update_service_desired_replica_count(
203205
session,
204206
run_model,
@@ -211,15 +213,14 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
211213
# stay zero scaled
212214
return
213215

214-
215216
# Per group scaling because single replica is also normalized to replica groups.
216-
replica_groups = run.run_spec.configuration.replica_groups or []
217+
replicas = run.run_spec.configuration.replicas or []
217218
counts = (
218219
json.loads(run_model.desired_replica_counts)
219220
if run_model.desired_replica_counts
220221
else {}
221222
)
222-
await scale_run_replicas_per_group(session, run_model, replica_groups, counts)
223+
await scale_run_replicas_per_group(session, run_model, replicas, counts)
223224
else:
224225
run_model.desired_replica_count = 1
225226
await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count)
@@ -460,24 +461,24 @@ async def _handle_run_replicas(
460461
# FIXME: should only include scaling events, not retries and deployments
461462
last_scaled_at=max((r.timestamp for r in replicas_info), default=None),
462463
)
463-
replica_groups = run_spec.configuration.replica_groups or []
464-
if replica_groups:
464+
replicas = run_spec.configuration.replicas or []
465+
if replicas:
465466
counts = (
466467
json.loads(run_model.desired_replica_counts)
467468
if run_model.desired_replica_counts
468469
else {}
469470
)
470-
await scale_run_replicas_per_group(session, run_model, replica_groups, counts)
471+
await scale_run_replicas_per_group(session, run_model, replicas, counts)
471472

472473
# Handle per-group rolling deployment
473474
await _update_jobs_to_new_deployment_in_place(
474475
session=session,
475476
run_model=run_model,
476477
run_spec=run_spec,
477-
replica_groups=replica_groups,
478+
replicas=replicas,
478479
)
479480
# Process per-group rolling deployment
480-
for group in replica_groups:
481+
for group in replicas:
481482
await _handle_rolling_deployment_for_group(
482483
session=session,
483484
run_model=run_model,
@@ -554,7 +555,7 @@ async def _update_jobs_to_new_deployment_in_place(
554555
session: AsyncSession,
555556
run_model: RunModel,
556557
run_spec: RunSpec,
557-
replica_groups: Optional[List] = None,
558+
replicas: Optional[List] = None,
558559
) -> None:
559560
"""
560561
Bump deployment_num for jobs that do not require redeployment.
@@ -575,11 +576,11 @@ async def _update_jobs_to_new_deployment_in_place(
575576
replica_group_name = None
576577
group_run_spec = base_run_spec
577578

578-
if replica_groups:
579+
if replicas:
579580
job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data)
580581
replica_group_name = job_spec.replica_group or "default"
581582

582-
for group in replica_groups:
583+
for group in replicas:
583584
if group.name == replica_group_name:
584585
group_run_spec = create_group_run_spec(base_run_spec, group)
585586
break

src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
# revision identifiers, used by Alembic.
1313
revision = "706e0acc3a7d"
14-
down_revision = "22d74df9897e"
14+
down_revision = "903c91e24634"
1515
branch_labels = None
1616
depends_on = None
1717

src/dstack/_internal/server/services/runs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ async def submit_run(
520520

521521
global_replica_num = 0 # Global counter across all groups for unique replica_num
522522

523-
for replica_group in service_config.replica_groups:
523+
for replica_group in service_config.replicas:
524524
if run_spec.merged_profile.schedule is not None:
525525
group_initial_replicas = 0
526526
else:

src/dstack/_internal/server/services/runs/replicas.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ async def retry_run_replica_jobs(
3636
replica_group = None
3737

3838
# Find matching replica group
39-
if replica_group_name and base_run_spec.configuration.replica_groups:
40-
for group in base_run_spec.configuration.replica_groups:
39+
if replica_group_name and base_run_spec.configuration.replicas:
40+
for group in base_run_spec.configuration.replicas:
4141
if group.name == replica_group_name:
4242
replica_group = group
4343
break
@@ -245,14 +245,14 @@ async def _scale_up_replicas(
245245
async def scale_run_replicas_per_group(
246246
session: AsyncSession,
247247
run_model: RunModel,
248-
replica_groups: List[ReplicaGroup],
248+
replicas: List[ReplicaGroup],
249249
desired_replica_counts: Dict[str, int],
250250
) -> None:
251251
"""Scale each replica group independently"""
252-
if not replica_groups:
252+
if not replicas:
253253
return
254254

255-
for group in replica_groups:
255+
for group in replicas:
256256
group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0)
257257

258258
# Build replica lists filtered by this group

src/dstack/_internal/server/services/runs/spec.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
"env",
5151
"shell",
5252
"commands",
53-
"replica_groups",
5453
],
5554
}
5655

@@ -89,7 +88,10 @@ def validate_run_spec_and_set_defaults(
8988
f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_RUNNING_TTL_SECONDS}s"
9089
)
9190
if isinstance(run_spec.configuration, ServiceConfiguration):
92-
if run_spec.merged_profile.schedule and run_spec.configuration.replicas.min == 0:
91+
# Check if any group has min=0
92+
if run_spec.merged_profile.schedule and any(
93+
group.replicas.min == 0 for group in run_spec.configuration.replicas
94+
):
9395
raise ServerClientError(
9496
"Scheduled services with autoscaling to zero are not supported"
9597
)
@@ -150,11 +152,10 @@ def get_nodes_required_num(run_spec: RunSpec) -> int:
150152
nodes_required_num = 1
151153
if run_spec.configuration.type == "task":
152154
nodes_required_num = run_spec.configuration.nodes
153-
elif (
154-
run_spec.configuration.type == "service"
155-
and run_spec.configuration.replicas.min is not None
156-
):
157-
nodes_required_num = run_spec.configuration.replicas.min
155+
elif run_spec.configuration.type == "service":
156+
nodes_required_num = sum(
157+
group.replicas.min or 0 for group in run_spec.configuration.replicas
158+
)
158159
return nodes_required_num
159160

160161

0 commit comments

Comments
 (0)