Spaces:
Runtime error
Runtime error
ydshieh
commited on
Commit
•
c951094
1
Parent(s):
5bedd3a
update model
Browse files
app.py
CHANGED
@@ -1,19 +1,17 @@
|
|
1 |
import streamlit as st
|
2 |
-
from PIL import Image
|
3 |
-
import numpy as np
|
4 |
|
5 |
|
6 |
# Designing the interface
|
7 |
-
st.title("🖼️
|
8 |
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
|
9 |
|
10 |
st.sidebar.markdown(
|
11 |
"""
|
12 |
-
An image captioning model [ViT-GPT2](https://huggingface.co/flax-community/vit-gpt2) by combining the ViT model
|
13 |
[Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
|
14 |
-
The
|
15 |
-
The
|
16 |
-
The model is trained on
|
17 |
"""
|
18 |
)
|
19 |
|
@@ -21,6 +19,7 @@ st.sidebar.markdown(
|
|
21 |
#show = st.image(image, use_column_width=True)
|
22 |
#show.image(image, 'Preloaded Image', use_column_width=True)
|
23 |
|
|
|
24 |
with st.spinner('Loading and compiling ViT-GPT2 model ...'):
|
25 |
|
26 |
from model import *
|
@@ -43,16 +42,21 @@ show.image(image, '\n\nSelected Image', width=480)
|
|
43 |
# For newline
|
44 |
st.sidebar.write('\n')
|
45 |
|
|
|
46 |
with st.spinner('Generating image caption ...'):
|
47 |
|
48 |
caption = predict(image)
|
|
|
|
|
|
|
49 |
|
50 |
-
caption_en = translator.translate(caption, src='fr', dest='en').text
|
51 |
-
st.header(f'**Prediction (in French) **{caption}')
|
52 |
-
st.header(f'**English Translation**: {caption_en}')
|
|
|
53 |
|
54 |
st.sidebar.header("ViT-GPT2 predicts:")
|
55 |
-
st.sidebar.write(f"**
|
56 |
-
|
57 |
|
58 |
-
image.close()
|
|
|
1 |
import streamlit as st
|
|
|
|
|
2 |
|
3 |
|
4 |
# Designing the interface
|
5 |
+
st.title("🖼️ Image Captioning Demo 📝")
|
6 |
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
|
7 |
|
8 |
st.sidebar.markdown(
|
9 |
"""
|
10 |
+
An image captioning model [ViT-GPT2](https://huggingface.co/flax-community/vit-gpt2) by combining the ViT model with the GPT2 model.
|
11 |
[Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
|
12 |
+
The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' `FlaxVisionEncoderDecoderModel`.
|
13 |
+
The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
|
14 |
+
The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
|
15 |
"""
|
16 |
)
|
17 |
|
|
|
19 |
#show = st.image(image, use_column_width=True)
|
20 |
#show.image(image, 'Preloaded Image', use_column_width=True)
|
21 |
|
22 |
+
|
23 |
with st.spinner('Loading and compiling ViT-GPT2 model ...'):
|
24 |
|
25 |
from model import *
|
|
|
42 |
# For newline
|
43 |
st.sidebar.write('\n')
|
44 |
|
45 |
+
|
46 |
with st.spinner('Generating image caption ...'):
|
47 |
|
48 |
caption = predict(image)
|
49 |
+
|
50 |
+
caption_en = caption
|
51 |
+
st.header(f'**Prediction (in English) **{caption_en}')
|
52 |
|
53 |
+
# caption_en = translator.translate(caption, src='fr', dest='en').text
|
54 |
+
# st.header(f'**Prediction (in French) **{caption}')
|
55 |
+
# st.header(f'**English Translation**: {caption_en}')
|
56 |
+
|
57 |
|
58 |
st.sidebar.header("ViT-GPT2 predicts:")
|
59 |
+
st.sidebar.write(f"**English**: {caption}")
|
60 |
+
|
61 |
|
62 |
+
image.close()
|
model.py
CHANGED
@@ -1,83 +1,67 @@
|
|
1 |
-
import os,
|
2 |
-
import numpy as np
|
3 |
from PIL import Image
|
4 |
-
|
5 |
import jax
|
6 |
-
from transformers import ViTFeatureExtractor
|
7 |
-
from transformers import GPT2Tokenizer
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
|
10 |
from googletrans import Translator
|
11 |
translator = Translator()
|
12 |
|
13 |
-
current_path = os.path.dirname(os.path.abspath(__file__))
|
14 |
-
sys.path.append(current_path)
|
15 |
-
|
16 |
-
# Main model - ViTGPT2LM
|
17 |
-
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
|
18 |
|
19 |
# create target model directory
|
20 |
model_dir = './models/'
|
21 |
os.makedirs(model_dir, exist_ok=True)
|
22 |
-
# copy config file
|
23 |
-
filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/config.json")
|
24 |
-
shutil.copyfile(filepath, os.path.join(model_dir, 'config.json'))
|
25 |
-
# copy model file
|
26 |
-
filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/flax_model.msgpack")
|
27 |
-
shutil.copyfile(filepath, os.path.join(model_dir, 'flax_model.msgpack'))
|
28 |
-
|
29 |
-
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_dir)
|
30 |
-
|
31 |
-
vit_model_name = 'google/vit-base-patch16-224-in21k'
|
32 |
-
feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
40 |
|
41 |
|
42 |
@jax.jit
|
43 |
-
def
|
|
|
|
|
44 |
|
45 |
-
return flax_vit_gpt2_lm.generate(pixel_values, **gen_kwargs)
|
46 |
|
47 |
def predict(image):
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
# generation
|
54 |
-
generation = predict_fn(pixel_values)
|
55 |
|
56 |
-
|
57 |
-
caption = tokenizer.decode(token_ids)
|
58 |
-
caption = caption.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
|
59 |
-
caption = caption.replace("à l'arrière-plan", '').replace("Une photo en noir et blanc d'", '').replace("Une photo noire et blanche d'", '').replace("en arrière-plan", '').replace("Un gros plan d'", '').replace("un gros plan d'", '').replace("Une image d'", '')
|
60 |
-
while ' ' in caption:
|
61 |
-
caption = caption.replace(' ', ' ')
|
62 |
-
caption = caption.strip()
|
63 |
-
if caption:
|
64 |
-
caption = caption[0].upper() + caption[1:]
|
65 |
|
66 |
-
return caption
|
67 |
|
68 |
-
def
|
69 |
|
70 |
image_path = 'samples/val_000000039769.jpg'
|
71 |
image = Image.open(image_path)
|
72 |
-
|
73 |
caption = predict(image)
|
74 |
image.close()
|
75 |
|
76 |
-
def predict_dummy(image):
|
77 |
-
|
78 |
-
return 'dummy caption!'
|
79 |
|
80 |
-
|
|
|
81 |
|
82 |
sample_dir = './samples/'
|
83 |
sample_fns = tuple([f"{int(f.replace('COCO_val2014_', '').replace('.jpg', ''))}.jpg" for f in os.listdir(sample_dir) if f.startswith('COCO_val2014_')])
|
|
|
1 |
+
import os, shutil
|
|
|
2 |
from PIL import Image
|
|
|
3 |
import jax
|
4 |
+
from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
|
|
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
|
7 |
from googletrans import Translator
|
8 |
translator = Translator()
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# create target model directory
|
12 |
model_dir = './models/'
|
13 |
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
files_to_download = [
|
16 |
+
"config.json",
|
17 |
+
"flax_model.msgpack",
|
18 |
+
"merges.txt",
|
19 |
+
"special_tokens_map.json",
|
20 |
+
"tokenizer.json",
|
21 |
+
"tokenizer_config.json",
|
22 |
+
"vocab.json",
|
23 |
+
]
|
24 |
+
|
25 |
+
# copy files from checkpoint hub:
|
26 |
+
for fn in files_to_download:
|
27 |
+
file_path = hf_hub_download("ydshieh/vit-gpt2-coco-en", f"ckpt_epoch_3_step_6900/{fn}")
|
28 |
+
shutil.copyfile(file_path, os.path.join(model_dir, fn))
|
29 |
+
|
30 |
+
model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
|
31 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
33 |
+
|
34 |
+
max_length = 16
|
35 |
+
num_beams = 4
|
36 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
37 |
|
38 |
|
39 |
@jax.jit
|
40 |
+
def generate(pixel_values):
|
41 |
+
output_ids = model.generate(pixel_values, **gen_kwargs).sequences
|
42 |
+
return output_ids
|
43 |
|
|
|
44 |
|
45 |
def predict(image):
|
46 |
|
47 |
+
pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
|
48 |
+
output_ids = generate(pixel_values)
|
49 |
+
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
50 |
+
preds = [pred.strip() for pred in preds]
|
|
|
|
|
51 |
|
52 |
+
return preds[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
|
|
54 |
|
55 |
+
def _compile():
|
56 |
|
57 |
image_path = 'samples/val_000000039769.jpg'
|
58 |
image = Image.open(image_path)
|
|
|
59 |
caption = predict(image)
|
60 |
image.close()
|
61 |
|
|
|
|
|
|
|
62 |
|
63 |
+
_compile()
|
64 |
+
|
65 |
|
66 |
sample_dir = './samples/'
|
67 |
sample_fns = tuple([f"{int(f.replace('COCO_val2014_', '').replace('.jpg', ''))}.jpg" for f in os.listdir(sample_dir) if f.startswith('COCO_val2014_')])
|