@@ -14,15 +14,26 @@ class HuggingFaceVectoriser(VectoriserBase):
1414 tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the specified model.
1515 model (transformers.PreTrainedModel): The Huggingface model instance.
1616 device (torch.device): The device (CPU or GPU) on which the model is loaded.
17+ tokenizer_kwargs (dict): Additional keyword arguments passed to the tokenizer.
18+ model_kwargs (dict): Additional keyword arguments passed to the model.
1719 """
1820
19- def __init__ (self , model_name , device = None , model_revision = "main" ):
21+ def __init__ (
22+ self ,
23+ model_name ,
24+ device = None ,
25+ model_revision = "main" ,
26+ tokenizer_kwargs : dict | None = None ,
27+ model_kwargs : dict | None = None ,
28+ ):
2029 """Initializes the HuggingfaceVectoriser with the specified model name and device.
2130
2231 Args:
2332 model_name (str): The name of the Huggingface model to use.
2433 device (torch.device, optional): The device to use for computation. Defaults to GPU if available, otherwise CPU.
2534 model_revision (str, optional): The specific model revision to use. Defaults to "main".
35+ tokenizer_kwargs (dict, optional): Additional keyword arguments to pass to the tokenizer. Defaults to None.
36+ model_kwargs (dict, optional): Additional keyword arguments to pass to the model. Defaults to None.
2637
2738 Raises:
2839 ExternalServiceError: If the model or tokenizer cannot be loaded.
@@ -33,9 +44,17 @@ def __init__(self, model_name, device=None, model_revision="main"):
3344 from transformers import AutoModel , AutoTokenizer # type: ignore
3445
3546 self .model_name = model_name
47+
48+ tokenizer_kwargs = dict (tokenizer_kwargs or {})
49+ model_kwargs = dict (model_kwargs or {})
50+
51+ # Ensure consistent behavior unless user overrides it
52+ tokenizer_kwargs .setdefault ("trust_remote_code" , False )
53+ model_kwargs .setdefault ("trust_remote_code" , False )
54+
3655 try :
37- self .tokenizer = AutoTokenizer .from_pretrained (model_name , revision = model_revision ) # nosec: B615
38- self .model = AutoModel .from_pretrained (model_name , revision = model_revision ) # nosec: B615
56+ self .tokenizer = AutoTokenizer .from_pretrained (model_name , revision = model_revision , ** tokenizer_kwargs ) # nosec: B615
57+ self .model = AutoModel .from_pretrained (model_name , revision = model_revision , ** model_kwargs ) # nosec: B615
3958 except Exception as e :
4059 raise ExternalServiceError (
4160 "Failed to load HuggingFace model/tokenizer." ,
0 commit comments