Spaces:
Runtime error
Runtime error
jiwan-chung
commited on
Commit
•
99aac23
1
Parent(s):
901ff30
backbone choice dropdown
Browse files- app.py +4 -4
- arguments.py +1 -1
- run.py +24 -9
app.py
CHANGED
@@ -39,11 +39,11 @@ for k, v in images.items():
|
|
39 |
prompts = ['blog:', 'dialogue:', 'This is my favorite poem:']
|
40 |
|
41 |
title = 'Demo for ESPER'
|
42 |
-
description =
|
43 |
-
|
44 |
|
45 |
-
examples = [[[v, prompt, 20, False
|
46 |
for v in images.values()]
|
47 |
examples = list(chain(*examples))
|
48 |
|
49 |
-
launch(examples, title=title, description=description,
|
|
|
39 |
prompts = ['blog:', 'dialogue:', 'This is my favorite poem:']
|
40 |
|
41 |
title = 'Demo for ESPER'
|
42 |
+
description = None
|
43 |
+
prompt_eg = f'try pretrained styles such as "blog:" or "dialogue:"\n or unseen prompts such as "{prompts[-1]}"'
|
44 |
|
45 |
+
examples = [[[v, prompt, 20, False] for prompt in prompts]
|
46 |
for v in images.values()]
|
47 |
examples = list(chain(*examples))
|
48 |
|
49 |
+
launch(examples, title=title, description=description, prompt_eg=prompt_eg)
|
arguments.py
CHANGED
@@ -20,7 +20,7 @@ def get_args():
|
|
20 |
parser.add_argument(
|
21 |
'--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
|
22 |
parser.add_argument(
|
23 |
-
'--checkpoint', type=str, default='./data/esper_demo/ckpt
|
24 |
|
25 |
parser.add_argument(
|
26 |
'--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
|
|
|
20 |
parser.add_argument(
|
21 |
'--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
|
22 |
parser.add_argument(
|
23 |
+
'--checkpoint', type=str, default='./data/esper_demo/ckpt', help='checkpoint file path')
|
24 |
|
25 |
parser.add_argument(
|
26 |
'--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
|
run.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
import os
|
2 |
import math
|
|
|
3 |
import platform
|
4 |
import logging
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
import torch
|
8 |
from transformers import AutoModelForCausalLM
|
@@ -122,13 +124,14 @@ def prepare(args):
|
|
122 |
|
123 |
|
124 |
class Runner:
|
125 |
-
def __init__(self,
|
126 |
-
self.
|
127 |
|
128 |
-
def __call__(self, inp, prompt, length, sample):
|
|
|
129 |
# inp = inp.reshape((224, 224, 3))
|
130 |
img = Image.fromarray(np.uint8(inp))
|
131 |
-
text =
|
132 |
return prompt, text
|
133 |
# return inp, prompt, text
|
134 |
|
@@ -140,17 +143,29 @@ img, _, text = run(sample_img, 'There lies', 50, 20, sample=False)
|
|
140 |
print('test_run:', text)
|
141 |
'''
|
142 |
|
143 |
-
def launch(examples=None, title='Demo for ESPER', description=None,
|
144 |
args = get_args()
|
145 |
-
inferer = prepare(args)
|
146 |
-
runner = Runner(inferer)
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
iface = gr.Interface(
|
149 |
title=title,
|
150 |
description=description,
|
151 |
fn=runner.__call__,
|
152 |
-
inputs=[gr.components.
|
153 |
-
gr.components.
|
|
|
154 |
gr.components.Slider(20, 40, step=1, label='Length'),
|
155 |
# gr.components.Slider(10, 100, step=1, label='window_size'),
|
156 |
gr.components.Checkbox(label='do sample')],
|
|
|
1 |
import os
|
2 |
import math
|
3 |
+
import copy
|
4 |
import platform
|
5 |
import logging
|
6 |
from pathlib import Path
|
7 |
+
from itertools import chain
|
8 |
|
9 |
import torch
|
10 |
from transformers import AutoModelForCausalLM
|
|
|
124 |
|
125 |
|
126 |
class Runner:
|
127 |
+
def __init__(self, inferers):
|
128 |
+
self.inferers = inferers
|
129 |
|
130 |
+
def __call__(self, model_name, inp, prompt, length, sample):
|
131 |
+
inferer = self.inferers[model_name]
|
132 |
# inp = inp.reshape((224, 224, 3))
|
133 |
img = Image.fromarray(np.uint8(inp))
|
134 |
+
text = inferer(img, prompt, length, window_size=10, sample=sample)
|
135 |
return prompt, text
|
136 |
# return inp, prompt, text
|
137 |
|
|
|
143 |
print('test_run:', text)
|
144 |
'''
|
145 |
|
146 |
+
def launch(examples=None, title='Demo for ESPER', description=None, prompt_eg=None):
|
147 |
args = get_args()
|
|
|
|
|
148 |
|
149 |
+
ckpts = [p.parent / p.stem for p in Path(args.checkpoint).glob('*.ckpt')]
|
150 |
+
ckpts = {p.stem: p for p in ckpts}
|
151 |
+
|
152 |
+
inferers = {}
|
153 |
+
for model_name, ckpt in ckpts.items():
|
154 |
+
ckpt_args = copy.deepcopy(args)
|
155 |
+
ckpt_args.checkpoint = str(ckpt)
|
156 |
+
inferer = prepare(ckpt_args)
|
157 |
+
inferers[model_name] = inferer
|
158 |
+
runner = Runner(inferers)
|
159 |
+
model_names = sorted(list(ckpts.keys()))
|
160 |
+
|
161 |
+
examples = list(chain(*[[[n, *ex] for n in model_names] for ex in examples]))
|
162 |
iface = gr.Interface(
|
163 |
title=title,
|
164 |
description=description,
|
165 |
fn=runner.__call__,
|
166 |
+
inputs=[gr.components.Dropdown(choices=model_names, value=model_names[0], label='Backbone'),
|
167 |
+
gr.components.Image(shape=(224, 224), label='Image'),
|
168 |
+
gr.components.Textbox(label='Prompt', placeholder=prompt_eg),
|
169 |
gr.components.Slider(20, 40, step=1, label='Length'),
|
170 |
# gr.components.Slider(10, 100, step=1, label='window_size'),
|
171 |
gr.components.Checkbox(label='do sample')],
|