|
import random |
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
from PIL import Image |
|
from einops import rearrange |
|
from dataclasses import dataclass |
|
from torchvision.transforms import Normalize |
|
from torchvision.transforms import InterpolationMode |
|
from torchvision.transforms.transforms import _interpolation_modes_from_int |
|
from torchvision import transforms |
|
|
|
from transformers import CLIPTokenizer, CLIPImageProcessor |
|
from transformers.utils import ModelOutput |
|
from typing import Iterable, Optional, Union, List |
|
|
|
import craftsman |
|
from craftsman.utils.typing import * |
|
from .clip.modeling_clip import CLIPModel |
|
from .clip.modeling_conditional_clip import ConditionalCLIPModel |
|
from .base import BaseEmbedder, ImageType |
|
|
|
@dataclass |
|
class CLIPEmbedOutput(ModelOutput): |
|
last_hidden_state: torch.FloatTensor = None |
|
pooler_output: torch.FloatTensor = None |
|
embeds: torch.FloatTensor = None |
|
|
|
@craftsman.register("clip-embedder") |
|
class CLIPEmbedder(BaseEmbedder): |
|
|
|
@dataclass |
|
class Config(BaseEmbedder.Config): |
|
freeze_modulation: bool = False |
|
config_path: str = '' |
|
|
|
cfg: Config |
|
|
|
def configure(self) -> None: |
|
super().configure() |
|
|
|
|
|
if not self.cfg.encode_camera: |
|
self.model: CLIPModel = CLIPModel.from_pretrained(self.cfg.pretrained_model_name_or_path) |
|
else: |
|
if self.cfg.pretrained_model_name_or_path == '': |
|
assert self.cfg.config_path is not None, "The config path should be provided" |
|
conditional_clip_config = ConditionalCLIPModel.config_class.from_json_file(self.cfg.config_path) |
|
conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim |
|
self.model: CLIPModel = ConditionalCLIPModel(conditional_clip_config) |
|
else: |
|
conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained( |
|
self.cfg.pretrained_model_name_or_path, |
|
) |
|
conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim |
|
self.model: CLIPModel = ConditionalCLIPModel.from_pretrained( |
|
self.cfg.pretrained_model_name_or_path, |
|
vision_config=conditional_clip_config.vision_config |
|
) |
|
|
|
self.tokenizer = None |
|
self.image_preprocess = CLIPImageProcessor() |
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), |
|
transforms.CenterCrop(224), |
|
transforms.Normalize( |
|
mean=[0.48145466, 0.4578275, 0.40821073], |
|
std=[0.26862954, 0.26130258, 0.27577711], |
|
), |
|
] |
|
) |
|
|
|
self.logit_scale = self.model.logit_scale.exp() |
|
|
|
if self.cfg.zero_uncond_embeds: |
|
self.empty_text_embeds = torch.zeros((1, 77, 768)).detach() |
|
self.empty_image_embeds = torch.zeros((self.cfg.n_views, 257, 1024)).detach() |
|
else: |
|
try: |
|
self.empty_text_embeds = self.encode_text([""]).detach() |
|
except: |
|
self.empty_text_embeds = None |
|
if self.cfg.encode_camera: |
|
self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3), self.cameras[:self.cfg.n_views]).detach() |
|
else: |
|
self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3)).detach() |
|
|
|
|
|
self.model.eval() |
|
for k, p in self.model.named_parameters(): |
|
ks = k.split('.') |
|
if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation: |
|
p.requires_grad_(True) |
|
else: |
|
p.requires_grad_(False) |
|
|
|
def encode_image(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: |
|
camera_embeds = None |
|
if isinstance(images, (np.ndarray, torch.Tensor)): |
|
assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]" |
|
do_rescale = False |
|
if self.cfg.encode_camera: |
|
assert cameras is not None, "The cameras should be provided" |
|
camera_embeds = self.encode_camera(cameras) |
|
pixel_values = self.transform(images.permute(0, 3, 1, 2)) |
|
else: |
|
do_rescale = True |
|
if self.cfg.encode_camera: |
|
if cameras is None: |
|
bs = len(images) // self.cfg.n_views |
|
cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.model.device) |
|
camera_embeds = self.encode_camera(cameras) |
|
pixel_values = self.image_preprocess.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values |
|
|
|
if force_none_camera_embeds: |
|
camera_embeds = None |
|
|
|
packed = False |
|
if pixel_values.ndim == 4: |
|
packed = True |
|
pixel_values = pixel_values.unsqueeze(1) |
|
if camera_embeds is not None: |
|
camera_embeds = camera_embeds.unsqueeze(1) |
|
|
|
if self.cfg.encode_camera and camera_embeds is not None: |
|
vision_outputs = self.model.vision_model( |
|
pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), |
|
condition=rearrange(camera_embeds, "B N C -> (B N) C") |
|
) |
|
else: |
|
vision_outputs = self.model.vision_model( |
|
pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), |
|
) |
|
|
|
if return_dict: |
|
pooler_output = vision_outputs[1] |
|
image_features = self.model.visual_projection(pooler_output) |
|
|
|
return CLIPEmbedOutput( |
|
last_hidden_state=vision_outputs.last_hidden_state, |
|
pooler_output=pooler_output, |
|
embeds=image_features |
|
) |
|
else: |
|
return vision_outputs.last_hidden_state |
|
|
|
@torch.no_grad() |
|
def encode_text(self, text_inputs: torch.Tensor, return_dict: bool = False) -> torch.FloatTensor: |
|
if self.tokenizer is None: |
|
self.tokenizer = CLIPTokenizer.from_pretrained(self.cfg.pretrained_model_name_or_path) |
|
|
|
if isinstance(text_inputs, list): |
|
text_inputs = self.tokenizer( |
|
text_inputs, |
|
max_length=self.tokenizer.model_max_length, |
|
padding="max_length", |
|
return_tensors="pt" |
|
).input_ids |
|
text_outputs = self.model.text_model(input_ids=text_inputs.to(self.model.device)) |
|
|
|
pooler_output = text_outputs[1] |
|
text_features = self.model.text_projection(pooler_output) |
|
|
|
if return_dict: |
|
return CLIPEmbedOutput( |
|
last_hidden_state=text_outputs.last_hidden_state, |
|
pooler_output=pooler_output, |
|
embeds=text_features |
|
) |
|
else: |
|
return text_outputs.last_hidden_state |