Xhaheen commited on
Commit
eb35bd7
1 Parent(s): 27d5ee1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer
5
+
6
+
7
+ def create_caption_transformer(img):
8
+ """
9
+ create_caption_transformer() create a caption for an image using a transformer model
10
+ that was trained on 'Flickr image dataset'
11
+ :param img: a numpy array of the image
12
+ :return: a string of the image caption
13
+ """
14
+
15
+ sample = feature_extractor(img, return_tensors="pt").pixel_values.to('cpu')
16
+ caption_ids = model.generate(sample)[0]
17
+ caption_text = tokenizer.decode(caption_ids, skip_special_tokens=True)
18
+ caption_text = caption_text.split('.')[0]
19
+ return caption_text
20
+
21
+
22
+ torch.__version__
23
+ IMAGES_EXAMPLES_FOLDER = 'examples/'
24
+ images = os.listdir(IMAGES_EXAMPLES_FOLDER)
25
+ IMAGES_EXAMPLES = [IMAGES_EXAMPLES_FOLDER + img for img in images]
26
+ model = VisionEncoderDecoderModel.from_pretrained(os.getcwd()).to('cpu')
27
+ feature_extractor = AutoFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
28
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
29
+ iface = gr.Interface(fn=create_caption_transformer,
30
+ inputs="image",
31
+ outputs='text',
32
+ examples=IMAGES_EXAMPLES
33
+ ).launch(share=True)