diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 0d1780bc..b2d323bd 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -917,6 +917,9 @@ async def generate_gen( # Apply temperature last to builder if params.temperature_last: sampler_builder.temperature(params.temperature) + + # Apply adaptive P + sampler_builder.adaptive_p(params.adaptive_target, params.adaptive_decay) # Build the sampler # Set greedy if temperature is 0 diff --git a/backends/exllamav3/sampler.py b/backends/exllamav3/sampler.py index 7b08a9b1..688cd183 100644 --- a/backends/exllamav3/sampler.py +++ b/backends/exllamav3/sampler.py @@ -11,6 +11,7 @@ SS_TopP, SS_Sample, SS_Base, + SS_AdaptiveP, ) @@ -21,6 +22,7 @@ class ExllamaV3SamplerBuilder: """ stack: List[SS_Base] = field(default_factory=list) + has_adaptive: bool = False def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay): self.stack += [ @@ -40,6 +42,11 @@ def top_p(self, top_p): def min_p(self, min_p): self.stack.append(SS_MinP(min_p)) + def adaptive_p(self, adaptive_target, adaptive_decay): + self.stack.append(SS_AdaptiveP(adaptive_target, adaptive_decay)) + if adaptive_target != 1.0: + self.has_adaptive = True + def greedy(self): self.stack.append(SS_Argmax()) @@ -50,5 +57,6 @@ def build(self, greedy): if greedy: return CustomSampler([SS_Argmax()]) else: - self.stack.append(SS_Sample()) + if not self.has_adaptive: + self.stack.append(SS_Sample()) return CustomSampler(self.stack) diff --git a/common/sampling.py b/common/sampling.py index 49be5b99..7333ee59 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -104,6 +104,13 @@ class BaseSamplerRequest(BaseModel): min_p: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("min_p", 0.0) ) + adaptive_target: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("adaptive_target", 0.0) + ) + + adaptive_decay: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("adaptive_decay", 0.0) + ) tfs: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("tfs", 1.0),