random_random / app.py
raun12345678's picture
Update app.py
9f385c9 verified
raw
history blame
11.1 kB
import streamlit as st
import numpy as np
import PIL
import math
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.distributions import Categorical
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from transformers import AutoTokenizer
device = torch.device(0 if torch.cuda.is_available() else 'cpu')
def extract_patches(image_tensor, patch_size=16):
# Get the dimensions of the image tensor
bs, c, h, w = image_tensor.size()
# Define the Unfold layer with appropriate parameters
unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
# Apply Unfold to the image tensor
unfolded = unfold(image_tensor)
# Reshape the unfolded tensor to match the desired output shape
# Output shape: BSxLxH, where L is the number of patches in each dimension
unfolded = unfolded.transpose(1, 2).reshape(bs, -1, c * patch_size * patch_size)
return unfolded
# sinusoidal positional embeds
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# Define a module for attention blocks
class AttentionBlock(nn.Module):
def __init__(self, hidden_size=128, num_heads=4, masking=True):
super(AttentionBlock, self).__init__()
self.masking = masking
# Multi-head attention mechanism
self.multihead_attn = nn.MultiheadAttention(hidden_size,
num_heads=num_heads,
batch_first=True,
dropout=0.0)
def forward(self, x_in, kv_in, key_mask=None):
# Apply causal masking if enabled
if self.masking:
bs, l, h = x_in.shape
mask = torch.triu(torch.ones(l, l, device=x_in.device), 1).bool()
else:
mask = None
# Perform multi-head attention operation
return self.multihead_attn(x_in, kv_in, kv_in, attn_mask=mask,
key_padding_mask=key_mask)[0]
# Define a module for a transformer block with self-attention
# and optional causal masking
class TransformerBlock(nn.Module):
def __init__(self, hidden_size=128, num_heads=4, decoder=False, masking=True):
super(TransformerBlock, self).__init__()
self.decoder = decoder
# Layer normalization for the input
self.norm1 = nn.LayerNorm(hidden_size)
# Self-attention mechanism
self.attn1 = AttentionBlock(hidden_size=hidden_size, num_heads=num_heads,
masking=masking)
# Layer normalization for the output of the first attention layer
if self.decoder:
self.norm2 = nn.LayerNorm(hidden_size)
# Self-attention mechanism for the decoder with no masking
self.attn2 = AttentionBlock(hidden_size=hidden_size,
num_heads=num_heads, masking=False)
# Layer normalization for the output before the MLP
self.norm_mlp = nn.LayerNorm(hidden_size)
# Multi-layer perceptron (MLP)
self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size * 4),
nn.ELU(),
nn.Linear(hidden_size * 4, hidden_size))
def forward(self, x, input_key_mask=None, cross_key_mask=None, kv_cross=None):
# Perform self-attention operation
x = self.attn1(x, x, key_mask=input_key_mask) + x
x = self.norm1(x)
# If decoder, perform additional cross-attention layer
if self.decoder:
x = self.attn2(x, kv_cross, key_mask=cross_key_mask) + x
x = self.norm2(x)
# Apply MLP and layer normalization
x = self.mlp(x) + x
return self.norm_mlp(x)
# Define a decoder module for the Transformer architecture
class Decoder(nn.Module):
def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
super(Decoder, self).__init__()
# Create an embedding layer for tokens
self.embedding = nn.Embedding(num_emb, hidden_size)
# Initialize the embedding weights
self.embedding.weight.data = 0.001 * self.embedding.weight.data
# Initialize sinusoidal positional embeddings
self.pos_emb = SinusoidalPosEmb(hidden_size)
# Create multiple transformer blocks as layers
self.blocks = nn.ModuleList([
TransformerBlock(hidden_size, num_heads,
decoder=True) for _ in range(num_layers)
])
# Define a linear layer for output prediction
self.fc_out = nn.Linear(hidden_size, num_emb)
def forward(self, input_seq, encoder_output, input_padding_mask=None,
encoder_padding_mask=None):
# Embed the input sequence
input_embs = self.embedding(input_seq)
bs, l, h = input_embs.shape
# Add positional embeddings to the input embeddings
seq_indx = torch.arange(l, device=input_seq.device)
pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
embs = input_embs + pos_emb
# Pass the embeddings through each transformer block
for block in self.blocks:
embs = block(embs,
input_key_mask=input_padding_mask,
cross_key_mask=encoder_padding_mask,
kv_cross=encoder_output)
return self.fc_out(embs)
# Define an Vision Encoder module for the Transformer architecture
class VisionEncoder(nn.Module):
def __init__(self, image_size, channels_in, patch_size=16, hidden_size=128,
num_layers=3, num_heads=4):
super(VisionEncoder, self).__init__()
self.patch_size = patch_size
self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size)
seq_length = (image_size // patch_size) ** 2
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length,
hidden_size).normal_(std=0.02))
# Create multiple transformer blocks as layers
self.blocks = nn.ModuleList([
TransformerBlock(hidden_size, num_heads,
decoder=False, masking=False) for _ in range(num_layers)
])
def forward(self, image):
bs = image.shape[0]
patch_seq = extract_patches(image, patch_size=self.patch_size)
patch_emb = self.fc_in(patch_seq)
# Add a unique embedding to each token embedding
embs = patch_emb + self.pos_embedding
# Pass the embeddings through each transformer block
for block in self.blocks:
embs = block(embs)
return embs
# Define an Vision Encoder-Decoder module for the Transformer architecture
class VisionEncoderDecoder(nn.Module):
def __init__(self, image_size, channels_in, num_emb, patch_size=16,
hidden_size=128, num_layers=(3, 3), num_heads=4):
super(VisionEncoderDecoder, self).__init__()
# Create an encoder and decoder with specified parameters
self.encoder = VisionEncoder(image_size=image_size, channels_in=channels_in,
patch_size=patch_size, hidden_size=hidden_size,
num_layers=num_layers[0], num_heads=num_heads)
self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size,
num_layers=num_layers[1], num_heads=num_heads)
def forward(self, input_image, target_seq, padding_mask):
# Generate padding masks for the target sequence
bool_padding_mask = padding_mask == 0
# Encode the input sequence
encoded_seq = self.encoder(image=input_image)
# Decode the target sequence using the encoded sequence
decoded_seq = self.decoder(input_seq=target_seq,
encoder_output=encoded_seq,
input_padding_mask=bool_padding_mask)
return decoded_seq
model = torch.load("caption_model.pth", weights_only=False)
model.eval()
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased")
def pred_transformer_caption(test_img):
# Add the Start-Of-Sentence token to the prompt to signal the network to start generating the caption
sos_token = 101 * torch.ones(1, 1).long()
# Set the temperature for sampling during generation
temp = 0.5
log_tokens = [sos_token]
model.eval()
with torch.no_grad():
# Encode the input image
with torch.cuda.amp.autocast():
# Forward pass
image_embedding = model.encoder(test_img.to(device))
# Generate the answer tokens
for i in range(50):
input_tokens = torch.cat(log_tokens, 1)
# Decode the input tokens into the next predicted tokens
data_pred = model.decoder(input_tokens.to(device), image_embedding)
# Sample from the distribution of predicted probabilities
dist = Categorical(logits=data_pred[:, -1] / temp)
next_tokens = dist.sample().reshape(1, 1)
# Append the next predicted token to the sequence
log_tokens.append(next_tokens.cpu())
# Break the loop if the End-Of-Caption token is predicted
if next_tokens.item() == 102:
break
# Convert the list of token indices to a tensor
pred_text = torch.cat(log_tokens, 1)
# Convert the token indices to their corresponding strings using the vocabulary
pred_text_strings = tokenizer.decode(pred_text[0], skip_special_tokens=True)
# Join the token strings to form the predicted text
pred_text = "".join(pred_text_strings)
# Print the predicted text
return (pred_text)
##Dashboard
st.title("Caption_APP")
test_img=st.file_uploader(label="upload the funny pic :) :", type=["png","jpg","jpeg"])
caption=""
if test_img:
test_img=PIL.Image.open(test_img)
test_img=test_img.resize((128,128))
test_img=((test_img-np.amin(test_img))/(np.amax(test_img)-np.amin(test_img)))
test_img=np.array(test_img)
test_img=test_img.reshape((1,)+test_img.shape)
test_img=test_img.astype("float32")
copy=test_img
test_img=torch.from_numpy(test_img).to(device).unsqueeze(0)
caption=(str)(pred_transformer_caption(test_img))
st.image(image=np.squeeze(copy),caption=caption)
#st.write(caption)