Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import PIL | |
import math | |
import torch | |
import torch.nn as nn | |
from torch import optim | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F | |
from torch.distributions import Categorical | |
import torchvision | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
from transformers import AutoTokenizer | |
device = torch.device(0 if torch.cuda.is_available() else 'cpu') | |
def extract_patches(image_tensor, patch_size=16): | |
# Get the dimensions of the image tensor | |
bs, c, h, w = image_tensor.size() | |
# Define the Unfold layer with appropriate parameters | |
unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size) | |
# Apply Unfold to the image tensor | |
unfolded = unfold(image_tensor) | |
# Reshape the unfolded tensor to match the desired output shape | |
# Output shape: BSxLxH, where L is the number of patches in each dimension | |
unfolded = unfolded.transpose(1, 2).reshape(bs, -1, c * patch_size * patch_size) | |
return unfolded | |
# sinusoidal positional embeds | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
# Define a module for attention blocks | |
class AttentionBlock(nn.Module): | |
def __init__(self, hidden_size=128, num_heads=4, masking=True): | |
super(AttentionBlock, self).__init__() | |
self.masking = masking | |
# Multi-head attention mechanism | |
self.multihead_attn = nn.MultiheadAttention(hidden_size, | |
num_heads=num_heads, | |
batch_first=True, | |
dropout=0.0) | |
def forward(self, x_in, kv_in, key_mask=None): | |
# Apply causal masking if enabled | |
if self.masking: | |
bs, l, h = x_in.shape | |
mask = torch.triu(torch.ones(l, l, device=x_in.device), 1).bool() | |
else: | |
mask = None | |
# Perform multi-head attention operation | |
return self.multihead_attn(x_in, kv_in, kv_in, attn_mask=mask, | |
key_padding_mask=key_mask)[0] | |
# Define a module for a transformer block with self-attention | |
# and optional causal masking | |
class TransformerBlock(nn.Module): | |
def __init__(self, hidden_size=128, num_heads=4, decoder=False, masking=True): | |
super(TransformerBlock, self).__init__() | |
self.decoder = decoder | |
# Layer normalization for the input | |
self.norm1 = nn.LayerNorm(hidden_size) | |
# Self-attention mechanism | |
self.attn1 = AttentionBlock(hidden_size=hidden_size, num_heads=num_heads, | |
masking=masking) | |
# Layer normalization for the output of the first attention layer | |
if self.decoder: | |
self.norm2 = nn.LayerNorm(hidden_size) | |
# Self-attention mechanism for the decoder with no masking | |
self.attn2 = AttentionBlock(hidden_size=hidden_size, | |
num_heads=num_heads, masking=False) | |
# Layer normalization for the output before the MLP | |
self.norm_mlp = nn.LayerNorm(hidden_size) | |
# Multi-layer perceptron (MLP) | |
self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size * 4), | |
nn.ELU(), | |
nn.Linear(hidden_size * 4, hidden_size)) | |
def forward(self, x, input_key_mask=None, cross_key_mask=None, kv_cross=None): | |
# Perform self-attention operation | |
x = self.attn1(x, x, key_mask=input_key_mask) + x | |
x = self.norm1(x) | |
# If decoder, perform additional cross-attention layer | |
if self.decoder: | |
x = self.attn2(x, kv_cross, key_mask=cross_key_mask) + x | |
x = self.norm2(x) | |
# Apply MLP and layer normalization | |
x = self.mlp(x) + x | |
return self.norm_mlp(x) | |
# Define a decoder module for the Transformer architecture | |
class Decoder(nn.Module): | |
def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4): | |
super(Decoder, self).__init__() | |
# Create an embedding layer for tokens | |
self.embedding = nn.Embedding(num_emb, hidden_size) | |
# Initialize the embedding weights | |
self.embedding.weight.data = 0.001 * self.embedding.weight.data | |
# Initialize sinusoidal positional embeddings | |
self.pos_emb = SinusoidalPosEmb(hidden_size) | |
# Create multiple transformer blocks as layers | |
self.blocks = nn.ModuleList([ | |
TransformerBlock(hidden_size, num_heads, | |
decoder=True) for _ in range(num_layers) | |
]) | |
# Define a linear layer for output prediction | |
self.fc_out = nn.Linear(hidden_size, num_emb) | |
def forward(self, input_seq, encoder_output, input_padding_mask=None, | |
encoder_padding_mask=None): | |
# Embed the input sequence | |
input_embs = self.embedding(input_seq) | |
bs, l, h = input_embs.shape | |
# Add positional embeddings to the input embeddings | |
seq_indx = torch.arange(l, device=input_seq.device) | |
pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h) | |
embs = input_embs + pos_emb | |
# Pass the embeddings through each transformer block | |
for block in self.blocks: | |
embs = block(embs, | |
input_key_mask=input_padding_mask, | |
cross_key_mask=encoder_padding_mask, | |
kv_cross=encoder_output) | |
return self.fc_out(embs) | |
# Define an Vision Encoder module for the Transformer architecture | |
class VisionEncoder(nn.Module): | |
def __init__(self, image_size, channels_in, patch_size=16, hidden_size=128, | |
num_layers=3, num_heads=4): | |
super(VisionEncoder, self).__init__() | |
self.patch_size = patch_size | |
self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size) | |
seq_length = (image_size // patch_size) ** 2 | |
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, | |
hidden_size).normal_(std=0.02)) | |
# Create multiple transformer blocks as layers | |
self.blocks = nn.ModuleList([ | |
TransformerBlock(hidden_size, num_heads, | |
decoder=False, masking=False) for _ in range(num_layers) | |
]) | |
def forward(self, image): | |
bs = image.shape[0] | |
patch_seq = extract_patches(image, patch_size=self.patch_size) | |
patch_emb = self.fc_in(patch_seq) | |
# Add a unique embedding to each token embedding | |
embs = patch_emb + self.pos_embedding | |
# Pass the embeddings through each transformer block | |
for block in self.blocks: | |
embs = block(embs) | |
return embs | |
# Define an Vision Encoder-Decoder module for the Transformer architecture | |
class VisionEncoderDecoder(nn.Module): | |
def __init__(self, image_size, channels_in, num_emb, patch_size=16, | |
hidden_size=128, num_layers=(3, 3), num_heads=4): | |
super(VisionEncoderDecoder, self).__init__() | |
# Create an encoder and decoder with specified parameters | |
self.encoder = VisionEncoder(image_size=image_size, channels_in=channels_in, | |
patch_size=patch_size, hidden_size=hidden_size, | |
num_layers=num_layers[0], num_heads=num_heads) | |
self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size, | |
num_layers=num_layers[1], num_heads=num_heads) | |
def forward(self, input_image, target_seq, padding_mask): | |
# Generate padding masks for the target sequence | |
bool_padding_mask = padding_mask == 0 | |
# Encode the input sequence | |
encoded_seq = self.encoder(image=input_image) | |
# Decode the target sequence using the encoded sequence | |
decoded_seq = self.decoder(input_seq=target_seq, | |
encoder_output=encoded_seq, | |
input_padding_mask=bool_padding_mask) | |
return decoded_seq | |
model = torch.load("caption_model.pth", weights_only=False) | |
model.eval() | |
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
def pred_transformer_caption(test_img): | |
# Add the Start-Of-Sentence token to the prompt to signal the network to start generating the caption | |
sos_token = 101 * torch.ones(1, 1).long() | |
# Set the temperature for sampling during generation | |
temp = 0.5 | |
log_tokens = [sos_token] | |
model.eval() | |
with torch.no_grad(): | |
# Encode the input image | |
with torch.cuda.amp.autocast(): | |
# Forward pass | |
image_embedding = model.encoder(test_img.to(device)) | |
# Generate the answer tokens | |
for i in range(50): | |
input_tokens = torch.cat(log_tokens, 1) | |
# Decode the input tokens into the next predicted tokens | |
data_pred = model.decoder(input_tokens.to(device), image_embedding) | |
# Sample from the distribution of predicted probabilities | |
dist = Categorical(logits=data_pred[:, -1] / temp) | |
next_tokens = dist.sample().reshape(1, 1) | |
# Append the next predicted token to the sequence | |
log_tokens.append(next_tokens.cpu()) | |
# Break the loop if the End-Of-Caption token is predicted | |
if next_tokens.item() == 102: | |
break | |
# Convert the list of token indices to a tensor | |
pred_text = torch.cat(log_tokens, 1) | |
# Convert the token indices to their corresponding strings using the vocabulary | |
pred_text_strings = tokenizer.decode(pred_text[0], skip_special_tokens=True) | |
# Join the token strings to form the predicted text | |
pred_text = "".join(pred_text_strings) | |
# Print the predicted text | |
return (pred_text) | |
##Dashboard | |
st.title("Caption_APP") | |
test_img=st.file_uploader(label="upload the funny pic :) :", type=["png","jpg","jpeg"]) | |
caption="" | |
if test_img: | |
test_img=PIL.Image.open(test_img) | |
test_img=test_img.resize((128,128)) | |
test_img=((test_img-np.amin(test_img))/(np.amax(test_img)-np.amin(test_img))) | |
test_img=np.array(test_img) | |
test_img=test_img.reshape((1,)+test_img.shape) | |
test_img=test_img.astype("float32") | |
copy=test_img | |
test_img=torch.from_numpy(test_img).to(device).unsqueeze(0) | |
caption=(str)(pred_transformer_caption(test_img)) | |
st.image(image=np.squeeze(copy),caption=caption) | |
#st.write(caption) |