Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import transformers | |
from torch.utils.data import Dataset | |
from transformers import ViTFeatureExtractor | |
from io import BytesIO | |
from base64 import b64decode | |
from PIL import Image | |
from accelerate import Accelerator | |
import base64 | |
from config import get_config | |
from pathlib import Path | |
from tokenizers import Tokenizer | |
from tokenizers.models import WordLevel | |
from tokenizers.trainers import WordLevelTrainer | |
from tokenizers.pre_tokenizers import Whitespace | |
from model import build_transformer | |
import torch.nn.functional as F | |
from transformers import GPT2TokenizerFast | |
import streamlit as st | |
def process(model,image, tokenizer, device): | |
image = get_image(image) | |
model.eval() | |
with torch.no_grad(): | |
encoder_input = image.unsqueeze(0).to(device) # (b, seq_len) | |
model_out = greedy_decode(model, encoder_input, None, tokenizer, 196,device) | |
model_text = tokenizer.decode(model_out.detach().cpu().numpy()) | |
return model_text | |
# get image prompt | |
def get_image(image): | |
# import model | |
model_id = 'google/vit-base-patch16-224-in21k' | |
feature_extractor = ViTFeatureExtractor.from_pretrained( | |
model_id | |
) | |
image = Image.open(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
enc_input = feature_extractor( | |
image, | |
return_tensors='pt' | |
) | |
return enc_input['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0) | |
#get tokenizer | |
def get_or_build_tokenizer(config): | |
tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2", unk_token ='[UNK]', bos_token = '[SOS]', eos_token = '[EOS]' , pad_token = '[PAD]') | |
return tokenizer | |
def causal_mask(size): | |
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) | |
return mask == 0 | |
# get model | |
def get_model(config, vocab_tgt_len): | |
model = build_transformer(vocab_tgt_len, config['seq_len'], d_model=config['d_model']) | |
return model | |
# greedy decode | |
def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device): | |
sos_idx = tokenizer_tgt.convert_tokens_to_ids('[SOS]') | |
eos_idx = tokenizer_tgt.convert_tokens_to_ids('[EOS]') | |
# Precompute the encoder output and reuse it for every step | |
encoder_output = model.encode(source, None) | |
# Initialize the decoder input with the sos token | |
decoder_input = torch.empty(1, 1).fill_(sos_idx).long().to(device) | |
while True: | |
if decoder_input.size(1) == max_len: | |
break | |
# build mask for target | |
decoder_mask = causal_mask(decoder_input.size(1)).long().to(device) | |
# calculate output | |
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask) | |
# print(f'out: {out.shape}') | |
# Get next token probabilities with temperature applied | |
logits = model.project(out[:, -1]) | |
probabilities = F.softmax(logits, dim=-1) | |
# Greedily select the next word | |
next_word = torch.argmax(probabilities, dim=1) | |
# Append next word | |
decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1) | |
if next_word.item() == eos_idx: | |
break | |
return decoder_input.squeeze(0) | |
def image_base64(image): | |
base64_bytes = base64.b64encode(image_file.read()) | |
base64_string = base64_bytes.decode() | |
return base64_string | |
def main(): | |
st.title("Image Captioning with Vision Transformer") | |
image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
if image is not None: | |
st.image(image, use_column_width=True) | |
# image_bytes = uploaded_file.getvalue() | |
# image = image_base64(image_bytes) | |
# image = get_image(uploaded_file) | |
with st.empty(): | |
st.write("Processing the image... Please wait.") | |
accelerator = Accelerator() | |
device = accelerator.device | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
config = get_config() | |
tokenizer = get_or_build_tokenizer(config) | |
model = get_model(config, len(tokenizer)) | |
model = accelerator.prepare(model) | |
accelerator.load_state('models/') | |
text_output = process(model, image, tokenizer, device) | |
st.write(text_output) | |
if __name__ == "__main__": | |
main() | |