from dataclasses import dataclass from typing import Dict, List import torch from PIL.Image import Image from transformers import LlamaTokenizerFast from transformers.processing_utils import ProcessorMixin from deepseek_vl.models.image_processing_vlm import VLMImageProcessor from deepseek_vl.utils.conversation import get_conv_template class DictOutput(object): def keys(self): return self.__dict__.keys() def __getitem__(self, item): return self.__dict__[item] def __setitem__(self, key, value): self.__dict__[key] = value @dataclass class VLChatProcessorOutput(DictOutput): sft_format: str input_ids: torch.Tensor pixel_values: torch.Tensor num_image_tokens: torch.IntTensor def __len__(self): return len(self.input_ids) @dataclass class BatchedVLChatProcessorOutput(DictOutput): sft_format: List[str] input_ids: torch.Tensor pixel_values: torch.Tensor attention_mask: torch.Tensor images_seq_mask: torch.BoolTensor images_emb_mask: torch.BoolTensor def to(self, device, dtype=torch.bfloat16): self.input_ids = self.input_ids.to(device) self.attention_mask = self.attention_mask.to(device) self.images_seq_mask = self.images_seq_mask.to(device) self.images_emb_mask = self.images_emb_mask.to(device) self.pixel_values = self.pixel_values.to(device=device, dtype=dtype) return self class VLChatProcessor(ProcessorMixin): image_processor_class = "AutoImageProcessor" tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") attributes = ["image_processor", "tokenizer"] system_prompt = ( "You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language." ) def __init__( self, image_processor: VLMImageProcessor, tokenizer: LlamaTokenizerFast, image_tag: str = "", num_image_tokens: int = 576, add_special_token: bool = False, sft_format: str = "deepseek", mask_prompt: bool = True, ignore_id: int = -100, system="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." **kwargs, ): self.system_prompt = system self.image_processor = image_processor self.tokenizer = tokenizer image_id = self.tokenizer.vocab.get(image_tag) if image_id is None: special_tokens = [image_tag] special_tokens_dict = {"additional_special_tokens": special_tokens} self.tokenizer.add_special_tokens(special_tokens_dict) print(f"Add image tag = {image_tag} to the tokenizer") self.image_tag = image_tag self.num_image_tokens = num_image_tokens self.add_special_token = add_special_token self.sft_format = sft_format self.mask_prompt = mask_prompt self.ignore_id = ignore_id super().__init__( image_processor, tokenizer, image_tag, num_image_tokens, add_special_token, sft_format, mask_prompt, ignore_id, **kwargs, ) def new_chat_template(self): conv = get_conv_template(self.sft_format) conv.set_system_message(self.system_prompt) return conv def apply_sft_template_for_multi_turn_prompts( self, conversations: List[Dict[str, str]], sft_format: str = "deepseek", system_prompt: str = "", ): """ Applies the SFT template to conversation. An example of conversation: conversation = [ { "role": "User", "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?", "images": [ "./multi-images/attribute_comparison_1.png", "./multi-images/attribute_comparison_2.png" ] }, { "role": "Assistant", "content": "" } ] Args: conversations (List[Dict]): A conversation with a List of Dict[str, str] text. sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". Returns: sft_prompt (str): The formatted text. """ conv = get_conv_template(sft_format) conv.set_system_message(system_prompt) for message in conversations: conv.append_message(message["role"], message["content"].strip()) sft_prompt = conv.get_prompt().strip() return sft_prompt @property def image_token(self): return self.image_tag @property def image_id(self): image_id = self.tokenizer.vocab.get(self.image_tag) return image_id @property def pad_id(self): pad_id = self.tokenizer.pad_token_id if pad_id is None: pad_id = self.tokenizer.eos_token_id return pad_id def add_image_token( self, image_indices: List[int], input_ids: torch.LongTensor, ): """ Args: image_indices (List[int]): [index_0, index_1, ..., index_j] input_ids (torch.LongTensor): [N] Returns: input_ids (torch.LongTensor): [N + image tokens] num_image_tokens (torch.IntTensor): [n_images] """ input_slices = [] start = 0 for index in image_indices: if self.add_special_token: end = index + 1 else: end = index # original text tokens input_slices.append(input_ids[start:end]) # add image tokens, and set the mask as False input_slices.append( self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) ) start = index + 1 # the left part input_slices.append(input_ids[start:]) # concat all slices input_ids = torch.cat(input_slices, dim=0) num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) return input_ids, num_image_tokens def process_one( self, prompt: str = None, conversations: List[Dict[str, str]] = None, images: List[Image] = None, **kwargs, ): """ Args: prompt (str): the formatted prompt; conversations (List[Dict]): conversations with a list of messages; images (List[ImageType]): the list of images; **kwargs: Returns: outputs (BaseProcessorOutput): the output of the processor, - input_ids (torch.LongTensor): [N + image tokens] - target_ids (torch.LongTensor): [N + image tokens] - images (torch.FloatTensor): [n_images, 3, H, W] - image_id (int): the id of the image token - num_image_tokens (List[int]): the number of image tokens """ assert ( prompt is None or conversations is None ), "prompt and conversations cannot be used at the same time." if prompt is None: # apply sft format sft_format = self.apply_sft_template_for_multi_turn_prompts( conversations=conversations, sft_format=self.sft_format, system_prompt=self.system_prompt, ) else: sft_format = prompt # tokenize input_ids = self.tokenizer.encode(sft_format) input_ids = torch.LongTensor(input_ids) # add image tokens to the input_ids image_token_mask: torch.BoolTensor = input_ids == self.image_id image_indices = image_token_mask.nonzero() input_ids, num_image_tokens = self.add_image_token( image_indices=image_indices, input_ids=input_ids, ) # load images images_outputs = self.image_processor(images, return_tensors="pt") prepare = VLChatProcessorOutput( sft_format=sft_format, input_ids=input_ids, pixel_values=images_outputs.pixel_values, num_image_tokens=num_image_tokens, ) return prepare def __call__( self, *, prompt: str = None, conversations: List[Dict[str, str]] = None, images: List[Image] = None, force_batchify: bool = True, **kwargs, ): """ Args: prompt (str): the formatted prompt; conversations (List[Dict]): conversations with a list of messages; images (List[ImageType]): the list of images; force_batchify (bool): force batchify the inputs; **kwargs: Returns: outputs (BaseProcessorOutput): the output of the processor, - input_ids (torch.LongTensor): [N + image tokens] - images (torch.FloatTensor): [n_images, 3, H, W] - image_id (int): the id of the image token - num_image_tokens (List[int]): the number of image tokens """ prepare = self.process_one( prompt=prompt, conversations=conversations, images=images ) if force_batchify: prepare = self.batchify([prepare]) return prepare def batchify( self, prepare_list: List[VLChatProcessorOutput] ) -> BatchedVLChatProcessorOutput: """ Preprocesses the inputs for multimodal inference. Args: prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. Returns: BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. """ batch_size = len(prepare_list) sft_format = [] n_images = [] seq_lens = [] for prepare in prepare_list: n_images.append(len(prepare.num_image_tokens)) seq_lens.append(len(prepare)) input_token_max_len = max(seq_lens) max_n_images = max(1, max(n_images)) batched_input_ids = torch.full( (batch_size, input_token_max_len), self.pad_id ).long() # FIXME batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() batched_pixel_values = torch.zeros( (batch_size, max_n_images, *self.image_processor.default_shape) ).float() batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() batched_images_emb_mask = torch.zeros( (batch_size, max_n_images, self.num_image_tokens) ).bool() for i, prepare in enumerate(prepare_list): input_ids = prepare.input_ids seq_len = len(prepare) n_image = len(prepare.num_image_tokens) # left-padding batched_attention_mask[i, -seq_len:] = 1 batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id if n_image > 0: batched_pixel_values[i, :n_image] = prepare.pixel_values for j, n_image_tokens in enumerate(prepare.num_image_tokens): batched_images_emb_mask[i, j, :n_image_tokens] = True sft_format.append(prepare.sft_format) batched_prepares = BatchedVLChatProcessorOutput( input_ids=batched_input_ids, attention_mask=batched_attention_mask, pixel_values=batched_pixel_values, images_seq_mask=batched_images_seq_mask, images_emb_mask=batched_images_emb_mask, sft_format=sft_format, ) return batched_prepares