diff --git a/backends/base_model_container.py b/backends/base_model_container.py index 96393aba..cc0c5c1e 100644 --- a/backends/base_model_container.py +++ b/backends/base_model_container.py @@ -2,13 +2,7 @@ import asyncio import pathlib from loguru import logger -from typing import ( - Any, - AsyncIterator, - Dict, - List, - Optional, -) +from typing import Any, AsyncIterator from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate @@ -21,7 +15,7 @@ class BaseModelContainer(abc.ABC): # Exposed model information model_dir: pathlib.Path = pathlib.Path("models") - prompt_template: Optional[PromptTemplate] = None + prompt_template: PromptTemplate | None = None # HF Model instance hf_model: HFModel @@ -34,7 +28,7 @@ class BaseModelContainer(abc.ABC): # The bool is a master switch for accepting requests # The lock keeps load tasks sequential # The condition notifies any waiting tasks - active_job_ids: Dict[str, Any] = {} + active_job_ids: dict[str, Any] = {} loaded: bool = False load_lock: asyncio.Lock load_condition: asyncio.Condition @@ -98,7 +92,7 @@ async def unload(self, loras_only: bool = False, **kwargs): pass @abc.abstractmethod - def encode_tokens(self, text: str, **kwargs) -> List[int]: + def encode_tokens(self, text: str, **kwargs) -> list[int]: """ Encodes a string of text into a list of token IDs. @@ -113,7 +107,7 @@ def encode_tokens(self, text: str, **kwargs) -> List[int]: pass @abc.abstractmethod - def decode_tokens(self, ids: List[int], **kwargs) -> str: + def decode_tokens(self, ids: list[int], **kwargs) -> str: """ Decodes a list of token IDs back into a string. @@ -128,7 +122,7 @@ def decode_tokens(self, ids: List[int], **kwargs) -> str: pass @abc.abstractmethod - def get_special_tokens(self) -> Dict[str, Any]: + def get_special_tokens(self) -> dict[str, Any]: """ Gets special tokens used by the model/tokenizer. @@ -164,7 +158,7 @@ async def wait_for_jobs(self, skip_wait: bool = False): # Optional methods async def load_loras( self, lora_directory: pathlib.Path, **kwargs - ) -> Dict[str, List[str]]: + ) -> dict[str, list[str]]: """ Loads LoRA adapters. Base implementation does nothing or raises error. @@ -184,7 +178,7 @@ async def load_loras( ], } - def get_loras(self) -> List[Any]: + def get_loras(self) -> list[Any]: """ Gets the currently loaded LoRA adapters. Base implementation returns empty list. @@ -200,9 +194,9 @@ async def generate( request_id: str, prompt: str, params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, - ) -> Dict[str, Any]: + abort_event: asyncio.Event | None = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, + ) -> dict[str, Any]: """ Generates a complete response for a given prompt and parameters. @@ -225,9 +219,9 @@ async def stream_generate( request_id: str, prompt: str, params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, - ) -> AsyncIterator[Dict[str, Any]]: + abort_event: asyncio.Event | None = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, + ) -> AsyncIterator[dict[str, Any]]: """ Generates a response iteratively (streaming) for a given prompt. diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index f9bb464c..2605bec3 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,7 +1,6 @@ import traceback -import typing from functools import lru_cache -from typing import List +from typing import Any import torch from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer @@ -16,7 +15,7 @@ class ExLlamaV2Grammar: """ExLlamaV2 class for various grammar filters/parsers.""" - filters: List[ExLlamaV2Filter] + filters: list[ExLlamaV2Filter] def __init__(self): self.filters = [] @@ -123,7 +122,7 @@ def __init__(self, nonterminal: str, kbnf_string: str): self.kbnf_string = kbnf_string # Return the entire input string as the extracted string - def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]: + def extract(self, input_str: str) -> tuple[str, Any] | None: return "", input_str @property diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ae71b00f..5c5c934c 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -25,7 +25,6 @@ ) from itertools import zip_longest from loguru import logger -from typing import Dict, List, Optional from backends.base_model_container import BaseModelContainer from backends.exllamav2.grammar import ( @@ -58,45 +57,45 @@ class ExllamaV2Container(BaseModelContainer): # Model directories model_dir: pathlib.Path = pathlib.Path("models") draft_model_dir: pathlib.Path = pathlib.Path("models") - prompt_template: Optional[PromptTemplate] = None + prompt_template: PromptTemplate | None = None # HF model instance hf_model: HFModel # Exl2 vars - config: Optional[ExLlamaV2Config] = None - model: Optional[ExLlamaV2] = None - cache: Optional[ExLlamaV2Cache] = None - tokenizer: Optional[ExLlamaV2Tokenizer] = None - generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None - prompt_template: Optional[PromptTemplate] = None + config: ExLlamaV2Config | None = None + model: ExLlamaV2 | None = None + cache: ExLlamaV2Cache | None = None + tokenizer: ExLlamaV2Tokenizer | None = None + generator: ExLlamaV2DynamicGeneratorAsync | None = None + prompt_template: PromptTemplate | None = None paged: bool = True # Draft model vars use_draft_model: bool = False - draft_config: Optional[ExLlamaV2Config] = None - draft_model: Optional[ExLlamaV2] = None - draft_cache: Optional[ExLlamaV2Cache] = None + draft_config: ExLlamaV2Config | None = None + draft_model: ExLlamaV2 | None = None + draft_cache: ExLlamaV2Cache | None = None # Internal config vars cache_size: int = None cache_mode: str = "FP16" draft_cache_mode: str = "FP16" - max_batch_size: Optional[int] = None + max_batch_size: int | None = None # GPU split vars - gpu_split: List[float] = [] - draft_gpu_split: List[float] = [] + gpu_split: list[float] = [] + draft_gpu_split: list[float] = [] gpu_split_auto: bool = True - autosplit_reserve: List[float] = [96 * 1024**2] + autosplit_reserve: list[float] = [96 * 1024**2] use_tp: bool = False # Vision vars use_vision: bool = False - vision_model: Optional[ExLlamaV2VisionTower] = None + vision_model: ExLlamaV2VisionTower | None = None # Load synchronization - active_job_ids: Dict[str, Optional[ExLlamaV2DynamicJobAsync]] = {} + active_job_ids: dict[str, ExLlamaV2DynamicJobAsync | None] = {} loaded: bool = False load_lock: asyncio.Lock = asyncio.Lock() load_condition: asyncio.Condition = asyncio.Condition() @@ -272,6 +271,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs self.config.max_seq_len = unwrap( user_max_seq_len, min(hf_model.hf_config.max_position_embeddings, 4096) ) + self.cache_size = self.config.max_seq_len # Set the rope scale self.config.scale_pos_emb = unwrap( @@ -750,9 +750,9 @@ async def load_loras(self, lora_directory: pathlib.Path, **kwargs): # Wait for existing generation jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) - loras_to_load: List[ExLlamaV2Lora] = [] - success: List[str] = [] - failure: List[str] = [] + loras_to_load: list[ExLlamaV2Lora] = [] + success: list[str] = [] + failure: list[str] = [] for lora in loras: lora_name = lora.get("name") @@ -869,7 +869,7 @@ def encode_tokens(self, text: str, **kwargs): .tolist() ) - def decode_tokens(self, ids: List[int], **kwargs): + def decode_tokens(self, ids: list[int], **kwargs): """Wrapper to decode tokens from a list of IDs""" ids = torch.tensor([ids]) @@ -908,8 +908,8 @@ async def generate( request_id: str, prompt: str, params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + abort_event: asyncio.Event | None = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, ): """Generate a response to a prompt.""" generations = [] @@ -969,8 +969,8 @@ async def stream_generate( request_id: str, prompt: str, params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + abort_event: asyncio.Event | None = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, ): try: # Wait for load lock to be freed before processing @@ -1242,8 +1242,8 @@ async def generate_gen( request_id: str, prompt: str, params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + abort_event: asyncio.Event | None = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, ): """ Create generator function for prompt completion. diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index d385de7c..389f581b 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -7,9 +7,6 @@ from typing import ( Any, AsyncIterator, - Dict, - List, - Optional, ) from exllamav3 import ( @@ -49,7 +46,7 @@ class ExllamaV3Container(BaseModelContainer): # Exposed model information model_dir: pathlib.Path = pathlib.Path("models") - prompt_template: Optional[PromptTemplate] = None + prompt_template: PromptTemplate | None = None # HF Model instance hf_model: HFModel @@ -58,26 +55,26 @@ class ExllamaV3Container(BaseModelContainer): # The bool is a master switch for accepting requests # The lock keeps load tasks sequential # The condition notifies any waiting tasks - active_job_ids: Dict[str, Any] = {} + active_job_ids: dict[str, Any] = {} loaded: bool = False load_lock: asyncio.Lock = asyncio.Lock() load_condition: asyncio.Condition = asyncio.Condition() # Exl3 vars - model: Optional[Model] = None - cache: Optional[Cache] = None - draft_model: Optional[Model] = None - draft_cache: Optional[Cache] = None - tokenizer: Optional[Tokenizer] = None - config: Optional[Config] = None - draft_config: Optional[Config] = None - generator: Optional[AsyncGenerator] = None - vision_model: Optional[Model] = None + model: Model | None = None + cache: Cache | None = None + draft_model: Model | None = None + draft_cache: Cache | None = None + tokenizer: Tokenizer | None = None + config: Config | None = None + draft_config: Config | None = None + generator: AsyncGenerator | None = None + vision_model: Model | None = None # Class-specific vars - gpu_split: Optional[List[float]] = None + gpu_split: list[float] | None = None gpu_split_auto: bool = True - autosplit_reserve: Optional[List[float]] = [96 / 1024] + autosplit_reserve: list[float] | None = [96 / 1024] use_tp: bool = False tp_backend: str = "native" max_seq_len: int = 4096 @@ -85,8 +82,8 @@ class ExllamaV3Container(BaseModelContainer): cache_mode: str = "FP16" draft_cache_mode: str = "FP16" chunk_size: int = 2048 - max_rq_tokens: Optional[int] = 2048 - max_batch_size: Optional[int] = None + max_rq_tokens: int | None = 2048 + max_batch_size: int | None = None # Required methods @classmethod @@ -579,7 +576,7 @@ async def unload(self, loras_only: bool = False, **kwargs): async with self.load_condition: self.load_condition.notify_all() - def encode_tokens(self, text: str, **kwargs) -> List[int]: + def encode_tokens(self, text: str, **kwargs) -> list[int]: """ Encodes a string of text into a list of token IDs. @@ -607,7 +604,7 @@ def encode_tokens(self, text: str, **kwargs) -> List[int]: .tolist() ) - def decode_tokens(self, ids: List[int], **kwargs) -> str: + def decode_tokens(self, ids: list[int], **kwargs) -> str: """ Decodes a list of token IDs back into a string. @@ -666,9 +663,9 @@ async def generate( request_id: str, prompt: str, params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, - ) -> Dict[str, Any]: + abort_event: asyncio.Event | None = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, + ) -> dict[str, Any]: """ Generates a complete response for a given prompt and parameters. @@ -738,9 +735,9 @@ async def stream_generate( request_id: str, prompt: str, params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, - ) -> AsyncIterator[Dict[str, Any]]: + abort_event: asyncio.Event | None = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, + ) -> AsyncIterator[dict[str, Any]]: """ Generates a response iteratively (streaming) for a given prompt. @@ -859,8 +856,8 @@ async def generate_gen( request_id: str, prompt: str, params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + abort_event: asyncio.Event | None = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, ): """ Create generator function for prompt completion. diff --git a/backends/exllamav3/sampler.py b/backends/exllamav3/sampler.py index 7b08a9b1..eef7d944 100644 --- a/backends/exllamav3/sampler.py +++ b/backends/exllamav3/sampler.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import List from exllamav3.generator.sampler import ( CustomSampler, SS_Temperature, @@ -20,7 +19,7 @@ class ExllamaV3SamplerBuilder: Custom sampler chain/stack for TabbyAPI """ - stack: List[SS_Base] = field(default_factory=list) + stack: list[SS_Base] = field(default_factory=list) def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay): self.stack += [ diff --git a/backends/infinity/model.py b/backends/infinity/model.py index c131e3cb..e1dd9b1b 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import gc import pathlib import torch from loguru import logger -from typing import List, Optional from common.utils import unwrap from common.optional_dependencies import dependencies @@ -17,7 +18,7 @@ class InfinityContainer: loaded: bool = False # Use a runtime type hint here - engine: Optional["AsyncEmbeddingEngine"] = None + engine: AsyncEmbeddingEngine | None = None def __init__(self, model_directory: pathlib.Path): self.model_dir = model_directory @@ -49,7 +50,7 @@ async def unload(self): logger.info("Embedding model unloaded.") - async def generate(self, sentence_input: List[str]): + async def generate(self, sentence_input: list[str]): result_embeddings, usage = await self.engine.embed(sentence_input) return {"embeddings": result_embeddings, "usage": usage} diff --git a/common/args.py b/common/args.py index a0da4f98..cd286877 100644 --- a/common/args.py +++ b/common/args.py @@ -1,7 +1,6 @@ """Argparser for overriding config values""" import argparse -from typing import Optional from pydantic import BaseModel from common.config_models import TabbyConfigModel @@ -25,7 +24,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None: def init_argparser( - existing_parser: Optional[argparse.ArgumentParser] = None, + existing_parser: argparse.ArgumentParser | None = None, ) -> argparse.ArgumentParser: """ Initializes an argparse parser based on a Pydantic config schema. diff --git a/common/auth.py b/common/auth.py index b02cdd02..bd93afe0 100644 --- a/common/auth.py +++ b/common/auth.py @@ -10,7 +10,6 @@ from fastapi import Header, HTTPException, Request from pydantic import BaseModel from loguru import logger -from typing import Optional from common.utils import coalesce @@ -38,7 +37,7 @@ def verify_key(self, test_key: str, key_type: str): # Global auth constants -AUTH_KEYS: Optional[AuthKeys] = None +AUTH_KEYS: AuthKeys | None = None DISABLE_AUTH: bool = False diff --git a/common/config_models.py b/common/config_models.py index 0e71734c..c5772c42 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -6,17 +6,17 @@ PrivateAttr, field_validator, ) -from typing import List, Literal, Optional, Union +from typing import Literal CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] -CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")] +CACHE_TYPE = CACHE_SIZES | constr(pattern=r"^[2-8]\s*,\s*[2-8]$") class Metadata(BaseModel): """metadata model for config options""" - include_in_config: Optional[bool] = Field(True) + include_in_config: bool | None = Field(True) class BaseConfigModel(BaseModel): @@ -29,7 +29,7 @@ class ConfigOverrideConfig(BaseConfigModel): """Model for overriding a provided config file.""" # TODO: convert this to a pathlib.path? - config: Optional[str] = Field( + config: str | None = Field( None, description=("Path to an overriding config.yml file") ) @@ -39,17 +39,17 @@ class ConfigOverrideConfig(BaseConfigModel): class NetworkConfig(BaseConfigModel): """Options for networking""" - host: Optional[str] = Field( + host: str | None = Field( "127.0.0.1", description=( "The IP to host on (default: 127.0.0.1).\n" "Use 0.0.0.0 to expose on all network adapters." ), ) - port: Optional[int] = Field( + port: int | None = Field( 5000, description=("The port to host on (default: 5000).") ) - disable_auth: Optional[bool] = Field( + disable_auth: bool | None = Field( False, description=( "Disable HTTP token authentication with requests.\n" @@ -57,21 +57,21 @@ class NetworkConfig(BaseConfigModel): "Turn on this option if you are ONLY connecting from localhost." ), ) - disable_fetch_requests: Optional[bool] = Field( + disable_fetch_requests: bool | None = Field( False, description=( "Disable fetching external content in response to requests," "such as images from URLs." ), ) - send_tracebacks: Optional[bool] = Field( + send_tracebacks: bool | None = Field( False, description=( "Send tracebacks over the API (default: False).\n" "NOTE: Only enable this for debug purposes." ), ) - api_servers: Optional[List[Literal["oai", "kobold"]]] = Field( + api_servers: list[Literal["oai", "kobold"]] | None = Field( ["OAI"], description=( 'Select API servers to enable (default: ["OAI"]).\n' @@ -91,15 +91,15 @@ def api_server_validator(cls, api_servers): class LoggingConfig(BaseConfigModel): """Options for logging""" - log_prompt: Optional[bool] = Field( + log_prompt: bool | None = Field( False, description=("Enable prompt logging (default: False)."), ) - log_generation_params: Optional[bool] = Field( + log_generation_params: bool | None = Field( False, description=("Enable generation parameter logging (default: False)."), ) - log_requests: Optional[bool] = Field( + log_requests: bool | None = Field( False, description=( "Enable request logging (default: False).\n" @@ -123,7 +123,7 @@ class ModelConfig(BaseConfigModel): "Windows users, do NOT put this path in quotes!" ), ) - inline_model_loading: Optional[bool] = Field( + inline_model_loading: bool | None = Field( False, description=( "Allow direct loading of models " @@ -132,7 +132,7 @@ class ModelConfig(BaseConfigModel): "Enable dummy models to add exceptions for invalid model names." ), ) - use_dummy_models: Optional[bool] = Field( + use_dummy_models: bool | None = Field( False, description=( "Sends dummy model names when the models endpoint is queried. " @@ -140,7 +140,7 @@ class ModelConfig(BaseConfigModel): "Enable this if the client is looking for specific OAI models.\n" ), ) - dummy_model_names: List[str] = Field( + dummy_model_names: list[str] = Field( default=["gpt-3.5-turbo"], description=( "A list of fake model names that are sent via the /v1/models endpoint. " @@ -148,7 +148,7 @@ class ModelConfig(BaseConfigModel): "Also used as bypasses for strict mode if inline_model_loading is true." ), ) - model_name: Optional[str] = Field( + model_name: str | None = Field( None, description=( "An initial model to load.\n" @@ -156,7 +156,7 @@ class ModelConfig(BaseConfigModel): "REQUIRED: This must be filled out to load a model on startup." ), ) - use_as_default: List[str] = Field( + use_as_default: list[str] = Field( default_factory=list, description=( "Names of args to use as a fallback for API load requests (default: []).\n" @@ -165,22 +165,22 @@ class ModelConfig(BaseConfigModel): "Example: ['max_seq_len', 'cache_mode']." ), ) - backend: Optional[str] = Field( + backend: str | None = Field( None, description=( "Backend to use for this model (auto-detect if not specified)\n" "Options: exllamav2, exllamav3" ), ) - max_seq_len: Optional[int] = Field( + max_seq_len: int | None = Field( None, description=( "Max sequence length (default: 4096).\n" - "Set to -1 to fetch from the model's config.json" + "set to -1 to fetch from the model's config.json" ), ge=-1, ) - cache_size: Optional[int] = Field( + cache_size: int | None = Field( None, description=( "Size of the prompt cache to allocate (default: max_seq_len).\n" @@ -190,7 +190,7 @@ class ModelConfig(BaseConfigModel): multiple_of=256, gt=0, ) - cache_mode: Optional[CACHE_TYPE] = Field( + cache_mode: CACHE_TYPE | None = Field( "FP16", description=( "Enable different cache modes for VRAM savings (default: FP16).\n" @@ -199,7 +199,7 @@ class ModelConfig(BaseConfigModel): "are integers from 2-8 (i.e. 8,8)." ), ) - tensor_parallel: Optional[bool] = Field( + tensor_parallel: bool | None = Field( False, description=( "Load model with tensor parallelism (default: False).\n" @@ -207,7 +207,7 @@ class ModelConfig(BaseConfigModel): "This ignores the gpu_split_auto value." ), ) - tensor_parallel_backend: Optional[str] = Field( + tensor_parallel_backend: str | None = Field( "native", description=( "Sets a backend type for tensor parallelism. (default: native).\n" @@ -216,28 +216,28 @@ class ModelConfig(BaseConfigModel): "NCCL is recommended for NVLink." ), ) - gpu_split_auto: Optional[bool] = Field( + gpu_split_auto: bool | None = Field( True, description=( "Automatically allocate resources to GPUs (default: True).\n" "Not parsed for single GPU users." ), ) - autosplit_reserve: List[float] = Field( + autosplit_reserve: list[float] = Field( [96], description=( "Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n" "Represented as an array of MB per GPU." ), ) - gpu_split: List[float] = Field( + gpu_split: list[float] = Field( default_factory=list, description=( "An integer array of GBs of VRAM to split between GPUs (default: []).\n" "Used with tensor parallelism." ), ) - rope_scale: Optional[float] = Field( + rope_scale: float | None = Field( 1.0, description=( "Rope scale (default: 1.0).\n" @@ -246,7 +246,7 @@ class ModelConfig(BaseConfigModel): "Leave blank to pull the value from the model." ), ) - rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( + rope_alpha: float | Literal["auto"] | None = Field( None, description=( "Rope alpha (default: None).\n" @@ -255,7 +255,7 @@ class ModelConfig(BaseConfigModel): "or auto-calculate." ), ) - chunk_size: Optional[int] = Field( + chunk_size: int | None = Field( 2048, description=( "Chunk size for prompt ingestion (default: 2048).\n" @@ -265,7 +265,7 @@ class ModelConfig(BaseConfigModel): ), gt=0, ) - output_chunking: Optional[bool] = Field( + output_chunking: bool | None = Field( True, description=( "Use output chunking (default: True)\n" @@ -274,7 +274,7 @@ class ModelConfig(BaseConfigModel): "Used by EXL3 models only.\n" ), ) - max_batch_size: Optional[int] = Field( + max_batch_size: int | None = Field( None, description=( "Set the maximum number of prompts to process at one time " @@ -284,7 +284,7 @@ class ModelConfig(BaseConfigModel): ), ge=1, ) - prompt_template: Optional[str] = Field( + prompt_template: str | None = Field( None, description=( "Set the prompt template for this model. (default: None)\n" @@ -294,7 +294,7 @@ class ModelConfig(BaseConfigModel): "NOTE: Only works with chat completion message lists!" ), ) - vision: Optional[bool] = Field( + vision: bool | None = Field( False, description=( "Enables vision support if the model supports it. (default: False)" @@ -312,18 +312,18 @@ class DraftModelConfig(BaseConfigModel): """ # TODO: convert this to a pathlib.path? - draft_model_dir: Optional[str] = Field( + draft_model_dir: str | None = Field( "models", description=("Directory to look for draft models (default: models)"), ) - draft_model_name: Optional[str] = Field( + draft_model_name: str | None = Field( None, description=( "An initial draft model to load.\n" "Ensure the model is in the model directory." ), ) - draft_rope_scale: Optional[float] = Field( + draft_rope_scale: float | None = Field( 1.0, description=( "Rope scale for draft models (default: 1.0).\n" @@ -331,7 +331,7 @@ class DraftModelConfig(BaseConfigModel): "Use if the draft model was trained on long context with rope." ), ) - draft_rope_alpha: Optional[float] = Field( + draft_rope_alpha: float | None = Field( None, description=( "Rope alpha for draft models (default: None).\n" @@ -340,14 +340,14 @@ class DraftModelConfig(BaseConfigModel): "or auto-calculate." ), ) - draft_cache_mode: Optional[CACHE_SIZES] = Field( + draft_cache_mode: CACHE_SIZES | None = Field( "FP16", description=( "Cache mode for draft models to save VRAM (default: FP16).\n" f"Possible values: {str(CACHE_SIZES)[15:-1]}." ), ) - draft_gpu_split: List[float] = Field( + draft_gpu_split: list[float] = Field( default_factory=list, description=( "An integer array of GBs of VRAM to split between GPUs (default: []).\n" @@ -359,7 +359,7 @@ class DraftModelConfig(BaseConfigModel): class SamplingConfig(BaseConfigModel): """Options for Sampling""" - override_preset: Optional[str] = Field( + override_preset: str | None = Field( None, description=( "Select a sampler override preset (default: None).\n" @@ -376,7 +376,7 @@ class SamplingConfig(BaseConfigModel): class LoraInstanceModel(BaseConfigModel): """Model representing an instance of a Lora.""" - name: Optional[str] = None + name: str | None = None scaling: float = Field(1.0, ge=0) @@ -384,10 +384,10 @@ class LoraConfig(BaseConfigModel): """Options for Loras""" # TODO: convert this to a pathlib.path? - lora_dir: Optional[str] = Field( + lora_dir: str | None = Field( "loras", description=("Directory to look for LoRAs (default: loras).") ) - loras: Optional[List[LoraInstanceModel]] = Field( + loras: list[LoraInstanceModel] | None = Field( None, description=( "List of LoRAs to load and associated scaling factors " @@ -407,11 +407,11 @@ class EmbeddingsConfig(BaseConfigModel): """ # TODO: convert this to a pathlib.path? - embedding_model_dir: Optional[str] = Field( + embedding_model_dir: str | None = Field( "models", description=("Directory to look for embedding models (default: models)."), ) - embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field( + embeddings_device: Literal["cpu", "auto", "cuda"] | None = Field( "cpu", description=( "Device to load embedding models on (default: cpu).\n" @@ -420,7 +420,7 @@ class EmbeddingsConfig(BaseConfigModel): "If using an AMD GPU, set this value to 'cuda'." ), ) - embedding_model_name: Optional[str] = Field( + embedding_model_name: str | None = Field( None, description=("An initial embedding model to load on the infinity backend."), ) @@ -429,7 +429,7 @@ class EmbeddingsConfig(BaseConfigModel): class DeveloperConfig(BaseConfigModel): """Options for development and experimentation""" - unsafe_launch: Optional[bool] = Field( + unsafe_launch: bool | None = Field( False, description=( "Skip Exllamav2 version check (default: False).\n" @@ -437,10 +437,10 @@ class DeveloperConfig(BaseConfigModel): "than enabling this flag." ), ) - disable_request_streaming: Optional[bool] = Field( + disable_request_streaming: bool | None = Field( False, description=("Disable API request streaming (default: False).") ) - realtime_process_priority: Optional[bool] = Field( + realtime_process_priority: bool | None = Field( False, description=( "Set process to use a higher priority.\n" @@ -453,27 +453,27 @@ class DeveloperConfig(BaseConfigModel): class TabbyConfigModel(BaseModel): """Base model for a TabbyConfig.""" - config: Optional[ConfigOverrideConfig] = Field( + config: ConfigOverrideConfig | None = Field( default_factory=ConfigOverrideConfig.model_construct ) - network: Optional[NetworkConfig] = Field( + network: NetworkConfig | None = Field( default_factory=NetworkConfig.model_construct ) - logging: Optional[LoggingConfig] = Field( + logging: LoggingConfig | None = Field( default_factory=LoggingConfig.model_construct ) - model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct) - draft_model: Optional[DraftModelConfig] = Field( + model: ModelConfig | None = Field(default_factory=ModelConfig.model_construct) + draft_model: DraftModelConfig | None = Field( default_factory=DraftModelConfig.model_construct ) - lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct) - embeddings: Optional[EmbeddingsConfig] = Field( + lora: LoraConfig | None = Field(default_factory=LoraConfig.model_construct) + embeddings: EmbeddingsConfig | None = Field( default_factory=EmbeddingsConfig.model_construct ) - sampling: Optional[SamplingConfig] = Field( + sampling: SamplingConfig | None = Field( default_factory=SamplingConfig.model_construct ) - developer: Optional[DeveloperConfig] = Field( + developer: DeveloperConfig | None = Field( default_factory=DeveloperConfig.model_construct ) diff --git a/common/downloader.py b/common/downloader.py index 8307bbcd..8f3155f2 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -10,7 +10,6 @@ from fnmatch import fnmatch from loguru import logger from rich.progress import Progress -from typing import List, Optional from common.logger import get_progress_bar from common.tabby_config import config @@ -27,7 +26,7 @@ class RepoItem: async def _download_file( session: aiohttp.ClientSession, repo_item: RepoItem, - token: Optional[str], + token: str | None, download_path: pathlib.Path, chunk_limit: int, progress: Progress, @@ -92,7 +91,7 @@ def _get_repo_info(repo_id, revision, token): ] -def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str]): +def _get_download_folder(repo_id: str, repo_type: str, folder_name: str | None): """Gets the download folder for the repo.""" if repo_type == "lora": @@ -105,7 +104,7 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str def _check_exclusions( - filename: str, include_patterns: List[str], exclude_patterns: List[str] + filename: str, include_patterns: list[str], exclude_patterns: list[str] ): include_result = any(fnmatch(filename, pattern) for pattern in include_patterns) exclude_result = any(fnmatch(filename, pattern) for pattern in exclude_patterns) @@ -115,14 +114,14 @@ def _check_exclusions( async def hf_repo_download( repo_id: str, - folder_name: Optional[str], - revision: Optional[str], - token: Optional[str], - include: Optional[List[str]], - exclude: Optional[List[str]], - chunk_limit: Optional[float] = None, - timeout: Optional[int] = None, - repo_type: Optional[str] = "model", + folder_name: str | None, + revision: str | None, + token: str | None, + include: list[str] | None, + exclude: list[str] | None, + chunk_limit: float | None = None, + timeout: int | None = None, + repo_type: str | None = "model", ): """Gets a repo's information from HuggingFace and downloads it locally.""" diff --git a/common/gen_logging.py b/common/gen_logging.py index fcd3c01d..6cf5ddcb 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -3,7 +3,6 @@ """ from loguru import logger -from typing import Optional from common.tabby_config import config @@ -29,7 +28,7 @@ def log_generation_params(**kwargs): logger.info(f"Generation options: {kwargs}\n") -def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str] = None): +def log_prompt(prompt: str, request_id: str, negative_prompt: str | None = None): """Logs the prompt to console.""" if config.logging.log_prompt: formatted_prompt = "\n" + prompt @@ -55,7 +54,7 @@ def log_response(request_id: str, response: str): def log_metrics( request_id: str, metrics: dict, - context_len: Optional[int], + context_len: int | None, max_seq_len: int, ): initial_response = ( diff --git a/common/health.py b/common/health.py index 4d21d6af..a31cb689 100644 --- a/common/health.py +++ b/common/health.py @@ -3,7 +3,6 @@ from datetime import datetime, timezone from functools import partial from pydantic import BaseModel, Field -from typing import Union class UnhealthyEvent(BaseModel): @@ -24,7 +23,7 @@ def __init__(self): self.issues: deque[UnhealthyEvent] = deque(maxlen=100) self._lock = asyncio.Lock() - async def add_unhealthy_event(self, error: Union[str, Exception]): + async def add_unhealthy_event(self, error: str | Exception): """Add a new unhealthy event""" async with self._lock: if isinstance(error, Exception): diff --git a/common/model.py b/common/model.py index 4ac4861a..757688e8 100644 --- a/common/model.py +++ b/common/model.py @@ -10,7 +10,6 @@ from fastapi import HTTPException from loguru import logger from ruamel.yaml import YAML -from typing import Dict, Optional from backends.base_model_container import BaseModelContainer from common.logger import get_loading_progress_bar @@ -21,11 +20,11 @@ from common.utils import deep_merge_dict, unwrap # Global variables for model container -container: Optional[BaseModelContainer] = None +container: BaseModelContainer | None = None embeddings_container = None -_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {} +_BACKEND_REGISTRY: dict[str, BaseModelContainer] = {} if dependencies.exllamav2: from backends.exllamav2.model import ExllamaV2Container @@ -42,7 +41,7 @@ if dependencies.extras: from backends.infinity.model import InfinityContainer - embeddings_container: Optional[InfinityContainer] = None + embeddings_container: InfinityContainer | None = None class ModelType(Enum): diff --git a/common/multimodal.py b/common/multimodal.py index b92386f3..11e401dc 100644 --- a/common/multimodal.py +++ b/common/multimodal.py @@ -3,7 +3,6 @@ from common import model from loguru import logger from pydantic import BaseModel, Field -from typing import List from common.optional_dependencies import dependencies @@ -18,7 +17,7 @@ class MultimodalEmbeddingWrapper(BaseModel): type: str = None content: list = Field(default_factory=list) - text_alias: List[str] = Field(default_factory=list) + text_alias: list[str] = Field(default_factory=list) async def add(self, url: str): # Determine the type of vision embedding to use diff --git a/common/networking.py b/common/networking.py index 597ed078..f4d72917 100644 --- a/common/networking.py +++ b/common/networking.py @@ -7,7 +7,6 @@ from fastapi import Depends, HTTPException, Request from loguru import logger from pydantic import BaseModel -from typing import Optional from uuid import uuid4 from common.tabby_config import config @@ -17,7 +16,7 @@ class TabbyRequestErrorMessage(BaseModel): """Common request error type.""" message: str - trace: Optional[str] = None + trace: str | None = None class TabbyRequestError(BaseModel): diff --git a/common/sampling.py b/common/sampling.py index 49be5b99..832661ed 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -14,7 +14,6 @@ field_validator, model_validator, ) -from typing import Dict, List, Optional, Union from common.utils import filter_none_values, unwrap @@ -23,7 +22,7 @@ class BaseSamplerRequest(BaseModel): """Common class for sampler params that are used in APIs""" - max_tokens: Optional[int] = Field( + max_tokens: int | None = Field( default_factory=lambda: get_default_sampler_value("max_tokens"), validation_alias=AliasChoices( "max_tokens", "max_completion_tokens", "max_length" @@ -33,7 +32,7 @@ class BaseSamplerRequest(BaseModel): ge=0, ) - min_tokens: Optional[int] = Field( + min_tokens: int | None = Field( default_factory=lambda: get_default_sampler_value("min_tokens", 0), validation_alias=AliasChoices("min_tokens", "min_length"), description="Aliases: min_length", @@ -41,76 +40,76 @@ class BaseSamplerRequest(BaseModel): ge=0, ) - stop: Optional[Union[str, List[Union[str, int]]]] = Field( + stop: str | list[str | int] | None = Field( default_factory=lambda: get_default_sampler_value("stop", []), validation_alias=AliasChoices("stop", "stop_sequence"), description="Aliases: stop_sequence", ) - banned_strings: Optional[Union[str, List[str]]] = Field( + banned_strings: str | list[str] | None = Field( default_factory=lambda: get_default_sampler_value("banned_strings", []) ) - banned_tokens: Optional[Union[List[int], str]] = Field( + banned_tokens: list[int] | str | None = Field( default_factory=lambda: get_default_sampler_value("banned_tokens", []), validation_alias=AliasChoices("banned_tokens", "custom_token_bans"), description="Aliases: custom_token_bans", examples=[[128, 330]], ) - allowed_tokens: Optional[Union[List[int], str]] = Field( + allowed_tokens: list[int] | str | None = Field( default_factory=lambda: get_default_sampler_value("allowed_tokens", []), validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"), description="Aliases: allowed_token_ids", examples=[[128, 330]], ) - token_healing: Optional[bool] = Field( + token_healing: bool | None = Field( default_factory=lambda: get_default_sampler_value("token_healing", False) ) - temperature: Optional[float] = Field( + temperature: float | None = Field( default_factory=lambda: get_default_sampler_value("temperature", 1.0), examples=[1.0], ge=0, le=10, ) - temperature_last: Optional[bool] = Field( + temperature_last: bool | None = Field( default_factory=lambda: get_default_sampler_value("temperature_last", False), ) - smoothing_factor: Optional[float] = Field( + smoothing_factor: float | None = Field( default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0), ge=0, ) - top_k: Optional[int] = Field( + top_k: int | None = Field( default_factory=lambda: get_default_sampler_value("top_k", 0), ge=-1, ) - top_p: Optional[float] = Field( + top_p: float | None = Field( default_factory=lambda: get_default_sampler_value("top_p", 1.0), ge=0, le=1, examples=[1.0], ) - top_a: Optional[float] = Field( + top_a: float | None = Field( default_factory=lambda: get_default_sampler_value("top_a", 0.0) ) - min_p: Optional[float] = Field( + min_p: float | None = Field( default_factory=lambda: get_default_sampler_value("min_p", 0.0) ) - tfs: Optional[float] = Field( + tfs: float | None = Field( default_factory=lambda: get_default_sampler_value("tfs", 1.0), examples=[1.0], ) - typical: Optional[float] = Field( + typical: float | None = Field( default_factory=lambda: get_default_sampler_value("typical", 1.0), validation_alias=AliasChoices("typical", "typical_p"), description="Aliases: typical_p", @@ -119,30 +118,30 @@ class BaseSamplerRequest(BaseModel): le=1, ) - skew: Optional[float] = Field( + skew: float | None = Field( default_factory=lambda: get_default_sampler_value("skew", 0.0), examples=[0.0], ) - xtc_probability: Optional[float] = Field( + xtc_probability: float | None = Field( default_factory=lambda: get_default_sampler_value("xtc_probability", 0.0), ) - xtc_threshold: Optional[float] = Field( + xtc_threshold: float | None = Field( default_factory=lambda: get_default_sampler_value("xtc_threshold", 0.1) ) - frequency_penalty: Optional[float] = Field( + frequency_penalty: float | None = Field( default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0), ge=0, ) - presence_penalty: Optional[float] = Field( + presence_penalty: float | None = Field( default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0), ge=0, ) - repetition_penalty: Optional[float] = Field( + repetition_penalty: float | None = Field( default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), validation_alias=AliasChoices("repetition_penalty", "rep_pen"), description="Aliases: rep_pen", @@ -150,7 +149,7 @@ class BaseSamplerRequest(BaseModel): gt=0, ) - penalty_range: Optional[int] = Field( + penalty_range: int | None = Field( default_factory=lambda: get_default_sampler_value("penalty_range", -1), validation_alias=AliasChoices( "penalty_range", @@ -163,91 +162,91 @@ class BaseSamplerRequest(BaseModel): ), ) - repetition_decay: Optional[int] = Field( + repetition_decay: int | None = Field( default_factory=lambda: get_default_sampler_value("repetition_decay", 0) ) - dry_multiplier: Optional[float] = Field( + dry_multiplier: float | None = Field( default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0) ) - dry_base: Optional[float] = Field( + dry_base: float | None = Field( default_factory=lambda: get_default_sampler_value("dry_base", 0.0) ) - dry_allowed_length: Optional[int] = Field( + dry_allowed_length: int | None = Field( default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0) ) - dry_range: Optional[int] = Field( + dry_range: int | None = Field( default_factory=lambda: get_default_sampler_value("dry_range", 0), validation_alias=AliasChoices("dry_range", "dry_penalty_last_n"), description=("Aliases: dry_penalty_last_n"), ) - dry_sequence_breakers: Optional[Union[str, List[str]]] = Field( + dry_sequence_breakers: str | list[str] | None = Field( default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) ) - mirostat_mode: Optional[int] = Field( + mirostat_mode: int | None = Field( default_factory=lambda: get_default_sampler_value("mirostat_mode", 0), alias=AliasChoices("mirostat_mode", "mirostat"), ) - mirostat_tau: Optional[float] = Field( + mirostat_tau: float | None = Field( default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5), examples=[1.5], ) - mirostat_eta: Optional[float] = Field( + mirostat_eta: float | None = Field( default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3), examples=[0.3], ) - add_bos_token: Optional[bool] = Field( + add_bos_token: bool | None = Field( default_factory=lambda: get_default_sampler_value("add_bos_token") ) - ban_eos_token: Optional[bool] = Field( + ban_eos_token: bool | None = Field( default_factory=lambda: get_default_sampler_value("ban_eos_token", False), validation_alias=AliasChoices("ban_eos_token", "ignore_eos"), description="Aliases: ignore_eos", examples=[False], ) - logit_bias: Optional[Dict[int, float]] = Field( + logit_bias: dict[int, float] | None = Field( default_factory=lambda: get_default_sampler_value("logit_bias"), examples=[{"1": 10, "2": 50}], ) - negative_prompt: Optional[str] = Field( + negative_prompt: str | None = Field( default_factory=lambda: get_default_sampler_value("negative_prompt") ) - json_schema: Optional[object] = Field( + json_schema: object | None = Field( default_factory=lambda: get_default_sampler_value("json_schema"), ) - regex_pattern: Optional[str] = Field( + regex_pattern: str | None = Field( default_factory=lambda: get_default_sampler_value("regex_pattern"), ) - grammar_string: Optional[str] = Field( + grammar_string: str | None = Field( default_factory=lambda: get_default_sampler_value("grammar_string"), ) - speculative_ngram: Optional[bool] = Field( + speculative_ngram: bool | None = Field( default_factory=lambda: get_default_sampler_value("speculative_ngram"), ) - cfg_scale: Optional[float] = Field( + cfg_scale: float | None = Field( default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), validation_alias=AliasChoices("cfg_scale", "guidance_scale"), description="Aliases: guidance_scale", examples=[1.0], ) - max_temp: Optional[float] = Field( + max_temp: float | None = Field( default_factory=lambda: get_default_sampler_value("max_temp", 1.0), validation_alias=AliasChoices("max_temp", "dynatemp_high"), description="Aliases: dynatemp_high", @@ -255,7 +254,7 @@ class BaseSamplerRequest(BaseModel): ge=0, ) - min_temp: Optional[float] = Field( + min_temp: float | None = Field( default_factory=lambda: get_default_sampler_value("min_temp", 1.0), validation_alias=AliasChoices("min_temp", "dynatemp_low"), description="Aliases: dynatemp_low", @@ -263,14 +262,14 @@ class BaseSamplerRequest(BaseModel): ge=0, ) - temp_exponent: Optional[float] = Field( + temp_exponent: float | None = Field( default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0), validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"), examples=[1.0], ge=0, ) - logprobs: Optional[int] = Field( + logprobs: int | None = Field( default_factory=lambda: get_default_sampler_value("logprobs", 0), ge=0, ) @@ -335,7 +334,7 @@ def after_validate(self): class SamplerOverridesContainer(BaseModel): - selected_preset: Optional[str] = None + selected_preset: str | None = None overrides: dict = {} diff --git a/common/tabby_config.py b/common/tabby_config.py index 9c4cc5d6..535cb6b2 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -2,7 +2,6 @@ from inspect import getdoc from os import getenv from textwrap import dedent -from typing import Optional from loguru import logger from pydantic import BaseModel @@ -22,7 +21,7 @@ class TabbyConfig(TabbyConfigModel): model_defaults: dict = {} draft_model_defaults: dict = {} - def load(self, arguments: Optional[dict] = None): + def load(self, arguments: dict | None = None): """Synchronously loads the global application config""" # config is applied in order of items in the list diff --git a/common/templating.py b/common/templating.py index cc0cceb1..864ee337 100644 --- a/common/templating.py +++ b/common/templating.py @@ -7,7 +7,6 @@ from dataclasses import dataclass, field from datetime import datetime from importlib.metadata import version as package_version -from typing import List, Optional from jinja2 import Template, TemplateError from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment @@ -28,8 +27,8 @@ class TemplateLoadError(Exception): class TemplateMetadata: """Represents the parsed metadata from a template.""" - stop_strings: List[str] = field(default_factory=list) - tool_start: Optional[str] = None + stop_strings: list[str] = field(default_factory=list) + tool_start: str | None = None class PromptTemplate: @@ -44,7 +43,7 @@ class PromptTemplate: enable_async=True, extensions=[loopcontrols], ) - metadata: Optional[TemplateMetadata] = None + metadata: TemplateMetadata | None = None async def extract_metadata(self, template_vars: dict): """ @@ -145,7 +144,7 @@ async def from_file(cls, template_path: pathlib.Path): @classmethod async def from_model_json( - cls, json_path: pathlib.Path, key: str, name: Optional[str] = None + cls, json_path: pathlib.Path, key: str, name: str | None = None ): """Get a template from a JSON file. Requires a key and template name""" if not json_path.exists(): diff --git a/common/transformers_utils.py b/common/transformers_utils.py index a7b0f0c1..92441564 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -3,7 +3,6 @@ import pathlib from loguru import logger from pydantic import BaseModel -from typing import Dict, List, Optional, Set, Union class GenerationConfig(BaseModel): @@ -12,7 +11,7 @@ class GenerationConfig(BaseModel): Will be expanded as needed. """ - eos_token_id: Optional[Union[int, List[int]]] = None + eos_token_id: int | list[int] | None = None @classmethod async def from_directory(cls, model_directory: pathlib.Path): @@ -44,8 +43,8 @@ class HuggingFaceConfig(BaseModel): """ max_position_embeddings: int = 4096 - eos_token_id: Optional[Union[int, List[int]]] = None - quantization_config: Optional[Dict] = None + eos_token_id: int | list[int] | None = None + quantization_config: dict | None = None @classmethod async def from_directory(cls, model_directory: pathlib.Path): @@ -62,7 +61,7 @@ async def from_directory(cls, model_directory: pathlib.Path): def quant_method(self): """Wrapper method to fetch quant type""" - if isinstance(self.quantization_config, Dict): + if isinstance(self.quantization_config, dict): return self.quantization_config.get("quant_method") else: return None @@ -83,7 +82,7 @@ class TokenizerConfig(BaseModel): An abridged version of HuggingFace's tokenizer config. """ - add_bos_token: Optional[bool] = True + add_bos_token: bool | None = True @classmethod async def from_directory(cls, model_directory: pathlib.Path): @@ -111,8 +110,8 @@ class HFModel: """ hf_config: HuggingFaceConfig - tokenizer_config: Optional[TokenizerConfig] = None - generation_config: Optional[GenerationConfig] = None + tokenizer_config: TokenizerConfig | None = None + generation_config: GenerationConfig | None = None @classmethod async def from_directory(cls, model_directory: pathlib.Path): @@ -156,7 +155,7 @@ def quant_method(self): def eos_tokens(self): """Combines and returns EOS tokens from various configs""" - eos_ids: Set[int] = set() + eos_ids: set[int] = set() eos_ids.update(self.hf_config.eos_tokens()) diff --git a/common/utils.py b/common/utils.py index b0d7ad24..da841799 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,12 +1,12 @@ """Common utility functions""" -from types import NoneType -from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar +from types import NoneType, UnionType +from typing import Type, get_args, get_origin, TypeVar T = TypeVar("T") -def unwrap(wrapped: Optional[T], default: T = None) -> T: +def unwrap(wrapped: T | None, default: T = None) -> T: """Unwrap function for Optionals.""" if wrapped is None: return default @@ -19,7 +19,7 @@ def coalesce(*args): return next((arg for arg in args if arg is not None), None) -def filter_none_values(collection: Union[dict, list]) -> Union[dict, list]: +def filter_none_values(collection: dict | list) -> dict | list: """Remove None values from a collection.""" if isinstance(collection, dict): @@ -32,7 +32,7 @@ def filter_none_values(collection: Union[dict, list]) -> Union[dict, list]: return collection -def deep_merge_dict(dict1: Dict, dict2: Dict, copy: bool = False) -> Dict: +def deep_merge_dict(dict1: dict, dict2: dict, copy: bool = False) -> dict: """ Merge 2 dictionaries. If copy is true, the original dictionary isn't modified. """ @@ -49,7 +49,7 @@ def deep_merge_dict(dict1: Dict, dict2: Dict, copy: bool = False) -> Dict: return dict1 -def deep_merge_dicts(*dicts: Dict) -> Dict: +def deep_merge_dicts(*dicts: dict) -> dict: """ Merge an arbitrary amount of dictionaries. We wanna do in-place modification for each level, so do not copy. @@ -84,11 +84,13 @@ def is_list_type(type_hint) -> bool: def unwrap_optional_type(type_hint) -> Type: """ - Unwrap Optional[type] annotations. + Unwrap type | None annotations to extract the base type. This is not the same as unwrap. """ - if get_origin(type_hint) is Union: + origin = get_origin(type_hint) + + if origin is UnionType: args = get_args(type_hint) if NoneType in args: for arg in args: diff --git a/docs/02.-Server-options.md b/docs/02.-Server-options.md index 98cee556..00546e1b 100644 --- a/docs/02.-Server-options.md +++ b/docs/02.-Server-options.md @@ -21,7 +21,7 @@ All of these options have descriptive comments above them. You should not need t | disable_auth | Bool (False) | Disables API authentication | | disable_fetch_requests | Bool (False) | Disables fetching external content when responding to requests (ex. fetching images from URLs) | | send_tracebacks | Bool (False) | Send server tracebacks to client.

Note: It's not recommended to enable this if sharing the instance with others. | -| api_servers | List[String] (["OAI"]) | API servers to enable. Possible values `"OAI", "Kobold"` | +| api_servers | list[String] (["OAI"]) | API servers to enable. Possible values `"OAI", "Kobold"` | ### Logging Options @@ -58,14 +58,14 @@ Note: Most of the options here will only apply on initial model load/startup (ep | model_dir | String ("models") | Directory to look for models.

Note: Persisted across subsequent load requests | | inline_model_loading | Bool (False) | Enables ability to switch models using the `model` argument in a generation request. More info in [Usage](https://github.com/theroyallab/tabbyAPI/wiki/03.-Usage#inline-loading) | | use_dummy_models | Bool (False) | Send a dummy OAI model card when calling the `/v1/models` endpoint. Used for clients which enforce specific OAI models.

Note: Persisted across subsequent load requests | -| dummy_model_names | List[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests | +| dummy_model_names | list[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests | | model_name | String (None) | Folder name of a model to load. The below parameters will not apply unless this is filled out. | -| use_as_default | List[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.

Note: Also applies to the `draft` sub-block | +| use_as_default | list[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.

Note: Also applies to the `draft` sub-block | | max_seq_len | Float (None) | Maximum sequence length of the model. Uses the value from config.json if not specified here. Also called the max context length. | | tensor_parallel | Bool (False) | Enables tensor parallelism. Automatically falls back to autosplit if GPU split isn't provided.

Note: `gpu_split_auto` is ignored when this is enabled. | | gpu_split_auto | Bool (True) | Automatically split the model across multiple GPUs. Manual GPU split isn't used if this is enabled. | -| autosplit_reserve | List[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.

Represented as an array of MB per GPU used. | -| gpu_split | List[Float] ([]) | Float array of GBs to split a model between GPUs. | +| autosplit_reserve | list[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.

Represented as an array of MB per GPU used. | +| gpu_split | list[Float] ([]) | Float array of GBs to split a model between GPUs. | | rope_scale | Float (1.0) | Adjustment for rope scale (or compress_pos_emb)

Note: If the model has YaRN support, this option will not apply. | | rope_alpha | Float (None) | Adjustment for rope alpha. Leave blank to automatically calculate based on the max_seq_len.

Note: If the model has YaRN support, this option will not apply. | | cache_mode | String ("FP16") | Cache mode for the model.

Options: FP16, Q8, Q6, Q4 | @@ -86,7 +86,7 @@ Note: Sub-block of Model Options. Same rules apply. | draft_rope_scale | Float (1.0) | String: RoPE scale value for the draft model. | | draft_rope_alpha | Float (1.0) | RoPE alpha value for the draft model. Leave blank for auto-calculation. | | draft_cache_mode | String ("FP16") | Cache mode for the draft model.

Options: FP16, Q8, Q6, Q4 | -| draft_gpu_split | List[Float] ([]) | Float array of GBs to split a draft model between GPUs. | +| draft_gpu_split | list[Float] ([]) | Float array of GBs to split a draft model between GPUs. | ### Lora Options @@ -95,7 +95,7 @@ Note: Sub-block of Mode Options. Same rules apply. | Config Option | Type (Default) | Description | |---------------|------------------|--------------------------------------------------------------| | lora_dir | String ("loras") | Directory to look for loras.

Note: Persisted across subsequent load requests | -| loras | List[loras] ([]) | List of lora objects to apply to the model. Each object contains a name and scaling. | +| loras | list[loras] ([]) | List of lora objects to apply to the model. Each object contains a name and scaling. | | name | String (None) | Folder name of a lora to load.

Note: An element of the `loras` key | | scaling | Float (1.0) | "Weight" to apply the lora on the parent model. For example, applying a lora with 0.9 scaling will lower the amount of application on the parent model.

Note: An element of the `loras` key | diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index 5432130b..a58bceba 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -1,6 +1,5 @@ from functools import partial from pydantic import BaseModel, Field, field_validator -from typing import List, Optional from common.sampling import BaseSamplerRequest, get_default_sampler_value from common.utils import unwrap @@ -8,9 +7,9 @@ class GenerateRequest(BaseSamplerRequest): prompt: str - genkey: Optional[str] = None - use_default_badwordsids: Optional[bool] = False - dynatemp_range: Optional[float] = Field( + genkey: str | None = None + use_default_badwordsids: bool | None = False + dynatemp_range: float | None = Field( default_factory=partial(get_default_sampler_value, "dynatemp_range") ) @@ -43,7 +42,7 @@ class GenerateResponseResult(BaseModel): class GenerateResponse(BaseModel): - results: List[GenerateResponseResult] = Field(default_factory=list) + results: list[GenerateResponseResult] = Field(default_factory=list) class StreamGenerateChunk(BaseModel): diff --git a/endpoints/Kobold/types/token.py b/endpoints/Kobold/types/token.py index e6639d94..3f5a9fe6 100644 --- a/endpoints/Kobold/types/token.py +++ b/endpoints/Kobold/types/token.py @@ -1,5 +1,4 @@ from pydantic import BaseModel -from typing import List class TokenCountRequest(BaseModel): @@ -12,4 +11,4 @@ class TokenCountResponse(BaseModel): """Represents a KAI tokenization response.""" value: int - ids: List[int] + ids: list[int] diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 52523149..86118463 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from pydantic import AliasChoices, BaseModel, Field, field_validator from time import time -from typing import Literal, Union, List, Optional, Dict +from typing import Literal from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest @@ -10,11 +12,11 @@ class ChatCompletionLogprob(BaseModel): token: str logprob: float - top_logprobs: Optional[List["ChatCompletionLogprob"]] = Field(default_factory=list) + top_logprobs: list[ChatCompletionLogprob] | None = Field(default_factory=list) class ChatCompletionLogprobs(BaseModel): - content: List[ChatCompletionLogprob] = Field(default_factory=list) + content: list[ChatCompletionLogprob] = Field(default_factory=list) class ChatCompletionImageUrl(BaseModel): @@ -23,58 +25,58 @@ class ChatCompletionImageUrl(BaseModel): class ChatCompletionMessagePart(BaseModel): type: Literal["text", "image_url"] = "text" - text: Optional[str] = None - image_url: Optional[ChatCompletionImageUrl] = None + text: str | None = None + image_url: ChatCompletionImageUrl | None = None class ChatCompletionMessage(BaseModel): role: str = "user" - content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None - tool_calls: Optional[List[ToolCall]] = None - tool_call_id: Optional[str] = None + content: str | list[ChatCompletionMessagePart] | None = None + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None class ChatCompletionRespChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 - finish_reason: Optional[str] = None + finish_reason: str | None = None # let's us understand why it stopped and if we need to generate a tool_call - stop_str: Optional[str] = None + stop_str: str | None = None message: ChatCompletionMessage - logprobs: Optional[ChatCompletionLogprobs] = None + logprobs: ChatCompletionLogprobs | None = None class ChatCompletionStreamChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 - finish_reason: Optional[str] = None - delta: Union[ChatCompletionMessage, dict] = {} - logprobs: Optional[ChatCompletionLogprobs] = None + finish_reason: str | None = None + delta: ChatCompletionMessage | dict = {} + logprobs: ChatCompletionLogprobs | None = None # Inherited from common request class ChatCompletionRequest(CommonCompletionRequest): - messages: List[ChatCompletionMessage] - prompt_template: Optional[str] = None - add_generation_prompt: Optional[bool] = True - template_vars: Optional[dict] = Field( + messages: list[ChatCompletionMessage] + prompt_template: str | None = None + add_generation_prompt: bool | None = True + template_vars: dict | None = Field( default={}, validation_alias=AliasChoices("template_vars", "chat_template_kwargs"), description="Aliases: chat_template_kwargs", ) - response_prefix: Optional[str] = None - model: Optional[str] = None + response_prefix: str | None = None + model: str | None = None # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. - tools: Optional[List[ToolSpec]] = None - functions: Optional[List[Dict]] = None + tools: list[ToolSpec] | None = None + functions: list[dict] | None = None # Chat completions requests do not have a BOS token preference. Backend # respects the tokenization config for the individual model. - add_bos_token: Optional[bool] = None + add_bos_token: bool | None = None @field_validator("add_bos_token", mode="after") def force_bos_token(cls, v): @@ -84,17 +86,17 @@ def force_bos_token(cls, v): class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") - choices: List[ChatCompletionRespChoice] + choices: list[ChatCompletionRespChoice] created: int = Field(default_factory=lambda: int(time())) model: str object: str = "chat.completion" - usage: Optional[UsageStats] = None + usage: UsageStats | None = None class ChatCompletionStreamChunk(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") - choices: List[ChatCompletionStreamChoice] + choices: list[ChatCompletionStreamChoice] created: int = Field(default_factory=lambda: int(time())) model: str object: str = "chat.completion.chunk" - usage: Optional[UsageStats] = None + usage: UsageStats | None = None diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index 16ef2edd..fb7a4a82 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -1,7 +1,6 @@ """Common types for OAI.""" from pydantic import BaseModel, Field -from typing import Optional, Union from common.sampling import BaseSamplerRequest, get_default_sampler_value @@ -10,13 +9,13 @@ class UsageStats(BaseModel): """Represents usage stats.""" prompt_tokens: int - prompt_time: Optional[float] = None - prompt_tokens_per_sec: Optional[Union[float, str]] = None + prompt_time: float | None = None + prompt_tokens_per_sec: float | str | None = None completion_tokens: int - completion_time: Optional[float] = None - completion_tokens_per_sec: Optional[Union[float, str]] = None + completion_time: float | None = None + completion_tokens_per_sec: float | str | None = None total_tokens: int - total_time: Optional[float] = None + total_time: float | None = None class CompletionResponseFormat(BaseModel): @@ -24,7 +23,7 @@ class CompletionResponseFormat(BaseModel): class ChatCompletionStreamOptions(BaseModel): - include_usage: Optional[bool] = False + include_usage: bool | None = False class CommonCompletionRequest(BaseSamplerRequest): @@ -32,29 +31,29 @@ class CommonCompletionRequest(BaseSamplerRequest): # Model information # This parameter is not used, the loaded model is used instead - model: Optional[str] = None + model: str | None = None # Generation info (remainder is in BaseSamplerRequest superclass) - stream: Optional[bool] = False - stream_options: Optional[ChatCompletionStreamOptions] = None - response_format: Optional[CompletionResponseFormat] = Field( + stream: bool | None = False + stream_options: ChatCompletionStreamOptions | None = None + response_format: CompletionResponseFormat | None = Field( default_factory=CompletionResponseFormat ) - n: Optional[int] = Field( + n: int | None = Field( default_factory=lambda: get_default_sampler_value("n", 1), ge=1, ) # Extra OAI request stuff - best_of: Optional[int] = Field( + best_of: int | None = Field( description="Not parsed. Only used for OAI compliance.", default=None ) - echo: Optional[bool] = Field( + echo: bool | None = Field( description="Not parsed. Only used for OAI compliance.", default=False ) - suffix: Optional[str] = Field( + suffix: str | None = Field( description="Not parsed. Only used for OAI compliance.", default=None ) - user: Optional[str] = Field( + user: str | None = Field( description="Not parsed. Only used for OAI compliance.", default=None ) diff --git a/endpoints/OAI/types/completion.py b/endpoints/OAI/types/completion.py index d0a7187e..bb3f8443 100644 --- a/endpoints/OAI/types/completion.py +++ b/endpoints/OAI/types/completion.py @@ -2,7 +2,6 @@ from pydantic import BaseModel, Field from time import time -from typing import Dict, List, Optional, Union from uuid import uuid4 from endpoints.OAI.types.common import CommonCompletionRequest, UsageStats @@ -11,10 +10,10 @@ class CompletionLogProbs(BaseModel): """Represents log probabilities for a completion request.""" - text_offset: List[int] = Field(default_factory=list) - token_logprobs: List[Optional[float]] = Field(default_factory=list) - tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + text_offset: list[int] = Field(default_factory=list) + token_logprobs: list[float | None] = Field(default_factory=list) + tokens: list[str] = Field(default_factory=list) + top_logprobs: list[dict[str, float] | None] = Field(default_factory=list) class CompletionRespChoice(BaseModel): @@ -22,8 +21,8 @@ class CompletionRespChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 - finish_reason: Optional[str] = None - logprobs: Optional[CompletionLogProbs] = None + finish_reason: str | None = None + logprobs: CompletionLogProbs | None = None text: str @@ -33,15 +32,15 @@ class CompletionRequest(CommonCompletionRequest): # Prompt can also contain token ids, but that's out of scope # for this project. - prompt: Union[str, List[str]] + prompt: str | list[str] class CompletionResponse(BaseModel): """Represents a completion response.""" id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}") - choices: List[CompletionRespChoice] + choices: list[CompletionRespChoice] created: int = Field(default_factory=lambda: int(time())) model: str object: str = "text_completion" - usage: Optional[UsageStats] = None + usage: UsageStats | None = None diff --git a/endpoints/OAI/types/embedding.py b/endpoints/OAI/types/embedding.py index 41419c4d..abb68fd7 100644 --- a/endpoints/OAI/types/embedding.py +++ b/endpoints/OAI/types/embedding.py @@ -1,23 +1,21 @@ -from typing import List, Optional, Union - from pydantic import BaseModel, Field class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 - completion_tokens: Optional[int] = 0 + completion_tokens: int | None = 0 class EmbeddingsRequest(BaseModel): - input: Union[str, List[str]] = Field( + input: str | list[str] = Field( ..., description="List of input texts to generate embeddings for." ) encoding_format: str = Field( "float", description="Encoding format for the embeddings. Can be 'float' or 'base64'.", ) - model: Optional[str] = Field( + model: str | None = Field( None, description="Name of the embedding model to use. " "If not provided, the default model will be used.", @@ -26,7 +24,7 @@ class EmbeddingsRequest(BaseModel): class EmbeddingObject(BaseModel): object: str = Field("embedding", description="Type of the object.") - embedding: Union[List[float], str] = Field( + embedding: list[float] | str = Field( ..., description="Embedding values as a list of floats." ) index: int = Field( @@ -36,6 +34,6 @@ class EmbeddingObject(BaseModel): class EmbeddingsResponse(BaseModel): object: str = Field("list", description="Type of the response object.") - data: List[EmbeddingObject] = Field(..., description="List of embedding objects.") + data: list[EmbeddingObject] = Field(..., description="List of embedding objects.") model: str = Field(..., description="Name of the embedding model used.") usage: UsageInfo = Field(..., description="Information about token usage.") diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index b5b9611f..138fda9d 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Dict, Literal +from typing import Literal from uuid import uuid4 @@ -8,7 +8,7 @@ class Function(BaseModel): name: str description: str - parameters: Dict[str, object] + parameters: dict[str, object] class ToolSpec(BaseModel): diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index b559bb2b..157c96ba 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -3,7 +3,6 @@ import asyncio import pathlib from asyncio import CancelledError -from typing import List, Optional from fastapi import HTTPException, Request from jinja2 import TemplateError from loguru import logger @@ -33,7 +32,7 @@ def _create_response( - request_id: str, generations: List[dict], model_name: Optional[str] + request_id: str, generations: list[dict], model_name: str | None ): """Create a chat completion response from the provided text.""" @@ -111,8 +110,8 @@ def _create_response( def _create_stream_chunk( request_id: str, - generation: Optional[dict] = None, - model_name: Optional[str] = None, + generation: dict | None = None, + model_name: str | None = None, is_usage_chunk: bool = False, ): """Create a chat completion stream chunk from the provided text.""" @@ -212,8 +211,8 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars: async def format_messages_with_template( - messages: List[ChatCompletionMessage], - existing_template_vars: Optional[dict] = None, + messages: list[ChatCompletionMessage], + existing_template_vars: dict | None = None, ): """Barebones function to format chat completion messages into a prompt.""" @@ -221,7 +220,7 @@ async def format_messages_with_template( mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None # Convert all messages to a dictionary representation - message_dicts: List[dict] = [] + message_dicts: list[dict] = [] for message in messages: if isinstance(message.content, list): concatenated_content = "" @@ -317,7 +316,7 @@ async def stream_generate_chat_completion( """Generator for the generation process.""" abort_event = asyncio.Event() gen_queue = asyncio.Queue() - gen_tasks: List[asyncio.Task] = [] + gen_tasks: list[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start disconnect_task = asyncio.create_task(request_disconnect_loop(request)) @@ -414,7 +413,7 @@ async def generate_chat_completion( request: Request, model_path: pathlib.Path, ): - gen_tasks: List[asyncio.Task] = [] + gen_tasks: list[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start try: @@ -462,14 +461,14 @@ async def generate_tool_calls( prompt: str, embeddings: MultimodalEmbeddingWrapper, data: ChatCompletionRequest, - generations: List[str], + generations: list[str], request: Request, ): - gen_tasks: List[asyncio.Task] = [] + gen_tasks: list[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start # Tracks which generations asked for a tool call - tool_idx: List[int] = [] + tool_idx: list[int] = [] # Copy to make sure the parent JSON schema doesn't get modified tool_data = data.model_copy(deep=True) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index f66d381d..4de10361 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -9,7 +9,6 @@ from asyncio import CancelledError from fastapi import HTTPException, Request from loguru import logger -from typing import List, Optional, Union from common import model from common.auth import get_key_permission @@ -39,7 +38,7 @@ def _parse_gen_request_id(n: int, request_id: str, task_idx: int): def _create_response( - request_id: str, generations: Union[dict, List[dict]], model_name: str = "" + request_id: str, generations: dict | list[dict], model_name: str = "" ): """Create a completion response from the provided choices.""" @@ -47,7 +46,7 @@ def _create_response( if not isinstance(generations, list): generations = [generations] - choices: List[CompletionRespChoice] = [] + choices: list[CompletionRespChoice] = [] for index, generation in enumerate(generations): logprob_response = None @@ -103,7 +102,7 @@ async def _stream_collector( prompt: str, params: CompletionRequest, abort_event: asyncio.Event, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + mm_embeddings: MultimodalEmbeddingWrapper | None = None, ): """Collects a stream and places results in a common queue""" @@ -200,7 +199,7 @@ async def stream_generate_completion( abort_event = asyncio.Event() gen_queue = asyncio.Queue() - gen_tasks: List[asyncio.Task] = [] + gen_tasks: list[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: @@ -261,7 +260,7 @@ async def generate_completion( ): """Non-streaming generate for completions""" - gen_tasks: List[asyncio.Task] = [] + gen_tasks: list[asyncio.Task] = [] try: logger.info(f"Received completion request {request.state.id}") diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index c1ebdedf..66e466fe 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -1,6 +1,5 @@ import json from loguru import logger -from typing import List from endpoints.OAI.types.tools import ToolCall @@ -30,7 +29,7 @@ class ToolCallProcessor: @staticmethod - def from_json(tool_calls_str: str) -> List[ToolCall]: + def from_json(tool_calls_str: str) -> list[ToolCall]: """Postprocess tool call JSON to a parseable class""" tool_calls = json.loads(tool_calls_str) @@ -42,15 +41,15 @@ def from_json(tool_calls_str: str) -> List[ToolCall]: return [ToolCall(**tool_call) for tool_call in tool_calls] @staticmethod - def dump(tool_calls: List[ToolCall]) -> List[dict]: + def dump(tool_calls: list[ToolCall]) -> list[dict]: """ Convert ToolCall objects to a list of dictionaries. Args: - tool_calls (List[ToolCall]): List of ToolCall objects to convert + tool_calls (list[ToolCall]): list of ToolCall objects to convert Returns: - List[dict]: List of dictionaries representing the tool calls + list[dict]: list of dictionaries representing the tool calls """ # Don't use list comprehension here @@ -64,12 +63,12 @@ def dump(tool_calls: List[ToolCall]) -> List[dict]: return dumped_tool_calls @staticmethod - def to_json(tool_calls: List[ToolCall]) -> str: + def to_json(tool_calls: list[ToolCall]) -> str: """ Convert ToolCall objects to JSON string representation. Args: - tool_calls (List[ToolCall]): List of ToolCall objects to convert + tool_calls (list[ToolCall]): list of ToolCall objects to convert Returns: str: JSON representation of the tool calls diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 800c7d90..d69d683f 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -1,7 +1,6 @@ import asyncio import pathlib from sys import maxsize -from typing import Optional from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi.responses import JSONResponse @@ -403,7 +402,7 @@ async def unload_embedding_model(): async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None + mm_embeddings: MultimodalEmbeddingWrapper | None = None if isinstance(data.text, str): text = data.text diff --git a/endpoints/core/types/download.py b/endpoints/core/types/download.py index cf49501f..c9f8a040 100644 --- a/endpoints/core/types/download.py +++ b/endpoints/core/types/download.py @@ -1,5 +1,4 @@ from pydantic import BaseModel, Field -from typing import List, Optional def _generate_include_list(): @@ -11,13 +10,13 @@ class DownloadRequest(BaseModel): repo_id: str repo_type: str = "model" - folder_name: Optional[str] = None - revision: Optional[str] = None - token: Optional[str] = None - include: List[str] = Field(default_factory=_generate_include_list) - exclude: List[str] = Field(default_factory=list) - chunk_limit: Optional[int] = None - timeout: Optional[int] = None + folder_name: str | None = None + revision: str | None = None + token: str | None = None + include: list[str] = Field(default_factory=_generate_include_list) + exclude: list[str] = Field(default_factory=list) + chunk_limit: int | None = None + timeout: int | None = None class DownloadResponse(BaseModel): diff --git a/endpoints/core/types/health.py b/endpoints/core/types/health.py index ad5fffef..189fa008 100644 --- a/endpoints/core/types/health.py +++ b/endpoints/core/types/health.py @@ -11,5 +11,5 @@ class HealthCheckResponse(BaseModel): "healthy", description="System health status" ) issues: list[UnhealthyEvent] = Field( - default_factory=list, description="List of issues" + default_factory=list, description="list of issues" ) diff --git a/endpoints/core/types/lora.py b/endpoints/core/types/lora.py index 8435a8a4..88e9ba8f 100644 --- a/endpoints/core/types/lora.py +++ b/endpoints/core/types/lora.py @@ -2,7 +2,6 @@ from pydantic import BaseModel, Field from time import time -from typing import Optional, List class LoraCard(BaseModel): @@ -12,32 +11,32 @@ class LoraCard(BaseModel): object: str = "lora" created: int = Field(default_factory=lambda: int(time())) owned_by: str = "tabbyAPI" - scaling: Optional[float] = None + scaling: float | None = None class LoraList(BaseModel): """Represents a list of Lora cards.""" object: str = "list" - data: List[LoraCard] = Field(default_factory=list) + data: list[LoraCard] = Field(default_factory=list) class LoraLoadInfo(BaseModel): """Represents a single Lora load info.""" name: str - scaling: Optional[float] = 1.0 + scaling: float | None = 1.0 class LoraLoadRequest(BaseModel): """Represents a Lora load request.""" - loras: List[LoraLoadInfo] + loras: list[LoraLoadInfo] skip_queue: bool = False class LoraLoadResponse(BaseModel): """Represents a Lora load response.""" - success: List[str] = Field(default_factory=list) - failure: List[str] = Field(default_factory=list) + success: list[str] = Field(default_factory=list) + failure: list[str] = Field(default_factory=list) diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 84229294..ad0d8a85 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -1,8 +1,10 @@ """Contains model card types.""" +from __future__ import annotations + from pydantic import BaseModel, Field, ConfigDict from time import time -from typing import List, Literal, Optional, Union +from typing import Literal from common.config_models import LoggingConfig from common.tabby_config import config @@ -13,19 +15,19 @@ class ModelCardParameters(BaseModel): # Safe to do this since it's guaranteed to fetch a max seq len # from model_container - max_seq_len: Optional[int] = None - cache_size: Optional[int] = None - cache_mode: Optional[str] = "FP16" - rope_scale: Optional[float] = 1.0 - rope_alpha: Optional[float] = 1.0 - max_batch_size: Optional[int] = 1 - chunk_size: Optional[int] = 2048 - prompt_template: Optional[str] = None - prompt_template_content: Optional[str] = None - use_vision: Optional[bool] = False + max_seq_len: int | None = None + cache_size: int | None = None + cache_mode: str | None = "FP16" + rope_scale: float | None = 1.0 + rope_alpha: float | None = 1.0 + max_batch_size: int | None = 1 + chunk_size: int | None = 2048 + prompt_template: str | None = None + prompt_template_content: str | None = None + use_vision: bool | None = False # Draft is another model, so include it in the card params - draft: Optional["ModelCard"] = None + draft: ModelCard | None = None class ModelCard(BaseModel): @@ -35,15 +37,15 @@ class ModelCard(BaseModel): object: str = "model" created: int = Field(default_factory=lambda: int(time())) owned_by: str = "tabbyAPI" - logging: Optional[LoggingConfig] = None - parameters: Optional[ModelCardParameters] = None + logging: LoggingConfig | None = None + parameters: ModelCardParameters | None = None class ModelList(BaseModel): """Represents a list of model cards.""" object: str = "list" - data: List[ModelCard] = Field(default_factory=list) + data: list[ModelCard] = Field(default_factory=list) class DraftModelLoadRequest(BaseModel): @@ -53,13 +55,13 @@ class DraftModelLoadRequest(BaseModel): draft_model_name: str # Config arguments - draft_rope_scale: Optional[float] = None - draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( + draft_rope_scale: float | None = None + draft_rope_alpha: float | Literal["auto"] | None = Field( description='Automatically calculated if set to "auto"', default=None, examples=[1.0], ) - draft_gpu_split: Optional[List[float]] = Field( + draft_gpu_split: list[float] | None = Field( default_factory=list, examples=[[24.0, 20.0]], ) @@ -75,54 +77,54 @@ class ModelLoadRequest(BaseModel): model_name: str # Config arguments - backend: Optional[str] = Field( + backend: str | None = Field( description="Backend to use", default=None, ) - max_seq_len: Optional[int] = Field( + max_seq_len: int | None = Field( description="Leave this blank to use the model's base sequence length", default=None, examples=[4096], ) - cache_size: Optional[int] = Field( + cache_size: int | None = Field( description="Number in tokens, must be multiple of 256", default=None, examples=[4096], ) - cache_mode: Optional[str] = None - tensor_parallel: Optional[bool] = None - tensor_parallel_backend: Optional[str] = "native" - gpu_split_auto: Optional[bool] = None - autosplit_reserve: Optional[List[float]] = None - gpu_split: Optional[List[float]] = Field( + cache_mode: str | None = None + tensor_parallel: bool | None = None + tensor_parallel_backend: str | None = "native" + gpu_split_auto: bool | None = None + autosplit_reserve: list[float] | None = None + gpu_split: list[float] | None = Field( default_factory=list, examples=[[24.0, 20.0]], ) - rope_scale: Optional[float] = Field( + rope_scale: float | None = Field( description="Automatically pulled from the model's config if not present", default=None, examples=[1.0], ) - rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( + rope_alpha: float | Literal["auto"] | None = Field( description='Automatically calculated if set to "auto"', default=None, examples=[1.0], ) - chunk_size: Optional[int] = None - output_chunking: Optional[bool] = True - prompt_template: Optional[str] = None - vision: Optional[bool] = None + chunk_size: int | None = None + output_chunking: bool | None = True + prompt_template: str | None = None + vision: bool | None = None # Non-config arguments - draft_model: Optional[DraftModelLoadRequest] = None - skip_queue: Optional[bool] = False + draft_model: DraftModelLoadRequest | None = None + skip_queue: bool | None = False class EmbeddingModelLoadRequest(BaseModel): embedding_model_name: str # Set default from the config - embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device) + embeddings_device: str | None = Field(config.embeddings.embeddings_device) class ModelLoadResponse(BaseModel): diff --git a/endpoints/core/types/sampler_overrides.py b/endpoints/core/types/sampler_overrides.py index 18627829..2a2efab0 100644 --- a/endpoints/core/types/sampler_overrides.py +++ b/endpoints/core/types/sampler_overrides.py @@ -1,5 +1,4 @@ from pydantic import BaseModel, Field -from typing import List, Optional from common.sampling import SamplerOverridesContainer @@ -7,17 +6,17 @@ class SamplerOverrideListResponse(SamplerOverridesContainer): """Sampler override list response""" - presets: Optional[List[str]] + presets: list[str] | None class SamplerOverrideSwitchRequest(BaseModel): """Sampler override switch request""" - preset: Optional[str] = Field( + preset: str | None = Field( default=None, description="Pass a sampler override preset name" ) - overrides: Optional[dict] = Field( + overrides: dict | None = Field( default=None, description=( "Sampling override parent takes in individual keys and overrides. " diff --git a/endpoints/core/types/template.py b/endpoints/core/types/template.py index a82ef48d..a3b98fb2 100644 --- a/endpoints/core/types/template.py +++ b/endpoints/core/types/template.py @@ -1,12 +1,11 @@ from pydantic import BaseModel, Field -from typing import List class TemplateList(BaseModel): """Represents a list of templates.""" object: str = "list" - data: List[str] = Field(default_factory=list) + data: list[str] = Field(default_factory=list) class TemplateSwitchRequest(BaseModel): diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py index d43e65e4..8a28d880 100644 --- a/endpoints/core/types/token.py +++ b/endpoints/core/types/token.py @@ -1,7 +1,6 @@ """Tokenization types""" from pydantic import BaseModel -from typing import List, Union from endpoints.OAI.types.chat_completion import ChatCompletionMessage @@ -25,20 +24,20 @@ def get_params(self): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: Union[str, List[ChatCompletionMessage]] + text: str | list[ChatCompletionMessage] class TokenEncodeResponse(BaseModel): """Represents a tokenization response.""" - tokens: List[int] + tokens: list[int] length: int class TokenDecodeRequest(CommonTokenRequest): """ " Represents a detokenization request.""" - tokens: List[int] + tokens: list[int] class TokenDecodeResponse(BaseModel): diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 20c9433c..f2b39850 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -1,6 +1,5 @@ import pathlib from asyncio import CancelledError -from typing import Optional from common import model from common.networking import get_generator_error, handle_request_disconnect @@ -13,7 +12,7 @@ ) -def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None): +def get_model_list(model_path: pathlib.Path, draft_model_path: str | None = None): """Get the list of models from the provided path.""" # Convert the provided draft model path to a pathlib path for diff --git a/endpoints/server.py b/endpoints/server.py index c17cbb44..4287f747 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -3,7 +3,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger -from typing import Optional from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends @@ -13,7 +12,7 @@ from endpoints.core.router import router as CoreRouter -def setup_app(host: Optional[str] = None, port: Optional[int] = None): +def setup_app(host: str | None = None, port: int | None = None): """Includes the correct routers for startup""" app = FastAPI( diff --git a/main.py b/main.py index 661613eb..3f497901 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,6 @@ import platform import signal from loguru import logger -from typing import Optional from common import gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser @@ -104,8 +103,8 @@ async def entrypoint_async(): def entrypoint( - args: Optional[argparse.Namespace] = None, - parser: Optional[argparse.ArgumentParser] = None, + args: argparse.Namespace | None = None, + parser: argparse.ArgumentParser | None = None, ): setup_logger() diff --git a/start.py b/start.py index 95bdd366..4811c9ac 100644 --- a/start.py +++ b/start.py @@ -9,7 +9,6 @@ import sys import traceback from shutil import copyfile, which -from typing import List # Checks for uv installation has_uv = which("uv") is not None @@ -154,7 +153,7 @@ def migrate_start_options(start_options: dict): return migrated -def run_pip(command: List[str]): +def run_pip(command: list[str]): if has_uv: command.insert(0, "uv")