diff --git a/backends/base_model_container.py b/backends/base_model_container.py
index 96393aba..cc0c5c1e 100644
--- a/backends/base_model_container.py
+++ b/backends/base_model_container.py
@@ -2,13 +2,7 @@
import asyncio
import pathlib
from loguru import logger
-from typing import (
- Any,
- AsyncIterator,
- Dict,
- List,
- Optional,
-)
+from typing import Any, AsyncIterator
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate
@@ -21,7 +15,7 @@ class BaseModelContainer(abc.ABC):
# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
- prompt_template: Optional[PromptTemplate] = None
+ prompt_template: PromptTemplate | None = None
# HF Model instance
hf_model: HFModel
@@ -34,7 +28,7 @@ class BaseModelContainer(abc.ABC):
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
- active_job_ids: Dict[str, Any] = {}
+ active_job_ids: dict[str, Any] = {}
loaded: bool = False
load_lock: asyncio.Lock
load_condition: asyncio.Condition
@@ -98,7 +92,7 @@ async def unload(self, loras_only: bool = False, **kwargs):
pass
@abc.abstractmethod
- def encode_tokens(self, text: str, **kwargs) -> List[int]:
+ def encode_tokens(self, text: str, **kwargs) -> list[int]:
"""
Encodes a string of text into a list of token IDs.
@@ -113,7 +107,7 @@ def encode_tokens(self, text: str, **kwargs) -> List[int]:
pass
@abc.abstractmethod
- def decode_tokens(self, ids: List[int], **kwargs) -> str:
+ def decode_tokens(self, ids: list[int], **kwargs) -> str:
"""
Decodes a list of token IDs back into a string.
@@ -128,7 +122,7 @@ def decode_tokens(self, ids: List[int], **kwargs) -> str:
pass
@abc.abstractmethod
- def get_special_tokens(self) -> Dict[str, Any]:
+ def get_special_tokens(self) -> dict[str, Any]:
"""
Gets special tokens used by the model/tokenizer.
@@ -164,7 +158,7 @@ async def wait_for_jobs(self, skip_wait: bool = False):
# Optional methods
async def load_loras(
self, lora_directory: pathlib.Path, **kwargs
- ) -> Dict[str, List[str]]:
+ ) -> dict[str, list[str]]:
"""
Loads LoRA adapters. Base implementation does nothing or raises error.
@@ -184,7 +178,7 @@ async def load_loras(
],
}
- def get_loras(self) -> List[Any]:
+ def get_loras(self) -> list[Any]:
"""
Gets the currently loaded LoRA adapters. Base implementation returns empty list.
@@ -200,9 +194,9 @@ async def generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
- abort_event: Optional[asyncio.Event] = None,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
- ) -> Dict[str, Any]:
+ abort_event: asyncio.Event | None = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
+ ) -> dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.
@@ -225,9 +219,9 @@ async def stream_generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
- abort_event: Optional[asyncio.Event] = None,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
- ) -> AsyncIterator[Dict[str, Any]]:
+ abort_event: asyncio.Event | None = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
+ ) -> AsyncIterator[dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.
diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py
index f9bb464c..2605bec3 100644
--- a/backends/exllamav2/grammar.py
+++ b/backends/exllamav2/grammar.py
@@ -1,7 +1,6 @@
import traceback
-import typing
from functools import lru_cache
-from typing import List
+from typing import Any
import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
@@ -16,7 +15,7 @@
class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""
- filters: List[ExLlamaV2Filter]
+ filters: list[ExLlamaV2Filter]
def __init__(self):
self.filters = []
@@ -123,7 +122,7 @@ def __init__(self, nonterminal: str, kbnf_string: str):
self.kbnf_string = kbnf_string
# Return the entire input string as the extracted string
- def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
+ def extract(self, input_str: str) -> tuple[str, Any] | None:
return "", input_str
@property
diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py
index ae71b00f..5c5c934c 100644
--- a/backends/exllamav2/model.py
+++ b/backends/exllamav2/model.py
@@ -25,7 +25,6 @@
)
from itertools import zip_longest
from loguru import logger
-from typing import Dict, List, Optional
from backends.base_model_container import BaseModelContainer
from backends.exllamav2.grammar import (
@@ -58,45 +57,45 @@ class ExllamaV2Container(BaseModelContainer):
# Model directories
model_dir: pathlib.Path = pathlib.Path("models")
draft_model_dir: pathlib.Path = pathlib.Path("models")
- prompt_template: Optional[PromptTemplate] = None
+ prompt_template: PromptTemplate | None = None
# HF model instance
hf_model: HFModel
# Exl2 vars
- config: Optional[ExLlamaV2Config] = None
- model: Optional[ExLlamaV2] = None
- cache: Optional[ExLlamaV2Cache] = None
- tokenizer: Optional[ExLlamaV2Tokenizer] = None
- generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
- prompt_template: Optional[PromptTemplate] = None
+ config: ExLlamaV2Config | None = None
+ model: ExLlamaV2 | None = None
+ cache: ExLlamaV2Cache | None = None
+ tokenizer: ExLlamaV2Tokenizer | None = None
+ generator: ExLlamaV2DynamicGeneratorAsync | None = None
+ prompt_template: PromptTemplate | None = None
paged: bool = True
# Draft model vars
use_draft_model: bool = False
- draft_config: Optional[ExLlamaV2Config] = None
- draft_model: Optional[ExLlamaV2] = None
- draft_cache: Optional[ExLlamaV2Cache] = None
+ draft_config: ExLlamaV2Config | None = None
+ draft_model: ExLlamaV2 | None = None
+ draft_cache: ExLlamaV2Cache | None = None
# Internal config vars
cache_size: int = None
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
- max_batch_size: Optional[int] = None
+ max_batch_size: int | None = None
# GPU split vars
- gpu_split: List[float] = []
- draft_gpu_split: List[float] = []
+ gpu_split: list[float] = []
+ draft_gpu_split: list[float] = []
gpu_split_auto: bool = True
- autosplit_reserve: List[float] = [96 * 1024**2]
+ autosplit_reserve: list[float] = [96 * 1024**2]
use_tp: bool = False
# Vision vars
use_vision: bool = False
- vision_model: Optional[ExLlamaV2VisionTower] = None
+ vision_model: ExLlamaV2VisionTower | None = None
# Load synchronization
- active_job_ids: Dict[str, Optional[ExLlamaV2DynamicJobAsync]] = {}
+ active_job_ids: dict[str, ExLlamaV2DynamicJobAsync | None] = {}
loaded: bool = False
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
@@ -272,6 +271,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
self.config.max_seq_len = unwrap(
user_max_seq_len, min(hf_model.hf_config.max_position_embeddings, 4096)
)
+ self.cache_size = self.config.max_seq_len
# Set the rope scale
self.config.scale_pos_emb = unwrap(
@@ -750,9 +750,9 @@ async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
- loras_to_load: List[ExLlamaV2Lora] = []
- success: List[str] = []
- failure: List[str] = []
+ loras_to_load: list[ExLlamaV2Lora] = []
+ success: list[str] = []
+ failure: list[str] = []
for lora in loras:
lora_name = lora.get("name")
@@ -869,7 +869,7 @@ def encode_tokens(self, text: str, **kwargs):
.tolist()
)
- def decode_tokens(self, ids: List[int], **kwargs):
+ def decode_tokens(self, ids: list[int], **kwargs):
"""Wrapper to decode tokens from a list of IDs"""
ids = torch.tensor([ids])
@@ -908,8 +908,8 @@ async def generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
- abort_event: Optional[asyncio.Event] = None,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
+ abort_event: asyncio.Event | None = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""Generate a response to a prompt."""
generations = []
@@ -969,8 +969,8 @@ async def stream_generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
- abort_event: Optional[asyncio.Event] = None,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
+ abort_event: asyncio.Event | None = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
try:
# Wait for load lock to be freed before processing
@@ -1242,8 +1242,8 @@ async def generate_gen(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
- abort_event: Optional[asyncio.Event] = None,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
+ abort_event: asyncio.Event | None = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""
Create generator function for prompt completion.
diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py
index d385de7c..389f581b 100644
--- a/backends/exllamav3/model.py
+++ b/backends/exllamav3/model.py
@@ -7,9 +7,6 @@
from typing import (
Any,
AsyncIterator,
- Dict,
- List,
- Optional,
)
from exllamav3 import (
@@ -49,7 +46,7 @@ class ExllamaV3Container(BaseModelContainer):
# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
- prompt_template: Optional[PromptTemplate] = None
+ prompt_template: PromptTemplate | None = None
# HF Model instance
hf_model: HFModel
@@ -58,26 +55,26 @@ class ExllamaV3Container(BaseModelContainer):
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
- active_job_ids: Dict[str, Any] = {}
+ active_job_ids: dict[str, Any] = {}
loaded: bool = False
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
# Exl3 vars
- model: Optional[Model] = None
- cache: Optional[Cache] = None
- draft_model: Optional[Model] = None
- draft_cache: Optional[Cache] = None
- tokenizer: Optional[Tokenizer] = None
- config: Optional[Config] = None
- draft_config: Optional[Config] = None
- generator: Optional[AsyncGenerator] = None
- vision_model: Optional[Model] = None
+ model: Model | None = None
+ cache: Cache | None = None
+ draft_model: Model | None = None
+ draft_cache: Cache | None = None
+ tokenizer: Tokenizer | None = None
+ config: Config | None = None
+ draft_config: Config | None = None
+ generator: AsyncGenerator | None = None
+ vision_model: Model | None = None
# Class-specific vars
- gpu_split: Optional[List[float]] = None
+ gpu_split: list[float] | None = None
gpu_split_auto: bool = True
- autosplit_reserve: Optional[List[float]] = [96 / 1024]
+ autosplit_reserve: list[float] | None = [96 / 1024]
use_tp: bool = False
tp_backend: str = "native"
max_seq_len: int = 4096
@@ -85,8 +82,8 @@ class ExllamaV3Container(BaseModelContainer):
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
chunk_size: int = 2048
- max_rq_tokens: Optional[int] = 2048
- max_batch_size: Optional[int] = None
+ max_rq_tokens: int | None = 2048
+ max_batch_size: int | None = None
# Required methods
@classmethod
@@ -579,7 +576,7 @@ async def unload(self, loras_only: bool = False, **kwargs):
async with self.load_condition:
self.load_condition.notify_all()
- def encode_tokens(self, text: str, **kwargs) -> List[int]:
+ def encode_tokens(self, text: str, **kwargs) -> list[int]:
"""
Encodes a string of text into a list of token IDs.
@@ -607,7 +604,7 @@ def encode_tokens(self, text: str, **kwargs) -> List[int]:
.tolist()
)
- def decode_tokens(self, ids: List[int], **kwargs) -> str:
+ def decode_tokens(self, ids: list[int], **kwargs) -> str:
"""
Decodes a list of token IDs back into a string.
@@ -666,9 +663,9 @@ async def generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
- abort_event: Optional[asyncio.Event] = None,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
- ) -> Dict[str, Any]:
+ abort_event: asyncio.Event | None = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
+ ) -> dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.
@@ -738,9 +735,9 @@ async def stream_generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
- abort_event: Optional[asyncio.Event] = None,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
- ) -> AsyncIterator[Dict[str, Any]]:
+ abort_event: asyncio.Event | None = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
+ ) -> AsyncIterator[dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.
@@ -859,8 +856,8 @@ async def generate_gen(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
- abort_event: Optional[asyncio.Event] = None,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
+ abort_event: asyncio.Event | None = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""
Create generator function for prompt completion.
diff --git a/backends/exllamav3/sampler.py b/backends/exllamav3/sampler.py
index 7b08a9b1..eef7d944 100644
--- a/backends/exllamav3/sampler.py
+++ b/backends/exllamav3/sampler.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
-from typing import List
from exllamav3.generator.sampler import (
CustomSampler,
SS_Temperature,
@@ -20,7 +19,7 @@ class ExllamaV3SamplerBuilder:
Custom sampler chain/stack for TabbyAPI
"""
- stack: List[SS_Base] = field(default_factory=list)
+ stack: list[SS_Base] = field(default_factory=list)
def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay):
self.stack += [
diff --git a/backends/infinity/model.py b/backends/infinity/model.py
index c131e3cb..e1dd9b1b 100644
--- a/backends/infinity/model.py
+++ b/backends/infinity/model.py
@@ -1,8 +1,9 @@
+from __future__ import annotations
+
import gc
import pathlib
import torch
from loguru import logger
-from typing import List, Optional
from common.utils import unwrap
from common.optional_dependencies import dependencies
@@ -17,7 +18,7 @@ class InfinityContainer:
loaded: bool = False
# Use a runtime type hint here
- engine: Optional["AsyncEmbeddingEngine"] = None
+ engine: AsyncEmbeddingEngine | None = None
def __init__(self, model_directory: pathlib.Path):
self.model_dir = model_directory
@@ -49,7 +50,7 @@ async def unload(self):
logger.info("Embedding model unloaded.")
- async def generate(self, sentence_input: List[str]):
+ async def generate(self, sentence_input: list[str]):
result_embeddings, usage = await self.engine.embed(sentence_input)
return {"embeddings": result_embeddings, "usage": usage}
diff --git a/common/args.py b/common/args.py
index a0da4f98..cd286877 100644
--- a/common/args.py
+++ b/common/args.py
@@ -1,7 +1,6 @@
"""Argparser for overriding config values"""
import argparse
-from typing import Optional
from pydantic import BaseModel
from common.config_models import TabbyConfigModel
@@ -25,7 +24,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None:
def init_argparser(
- existing_parser: Optional[argparse.ArgumentParser] = None,
+ existing_parser: argparse.ArgumentParser | None = None,
) -> argparse.ArgumentParser:
"""
Initializes an argparse parser based on a Pydantic config schema.
diff --git a/common/auth.py b/common/auth.py
index b02cdd02..bd93afe0 100644
--- a/common/auth.py
+++ b/common/auth.py
@@ -10,7 +10,6 @@
from fastapi import Header, HTTPException, Request
from pydantic import BaseModel
from loguru import logger
-from typing import Optional
from common.utils import coalesce
@@ -38,7 +37,7 @@ def verify_key(self, test_key: str, key_type: str):
# Global auth constants
-AUTH_KEYS: Optional[AuthKeys] = None
+AUTH_KEYS: AuthKeys | None = None
DISABLE_AUTH: bool = False
diff --git a/common/config_models.py b/common/config_models.py
index 0e71734c..c5772c42 100644
--- a/common/config_models.py
+++ b/common/config_models.py
@@ -6,17 +6,17 @@
PrivateAttr,
field_validator,
)
-from typing import List, Literal, Optional, Union
+from typing import Literal
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
-CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")]
+CACHE_TYPE = CACHE_SIZES | constr(pattern=r"^[2-8]\s*,\s*[2-8]$")
class Metadata(BaseModel):
"""metadata model for config options"""
- include_in_config: Optional[bool] = Field(True)
+ include_in_config: bool | None = Field(True)
class BaseConfigModel(BaseModel):
@@ -29,7 +29,7 @@ class ConfigOverrideConfig(BaseConfigModel):
"""Model for overriding a provided config file."""
# TODO: convert this to a pathlib.path?
- config: Optional[str] = Field(
+ config: str | None = Field(
None, description=("Path to an overriding config.yml file")
)
@@ -39,17 +39,17 @@ class ConfigOverrideConfig(BaseConfigModel):
class NetworkConfig(BaseConfigModel):
"""Options for networking"""
- host: Optional[str] = Field(
+ host: str | None = Field(
"127.0.0.1",
description=(
"The IP to host on (default: 127.0.0.1).\n"
"Use 0.0.0.0 to expose on all network adapters."
),
)
- port: Optional[int] = Field(
+ port: int | None = Field(
5000, description=("The port to host on (default: 5000).")
)
- disable_auth: Optional[bool] = Field(
+ disable_auth: bool | None = Field(
False,
description=(
"Disable HTTP token authentication with requests.\n"
@@ -57,21 +57,21 @@ class NetworkConfig(BaseConfigModel):
"Turn on this option if you are ONLY connecting from localhost."
),
)
- disable_fetch_requests: Optional[bool] = Field(
+ disable_fetch_requests: bool | None = Field(
False,
description=(
"Disable fetching external content in response to requests,"
"such as images from URLs."
),
)
- send_tracebacks: Optional[bool] = Field(
+ send_tracebacks: bool | None = Field(
False,
description=(
"Send tracebacks over the API (default: False).\n"
"NOTE: Only enable this for debug purposes."
),
)
- api_servers: Optional[List[Literal["oai", "kobold"]]] = Field(
+ api_servers: list[Literal["oai", "kobold"]] | None = Field(
["OAI"],
description=(
'Select API servers to enable (default: ["OAI"]).\n'
@@ -91,15 +91,15 @@ def api_server_validator(cls, api_servers):
class LoggingConfig(BaseConfigModel):
"""Options for logging"""
- log_prompt: Optional[bool] = Field(
+ log_prompt: bool | None = Field(
False,
description=("Enable prompt logging (default: False)."),
)
- log_generation_params: Optional[bool] = Field(
+ log_generation_params: bool | None = Field(
False,
description=("Enable generation parameter logging (default: False)."),
)
- log_requests: Optional[bool] = Field(
+ log_requests: bool | None = Field(
False,
description=(
"Enable request logging (default: False).\n"
@@ -123,7 +123,7 @@ class ModelConfig(BaseConfigModel):
"Windows users, do NOT put this path in quotes!"
),
)
- inline_model_loading: Optional[bool] = Field(
+ inline_model_loading: bool | None = Field(
False,
description=(
"Allow direct loading of models "
@@ -132,7 +132,7 @@ class ModelConfig(BaseConfigModel):
"Enable dummy models to add exceptions for invalid model names."
),
)
- use_dummy_models: Optional[bool] = Field(
+ use_dummy_models: bool | None = Field(
False,
description=(
"Sends dummy model names when the models endpoint is queried. "
@@ -140,7 +140,7 @@ class ModelConfig(BaseConfigModel):
"Enable this if the client is looking for specific OAI models.\n"
),
)
- dummy_model_names: List[str] = Field(
+ dummy_model_names: list[str] = Field(
default=["gpt-3.5-turbo"],
description=(
"A list of fake model names that are sent via the /v1/models endpoint. "
@@ -148,7 +148,7 @@ class ModelConfig(BaseConfigModel):
"Also used as bypasses for strict mode if inline_model_loading is true."
),
)
- model_name: Optional[str] = Field(
+ model_name: str | None = Field(
None,
description=(
"An initial model to load.\n"
@@ -156,7 +156,7 @@ class ModelConfig(BaseConfigModel):
"REQUIRED: This must be filled out to load a model on startup."
),
)
- use_as_default: List[str] = Field(
+ use_as_default: list[str] = Field(
default_factory=list,
description=(
"Names of args to use as a fallback for API load requests (default: []).\n"
@@ -165,22 +165,22 @@ class ModelConfig(BaseConfigModel):
"Example: ['max_seq_len', 'cache_mode']."
),
)
- backend: Optional[str] = Field(
+ backend: str | None = Field(
None,
description=(
"Backend to use for this model (auto-detect if not specified)\n"
"Options: exllamav2, exllamav3"
),
)
- max_seq_len: Optional[int] = Field(
+ max_seq_len: int | None = Field(
None,
description=(
"Max sequence length (default: 4096).\n"
- "Set to -1 to fetch from the model's config.json"
+ "set to -1 to fetch from the model's config.json"
),
ge=-1,
)
- cache_size: Optional[int] = Field(
+ cache_size: int | None = Field(
None,
description=(
"Size of the prompt cache to allocate (default: max_seq_len).\n"
@@ -190,7 +190,7 @@ class ModelConfig(BaseConfigModel):
multiple_of=256,
gt=0,
)
- cache_mode: Optional[CACHE_TYPE] = Field(
+ cache_mode: CACHE_TYPE | None = Field(
"FP16",
description=(
"Enable different cache modes for VRAM savings (default: FP16).\n"
@@ -199,7 +199,7 @@ class ModelConfig(BaseConfigModel):
"are integers from 2-8 (i.e. 8,8)."
),
)
- tensor_parallel: Optional[bool] = Field(
+ tensor_parallel: bool | None = Field(
False,
description=(
"Load model with tensor parallelism (default: False).\n"
@@ -207,7 +207,7 @@ class ModelConfig(BaseConfigModel):
"This ignores the gpu_split_auto value."
),
)
- tensor_parallel_backend: Optional[str] = Field(
+ tensor_parallel_backend: str | None = Field(
"native",
description=(
"Sets a backend type for tensor parallelism. (default: native).\n"
@@ -216,28 +216,28 @@ class ModelConfig(BaseConfigModel):
"NCCL is recommended for NVLink."
),
)
- gpu_split_auto: Optional[bool] = Field(
+ gpu_split_auto: bool | None = Field(
True,
description=(
"Automatically allocate resources to GPUs (default: True).\n"
"Not parsed for single GPU users."
),
)
- autosplit_reserve: List[float] = Field(
+ autosplit_reserve: list[float] = Field(
[96],
description=(
"Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n"
"Represented as an array of MB per GPU."
),
)
- gpu_split: List[float] = Field(
+ gpu_split: list[float] = Field(
default_factory=list,
description=(
"An integer array of GBs of VRAM to split between GPUs (default: []).\n"
"Used with tensor parallelism."
),
)
- rope_scale: Optional[float] = Field(
+ rope_scale: float | None = Field(
1.0,
description=(
"Rope scale (default: 1.0).\n"
@@ -246,7 +246,7 @@ class ModelConfig(BaseConfigModel):
"Leave blank to pull the value from the model."
),
)
- rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
+ rope_alpha: float | Literal["auto"] | None = Field(
None,
description=(
"Rope alpha (default: None).\n"
@@ -255,7 +255,7 @@ class ModelConfig(BaseConfigModel):
"or auto-calculate."
),
)
- chunk_size: Optional[int] = Field(
+ chunk_size: int | None = Field(
2048,
description=(
"Chunk size for prompt ingestion (default: 2048).\n"
@@ -265,7 +265,7 @@ class ModelConfig(BaseConfigModel):
),
gt=0,
)
- output_chunking: Optional[bool] = Field(
+ output_chunking: bool | None = Field(
True,
description=(
"Use output chunking (default: True)\n"
@@ -274,7 +274,7 @@ class ModelConfig(BaseConfigModel):
"Used by EXL3 models only.\n"
),
)
- max_batch_size: Optional[int] = Field(
+ max_batch_size: int | None = Field(
None,
description=(
"Set the maximum number of prompts to process at one time "
@@ -284,7 +284,7 @@ class ModelConfig(BaseConfigModel):
),
ge=1,
)
- prompt_template: Optional[str] = Field(
+ prompt_template: str | None = Field(
None,
description=(
"Set the prompt template for this model. (default: None)\n"
@@ -294,7 +294,7 @@ class ModelConfig(BaseConfigModel):
"NOTE: Only works with chat completion message lists!"
),
)
- vision: Optional[bool] = Field(
+ vision: bool | None = Field(
False,
description=(
"Enables vision support if the model supports it. (default: False)"
@@ -312,18 +312,18 @@ class DraftModelConfig(BaseConfigModel):
"""
# TODO: convert this to a pathlib.path?
- draft_model_dir: Optional[str] = Field(
+ draft_model_dir: str | None = Field(
"models",
description=("Directory to look for draft models (default: models)"),
)
- draft_model_name: Optional[str] = Field(
+ draft_model_name: str | None = Field(
None,
description=(
"An initial draft model to load.\n"
"Ensure the model is in the model directory."
),
)
- draft_rope_scale: Optional[float] = Field(
+ draft_rope_scale: float | None = Field(
1.0,
description=(
"Rope scale for draft models (default: 1.0).\n"
@@ -331,7 +331,7 @@ class DraftModelConfig(BaseConfigModel):
"Use if the draft model was trained on long context with rope."
),
)
- draft_rope_alpha: Optional[float] = Field(
+ draft_rope_alpha: float | None = Field(
None,
description=(
"Rope alpha for draft models (default: None).\n"
@@ -340,14 +340,14 @@ class DraftModelConfig(BaseConfigModel):
"or auto-calculate."
),
)
- draft_cache_mode: Optional[CACHE_SIZES] = Field(
+ draft_cache_mode: CACHE_SIZES | None = Field(
"FP16",
description=(
"Cache mode for draft models to save VRAM (default: FP16).\n"
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
),
)
- draft_gpu_split: List[float] = Field(
+ draft_gpu_split: list[float] = Field(
default_factory=list,
description=(
"An integer array of GBs of VRAM to split between GPUs (default: []).\n"
@@ -359,7 +359,7 @@ class DraftModelConfig(BaseConfigModel):
class SamplingConfig(BaseConfigModel):
"""Options for Sampling"""
- override_preset: Optional[str] = Field(
+ override_preset: str | None = Field(
None,
description=(
"Select a sampler override preset (default: None).\n"
@@ -376,7 +376,7 @@ class SamplingConfig(BaseConfigModel):
class LoraInstanceModel(BaseConfigModel):
"""Model representing an instance of a Lora."""
- name: Optional[str] = None
+ name: str | None = None
scaling: float = Field(1.0, ge=0)
@@ -384,10 +384,10 @@ class LoraConfig(BaseConfigModel):
"""Options for Loras"""
# TODO: convert this to a pathlib.path?
- lora_dir: Optional[str] = Field(
+ lora_dir: str | None = Field(
"loras", description=("Directory to look for LoRAs (default: loras).")
)
- loras: Optional[List[LoraInstanceModel]] = Field(
+ loras: list[LoraInstanceModel] | None = Field(
None,
description=(
"List of LoRAs to load and associated scaling factors "
@@ -407,11 +407,11 @@ class EmbeddingsConfig(BaseConfigModel):
"""
# TODO: convert this to a pathlib.path?
- embedding_model_dir: Optional[str] = Field(
+ embedding_model_dir: str | None = Field(
"models",
description=("Directory to look for embedding models (default: models)."),
)
- embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field(
+ embeddings_device: Literal["cpu", "auto", "cuda"] | None = Field(
"cpu",
description=(
"Device to load embedding models on (default: cpu).\n"
@@ -420,7 +420,7 @@ class EmbeddingsConfig(BaseConfigModel):
"If using an AMD GPU, set this value to 'cuda'."
),
)
- embedding_model_name: Optional[str] = Field(
+ embedding_model_name: str | None = Field(
None,
description=("An initial embedding model to load on the infinity backend."),
)
@@ -429,7 +429,7 @@ class EmbeddingsConfig(BaseConfigModel):
class DeveloperConfig(BaseConfigModel):
"""Options for development and experimentation"""
- unsafe_launch: Optional[bool] = Field(
+ unsafe_launch: bool | None = Field(
False,
description=(
"Skip Exllamav2 version check (default: False).\n"
@@ -437,10 +437,10 @@ class DeveloperConfig(BaseConfigModel):
"than enabling this flag."
),
)
- disable_request_streaming: Optional[bool] = Field(
+ disable_request_streaming: bool | None = Field(
False, description=("Disable API request streaming (default: False).")
)
- realtime_process_priority: Optional[bool] = Field(
+ realtime_process_priority: bool | None = Field(
False,
description=(
"Set process to use a higher priority.\n"
@@ -453,27 +453,27 @@ class DeveloperConfig(BaseConfigModel):
class TabbyConfigModel(BaseModel):
"""Base model for a TabbyConfig."""
- config: Optional[ConfigOverrideConfig] = Field(
+ config: ConfigOverrideConfig | None = Field(
default_factory=ConfigOverrideConfig.model_construct
)
- network: Optional[NetworkConfig] = Field(
+ network: NetworkConfig | None = Field(
default_factory=NetworkConfig.model_construct
)
- logging: Optional[LoggingConfig] = Field(
+ logging: LoggingConfig | None = Field(
default_factory=LoggingConfig.model_construct
)
- model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct)
- draft_model: Optional[DraftModelConfig] = Field(
+ model: ModelConfig | None = Field(default_factory=ModelConfig.model_construct)
+ draft_model: DraftModelConfig | None = Field(
default_factory=DraftModelConfig.model_construct
)
- lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct)
- embeddings: Optional[EmbeddingsConfig] = Field(
+ lora: LoraConfig | None = Field(default_factory=LoraConfig.model_construct)
+ embeddings: EmbeddingsConfig | None = Field(
default_factory=EmbeddingsConfig.model_construct
)
- sampling: Optional[SamplingConfig] = Field(
+ sampling: SamplingConfig | None = Field(
default_factory=SamplingConfig.model_construct
)
- developer: Optional[DeveloperConfig] = Field(
+ developer: DeveloperConfig | None = Field(
default_factory=DeveloperConfig.model_construct
)
diff --git a/common/downloader.py b/common/downloader.py
index 8307bbcd..8f3155f2 100644
--- a/common/downloader.py
+++ b/common/downloader.py
@@ -10,7 +10,6 @@
from fnmatch import fnmatch
from loguru import logger
from rich.progress import Progress
-from typing import List, Optional
from common.logger import get_progress_bar
from common.tabby_config import config
@@ -27,7 +26,7 @@ class RepoItem:
async def _download_file(
session: aiohttp.ClientSession,
repo_item: RepoItem,
- token: Optional[str],
+ token: str | None,
download_path: pathlib.Path,
chunk_limit: int,
progress: Progress,
@@ -92,7 +91,7 @@ def _get_repo_info(repo_id, revision, token):
]
-def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str]):
+def _get_download_folder(repo_id: str, repo_type: str, folder_name: str | None):
"""Gets the download folder for the repo."""
if repo_type == "lora":
@@ -105,7 +104,7 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
def _check_exclusions(
- filename: str, include_patterns: List[str], exclude_patterns: List[str]
+ filename: str, include_patterns: list[str], exclude_patterns: list[str]
):
include_result = any(fnmatch(filename, pattern) for pattern in include_patterns)
exclude_result = any(fnmatch(filename, pattern) for pattern in exclude_patterns)
@@ -115,14 +114,14 @@ def _check_exclusions(
async def hf_repo_download(
repo_id: str,
- folder_name: Optional[str],
- revision: Optional[str],
- token: Optional[str],
- include: Optional[List[str]],
- exclude: Optional[List[str]],
- chunk_limit: Optional[float] = None,
- timeout: Optional[int] = None,
- repo_type: Optional[str] = "model",
+ folder_name: str | None,
+ revision: str | None,
+ token: str | None,
+ include: list[str] | None,
+ exclude: list[str] | None,
+ chunk_limit: float | None = None,
+ timeout: int | None = None,
+ repo_type: str | None = "model",
):
"""Gets a repo's information from HuggingFace and downloads it locally."""
diff --git a/common/gen_logging.py b/common/gen_logging.py
index fcd3c01d..6cf5ddcb 100644
--- a/common/gen_logging.py
+++ b/common/gen_logging.py
@@ -3,7 +3,6 @@
"""
from loguru import logger
-from typing import Optional
from common.tabby_config import config
@@ -29,7 +28,7 @@ def log_generation_params(**kwargs):
logger.info(f"Generation options: {kwargs}\n")
-def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str] = None):
+def log_prompt(prompt: str, request_id: str, negative_prompt: str | None = None):
"""Logs the prompt to console."""
if config.logging.log_prompt:
formatted_prompt = "\n" + prompt
@@ -55,7 +54,7 @@ def log_response(request_id: str, response: str):
def log_metrics(
request_id: str,
metrics: dict,
- context_len: Optional[int],
+ context_len: int | None,
max_seq_len: int,
):
initial_response = (
diff --git a/common/health.py b/common/health.py
index 4d21d6af..a31cb689 100644
--- a/common/health.py
+++ b/common/health.py
@@ -3,7 +3,6 @@
from datetime import datetime, timezone
from functools import partial
from pydantic import BaseModel, Field
-from typing import Union
class UnhealthyEvent(BaseModel):
@@ -24,7 +23,7 @@ def __init__(self):
self.issues: deque[UnhealthyEvent] = deque(maxlen=100)
self._lock = asyncio.Lock()
- async def add_unhealthy_event(self, error: Union[str, Exception]):
+ async def add_unhealthy_event(self, error: str | Exception):
"""Add a new unhealthy event"""
async with self._lock:
if isinstance(error, Exception):
diff --git a/common/model.py b/common/model.py
index 4ac4861a..757688e8 100644
--- a/common/model.py
+++ b/common/model.py
@@ -10,7 +10,6 @@
from fastapi import HTTPException
from loguru import logger
from ruamel.yaml import YAML
-from typing import Dict, Optional
from backends.base_model_container import BaseModelContainer
from common.logger import get_loading_progress_bar
@@ -21,11 +20,11 @@
from common.utils import deep_merge_dict, unwrap
# Global variables for model container
-container: Optional[BaseModelContainer] = None
+container: BaseModelContainer | None = None
embeddings_container = None
-_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {}
+_BACKEND_REGISTRY: dict[str, BaseModelContainer] = {}
if dependencies.exllamav2:
from backends.exllamav2.model import ExllamaV2Container
@@ -42,7 +41,7 @@
if dependencies.extras:
from backends.infinity.model import InfinityContainer
- embeddings_container: Optional[InfinityContainer] = None
+ embeddings_container: InfinityContainer | None = None
class ModelType(Enum):
diff --git a/common/multimodal.py b/common/multimodal.py
index b92386f3..11e401dc 100644
--- a/common/multimodal.py
+++ b/common/multimodal.py
@@ -3,7 +3,6 @@
from common import model
from loguru import logger
from pydantic import BaseModel, Field
-from typing import List
from common.optional_dependencies import dependencies
@@ -18,7 +17,7 @@ class MultimodalEmbeddingWrapper(BaseModel):
type: str = None
content: list = Field(default_factory=list)
- text_alias: List[str] = Field(default_factory=list)
+ text_alias: list[str] = Field(default_factory=list)
async def add(self, url: str):
# Determine the type of vision embedding to use
diff --git a/common/networking.py b/common/networking.py
index 597ed078..f4d72917 100644
--- a/common/networking.py
+++ b/common/networking.py
@@ -7,7 +7,6 @@
from fastapi import Depends, HTTPException, Request
from loguru import logger
from pydantic import BaseModel
-from typing import Optional
from uuid import uuid4
from common.tabby_config import config
@@ -17,7 +16,7 @@ class TabbyRequestErrorMessage(BaseModel):
"""Common request error type."""
message: str
- trace: Optional[str] = None
+ trace: str | None = None
class TabbyRequestError(BaseModel):
diff --git a/common/sampling.py b/common/sampling.py
index 49be5b99..832661ed 100644
--- a/common/sampling.py
+++ b/common/sampling.py
@@ -14,7 +14,6 @@
field_validator,
model_validator,
)
-from typing import Dict, List, Optional, Union
from common.utils import filter_none_values, unwrap
@@ -23,7 +22,7 @@
class BaseSamplerRequest(BaseModel):
"""Common class for sampler params that are used in APIs"""
- max_tokens: Optional[int] = Field(
+ max_tokens: int | None = Field(
default_factory=lambda: get_default_sampler_value("max_tokens"),
validation_alias=AliasChoices(
"max_tokens", "max_completion_tokens", "max_length"
@@ -33,7 +32,7 @@ class BaseSamplerRequest(BaseModel):
ge=0,
)
- min_tokens: Optional[int] = Field(
+ min_tokens: int | None = Field(
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
validation_alias=AliasChoices("min_tokens", "min_length"),
description="Aliases: min_length",
@@ -41,76 +40,76 @@ class BaseSamplerRequest(BaseModel):
ge=0,
)
- stop: Optional[Union[str, List[Union[str, int]]]] = Field(
+ stop: str | list[str | int] | None = Field(
default_factory=lambda: get_default_sampler_value("stop", []),
validation_alias=AliasChoices("stop", "stop_sequence"),
description="Aliases: stop_sequence",
)
- banned_strings: Optional[Union[str, List[str]]] = Field(
+ banned_strings: str | list[str] | None = Field(
default_factory=lambda: get_default_sampler_value("banned_strings", [])
)
- banned_tokens: Optional[Union[List[int], str]] = Field(
+ banned_tokens: list[int] | str | None = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",
examples=[[128, 330]],
)
- allowed_tokens: Optional[Union[List[int], str]] = Field(
+ allowed_tokens: list[int] | str | None = Field(
default_factory=lambda: get_default_sampler_value("allowed_tokens", []),
validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"),
description="Aliases: allowed_token_ids",
examples=[[128, 330]],
)
- token_healing: Optional[bool] = Field(
+ token_healing: bool | None = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
- temperature: Optional[float] = Field(
+ temperature: float | None = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
ge=0,
le=10,
)
- temperature_last: Optional[bool] = Field(
+ temperature_last: bool | None = Field(
default_factory=lambda: get_default_sampler_value("temperature_last", False),
)
- smoothing_factor: Optional[float] = Field(
+ smoothing_factor: float | None = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
ge=0,
)
- top_k: Optional[int] = Field(
+ top_k: int | None = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0),
ge=-1,
)
- top_p: Optional[float] = Field(
+ top_p: float | None = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
ge=0,
le=1,
examples=[1.0],
)
- top_a: Optional[float] = Field(
+ top_a: float | None = Field(
default_factory=lambda: get_default_sampler_value("top_a", 0.0)
)
- min_p: Optional[float] = Field(
+ min_p: float | None = Field(
default_factory=lambda: get_default_sampler_value("min_p", 0.0)
)
- tfs: Optional[float] = Field(
+ tfs: float | None = Field(
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
examples=[1.0],
)
- typical: Optional[float] = Field(
+ typical: float | None = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
@@ -119,30 +118,30 @@ class BaseSamplerRequest(BaseModel):
le=1,
)
- skew: Optional[float] = Field(
+ skew: float | None = Field(
default_factory=lambda: get_default_sampler_value("skew", 0.0),
examples=[0.0],
)
- xtc_probability: Optional[float] = Field(
+ xtc_probability: float | None = Field(
default_factory=lambda: get_default_sampler_value("xtc_probability", 0.0),
)
- xtc_threshold: Optional[float] = Field(
+ xtc_threshold: float | None = Field(
default_factory=lambda: get_default_sampler_value("xtc_threshold", 0.1)
)
- frequency_penalty: Optional[float] = Field(
+ frequency_penalty: float | None = Field(
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0),
ge=0,
)
- presence_penalty: Optional[float] = Field(
+ presence_penalty: float | None = Field(
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0),
ge=0,
)
- repetition_penalty: Optional[float] = Field(
+ repetition_penalty: float | None = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
validation_alias=AliasChoices("repetition_penalty", "rep_pen"),
description="Aliases: rep_pen",
@@ -150,7 +149,7 @@ class BaseSamplerRequest(BaseModel):
gt=0,
)
- penalty_range: Optional[int] = Field(
+ penalty_range: int | None = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
@@ -163,91 +162,91 @@ class BaseSamplerRequest(BaseModel):
),
)
- repetition_decay: Optional[int] = Field(
+ repetition_decay: int | None = Field(
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
- dry_multiplier: Optional[float] = Field(
+ dry_multiplier: float | None = Field(
default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0)
)
- dry_base: Optional[float] = Field(
+ dry_base: float | None = Field(
default_factory=lambda: get_default_sampler_value("dry_base", 0.0)
)
- dry_allowed_length: Optional[int] = Field(
+ dry_allowed_length: int | None = Field(
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
)
- dry_range: Optional[int] = Field(
+ dry_range: int | None = Field(
default_factory=lambda: get_default_sampler_value("dry_range", 0),
validation_alias=AliasChoices("dry_range", "dry_penalty_last_n"),
description=("Aliases: dry_penalty_last_n"),
)
- dry_sequence_breakers: Optional[Union[str, List[str]]] = Field(
+ dry_sequence_breakers: str | list[str] | None = Field(
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
)
- mirostat_mode: Optional[int] = Field(
+ mirostat_mode: int | None = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0),
alias=AliasChoices("mirostat_mode", "mirostat"),
)
- mirostat_tau: Optional[float] = Field(
+ mirostat_tau: float | None = Field(
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
examples=[1.5],
)
- mirostat_eta: Optional[float] = Field(
+ mirostat_eta: float | None = Field(
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
examples=[0.3],
)
- add_bos_token: Optional[bool] = Field(
+ add_bos_token: bool | None = Field(
default_factory=lambda: get_default_sampler_value("add_bos_token")
)
- ban_eos_token: Optional[bool] = Field(
+ ban_eos_token: bool | None = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
validation_alias=AliasChoices("ban_eos_token", "ignore_eos"),
description="Aliases: ignore_eos",
examples=[False],
)
- logit_bias: Optional[Dict[int, float]] = Field(
+ logit_bias: dict[int, float] | None = Field(
default_factory=lambda: get_default_sampler_value("logit_bias"),
examples=[{"1": 10, "2": 50}],
)
- negative_prompt: Optional[str] = Field(
+ negative_prompt: str | None = Field(
default_factory=lambda: get_default_sampler_value("negative_prompt")
)
- json_schema: Optional[object] = Field(
+ json_schema: object | None = Field(
default_factory=lambda: get_default_sampler_value("json_schema"),
)
- regex_pattern: Optional[str] = Field(
+ regex_pattern: str | None = Field(
default_factory=lambda: get_default_sampler_value("regex_pattern"),
)
- grammar_string: Optional[str] = Field(
+ grammar_string: str | None = Field(
default_factory=lambda: get_default_sampler_value("grammar_string"),
)
- speculative_ngram: Optional[bool] = Field(
+ speculative_ngram: bool | None = Field(
default_factory=lambda: get_default_sampler_value("speculative_ngram"),
)
- cfg_scale: Optional[float] = Field(
+ cfg_scale: float | None = Field(
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
examples=[1.0],
)
- max_temp: Optional[float] = Field(
+ max_temp: float | None = Field(
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
@@ -255,7 +254,7 @@ class BaseSamplerRequest(BaseModel):
ge=0,
)
- min_temp: Optional[float] = Field(
+ min_temp: float | None = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
@@ -263,14 +262,14 @@ class BaseSamplerRequest(BaseModel):
ge=0,
)
- temp_exponent: Optional[float] = Field(
+ temp_exponent: float | None = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
ge=0,
)
- logprobs: Optional[int] = Field(
+ logprobs: int | None = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0),
ge=0,
)
@@ -335,7 +334,7 @@ def after_validate(self):
class SamplerOverridesContainer(BaseModel):
- selected_preset: Optional[str] = None
+ selected_preset: str | None = None
overrides: dict = {}
diff --git a/common/tabby_config.py b/common/tabby_config.py
index 9c4cc5d6..535cb6b2 100644
--- a/common/tabby_config.py
+++ b/common/tabby_config.py
@@ -2,7 +2,6 @@
from inspect import getdoc
from os import getenv
from textwrap import dedent
-from typing import Optional
from loguru import logger
from pydantic import BaseModel
@@ -22,7 +21,7 @@ class TabbyConfig(TabbyConfigModel):
model_defaults: dict = {}
draft_model_defaults: dict = {}
- def load(self, arguments: Optional[dict] = None):
+ def load(self, arguments: dict | None = None):
"""Synchronously loads the global application config"""
# config is applied in order of items in the list
diff --git a/common/templating.py b/common/templating.py
index cc0cceb1..864ee337 100644
--- a/common/templating.py
+++ b/common/templating.py
@@ -7,7 +7,6 @@
from dataclasses import dataclass, field
from datetime import datetime
from importlib.metadata import version as package_version
-from typing import List, Optional
from jinja2 import Template, TemplateError
from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
@@ -28,8 +27,8 @@ class TemplateLoadError(Exception):
class TemplateMetadata:
"""Represents the parsed metadata from a template."""
- stop_strings: List[str] = field(default_factory=list)
- tool_start: Optional[str] = None
+ stop_strings: list[str] = field(default_factory=list)
+ tool_start: str | None = None
class PromptTemplate:
@@ -44,7 +43,7 @@ class PromptTemplate:
enable_async=True,
extensions=[loopcontrols],
)
- metadata: Optional[TemplateMetadata] = None
+ metadata: TemplateMetadata | None = None
async def extract_metadata(self, template_vars: dict):
"""
@@ -145,7 +144,7 @@ async def from_file(cls, template_path: pathlib.Path):
@classmethod
async def from_model_json(
- cls, json_path: pathlib.Path, key: str, name: Optional[str] = None
+ cls, json_path: pathlib.Path, key: str, name: str | None = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
diff --git a/common/transformers_utils.py b/common/transformers_utils.py
index a7b0f0c1..92441564 100644
--- a/common/transformers_utils.py
+++ b/common/transformers_utils.py
@@ -3,7 +3,6 @@
import pathlib
from loguru import logger
from pydantic import BaseModel
-from typing import Dict, List, Optional, Set, Union
class GenerationConfig(BaseModel):
@@ -12,7 +11,7 @@ class GenerationConfig(BaseModel):
Will be expanded as needed.
"""
- eos_token_id: Optional[Union[int, List[int]]] = None
+ eos_token_id: int | list[int] | None = None
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
@@ -44,8 +43,8 @@ class HuggingFaceConfig(BaseModel):
"""
max_position_embeddings: int = 4096
- eos_token_id: Optional[Union[int, List[int]]] = None
- quantization_config: Optional[Dict] = None
+ eos_token_id: int | list[int] | None = None
+ quantization_config: dict | None = None
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
@@ -62,7 +61,7 @@ async def from_directory(cls, model_directory: pathlib.Path):
def quant_method(self):
"""Wrapper method to fetch quant type"""
- if isinstance(self.quantization_config, Dict):
+ if isinstance(self.quantization_config, dict):
return self.quantization_config.get("quant_method")
else:
return None
@@ -83,7 +82,7 @@ class TokenizerConfig(BaseModel):
An abridged version of HuggingFace's tokenizer config.
"""
- add_bos_token: Optional[bool] = True
+ add_bos_token: bool | None = True
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
@@ -111,8 +110,8 @@ class HFModel:
"""
hf_config: HuggingFaceConfig
- tokenizer_config: Optional[TokenizerConfig] = None
- generation_config: Optional[GenerationConfig] = None
+ tokenizer_config: TokenizerConfig | None = None
+ generation_config: GenerationConfig | None = None
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
@@ -156,7 +155,7 @@ def quant_method(self):
def eos_tokens(self):
"""Combines and returns EOS tokens from various configs"""
- eos_ids: Set[int] = set()
+ eos_ids: set[int] = set()
eos_ids.update(self.hf_config.eos_tokens())
diff --git a/common/utils.py b/common/utils.py
index b0d7ad24..da841799 100644
--- a/common/utils.py
+++ b/common/utils.py
@@ -1,12 +1,12 @@
"""Common utility functions"""
-from types import NoneType
-from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
+from types import NoneType, UnionType
+from typing import Type, get_args, get_origin, TypeVar
T = TypeVar("T")
-def unwrap(wrapped: Optional[T], default: T = None) -> T:
+def unwrap(wrapped: T | None, default: T = None) -> T:
"""Unwrap function for Optionals."""
if wrapped is None:
return default
@@ -19,7 +19,7 @@ def coalesce(*args):
return next((arg for arg in args if arg is not None), None)
-def filter_none_values(collection: Union[dict, list]) -> Union[dict, list]:
+def filter_none_values(collection: dict | list) -> dict | list:
"""Remove None values from a collection."""
if isinstance(collection, dict):
@@ -32,7 +32,7 @@ def filter_none_values(collection: Union[dict, list]) -> Union[dict, list]:
return collection
-def deep_merge_dict(dict1: Dict, dict2: Dict, copy: bool = False) -> Dict:
+def deep_merge_dict(dict1: dict, dict2: dict, copy: bool = False) -> dict:
"""
Merge 2 dictionaries. If copy is true, the original dictionary isn't modified.
"""
@@ -49,7 +49,7 @@ def deep_merge_dict(dict1: Dict, dict2: Dict, copy: bool = False) -> Dict:
return dict1
-def deep_merge_dicts(*dicts: Dict) -> Dict:
+def deep_merge_dicts(*dicts: dict) -> dict:
"""
Merge an arbitrary amount of dictionaries.
We wanna do in-place modification for each level, so do not copy.
@@ -84,11 +84,13 @@ def is_list_type(type_hint) -> bool:
def unwrap_optional_type(type_hint) -> Type:
"""
- Unwrap Optional[type] annotations.
+ Unwrap type | None annotations to extract the base type.
This is not the same as unwrap.
"""
- if get_origin(type_hint) is Union:
+ origin = get_origin(type_hint)
+
+ if origin is UnionType:
args = get_args(type_hint)
if NoneType in args:
for arg in args:
diff --git a/docs/02.-Server-options.md b/docs/02.-Server-options.md
index 98cee556..00546e1b 100644
--- a/docs/02.-Server-options.md
+++ b/docs/02.-Server-options.md
@@ -21,7 +21,7 @@ All of these options have descriptive comments above them. You should not need t
| disable_auth | Bool (False) | Disables API authentication |
| disable_fetch_requests | Bool (False) | Disables fetching external content when responding to requests (ex. fetching images from URLs) |
| send_tracebacks | Bool (False) | Send server tracebacks to client.
Note: It's not recommended to enable this if sharing the instance with others. |
-| api_servers | List[String] (["OAI"]) | API servers to enable. Possible values `"OAI", "Kobold"` |
+| api_servers | list[String] (["OAI"]) | API servers to enable. Possible values `"OAI", "Kobold"` |
### Logging Options
@@ -58,14 +58,14 @@ Note: Most of the options here will only apply on initial model load/startup (ep
| model_dir | String ("models") | Directory to look for models.
Note: Persisted across subsequent load requests |
| inline_model_loading | Bool (False) | Enables ability to switch models using the `model` argument in a generation request. More info in [Usage](https://github.com/theroyallab/tabbyAPI/wiki/03.-Usage#inline-loading) |
| use_dummy_models | Bool (False) | Send a dummy OAI model card when calling the `/v1/models` endpoint. Used for clients which enforce specific OAI models.
Note: Persisted across subsequent load requests |
-| dummy_model_names | List[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests |
+| dummy_model_names | list[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests |
| model_name | String (None) | Folder name of a model to load. The below parameters will not apply unless this is filled out. |
-| use_as_default | List[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.
Note: Also applies to the `draft` sub-block |
+| use_as_default | list[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.
Note: Also applies to the `draft` sub-block |
| max_seq_len | Float (None) | Maximum sequence length of the model. Uses the value from config.json if not specified here. Also called the max context length. |
| tensor_parallel | Bool (False) | Enables tensor parallelism. Automatically falls back to autosplit if GPU split isn't provided.
Note: `gpu_split_auto` is ignored when this is enabled. |
| gpu_split_auto | Bool (True) | Automatically split the model across multiple GPUs. Manual GPU split isn't used if this is enabled. |
-| autosplit_reserve | List[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.
Represented as an array of MB per GPU used. |
-| gpu_split | List[Float] ([]) | Float array of GBs to split a model between GPUs. |
+| autosplit_reserve | list[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.
Represented as an array of MB per GPU used. |
+| gpu_split | list[Float] ([]) | Float array of GBs to split a model between GPUs. |
| rope_scale | Float (1.0) | Adjustment for rope scale (or compress_pos_emb)
Note: If the model has YaRN support, this option will not apply. |
| rope_alpha | Float (None) | Adjustment for rope alpha. Leave blank to automatically calculate based on the max_seq_len.
Note: If the model has YaRN support, this option will not apply. |
| cache_mode | String ("FP16") | Cache mode for the model.
Options: FP16, Q8, Q6, Q4 |
@@ -86,7 +86,7 @@ Note: Sub-block of Model Options. Same rules apply.
| draft_rope_scale | Float (1.0) | String: RoPE scale value for the draft model. |
| draft_rope_alpha | Float (1.0) | RoPE alpha value for the draft model. Leave blank for auto-calculation. |
| draft_cache_mode | String ("FP16") | Cache mode for the draft model.
Options: FP16, Q8, Q6, Q4 |
-| draft_gpu_split | List[Float] ([]) | Float array of GBs to split a draft model between GPUs. |
+| draft_gpu_split | list[Float] ([]) | Float array of GBs to split a draft model between GPUs. |
### Lora Options
@@ -95,7 +95,7 @@ Note: Sub-block of Mode Options. Same rules apply.
| Config Option | Type (Default) | Description |
|---------------|------------------|--------------------------------------------------------------|
| lora_dir | String ("loras") | Directory to look for loras.
Note: Persisted across subsequent load requests |
-| loras | List[loras] ([]) | List of lora objects to apply to the model. Each object contains a name and scaling. |
+| loras | list[loras] ([]) | List of lora objects to apply to the model. Each object contains a name and scaling. |
| name | String (None) | Folder name of a lora to load.
Note: An element of the `loras` key |
| scaling | Float (1.0) | "Weight" to apply the lora on the parent model. For example, applying a lora with 0.9 scaling will lower the amount of application on the parent model.
Note: An element of the `loras` key |
diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py
index 5432130b..a58bceba 100644
--- a/endpoints/Kobold/types/generation.py
+++ b/endpoints/Kobold/types/generation.py
@@ -1,6 +1,5 @@
from functools import partial
from pydantic import BaseModel, Field, field_validator
-from typing import List, Optional
from common.sampling import BaseSamplerRequest, get_default_sampler_value
from common.utils import unwrap
@@ -8,9 +7,9 @@
class GenerateRequest(BaseSamplerRequest):
prompt: str
- genkey: Optional[str] = None
- use_default_badwordsids: Optional[bool] = False
- dynatemp_range: Optional[float] = Field(
+ genkey: str | None = None
+ use_default_badwordsids: bool | None = False
+ dynatemp_range: float | None = Field(
default_factory=partial(get_default_sampler_value, "dynatemp_range")
)
@@ -43,7 +42,7 @@ class GenerateResponseResult(BaseModel):
class GenerateResponse(BaseModel):
- results: List[GenerateResponseResult] = Field(default_factory=list)
+ results: list[GenerateResponseResult] = Field(default_factory=list)
class StreamGenerateChunk(BaseModel):
diff --git a/endpoints/Kobold/types/token.py b/endpoints/Kobold/types/token.py
index e6639d94..3f5a9fe6 100644
--- a/endpoints/Kobold/types/token.py
+++ b/endpoints/Kobold/types/token.py
@@ -1,5 +1,4 @@
from pydantic import BaseModel
-from typing import List
class TokenCountRequest(BaseModel):
@@ -12,4 +11,4 @@ class TokenCountResponse(BaseModel):
"""Represents a KAI tokenization response."""
value: int
- ids: List[int]
+ ids: list[int]
diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py
index 52523149..86118463 100644
--- a/endpoints/OAI/types/chat_completion.py
+++ b/endpoints/OAI/types/chat_completion.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
from pydantic import AliasChoices, BaseModel, Field, field_validator
from time import time
-from typing import Literal, Union, List, Optional, Dict
+from typing import Literal
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
@@ -10,11 +12,11 @@
class ChatCompletionLogprob(BaseModel):
token: str
logprob: float
- top_logprobs: Optional[List["ChatCompletionLogprob"]] = Field(default_factory=list)
+ top_logprobs: list[ChatCompletionLogprob] | None = Field(default_factory=list)
class ChatCompletionLogprobs(BaseModel):
- content: List[ChatCompletionLogprob] = Field(default_factory=list)
+ content: list[ChatCompletionLogprob] = Field(default_factory=list)
class ChatCompletionImageUrl(BaseModel):
@@ -23,58 +25,58 @@ class ChatCompletionImageUrl(BaseModel):
class ChatCompletionMessagePart(BaseModel):
type: Literal["text", "image_url"] = "text"
- text: Optional[str] = None
- image_url: Optional[ChatCompletionImageUrl] = None
+ text: str | None = None
+ image_url: ChatCompletionImageUrl | None = None
class ChatCompletionMessage(BaseModel):
role: str = "user"
- content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
- tool_calls: Optional[List[ToolCall]] = None
- tool_call_id: Optional[str] = None
+ content: str | list[ChatCompletionMessagePart] | None = None
+ tool_calls: list[ToolCall] | None = None
+ tool_call_id: str | None = None
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
- finish_reason: Optional[str] = None
+ finish_reason: str | None = None
# let's us understand why it stopped and if we need to generate a tool_call
- stop_str: Optional[str] = None
+ stop_str: str | None = None
message: ChatCompletionMessage
- logprobs: Optional[ChatCompletionLogprobs] = None
+ logprobs: ChatCompletionLogprobs | None = None
class ChatCompletionStreamChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
- finish_reason: Optional[str] = None
- delta: Union[ChatCompletionMessage, dict] = {}
- logprobs: Optional[ChatCompletionLogprobs] = None
+ finish_reason: str | None = None
+ delta: ChatCompletionMessage | dict = {}
+ logprobs: ChatCompletionLogprobs | None = None
# Inherited from common request
class ChatCompletionRequest(CommonCompletionRequest):
- messages: List[ChatCompletionMessage]
- prompt_template: Optional[str] = None
- add_generation_prompt: Optional[bool] = True
- template_vars: Optional[dict] = Field(
+ messages: list[ChatCompletionMessage]
+ prompt_template: str | None = None
+ add_generation_prompt: bool | None = True
+ template_vars: dict | None = Field(
default={},
validation_alias=AliasChoices("template_vars", "chat_template_kwargs"),
description="Aliases: chat_template_kwargs",
)
- response_prefix: Optional[str] = None
- model: Optional[str] = None
+ response_prefix: str | None = None
+ model: str | None = None
# tools is follows the format OAI schema, functions is more flexible
# both are available in the chat template.
- tools: Optional[List[ToolSpec]] = None
- functions: Optional[List[Dict]] = None
+ tools: list[ToolSpec] | None = None
+ functions: list[dict] | None = None
# Chat completions requests do not have a BOS token preference. Backend
# respects the tokenization config for the individual model.
- add_bos_token: Optional[bool] = None
+ add_bos_token: bool | None = None
@field_validator("add_bos_token", mode="after")
def force_bos_token(cls, v):
@@ -84,17 +86,17 @@ def force_bos_token(cls, v):
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
- choices: List[ChatCompletionRespChoice]
+ choices: list[ChatCompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion"
- usage: Optional[UsageStats] = None
+ usage: UsageStats | None = None
class ChatCompletionStreamChunk(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
- choices: List[ChatCompletionStreamChoice]
+ choices: list[ChatCompletionStreamChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion.chunk"
- usage: Optional[UsageStats] = None
+ usage: UsageStats | None = None
diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py
index 16ef2edd..fb7a4a82 100644
--- a/endpoints/OAI/types/common.py
+++ b/endpoints/OAI/types/common.py
@@ -1,7 +1,6 @@
"""Common types for OAI."""
from pydantic import BaseModel, Field
-from typing import Optional, Union
from common.sampling import BaseSamplerRequest, get_default_sampler_value
@@ -10,13 +9,13 @@ class UsageStats(BaseModel):
"""Represents usage stats."""
prompt_tokens: int
- prompt_time: Optional[float] = None
- prompt_tokens_per_sec: Optional[Union[float, str]] = None
+ prompt_time: float | None = None
+ prompt_tokens_per_sec: float | str | None = None
completion_tokens: int
- completion_time: Optional[float] = None
- completion_tokens_per_sec: Optional[Union[float, str]] = None
+ completion_time: float | None = None
+ completion_tokens_per_sec: float | str | None = None
total_tokens: int
- total_time: Optional[float] = None
+ total_time: float | None = None
class CompletionResponseFormat(BaseModel):
@@ -24,7 +23,7 @@ class CompletionResponseFormat(BaseModel):
class ChatCompletionStreamOptions(BaseModel):
- include_usage: Optional[bool] = False
+ include_usage: bool | None = False
class CommonCompletionRequest(BaseSamplerRequest):
@@ -32,29 +31,29 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Model information
# This parameter is not used, the loaded model is used instead
- model: Optional[str] = None
+ model: str | None = None
# Generation info (remainder is in BaseSamplerRequest superclass)
- stream: Optional[bool] = False
- stream_options: Optional[ChatCompletionStreamOptions] = None
- response_format: Optional[CompletionResponseFormat] = Field(
+ stream: bool | None = False
+ stream_options: ChatCompletionStreamOptions | None = None
+ response_format: CompletionResponseFormat | None = Field(
default_factory=CompletionResponseFormat
)
- n: Optional[int] = Field(
+ n: int | None = Field(
default_factory=lambda: get_default_sampler_value("n", 1),
ge=1,
)
# Extra OAI request stuff
- best_of: Optional[int] = Field(
+ best_of: int | None = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
- echo: Optional[bool] = Field(
+ echo: bool | None = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
- suffix: Optional[str] = Field(
+ suffix: str | None = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
- user: Optional[str] = Field(
+ user: str | None = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
diff --git a/endpoints/OAI/types/completion.py b/endpoints/OAI/types/completion.py
index d0a7187e..bb3f8443 100644
--- a/endpoints/OAI/types/completion.py
+++ b/endpoints/OAI/types/completion.py
@@ -2,7 +2,6 @@
from pydantic import BaseModel, Field
from time import time
-from typing import Dict, List, Optional, Union
from uuid import uuid4
from endpoints.OAI.types.common import CommonCompletionRequest, UsageStats
@@ -11,10 +10,10 @@
class CompletionLogProbs(BaseModel):
"""Represents log probabilities for a completion request."""
- text_offset: List[int] = Field(default_factory=list)
- token_logprobs: List[Optional[float]] = Field(default_factory=list)
- tokens: List[str] = Field(default_factory=list)
- top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
+ text_offset: list[int] = Field(default_factory=list)
+ token_logprobs: list[float | None] = Field(default_factory=list)
+ tokens: list[str] = Field(default_factory=list)
+ top_logprobs: list[dict[str, float] | None] = Field(default_factory=list)
class CompletionRespChoice(BaseModel):
@@ -22,8 +21,8 @@ class CompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
- finish_reason: Optional[str] = None
- logprobs: Optional[CompletionLogProbs] = None
+ finish_reason: str | None = None
+ logprobs: CompletionLogProbs | None = None
text: str
@@ -33,15 +32,15 @@ class CompletionRequest(CommonCompletionRequest):
# Prompt can also contain token ids, but that's out of scope
# for this project.
- prompt: Union[str, List[str]]
+ prompt: str | list[str]
class CompletionResponse(BaseModel):
"""Represents a completion response."""
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
- choices: List[CompletionRespChoice]
+ choices: list[CompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "text_completion"
- usage: Optional[UsageStats] = None
+ usage: UsageStats | None = None
diff --git a/endpoints/OAI/types/embedding.py b/endpoints/OAI/types/embedding.py
index 41419c4d..abb68fd7 100644
--- a/endpoints/OAI/types/embedding.py
+++ b/endpoints/OAI/types/embedding.py
@@ -1,23 +1,21 @@
-from typing import List, Optional, Union
-
from pydantic import BaseModel, Field
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
- completion_tokens: Optional[int] = 0
+ completion_tokens: int | None = 0
class EmbeddingsRequest(BaseModel):
- input: Union[str, List[str]] = Field(
+ input: str | list[str] = Field(
..., description="List of input texts to generate embeddings for."
)
encoding_format: str = Field(
"float",
description="Encoding format for the embeddings. Can be 'float' or 'base64'.",
)
- model: Optional[str] = Field(
+ model: str | None = Field(
None,
description="Name of the embedding model to use. "
"If not provided, the default model will be used.",
@@ -26,7 +24,7 @@ class EmbeddingsRequest(BaseModel):
class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.")
- embedding: Union[List[float], str] = Field(
+ embedding: list[float] | str = Field(
..., description="Embedding values as a list of floats."
)
index: int = Field(
@@ -36,6 +34,6 @@ class EmbeddingObject(BaseModel):
class EmbeddingsResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
- data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
+ data: list[EmbeddingObject] = Field(..., description="List of embedding objects.")
model: str = Field(..., description="Name of the embedding model used.")
usage: UsageInfo = Field(..., description="Information about token usage.")
diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py
index b5b9611f..138fda9d 100644
--- a/endpoints/OAI/types/tools.py
+++ b/endpoints/OAI/types/tools.py
@@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
-from typing import Dict, Literal
+from typing import Literal
from uuid import uuid4
@@ -8,7 +8,7 @@ class Function(BaseModel):
name: str
description: str
- parameters: Dict[str, object]
+ parameters: dict[str, object]
class ToolSpec(BaseModel):
diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py
index b559bb2b..157c96ba 100644
--- a/endpoints/OAI/utils/chat_completion.py
+++ b/endpoints/OAI/utils/chat_completion.py
@@ -3,7 +3,6 @@
import asyncio
import pathlib
from asyncio import CancelledError
-from typing import List, Optional
from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
@@ -33,7 +32,7 @@
def _create_response(
- request_id: str, generations: List[dict], model_name: Optional[str]
+ request_id: str, generations: list[dict], model_name: str | None
):
"""Create a chat completion response from the provided text."""
@@ -111,8 +110,8 @@ def _create_response(
def _create_stream_chunk(
request_id: str,
- generation: Optional[dict] = None,
- model_name: Optional[str] = None,
+ generation: dict | None = None,
+ model_name: str | None = None,
is_usage_chunk: bool = False,
):
"""Create a chat completion stream chunk from the provided text."""
@@ -212,8 +211,8 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
async def format_messages_with_template(
- messages: List[ChatCompletionMessage],
- existing_template_vars: Optional[dict] = None,
+ messages: list[ChatCompletionMessage],
+ existing_template_vars: dict | None = None,
):
"""Barebones function to format chat completion messages into a prompt."""
@@ -221,7 +220,7 @@ async def format_messages_with_template(
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
# Convert all messages to a dictionary representation
- message_dicts: List[dict] = []
+ message_dicts: list[dict] = []
for message in messages:
if isinstance(message.content, list):
concatenated_content = ""
@@ -317,7 +316,7 @@ async def stream_generate_chat_completion(
"""Generator for the generation process."""
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
- gen_tasks: List[asyncio.Task] = []
+ gen_tasks: list[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
@@ -414,7 +413,7 @@ async def generate_chat_completion(
request: Request,
model_path: pathlib.Path,
):
- gen_tasks: List[asyncio.Task] = []
+ gen_tasks: list[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
try:
@@ -462,14 +461,14 @@ async def generate_tool_calls(
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
- generations: List[str],
+ generations: list[str],
request: Request,
):
- gen_tasks: List[asyncio.Task] = []
+ gen_tasks: list[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
# Tracks which generations asked for a tool call
- tool_idx: List[int] = []
+ tool_idx: list[int] = []
# Copy to make sure the parent JSON schema doesn't get modified
tool_data = data.model_copy(deep=True)
diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py
index f66d381d..4de10361 100644
--- a/endpoints/OAI/utils/completion.py
+++ b/endpoints/OAI/utils/completion.py
@@ -9,7 +9,6 @@
from asyncio import CancelledError
from fastapi import HTTPException, Request
from loguru import logger
-from typing import List, Optional, Union
from common import model
from common.auth import get_key_permission
@@ -39,7 +38,7 @@ def _parse_gen_request_id(n: int, request_id: str, task_idx: int):
def _create_response(
- request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
+ request_id: str, generations: dict | list[dict], model_name: str = ""
):
"""Create a completion response from the provided choices."""
@@ -47,7 +46,7 @@ def _create_response(
if not isinstance(generations, list):
generations = [generations]
- choices: List[CompletionRespChoice] = []
+ choices: list[CompletionRespChoice] = []
for index, generation in enumerate(generations):
logprob_response = None
@@ -103,7 +102,7 @@ async def _stream_collector(
prompt: str,
params: CompletionRequest,
abort_event: asyncio.Event,
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""Collects a stream and places results in a common queue"""
@@ -200,7 +199,7 @@ async def stream_generate_completion(
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
- gen_tasks: List[asyncio.Task] = []
+ gen_tasks: list[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
@@ -261,7 +260,7 @@ async def generate_completion(
):
"""Non-streaming generate for completions"""
- gen_tasks: List[asyncio.Task] = []
+ gen_tasks: list[asyncio.Task] = []
try:
logger.info(f"Received completion request {request.state.id}")
diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py
index c1ebdedf..66e466fe 100644
--- a/endpoints/OAI/utils/tools.py
+++ b/endpoints/OAI/utils/tools.py
@@ -1,6 +1,5 @@
import json
from loguru import logger
-from typing import List
from endpoints.OAI.types.tools import ToolCall
@@ -30,7 +29,7 @@
class ToolCallProcessor:
@staticmethod
- def from_json(tool_calls_str: str) -> List[ToolCall]:
+ def from_json(tool_calls_str: str) -> list[ToolCall]:
"""Postprocess tool call JSON to a parseable class"""
tool_calls = json.loads(tool_calls_str)
@@ -42,15 +41,15 @@ def from_json(tool_calls_str: str) -> List[ToolCall]:
return [ToolCall(**tool_call) for tool_call in tool_calls]
@staticmethod
- def dump(tool_calls: List[ToolCall]) -> List[dict]:
+ def dump(tool_calls: list[ToolCall]) -> list[dict]:
"""
Convert ToolCall objects to a list of dictionaries.
Args:
- tool_calls (List[ToolCall]): List of ToolCall objects to convert
+ tool_calls (list[ToolCall]): list of ToolCall objects to convert
Returns:
- List[dict]: List of dictionaries representing the tool calls
+ list[dict]: list of dictionaries representing the tool calls
"""
# Don't use list comprehension here
@@ -64,12 +63,12 @@ def dump(tool_calls: List[ToolCall]) -> List[dict]:
return dumped_tool_calls
@staticmethod
- def to_json(tool_calls: List[ToolCall]) -> str:
+ def to_json(tool_calls: list[ToolCall]) -> str:
"""
Convert ToolCall objects to JSON string representation.
Args:
- tool_calls (List[ToolCall]): List of ToolCall objects to convert
+ tool_calls (list[ToolCall]): list of ToolCall objects to convert
Returns:
str: JSON representation of the tool calls
diff --git a/endpoints/core/router.py b/endpoints/core/router.py
index 800c7d90..d69d683f 100644
--- a/endpoints/core/router.py
+++ b/endpoints/core/router.py
@@ -1,7 +1,6 @@
import asyncio
import pathlib
from sys import maxsize
-from typing import Optional
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import JSONResponse
@@ -403,7 +402,7 @@ async def unload_embedding_model():
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
- mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None
+ mm_embeddings: MultimodalEmbeddingWrapper | None = None
if isinstance(data.text, str):
text = data.text
diff --git a/endpoints/core/types/download.py b/endpoints/core/types/download.py
index cf49501f..c9f8a040 100644
--- a/endpoints/core/types/download.py
+++ b/endpoints/core/types/download.py
@@ -1,5 +1,4 @@
from pydantic import BaseModel, Field
-from typing import List, Optional
def _generate_include_list():
@@ -11,13 +10,13 @@ class DownloadRequest(BaseModel):
repo_id: str
repo_type: str = "model"
- folder_name: Optional[str] = None
- revision: Optional[str] = None
- token: Optional[str] = None
- include: List[str] = Field(default_factory=_generate_include_list)
- exclude: List[str] = Field(default_factory=list)
- chunk_limit: Optional[int] = None
- timeout: Optional[int] = None
+ folder_name: str | None = None
+ revision: str | None = None
+ token: str | None = None
+ include: list[str] = Field(default_factory=_generate_include_list)
+ exclude: list[str] = Field(default_factory=list)
+ chunk_limit: int | None = None
+ timeout: int | None = None
class DownloadResponse(BaseModel):
diff --git a/endpoints/core/types/health.py b/endpoints/core/types/health.py
index ad5fffef..189fa008 100644
--- a/endpoints/core/types/health.py
+++ b/endpoints/core/types/health.py
@@ -11,5 +11,5 @@ class HealthCheckResponse(BaseModel):
"healthy", description="System health status"
)
issues: list[UnhealthyEvent] = Field(
- default_factory=list, description="List of issues"
+ default_factory=list, description="list of issues"
)
diff --git a/endpoints/core/types/lora.py b/endpoints/core/types/lora.py
index 8435a8a4..88e9ba8f 100644
--- a/endpoints/core/types/lora.py
+++ b/endpoints/core/types/lora.py
@@ -2,7 +2,6 @@
from pydantic import BaseModel, Field
from time import time
-from typing import Optional, List
class LoraCard(BaseModel):
@@ -12,32 +11,32 @@ class LoraCard(BaseModel):
object: str = "lora"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
- scaling: Optional[float] = None
+ scaling: float | None = None
class LoraList(BaseModel):
"""Represents a list of Lora cards."""
object: str = "list"
- data: List[LoraCard] = Field(default_factory=list)
+ data: list[LoraCard] = Field(default_factory=list)
class LoraLoadInfo(BaseModel):
"""Represents a single Lora load info."""
name: str
- scaling: Optional[float] = 1.0
+ scaling: float | None = 1.0
class LoraLoadRequest(BaseModel):
"""Represents a Lora load request."""
- loras: List[LoraLoadInfo]
+ loras: list[LoraLoadInfo]
skip_queue: bool = False
class LoraLoadResponse(BaseModel):
"""Represents a Lora load response."""
- success: List[str] = Field(default_factory=list)
- failure: List[str] = Field(default_factory=list)
+ success: list[str] = Field(default_factory=list)
+ failure: list[str] = Field(default_factory=list)
diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py
index 84229294..ad0d8a85 100644
--- a/endpoints/core/types/model.py
+++ b/endpoints/core/types/model.py
@@ -1,8 +1,10 @@
"""Contains model card types."""
+from __future__ import annotations
+
from pydantic import BaseModel, Field, ConfigDict
from time import time
-from typing import List, Literal, Optional, Union
+from typing import Literal
from common.config_models import LoggingConfig
from common.tabby_config import config
@@ -13,19 +15,19 @@ class ModelCardParameters(BaseModel):
# Safe to do this since it's guaranteed to fetch a max seq len
# from model_container
- max_seq_len: Optional[int] = None
- cache_size: Optional[int] = None
- cache_mode: Optional[str] = "FP16"
- rope_scale: Optional[float] = 1.0
- rope_alpha: Optional[float] = 1.0
- max_batch_size: Optional[int] = 1
- chunk_size: Optional[int] = 2048
- prompt_template: Optional[str] = None
- prompt_template_content: Optional[str] = None
- use_vision: Optional[bool] = False
+ max_seq_len: int | None = None
+ cache_size: int | None = None
+ cache_mode: str | None = "FP16"
+ rope_scale: float | None = 1.0
+ rope_alpha: float | None = 1.0
+ max_batch_size: int | None = 1
+ chunk_size: int | None = 2048
+ prompt_template: str | None = None
+ prompt_template_content: str | None = None
+ use_vision: bool | None = False
# Draft is another model, so include it in the card params
- draft: Optional["ModelCard"] = None
+ draft: ModelCard | None = None
class ModelCard(BaseModel):
@@ -35,15 +37,15 @@ class ModelCard(BaseModel):
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
- logging: Optional[LoggingConfig] = None
- parameters: Optional[ModelCardParameters] = None
+ logging: LoggingConfig | None = None
+ parameters: ModelCardParameters | None = None
class ModelList(BaseModel):
"""Represents a list of model cards."""
object: str = "list"
- data: List[ModelCard] = Field(default_factory=list)
+ data: list[ModelCard] = Field(default_factory=list)
class DraftModelLoadRequest(BaseModel):
@@ -53,13 +55,13 @@ class DraftModelLoadRequest(BaseModel):
draft_model_name: str
# Config arguments
- draft_rope_scale: Optional[float] = None
- draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
+ draft_rope_scale: float | None = None
+ draft_rope_alpha: float | Literal["auto"] | None = Field(
description='Automatically calculated if set to "auto"',
default=None,
examples=[1.0],
)
- draft_gpu_split: Optional[List[float]] = Field(
+ draft_gpu_split: list[float] | None = Field(
default_factory=list,
examples=[[24.0, 20.0]],
)
@@ -75,54 +77,54 @@ class ModelLoadRequest(BaseModel):
model_name: str
# Config arguments
- backend: Optional[str] = Field(
+ backend: str | None = Field(
description="Backend to use",
default=None,
)
- max_seq_len: Optional[int] = Field(
+ max_seq_len: int | None = Field(
description="Leave this blank to use the model's base sequence length",
default=None,
examples=[4096],
)
- cache_size: Optional[int] = Field(
+ cache_size: int | None = Field(
description="Number in tokens, must be multiple of 256",
default=None,
examples=[4096],
)
- cache_mode: Optional[str] = None
- tensor_parallel: Optional[bool] = None
- tensor_parallel_backend: Optional[str] = "native"
- gpu_split_auto: Optional[bool] = None
- autosplit_reserve: Optional[List[float]] = None
- gpu_split: Optional[List[float]] = Field(
+ cache_mode: str | None = None
+ tensor_parallel: bool | None = None
+ tensor_parallel_backend: str | None = "native"
+ gpu_split_auto: bool | None = None
+ autosplit_reserve: list[float] | None = None
+ gpu_split: list[float] | None = Field(
default_factory=list,
examples=[[24.0, 20.0]],
)
- rope_scale: Optional[float] = Field(
+ rope_scale: float | None = Field(
description="Automatically pulled from the model's config if not present",
default=None,
examples=[1.0],
)
- rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
+ rope_alpha: float | Literal["auto"] | None = Field(
description='Automatically calculated if set to "auto"',
default=None,
examples=[1.0],
)
- chunk_size: Optional[int] = None
- output_chunking: Optional[bool] = True
- prompt_template: Optional[str] = None
- vision: Optional[bool] = None
+ chunk_size: int | None = None
+ output_chunking: bool | None = True
+ prompt_template: str | None = None
+ vision: bool | None = None
# Non-config arguments
- draft_model: Optional[DraftModelLoadRequest] = None
- skip_queue: Optional[bool] = False
+ draft_model: DraftModelLoadRequest | None = None
+ skip_queue: bool | None = False
class EmbeddingModelLoadRequest(BaseModel):
embedding_model_name: str
# Set default from the config
- embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
+ embeddings_device: str | None = Field(config.embeddings.embeddings_device)
class ModelLoadResponse(BaseModel):
diff --git a/endpoints/core/types/sampler_overrides.py b/endpoints/core/types/sampler_overrides.py
index 18627829..2a2efab0 100644
--- a/endpoints/core/types/sampler_overrides.py
+++ b/endpoints/core/types/sampler_overrides.py
@@ -1,5 +1,4 @@
from pydantic import BaseModel, Field
-from typing import List, Optional
from common.sampling import SamplerOverridesContainer
@@ -7,17 +6,17 @@
class SamplerOverrideListResponse(SamplerOverridesContainer):
"""Sampler override list response"""
- presets: Optional[List[str]]
+ presets: list[str] | None
class SamplerOverrideSwitchRequest(BaseModel):
"""Sampler override switch request"""
- preset: Optional[str] = Field(
+ preset: str | None = Field(
default=None, description="Pass a sampler override preset name"
)
- overrides: Optional[dict] = Field(
+ overrides: dict | None = Field(
default=None,
description=(
"Sampling override parent takes in individual keys and overrides. "
diff --git a/endpoints/core/types/template.py b/endpoints/core/types/template.py
index a82ef48d..a3b98fb2 100644
--- a/endpoints/core/types/template.py
+++ b/endpoints/core/types/template.py
@@ -1,12 +1,11 @@
from pydantic import BaseModel, Field
-from typing import List
class TemplateList(BaseModel):
"""Represents a list of templates."""
object: str = "list"
- data: List[str] = Field(default_factory=list)
+ data: list[str] = Field(default_factory=list)
class TemplateSwitchRequest(BaseModel):
diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py
index d43e65e4..8a28d880 100644
--- a/endpoints/core/types/token.py
+++ b/endpoints/core/types/token.py
@@ -1,7 +1,6 @@
"""Tokenization types"""
from pydantic import BaseModel
-from typing import List, Union
from endpoints.OAI.types.chat_completion import ChatCompletionMessage
@@ -25,20 +24,20 @@ def get_params(self):
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
- text: Union[str, List[ChatCompletionMessage]]
+ text: str | list[ChatCompletionMessage]
class TokenEncodeResponse(BaseModel):
"""Represents a tokenization response."""
- tokens: List[int]
+ tokens: list[int]
length: int
class TokenDecodeRequest(CommonTokenRequest):
""" " Represents a detokenization request."""
- tokens: List[int]
+ tokens: list[int]
class TokenDecodeResponse(BaseModel):
diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py
index 20c9433c..f2b39850 100644
--- a/endpoints/core/utils/model.py
+++ b/endpoints/core/utils/model.py
@@ -1,6 +1,5 @@
import pathlib
from asyncio import CancelledError
-from typing import Optional
from common import model
from common.networking import get_generator_error, handle_request_disconnect
@@ -13,7 +12,7 @@
)
-def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
+def get_model_list(model_path: pathlib.Path, draft_model_path: str | None = None):
"""Get the list of models from the provided path."""
# Convert the provided draft model path to a pathlib path for
diff --git a/endpoints/server.py b/endpoints/server.py
index c17cbb44..4287f747 100644
--- a/endpoints/server.py
+++ b/endpoints/server.py
@@ -3,7 +3,6 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
-from typing import Optional
from common.logger import UVICORN_LOG_CONFIG
from common.networking import get_global_depends
@@ -13,7 +12,7 @@
from endpoints.core.router import router as CoreRouter
-def setup_app(host: Optional[str] = None, port: Optional[int] = None):
+def setup_app(host: str | None = None, port: int | None = None):
"""Includes the correct routers for startup"""
app = FastAPI(
diff --git a/main.py b/main.py
index 661613eb..3f497901 100644
--- a/main.py
+++ b/main.py
@@ -11,7 +11,6 @@
import platform
import signal
from loguru import logger
-from typing import Optional
from common import gen_logging, sampling, model
from common.args import convert_args_to_dict, init_argparser
@@ -104,8 +103,8 @@ async def entrypoint_async():
def entrypoint(
- args: Optional[argparse.Namespace] = None,
- parser: Optional[argparse.ArgumentParser] = None,
+ args: argparse.Namespace | None = None,
+ parser: argparse.ArgumentParser | None = None,
):
setup_logger()
diff --git a/start.py b/start.py
index 95bdd366..4811c9ac 100644
--- a/start.py
+++ b/start.py
@@ -9,7 +9,6 @@
import sys
import traceback
from shutil import copyfile, which
-from typing import List
# Checks for uv installation
has_uv = which("uv") is not None
@@ -154,7 +153,7 @@ def migrate_start_options(start_options: dict):
return migrated
-def run_pip(command: List[str]):
+def run_pip(command: list[str]):
if has_uv:
command.insert(0, "uv")