ydshieh commited on
Commit
c951094
1 Parent(s): 5bedd3a

update model

Browse files
Files changed (2) hide show
  1. app.py +17 -13
  2. model.py +34 -50
app.py CHANGED
@@ -1,19 +1,17 @@
1
  import streamlit as st
2
- from PIL import Image
3
- import numpy as np
4
 
5
 
6
  # Designing the interface
7
- st.title("🖼️ French Image Captioning Demo 📝")
8
  st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
9
 
10
  st.sidebar.markdown(
11
  """
12
- An image captioning model [ViT-GPT2](https://huggingface.co/flax-community/vit-gpt2) by combining the ViT model and a French GPT2 model.
13
  [Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
14
- The GPT2 model source code is modified so it can accept an encoder's output.
15
- The pretained weights of both models are loaded, with a set of randomly initialized cross-attention weigths.
16
- The model is trained on 65000 images from the COCO dataset for about 1500 steps (batch_size=256), with the original English cpationis being translated to French for training purpose.
17
  """
18
  )
19
 
@@ -21,6 +19,7 @@ st.sidebar.markdown(
21
  #show = st.image(image, use_column_width=True)
22
  #show.image(image, 'Preloaded Image', use_column_width=True)
23
 
 
24
  with st.spinner('Loading and compiling ViT-GPT2 model ...'):
25
 
26
  from model import *
@@ -43,16 +42,21 @@ show.image(image, '\n\nSelected Image', width=480)
43
  # For newline
44
  st.sidebar.write('\n')
45
 
 
46
  with st.spinner('Generating image caption ...'):
47
 
48
  caption = predict(image)
 
 
 
49
 
50
- caption_en = translator.translate(caption, src='fr', dest='en').text
51
- st.header(f'**Prediction (in French) **{caption}')
52
- st.header(f'**English Translation**: {caption_en}')
 
53
 
54
  st.sidebar.header("ViT-GPT2 predicts:")
55
- st.sidebar.write(f"**French**: {caption}")
56
- st.sidebar.write(f"**English Translation**: {caption_en}")
57
 
58
- image.close()
 
1
  import streamlit as st
 
 
2
 
3
 
4
  # Designing the interface
5
+ st.title("🖼️ Image Captioning Demo 📝")
6
  st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
7
 
8
  st.sidebar.markdown(
9
  """
10
+ An image captioning model [ViT-GPT2](https://huggingface.co/flax-community/vit-gpt2) by combining the ViT model with the GPT2 model.
11
  [Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
12
+ The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' `FlaxVisionEncoderDecoderModel`.
13
+ The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
14
+ The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
15
  """
16
  )
17
 
 
19
  #show = st.image(image, use_column_width=True)
20
  #show.image(image, 'Preloaded Image', use_column_width=True)
21
 
22
+
23
  with st.spinner('Loading and compiling ViT-GPT2 model ...'):
24
 
25
  from model import *
 
42
  # For newline
43
  st.sidebar.write('\n')
44
 
45
+
46
  with st.spinner('Generating image caption ...'):
47
 
48
  caption = predict(image)
49
+
50
+ caption_en = caption
51
+ st.header(f'**Prediction (in English) **{caption_en}')
52
 
53
+ # caption_en = translator.translate(caption, src='fr', dest='en').text
54
+ # st.header(f'**Prediction (in French) **{caption}')
55
+ # st.header(f'**English Translation**: {caption_en}')
56
+
57
 
58
  st.sidebar.header("ViT-GPT2 predicts:")
59
+ st.sidebar.write(f"**English**: {caption}")
60
+
61
 
62
+ image.close()
model.py CHANGED
@@ -1,83 +1,67 @@
1
- import os, sys, shutil
2
- import numpy as np
3
  from PIL import Image
4
-
5
  import jax
6
- from transformers import ViTFeatureExtractor
7
- from transformers import GPT2Tokenizer
8
  from huggingface_hub import hf_hub_download
9
 
10
  from googletrans import Translator
11
  translator = Translator()
12
 
13
- current_path = os.path.dirname(os.path.abspath(__file__))
14
- sys.path.append(current_path)
15
-
16
- # Main model - ViTGPT2LM
17
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
18
 
19
  # create target model directory
20
  model_dir = './models/'
21
  os.makedirs(model_dir, exist_ok=True)
22
- # copy config file
23
- filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/config.json")
24
- shutil.copyfile(filepath, os.path.join(model_dir, 'config.json'))
25
- # copy model file
26
- filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/flax_model.msgpack")
27
- shutil.copyfile(filepath, os.path.join(model_dir, 'flax_model.msgpack'))
28
-
29
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_dir)
30
-
31
- vit_model_name = 'google/vit-base-patch16-224-in21k'
32
- feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
33
 
34
- gpt2_model_name = 'asi/gpt-fr-cased-small'
35
- tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
36
-
37
- max_length = 32
38
- num_beams = 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
40
 
41
 
42
  @jax.jit
43
- def predict_fn(pixel_values):
 
 
44
 
45
- return flax_vit_gpt2_lm.generate(pixel_values, **gen_kwargs)
46
 
47
  def predict(image):
48
 
49
- # batch dim is added automatically
50
- encoder_inputs = feature_extractor(images=image, return_tensors="jax")
51
- pixel_values = encoder_inputs.pixel_values
52
-
53
- # generation
54
- generation = predict_fn(pixel_values)
55
 
56
- token_ids = np.array(generation.sequences)[0]
57
- caption = tokenizer.decode(token_ids)
58
- caption = caption.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
59
- caption = caption.replace("à l'arrière-plan", '').replace("Une photo en noir et blanc d'", '').replace("Une photo noire et blanche d'", '').replace("en arrière-plan", '').replace("Un gros plan d'", '').replace("un gros plan d'", '').replace("Une image d'", '')
60
- while ' ' in caption:
61
- caption = caption.replace(' ', ' ')
62
- caption = caption.strip()
63
- if caption:
64
- caption = caption[0].upper() + caption[1:]
65
 
66
- return caption
67
 
68
- def compile():
69
 
70
  image_path = 'samples/val_000000039769.jpg'
71
  image = Image.open(image_path)
72
-
73
  caption = predict(image)
74
  image.close()
75
 
76
- def predict_dummy(image):
77
-
78
- return 'dummy caption!'
79
 
80
- compile()
 
81
 
82
  sample_dir = './samples/'
83
  sample_fns = tuple([f"{int(f.replace('COCO_val2014_', '').replace('.jpg', ''))}.jpg" for f in os.listdir(sample_dir) if f.startswith('COCO_val2014_')])
 
1
+ import os, shutil
 
2
  from PIL import Image
 
3
  import jax
4
+ from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
 
5
  from huggingface_hub import hf_hub_download
6
 
7
  from googletrans import Translator
8
  translator = Translator()
9
 
 
 
 
 
 
10
 
11
  # create target model directory
12
  model_dir = './models/'
13
  os.makedirs(model_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ files_to_download = [
16
+ "config.json",
17
+ "flax_model.msgpack",
18
+ "merges.txt",
19
+ "special_tokens_map.json",
20
+ "tokenizer.json",
21
+ "tokenizer_config.json",
22
+ "vocab.json",
23
+ ]
24
+
25
+ # copy files from checkpoint hub:
26
+ for fn in files_to_download:
27
+ file_path = hf_hub_download("ydshieh/vit-gpt2-coco-en", f"ckpt_epoch_3_step_6900/{fn}")
28
+ shutil.copyfile(file_path, os.path.join(model_dir, fn))
29
+
30
+ model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
31
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
32
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
33
+
34
+ max_length = 16
35
+ num_beams = 4
36
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
37
 
38
 
39
  @jax.jit
40
+ def generate(pixel_values):
41
+ output_ids = model.generate(pixel_values, **gen_kwargs).sequences
42
+ return output_ids
43
 
 
44
 
45
  def predict(image):
46
 
47
+ pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
48
+ output_ids = generate(pixel_values)
49
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
50
+ preds = [pred.strip() for pred in preds]
 
 
51
 
52
+ return preds[0]
 
 
 
 
 
 
 
 
53
 
 
54
 
55
+ def _compile():
56
 
57
  image_path = 'samples/val_000000039769.jpg'
58
  image = Image.open(image_path)
 
59
  caption = predict(image)
60
  image.close()
61
 
 
 
 
62
 
63
+ _compile()
64
+
65
 
66
  sample_dir = './samples/'
67
  sample_fns = tuple([f"{int(f.replace('COCO_val2014_', '').replace('.jpg', ''))}.jpg" for f in os.listdir(sample_dir) if f.startswith('COCO_val2014_')])