WEBing commited on
Commit
1f1db3a
1 Parent(s): 732992b

fix eva in data_utils

Browse files
Files changed (1) hide show
  1. data_utils.py +34 -18
data_utils.py CHANGED
@@ -1,11 +1,29 @@
1
  import decord
2
- import numpy as np
3
- import torch
4
- from PIL import Image
5
  import random
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- from eva_clip.transform import image_transform
8
- image_processor = image_transform(image_size=448, is_train=False)
9
 
10
  def preprocess_multimodal(sources, num_segments):
11
  for source in sources:
@@ -26,6 +44,7 @@ def preprocess_multimodal(sources, num_segments):
26
  sentence["content"] = sentence["content"].replace(X_token, replace_token)
27
  return sources
28
 
 
29
  def preprocess(
30
  sources,
31
  tokenizer,
@@ -60,9 +79,6 @@ def preprocess(
60
  else:
61
  index = random.choice(range(len(en_qa_templates)))
62
  system_prompt = f"""You are a helpful assistant, {en_qa_templates[index]} 你是一个乐于助人的助手,{ch_qa_templates[index]}"""
63
- chat_template = """{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>'
64
- + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}
65
- {% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"""
66
  messages = []
67
  for source in sources:
68
  message = [{'role': 'system', 'content': system_prompt}]
@@ -70,14 +86,14 @@ def preprocess(
70
  message.append(sentence)
71
  messages.append(message)
72
 
73
- #input_ids = tokenizer.apply_chat_template(messages, chat_template, add_generation_prompt=True, return_tensors='pt')
74
  input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt')
75
  return input_ids
76
-
 
77
  def get_index(fps, max_frame, num_segments):
78
  num_frames = max_frame
79
  if num_frames <= num_segments:
80
- out_indices = start_idx + np.array([(idx % num_frames) for idx in range(num_segments)])
81
  out_indices = np.sort(out_indices)
82
  else:
83
  out_indices = np.linspace(0, num_frames-1, num_segments)
@@ -85,22 +101,23 @@ def get_index(fps, max_frame, num_segments):
85
  durations = [idx.item() / fps for idx in out_indices]
86
  return out_indices.astype(np.int64), durations
87
 
 
88
  def read_video(video_path, num_segments):
 
89
  vr = decord.VideoReader(video_path)
90
- max_frame = len(vr) - 1
91
  fps = float(vr.get_avg_fps())
92
 
93
- total_duration = len(vr) / fps
94
- frame_indices, durations = get_index(fps, max_frame, num_segments)
95
  video = []
96
  for frame_index in frame_indices:
97
  image = Image.fromarray(vr[frame_index].asnumpy())
98
  video.append(image_processor(image).unsqueeze(0))
99
  video = torch.concat(video)
100
- return video, torch.Tensor(durations), total_duration
 
101
 
102
  def get_input(video_path, num_segments, question, history, tokenizer, s_id):
103
- video, durations, total_duration = read_video(video_path, num_segments)
104
  if history == None:
105
  conversations = []
106
  conversations.append({'role': 'user', 'content': f'<video>\n{question}'})
@@ -113,8 +130,7 @@ def get_input(video_path, num_segments, question, history, tokenizer, s_id):
113
 
114
  return video, durations, input_ids, conversations
115
 
 
116
  def add_pred_to_history(history, pred):
117
  history.append({'role': 'assistant', 'content': pred})
118
  return history
119
-
120
-
 
1
  import decord
 
 
 
2
  import random
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import torch
7
+ from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize
8
+
9
+
10
+ def _convert_to_rgb(image):
11
+ return image.convert('RGB')
12
+
13
+
14
+ def image_transform(image_size: int):
15
+ mean = (0.48145466, 0.4578275, 0.40821073)
16
+ std = (0.26862954, 0.26130258, 0.27577711)
17
+
18
+ normalize = Normalize(mean=mean, std=std)
19
+ transforms = [
20
+ Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
21
+ _convert_to_rgb,
22
+ ToTensor(),
23
+ normalize,
24
+ ]
25
+ return Compose(transforms)
26
 
 
 
27
 
28
  def preprocess_multimodal(sources, num_segments):
29
  for source in sources:
 
44
  sentence["content"] = sentence["content"].replace(X_token, replace_token)
45
  return sources
46
 
47
+
48
  def preprocess(
49
  sources,
50
  tokenizer,
 
79
  else:
80
  index = random.choice(range(len(en_qa_templates)))
81
  system_prompt = f"""You are a helpful assistant, {en_qa_templates[index]} 你是一个乐于助人的助手,{ch_qa_templates[index]}"""
 
 
 
82
  messages = []
83
  for source in sources:
84
  message = [{'role': 'system', 'content': system_prompt}]
 
86
  message.append(sentence)
87
  messages.append(message)
88
 
 
89
  input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt')
90
  return input_ids
91
+
92
+
93
  def get_index(fps, max_frame, num_segments):
94
  num_frames = max_frame
95
  if num_frames <= num_segments:
96
+ out_indices = np.array([(idx % num_frames) for idx in range(num_segments)])
97
  out_indices = np.sort(out_indices)
98
  else:
99
  out_indices = np.linspace(0, num_frames-1, num_segments)
 
101
  durations = [idx.item() / fps for idx in out_indices]
102
  return out_indices.astype(np.int64), durations
103
 
104
+
105
  def read_video(video_path, num_segments):
106
+ image_processor = image_transform(image_size=448)
107
  vr = decord.VideoReader(video_path)
 
108
  fps = float(vr.get_avg_fps())
109
 
110
+ frame_indices, durations = get_index(fps, len(vr) - 1, num_segments)
 
111
  video = []
112
  for frame_index in frame_indices:
113
  image = Image.fromarray(vr[frame_index].asnumpy())
114
  video.append(image_processor(image).unsqueeze(0))
115
  video = torch.concat(video)
116
+ return video, torch.Tensor(durations)
117
+
118
 
119
  def get_input(video_path, num_segments, question, history, tokenizer, s_id):
120
+ video, durations = read_video(video_path, num_segments)
121
  if history == None:
122
  conversations = []
123
  conversations.append({'role': 'user', 'content': f'<video>\n{question}'})
 
130
 
131
  return video, durations, input_ids, conversations
132
 
133
+
134
  def add_pred_to_history(history, pred):
135
  history.append({'role': 'assistant', 'content': pred})
136
  return history