File size: 2,757 Bytes
97a05c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (2024) Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataset.utils import get_visual_type, sample_frame_indices
from .processor import Processor
from tools.rw_utils import read_jsonlines

class MMDataset(object):
    def __init__(self, ann_path="", anns=None, processor:Processor=None):
        self.processor = processor
        if anns is None:
            self.anns = []
            if isinstance(ann_path, str):
                ann_path = [ann_path]
            for path in ann_path:
                self.anns.extend(read_jsonlines(path))
        else:
            self.anns = anns

    def __len__(self):
        return len(self.anns)
    
    def __getitem__(self, index):
        try:
            ann = self.anns[index]

            prompt = ann['text']['prompt']

            video_file = ann['video_file']
            visual_files = []
            start_time = ann.get("start_time", 0)
            end_time = ann.get("end_time", -1)
            if isinstance(video_file, list):
                # This is for MVBench/Episodic Reasoning
                # The video_file are a list of sorted frames extract from the target video
                for img_file in video_file:
                    if get_visual_type(img_file) == 'image':
                        visual_files.append(img_file)
                frame_indices = sample_frame_indices(start_frame=0, total_frames=len(visual_files), n_frames=min(len(visual_files), self.processor.max_n_frames))
                visual_files = [v for i,v in enumerate(visual_files) if i in frame_indices]    
            else:
                if get_visual_type(video_file) in ['image', 'video', 'gif']:
                    visual_files.append(video_file)
            assert len(visual_files) >= 0, f"Failed to load valid visual file from anns[{index}]!"
            images = []
            for v_f in visual_files:
                images.extend(self.processor.load_images(v_f, start_time=start_time, end_time=end_time))
            model_inputs = self.processor(prompt, images=images, edit_prompt=True, return_prompt=True)
        except Exception as e:
            print(f"Load data error: {e}")
            return ann, None
        return ann, model_inputs