I am trying to reduce the memory footprint of the 2:4 sparsegpt pruned LLaMA2 model using to_sparse_semi_structured method from PyTorch. However, when I apply this to modify the way the sparse parameters are stored, I got out of memory. Please note that I did not get out of memory for the original dense model.
Below is the code I was running, where model_path is the path to the pruned model.
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
model = AutoModelForCausalLM.from_pretrained(model_path)
model = model.to(device).half()
for fqn, module in model.named_modules():
# print(fqn)
if isinstance(module, nn.Linear):
module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))