Skip to content

Commit 6d96733

Browse files
committed
feat: support vllm lora
1 parent f087e3b commit 6d96733

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

bigcodebench/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def run_codegen(
127127
split: str,
128128
subset: str,
129129
root: str = "bcb_results",
130+
lora_path: str = None,
130131
bs: Optional[int] = None,
131132
n_samples: int = 1,
132133
temperature: float = 0.0,
@@ -174,6 +175,7 @@ def run_codegen(
174175
backend=backend,
175176
subset=subset,
176177
split=split,
178+
lora_path=lora_path,
177179
temperature=temperature,
178180
max_new_tokens=max_new_tokens,
179181
reasoning_effort=reasoning_effort,

bigcodebench/provider/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ def make_model(
66
backend: str,
77
subset: str,
88
split: str,
9+
lora_path: str = None,
910
dataset: str = "bigcodebench",
1011
temperature: float = 0.0,
1112
max_new_tokens: int = 1280,
@@ -38,6 +39,7 @@ def make_model(
3839
name=model,
3940
subset=subset,
4041
split=split,
42+
lora_path=lora_path,
4143
temperature=temperature,
4244
max_new_tokens=max_new_tokens,
4345
revision=revision,
@@ -58,6 +60,7 @@ def make_model(
5860
name=model,
5961
subset=subset,
6062
split=split,
63+
lora_path=lora_path,
6164
temperature=temperature,
6265
max_new_tokens=max_new_tokens,
6366
revision=revision,

bigcodebench/provider/vllm.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from transformers import AutoTokenizer
55
from vllm import LLM, SamplingParams
6+
from vllm.lora.request import LoRARequest
7+
from huggingface_hub import snapshot_download
68

79
from bigcodebench.provider.base import DecoderBase
810
from bigcodebench.provider.utility import (
@@ -11,7 +13,7 @@
1113
)
1214

1315
class VllmDecoder(DecoderBase):
14-
def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
16+
def __init__(self, name: str, lora_path: str, dataset: str, tp: int, **kwargs) -> None:
1517
super().__init__(name, **kwargs)
1618

1719
kwargs = {
@@ -29,7 +31,17 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
2931
else:
3032
if self.prefill and "```" in self.response_prefix:
3133
self.eos += ["\n```\n"]
32-
self.llm = LLM(model=name, max_model_len=self.max_new_tokens, **kwargs)
34+
35+
self.lora_request = None
36+
if lora_path:
37+
local_lora_path = snapshot_download(lora_path)
38+
self.lora_request = LoRARequest(
39+
"lora",
40+
1,
41+
local_lora_path,
42+
)
43+
44+
self.llm = LLM(model=name, max_model_len=self.max_new_tokens, enable_lora=True if self.lora_path else False, **kwargs)
3345
self.llm.set_tokenizer(tokenizer=self.tokenizer)
3446

3547
def is_direct_completion(self) -> bool:
@@ -64,6 +76,7 @@ def codegen(
6476
stop=self.eos,
6577
skip_special_tokens=self.skip_special_tokens,
6678
),
79+
lora_request=self.lora_request,
6780
use_tqdm=True,
6881
)
6982

0 commit comments

Comments
 (0)