|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrefixEncoder(torch.nn.Module):
|
|
def __init__(self,config):
|
|
super(PrefixEncoder,self).__init__()
|
|
self.config=config
|
|
self.device=config.device
|
|
self.dtype=config.dtype
|
|
self.num_virtual_tokens=config.num_virtual_tokens
|
|
|
|
self.token_dim=config.token_dim
|
|
self.encoder_hidden_size=config.encoder_hidden_size
|
|
self.num_layers=config.num_layers
|
|
"""
|
|
self.transformer=torch.nn.Sequential(
|
|
torch.nn.Linear(self.token_dim,self.encoder_hidden_size,device=self.device,dtype=self.dtype),
|
|
torch.nn.Tanh(),
|
|
torch.nn.Linear(self.encoder_hidden_size,self.num_layers*2*self.token_dim,device=self.device,dtype=self.dtype),
|
|
)
|
|
"""
|
|
self.prefix_embedding=nn.Parameter(torch.zeros(1,self.num_virtual_tokens,self.token_dim*2*self.num_layers,device=self.device,dtype=self.dtype),requires_grad=False)
|
|
def forward(self,batch_size):
|
|
"""
|
|
input_ids=input_ids.unsqueeze(0).expand(batch_size,self.num_virtual_tokens)
|
|
prefix_embedding=self.embedding(input_ids)
|
|
prefix_embedding=self.transformer(prefix_embedding)
|
|
self.register_parameter("prefix_embedding",nn.Parameter(prefix_embedding,requires_grad=False))
|
|
"""
|
|
|
|
|
|
|
|
prefix_embedding=self.prefix_embedding.expand(batch_size,self.num_virtual_tokens,self.token_dim*2*self.num_layers)
|
|
prefix_embedding=prefix_embedding.reshape(batch_size,self.num_virtual_tokens,self.num_layers,2,self.token_dim)
|
|
prefix_embedding=prefix_embedding.permute(3,2,0,1,4)
|
|
k,v=prefix_embedding.chunk(2,dim=0)
|
|
return (k.squeeze(0),v.squeeze(0))
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(self,config):
|
|
super(MultiHeadAttention,self).__init__()
|
|
self.hidden_size=config.hidden_size
|
|
self.num_heads=config.num_heads
|
|
self.head_size=self.hidden_size//self.num_heads
|
|
|
|
self.in_proj_weight=nn.Parameter(torch.zeros(3*config.hidden_size,config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=True)
|
|
self.in_proj_bias=nn.Parameter(torch.zeros(3*config.hidden_size,device=config.device,dtype=config.dtype),requires_grad=True)
|
|
|
|
|
|
|
|
self.out_proj=nn.Linear(self.hidden_size,self.hidden_size,bias=True,device=config.device,dtype=config.dtype)
|
|
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
|
b,n,c=hidden_state.shape
|
|
|
|
|
|
|
|
q,k,v=(torch.matmul(hidden_state,self.in_proj_weight.T)+self.in_proj_bias.expand(b,n,-1)).chunk(3,dim=-1)
|
|
if prefix_k is not None and prefix_v is not None:
|
|
|
|
k=torch.cat((prefix_k,k),dim=1)
|
|
|
|
v=torch.cat((prefix_v,v),dim=1)
|
|
bk,nk,hk=k.shape
|
|
bq,nq,hq=q.shape
|
|
q=q.view(bq,nq,self.num_heads,self.head_size).permute(0,2,1,3)
|
|
k=k.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
|
v=v.view(bk,nk,self.num_heads,self.head_size).permute(0,2,1,3)
|
|
attention_logits=F.scaled_dot_product_attention(q, k, v)
|
|
attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(bk,nq,self.hidden_size)
|
|
attention_output=self.out_proj(attention_logits)
|
|
return attention_output
|
|
|
|
|
|
class QuickGELU(nn.Module):
|
|
def __init__(self):
|
|
super(QuickGELU,self).__init__()
|
|
def forward(self,x):
|
|
old_dtype=x.dtype
|
|
x=x.to(torch.float32)
|
|
return (x*torch.sigmoid(1.702*x)).to(old_dtype)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self,config):
|
|
super(MLP,self).__init__()
|
|
self.hidden_size=config.hidden_size
|
|
self.c_fc=nn.Linear(self.hidden_size,4*self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
|
self.gelu=QuickGELU()
|
|
self.c_proj=nn.Linear(self.hidden_size*4,self.hidden_size,device=config.device,bias=True,dtype=config.dtype)
|
|
def forward(self,hidden_state):
|
|
hidden_state=self.c_fc(hidden_state)
|
|
hidden_state=self.gelu(hidden_state)
|
|
hidden_state=self.c_proj(hidden_state)
|
|
return hidden_state
|
|
|
|
class ResidualAttentionBlock(nn.Module):
|
|
def __init__(self,config):
|
|
super(ResidualAttentionBlock,self).__init__()
|
|
self.ln_1=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
|
self.ln_2=nn.LayerNorm(config.hidden_size,eps=config.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
|
|
|
self.attn=MultiHeadAttention(config)
|
|
self.mlp=MLP(config)
|
|
def forward(self,hidden_state,prefix_k=None,prefix_v=None):
|
|
residual=hidden_state
|
|
hidden_state=self.ln_1(hidden_state)
|
|
hidden_state=self.attn(hidden_state,prefix_k,prefix_v)
|
|
hidden_state=residual+hidden_state
|
|
residual=hidden_state
|
|
hidden_state=self.ln_2(hidden_state)
|
|
hidden_state=self.mlp(hidden_state)
|
|
hidden_state=residual+hidden_state
|
|
return hidden_state
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(self,config):
|
|
super(Transformer,self).__init__()
|
|
self.resblocks=nn.ModuleList([ResidualAttentionBlock(config) for _ in range(config.num_layers)])
|
|
self.prefix=PrefixEncoder(config)
|
|
|
|
|
|
def forward(self,hidden_state):
|
|
b,n,h=hidden_state.shape
|
|
prefix_k,prefix_v=self.prefix(b)
|
|
for index,resblock in enumerate(self.resblocks):
|
|
hidden_state=resblock(hidden_state,prefix_k[index],prefix_v[index])
|
|
return hidden_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextEncoder_Config:
|
|
def __init__(self,vocab_size,max_position_embeddings,hidden_size,num_layers,num_heads,device,dtype):
|
|
self.vocab_size=vocab_size
|
|
self.max_position_embeddings=max_position_embeddings
|
|
self.hidden_size=hidden_size
|
|
self.num_layers=num_layers
|
|
self.num_heads=num_heads
|
|
self.device=device
|
|
self.dtype=dtype
|
|
self.norm_eps=1e-5
|
|
self.num_virtual_tokens=20
|
|
self.token_dim=hidden_size
|
|
self.encoder_hidden_size=hidden_size
|
|
textencoder_config=TextEncoder_Config(
|
|
vocab_size=49408,
|
|
max_position_embeddings=77,
|
|
hidden_size=512,
|
|
num_layers=12,
|
|
num_heads=8,
|
|
device=torch.device('cuda:0'),
|
|
dtype=torch.float16
|
|
)
|
|
|
|
Encoder_model=Transformer(textencoder_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def position_embedding(x,position_ids):
|
|
hidden_size=x.size(2)
|
|
seq_len=x.size(1)
|
|
div_term=torch.exp(torch.arange(0,hidden_size,2,device=x.device).float()*(-math.log(10000.0)/hidden_size))
|
|
positional_encoding=torch.zeros(seq_len,hidden_size,device=x.device)
|
|
positional_encoding[:,0::2]=torch.sin(position_ids.float()[:,None]*div_term)
|
|
positional_encoding[:,1::2]=torch.cos(position_ids.float()[:,None]*div_term)
|
|
positional_encoding=positional_encoding.unsqueeze(0)
|
|
return positional_encoding
|
|
|
|
class VisionTransformer(nn.Module):
|
|
def __init__(self,config):
|
|
super(VisionTransformer,self).__init__()
|
|
self.image_channel=config.image_channel
|
|
self.hidden_size=config.hidden_size
|
|
self.norm_eps=config.norm_eps
|
|
self.patch_size=config.patch_size
|
|
self.output_dim=config.output_dim
|
|
self.dtype=config.dtype
|
|
self.num_virtual_tokens=config.num_virtual_tokens if hasattr(config,"num_virtual_tokens") else None
|
|
self.conv1=nn.Conv2d(self.image_channel,self.hidden_size,self.patch_size,stride=self.patch_size,bias=False,device=config.device,dtype=config.dtype)
|
|
self.ln_pre=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
|
self.transformer=Transformer(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.class_embedding=nn.Parameter(torch.zeros(config.hidden_size,device=config.device),requires_grad=True)
|
|
|
|
self.positional_embedding=nn.Parameter(torch.zeros(config.num_patches+1,config.hidden_size,device=config.device),requires_grad=True)
|
|
|
|
self.proj=nn.Parameter(torch.zeros(config.hidden_size,config.output_dim,device=config.device,dtype=config.dtype),requires_grad=True)
|
|
self.ln_post=nn.LayerNorm(self.hidden_size,eps=self.norm_eps,elementwise_affine=True,device=config.device,dtype=config.dtype)
|
|
def forward(self,hidden_state):
|
|
b,c,h,w=hidden_state.shape
|
|
|
|
hidden_state=self.conv1(hidden_state)
|
|
hidden_state=hidden_state.reshape(b,self.hidden_size,-1).transpose(1,2)
|
|
|
|
hidden_state=torch.cat((self.class_embedding.expand(b,1,-1).to(hidden_state.dtype),hidden_state),dim=1)
|
|
|
|
|
|
hidden_state=hidden_state+self.positional_embedding.unsqueeze(0).to(hidden_state.dtype)
|
|
hidden_state=self.ln_pre(hidden_state)
|
|
hidden_state=self.transformer(hidden_state)
|
|
|
|
if self.num_virtual_tokens is not None:
|
|
hidden_state=hidden_state[:,self.num_virtual_tokens,:]
|
|
else:
|
|
hidden_state=hidden_state[:,0,:]
|
|
hidden_state=self.ln_post(hidden_state)
|
|
hidden_state=torch.matmul(hidden_state,self.proj)
|
|
return hidden_state
|
|
|
|
class ViTConfig:
|
|
def __init__(self,image_channel,hidden_size,num_heads,num_layers,patch_size,num_patches,output_dim,norm_eps,device):
|
|
self.image_channel=image_channel
|
|
self.hidden_size=hidden_size
|
|
self.num_heads=num_heads
|
|
self.num_layers=num_layers
|
|
self.patch_size=patch_size
|
|
self.num_patches=num_patches
|
|
self.norm_eps=norm_eps
|
|
self.device=device
|
|
self.dtype=torch.float16
|
|
self.patch_token_num=self.hidden_size//self.patch_size**2+1
|
|
self.output_dim=output_dim
|
|
self.num_virtual_tokens=20
|
|
self.token_dim=self.hidden_size
|
|
self.encoder_hidden_size=self.hidden_size
|
|
|
|
config=ViTConfig(3,768,12,12,32,49,512,1e-5,torch.device("cuda"))
|
|
VIT_model=VisionTransformer(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIP(nn.Module):
|
|
def __init__(self,config):
|
|
super().__init__()
|
|
self.visual=VIT_model
|
|
self.device=config.device
|
|
self.dtype=config.dtype
|
|
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
|
|
self.transformer=Encoder_model
|
|
self.positional_embedding=nn.Parameter(torch.randn(config.max_position_embeddings,config.hidden_size,device=config.device))
|
|
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
|
|
self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
|
|
self.logit_scale=nn.Parameter(torch.ones([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=True)
|
|
def encode_image(self,img):
|
|
return self.visual(img)
|
|
def encode_text(self,text):
|
|
token_embedding=self.token_embedding(text)
|
|
position_embedding=self.positional_embedding[None,:text.shape[1],:].to(self.dtype)
|
|
text_embedding=token_embedding+position_embedding
|
|
text_embedding=self.transformer(text_embedding)
|
|
text_embedding=self.ln_final(text_embedding)
|
|
|
|
text_embedding=text_embedding[torch.arange(text.shape[0]),text.argmax(dim=-1)]
|
|
text_embedding=text_embedding@self.text_projection.to(self.dtype)
|
|
|
|
return text_embedding
|
|
|
|
def forward(self,image,text):
|
|
image_features=self.encode_image(image)
|
|
text_features=self.encode_text(text)
|
|
|
|
image_features=image_features/image_features.norm(dim=-1,keepdim=True)
|
|
text_features=text_features/text_features.norm(dim=-1,keepdim=True)
|
|
|
|
logit_scale=self.logit_scale.exp()
|
|
logits_per_image=logit_scale*image_features@text_features.t()
|
|
logits_per_text=logits_per_image.t()
|
|
|
|
return logits_per_image,logits_per_text
|
|
|
|
class CLIPConfig:
|
|
def __init__(self):
|
|
self.vocab_size=49408
|
|
self.hidden_size=512
|
|
self.max_position_embeddings=77
|
|
self.num_hidden_layers=12
|
|
self.num_attention_heads=8
|
|
self.layer_norm_eps=1e-5
|
|
self.activation_function="Quickgelu"
|
|
self.dtype=torch.float16
|
|
self.device=torch.device("cuda:0")
|
|
self.logit_scale_init=4.6052
|
|
self.num_virtual_tokens=20
|
|
self.token_dim=self.hidden_size
|
|
self.encoder_hidden_size=self.hidden_size
|
|
CLIPconfig=CLIPConfig()
|
|
model=CLIP(CLIPconfig)
|
|
|
|
model.load_state_dict(torch.load(r'./Mix_CLIP.pth',weights_only=True),strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pickle
|
|
with open('./preprocess.pkl','rb') as f:
|
|
preprocess = pickle.load(f)
|
|
with open('./tokenize.pkl','rb') as f:
|
|
tokenizer=pickle.load(f)
|
|
|
|
|
|
|