jiwan-chung commited on
Commit
99aac23
1 Parent(s): 901ff30

backbone choice dropdown

Browse files
Files changed (3) hide show
  1. app.py +4 -4
  2. arguments.py +1 -1
  3. run.py +24 -9
app.py CHANGED
@@ -39,11 +39,11 @@ for k, v in images.items():
39
  prompts = ['blog:', 'dialogue:', 'This is my favorite poem:']
40
 
41
  title = 'Demo for ESPER'
42
- description = 'backbone: style-finetuned GPT-2-base'
43
- prompt_label = f'Prompt (try pretrained styles such as "blog:" or "dialogue:" or unseen prompts such as "{prompts[-1]}")'
44
 
45
- examples = [[[v, prompt, 20, False ] for prompt in prompts]
46
  for v in images.values()]
47
  examples = list(chain(*examples))
48
 
49
- launch(examples, title=title, description=description, prompt_label=prompt_label)
 
39
  prompts = ['blog:', 'dialogue:', 'This is my favorite poem:']
40
 
41
  title = 'Demo for ESPER'
42
+ description = None
43
+ prompt_eg = f'try pretrained styles such as "blog:" or "dialogue:"\n or unseen prompts such as "{prompts[-1]}"'
44
 
45
+ examples = [[[v, prompt, 20, False] for prompt in prompts]
46
  for v in images.values()]
47
  examples = list(chain(*examples))
48
 
49
+ launch(examples, title=title, description=description, prompt_eg=prompt_eg)
arguments.py CHANGED
@@ -20,7 +20,7 @@ def get_args():
20
  parser.add_argument(
21
  '--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
22
  parser.add_argument(
23
- '--checkpoint', type=str, default='./data/esper_demo/ckpt/gpt2_style', help='checkpoint file path')
24
 
25
  parser.add_argument(
26
  '--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
 
20
  parser.add_argument(
21
  '--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path')
22
  parser.add_argument(
23
+ '--checkpoint', type=str, default='./data/esper_demo/ckpt', help='checkpoint file path')
24
 
25
  parser.add_argument(
26
  '--prefix_length', type=int, default=10, help='prefix length for the visual mapper')
run.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  import math
 
3
  import platform
4
  import logging
5
  from pathlib import Path
 
6
 
7
  import torch
8
  from transformers import AutoModelForCausalLM
@@ -122,13 +124,14 @@ def prepare(args):
122
 
123
 
124
  class Runner:
125
- def __init__(self, inferer):
126
- self.inferer = inferer
127
 
128
- def __call__(self, inp, prompt, length, sample):
 
129
  # inp = inp.reshape((224, 224, 3))
130
  img = Image.fromarray(np.uint8(inp))
131
- text = self.inferer(img, prompt, length, window_size=10, sample=sample)
132
  return prompt, text
133
  # return inp, prompt, text
134
 
@@ -140,17 +143,29 @@ img, _, text = run(sample_img, 'There lies', 50, 20, sample=False)
140
  print('test_run:', text)
141
  '''
142
 
143
- def launch(examples=None, title='Demo for ESPER', description=None, prompt_label='Prompt'):
144
  args = get_args()
145
- inferer = prepare(args)
146
- runner = Runner(inferer)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  iface = gr.Interface(
149
  title=title,
150
  description=description,
151
  fn=runner.__call__,
152
- inputs=[gr.components.Image(shape=(224, 224), label='Image'),
153
- gr.components.Textbox(label=prompt_label),
 
154
  gr.components.Slider(20, 40, step=1, label='Length'),
155
  # gr.components.Slider(10, 100, step=1, label='window_size'),
156
  gr.components.Checkbox(label='do sample')],
 
1
  import os
2
  import math
3
+ import copy
4
  import platform
5
  import logging
6
  from pathlib import Path
7
+ from itertools import chain
8
 
9
  import torch
10
  from transformers import AutoModelForCausalLM
 
124
 
125
 
126
  class Runner:
127
+ def __init__(self, inferers):
128
+ self.inferers = inferers
129
 
130
+ def __call__(self, model_name, inp, prompt, length, sample):
131
+ inferer = self.inferers[model_name]
132
  # inp = inp.reshape((224, 224, 3))
133
  img = Image.fromarray(np.uint8(inp))
134
+ text = inferer(img, prompt, length, window_size=10, sample=sample)
135
  return prompt, text
136
  # return inp, prompt, text
137
 
 
143
  print('test_run:', text)
144
  '''
145
 
146
+ def launch(examples=None, title='Demo for ESPER', description=None, prompt_eg=None):
147
  args = get_args()
 
 
148
 
149
+ ckpts = [p.parent / p.stem for p in Path(args.checkpoint).glob('*.ckpt')]
150
+ ckpts = {p.stem: p for p in ckpts}
151
+
152
+ inferers = {}
153
+ for model_name, ckpt in ckpts.items():
154
+ ckpt_args = copy.deepcopy(args)
155
+ ckpt_args.checkpoint = str(ckpt)
156
+ inferer = prepare(ckpt_args)
157
+ inferers[model_name] = inferer
158
+ runner = Runner(inferers)
159
+ model_names = sorted(list(ckpts.keys()))
160
+
161
+ examples = list(chain(*[[[n, *ex] for n in model_names] for ex in examples]))
162
  iface = gr.Interface(
163
  title=title,
164
  description=description,
165
  fn=runner.__call__,
166
+ inputs=[gr.components.Dropdown(choices=model_names, value=model_names[0], label='Backbone'),
167
+ gr.components.Image(shape=(224, 224), label='Image'),
168
+ gr.components.Textbox(label='Prompt', placeholder=prompt_eg),
169
  gr.components.Slider(20, 40, step=1, label='Length'),
170
  # gr.components.Slider(10, 100, step=1, label='window_size'),
171
  gr.components.Checkbox(label='do sample')],