Skip to content

Commit df93637

Browse files
authored
chore: Update CI to run on external PRs, fix test import (#303)
* Updated CI * Fixed import
1 parent d83d247 commit df93637

2 files changed

Lines changed: 12 additions & 7 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
name: Run tests and upload coverage
22

33
on:
4-
push
4+
push:
5+
branches:
6+
- main
7+
pull_request:
58

69
jobs:
710
test:

tests/test_distillation.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import numpy as np
99
import pytest
1010
from pytest import LogCaptureFixture
11-
from transformers import BertTokenizerFast
1211
from transformers.modeling_utils import PreTrainedModel
12+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1313

1414
from model2vec.distill.distillation import distill, distill_from_model
1515
from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings
@@ -42,7 +42,7 @@
4242
def test_distill_from_model(
4343
mock_auto_model: MagicMock,
4444
mock_model_info: MagicMock,
45-
mock_berttokenizer: BertTokenizerFast,
45+
mock_berttokenizer: PreTrainedTokenizerFast,
4646
mock_transformer: PreTrainedModel,
4747
vocabulary: list[str] | None,
4848
pca_dims: int | None,
@@ -83,7 +83,7 @@ def test_distill_from_model(
8383
def test_distill_removal_pattern(
8484
mock_auto_model: MagicMock,
8585
mock_model_info: MagicMock,
86-
mock_berttokenizer: BertTokenizerFast,
86+
mock_berttokenizer: PreTrainedTokenizerFast,
8787
mock_transformer: PreTrainedModel,
8888
) -> None:
8989
"""Test the removal pattern."""
@@ -180,10 +180,12 @@ def test_distill(
180180
def test_missing_modelinfo(
181181
mock_model_info: MagicMock,
182182
mock_transformer: PreTrainedModel,
183-
mock_berttokenizer: BertTokenizerFast,
183+
mock_berttokenizer: PreTrainedTokenizerFast,
184184
) -> None:
185185
"""Test that missing model info does not crash."""
186-
mock_model_info.side_effect = RepositoryNotFoundError("Model not found")
186+
mock_response = MagicMock()
187+
mock_response.status_code = 404
188+
mock_model_info.side_effect = RepositoryNotFoundError("Model not found", response=mock_response)
187189
static_model = distill_from_model(model=mock_transformer, tokenizer=mock_berttokenizer, device="cpu")
188190
assert static_model.language is None
189191

@@ -237,7 +239,7 @@ def test__post_process_embeddings(
237239
],
238240
)
239241
def test_clean_and_create_vocabulary(
240-
mock_berttokenizer: BertTokenizerFast,
242+
mock_berttokenizer: PreTrainedTokenizerFast,
241243
added_tokens: list[str],
242244
expected_output: list[str],
243245
expected_warnings: list[str],

0 commit comments

Comments
 (0)