-
Notifications
You must be signed in to change notification settings - Fork 65
Description
import torch
from bminf.wrapper import wrapper
from bminf.scheduler import TransformerBlockList
from transformers import RobertaTokenizer, RobertaForQuestionAnswering
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForQuestionAnswering.from_pretrained('roberta-base')
将模型移动到CUDA设备上
model.to('cuda')
使用wrapper函数对模型进行量化
quantized_model = wrapper(model, quantization=True)
确保quantized_model是一个包含nn.Module的列表
quantized_model_list = [quantized_model] # 将模型封装成列表
print(quantized_model)
创建TransformerBlockList实例
gpus = [(0, None)] # 假设有一个GPU,编号为0
torch.cuda.empty_cache()
创建TransformerBlockList实例
scheduled_model = TransformerBlockList(quantized_model_list, gpus)
文本数据
texts = ["Hello, my dog is cute", "I love natural language processing"]
定义一个函数来确保序列长度是4的倍数
def pad_to_multiple_of_4(input_ids, attention_mask):
seq_len = input_ids.shape[-1]
padding_length = (4 - seq_len % 4) % 4 # 确保padding_length是0到3之间的值
if padding_length > 0:
input_ids = torch.cat([input_ids, torch.zeros((input_ids.shape[0], padding_length), dtype=torch.long)], dim=1)
attention_mask = torch.cat([attention_mask, torch.zeros((attention_mask.shape[0], padding_length), dtype=torch.long)], dim=1)
return input_ids, attention_mask
使用分词器编码文本,并确保序列长度是4的倍数
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", padding_side="right", max_length=512)
input_ids = inputs['input_ids']
attention_mask = inputs.get('attention_mask')
input_ids, attention_mask = pad_to_multiple_of_4(input_ids, attention_mask)
验证序列长度
if input_ids.shape[-1] % 4 != 0:
raise ValueError("Input sequence length is not a multiple of 4 after padding.")
执行前向传播
with torch.no_grad():
output = scheduled_model(input_ids, attention_mask=attention_mask)
输出结果
print(output)
Traceback (most recent call last):
File "D:\LLM\模型压缩\BMinf_\demo.py", line 52, in
output = scheduled_model(input_ids, attention_mask=attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\torch\nn\modules\module.py", line 1736, in wrapped_call_impl
return self.call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\torch\nn\modules\module.py", line 1747, in call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\bminf\scheduler_init.py", line 417, in forward
x = sched.forward(x, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\bminf\scheduler_init.py", line 348, in forward
return OpDeviceLayer.apply(self, x, len(kwargs), *cuda_lst)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\torch\autograd\function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\bminf\scheduler_init.py", line 105, in forward
hidden_state = self._layers[i](hidden_state, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\transformers\models\roberta\modeling_roberta.py", line 1634, in forward
logits = self.qa_outputs(sequence_output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
return self.call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\torch\nn\modules\module.py", line 1747, in call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\bminf\quantization_init.py", line 81, in forward
out = OpLinear.apply(x, self.weight_quant, self.weight_scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\torch\autograd\function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\bminf\quantization_init.py", line 31, in forward
gemm_int8(
File "C:\Users\Zeng xiang xi.conda\envs\LLM\Lib\site-packages\cpm_kernels\kernels\gemm.py", line 139, in gemm_int8
assert m % 4 == 0 and n % 4 == 0 and k % 4 == 0
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError