|
8 | 8 | import numpy as np |
9 | 9 | import pytest |
10 | 10 | from pytest import LogCaptureFixture |
11 | | -from transformers import BertTokenizerFast |
12 | 11 | from transformers.modeling_utils import PreTrainedModel |
| 12 | +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
13 | 13 |
|
14 | 14 | from model2vec.distill.distillation import distill, distill_from_model |
15 | 15 | from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings |
|
42 | 42 | def test_distill_from_model( |
43 | 43 | mock_auto_model: MagicMock, |
44 | 44 | mock_model_info: MagicMock, |
45 | | - mock_berttokenizer: BertTokenizerFast, |
| 45 | + mock_berttokenizer: PreTrainedTokenizerFast, |
46 | 46 | mock_transformer: PreTrainedModel, |
47 | 47 | vocabulary: list[str] | None, |
48 | 48 | pca_dims: int | None, |
@@ -83,7 +83,7 @@ def test_distill_from_model( |
83 | 83 | def test_distill_removal_pattern( |
84 | 84 | mock_auto_model: MagicMock, |
85 | 85 | mock_model_info: MagicMock, |
86 | | - mock_berttokenizer: BertTokenizerFast, |
| 86 | + mock_berttokenizer: PreTrainedTokenizerFast, |
87 | 87 | mock_transformer: PreTrainedModel, |
88 | 88 | ) -> None: |
89 | 89 | """Test the removal pattern.""" |
@@ -180,10 +180,12 @@ def test_distill( |
180 | 180 | def test_missing_modelinfo( |
181 | 181 | mock_model_info: MagicMock, |
182 | 182 | mock_transformer: PreTrainedModel, |
183 | | - mock_berttokenizer: BertTokenizerFast, |
| 183 | + mock_berttokenizer: PreTrainedTokenizerFast, |
184 | 184 | ) -> None: |
185 | 185 | """Test that missing model info does not crash.""" |
186 | | - mock_model_info.side_effect = RepositoryNotFoundError("Model not found") |
| 186 | + mock_response = MagicMock() |
| 187 | + mock_response.status_code = 404 |
| 188 | + mock_model_info.side_effect = RepositoryNotFoundError("Model not found", response=mock_response) |
187 | 189 | static_model = distill_from_model(model=mock_transformer, tokenizer=mock_berttokenizer, device="cpu") |
188 | 190 | assert static_model.language is None |
189 | 191 |
|
@@ -237,7 +239,7 @@ def test__post_process_embeddings( |
237 | 239 | ], |
238 | 240 | ) |
239 | 241 | def test_clean_and_create_vocabulary( |
240 | | - mock_berttokenizer: BertTokenizerFast, |
| 242 | + mock_berttokenizer: PreTrainedTokenizerFast, |
241 | 243 | added_tokens: list[str], |
242 | 244 | expected_output: list[str], |
243 | 245 | expected_warnings: list[str], |
|
0 commit comments