zamborg commited on
Commit
65193db
·
1 Parent(s): 7332d54

- Rename redCaps
- naming fix
- allow png

Files changed (3) hide show
  1. app.py +25 -11
  2. model.py +3 -3
  3. requirements.txt +1 -0
app.py CHANGED
@@ -22,10 +22,16 @@ def gen_show_caption(sub_prompt=None, cap_prompt = ""):
22
  )
23
 
24
 
25
- st.title("Image Captioning Demo from Redcaps")
26
  st.sidebar.markdown(
27
  """
28
- Image Captioning Model from VirTex trained on Redcaps
 
 
 
 
 
 
29
  """
30
  )
31
 
@@ -48,6 +54,15 @@ else:
48
 
49
  sample_image = sample_images[0 if select_idx is None else select_idx]
50
 
 
 
 
 
 
 
 
 
 
51
  # class OnChange():
52
  # def __init__(self, idx):
53
  # self.idx = idx
@@ -75,16 +90,12 @@ else:
75
  value=""
76
  )
77
 
78
-
79
- uploaded_image = None
80
- with st.sidebar.form("file-uploader-form", clear_on_submit=True):
81
- uploaded_file = st.file_uploader("Choose a file")
82
- submitted = st.form_submit_button("Submit")
83
- if uploaded_file is not None and submitted:
84
- uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
85
- select_idx = None # set this to help rewrite the cache
86
-
87
  _ = st.sidebar.button("Regenerate Caption")
 
 
 
 
 
88
 
89
  if uploaded_image is None and submitted:
90
  st.write("Please select a file to upload")
@@ -100,8 +111,11 @@ else:
100
  else:
101
  image = Image.open(image_file)
102
 
 
 
103
  st.session_state['image'] = image
104
 
 
105
  image_dict = imageLoader.transform(image)
106
 
107
  show_image = imageLoader.show_resize(image)
 
22
  )
23
 
24
 
25
+ st.title("Image Captioning Demo from RedCaps")
26
  st.sidebar.markdown(
27
  """
28
+ ### Image Captioning Model from VirTex trained on RedCaps
29
+
30
+ Use this page to caption your own images or try out some of our samples.
31
+ You can also generate captions as if they are from specific subreddits,
32
+ as if they start with a particular prompt, or even both.
33
+
34
+ Feel free to share your results on twitter with #redcaps or with a friend.
35
  """
36
  )
37
 
 
54
 
55
  sample_image = sample_images[0 if select_idx is None else select_idx]
56
 
57
+
58
+ uploaded_image = None
59
+ with st.sidebar.form("file-uploader-form", clear_on_submit=True):
60
+ uploaded_file = st.file_uploader("Choose a file")
61
+ submitted = st.form_submit_button("Submit")
62
+ if uploaded_file is not None and submitted:
63
+ uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
64
+ select_idx = None # set this to help rewrite the cache
65
+
66
  # class OnChange():
67
  # def __init__(self, idx):
68
  # self.idx = idx
 
90
  value=""
91
  )
92
 
 
 
 
 
 
 
 
 
 
93
  _ = st.sidebar.button("Regenerate Caption")
94
+
95
+ # advanced = st.sidebar.checkbox("Advanced Options")
96
+
97
+ # if advanced:
98
+ # nuc_size = st.sidebar.slider("")
99
 
100
  if uploaded_image is None and submitted:
101
  st.write("Please select a file to upload")
 
111
  else:
112
  image = Image.open(image_file)
113
 
114
+ image = image.convert("RGB")
115
+
116
  st.session_state['image'] = image
117
 
118
+
119
  image_dict = imageLoader.transform(image)
120
 
121
  show_image = imageLoader.show_resize(image)
model.py CHANGED
@@ -22,7 +22,7 @@ SAMPLES_PATH = "./samples/*.jpg"
22
 
23
  class ImageLoader():
24
  def __init__(self):
25
- self.transformer = torchvision.transforms.Compose([
26
  torchvision.transforms.ToTensor(),
27
  torchvision.transforms.Resize(256),
28
  torchvision.transforms.CenterCrop(224),
@@ -30,7 +30,7 @@ class ImageLoader():
30
  self.show_size=500
31
 
32
  def load(self, im_path):
33
- im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
34
  return {"image": im}
35
 
36
  def raw_load(self, im_path):
@@ -38,7 +38,7 @@ class ImageLoader():
38
  return {"image": im}
39
 
40
  def transform(self, image):
41
- im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
42
  return {"image": im}
43
 
44
  def text_transform(self, text):
 
22
 
23
  class ImageLoader():
24
  def __init__(self):
25
+ self.image_transform = torchvision.transforms.Compose([
26
  torchvision.transforms.ToTensor(),
27
  torchvision.transforms.Resize(256),
28
  torchvision.transforms.CenterCrop(224),
 
30
  self.show_size=500
31
 
32
  def load(self, im_path):
33
+ im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
34
  return {"image": im}
35
 
36
  def raw_load(self, im_path):
 
38
  return {"image": im}
39
 
40
  def transform(self, image):
41
+ im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
42
  return {"image": im}
43
 
44
  def text_transform(self, text):
requirements.txt CHANGED
@@ -14,4 +14,5 @@ torch==1.7.0
14
  torchvision==0.8
15
  tqdm>=4.50.0
16
  wordsegment==1.3.1
 
17
  git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
 
14
  torchvision==0.8
15
  tqdm>=4.50.0
16
  wordsegment==1.3.1
17
+ whatimage==0.0.3
18
  git+git://github.com/facebookresearch/fvcore.git#egg=fvcore