Lwasinam commited on
Commit
63dd1a3
·
verified ·
1 Parent(s): cb0caa0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import transformers
4
+ from torch.utils.data import Dataset
5
+ from transformers import ViTFeatureExtractor
6
+ from io import BytesIO
7
+ from base64 import b64decode
8
+ from PIL import Image
9
+ from accelerate import Accelerator
10
+ import base64
11
+ from config import get_config
12
+ from pathlib import Path
13
+ from tokenizers import Tokenizer
14
+ from tokenizers.models import WordLevel
15
+ from tokenizers.trainers import WordLevelTrainer
16
+ from tokenizers.pre_tokenizers import Whitespace
17
+ from model import build_transformer
18
+ import torch.nn.functional as F
19
+ from transformers import GPT2TokenizerFast
20
+ import streamlit as st
21
+
22
+ def process(model,image, tokenizer, device):
23
+ image = get_image(image)
24
+ model.eval()
25
+ with torch.no_grad():
26
+ encoder_input = image.unsqueeze(0).to(device) # (b, seq_len)
27
+ # decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
28
+ # encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
29
+ # decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
30
+
31
+ model_out = greedy_decode(model, encoder_input, None, tokenizer, 196,device)
32
+ model_text = tokenizer.decode(model_out.detach().cpu().numpy())
33
+ return model_text
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+ # get image prompt
44
+ def get_image(image):
45
+ # import model
46
+ model_id = 'google/vit-base-patch16-224-in21k'
47
+ feature_extractor = ViTFeatureExtractor.from_pretrained(
48
+ model_id
49
+ )
50
+
51
+
52
+ image = Image.open(image)
53
+
54
+ if image.mode != 'RGB':
55
+ image = image.convert('RGB')
56
+
57
+ enc_input = feature_extractor(
58
+ image,
59
+ return_tensors='pt'
60
+ )
61
+
62
+ return enc_input['pixel_values'].squeeze(0).squeeze(0).squeeze(0).squeeze(0).squeeze(0)
63
+
64
+
65
+
66
+
67
+ #get tokenizer
68
+ def get_or_build_tokenizer(config):
69
+ tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2", unk_token ='[UNK]', bos_token = '[SOS]', eos_token = '[EOS]' , pad_token = '[PAD]')
70
+ return tokenizer
71
+
72
+
73
+ def causal_mask(size):
74
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
75
+ return mask == 0
76
+
77
+
78
+ # get model
79
+ def get_model(config, vocab_tgt_len):
80
+ model = build_transformer(vocab_tgt_len, config['seq_len'], d_model=config['d_model'])
81
+ return model
82
+
83
+ # greedy decode
84
+ def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
85
+ sos_idx = tokenizer_tgt.convert_tokens_to_ids('[SOS]')
86
+ eos_idx = tokenizer_tgt.convert_tokens_to_ids('[EOS]')
87
+
88
+ # Precompute the encoder output and reuse it for every step
89
+ encoder_output = model.encode(source, None)
90
+
91
+ # Initialize the decoder input with the sos token
92
+ decoder_input = torch.empty(1, 1).fill_(sos_idx).long().to(device)
93
+ while True:
94
+ if decoder_input.size(1) == max_len:
95
+ break
96
+
97
+ # build mask for target
98
+ decoder_mask = causal_mask(decoder_input.size(1)).long().to(device)
99
+
100
+
101
+ # calculate output
102
+ out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
103
+ # print(f'out: {out.shape}')
104
+
105
+ # Get next token probabilities with temperature applied
106
+ logits = model.project(out[:, -1])
107
+ probabilities = F.softmax(logits, dim=-1)
108
+
109
+ # Greedily select the next word
110
+ next_word = torch.argmax(probabilities, dim=1)
111
+
112
+ # Append next word
113
+ decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1)
114
+ # # get next token
115
+ # prob = model.project(out[:, -1])
116
+ # _, next_word = torch.max(prob, dim=1)
117
+ # # print(f'prob: {prob.shape}')
118
+ # decoder_input = torch.cat(
119
+ # [decoder_input, torch.empty(1, 1).long().fill_(next_word.item()).to(device)], dim=1
120
+ # )
121
+
122
+ if next_word.item() == eos_idx:
123
+ break
124
+
125
+ return decoder_input.squeeze(0)
126
+
127
+ def image_base64(image):
128
+
129
+
130
+ # with open('C:/AI/projects/vision_model_pretrained/validation/content/memory_image_23330.jpg', 'rb') as image_file:
131
+ base64_bytes = base64.b64encode(image_file.read())
132
+
133
+
134
+ base64_string = base64_bytes.decode()
135
+ return base64_string
136
+
137
+
138
+ def start():
139
+ print('start')
140
+ accelerator = Accelerator()
141
+ device = accelerator.device
142
+
143
+ config = get_config()
144
+ tokenizer = get_or_build_tokenizer(config)
145
+ model = get_model(config, len(tokenizer))
146
+ model = accelerator.prepare(model)
147
+ accelerator.load_state('C:/AI/projects/vision_model_pretrained/Vision_Model_pretrained/models/vision_model_04')
148
+
149
+ image = image_base64()
150
+
151
+
152
+
153
+ process(model, image, tokenizer, device)
154
+
155
+ # start()
156
+
157
+ def main():
158
+ st.title("Image Captioning with Transformer Models")
159
+ image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
160
+ if image is not None:
161
+
162
+ # image_bytes = uploaded_file.getvalue()
163
+ # image = image_base64(image_bytes)
164
+ # image = get_image(uploaded_file)
165
+
166
+ accelerator = Accelerator()
167
+ device = accelerator.device
168
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
+ config = get_config()
170
+ tokenizer = get_or_build_tokenizer(config)
171
+ model = get_model(config, len(tokenizer))
172
+ model = accelerator.prepare(model)
173
+ accelerator.load_state('C:/AI/projects/vision_model_pretrained/Vision_Model_pretrained/models/vision_model_04')
174
+ # model = get_model(config, len(tokenizer))
175
+ # model.to(device)
176
+
177
+ text_output = process(model, image, tokenizer, device)
178
+ st.write(text_output)
179
+
180
+ if __name__ == "__main__":
181
+ main()
182
+