Spaces:
Runtime error
Runtime error
updates:
Browse files- Rename redCaps
- naming fix
- allow png
- app.py +25 -11
- model.py +3 -3
- requirements.txt +1 -0
app.py
CHANGED
@@ -22,10 +22,16 @@ def gen_show_caption(sub_prompt=None, cap_prompt = ""):
|
|
22 |
)
|
23 |
|
24 |
|
25 |
-
st.title("Image Captioning Demo from
|
26 |
st.sidebar.markdown(
|
27 |
"""
|
28 |
-
Image Captioning Model from VirTex trained on
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
"""
|
30 |
)
|
31 |
|
@@ -48,6 +54,15 @@ else:
|
|
48 |
|
49 |
sample_image = sample_images[0 if select_idx is None else select_idx]
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# class OnChange():
|
52 |
# def __init__(self, idx):
|
53 |
# self.idx = idx
|
@@ -75,16 +90,12 @@ else:
|
|
75 |
value=""
|
76 |
)
|
77 |
|
78 |
-
|
79 |
-
uploaded_image = None
|
80 |
-
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
|
81 |
-
uploaded_file = st.file_uploader("Choose a file")
|
82 |
-
submitted = st.form_submit_button("Submit")
|
83 |
-
if uploaded_file is not None and submitted:
|
84 |
-
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
|
85 |
-
select_idx = None # set this to help rewrite the cache
|
86 |
-
|
87 |
_ = st.sidebar.button("Regenerate Caption")
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
if uploaded_image is None and submitted:
|
90 |
st.write("Please select a file to upload")
|
@@ -100,8 +111,11 @@ else:
|
|
100 |
else:
|
101 |
image = Image.open(image_file)
|
102 |
|
|
|
|
|
103 |
st.session_state['image'] = image
|
104 |
|
|
|
105 |
image_dict = imageLoader.transform(image)
|
106 |
|
107 |
show_image = imageLoader.show_resize(image)
|
|
|
22 |
)
|
23 |
|
24 |
|
25 |
+
st.title("Image Captioning Demo from RedCaps")
|
26 |
st.sidebar.markdown(
|
27 |
"""
|
28 |
+
### Image Captioning Model from VirTex trained on RedCaps
|
29 |
+
|
30 |
+
Use this page to caption your own images or try out some of our samples.
|
31 |
+
You can also generate captions as if they are from specific subreddits,
|
32 |
+
as if they start with a particular prompt, or even both.
|
33 |
+
|
34 |
+
Feel free to share your results on twitter with #redcaps or with a friend.
|
35 |
"""
|
36 |
)
|
37 |
|
|
|
54 |
|
55 |
sample_image = sample_images[0 if select_idx is None else select_idx]
|
56 |
|
57 |
+
|
58 |
+
uploaded_image = None
|
59 |
+
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
|
60 |
+
uploaded_file = st.file_uploader("Choose a file")
|
61 |
+
submitted = st.form_submit_button("Submit")
|
62 |
+
if uploaded_file is not None and submitted:
|
63 |
+
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
|
64 |
+
select_idx = None # set this to help rewrite the cache
|
65 |
+
|
66 |
# class OnChange():
|
67 |
# def __init__(self, idx):
|
68 |
# self.idx = idx
|
|
|
90 |
value=""
|
91 |
)
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
_ = st.sidebar.button("Regenerate Caption")
|
94 |
+
|
95 |
+
# advanced = st.sidebar.checkbox("Advanced Options")
|
96 |
+
|
97 |
+
# if advanced:
|
98 |
+
# nuc_size = st.sidebar.slider("")
|
99 |
|
100 |
if uploaded_image is None and submitted:
|
101 |
st.write("Please select a file to upload")
|
|
|
111 |
else:
|
112 |
image = Image.open(image_file)
|
113 |
|
114 |
+
image = image.convert("RGB")
|
115 |
+
|
116 |
st.session_state['image'] = image
|
117 |
|
118 |
+
|
119 |
image_dict = imageLoader.transform(image)
|
120 |
|
121 |
show_image = imageLoader.show_resize(image)
|
model.py
CHANGED
@@ -22,7 +22,7 @@ SAMPLES_PATH = "./samples/*.jpg"
|
|
22 |
|
23 |
class ImageLoader():
|
24 |
def __init__(self):
|
25 |
-
self.
|
26 |
torchvision.transforms.ToTensor(),
|
27 |
torchvision.transforms.Resize(256),
|
28 |
torchvision.transforms.CenterCrop(224),
|
@@ -30,7 +30,7 @@ class ImageLoader():
|
|
30 |
self.show_size=500
|
31 |
|
32 |
def load(self, im_path):
|
33 |
-
im = torch.FloatTensor(self.
|
34 |
return {"image": im}
|
35 |
|
36 |
def raw_load(self, im_path):
|
@@ -38,7 +38,7 @@ class ImageLoader():
|
|
38 |
return {"image": im}
|
39 |
|
40 |
def transform(self, image):
|
41 |
-
im = torch.FloatTensor(self.
|
42 |
return {"image": im}
|
43 |
|
44 |
def text_transform(self, text):
|
|
|
22 |
|
23 |
class ImageLoader():
|
24 |
def __init__(self):
|
25 |
+
self.image_transform = torchvision.transforms.Compose([
|
26 |
torchvision.transforms.ToTensor(),
|
27 |
torchvision.transforms.Resize(256),
|
28 |
torchvision.transforms.CenterCrop(224),
|
|
|
30 |
self.show_size=500
|
31 |
|
32 |
def load(self, im_path):
|
33 |
+
im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
|
34 |
return {"image": im}
|
35 |
|
36 |
def raw_load(self, im_path):
|
|
|
38 |
return {"image": im}
|
39 |
|
40 |
def transform(self, image):
|
41 |
+
im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
|
42 |
return {"image": im}
|
43 |
|
44 |
def text_transform(self, text):
|
requirements.txt
CHANGED
@@ -14,4 +14,5 @@ torch==1.7.0
|
|
14 |
torchvision==0.8
|
15 |
tqdm>=4.50.0
|
16 |
wordsegment==1.3.1
|
|
|
17 |
git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
|
|
|
14 |
torchvision==0.8
|
15 |
tqdm>=4.50.0
|
16 |
wordsegment==1.3.1
|
17 |
+
whatimage==0.0.3
|
18 |
git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
|