Spaces:
Running
Running
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import transformers
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from transformers import ViTFeatureExtractor
|
6 |
+
from io import BytesIO
|
7 |
+
from base64 import b64decode
|
8 |
+
from PIL import Image
|
9 |
+
from accelerate import Accelerator
|
10 |
+
import base64
|
11 |
+
from config import get_config
|
12 |
+
from pathlib import Path
|
13 |
+
from tokenizers import Tokenizer
|
14 |
+
from tokenizers.models import WordLevel
|
15 |
+
from tokenizers.trainers import WordLevelTrainer
|
16 |
+
from tokenizers.pre_tokenizers import Whitespace
|
17 |
+
from model import build_transformer
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from transformers import GPT2TokenizerFast
|
20 |
+
import streamlit as st
|
21 |
+
|
22 |
+
def process(model,image, tokenizer, device):
|
23 |
+
image = get_image(image)
|
24 |
+
model.eval()
|
25 |
+
with torch.no_grad():
|
26 |
+
encoder_input = image.unsqueeze(0).to(device) # (b, seq_len)
|
27 |
+
# decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
|
28 |
+
# encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
|
29 |
+
# decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
|
30 |
+
|
31 |
+
model_out = greedy_decode(model, encoder_input, None, tokenizer, 196,device)
|
32 |
+
model_text = tokenizer.decode(model_out.detach().cpu().numpy())
|
33 |
+
return model_text
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
# get image prompt
|
44 |
+
def get_image(image):
|
45 |
+
# import model
|
46 |
+
model_id = 'google/vit-base-patch16-224-in21k'
|
47 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(
|
48 |
+
model_id
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
image = Image.open(image)
|
53 |
+
|
54 |
+
if image.mode != 'RGB':
|
55 |
+
image = image.convert('RGB')
|
56 |
+
|
57 |
+
enc_input = feature_extractor(
|
58 |
+
image,
|
59 |
+
return_tensors='pt'
|
60 |
+
)
|
61 |
+
|
62 |
+
return enc_input['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0)
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
#get tokenizer
|
68 |
+
def get_or_build_tokenizer(config):
|
69 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2", unk_token ='[UNK]', bos_token = '[SOS]', eos_token = '[EOS]' , pad_token = '[PAD]')
|
70 |
+
return tokenizer
|
71 |
+
|
72 |
+
|
73 |
+
def causal_mask(size):
|
74 |
+
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
|
75 |
+
return mask == 0
|
76 |
+
|
77 |
+
|
78 |
+
# get model
|
79 |
+
def get_model(config, vocab_tgt_len):
|
80 |
+
model = build_transformer(vocab_tgt_len, config['seq_len'], d_model=config['d_model'])
|
81 |
+
return model
|
82 |
+
|
83 |
+
# greedy decode
|
84 |
+
def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
|
85 |
+
sos_idx = tokenizer_tgt.convert_tokens_to_ids('[SOS]')
|
86 |
+
eos_idx = tokenizer_tgt.convert_tokens_to_ids('[EOS]')
|
87 |
+
|
88 |
+
# Precompute the encoder output and reuse it for every step
|
89 |
+
encoder_output = model.encode(source, None)
|
90 |
+
|
91 |
+
# Initialize the decoder input with the sos token
|
92 |
+
decoder_input = torch.empty(1, 1).fill_(sos_idx).long().to(device)
|
93 |
+
while True:
|
94 |
+
if decoder_input.size(1) == max_len:
|
95 |
+
break
|
96 |
+
|
97 |
+
# build mask for target
|
98 |
+
decoder_mask = causal_mask(decoder_input.size(1)).long().to(device)
|
99 |
+
|
100 |
+
|
101 |
+
# calculate output
|
102 |
+
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
|
103 |
+
# print(f'out: {out.shape}')
|
104 |
+
|
105 |
+
# Get next token probabilities with temperature applied
|
106 |
+
logits = model.project(out[:, -1])
|
107 |
+
probabilities = F.softmax(logits, dim=-1)
|
108 |
+
|
109 |
+
# Greedily select the next word
|
110 |
+
next_word = torch.argmax(probabilities, dim=1)
|
111 |
+
|
112 |
+
# Append next word
|
113 |
+
decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1)
|
114 |
+
# # get next token
|
115 |
+
# prob = model.project(out[:, -1])
|
116 |
+
# _, next_word = torch.max(prob, dim=1)
|
117 |
+
# # print(f'prob: {prob.shape}')
|
118 |
+
# decoder_input = torch.cat(
|
119 |
+
# [decoder_input, torch.empty(1, 1).long().fill_(next_word.item()).to(device)], dim=1
|
120 |
+
# )
|
121 |
+
|
122 |
+
if next_word.item() == eos_idx:
|
123 |
+
break
|
124 |
+
|
125 |
+
return decoder_input.squeeze(0)
|
126 |
+
|
127 |
+
def image_base64(image):
|
128 |
+
|
129 |
+
|
130 |
+
# with open('C:/AI/projects/vision_model_pretrained/validation/content/memory_image_23330.jpg', 'rb') as image_file:
|
131 |
+
base64_bytes = base64.b64encode(image_file.read())
|
132 |
+
|
133 |
+
|
134 |
+
base64_string = base64_bytes.decode()
|
135 |
+
return base64_string
|
136 |
+
|
137 |
+
|
138 |
+
def start():
|
139 |
+
print('start')
|
140 |
+
accelerator = Accelerator()
|
141 |
+
device = accelerator.device
|
142 |
+
|
143 |
+
config = get_config()
|
144 |
+
tokenizer = get_or_build_tokenizer(config)
|
145 |
+
model = get_model(config, len(tokenizer))
|
146 |
+
model = accelerator.prepare(model)
|
147 |
+
accelerator.load_state('C:/AI/projects/vision_model_pretrained/Vision_Model_pretrained/models/vision_model_04')
|
148 |
+
|
149 |
+
image = image_base64()
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
process(model, image, tokenizer, device)
|
154 |
+
|
155 |
+
# start()
|
156 |
+
|
157 |
+
def main():
|
158 |
+
st.title("Image Captioning with Transformer Models")
|
159 |
+
image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
|
160 |
+
if image is not None:
|
161 |
+
|
162 |
+
# image_bytes = uploaded_file.getvalue()
|
163 |
+
# image = image_base64(image_bytes)
|
164 |
+
# image = get_image(uploaded_file)
|
165 |
+
|
166 |
+
accelerator = Accelerator()
|
167 |
+
device = accelerator.device
|
168 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
169 |
+
config = get_config()
|
170 |
+
tokenizer = get_or_build_tokenizer(config)
|
171 |
+
model = get_model(config, len(tokenizer))
|
172 |
+
model = accelerator.prepare(model)
|
173 |
+
accelerator.load_state('C:/AI/projects/vision_model_pretrained/Vision_Model_pretrained/models/vision_model_04')
|
174 |
+
# model = get_model(config, len(tokenizer))
|
175 |
+
# model.to(device)
|
176 |
+
|
177 |
+
text_output = process(model, image, tokenizer, device)
|
178 |
+
st.write(text_output)
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
main()
|
182 |
+
|