MEMO / memo /models /image_proj.py
fffiloni's picture
Migrated from GitHub
1a9b87d verified
raw
history blame
929 Bytes
import torch
from diffusers import ConfigMixin, ModelMixin
class ImageProjModel(ModelMixin, ConfigMixin):
def __init__(
self,
cross_attention_dim=768,
clip_embeddings_dim=512,
clip_extra_context_tokens=4,
):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens