@@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel):
612612
613613 @root_validator
614614 def check_image_or_commands_present (cls , values ):
615+ # If replica_groups is present, skip validation - commands come from replica groups
616+ replica_groups = values .get ("replica_groups" )
617+ if replica_groups :
618+ return values
619+
615620 if not values .get ("commands" ) and not values .get ("image" ):
616621 raise ValueError ("Either `commands` or `image` must be set" )
617622 return values
@@ -714,6 +719,85 @@ def schema_extra(schema: Dict[str, Any]):
714719 )
715720
716721
722+ class ReplicaGroup (ConfigurationWithCommandsParams , CoreModel ):
723+ name : Annotated [
724+ str ,
725+ Field (description = "The name of the replica group" ),
726+ ]
727+ replicas : Annotated [
728+ Range [int ],
729+ Field (
730+ description = "The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
731+ "If it's a range, the `scaling` property is required"
732+ ),
733+ ]
734+ scaling : Annotated [
735+ Optional [ScalingSpec ],
736+ Field (description = "The auto-scaling rules. Required if `replicas` is set to a range" ),
737+ ] = None
738+ probes : Annotated [
739+ list [ProbeConfig ],
740+ Field (description = "List of probes used to determine job health for this replica group" ),
741+ ] = []
742+ rate_limits : Annotated [
743+ list [RateLimit ],
744+ Field (description = "Rate limiting rules for this replica group" ),
745+ ] = []
746+ # TODO: Extract to ConfigurationWithResourcesParams mixin
747+ resources : Annotated [
748+ ResourcesSpec ,
749+ Field (description = "The resources requirements for replicas in this group" ),
750+ ] = ResourcesSpec ()
751+
752+ @validator ("replicas" )
753+ def convert_replicas (cls , v : Range [int ]) -> Range [int ]:
754+ if v .max is None :
755+ raise ValueError ("The maximum number of replicas is required" )
756+ if v .min is None :
757+ v .min = 0
758+ if v .min < 0 :
759+ raise ValueError ("The minimum number of replicas must be greater than or equal to 0" )
760+ return v
761+
762+ @root_validator ()
763+ def override_commands_validation (cls , values ):
764+ """
765+ Override parent validator from ConfigurationWithCommandsParams.
766+ ReplicaGroup always requires commands (no image option).
767+ """
768+ commands = values .get ("commands" , [])
769+ if not commands :
770+ raise ValueError ("`commands` must be set for replica groups" )
771+ return values
772+
773+ @root_validator ()
774+ def validate_scaling (cls , values ):
775+ scaling = values .get ("scaling" )
776+ replicas = values .get ("replicas" )
777+ if replicas and replicas .min != replicas .max and not scaling :
778+ raise ValueError ("When you set `replicas` to a range, ensure to specify `scaling`." )
779+ if replicas and replicas .min == replicas .max and scaling :
780+ raise ValueError ("To use `scaling`, `replicas` must be set to a range." )
781+ return values
782+
783+ @validator ("rate_limits" )
784+ def validate_rate_limits (cls , v : list [RateLimit ]) -> list [RateLimit ]:
785+ counts = Counter (limit .prefix for limit in v )
786+ duplicates = [prefix for prefix , count in counts .items () if count > 1 ]
787+ if duplicates :
788+ raise ValueError (
789+ f"Prefixes { duplicates } are used more than once."
790+ " Each rate limit should have a unique path prefix"
791+ )
792+ return v
793+
794+ @validator ("probes" )
795+ def validate_probes (cls , v : list [ProbeConfig ]) -> list [ProbeConfig ]:
796+ if has_duplicates (v ):
797+ raise ValueError ("Probes must be unique" )
798+ return v
799+
800+
717801class ServiceConfigurationParams (CoreModel ):
718802 port : Annotated [
719803 # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used.
@@ -771,6 +855,19 @@ class ServiceConfigurationParams(CoreModel):
771855 Field (description = "List of probes used to determine job health" ),
772856 ] = []
773857
858+ replica_groups : Annotated [
859+ Optional [List [ReplicaGroup ]],
860+ Field (
861+ description = (
862+ "List of replica groups. Each group defines replicas with shared configuration "
863+ "(commands, port, resources, scaling, probes, rate_limits). "
864+ "When specified, the top-level `replicas`, `commands`, `port`, `resources`, "
865+ "`scaling`, `probes`, and `rate_limits` are ignored. "
866+ "Each replica group must have a unique name."
867+ )
868+ ),
869+ ] = None
870+
774871 @validator ("port" )
775872 def convert_port (cls , v ) -> PortMapping :
776873 if isinstance (v , int ):
@@ -807,6 +904,12 @@ def validate_gateway(
807904
808905 @root_validator ()
809906 def validate_scaling (cls , values ):
907+ replica_groups = values .get ("replica_groups" )
908+ # If replica_groups are set, we don't need to validate scaling.
909+ # Each replica group has its own scaling.
910+ if replica_groups :
911+ return values
912+
810913 scaling = values .get ("scaling" )
811914 replicas = values .get ("replicas" )
812915 if replicas and replicas .min != replicas .max and not scaling :
@@ -815,6 +918,42 @@ def validate_scaling(cls, values):
815918 raise ValueError ("To use `scaling`, `replicas` must be set to a range." )
816919 return values
817920
921+ @root_validator ()
922+ def normalize_to_replica_groups (cls , values ):
923+ replica_groups = values .get ("replica_groups" )
924+ if replica_groups :
925+ return values
926+
927+ # TEMP: prove we’re here and see the inputs
928+ print (
929+ "[normalize_to_replica_groups]" ,
930+ "commands:" ,
931+ values .get ("commands" ),
932+ "replicas:" ,
933+ values .get ("replicas" ),
934+ "resources:" ,
935+ values .get ("resources" ),
936+ "scaling:" ,
937+ values .get ("scaling" ),
938+ "probes:" ,
939+ values .get ("probes" ),
940+ "rate_limits:" ,
941+ values .get ("rate_limits" ),
942+ )
943+ # If replica_groups is not set, we need to normalize the configuration to replica groups.
944+ values ["replica_groups" ] = [
945+ ReplicaGroup (
946+ name = "default" ,
947+ replicas = values .get ("replicas" ),
948+ commands = values .get ("commands" ),
949+ resources = values .get ("resources" ),
950+ scaling = values .get ("scaling" ),
951+ probes = values .get ("probes" ),
952+ rate_limits = values .get ("rate_limits" ),
953+ )
954+ ]
955+ return values
956+
818957 @validator ("rate_limits" )
819958 def validate_rate_limits (cls , v : list [RateLimit ]) -> list [RateLimit ]:
820959 counts = Counter (limit .prefix for limit in v )
@@ -836,6 +975,24 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
836975 raise ValueError ("Probes must be unique" )
837976 return v
838977
978+ @validator ("replica_groups" )
979+ def validate_replica_groups (
980+ cls , v : Optional [List [ReplicaGroup ]]
981+ ) -> Optional [List [ReplicaGroup ]]:
982+ if v is None :
983+ return v
984+ if not v :
985+ raise ValueError ("`replica_groups` cannot be an empty list" )
986+ # Check for duplicate names
987+ names = [group .name for group in v ]
988+ if len (names ) != len (set (names )):
989+ duplicates = [name for name in set (names ) if names .count (name ) > 1 ]
990+ raise ValueError (
991+ f"Duplicate replica group names found: { duplicates } . "
992+ "Each replica group must have a unique name."
993+ )
994+ return v
995+
839996
840997class ServiceConfigurationConfig (
841998 ProfileParamsConfig ,
0 commit comments