Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 83 additions & 72 deletions CLI.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,26 @@
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Optional
from typing import Dict, List, Literal, Optional, Tuple
from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable

import lightning as L
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import models.vqvae as vqvae
from generate import generate
from lit_llama import Tokenizer, LLaMA, LLaMAConfig
from lit_llama.lora import lora
from lit_llama.utils import EmptyInitOnDevice
from lit_gpt.utils import lazy_load
from scripts.video_dataset.prepare_video_dataset_video_llava import generate_prompt_mlp
from options import option
import imageio
from tqdm import tqdm
from models.multimodal_encoder.builder import build_image_tower, build_video_tower
from models.multimodal_projector.builder import build_vision_projector

warnings.filterwarnings('ignore')

args = option.get_args_parser()


class LlavaMetaModel:

Expand Down Expand Up @@ -57,16 +48,13 @@ def get_video_tower(self):
video_tower = video_tower[0]
return video_tower


def get_all_tower(self, keys):
tower = {key: getattr(self, f'get_{key}_tower') for key in keys}
return tower


def load_video_tower_pretrained(self, pretrained_checkpoint):
self.mm_projector.load_state_dict(pretrained_checkpoint, strict=True)


def initialize_image_modules(self, model_args, fsdp=None):
image_tower = model_args.image_tower
mm_vision_select_layer = model_args.mm_vision_select_layer
Expand All @@ -92,6 +80,7 @@ def initialize_image_modules(self, model_args, fsdp=None):

if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')

def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}

Expand Down Expand Up @@ -122,6 +111,7 @@ def initialize_video_modules(self, model_args, fsdp=None):

if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')

def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}

Expand All @@ -133,27 +123,26 @@ def encode_images(self, images):
return image_features

def encode_videos(self, videos):
video_features = self.get_video_tower()(videos) # torch.Size([1, 2048, 1024])
video_features = self.mm_projector(video_features.float()) # torch.Size([1, 2048, 4096])
video_features = self.get_video_tower()(videos) # torch.Size([1, 2048, 1024])
video_features = self.mm_projector(video_features.float()) # torch.Size([1, 2048, 4096])

return video_features

def get_multimodal_embeddings(self, X_modalities):
Xs, keys= X_modalities
Xs, keys = X_modalities
X_features = getattr(self, f'encode_{keys[0]}s')(Xs) # expand to get batchsize

return X_features

def get_processor(X, config, device, pretrained_checkpoint_tower, model_path = 'LanguageBind/Video-LLaVA-7B'):


def get_processor(X, config, device, pretrained_checkpoint_tower, model_path='LanguageBind/Video-LLaVA-7B'):
processor = {}

mm_backbone_mlp_model = LlavaMetaModel(config, pretrained_checkpoint_tower)

print(X)
print(X)
if 'Image' in X:
image_tower = mm_backbone_mlp_model.get_image_tower() # LanguageBindImageTower()
image_tower = mm_backbone_mlp_model.get_image_tower() # LanguageBindImageTower()
if not image_tower.is_loaded:
image_tower.load_model()
image_tower.to(device=device, dtype=torch.float16)
Expand All @@ -175,6 +164,7 @@ class Projection(nn.Module):
def __init__(self, ):
super().__init__()
self.linear_proj = nn.Linear(512, 4096)

def forward(self, x):
return self.linear_proj(x)

Expand All @@ -187,24 +177,23 @@ def __init__(self, ):
nn.GELU(),
nn.Linear(4096, 4096)
)

def forward(self, x):
return self.proj(x)


def main(
def load_model(
args: any,
quantize: Optional[str] = None,
dtype: str = "float32",
max_new_tokens: int = 200,
top_k: int = 200,
temperature: float = 0.8,
accelerator: str = "auto",
) -> None:
checkpoint_dir: str = "./checkpoints",
) -> dict:
# import pdb; pdb.set_trace()
lora_path = Path(args.lora_path)
pretrained_llm_path = Path(f"./checkpoints/vicuna-7b-v1.5/lit_model.pth")
tokenizer_llm_path = Path("./checkpoints/vicuna-7b-v1.5/tokenizer.model")
pretrained_llm_path = Path(f"{checkpoint_dir}/vicuna-7b-v1.5/lit_model.pth")
tokenizer_llm_path = Path(f"{checkpoint_dir}/vicuna-7b-v1.5/tokenizer.model")

# assert lora_path.is_file()
assert pretrained_llm_path.is_file()
assert tokenizer_llm_path.is_file()
Expand All @@ -223,7 +212,7 @@ def main(
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
), lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True):
checkpoint_dir = Path("checkpoints/vicuna-7b-v1.5")
checkpoint_dir = Path(f"{checkpoint_dir}/vicuna-7b-v1.5")
lora_query = True
lora_key = False
lora_value = True
Expand All @@ -243,13 +232,14 @@ def main(
to_head=lora_head,
)
model = GPT(config).bfloat16()

mlp_path = args.mlp_path
pretrained_checkpoint_mlp = torch.load(mlp_path)

X = ['Video']

mm_backbone_mlp_model, processor = get_processor(X, args, 'cuda', pretrained_checkpoint_mlp, model_path = 'LanguageBind/Video-LLaVA-7B')
mm_backbone_mlp_model, processor = get_processor(X, args, 'cuda', pretrained_checkpoint_mlp,
model_path='LanguageBind/Video-LLaVA-7B')
video_processor = processor['video']

linear_proj = mm_backbone_mlp_model.mm_projector
Expand All @@ -264,66 +254,87 @@ def main(
print('Load llm base model from', pretrained_llm_path)
print('Load lora model from', lora_path)

# load mlp again, to en sure, not neccessary actually
# load mlp again, to en sure, not neccessary actually
linear_proj.load_state_dict(pretrained_checkpoint_mlp)
linear_proj = linear_proj.cuda()
print('Load mlp model again from', mlp_path)


print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()
model = fabric.setup_module(model)
linear_proj.eval()


tokenizer = Tokenizer(tokenizer_llm_path)
print('Load tokenizer from', tokenizer_llm_path)
return {
"tokenizer": tokenizer,
"model": model,
"mm_backbone_mlp_model": mm_backbone_mlp_model,
"video_processor": video_processor,
}

def predict(tokenizer: Tokenizer,
model: GPT,
mm_backbone_mlp_model: LlavaMetaModel,
video_processor: any,
input_video_path: str,
prompt: str,

max_new_tokens: int = 200,
top_k: int = 200,
temperature: float = 0.8) -> str:

video_tensor = video_processor(input_video_path, return_tensors='pt')['pixel_values']

if type(video_tensor) is list:
tensor = [video.to('cuda', dtype=torch.float16) for video in video_tensor]
else:
tensor = video_tensor.to('cuda', dtype=torch.float16) # (1,3,8,224,224)

X_modalities = [tensor, ['video']]

video_feature = mm_backbone_mlp_model.get_multimodal_embeddings(X_modalities)

sample = {"instruction": prompt, "input": input_video_path}

prefix = generate_prompt_mlp(sample)
pre = torch.cat(
(tokenizer.encode(prefix.split('INPUT_VIDEO: ')[0] + "\n", bos=True, eos=False, device=model.device).view(1,
-1),
tokenizer.encode("INPUT_VIDEO: ", bos=False, eos=False, device=model.device).view(1, -1)), dim=1)

prompt = (pre, ". ASSISTANT: ")
encoded = (prompt[0], video_feature[0],
tokenizer.encode(prompt[1], bos=False, eos=False, device=model.device).view(1, -1))

output_seq = generate(
model,
idx=encoded,
max_seq_length=4096,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id,
tokenizer=tokenizer,
)
outputfull = tokenizer.decode(output_seq)
return outputfull.split("ASSISTANT:")[-1].strip()

def main():
args = option.get_args_parser()


torch.set_float32_matmul_precision("high")
model_components = load_model(args)
while True:

input_video_path = input("\033[0;34;40m Input video path: \033[0m")
video_tensor = video_processor(input_video_path, return_tensors='pt')['pixel_values']

if type(video_tensor) is list:
tensor = [video.to('cuda', dtype=torch.float16) for video in video_tensor]
else:
tensor = video_tensor.to('cuda', dtype=torch.float16) # (1,3,8,224,224)

X_modalities = [tensor,['video']]
prompt = input("\033[0;34;40m Your question: \033[0m")

video_feature = mm_backbone_mlp_model.get_multimodal_embeddings(X_modalities)
output = predict(prompt=prompt, input_video_path=input_video_path, **model_components)

prompt = input("\033[0;34;40m Your question: \033[0m")
sample = {"instruction": prompt, "input": input_video_path}

prefix = generate_prompt_mlp(sample)
pre = torch.cat((tokenizer.encode(prefix.split('INPUT_VIDEO: ')[0] + "\n", bos=True, eos=False, device=model.device).view(1, -1), tokenizer.encode("INPUT_VIDEO: ", bos=False, eos=False, device=model.device).view(1, -1)), dim=1)

prompt = (pre, ". ASSISTANT: ")
encoded = (prompt[0], video_feature[0], tokenizer.encode(prompt[1], bos=False, eos=False, device=model.device).view(1, -1))


t0 = time.perf_counter()

output_seq = generate(
model,
idx=encoded,
max_seq_length=4096,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id,
tokenizer = tokenizer,
)
outputfull = tokenizer.decode(output_seq)
output = outputfull.split("ASSISTANT:")[-1].strip()
print("================================")
print("Model output", output)
print("================================")


if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
main()