Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import captioning.utils.opts as opts
|
6 |
+
import captioning.models as models
|
7 |
+
import captioning.utils.misc as utils
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
|
12 |
+
# Checkpoint class
|
13 |
+
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
14 |
+
def on_keyboard_interrupt(self, trainer, pl_module):
|
15 |
+
# Save model when keyboard interrupt
|
16 |
+
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
|
17 |
+
self._save_model(filepath)
|
18 |
+
|
19 |
+
device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
|
20 |
+
reward = 'clips_grammar'
|
21 |
+
|
22 |
+
cfg = f'./configs/phase2/clipRN50_{reward}.yml'
|
23 |
+
|
24 |
+
print("Loading cfg from", cfg)
|
25 |
+
|
26 |
+
opt = opts.parse_opt(parse=False, cfg=cfg)
|
27 |
+
|
28 |
+
import gdown
|
29 |
+
|
30 |
+
url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
|
31 |
+
gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
|
32 |
+
|
33 |
+
url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
|
34 |
+
gdown.download(url, quiet=True, use_cookies=False, output="data/")
|
35 |
+
|
36 |
+
dict_json = json.load(open('./data/cocotalk.json'))
|
37 |
+
print(dict_json.keys())
|
38 |
+
|
39 |
+
ix_to_word = dict_json['ix_to_word']
|
40 |
+
vocab_size = len(ix_to_word)
|
41 |
+
print('vocab size:', vocab_size)
|
42 |
+
|
43 |
+
seq_length = 1
|
44 |
+
|
45 |
+
opt.vocab_size = vocab_size
|
46 |
+
opt.seq_length = seq_length
|
47 |
+
|
48 |
+
opt.batch_size = 1
|
49 |
+
opt.vocab = ix_to_word
|
50 |
+
# opt.use_grammar = False
|
51 |
+
|
52 |
+
model = models.setup(opt)
|
53 |
+
del opt.vocab
|
54 |
+
|
55 |
+
ckpt_path = opt.checkpoint_path + '-last.ckpt'
|
56 |
+
|
57 |
+
print("Loading checkpoint from", ckpt_path)
|
58 |
+
raw_state_dict = torch.load(
|
59 |
+
ckpt_path,
|
60 |
+
map_location=device)
|
61 |
+
|
62 |
+
strict = True
|
63 |
+
|
64 |
+
state_dict = raw_state_dict['state_dict']
|
65 |
+
|
66 |
+
if '_vocab' in state_dict:
|
67 |
+
model.vocab = utils.deserialize(state_dict['_vocab'])
|
68 |
+
del state_dict['_vocab']
|
69 |
+
elif strict:
|
70 |
+
raise KeyError
|
71 |
+
if '_opt' in state_dict:
|
72 |
+
saved_model_opt = utils.deserialize(state_dict['_opt'])
|
73 |
+
del state_dict['_opt']
|
74 |
+
# Make sure the saved opt is compatible with the curren topt
|
75 |
+
need_be_same = ["caption_model",
|
76 |
+
"rnn_type", "rnn_size", "num_layers"]
|
77 |
+
for checkme in need_be_same:
|
78 |
+
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
|
79 |
+
getattr(opt, checkme) in ['updown', 'topdown']:
|
80 |
+
continue
|
81 |
+
assert getattr(saved_model_opt, checkme) == getattr(
|
82 |
+
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
|
83 |
+
elif strict:
|
84 |
+
raise KeyError
|
85 |
+
res = model.load_state_dict(state_dict, strict)
|
86 |
+
print(res)
|
87 |
+
|
88 |
+
model = model.to(device)
|
89 |
+
model.eval();
|
90 |
+
|
91 |
+
import clip
|
92 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
93 |
+
from PIL import Image
|
94 |
+
from timm.models.vision_transformer import resize_pos_embed
|
95 |
+
|
96 |
+
clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
|
97 |
+
|
98 |
+
preprocess = Compose([
|
99 |
+
Resize((448, 448), interpolation=Image.BICUBIC),
|
100 |
+
CenterCrop((448, 448)),
|
101 |
+
ToTensor()
|
102 |
+
])
|
103 |
+
|
104 |
+
image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
|
105 |
+
image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
|
106 |
+
|
107 |
+
num_patches = 196 #600 * 1000 // 32 // 32
|
108 |
+
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
|
109 |
+
pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
|
110 |
+
clip_model.visual.attnpool.positional_embedding = pos_embed
|
111 |
+
|
112 |
+
|
113 |
+
# End below
|
114 |
+
|
115 |
+
def generate_image(img, steps=100, seed=42, guidance_scale=6.0):
|
116 |
+
|
117 |
+
with torch.no_grad():
|
118 |
+
image = preprocess(img)
|
119 |
+
image = torch.tensor(np.stack([image])).to(device)
|
120 |
+
image -= image_mean
|
121 |
+
image /= image_std
|
122 |
+
|
123 |
+
tmp_att, tmp_fc = clip_model.encode_image(image)
|
124 |
+
tmp_att = tmp_att[0].permute(1, 2, 0)
|
125 |
+
tmp_fc = tmp_fc[0]
|
126 |
+
|
127 |
+
att_feat = tmp_att
|
128 |
+
fc_feat = tmp_fc
|
129 |
+
|
130 |
+
|
131 |
+
# Inference configurations
|
132 |
+
eval_kwargs = {}
|
133 |
+
eval_kwargs.update(vars(opt))
|
134 |
+
|
135 |
+
verbose = eval_kwargs.get('verbose', True)
|
136 |
+
verbose_beam = eval_kwargs.get('verbose_beam', 0)
|
137 |
+
verbose_loss = eval_kwargs.get('verbose_loss', 1)
|
138 |
+
|
139 |
+
# dataset = eval_kwargs.get('dataset', 'coco')
|
140 |
+
beam_size = eval_kwargs.get('beam_size', 1)
|
141 |
+
sample_n = eval_kwargs.get('sample_n', 1)
|
142 |
+
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
|
143 |
+
|
144 |
+
with torch.no_grad():
|
145 |
+
fc_feats = torch.zeros((1,0)).to(device)
|
146 |
+
att_feats = att_feat.view(1, 196, 2048).float().to(device)
|
147 |
+
att_masks = None
|
148 |
+
|
149 |
+
# forward the model to also get generated samples for each image
|
150 |
+
# Only leave one feature for each image, in case duplicate sample
|
151 |
+
tmp_eval_kwargs = eval_kwargs.copy()
|
152 |
+
tmp_eval_kwargs.update({'sample_n': 1})
|
153 |
+
seq, seq_logprobs = model(
|
154 |
+
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
155 |
+
seq = seq.data
|
156 |
+
|
157 |
+
sents = utils.decode_sequence(model.vocab, seq)
|
158 |
+
|
159 |
+
return sents[0]
|
160 |
+
|
161 |
+
gr.Interface(
|
162 |
+
generate_image,
|
163 |
+
inputs=[
|
164 |
+
gr.Image(type="pil"),
|
165 |
+
gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
|
166 |
+
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1),
|
167 |
+
gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1),
|
168 |
+
],
|
169 |
+
outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
|
170 |
+
css="#output_image{width: 256px}",
|
171 |
+
).launch()
|