Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@
"type": "null"
}
],
"default": null,
"default": 200000,
"description": "When set, cap each embeddings API call by the summed tiktoken length of inputs (using the encoding for `model_name`). Requests are also limited to at most `batch_size` strings. Use values around 200000 to avoid OpenAI `max_tokens_per_request` errors on long texts. Requires `tiktoken` (installed with `autointent[openai]`).",
"title": "Max Tokens In Batch"
},
Expand Down
69 changes: 63 additions & 6 deletions src/autointent/_wrappers/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,73 @@

import numpy.typing as npt
import openai
from tiktoken import Encoding
from typing_extensions import NotRequired

from autointent.configs import TaskTypeEnum


logger = logging.getLogger(__name__)

# Third-party embedding model ids (e.g. OpenRouter) are unknown to tiktoken; use a conservative encoding
# only for counting tokens when splitting batches.
_FALLBACK_TIKTOKEN_ENCODING = "cl100k_base"
_ERROR_DETAIL_LIMIT = 2000


def _compact_error_detail(value: object) -> str:
"""Render provider error details without letting huge bodies flood logs/results."""
if isinstance(value, (dict, list, tuple)):
try:
text = json.dumps(value, ensure_ascii=False)
except TypeError:
text = repr(value)
else:
text = str(value)

if len(text) <= _ERROR_DETAIL_LIMIT:
return text
return f"{text[:_ERROR_DETAIL_LIMIT]}... <truncated>"


def _openai_api_error_message(exc: BaseException, *, batch_size: int) -> str:
"""Build a RuntimeError message that preserves useful OpenAI/provider details."""
details = [f"{exc.__class__.__name__}: {_compact_error_detail(exc)}"]

for attr in ("status_code", "code", "type", "body"):
value = getattr(exc, attr, None)
if value is not None:
details.append(f"{attr}={_compact_error_detail(value)}")

response = getattr(exc, "response", None)
if response is not None:
status_code = getattr(response, "status_code", None)
if status_code is not None:
details.append(f"response_status_code={status_code}")

response_text = getattr(response, "text", None)
if response_text:
details.append(f"response_text={_compact_error_detail(response_text)}")

return f"Error calling OpenAI API (batch_size={batch_size}): {'; '.join(details)}"


def _tiktoken_encoding_for_embedding_model(model_name: str) -> Encoding:
"""Resolve tiktoken encoding for batch sizing; fallback for unknown provider model ids."""
require("tiktoken", "openai")
import tiktoken

try:
return tiktoken.encoding_for_model(model_name)
except KeyError:
logger.warning(
"tiktoken has no mapping for embedding model %r; using %r for token counting "
"(per-request batch limits are approximate).",
model_name,
_FALLBACK_TIKTOKEN_ENCODING,
)
return tiktoken.get_encoding(_FALLBACK_TIKTOKEN_ENCODING)


class EmbeddingsCreateKwargs(TypedDict):
input: list[str]
Expand Down Expand Up @@ -208,7 +268,7 @@ def _process_embeddings_sync(self, utterances: list[str]) -> npt.NDArray[np.floa
all_embeddings.extend(batch_embeddings)

except Exception as e:
msg = "Error calling OpenAI API"
msg = _openai_api_error_message(e, batch_size=len(batch))
logger.exception(msg)
raise RuntimeError(msg) from e

Expand Down Expand Up @@ -253,7 +313,7 @@ async def _process_batch_async(self, batch: list[str]) -> list[list[float]]:
response = await client.embeddings.create(**kwargs)
return [data.embedding for data in response.data]
except Exception as e:
msg = f"Error calling OpenAI API for batch: {e}"
msg = _openai_api_error_message(e, batch_size=len(batch))
logger.exception(msg)
raise RuntimeError(msg) from e

Expand Down Expand Up @@ -325,10 +385,7 @@ def _batch_strings_by_token_budget(
if max_tokens_per_batch is None:
return [texts[i : i + max_strings_per_batch] for i in range(0, len(texts), max_strings_per_batch)]

require("tiktoken", "openai")
import tiktoken

encoding = tiktoken.encoding_for_model(model_name)
encoding = _tiktoken_encoding_for_embedding_model(model_name)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Мб добавить запросы на /v1/tokenize? Большинство openai compatible api поддерживают

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

у openai и openrouter такого нет; я нашел что у vllm и sglang есть

как фича прикольно, как вариант - добавить опцию token_counter: Literal["tiktoken", "endpoint", "transformers"] в OpenaiEmbeddingConfig

пока нет критической необходимости в этом, такие грубые оценки токенов норм для текущих задач

batches: list[list[str]] = []
current_batch: list[str] = []
current_tokens = 0
Expand Down
2 changes: 1 addition & 1 deletion src/autointent/configs/_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class OpenaiEmbeddingConfig(BaseEmbedderConfig):
model_name: str = Field("text-embedding-3-small", description="Name of the OpenAI embedding model.")
batch_size: int = Field(100, description="Batch size for API requests.")
max_tokens_in_batch: PositiveInt | None = Field(
None,
200_000,
description=(
"When set, cap each embeddings API call by the summed tiktoken length of inputs "
"(using the encoding for `model_name`). Requests are also limited to at most "
Expand Down
29 changes: 29 additions & 0 deletions tests/embedder/test_openai_token_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from autointent._wrappers.embedder.openai import ( # noqa: E402
OpenaiEmbeddingBackend,
_batch_strings_by_token_budget,
_openai_api_error_message,
)
from autointent.configs import OpenaiEmbeddingConfig # noqa: E402

Expand All @@ -23,6 +24,17 @@ def test_batch_strings_none_max_tokens_uses_batch_size_only() -> None:
assert batches == [["a", "b", "c"], ["d", "e", "f"]]


def test_batch_strings_unknown_model_uses_fallback_encoding() -> None:
"""Third-party embedding ids (e.g. OpenRouter) are not in tiktoken's model map."""
batches = _batch_strings_by_token_budget(
["hello", "world"],
model_name="qwen/qwen3-embedding-8b",
max_strings_per_batch=10,
max_tokens_per_batch=100,
)
assert batches == [["hello", "world"]]


def test_batch_strings_respects_token_budget() -> None:
encoding = tiktoken.encoding_for_model("text-embedding-3-small")
batches = _batch_strings_by_token_budget(
Expand Down Expand Up @@ -61,3 +73,20 @@ def test_embedding_request_batches_on_backend() -> None:
backend = OpenaiEmbeddingBackend(config)
batches = backend._embedding_request_batches(["hello"] * 12)
assert sum(len(b) for b in batches) == 12


def test_openai_api_error_message_preserves_provider_details() -> None:
class ProviderError(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)
self.status_code = 400
self.code = "context_length_exceeded"
self.body = {"error": {"message": "input is too long"}}

message = _openai_api_error_message(ProviderError("No embedding data received"), batch_size=3)

assert "Error calling OpenAI API (batch_size=3)" in message
assert "ProviderError: No embedding data received" in message
assert "status_code=400" in message
assert "code=context_length_exceeded" in message
assert "input is too long" in message