Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions backends/base_model_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
import asyncio
import pathlib
from loguru import logger
from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
)
from typing import Any, AsyncIterator
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate
Expand All @@ -21,7 +15,7 @@ class BaseModelContainer(abc.ABC):

# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
prompt_template: PromptTemplate | None = None

# HF Model instance
hf_model: HFModel
Expand All @@ -34,7 +28,7 @@ class BaseModelContainer(abc.ABC):
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
active_job_ids: Dict[str, Any] = {}
active_job_ids: dict[str, Any] = {}
loaded: bool = False
load_lock: asyncio.Lock
load_condition: asyncio.Condition
Expand Down Expand Up @@ -98,7 +92,7 @@ async def unload(self, loras_only: bool = False, **kwargs):
pass

@abc.abstractmethod
def encode_tokens(self, text: str, **kwargs) -> List[int]:
def encode_tokens(self, text: str, **kwargs) -> list[int]:
"""
Encodes a string of text into a list of token IDs.

Expand All @@ -113,7 +107,7 @@ def encode_tokens(self, text: str, **kwargs) -> List[int]:
pass

@abc.abstractmethod
def decode_tokens(self, ids: List[int], **kwargs) -> str:
def decode_tokens(self, ids: list[int], **kwargs) -> str:
"""
Decodes a list of token IDs back into a string.

Expand All @@ -128,7 +122,7 @@ def decode_tokens(self, ids: List[int], **kwargs) -> str:
pass

@abc.abstractmethod
def get_special_tokens(self) -> Dict[str, Any]:
def get_special_tokens(self) -> dict[str, Any]:
"""
Gets special tokens used by the model/tokenizer.

Expand Down Expand Up @@ -164,7 +158,7 @@ async def wait_for_jobs(self, skip_wait: bool = False):
# Optional methods
async def load_loras(
self, lora_directory: pathlib.Path, **kwargs
) -> Dict[str, List[str]]:
) -> dict[str, list[str]]:
"""
Loads LoRA adapters. Base implementation does nothing or raises error.

Expand All @@ -184,7 +178,7 @@ async def load_loras(
],
}

def get_loras(self) -> List[Any]:
def get_loras(self) -> list[Any]:
"""
Gets the currently loaded LoRA adapters. Base implementation returns empty list.

Expand All @@ -200,9 +194,9 @@ async def generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> Dict[str, Any]:
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
) -> dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.

Expand All @@ -225,9 +219,9 @@ async def stream_generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> AsyncIterator[Dict[str, Any]]:
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.

Expand Down
7 changes: 3 additions & 4 deletions backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import traceback
import typing
from functools import lru_cache
from typing import List
from typing import Any

import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
Expand All @@ -16,7 +15,7 @@
class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""

filters: List[ExLlamaV2Filter]
filters: list[ExLlamaV2Filter]

def __init__(self):
self.filters = []
Expand Down Expand Up @@ -123,7 +122,7 @@ def __init__(self, nonterminal: str, kbnf_string: str):
self.kbnf_string = kbnf_string

# Return the entire input string as the extracted string
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
def extract(self, input_str: str) -> tuple[str, Any] | None:
return "", input_str

@property
Expand Down
54 changes: 27 additions & 27 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from itertools import zip_longest
from loguru import logger
from typing import Dict, List, Optional

from backends.base_model_container import BaseModelContainer
from backends.exllamav2.grammar import (
Expand Down Expand Up @@ -58,45 +57,45 @@ class ExllamaV2Container(BaseModelContainer):
# Model directories
model_dir: pathlib.Path = pathlib.Path("models")
draft_model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
prompt_template: PromptTemplate | None = None

# HF model instance
hf_model: HFModel

# Exl2 vars
config: Optional[ExLlamaV2Config] = None
model: Optional[ExLlamaV2] = None
cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
prompt_template: Optional[PromptTemplate] = None
config: ExLlamaV2Config | None = None
model: ExLlamaV2 | None = None
cache: ExLlamaV2Cache | None = None
tokenizer: ExLlamaV2Tokenizer | None = None
generator: ExLlamaV2DynamicGeneratorAsync | None = None
prompt_template: PromptTemplate | None = None
paged: bool = True

# Draft model vars
use_draft_model: bool = False
draft_config: Optional[ExLlamaV2Config] = None
draft_model: Optional[ExLlamaV2] = None
draft_cache: Optional[ExLlamaV2Cache] = None
draft_config: ExLlamaV2Config | None = None
draft_model: ExLlamaV2 | None = None
draft_cache: ExLlamaV2Cache | None = None

# Internal config vars
cache_size: int = None
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
max_batch_size: Optional[int] = None
max_batch_size: int | None = None

# GPU split vars
gpu_split: List[float] = []
draft_gpu_split: List[float] = []
gpu_split: list[float] = []
draft_gpu_split: list[float] = []
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 * 1024**2]
autosplit_reserve: list[float] = [96 * 1024**2]
use_tp: bool = False

# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None
vision_model: ExLlamaV2VisionTower | None = None

# Load synchronization
active_job_ids: Dict[str, Optional[ExLlamaV2DynamicJobAsync]] = {}
active_job_ids: dict[str, ExLlamaV2DynamicJobAsync | None] = {}
loaded: bool = False
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
Expand Down Expand Up @@ -272,6 +271,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs
self.config.max_seq_len = unwrap(
user_max_seq_len, min(hf_model.hf_config.max_position_embeddings, 4096)
)
self.cache_size = self.config.max_seq_len

# Set the rope scale
self.config.scale_pos_emb = unwrap(
Expand Down Expand Up @@ -750,9 +750,9 @@ async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))

loras_to_load: List[ExLlamaV2Lora] = []
success: List[str] = []
failure: List[str] = []
loras_to_load: list[ExLlamaV2Lora] = []
success: list[str] = []
failure: list[str] = []

for lora in loras:
lora_name = lora.get("name")
Expand Down Expand Up @@ -869,7 +869,7 @@ def encode_tokens(self, text: str, **kwargs):
.tolist()
)

def decode_tokens(self, ids: List[int], **kwargs):
def decode_tokens(self, ids: list[int], **kwargs):
"""Wrapper to decode tokens from a list of IDs"""

ids = torch.tensor([ids])
Expand Down Expand Up @@ -908,8 +908,8 @@ async def generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""Generate a response to a prompt."""
generations = []
Expand Down Expand Up @@ -969,8 +969,8 @@ async def stream_generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
try:
# Wait for load lock to be freed before processing
Expand Down Expand Up @@ -1242,8 +1242,8 @@ async def generate_gen(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""
Create generator function for prompt completion.
Expand Down
53 changes: 25 additions & 28 deletions backends/exllamav3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
)

from exllamav3 import (
Expand Down Expand Up @@ -49,7 +46,7 @@ class ExllamaV3Container(BaseModelContainer):

# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
prompt_template: PromptTemplate | None = None

# HF Model instance
hf_model: HFModel
Expand All @@ -58,35 +55,35 @@ class ExllamaV3Container(BaseModelContainer):
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
active_job_ids: Dict[str, Any] = {}
active_job_ids: dict[str, Any] = {}
loaded: bool = False
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()

# Exl3 vars
model: Optional[Model] = None
cache: Optional[Cache] = None
draft_model: Optional[Model] = None
draft_cache: Optional[Cache] = None
tokenizer: Optional[Tokenizer] = None
config: Optional[Config] = None
draft_config: Optional[Config] = None
generator: Optional[AsyncGenerator] = None
vision_model: Optional[Model] = None
model: Model | None = None
cache: Cache | None = None
draft_model: Model | None = None
draft_cache: Cache | None = None
tokenizer: Tokenizer | None = None
config: Config | None = None
draft_config: Config | None = None
generator: AsyncGenerator | None = None
vision_model: Model | None = None

# Class-specific vars
gpu_split: Optional[List[float]] = None
gpu_split: list[float] | None = None
gpu_split_auto: bool = True
autosplit_reserve: Optional[List[float]] = [96 / 1024]
autosplit_reserve: list[float] | None = [96 / 1024]
use_tp: bool = False
tp_backend: str = "native"
max_seq_len: int = 4096
cache_size: int = 4096
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
chunk_size: int = 2048
max_rq_tokens: Optional[int] = 2048
max_batch_size: Optional[int] = None
max_rq_tokens: int | None = 2048
max_batch_size: int | None = None

# Required methods
@classmethod
Expand Down Expand Up @@ -579,7 +576,7 @@ async def unload(self, loras_only: bool = False, **kwargs):
async with self.load_condition:
self.load_condition.notify_all()

def encode_tokens(self, text: str, **kwargs) -> List[int]:
def encode_tokens(self, text: str, **kwargs) -> list[int]:
"""
Encodes a string of text into a list of token IDs.

Expand Down Expand Up @@ -607,7 +604,7 @@ def encode_tokens(self, text: str, **kwargs) -> List[int]:
.tolist()
)

def decode_tokens(self, ids: List[int], **kwargs) -> str:
def decode_tokens(self, ids: list[int], **kwargs) -> str:
"""
Decodes a list of token IDs back into a string.

Expand Down Expand Up @@ -666,9 +663,9 @@ async def generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> Dict[str, Any]:
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
) -> dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.

Expand Down Expand Up @@ -738,9 +735,9 @@ async def stream_generate(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> AsyncIterator[Dict[str, Any]]:
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.

Expand Down Expand Up @@ -859,8 +856,8 @@ async def generate_gen(
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
abort_event: asyncio.Event | None = None,
mm_embeddings: MultimodalEmbeddingWrapper | None = None,
):
"""
Create generator function for prompt completion.
Expand Down
Loading
Loading