from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers import AutoConfig, AutoModelForCausalLM, \ LlamaConfig, LlamaModel, LlamaForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from PIL import Image from abc import ABC, abstractmethod import os import math from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig from functools import partial from transformers.configuration_utils import PretrainedConfig from timm.models.layers import LayerNorm, LayerNorm2d from timm.models.regnet import RegStage from torch.nn import functional as F import math from einops import rearrange CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "." # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" class CLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') if not delay_load: self.load_model() else: self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) def load_model(self): self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) self.vision_tower.requires_grad_(False) self.is_loaded = True def feature_select(self, image_forward_outs): image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == 'patch': image_features = image_features[:, 1:] elif self.select_feature == 'cls_patch': image_features = image_features else: raise ValueError(f'Unexpected select feature: {self.select_feature}') return image_features @torch.no_grad() def forward(self, images): if type(images) is list: image_features = [] for image in images: image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) image_feature = self.feature_select(image_forward_out).to(image.dtype) image_features.append(image_feature) else: image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (self.config.image_size // self.config.patch_size) ** 2 def build_vision_tower(vision_tower_cfg, **kwargs): vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) is_absolute_path_exists = os.path.exists(vision_tower) if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) raise ValueError(f'Unknown vision tower: {vision_tower}') class HoneybeeVisualProjectorConfig(PretrainedConfig): model_type = "mllm_visual_projector" def __init__( self, projector_type: str = "resampler", hidden_size: int = 1024, # num_hidden_layers: int = 6, # num_attention_heads: int = 16, # intermediate_size: int = 4096, # attention_probs_dropout_prob: float = 0.1, # initializer_range: float = 0.02, layer_norm_eps: float = 1e-6, # encoder_hidden_size: int = 1024, # This will be overwritten by vision_model's hidden_size pos_emb=False, feature_layer_index=-1, # vision feature layer index; -1: last layer num_eos_tokens=1, use_cls=True, prenorm=False, **kwargs, ): super().__init__(**kwargs) self.projector_type = projector_type self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.attention_probs_dropout_prob = attention_probs_dropout_prob self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.encoder_hidden_size = encoder_hidden_size self.pos_emb = pos_emb self.feature_layer_index = feature_layer_index self.num_eos_tokens = num_eos_tokens self.use_cls = use_cls self.prenorm = prenorm @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> "PretrainedConfig": config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) # get the visual_projector config dict if we are loading from HoneybeeConfig if config_dict.get("model_type") == "QH_360VL": config_dict = config_dict["visual_projector_config"] return cls.from_dict(config_dict, **kwargs) def build_pos_embeds( config: HoneybeeVisualProjectorConfig, num_input_tokens: int, vision_hidden_size: int ): # pos emb # true if config.pos_emb: pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size)) nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02) else: pos_emb = None return pos_emb def build_eos_tokens(config: HoneybeeVisualProjectorConfig, output_hidden_size: int): # think tokens num_eos_tokens = config.num_eos_tokens # 0 if num_eos_tokens: eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size)) nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range) else: eos_tokens = None return eos_tokens def build_prenorm(config: HoneybeeVisualProjectorConfig): # false if config.prenorm: prenorm = LayerNorm(config.encoder_hidden_size) else: prenorm = None return prenorm def build_mlp(depth, hidden_size, output_hidden_size): layers = [nn.Linear(hidden_size, output_hidden_size)] for _ in range(1, depth): layers.append(nn.SiLU()) layers.append(nn.Linear(output_hidden_size, output_hidden_size)) return nn.Sequential(*layers) def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C # tgt_size: M # return: M, C # 16,24 src_size = int(math.sqrt(abs_pos.size(1))) # 32,48 tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype if src_size != tgt_size: return F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size, tgt_size), mode="bicubic", align_corners=False, ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) else: return abs_pos class Projector(nn.Module): """Base projector class""" def __init__( self, config: HoneybeeVisualProjectorConfig, num_input_tokens: int, output_hidden_size: int, ): super().__init__() self.config = config self.num_input_tokens = num_input_tokens self.output_hidden_size = output_hidden_size # think tokens self.eos_tokens = build_eos_tokens(config, output_hidden_size) # pos emb self.pos_emb = build_pos_embeds(config, num_input_tokens, config.encoder_hidden_size) self.prenorm = build_prenorm(config) self.build_net() def build_net(self): raise NotImplementedError() def _forward(self, x): raise NotImplementedError() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token. """ if self.prenorm is not None: x = self.prenorm(x) if self.pos_emb is not None: # self.pos_emb = self.pos_emb[:,1:] pos_emb = get_abs_pos(self.pos_emb[:,1:], x.size(1)) pos_emb = pos_emb.to(device=x.device) x += pos_emb x = self._forward(x) # (B, L, output_hidden_size) B = x.size(0) if self.eos_tokens is not None: x = torch.cat([x, self.eos_tokens.expand(B, -1, -1)], dim=1) return x class ConvProjector(Projector): def _forward(self, x): # x: [B, L, dim] # x = x[:, 1:] # drop cls token and 2d forward hw = int(x.size(1) ** 0.5) x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) x = self.net(x) x = rearrange(x, "b d h w -> b (h w) d") x = self.readout(x) return x class CAbstractor(ConvProjector): """C-Abstractor""" def build_net(self): encoder_hidden_size = self.config.encoder_hidden_size hidden_size = self.config.hidden_size output_hidden_size = self.output_hidden_size depth = self.config.depth mlp_depth = self.config.mlp_depth n_queries = self.config.num_queries assert (n_queries ** 0.5).is_integer(), "n_queries must be square number" hw = int(n_queries ** 0.5) # RegBlock = ResBlock + SE RegBlock = partial( RegStage, stride=1, dilation=1, act_layer=nn.SiLU, norm_layer=LayerNorm2d, ) s1 = RegBlock( depth, encoder_hidden_size, hidden_size, ) sampler = nn.AdaptiveAvgPool2d((hw, hw)) s2 = RegBlock( depth, hidden_size, hidden_size, ) self.net = nn.Sequential(s1, sampler, s2) self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size) class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": 'identity'} class SimpleResBlock(nn.Module): def __init__(self, channels): super().__init__() self.pre_norm = nn.LayerNorm(channels) self.proj = nn.Sequential( nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) def build_honeybee_projector(config, projector_type, num_tokens,lm_hidden_size): """Build projector (abstractor) and query_tokens (optionally for resampler)""" proj_config = config proj_type = projector_type num_tokens = num_tokens output_hidden_size = lm_hidden_size # LM hidden size abstractor = { "c-abs": CAbstractor, }[ proj_type ](proj_config, num_tokens, output_hidden_size) return abstractor def build_vision_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, 'mm_projector_type', 'linear') if projector_type == 'linear': return nn.Linear(config.mm_hidden_size, config.hidden_size) if projector_type == 'c-abs': local_config_path = config.mm_projector_config honeybee_config = HoneybeeVisualProjectorConfig.from_pretrained(local_config_path) num_tokens = config.mm_num_tokens lm_hidden_size = config.hidden_size abstractor = build_honeybee_projector(honeybee_config,projector_type,num_tokens,lm_hidden_size) return abstractor mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) if projector_type == 'identity': return IdentityMap() raise ValueError(f'Unknown projector type: {projector_type}') class QH360_VL_MetaModel: def __init__(self, config): super(QH360_VL_MetaModel, self).__init__(config) if hasattr(config, "mm_vision_tower"): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector_ctt = build_vision_projector(config) self.mm_projector_ori = build_vision_projector(config) def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower class QH360_VL_MetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def encode_images_noprojector(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = image_features.detach() return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: image_features = [] for image in images: if image.ndim == 3: image_features.append(self.encode_images(image.unsqueeze(0)).squeeze(0)) elif image.ndim == 4: #NOTE cc-plan temp_feats = self.encode_images_noprojector(image) src_size = int(math.sqrt(temp_feats.shape[1])) temp_feats = temp_feats.reshape(temp_feats.shape[0]//5,5,-1, temp_feats.shape[-1]) x1 = temp_feats[:,4,:,:] x = temp_feats[:,:4,:,:] x = x.reshape(x.shape[0], -1, src_size, src_size, x.shape[-1]) x = x.transpose(1,2).reshape(x.shape[0], src_size,2,2, src_size, x.shape[-1]) x = x.transpose(1,2).reshape(x.shape[0], -1, x.shape[-1]) x1 = self.get_model().mm_projector_ori(x1).squeeze(0) x = self.get_model().mm_projector_ctt(x).squeeze(0) temp_feats_all = torch.cat([x,x1],dim=0) image_features.append(temp_feats_all) else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal # FIXME: this is a hacky fix, for deepspeed zero3 to work half_len = cur_input_ids.shape[0] // 2 cur_image_features = image_features[cur_image_idx] cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len]) cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:]) cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0) new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach()) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels class QH360_VLConfig(LlamaConfig): model_type = "QH_360VL" class QH360_VL_LlamaModel(QH360_VL_MetaModel, LlamaModel): config_class = QH360_VLConfig def __init__(self, config: LlamaConfig): super(QH360_VL_LlamaModel, self).__init__(config) class QH360_VL_LlamaForCausalLM(LlamaForCausalLM, QH360_VL_MetaForCausalLM): config_class = QH360_VLConfig def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) config._attn_implementation == "flash_attention_2" self.model = QH360_VL_LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if past_key_values: input_ids = input_ids[:, -1:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs def build_conversation_input_ids( self, tokenizer: "PreTrainedTokenizer", query: str, image = None, image_processor=None, ): input_msg = [ { "role": "system", "content": "You are a multilingual, helpful, respectful and honest assistant who can respond in the same language, depending on the language of the question. Try to be as helpful as possible while still being safe. Your answer should not contain anything that is false, unhealthy, harmful, immoral, racist, sexist, toxic, dangerous, or illegal, and if the question relates to such content, please decline to answer. Make sure your answer is socially fair and positive. If a question doesn't make any sense, or is inconsistent with the facts, explain why instead of answering the wrong answer. If you don't know the answer to a question, don't share false information." }, { "role": "user", "content": "<|reserved_special_token_44|>"+ '\n' + query } ] input_ids = tokenizer.apply_chat_template( input_msg, add_generation_prompt=True, padding="longest", return_tensors="pt", ) input_id_list = input_ids[0].tolist() input_id_list[input_id_list.index(128049)]=-200 input_ids = torch.tensor(input_id_list, dtype=input_ids.dtype,device=input_ids.device) input_ids = input_ids.unsqueeze(0) image_tensor = self.process_images_slid_window(image,image_processor).unsqueeze(0) return { 'input_ids': input_ids, 'image': image_tensor, } def process_images_slid_window(self, image, image_processor, vit_is=336): def get_proper_imgsize(pil_img, vit_is): max_w_h = vit_is * 2 new_pil_img = pil_img.resize((max_w_h, max_w_h)) return new_pil_img def tensor_crop(tensor_array, left, upper, right, lower): # tensor_array: C * H * W return tensor_array[:, upper:lower, left:right] def image_slid_window(image, num_slid_window): # image: tensor, 3 * 336 * 336 or 3 * 672 * 672 # image: tensor, 3 * 224 * 224 or 3 * 448 * 448 if num_slid_window == 5: image_x2, image_x1 = image[0], image[1] vit_is = image_x1.shape[1] h, w = image_x2.shape[1],image_x2.shape[2] image0 = tensor_crop(image_x2, 0, 0, vit_is, vit_is) image1 = tensor_crop(image_x2, w-vit_is, 0, w, vit_is) image2 = tensor_crop(image_x2, 0, h-vit_is, vit_is, h) image3 = tensor_crop(image_x2, w-vit_is, h-vit_is, w, h) return torch.stack([image0, image1, image2, image3, image_x1]) else: return image def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result vit_is = vit_is # vit_input_size, for simplicity num_slid_window = 5 image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) image = get_proper_imgsize(image, vit_is) image_x2 = image_processor.preprocess(image, return_tensors='pt', do_resize=False, do_center_crop=False)['pixel_values'][0] image_x1 = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] image = [image_x2, image_x1] image = image_slid_window(image, num_slid_window) return image