-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmerge.py
More file actions
135 lines (117 loc) · 3.74 KB
/
merge.py
File metadata and controls
135 lines (117 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Merge LoRA weights with the original model into a new model
"""
import json
from pathlib import Path
import time
from typing import List, Tuple, Any
from typing import Optional
import torch
from dataclasses import dataclass, field, asdict
from transformers import HfArgumentParser
from torch import Tensor
from transformers import AutoTokenizer
from transformers.modeling_outputs import BaseModelOutputWithPast
from accelerate import Accelerator, InitProcessGroupKwargs
from src import get_model_and_tokenizer,ModelArgs
from src.util import DatasetProcessFn, add_eos
from datetime import timedelta
MAX_POSITION_ID = 256 * 1024 # Determined by the model
TRUNCATE_LEN = 256 * 1024
device = torch.device("cuda")
@dataclass
class Args(ModelArgs):
eval_data: str = field(
default="activation-beacon:lm/pg19.json",
metadata={'help': 'The evaluation json data path.'}
)
output_dir: str = field(
default="data/results/lm/",
metadata={'help': 'Output directory for results and logs.'}
)
retokenize: bool = field(
default=False,
metadata={'help': 'Retokenize the corpus?'}
)
tokenize_max_char: Optional[int] = field(
default=None,
metadata={'help': 'The number of chars to truncate.'}
)
batch_size: int = field(
default=1,
metadata={'help': 'Evaluation batch size.'}
)
padding_side: str = field(
default="right",
metadata={'help': 'Which side to pad?'}
)
stride: int = field(
default=2048,
metadata={'help': 'Streaming stride when evaluating perplexity.'}
)
max_sample_num: int = field(
default=100,
metadata={'help': 'How many samples to evaluate in eval_data?'}
)
min_length: Optional[int] = field(
default=None,
metadata={'help': 'Minimum length for input_ids.'}
)
###### RAMP ######
training_stage: str = field(
default="finetune",
metadata={'help': 'What training stage? pretrain or finetune.'}
)
mp_layer_num: int = field(
default=3,
metadata={'help': 'The number of MP layers'}
)
c_ratio: float = field(
default=0.2,
metadata={'help': 'Compression Ratio on Node Content'}
)
###### RAMP ######
###### RAMP Merge ######
lora_path: str = field(
default="./",
metadata={'help': 'lora path'}
)
final_output_dir: str = field(
default="./",
metadata={'help': 'final_output_dir'}
)
###### RAMP Merge ######
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
# increase timeout to avoid error
accelerator = Accelerator(cpu=args.cpu, kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=100000))])
model, tokenizer = get_model_and_tokenizer(args, mp_layer_num=args.mp_layer_num, accelerator=accelerator)
if __name__ == "__main__":
################### merge model ###################
from peft import (
LoraConfig,
get_peft_model,
PeftModel,
PeftConfig
)
lora_path = args.lora_path
final_output_dir = args.final_output_dir
print(f"Loading LoRA adapter from {lora_path}")
lora_model = PeftModel.from_pretrained(
model,
lora_path,
torch_dtype=torch.float16
)
# Merge weights and unload adapter
print("Merging LoRA weights with base model")
merged_model = lora_model.merge_and_unload()
# Save merged model
print(f"Saving merged model to {final_output_dir}")
# Confirm precision
merged_model = merged_model.to(torch.float16)
merged_model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print("Merge completed successfully!")
# Exit after first save
exit(0)
################### merge model ###################