internlm-xcomposer2d5-ol-7b / memory /model_grounding_ego.py
yhcao's picture
upload models
8e1010d
import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
from grounding_qwen import GroundQwenForCausalLM
from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from conversation import conv_templates, SeparatorStyle
from utils import disable_torch_init
from mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import math
import pickle
from decord import VideoReader
import numpy as np
from transformers import StoppingCriteria, StoppingCriteriaList, AutoTokenizer
from petrel_client.client import Client
client = Client('~/petreloss.conf')
def get_seq_frames(total_num_frames, desired_num_frames):
seg_size = float(total_num_frames - 1) / desired_num_frames
seq = []
for i in range(desired_num_frames):
start = int(np.round(seg_size * i))
end = int(np.round(seg_size * (i + 1)))
seq.append((start + end) // 2)
return seq
def get_seq_time(vr, frame_idx, num_clip):
frm_per_clip = len(frame_idx) // num_clip
key_frame = [[frame_idx[i*frm_per_clip], frame_idx[i*frm_per_clip+frm_per_clip-1]] for i in range(num_clip)]
time = vr.get_frame_timestamp(key_frame)
return np.hstack([time[:, 0, 0], time[:, 1, 1]])
def calculate_diff(scene_sep, start_frame):
diff = [scene_sep[0]-start_frame]
for i in range(len(scene_sep)-1):
diff.append(scene_sep[i+1]-scene_sep[i])
return diff
def load_video(vis_path, scene_sep, num_frm=16, max_clip=4):
block_size = 1
vr = VideoReader(vis_path)
total_frame_num = len(vr)
fps = vr.get_avg_fps()
total_time = total_frame_num / fps
if len(scene_sep) == 0:
num_clip = total_time / num_frm
num_clip = int(block_size*np.round(num_clip/block_size)) if num_clip > block_size else int(np.round(num_clip))
num_clip = max(num_clip, 5)
num_clip = min(num_clip, max_clip)
total_num_frm = num_frm * num_clip
start_frame = 0
frame_idx = get_seq_frames(total_frame_num, total_num_frm)
else:
num_clip = max(len(scene_sep), 5)
num_clip = min(num_clip, max_clip)
total_num_frm = num_frm * num_clip
start_frame = 0
frame_idx = []
if len(scene_sep) < 5:
diff = calculate_diff(scene_sep, start_frame)
new_sep = max(diff) / (5-len(scene_sep)+1)
max_idx = np.argmax(diff)
if max_idx == 0:
scene_sep = [int(start_frame+new_sep*i) for i in range(1, 5-len(scene_sep)+1)] + scene_sep
else:
scene_sep = scene_sep[:max_idx]+[int(scene_sep[max_idx-1]+i*new_sep) for i in range(1, 5-len(scene_sep)+1)] + scene_sep[max_idx:]
elif len(scene_sep) > max_clip:
diff = calculate_diff(scene_sep, start_frame)
min_idx = np.argsort(diff[:-1])[:len(scene_sep)-max_clip] ##minimum diff to remove
for i in np.sort(min_idx)[::-1]:
del scene_sep[i]
start_ = start_frame
for end_frame in scene_sep:
idx_list = np.linspace(start_, end_frame, num=num_frm, endpoint=False)
frame_idx.extend([int(id) for id in idx_list])
start_ = end_frame
time_idx = get_seq_time(vr, frame_idx, num_clip)
img_array = vr.get_batch(frame_idx).asnumpy() # (n_clips*num_frm, H, W, 3)
a, H, W, _ = img_array.shape
h, w = 224, 224
if img_array.shape[-3] != h or img_array.shape[-2] != w:
img_array = torch.from_numpy(img_array).permute(0, 3, 1, 2).float()
img_array = torch.nn.functional.interpolate(img_array, size=(h, w))
img_array = img_array.permute(0, 2, 3, 1).to(torch.uint8).numpy()
img_array = img_array.reshape((1, total_num_frm, img_array.shape[-3], img_array.shape[-2], img_array.shape[-1]))
clip_imgs = []
for j in range(total_num_frm):
clip_imgs.append(Image.fromarray(img_array[0, j]))
return clip_imgs, time_idx, num_clip
def preprocess_time(time, num_clip, tokenizer):
time = time.reshape(2, num_clip)
seq = []
block_size = 1
for i in range(num_clip):
start, end = time[:, i]
start = int(np.round(start))
end = int(np.round(end))
if (i+1) % block_size == 0:
history_end = end
sentence = 'This contains a clip sampled in %d to %d seconds' % (start, end) + DEFAULT_IMAGE_TOKEN
sentence = tokenizer_image_token(sentence, tokenizer, return_tensors='pt')
seq.append(sentence)
return seq
def preprocess_question(questions, tokenizer):
seq = []
for q in questions:
sentence = tokenizer_image_token(q, tokenizer, return_tensors='pt')
seq.append(sentence)
return seq
def eval_dataset(args):
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
device = 'cuda'
kwargs = {"device_map": 'auto'}
kwargs['torch_dtype'] = torch.float16
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = GroundQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
video_path = '/mnt/hwfile/mllm/qianrui/test_video/air.mp4'
question = ['How many pilots are shown in the video?', 'What airlines are shown in the video?', 'How is the decoration of the airport?']
scene_sep = json.load(open('/mnt/hwfile/mllm/qianrui/test_video/air.json', 'r'))['scene_sep']
frames, time_idx, num_clips = load_video(video_path, scene_sep, num_frm=16, max_clip=32)
video = image_processor.preprocess(frames, return_tensors='pt')['pixel_values']
video = video.view(num_clips, 16, *video.shape[1:])
seqs = preprocess_time(time_idx, num_clips, tokenizer)
seqs = torch.nn.utils.rnn.pad_sequence(
seqs,
batch_first=True,
padding_value=tokenizer.pad_token_id)
compress_mask = seqs.ne(tokenizer.pad_token_id)
question = preprocess_question(question, tokenizer)
question = torch.nn.utils.rnn.pad_sequence(
question,
batch_first=True,
padding_value=tokenizer.pad_token_id)
qs_mask = question.ne(tokenizer.pad_token_id)
with torch.inference_mode():
similarity = model.forward_grounding(
input_ids=seqs.to(device='cuda', non_blocking=True),
attention_mask=compress_mask.to(device='cuda', non_blocking=True),
images=video.to(dtype=torch.float16, device='cuda', non_blocking=True),
qs_ids=question.to(device='cuda', non_blocking=True),
qs_mask=qs_mask.to(device='cuda', non_blocking=True))
print(similarity.shape, similarity)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--conv-mode", type=str, default="llava_v1")
args = parser.parse_args()
eval_dataset(args)