Skip to content

Commit 700ecae

Browse files
committed
feat: add legacy option
1 parent 8a4b402 commit 700ecae

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

bigcodebench/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def main():
119119
parser.add_argument("--base_url", default=None, type=str)
120120
parser.add_argument("--tp", default=1, type=int)
121121
parser.add_argument("--trust_remote_code", action="store_true")
122+
parser.add_argument("--tokenizer_legacy", action="store_true")
122123
parser.add_argument("--tokenizer_name", default=None, type=str)
123124

124125
args = parser.parse_args()
@@ -144,7 +145,8 @@ def main():
144145
base_url=args.base_url,
145146
tp=args.tp,
146147
trust_remote_code=args.trust_remote_code,
147-
tokenizer_name=args.tokenizer_name
148+
tokenizer_name=args.tokenizer_name,
149+
tokenizer_legacy=args.tokenizer_legacy
148150
)
149151

150152
extra = "-" + args.subset if args.subset != "full" else ""

bigcodebench/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
dtype: str = "bfloat16", # default
9393
trust_remote_code: bool = False,
9494
tokenizer_name: str = None,
95+
tokenizer_legacy: bool = False,
9596
) -> None:
9697
print("Initializing a decoder model: {} ...".format(name))
9798
self.name = name
@@ -103,6 +104,7 @@ def __init__(
103104
self.dtype = dtype
104105
self.trust_remote_code = trust_remote_code
105106
self.tokenizer_name = tokenizer_name
107+
self.tokenizer_legacy = tokenizer_legacy
106108

107109
@abstractmethod
108110
def codegen(
@@ -133,7 +135,7 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
133135
if self.tokenizer_name is None:
134136
self.tokenizer_name = self.name
135137

136-
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs)
138+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=not self.tokenizer_legacy)
137139
if self.tokenizer.chat_template is None:
138140
self.eos += extra_eos_for_direct_completion(dataset)
139141
self.llm = LLM(model=name, max_model_len=2048, **kwargs)
@@ -193,7 +195,7 @@ def __init__(self, name: str, dataset: str, **kwargs):
193195
if self.tokenizer_name is None:
194196
self.tokenizer_name = self.name
195197

196-
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs)
198+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=not self.tokenizer_legacy)
197199

198200
if self.tokenizer.chat_template is None:
199201
self.eos += extra_eos_for_direct_completion(dataset)
@@ -249,7 +251,8 @@ def __init__(self, name: str, **kwargs):
249251
super().__init__(name=name, **kwargs)
250252
self.eos += ["\n```\n"]
251253
print(f"EOS strings: {self.eos}")
252-
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name if self.tokenizer_name else self.name, **kwargs)
254+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name if self.tokenizer_name else self.name,
255+
**kwargs, legacy=not self.tokenizer_legacy)
253256

254257
def codegen(
255258
self, prompt: str, do_sample: bool = True, num_samples: int = 200
@@ -483,6 +486,7 @@ def make_model(
483486
base_url=None,
484487
trust_remote_code=False,
485488
tokenizer_name=None,
489+
tokenizer_legacy=True,
486490
):
487491
if backend == "vllm":
488492
return GeneralVllmDecoder(
@@ -493,6 +497,7 @@ def make_model(
493497
tp=tp,
494498
trust_remote_code=trust_remote_code,
495499
tokenizer_name=tokenizer_name,
500+
tokenizer_legacy=tokenizer_legacy,
496501
)
497502
elif backend == "hf":
498503
return GenenralHfTorchDecoder(
@@ -502,6 +507,7 @@ def make_model(
502507
dataset=dataset,
503508
trust_remote_code=trust_remote_code,
504509
tokenizer_name=tokenizer_name,
510+
tokenizer_legacy=tokenizer_legacy,
505511
)
506512
elif backend == "openai":
507513
return OpenAIChatDecoder(

0 commit comments

Comments
 (0)