# Copyright (2024) Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from PIL import Image from typing import List import torch from transformers import DataCollatorForSeq2Seq from transformers.models.llava import LlavaProcessor import re import os from .utils import sample_image, sample_video, sample_gif, get_visual_type HF_TOKEN = os.environ.get('HF_TOKEN', '') ext2sampler = { 'image': sample_image, 'gif': sample_gif, 'video': sample_video } class CustomImageProcessor: def __init__(self, processor) -> None: self.processor = processor def __call__(self, images: List[Image.Image], do_padding=False) -> torch.Tensor: if do_padding: images = [self.expand2square( img, tuple(int(x * 255) for x in self.processor.image_processor.image_mean) ) for img in images] else: images = [self.resize2square(img) for img in images] images_pixel = self.processor(text="", images=images, return_tensors="pt")['pixel_values'] return images_pixel # [num_images, 3, 336, 336] def expand2square(self, pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def resize2square(self, pil_img: Image.Image): width, height = pil_img.size pil_img = pil_img.resize((max(width, height), max(width, height))) return pil_img class Processor(object): def __init__( self, model_name_or_path, max_n_frames=8, max_seq_len=None, add_sep=False, do_image_padding=False, ): self.max_n_frames = max_n_frames self.max_seq_len = max_seq_len, self.add_sep = add_sep self.do_image_padding = do_image_padding if not self.do_image_padding: print(f"### do_image_padding is set as False, images will be resized directly!") self.setup(model_name_or_path) def setup(self, model_name_or_path): sub_processor = LlavaProcessor.from_pretrained( model_name_or_path, padding_side='left', trust_remote_code=True, token=HF_TOKEN, ) self.processor = CustomImageProcessor(sub_processor) self.tokenizer = sub_processor.tokenizer # self.pad_collator = DataCollatorForSeq2Seq(self.tokenizer, padding='longest') self.sep_id = self.tokenizer.sep_token_id self.pad_id = self.tokenizer.pad_token_id self.eos_id = self.tokenizer.eos_token_id if self.sep_id is None: self.add_sep = False if not self.max_seq_len: self.max_seq_len = self.tokenizer.model_max_length def process_prompt(self, prompt, images: List[Image.Image]=None): if not images: prompt = prompt.replace("", "").replace("