33
44from transformers import AutoTokenizer
55from vllm import LLM , SamplingParams
6+ from vllm .lora .request import LoRARequest
7+ from huggingface_hub import snapshot_download
68
79from bigcodebench .provider .base import DecoderBase
810from bigcodebench .provider .utility import (
1113)
1214
1315class VllmDecoder (DecoderBase ):
14- def __init__ (self , name : str , dataset : str , tp : int , ** kwargs ) -> None :
16+ def __init__ (self , name : str , lora_path : str , dataset : str , tp : int , ** kwargs ) -> None :
1517 super ().__init__ (name , ** kwargs )
1618
1719 kwargs = {
@@ -29,7 +31,17 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
2931 else :
3032 if self .prefill and "```" in self .response_prefix :
3133 self .eos += ["\n ```\n " ]
32- self .llm = LLM (model = name , max_model_len = self .max_new_tokens , ** kwargs )
34+
35+ self .lora_request = None
36+ if lora_path :
37+ local_lora_path = snapshot_download (lora_path )
38+ self .lora_request = LoRARequest (
39+ "lora" ,
40+ 1 ,
41+ local_lora_path ,
42+ )
43+
44+ self .llm = LLM (model = name , max_model_len = self .max_new_tokens , enable_lora = True if self .lora_path else False , ** kwargs )
3345 self .llm .set_tokenizer (tokenizer = self .tokenizer )
3446
3547 def is_direct_completion (self ) -> bool :
@@ -64,6 +76,7 @@ def codegen(
6476 stop = self .eos ,
6577 skip_special_tokens = self .skip_special_tokens ,
6678 ),
79+ lora_request = self .lora_request ,
6780 use_tqdm = True ,
6881 )
6982
0 commit comments