Spaces:
Sleeping
Sleeping
import torch | |
from diffusers import ConfigMixin, ModelMixin | |
class ImageProjModel(ModelMixin, ConfigMixin): | |
def __init__( | |
self, | |
cross_attention_dim=768, | |
clip_embeddings_dim=512, | |
clip_extra_context_tokens=4, | |
): | |
super().__init__() | |
self.generator = None | |
self.cross_attention_dim = cross_attention_dim | |
self.clip_extra_context_tokens = clip_extra_context_tokens | |
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds): | |
embeds = image_embeds | |
clip_extra_context_tokens = self.proj(embeds).reshape( | |
-1, self.clip_extra_context_tokens, self.cross_attention_dim | |
) | |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
return clip_extra_context_tokens | |