Alberto Carmona commited on
Commit
2773b59
·
1 Parent(s): 4973507

Enable all the code

Browse files
Files changed (1) hide show
  1. app.py +137 -138
app.py CHANGED
@@ -1,181 +1,181 @@
1
- # import torch
2
- # import torch.nn as nn
3
- # import numpy as np
4
- # import json
5
- # import captioning.utils.opts as opts
6
- # import captioning.models as models
7
- # import captioning.utils.misc as utils
8
- # import pytorch_lightning as pl
9
  import gradio as gr
10
 
11
- # from diffusers import LDMTextToImagePipeline
12
- # # import PIL.Image
13
  import random
14
- # import os
15
 
16
 
17
- # # Checkpoint class
18
- # class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
19
- # def on_keyboard_interrupt(self, trainer, pl_module):
20
- # # Save model when keyboard interrupt
21
- # filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
22
- # self._save_model(filepath)
23
 
24
- # device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
25
- # reward = 'clips_grammar'
26
 
27
- # cfg = f'./configs/phase2/clipRN50_{reward}.yml'
28
 
29
- # print("Loading cfg from", cfg)
30
 
31
- # opt = opts.parse_opt(parse=False, cfg=cfg)
32
 
33
- # import gdown
34
 
35
- # url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
36
- # gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
37
 
38
- # url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
39
- # gdown.download(url, quiet=True, use_cookies=False, output="data/")
40
 
41
- # dict_json = json.load(open('./data/cocotalk.json'))
42
- # print(dict_json.keys())
43
 
44
- # ix_to_word = dict_json['ix_to_word']
45
- # vocab_size = len(ix_to_word)
46
- # print('vocab size:', vocab_size)
47
 
48
- # seq_length = 1
49
 
50
- # opt.vocab_size = vocab_size
51
- # opt.seq_length = seq_length
52
 
53
- # opt.batch_size = 1
54
- # opt.vocab = ix_to_word
55
 
56
- # model = models.setup(opt)
57
- # del opt.vocab
58
 
59
- # ckpt_path = opt.checkpoint_path + '-last.ckpt'
60
 
61
- # print("Loading checkpoint from", ckpt_path)
62
- # raw_state_dict = torch.load(
63
- # ckpt_path,
64
- # map_location=device)
65
 
66
- # strict = True
67
 
68
- # state_dict = raw_state_dict['state_dict']
69
 
70
- # if '_vocab' in state_dict:
71
- # model.vocab = utils.deserialize(state_dict['_vocab'])
72
- # del state_dict['_vocab']
73
- # elif strict:
74
- # raise KeyError
75
- # if '_opt' in state_dict:
76
- # saved_model_opt = utils.deserialize(state_dict['_opt'])
77
- # del state_dict['_opt']
78
- # # Make sure the saved opt is compatible with the curren topt
79
- # need_be_same = ["caption_model",
80
- # "rnn_type", "rnn_size", "num_layers"]
81
- # for checkme in need_be_same:
82
- # if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
83
- # getattr(opt, checkme) in ['updown', 'topdown']:
84
- # continue
85
- # assert getattr(saved_model_opt, checkme) == getattr(
86
- # opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
87
- # elif strict:
88
- # raise KeyError
89
- # res = model.load_state_dict(state_dict, strict)
90
- # print(res)
91
 
92
- # model = model.to(device)
93
- # model.eval();
94
 
95
- # import clip
96
- # from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
97
- # from PIL import Image
98
- # from timm.models.vision_transformer import resize_pos_embed
99
 
100
- # clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
101
 
102
- # preprocess = Compose([
103
- # Resize((448, 448), interpolation=Image.BICUBIC),
104
- # CenterCrop((448, 448)),
105
- # ToTensor()
106
- # ])
107
 
108
- # image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
109
- # image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
110
 
111
- # num_patches = 196 #600 * 1000 // 32 // 32
112
- # pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
113
- # pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
114
- # clip_model.visual.attnpool.positional_embedding = pos_embed
115
 
116
 
117
- # # End below
118
- # print('Loading the model: CompVis/ldm-text2im-large-256')
119
- # ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
120
 
121
- # def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
122
- # print('RUN: generate_image_from_text')
123
- # torch.cuda.empty_cache()
124
- # generator = torch.manual_seed(seed)
125
- # images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"]
126
- # return images[0]
127
-
128
- # def generate_text_from_image(img):
129
- # print('RUN: generate_text_from_image')
130
- # with torch.no_grad():
131
- # image = preprocess(img)
132
- # image = torch.tensor(np.stack([image])).to(device)
133
- # image -= image_mean
134
- # image /= image_std
135
 
136
- # tmp_att, tmp_fc = clip_model.encode_image(image)
137
- # tmp_att = tmp_att[0].permute(1, 2, 0)
138
- # tmp_fc = tmp_fc[0]
139
 
140
- # att_feat = tmp_att
141
- # fc_feat = tmp_fc
142
 
143
- # # Inference configurations
144
- # eval_kwargs = {}
145
- # eval_kwargs.update(vars(opt))
146
 
147
- # verbose = eval_kwargs.get('verbose', True)
148
- # verbose_beam = eval_kwargs.get('verbose_beam', 0)
149
- # verbose_loss = eval_kwargs.get('verbose_loss', 1)
150
 
151
- # # dataset = eval_kwargs.get('dataset', 'coco')
152
- # beam_size = eval_kwargs.get('beam_size', 1)
153
- # sample_n = eval_kwargs.get('sample_n', 1)
154
- # remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
155
 
156
- # with torch.no_grad():
157
- # fc_feats = torch.zeros((1,0)).to(device)
158
- # att_feats = att_feat.view(1, 196, 2048).float().to(device)
159
- # att_masks = None
160
 
161
- # # forward the model to also get generated samples for each image
162
- # # Only leave one feature for each image, in case duplicate sample
163
- # tmp_eval_kwargs = eval_kwargs.copy()
164
- # tmp_eval_kwargs.update({'sample_n': 1})
165
- # seq, seq_logprobs = model(
166
- # fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
167
- # seq = seq.data
168
 
169
- # sents = utils.decode_sequence(model.vocab, seq)
170
 
171
- # return sents[0]
172
 
173
 
174
- # def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
175
- # print('RUN: generate_drawing_from_image')
176
- # caption = generate_text_from_image(img)
177
- # gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale)
178
- # return gen_image
179
 
180
 
181
  random_seed = random.randint(0, 2147483647)
@@ -183,10 +183,9 @@ random_seed = random.randint(0, 2147483647)
183
  def test_fn(**kwargs):
184
  return None
185
 
186
-
187
  gr.Interface(
188
- # generate_drawing_from_image,
189
- test_fn,
190
  inputs=[
191
  gr.Image(type="pil"),
192
  gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import json
5
+ import captioning.utils.opts as opts
6
+ import captioning.models as models
7
+ import captioning.utils.misc as utils
8
+ import pytorch_lightning as pl
9
  import gradio as gr
10
 
11
+ from diffusers import LDMTextToImagePipeline
12
+ # import PIL.Image
13
  import random
14
+ import os
15
 
16
 
17
+ # Checkpoint class
18
+ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
19
+ def on_keyboard_interrupt(self, trainer, pl_module):
20
+ # Save model when keyboard interrupt
21
+ filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
22
+ self._save_model(filepath)
23
 
24
+ device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
25
+ reward = 'clips_grammar'
26
 
27
+ cfg = f'./configs/phase2/clipRN50_{reward}.yml'
28
 
29
+ print("Loading cfg from", cfg)
30
 
31
+ opt = opts.parse_opt(parse=False, cfg=cfg)
32
 
33
+ import gdown
34
 
35
+ url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
36
+ gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
37
 
38
+ url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
39
+ gdown.download(url, quiet=True, use_cookies=False, output="data/")
40
 
41
+ dict_json = json.load(open('./data/cocotalk.json'))
42
+ print(dict_json.keys())
43
 
44
+ ix_to_word = dict_json['ix_to_word']
45
+ vocab_size = len(ix_to_word)
46
+ print('vocab size:', vocab_size)
47
 
48
+ seq_length = 1
49
 
50
+ opt.vocab_size = vocab_size
51
+ opt.seq_length = seq_length
52
 
53
+ opt.batch_size = 1
54
+ opt.vocab = ix_to_word
55
 
56
+ model = models.setup(opt)
57
+ del opt.vocab
58
 
59
+ ckpt_path = opt.checkpoint_path + '-last.ckpt'
60
 
61
+ print("Loading checkpoint from", ckpt_path)
62
+ raw_state_dict = torch.load(
63
+ ckpt_path,
64
+ map_location=device)
65
 
66
+ strict = True
67
 
68
+ state_dict = raw_state_dict['state_dict']
69
 
70
+ if '_vocab' in state_dict:
71
+ model.vocab = utils.deserialize(state_dict['_vocab'])
72
+ del state_dict['_vocab']
73
+ elif strict:
74
+ raise KeyError
75
+ if '_opt' in state_dict:
76
+ saved_model_opt = utils.deserialize(state_dict['_opt'])
77
+ del state_dict['_opt']
78
+ # Make sure the saved opt is compatible with the curren topt
79
+ need_be_same = ["caption_model",
80
+ "rnn_type", "rnn_size", "num_layers"]
81
+ for checkme in need_be_same:
82
+ if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
83
+ getattr(opt, checkme) in ['updown', 'topdown']:
84
+ continue
85
+ assert getattr(saved_model_opt, checkme) == getattr(
86
+ opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
87
+ elif strict:
88
+ raise KeyError
89
+ res = model.load_state_dict(state_dict, strict)
90
+ print(res)
91
 
92
+ model = model.to(device)
93
+ model.eval();
94
 
95
+ import clip
96
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
97
+ from PIL import Image
98
+ from timm.models.vision_transformer import resize_pos_embed
99
 
100
+ clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
101
 
102
+ preprocess = Compose([
103
+ Resize((448, 448), interpolation=Image.BICUBIC),
104
+ CenterCrop((448, 448)),
105
+ ToTensor()
106
+ ])
107
 
108
+ image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
109
+ image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
110
 
111
+ num_patches = 196 #600 * 1000 // 32 // 32
112
+ pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
113
+ pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
114
+ clip_model.visual.attnpool.positional_embedding = pos_embed
115
 
116
 
117
+ # End below
118
+ print('Loading the model: CompVis/ldm-text2im-large-256')
119
+ ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
120
 
121
+ def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
122
+ print('RUN: generate_image_from_text')
123
+ torch.cuda.empty_cache()
124
+ generator = torch.manual_seed(seed)
125
+ images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"]
126
+ return images[0]
127
+
128
+ def generate_text_from_image(img):
129
+ print('RUN: generate_text_from_image')
130
+ with torch.no_grad():
131
+ image = preprocess(img)
132
+ image = torch.tensor(np.stack([image])).to(device)
133
+ image -= image_mean
134
+ image /= image_std
135
 
136
+ tmp_att, tmp_fc = clip_model.encode_image(image)
137
+ tmp_att = tmp_att[0].permute(1, 2, 0)
138
+ tmp_fc = tmp_fc[0]
139
 
140
+ att_feat = tmp_att
141
+ fc_feat = tmp_fc
142
 
143
+ # Inference configurations
144
+ eval_kwargs = {}
145
+ eval_kwargs.update(vars(opt))
146
 
147
+ verbose = eval_kwargs.get('verbose', True)
148
+ verbose_beam = eval_kwargs.get('verbose_beam', 0)
149
+ verbose_loss = eval_kwargs.get('verbose_loss', 1)
150
 
151
+ # dataset = eval_kwargs.get('dataset', 'coco')
152
+ beam_size = eval_kwargs.get('beam_size', 1)
153
+ sample_n = eval_kwargs.get('sample_n', 1)
154
+ remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
155
 
156
+ with torch.no_grad():
157
+ fc_feats = torch.zeros((1,0)).to(device)
158
+ att_feats = att_feat.view(1, 196, 2048).float().to(device)
159
+ att_masks = None
160
 
161
+ # forward the model to also get generated samples for each image
162
+ # Only leave one feature for each image, in case duplicate sample
163
+ tmp_eval_kwargs = eval_kwargs.copy()
164
+ tmp_eval_kwargs.update({'sample_n': 1})
165
+ seq, seq_logprobs = model(
166
+ fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
167
+ seq = seq.data
168
 
169
+ sents = utils.decode_sequence(model.vocab, seq)
170
 
171
+ return sents[0]
172
 
173
 
174
+ def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
175
+ print('RUN: generate_drawing_from_image')
176
+ caption = generate_text_from_image(img)
177
+ gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale)
178
+ return gen_image
179
 
180
 
181
  random_seed = random.randint(0, 2147483647)
 
183
  def test_fn(**kwargs):
184
  return None
185
 
 
186
  gr.Interface(
187
+ generate_drawing_from_image,
188
+ # test_fn,
189
  inputs=[
190
  gr.Image(type="pil"),
191
  gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),