File size: 2,571 Bytes
bd6c4af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
image_proj_model.py

This module defines the ImageProjModel class, which is responsible for
projecting image embeddings into a different dimensional space. The model 
leverages a linear transformation followed by a layer normalization to 
reshape and normalize the input image embeddings for further processing in 
cross-attention mechanisms or other downstream tasks.

Classes:
    ImageProjModel

Dependencies:
    torch
    diffusers.ModelMixin

"""

import torch
from diffusers import ModelMixin


class ImageProjModel(ModelMixin):
    """
    ImageProjModel is a class that projects image embeddings into a different
    dimensional space. It inherits from ModelMixin, providing additional functionalities
    specific to image projection.

    Attributes:
        cross_attention_dim (int): The dimension of the cross attention.
        clip_embeddings_dim (int): The dimension of the CLIP embeddings.
        clip_extra_context_tokens (int): The number of extra context tokens in CLIP.

    Methods:
        forward(image_embeds): Forward pass of the ImageProjModel, which takes in image
        embeddings and returns the projected tokens.

    """

    def __init__(
        self,
        cross_attention_dim=1024,
        clip_embeddings_dim=1024,
        clip_extra_context_tokens=4,
    ):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(
            clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
        )
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        """
        Forward pass of the ImageProjModel, which takes in image embeddings and returns the
        projected tokens after reshaping and normalization.

        Args:
            image_embeds (torch.Tensor): The input image embeddings, with shape
            batch_size x num_image_tokens x clip_embeddings_dim.

        Returns:
            clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping
            and normalization, with shape batch_size x (clip_extra_context_tokens *
            cross_attention_dim).

        """
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens