File size: 7,438 Bytes
0f079b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import random
import torch
from torch import nn
import numpy as np
from PIL import Image
from einops import rearrange
from dataclasses import dataclass
from torchvision.transforms import Normalize
from torchvision.transforms import InterpolationMode
from torchvision.transforms.transforms import _interpolation_modes_from_int
from torchvision import transforms

from transformers import CLIPTokenizer, CLIPImageProcessor
from transformers.utils import ModelOutput
from typing import Iterable, Optional, Union, List

import craftsman
from craftsman.utils.typing import *
from .clip.modeling_clip import CLIPModel
from .clip.modeling_conditional_clip import ConditionalCLIPModel
from .base import BaseEmbedder, ImageType

@dataclass
class CLIPEmbedOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    embeds: torch.FloatTensor = None

@craftsman.register("clip-embedder")
class CLIPEmbedder(BaseEmbedder):

    @dataclass
    class Config(BaseEmbedder.Config):
        freeze_modulation: bool = False
        config_path: str = ''

    cfg: Config

    def configure(self) -> None:
        super().configure()

        # Load the CLIP model and processor
        if not self.cfg.encode_camera:
            self.model: CLIPModel = CLIPModel.from_pretrained(self.cfg.pretrained_model_name_or_path)
        else:
            if self.cfg.pretrained_model_name_or_path == '':
                assert self.cfg.config_path is not None, "The config path should be provided"
                conditional_clip_config = ConditionalCLIPModel.config_class.from_json_file(self.cfg.config_path)
                conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim
                self.model: CLIPModel = ConditionalCLIPModel(conditional_clip_config)
            else:
                conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained(
                    self.cfg.pretrained_model_name_or_path,
                )
                conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim
                self.model: CLIPModel = ConditionalCLIPModel.from_pretrained(
                    self.cfg.pretrained_model_name_or_path, 
                    vision_config=conditional_clip_config.vision_config
                )
                
        self.tokenizer = None
        self.image_preprocess = CLIPImageProcessor()
        self.transform = transforms.Compose(
            [
                transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
                transforms.CenterCrop(224),  # crop a (224, 224) square
                transforms.Normalize(
                    mean=[0.48145466, 0.4578275, 0.40821073],
                    std=[0.26862954, 0.26130258, 0.27577711],
                ),
            ]
        )

        self.logit_scale = self.model.logit_scale.exp()

        if self.cfg.zero_uncond_embeds:
            self.empty_text_embeds = torch.zeros((1, 77, 768)).detach()
            self.empty_image_embeds = torch.zeros((self.cfg.n_views, 257, 1024)).detach()
        else:
            try:
                self.empty_text_embeds = self.encode_text([""]).detach() # [1, 77, 768]
            except:
                self.empty_text_embeds = None
            if self.cfg.encode_camera:
                self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3), self.cameras[:self.cfg.n_views]).detach()
            else:
                self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3)).detach()

        # Freeze the model parameters
        self.model.eval()
        for k, p in self.model.named_parameters():
            ks = k.split('.')
            if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation:
                p.requires_grad_(True)
            else:
                p.requires_grad_(False)

    def encode_image(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor:
        camera_embeds = None
        if isinstance(images, (np.ndarray, torch.Tensor)): # for training process
            assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]"
            do_rescale = False
            if self.cfg.encode_camera:
                assert cameras is not None, "The cameras should be provided"
                camera_embeds = self.encode_camera(cameras)
            pixel_values = self.transform(images.permute(0, 3, 1, 2))
        else: # for inference process
            do_rescale = True
            if self.cfg.encode_camera:
                if cameras is None:
                    bs = len(images) // self.cfg.n_views
                    cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.model.device)
                camera_embeds = self.encode_camera(cameras)
            pixel_values = self.image_preprocess.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values

        if force_none_camera_embeds:
            camera_embeds = None

        packed = False
        if pixel_values.ndim == 4:
            packed = True
            pixel_values = pixel_values.unsqueeze(1)
            if camera_embeds is not None:
                camera_embeds = camera_embeds.unsqueeze(1)

        if self.cfg.encode_camera and camera_embeds is not None:
            vision_outputs = self.model.vision_model(
                pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"),
                condition=rearrange(camera_embeds, "B N C -> (B N) C")
            )
        else:
            vision_outputs = self.model.vision_model(
                pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), 
            )

        if return_dict:
            pooler_output = vision_outputs[1]  # pooled_output
            image_features = self.model.visual_projection(pooler_output)

            return CLIPEmbedOutput(
                last_hidden_state=vision_outputs.last_hidden_state,
                pooler_output=pooler_output,
                embeds=image_features
            )
        else:
            return vision_outputs.last_hidden_state

    @torch.no_grad()
    def encode_text(self, text_inputs: torch.Tensor, return_dict: bool = False) -> torch.FloatTensor:
        if self.tokenizer is None:
            self.tokenizer = CLIPTokenizer.from_pretrained(self.cfg.pretrained_model_name_or_path)
            
        if isinstance(text_inputs, list):
            text_inputs = self.tokenizer(
                text_inputs, 
                max_length=self.tokenizer.model_max_length,
                padding="max_length",
                return_tensors="pt"
            ).input_ids
        text_outputs = self.model.text_model(input_ids=text_inputs.to(self.model.device))

        pooler_output = text_outputs[1]  # pooled_output
        text_features = self.model.text_projection(pooler_output)

        if return_dict:
            return CLIPEmbedOutput(
                last_hidden_state=text_outputs.last_hidden_state,
                pooler_output=pooler_output,
                embeds=text_features
            )
        else:
            return text_outputs.last_hidden_state