Tarsier2-7b / dataset /mm_dataset.py
omni-research's picture
init
97a05c0
# Copyright (2024) Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataset.utils import get_visual_type, sample_frame_indices
from .processor import Processor
from tools.rw_utils import read_jsonlines
class MMDataset(object):
def __init__(self, ann_path="", anns=None, processor:Processor=None):
self.processor = processor
if anns is None:
self.anns = []
if isinstance(ann_path, str):
ann_path = [ann_path]
for path in ann_path:
self.anns.extend(read_jsonlines(path))
else:
self.anns = anns
def __len__(self):
return len(self.anns)
def __getitem__(self, index):
try:
ann = self.anns[index]
prompt = ann['text']['prompt']
video_file = ann['video_file']
visual_files = []
start_time = ann.get("start_time", 0)
end_time = ann.get("end_time", -1)
if isinstance(video_file, list):
# This is for MVBench/Episodic Reasoning
# The video_file are a list of sorted frames extract from the target video
for img_file in video_file:
if get_visual_type(img_file) == 'image':
visual_files.append(img_file)
frame_indices = sample_frame_indices(start_frame=0, total_frames=len(visual_files), n_frames=min(len(visual_files), self.processor.max_n_frames))
visual_files = [v for i,v in enumerate(visual_files) if i in frame_indices]
else:
if get_visual_type(video_file) in ['image', 'video', 'gif']:
visual_files.append(video_file)
assert len(visual_files) >= 0, f"Failed to load valid visual file from anns[{index}]!"
images = []
for v_f in visual_files:
images.extend(self.processor.load_images(v_f, start_time=start_time, end_time=end_time))
model_inputs = self.processor(prompt, images=images, edit_prompt=True, return_prompt=True)
except Exception as e:
print(f"Load data error: {e}")
return ann, None
return ann, model_inputs