Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sh = uv run --no-sync --frozen
.PHONY: install
install:
rm -rf uv.lock
uv sync --all-groups --extra catboost --extra peft --extra sentence-transformers --extra transformers
uv sync --all-groups --extra catboost --extra peft --extra sentence-transformers --extra transformers --extra openai

.PHONY: test
test:
Expand Down
9 changes: 7 additions & 2 deletions src/autointent/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Utils."""

from __future__ import annotations

import importlib
from typing import Any, TypeVar
from typing import TYPE_CHECKING, TypeVar

import torch

if TYPE_CHECKING:
from types import ModuleType

T = TypeVar("T")


Expand All @@ -28,7 +33,7 @@ def detect_device() -> str:
return "cpu"


def require(dependency: str, extra: str | None = None) -> Any: # noqa: ANN401
def require(dependency: str, extra: str | None = None) -> ModuleType:
"""Try to import dependency, raise informative ImportError if missing.

Args:
Expand Down
5 changes: 3 additions & 2 deletions src/autointent/_wrappers/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, config: OpenaiEmbeddingConfig) -> None:

def _get_client(self) -> openai.OpenAI:
"""Get or create OpenAI client instance."""
import openai
openai = require("openai", "openai")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Это в init уже проверяется. Мне кажется что можно просто импортировать

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

без этих штук ruff ругается что не установлен openai + так консистентнее, потому что в большинстве случаев именно через require импорт идет

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Я думаю ты про mypy и такой способ не позволяет валидировать типизацию

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. окей, это не ruff, это какой-то другой линтер
image
  1. хорошо это кажется действительно, тогда я еще подумаю - потому что тогда мы должны отказаться от использования require во всех местах где только можно

Copy link
Copy Markdown
Collaborator Author

@voorhs voorhs Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

тогда мне кажется что стоит отказаться от использования require для импорта - стоит его использовать только чтобы выводить информационное сообщение об отсутствии необходимой зависимости

импорты опциональных зависимостей стоит делать только внутри TYPE_CHECKING секции и внутри функций и методов классов. первое обеспечит тайпчекинг, второе позволит лениво подгружать опциональные зависимости

а на эти проблемы reportMissingImports придется забить

что думаешь?

Copy link
Copy Markdown
Member

@Samoed Samoed Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Да, согласен. В mteb так и делаем


if self._client is None:
self._client = openai.OpenAI(
Expand All @@ -71,7 +71,7 @@ def _get_client(self) -> openai.OpenAI:

def _get_async_client(self) -> openai.AsyncOpenAI:
"""Get or create async OpenAI client instance."""
import openai
openai = require("openai", "openai")

if self._async_client is None:
self._async_client = openai.AsyncOpenAI(
Expand Down Expand Up @@ -308,3 +308,4 @@ def load(cls, path: Path) -> OpenaiEmbeddingBackend:

# Create instance
return cls(config)

8 changes: 4 additions & 4 deletions src/autointent/generation/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ async def _get_structured_output_openai_async(
Returns:
Tuple of (parsed_result, error_message, raw_response).
"""
from openai import LengthFinishReasonError
openai = require("openai", "openai")

res: T | None = None
msg: str | None = None
Expand All @@ -235,7 +235,7 @@ async def _get_structured_output_openai_async(
)
raw = response.choices[0].message.content
res = response.choices[0].message.parsed
except (ValidationError, ValueError, LengthFinishReasonError) as e:
except (ValidationError, ValueError, openai.LengthFinishReasonError) as e:
msg = f"Failed to obtain structured output for model {self.model_name} and messages {messages}: {e!s}"
logger.warning(msg)
else:
Expand Down Expand Up @@ -307,7 +307,7 @@ def _get_structured_output_openai_sync(
Returns:
Tuple of (parsed_result, error_message, raw_response).
"""
from openai import LengthFinishReasonError
openai = require("openai", "openai")

res: T | None = None
msg: str | None = None
Expand All @@ -322,7 +322,7 @@ def _get_structured_output_openai_sync(
)
raw = response.choices[0].message.content
res = response.choices[0].message.parsed
except (ValidationError, ValueError, LengthFinishReasonError) as e:
except (ValidationError, ValueError, openai.LengthFinishReasonError) as e:
msg = f"Failed to obtain structured output for model {self.model_name} and messages {messages}: {e!s}"
logger.warning(msg)
else:
Expand Down
24 changes: 12 additions & 12 deletions src/autointent/modules/scoring/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def get_implicit_initialization_params(self) -> dict[str, Any]:
}

def _initialize_model(self) -> Any: # noqa: ANN401
from transformers import AutoModelForSequenceClassification
transformers = require("transformers", "transformers")

label2id = {i: i for i in range(self._n_classes)}
id2label = {i: i for i in range(self._n_classes)}

return AutoModelForSequenceClassification.from_pretrained(
return transformers.AutoModelForSequenceClassification.from_pretrained(
self.classification_model_config.model_name,
trust_remote_code=self.classification_model_config.trust_remote_code,
num_labels=self._n_classes,
Expand All @@ -147,11 +147,11 @@ def fit(
utterances: list[str],
labels: ListOfLabels,
) -> None:
from transformers import AutoTokenizer
transformers = require("transformers", "transformers")

self._validate_task(labels)

self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name) # type: ignore[no-untyped-call]
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.classification_model_config.model_name)
self._model = self._initialize_model()
tokenized_dataset = self._get_tokenized_dataset(utterances, labels)
self._train(tokenized_dataset)
Expand All @@ -164,10 +164,10 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
Args:
tokenized_dataset: output from :py:meth:`BertScorer._get_tokenized_dataset`
"""
from transformers import DataCollatorWithPadding, PrinterCallback, ProgressCallback, Trainer, TrainingArguments
transformers = require("transformers", "transformers")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
training_args = transformers.TrainingArguments(
output_dir=tmp_dir,
num_train_epochs=self.num_train_epochs,
per_device_train_batch_size=self.batch_size,
Expand All @@ -186,29 +186,29 @@ def _train(self, tokenized_dataset: DatasetDict) -> None:
load_best_model_at_end=self.early_stopping_config.metric is not None,
)

trainer = Trainer(
trainer = transformers.Trainer(
model=self._model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
processing_class=self._tokenizer,
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
data_collator=transformers.DataCollatorWithPadding(tokenizer=self._tokenizer),
compute_metrics=self._get_compute_metrics(),
callbacks=self._get_trainer_callbacks(),
)
if not self.print_progress:
trainer.remove_callback(PrinterCallback)
trainer.remove_callback(ProgressCallback)
trainer.remove_callback(transformers.PrinterCallback)
trainer.remove_callback(transformers.ProgressCallback)

trainer.train()

def _get_trainer_callbacks(self) -> list[TrainerCallback]:
from transformers import EarlyStoppingCallback
transformers = require("transformers", "transformers")

res: list[TrainerCallback] = []
if self.early_stopping_config.metric is not None:
res.append(
EarlyStoppingCallback(
transformers.EarlyStoppingCallback(
early_stopping_patience=self.early_stopping_config.patience,
early_stopping_threshold=self.early_stopping_config.threshold,
)
Expand Down
10 changes: 8 additions & 2 deletions tests/modules/test_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def check_attributes(self):

class TestVectorIndex:
def init_attributes(self):
pytest.importorskip("sentence_transformers", reason="Sentence Transformers library is required for these tests")

self.vector_index = VectorIndex(
embedder_config=initialize_embedder_config("bert-base-uncased"),
config=FaissConfig(),
Expand Down Expand Up @@ -178,6 +176,14 @@ def _transformers_is_installed() -> bool:
id="transformer",
),
TestVectorIndex,
pytest.param(
TestVectorIndex,
marks=pytest.mark.skipif(
not _st_is_installed(),
reason="need sentence-transformers dependency",
),
id="vector_index",
),
pytest.param(
TestEmbedder,
marks=pytest.mark.skipif(
Expand Down