zamborg commited on
Commit
5650fb4
1 Parent(s): 0de1d2d

update model to reflect dev changes

Browse files
Files changed (2) hide show
  1. app.py +41 -36
  2. model.py +38 -3
app.py CHANGED
@@ -8,31 +8,33 @@ from model import *
8
 
9
  # # TODO:
10
  # - Reformat the model introduction
11
- # - Center the images using the 3 column method
12
  # - Make the iterative text generation
13
 
14
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
15
  with st.spinner("Generating Caption"):
16
- if sub_prompt is None and cap_prompt is not "":
17
- st.write("Without a specified subreddit we default to /r/pics")
18
- subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt = cap_prompt)
19
  st.markdown(
20
  f"""
21
  <style>
22
  red{{
23
  color:#c62828
24
  }}
 
 
 
25
  mono{{
26
  font-family: "Inconsolata";
27
  }}
28
  </style>
29
 
30
- ### <red> r/{subreddit} </red> {caption}
31
  """,
32
  unsafe_allow_html=True)
33
 
 
34
 
35
- st.title("Image Captioning Demo from RedCaps")
 
36
  st.sidebar.markdown(
37
  """
38
  ### Image Captioning Model from VirTex trained on RedCaps
@@ -41,20 +43,15 @@ st.sidebar.markdown(
41
  You can also generate captions as if they are from specific subreddits,
42
  as if they start with a particular prompt, or even both.
43
 
44
- Share your results on twitter with #redcaps or with a friend.
45
  """
46
  )
 
47
 
48
  with st.spinner("Loading Model"):
49
  virtexModel, imageLoader, sample_images, valid_subs = create_objects()
50
 
51
 
52
- # staggered = st.sidebar.checkbox("Iteratively Generate Captions")
53
-
54
- # if staggered:
55
- # pass
56
- # else:
57
-
58
  select_idx = None
59
 
60
  st.sidebar.title("Select a sample image")
@@ -102,38 +99,46 @@ cap_prompt = st.sidebar.text_input(
102
 
103
  _ = st.sidebar.button("Regenerate Caption")
104
 
105
- advanced = st.sidebar.checkbox("Advanced Options")
106
- num_captions=1
107
- if advanced:
108
- num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
109
- nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
110
- virtexModel.model.decoder.nucleus_size = nuc_size
111
 
112
- if False: #uploaded_image is None:# and submitted:
113
- st.write("Please select a file to upload")
 
 
114
 
115
- else:
116
- image_file = sample_image
117
 
118
- # LOAD AND CACHE THE IMAGE
119
- if uploaded_image is not None:
120
- image = uploaded_image
121
- elif select_idx is None and 'image' in st.session_state:
122
- image = st.session_state['image']
123
- else:
124
- image = Image.open(image_file)
125
 
126
- image = image.convert("RGB")
127
 
128
- st.session_state['image'] = image
129
 
130
 
131
- image_dict = imageLoader.transform(image)
132
 
133
- show_image = imageLoader.show_resize(image)
134
 
 
135
  show = st.image(show_image)
136
- show.image(show_image, "Your Image")
137
 
 
 
138
  for i in range(num_captions):
139
- gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # # TODO:
10
  # - Reformat the model introduction
 
11
  # - Make the iterative text generation
12
 
13
  def gen_show_caption(sub_prompt=None, cap_prompt = ""):
14
  with st.spinner("Generating Caption"):
15
+ subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt=cap_prompt)
 
 
16
  st.markdown(
17
  f"""
18
  <style>
19
  red{{
20
  color:#c62828
21
  }}
22
+ blue{{
23
+ color:#2a72d5
24
+ }}
25
  mono{{
26
  font-family: "Inconsolata";
27
  }}
28
  </style>
29
 
30
+ ### <red> r/{subreddit} </red> <blue> {cap_prompt} </blue> {caption}
31
  """,
32
  unsafe_allow_html=True)
33
 
34
+ _, center, _ = st.columns([1,8,1])
35
 
36
+ with center:
37
+ st.title("Image Captioning Demo from RedCaps")
38
  st.sidebar.markdown(
39
  """
40
  ### Image Captioning Model from VirTex trained on RedCaps
 
43
  You can also generate captions as if they are from specific subreddits,
44
  as if they start with a particular prompt, or even both.
45
 
46
+ Share your results on twitter with #redcaps or with a friend*.
47
  """
48
  )
49
+ # st.markdown(footer,unsafe_allow_html=True)
50
 
51
  with st.spinner("Loading Model"):
52
  virtexModel, imageLoader, sample_images, valid_subs = create_objects()
53
 
54
 
 
 
 
 
 
 
55
  select_idx = None
56
 
57
  st.sidebar.title("Select a sample image")
 
99
 
100
  _ = st.sidebar.button("Regenerate Caption")
101
 
 
 
 
 
 
 
102
 
103
+ st.sidebar.write("Advanced Options:")
104
+ num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
105
+ nuc_size = st.sidebar.slider("Nucelus Size:\nLarger values lead to more diverse captions", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
106
+ virtexModel.model.decoder.nucleus_size = nuc_size
107
 
108
+ image_file = sample_image
 
109
 
110
+ # LOAD AND CACHE THE IMAGE
111
+ if uploaded_image is not None:
112
+ image = uploaded_image
113
+ elif select_idx is None and 'image' in st.session_state:
114
+ image = st.session_state['image']
115
+ else:
116
+ image = Image.open(image_file)
117
 
118
+ image = image.convert("RGB")
119
 
120
+ st.session_state['image'] = image
121
 
122
 
123
+ image_dict = imageLoader.transform(image)
124
 
125
+ show_image = imageLoader.show_resize(image)
126
 
127
+ with center:
128
  show = st.image(show_image)
129
+ show.image(show_image)
130
 
131
+ if sub is None and imageLoader.text_transform(cap_prompt) is not "":
132
+ st.write("Without a specified subreddit we default to /r/pics")
133
  for i in range(num_captions):
134
+ gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
135
+
136
+ st.sidebar.markdown(
137
+ """
138
+ *Please note that this model was explicitly not trained on images of people, and as a result is not designed to caption images with humans.
139
+
140
+ This demo accompanies our paper RedCaps.
141
+
142
+ Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
143
+ """
144
+ )
model.py CHANGED
@@ -92,7 +92,6 @@ class VirTexModel():
92
  subreddit_tokens = torch.cat(
93
  [
94
  subreddit_tokens,
95
- torch.tensor([self.tokenizer.token_to_id("[SEP]")], device=self.device).long(),
96
  cap_tokens
97
  ])
98
 
@@ -118,11 +117,14 @@ class VirTexModel():
118
  subreddit = "".join(subreddit.split())
119
  rest_of_caption = rest_of_caption.strip()
120
  else:
121
- subreddit, rest_of_caption = "", caption
 
 
 
 
122
 
123
  is_valid_subreddit = subreddit in self.valid_subs
124
 
125
-
126
  return subreddit, rest_of_caption
127
 
128
  def download_files():
@@ -147,3 +149,36 @@ def create_objects():
147
  valid_subs.insert(0, None)
148
  return virtexModel, imageLoader, sample_images, valid_subs
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  subreddit_tokens = torch.cat(
93
  [
94
  subreddit_tokens,
 
95
  cap_tokens
96
  ])
97
 
 
117
  subreddit = "".join(subreddit.split())
118
  rest_of_caption = rest_of_caption.strip()
119
  else:
120
+ subreddit, rest_of_caption = "", caption.strip()
121
+
122
+ # split prompt for coloring:
123
+ if prompt is not "":
124
+ _, rest_of_caption = caption.split(prompt.strip())
125
 
126
  is_valid_subreddit = subreddit in self.valid_subs
127
 
 
128
  return subreddit, rest_of_caption
129
 
130
  def download_files():
 
149
  valid_subs.insert(0, None)
150
  return virtexModel, imageLoader, sample_images, valid_subs
151
 
152
+ footer="""<style>
153
+ a:link , a:visited{
154
+ color: blue;
155
+ background-color: transparent;
156
+ text-decoration: underline;
157
+ }
158
+
159
+ a:hover, a:active {
160
+ color: red;
161
+ background-color: transparent;
162
+ text-decoration: underline;
163
+ }
164
+
165
+ .footer {
166
+ position: fixed;
167
+ left: 0;
168
+ bottom: 0;
169
+ width: 100%;
170
+ background-color: white;
171
+ color: black;
172
+ text-align: center;
173
+ }
174
+ </style>
175
+ <div class="footer">
176
+ <p>
177
+ *Please note that this model was explicitly not trained on images of people, and as a result is not designed to caption images with humans.
178
+
179
+ This demo accompanies our paper RedCaps.
180
+
181
+ Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
182
+ </p>
183
+ </div>
184
+ """