-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm.py
More file actions
62 lines (54 loc) · 2.64 KB
/
llm.py
File metadata and controls
62 lines (54 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
os.environ["CURL_CA_BUNDLE"]=""
os.environ["REQUESTS_CA_BUNDLE"]=""
import urllib3
import warnings
warnings.filterwarnings("ignore", category=urllib3.exceptions.InsecureRequestWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import AutoPeftModelForCausalLM
from constants import HF_CACHE, HF_TOKEN
os.environ["HF_HOME"] = HF_CACHE
class LLM():
def __init__(self, path, peft=False, use_flash_attention=False):
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=True,
)
LMModel = AutoPeftModelForCausalLM if peft==True else AutoModelForCausalLM
self.model = LMModel.from_pretrained(
pretrained_model_name_or_path=path,
cache_dir=HF_CACHE,
quantization_config=quant_config,
torch_dtype=torch.bfloat16 if use_flash_attention else torch.float16,
device_map='auto',
token=HF_TOKEN,
attn_implementation="flash_attention_2" if use_flash_attention else None,
)
self.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=path,
cache_dir=HF_CACHE,
token=HF_TOKEN,
)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.name = Path(path).stem
def label_states(self, batch_sentences, labels):
with torch.no_grad():
batch_token_ids = self.tokenizer(batch_sentences, return_tensors='pt', padding=True).to("cuda")
batch_nb_tokens = batch_token_ids.attention_mask.sum(-1)
batch_outputs = self.model(**batch_token_ids, output_hidden_states=True)
batch_logits = batch_outputs.logits
batch_logits = torch.stack([batch_logits[i, nb_tokens-1, :] for i, nb_tokens in enumerate(batch_nb_tokens)]).cpu()
last_hidden_states = torch.stack(batch_outputs.hidden_states, dim=1)[:, 1:, -1, :].cpu()
label_tokens = self.tokenizer.convert_tokens_to_ids(labels)
for token, token_id in zip(labels, label_tokens):
if token_id==None:
raise ValueError(f"Token '{token}' not in vocabulary.")
logits = torch.stack([batch_logits[:, token] for token in label_tokens], dim=1)
return last_hidden_states.squeeze(0), logits.squeeze(0)