import random import torch from torch import nn import numpy as np from PIL import Image from einops import rearrange from dataclasses import dataclass from torchvision.transforms import Normalize from torchvision.transforms import InterpolationMode from torchvision.transforms.transforms import _interpolation_modes_from_int from torchvision import transforms from transformers import CLIPTokenizer, CLIPImageProcessor from transformers.utils import ModelOutput from typing import Iterable, Optional, Union, List import craftsman from craftsman.utils.typing import * from .clip.modeling_clip import CLIPModel from .clip.modeling_conditional_clip import ConditionalCLIPModel from .base import BaseEmbedder, ImageType @dataclass class CLIPEmbedOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None pooler_output: torch.FloatTensor = None embeds: torch.FloatTensor = None @craftsman.register("clip-embedder") class CLIPEmbedder(BaseEmbedder): @dataclass class Config(BaseEmbedder.Config): freeze_modulation: bool = False config_path: str = '' cfg: Config def configure(self) -> None: super().configure() # Load the CLIP model and processor if not self.cfg.encode_camera: self.model: CLIPModel = CLIPModel.from_pretrained(self.cfg.pretrained_model_name_or_path) else: if self.cfg.pretrained_model_name_or_path == '': assert self.cfg.config_path is not None, "The config path should be provided" conditional_clip_config = ConditionalCLIPModel.config_class.from_json_file(self.cfg.config_path) conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim self.model: CLIPModel = ConditionalCLIPModel(conditional_clip_config) else: conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained( self.cfg.pretrained_model_name_or_path, ) conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim self.model: CLIPModel = ConditionalCLIPModel.from_pretrained( self.cfg.pretrained_model_name_or_path, vision_config=conditional_clip_config.vision_config ) self.tokenizer = None self.image_preprocess = CLIPImageProcessor() self.transform = transforms.Compose( [ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), transforms.CenterCrop(224), # crop a (224, 224) square transforms.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], ), ] ) self.logit_scale = self.model.logit_scale.exp() if self.cfg.zero_uncond_embeds: self.empty_text_embeds = torch.zeros((1, 77, 768)).detach() self.empty_image_embeds = torch.zeros((self.cfg.n_views, 257, 1024)).detach() else: try: self.empty_text_embeds = self.encode_text([""]).detach() # [1, 77, 768] except: self.empty_text_embeds = None if self.cfg.encode_camera: self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3), self.cameras[:self.cfg.n_views]).detach() else: self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3)).detach() # Freeze the model parameters self.model.eval() for k, p in self.model.named_parameters(): ks = k.split('.') if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation: p.requires_grad_(True) else: p.requires_grad_(False) def encode_image(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: camera_embeds = None if isinstance(images, (np.ndarray, torch.Tensor)): # for training process assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]" do_rescale = False if self.cfg.encode_camera: assert cameras is not None, "The cameras should be provided" camera_embeds = self.encode_camera(cameras) pixel_values = self.transform(images.permute(0, 3, 1, 2)) else: # for inference process do_rescale = True if self.cfg.encode_camera: if cameras is None: bs = len(images) // self.cfg.n_views cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.model.device) camera_embeds = self.encode_camera(cameras) pixel_values = self.image_preprocess.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values if force_none_camera_embeds: camera_embeds = None packed = False if pixel_values.ndim == 4: packed = True pixel_values = pixel_values.unsqueeze(1) if camera_embeds is not None: camera_embeds = camera_embeds.unsqueeze(1) if self.cfg.encode_camera and camera_embeds is not None: vision_outputs = self.model.vision_model( pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), condition=rearrange(camera_embeds, "B N C -> (B N) C") ) else: vision_outputs = self.model.vision_model( pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), ) if return_dict: pooler_output = vision_outputs[1] # pooled_output image_features = self.model.visual_projection(pooler_output) return CLIPEmbedOutput( last_hidden_state=vision_outputs.last_hidden_state, pooler_output=pooler_output, embeds=image_features ) else: return vision_outputs.last_hidden_state @torch.no_grad() def encode_text(self, text_inputs: torch.Tensor, return_dict: bool = False) -> torch.FloatTensor: if self.tokenizer is None: self.tokenizer = CLIPTokenizer.from_pretrained(self.cfg.pretrained_model_name_or_path) if isinstance(text_inputs, list): text_inputs = self.tokenizer( text_inputs, max_length=self.tokenizer.model_max_length, padding="max_length", return_tensors="pt" ).input_ids text_outputs = self.model.text_model(input_ids=text_inputs.to(self.model.device)) pooler_output = text_outputs[1] # pooled_output text_features = self.model.text_projection(pooler_output) if return_dict: return CLIPEmbedOutput( last_hidden_state=text_outputs.last_hidden_state, pooler_output=pooler_output, embeds=text_features ) else: return text_outputs.last_hidden_state