Spaces:
Runtime error
Runtime error
nucleus size updates in advanced"
Browse files
app.py
CHANGED
@@ -6,6 +6,11 @@ import json
|
|
6 |
sys.path.append("./virtex/")
|
7 |
from model import *
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
def gen_show_caption(sub_prompt=None, cap_prompt = ""):
|
10 |
with st.spinner("Generating Caption"):
|
11 |
if sub_prompt is None and cap_prompt is not "":
|
@@ -102,6 +107,7 @@ num_captions=1
|
|
102 |
if advanced:
|
103 |
nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
|
104 |
num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
|
|
|
105 |
|
106 |
if uploaded_image is None and submitted:
|
107 |
st.write("Please select a file to upload")
|
@@ -130,14 +136,4 @@ else:
|
|
130 |
show.image(show_image, "Your Image")
|
131 |
|
132 |
for i in range(num_captions):
|
133 |
-
gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
|
134 |
-
|
135 |
-
# from model import *
|
136 |
-
# sample_images = get_samples()
|
137 |
-
# v, il = VirTexModel(), ImageLoader()
|
138 |
-
|
139 |
-
# for s in sample_images:
|
140 |
-
# subreddit, caption = v.predict(il.load(s))
|
141 |
-
# print("=====================")
|
142 |
-
# print(subreddit)
|
143 |
-
# print(caption)
|
|
|
6 |
sys.path.append("./virtex/")
|
7 |
from model import *
|
8 |
|
9 |
+
# # TODO:
|
10 |
+
# - Reformat the model introduction
|
11 |
+
# - Center the images using the 3 column method
|
12 |
+
# - Make the iterative text generation
|
13 |
+
|
14 |
def gen_show_caption(sub_prompt=None, cap_prompt = ""):
|
15 |
with st.spinner("Generating Caption"):
|
16 |
if sub_prompt is None and cap_prompt is not "":
|
|
|
107 |
if advanced:
|
108 |
nuc_size = st.sidebar.slider("Nucelus Size:", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
|
109 |
num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
|
110 |
+
virtexModel.model.decoder.nucleus_size = nuc_size
|
111 |
|
112 |
if uploaded_image is None and submitted:
|
113 |
st.write("Please select a file to upload")
|
|
|
136 |
show.image(show_image, "Your Image")
|
137 |
|
138 |
for i in range(num_captions):
|
139 |
+
gen_show_caption(sub, imageLoader.text_transform(cap_prompt))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
CHANGED
@@ -61,7 +61,7 @@ class VirTexModel():
|
|
61 |
self.device = 'cpu'
|
62 |
self.tokenizer = TokenizerFactory.from_config(self.config)
|
63 |
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
|
64 |
-
CheckpointManager(model=self.model).load(
|
65 |
self.model.eval()
|
66 |
self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
|
67 |
|
|
|
61 |
self.device = 'cpu'
|
62 |
self.tokenizer = TokenizerFactory.from_config(self.config)
|
63 |
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
|
64 |
+
CheckpointManager(model=self.model).load(MODEL_PATH)
|
65 |
self.model.eval()
|
66 |
self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
|
67 |
|