zamborg commited on
Commit
5d3b8a6
·
1 Parent(s): d6fffee

added caching to loading

Browse files
Files changed (2) hide show
  1. app.py +9 -7
  2. model.py +10 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import io
3
  import sys
4
  import time
5
  import json
 
6
  sys.path.append("./virtex/")
7
 
8
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
@@ -12,7 +13,13 @@ def gen_show_caption(sub_prompt=None, cap_prompt = ""):
12
  subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
13
  st.header("Predicted Caption:\n\n")
14
  # st.subheader(f"r/{subreddit}:\t{caption}\n")
15
- st.markdown(f"## r/ {subreddit} \t **{caption}**")
 
 
 
 
 
 
16
 
17
 
18
  st.title("Image Captioning Demo from Redcaps")
@@ -23,12 +30,7 @@ st.sidebar.markdown(
23
  )
24
 
25
  with st.spinner("Loading Model"):
26
- from model import *
27
- sample_images = get_samples()
28
- virtexModel = VirTexModel()
29
- imageLoader = ImageLoader()
30
- valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
31
- valid_subs.insert(0, None)
32
 
33
  random_image = get_rand_img(sample_images)
34
  rand_idx = 0
 
3
  import sys
4
  import time
5
  import json
6
+ from model import *
7
  sys.path.append("./virtex/")
8
 
9
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
 
13
  subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
14
  st.header("Predicted Caption:\n\n")
15
  # st.subheader(f"r/{subreddit}:\t{caption}\n")
16
+ st.markdown(
17
+ f"""
18
+ r/{subreddit}
19
+ \t
20
+ **{caption}**
21
+ """
22
+ )
23
 
24
 
25
  st.title("Image Captioning Demo from Redcaps")
 
30
  )
31
 
32
  with st.spinner("Loading Model"):
33
+ virtexModel, imageLoader, sample_images, valid_subs = create_objects()
 
 
 
 
 
34
 
35
  random_image = get_rand_img(sample_images)
36
  rand_idx = 0
model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from huggingface_hub import hf_hub_url, cached_download
2
  from PIL import Image
3
  import os
@@ -127,3 +128,12 @@ def get_rand_img(samples):
127
  i = random.randint(0,len(samples)-1)
128
  return i, samples[i]
129
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  from huggingface_hub import hf_hub_url, cached_download
3
  from PIL import Image
4
  import os
 
128
  i = random.randint(0,len(samples)-1)
129
  return i, samples[i]
130
 
131
+ @st.cache
132
+ def create_objects():
133
+ sample_images = get_samples()
134
+ virtexModel = VirTexModel()
135
+ imageLoader = ImageLoader()
136
+ valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
137
+ valid_subs.insert(0, None)
138
+ return virtexModel, imageLoader, sample_images, valid_subs
139
+