Spaces:
Runtime error
Runtime error
Alberto Carmona
commited on
Commit
·
2773b59
1
Parent(s):
4973507
Enable all the code
Browse files
app.py
CHANGED
@@ -1,181 +1,181 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
import gradio as gr
|
10 |
|
11 |
-
|
12 |
-
#
|
13 |
import random
|
14 |
-
|
15 |
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
#
|
21 |
-
|
22 |
-
|
23 |
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
|
28 |
|
29 |
-
|
30 |
|
31 |
-
|
32 |
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
|
37 |
|
38 |
-
|
39 |
-
|
40 |
|
41 |
-
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
49 |
|
50 |
-
|
51 |
-
|
52 |
|
53 |
-
|
54 |
-
|
55 |
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
|
67 |
|
68 |
-
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
|
92 |
-
|
93 |
-
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
|
100 |
-
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
|
108 |
-
|
109 |
-
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
|
116 |
|
117 |
-
#
|
118 |
-
|
119 |
-
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
|
140 |
-
|
141 |
-
|
142 |
|
143 |
-
#
|
144 |
-
|
145 |
-
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
|
151 |
-
#
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
|
161 |
-
#
|
162 |
-
#
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
|
169 |
-
|
170 |
|
171 |
-
|
172 |
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
|
180 |
|
181 |
random_seed = random.randint(0, 2147483647)
|
@@ -183,10 +183,9 @@ random_seed = random.randint(0, 2147483647)
|
|
183 |
def test_fn(**kwargs):
|
184 |
return None
|
185 |
|
186 |
-
|
187 |
gr.Interface(
|
188 |
-
|
189 |
-
test_fn,
|
190 |
inputs=[
|
191 |
gr.Image(type="pil"),
|
192 |
gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import captioning.utils.opts as opts
|
6 |
+
import captioning.models as models
|
7 |
+
import captioning.utils.misc as utils
|
8 |
+
import pytorch_lightning as pl
|
9 |
import gradio as gr
|
10 |
|
11 |
+
from diffusers import LDMTextToImagePipeline
|
12 |
+
# import PIL.Image
|
13 |
import random
|
14 |
+
import os
|
15 |
|
16 |
|
17 |
+
# Checkpoint class
|
18 |
+
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
19 |
+
def on_keyboard_interrupt(self, trainer, pl_module):
|
20 |
+
# Save model when keyboard interrupt
|
21 |
+
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
|
22 |
+
self._save_model(filepath)
|
23 |
|
24 |
+
device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
|
25 |
+
reward = 'clips_grammar'
|
26 |
|
27 |
+
cfg = f'./configs/phase2/clipRN50_{reward}.yml'
|
28 |
|
29 |
+
print("Loading cfg from", cfg)
|
30 |
|
31 |
+
opt = opts.parse_opt(parse=False, cfg=cfg)
|
32 |
|
33 |
+
import gdown
|
34 |
|
35 |
+
url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
|
36 |
+
gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
|
37 |
|
38 |
+
url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
|
39 |
+
gdown.download(url, quiet=True, use_cookies=False, output="data/")
|
40 |
|
41 |
+
dict_json = json.load(open('./data/cocotalk.json'))
|
42 |
+
print(dict_json.keys())
|
43 |
|
44 |
+
ix_to_word = dict_json['ix_to_word']
|
45 |
+
vocab_size = len(ix_to_word)
|
46 |
+
print('vocab size:', vocab_size)
|
47 |
|
48 |
+
seq_length = 1
|
49 |
|
50 |
+
opt.vocab_size = vocab_size
|
51 |
+
opt.seq_length = seq_length
|
52 |
|
53 |
+
opt.batch_size = 1
|
54 |
+
opt.vocab = ix_to_word
|
55 |
|
56 |
+
model = models.setup(opt)
|
57 |
+
del opt.vocab
|
58 |
|
59 |
+
ckpt_path = opt.checkpoint_path + '-last.ckpt'
|
60 |
|
61 |
+
print("Loading checkpoint from", ckpt_path)
|
62 |
+
raw_state_dict = torch.load(
|
63 |
+
ckpt_path,
|
64 |
+
map_location=device)
|
65 |
|
66 |
+
strict = True
|
67 |
|
68 |
+
state_dict = raw_state_dict['state_dict']
|
69 |
|
70 |
+
if '_vocab' in state_dict:
|
71 |
+
model.vocab = utils.deserialize(state_dict['_vocab'])
|
72 |
+
del state_dict['_vocab']
|
73 |
+
elif strict:
|
74 |
+
raise KeyError
|
75 |
+
if '_opt' in state_dict:
|
76 |
+
saved_model_opt = utils.deserialize(state_dict['_opt'])
|
77 |
+
del state_dict['_opt']
|
78 |
+
# Make sure the saved opt is compatible with the curren topt
|
79 |
+
need_be_same = ["caption_model",
|
80 |
+
"rnn_type", "rnn_size", "num_layers"]
|
81 |
+
for checkme in need_be_same:
|
82 |
+
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
|
83 |
+
getattr(opt, checkme) in ['updown', 'topdown']:
|
84 |
+
continue
|
85 |
+
assert getattr(saved_model_opt, checkme) == getattr(
|
86 |
+
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
|
87 |
+
elif strict:
|
88 |
+
raise KeyError
|
89 |
+
res = model.load_state_dict(state_dict, strict)
|
90 |
+
print(res)
|
91 |
|
92 |
+
model = model.to(device)
|
93 |
+
model.eval();
|
94 |
|
95 |
+
import clip
|
96 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
97 |
+
from PIL import Image
|
98 |
+
from timm.models.vision_transformer import resize_pos_embed
|
99 |
|
100 |
+
clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
|
101 |
|
102 |
+
preprocess = Compose([
|
103 |
+
Resize((448, 448), interpolation=Image.BICUBIC),
|
104 |
+
CenterCrop((448, 448)),
|
105 |
+
ToTensor()
|
106 |
+
])
|
107 |
|
108 |
+
image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
|
109 |
+
image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
|
110 |
|
111 |
+
num_patches = 196 #600 * 1000 // 32 // 32
|
112 |
+
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
|
113 |
+
pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
|
114 |
+
clip_model.visual.attnpool.positional_embedding = pos_embed
|
115 |
|
116 |
|
117 |
+
# End below
|
118 |
+
print('Loading the model: CompVis/ldm-text2im-large-256')
|
119 |
+
ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
120 |
|
121 |
+
def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
|
122 |
+
print('RUN: generate_image_from_text')
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
generator = torch.manual_seed(seed)
|
125 |
+
images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"]
|
126 |
+
return images[0]
|
127 |
+
|
128 |
+
def generate_text_from_image(img):
|
129 |
+
print('RUN: generate_text_from_image')
|
130 |
+
with torch.no_grad():
|
131 |
+
image = preprocess(img)
|
132 |
+
image = torch.tensor(np.stack([image])).to(device)
|
133 |
+
image -= image_mean
|
134 |
+
image /= image_std
|
135 |
|
136 |
+
tmp_att, tmp_fc = clip_model.encode_image(image)
|
137 |
+
tmp_att = tmp_att[0].permute(1, 2, 0)
|
138 |
+
tmp_fc = tmp_fc[0]
|
139 |
|
140 |
+
att_feat = tmp_att
|
141 |
+
fc_feat = tmp_fc
|
142 |
|
143 |
+
# Inference configurations
|
144 |
+
eval_kwargs = {}
|
145 |
+
eval_kwargs.update(vars(opt))
|
146 |
|
147 |
+
verbose = eval_kwargs.get('verbose', True)
|
148 |
+
verbose_beam = eval_kwargs.get('verbose_beam', 0)
|
149 |
+
verbose_loss = eval_kwargs.get('verbose_loss', 1)
|
150 |
|
151 |
+
# dataset = eval_kwargs.get('dataset', 'coco')
|
152 |
+
beam_size = eval_kwargs.get('beam_size', 1)
|
153 |
+
sample_n = eval_kwargs.get('sample_n', 1)
|
154 |
+
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
|
155 |
|
156 |
+
with torch.no_grad():
|
157 |
+
fc_feats = torch.zeros((1,0)).to(device)
|
158 |
+
att_feats = att_feat.view(1, 196, 2048).float().to(device)
|
159 |
+
att_masks = None
|
160 |
|
161 |
+
# forward the model to also get generated samples for each image
|
162 |
+
# Only leave one feature for each image, in case duplicate sample
|
163 |
+
tmp_eval_kwargs = eval_kwargs.copy()
|
164 |
+
tmp_eval_kwargs.update({'sample_n': 1})
|
165 |
+
seq, seq_logprobs = model(
|
166 |
+
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
167 |
+
seq = seq.data
|
168 |
|
169 |
+
sents = utils.decode_sequence(model.vocab, seq)
|
170 |
|
171 |
+
return sents[0]
|
172 |
|
173 |
|
174 |
+
def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
|
175 |
+
print('RUN: generate_drawing_from_image')
|
176 |
+
caption = generate_text_from_image(img)
|
177 |
+
gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale)
|
178 |
+
return gen_image
|
179 |
|
180 |
|
181 |
random_seed = random.randint(0, 2147483647)
|
|
|
183 |
def test_fn(**kwargs):
|
184 |
return None
|
185 |
|
|
|
186 |
gr.Interface(
|
187 |
+
generate_drawing_from_image,
|
188 |
+
# test_fn,
|
189 |
inputs=[
|
190 |
gr.Image(type="pil"),
|
191 |
gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
|