Soumen commited on
Commit
4b17ef5
1 Parent(s): 049e141

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -27
app.py CHANGED
@@ -2,52 +2,54 @@ import streamlit as st
2
  import torch
3
  from PIL import Image
4
  from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
5
- st.title("Image_Captioning_App")
6
- model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
7
- feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
- tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
 
9
  #pickle.load(open('energy_model.pkl', 'rb'))
10
  #vocab = np.load('w2i.p', allow_pickle=True)
 
 
11
  #st.text("Build with Streamlit and OpenCV")
12
  if "photo" not in st.session_state:
13
  st.session_state["photo"]="not done"
14
-
15
  c2, c3 = st.columns([2,1])
16
  def change_photo_state():
17
  st.session_state["photo"]="done"
18
- print("="*150)
19
- print("RESNET MODEL LOADED")
20
-
21
  @st.cache
22
  def load_image(img):
23
  im = Image.open(img)
24
  return im
25
  uploaded_photo = c2.file_uploader("Upload Image",type=['jpg','png','jpeg'], on_change=change_photo_state)
26
  camera_photo = c2.camera_input("Take a photo", on_change=change_photo_state)
27
- st.subheader("The following activity is occurring ...")
28
- if st.session_state["photo"]=="done":
29
- if uploaded_photo:
30
- our_image= load_image(uploaded_photo)
31
- elif camera_photo:
32
- our_image= load_image(camera_photo)
33
- elif uploaded_photo==None and camera_photo==None:
34
- our_image= load_image('image.jpg')
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  model.to(device)
37
  max_length = 16
38
  num_beams = 4
39
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
40
  def predict_step(our_image):
41
- if our_image.mode != "RGB":
42
- our_image = our_image.convert(mode="RGB")
43
- pixel_values = feature_extractor(images=our_image, return_tensors="pt").pixel_values
44
- pixel_values = pixel_values.to(device)
45
- output_ids = model.generate(pixel_values, **gen_kwargs)
46
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
47
- preds = [pred.strip() for pred in preds]
48
- return preds
49
- st.success(predict_step(our_image))
50
- if st.checkbox('About'):
 
 
 
 
 
 
 
51
  st.subheader("About Image Captioning App")
52
- st.markdown("Built with Streamlit by [Soumen Sarker](https://soumen-sarker-personal-site.streamlit.app/)")
53
  st.markdown("Demo applicaton of the following model [credit](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning/)")
 
2
  import torch
3
  from PIL import Image
4
  from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
5
+ @st.cache
6
+ def load_models():
7
+ model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
9
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
10
+ return model, feature_extractor, tokenizer
11
  #pickle.load(open('energy_model.pkl', 'rb'))
12
  #vocab = np.load('w2i.p', allow_pickle=True)
13
+ st.title("Image_Captioning_App")
14
+
15
  #st.text("Build with Streamlit and OpenCV")
16
  if "photo" not in st.session_state:
17
  st.session_state["photo"]="not done"
 
18
  c2, c3 = st.columns([2,1])
19
  def change_photo_state():
20
  st.session_state["photo"]="done"
 
 
 
21
  @st.cache
22
  def load_image(img):
23
  im = Image.open(img)
24
  return im
25
  uploaded_photo = c2.file_uploader("Upload Image",type=['jpg','png','jpeg'], on_change=change_photo_state)
26
  camera_photo = c2.camera_input("Take a photo", on_change=change_photo_state)
27
+ #st.subheader("Detection")
28
+ if st.checkbox("Generate_Caption"):
29
+ model, feature_extractor, tokenizer = load_models()
 
 
 
 
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  model.to(device)
32
  max_length = 16
33
  num_beams = 4
34
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
35
  def predict_step(our_image):
36
+ if our_image.mode != "RGB":
37
+ our_image = our_image.convert(mode="RGB")
38
+ pixel_values = feature_extractor(images=our_image, return_tensors="pt").pixel_values
39
+ pixel_values = pixel_values.to(device)
40
+ output_ids = model.generate(pixel_values, **gen_kwargs)
41
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
42
+ preds = [pred.strip() for pred in preds]
43
+ return preds
44
+ if st.session_state["photo"]=="done":
45
+ if uploaded_photo:
46
+ our_image= load_image(uploaded_photo)
47
+ elif camera_photo:
48
+ our_image= load_image(camera_photo)
49
+ elif uploaded_photo==None and camera_photo==None:
50
+ our_image= load_image('image.jpg')
51
+ st.success(predict_step(our_image))
52
+ elif st.checkbox("About"):
53
  st.subheader("About Image Captioning App")
54
+ st.markdown("Built with Streamlit by [Soumen Sarker](https://soumen-sarker-personal-website.streamlit.app/)")
55
  st.markdown("Demo applicaton of the following model [credit](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning/)")