JustinLin610 commited on
Commit
97fc61f
·
1 Parent(s): dd78d66

remove unnecessary files

Browse files
data/mm_data/image_gen_dataset.py DELETED
@@ -1,171 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from io import BytesIO
7
-
8
- import logging
9
- import warnings
10
- import base64
11
- import random
12
-
13
- import numpy as np
14
- import torch
15
-
16
- from PIL import Image, ImageFile
17
- from itertools import chain
18
- from data.ofa_dataset import OFADataset
19
- from data import data_utils
20
-
21
- from PIL import Image
22
- from io import BytesIO
23
- import base64
24
-
25
- ImageFile.LOAD_TRUNCATED_IMAGES = True
26
- ImageFile.MAX_IMAGE_PIXELS = None
27
- Image.MAX_IMAGE_PIXELS = None
28
-
29
- logger = logging.getLogger(__name__)
30
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
31
-
32
-
33
- def collate(
34
- samples,
35
- pad_idx,
36
- eos_idx,
37
- left_pad_source=False,
38
- left_pad_target=False,
39
- ):
40
- if len(samples) == 0:
41
- return {}
42
-
43
- def merge(key, left_pad, move_eos_to_beginning=False):
44
- return data_utils.collate_tokens(
45
- [s[key] for s in samples],
46
- pad_idx,
47
- eos_idx,
48
- left_pad,
49
- move_eos_to_beginning,
50
- )
51
-
52
- id = np.array([s["id"] for s in samples])
53
- src_tokens = merge("source", left_pad=left_pad_source)
54
- # sort by descending source length
55
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
56
-
57
- code_images = np.array([s["code_image"] for s in samples])
58
- code_masks = torch.cat([sample['code_mask'] for sample in samples])
59
-
60
- prev_output_tokens = None
61
- target = None
62
- if samples[0].get("target", None) is not None:
63
- target = merge("target", left_pad=left_pad_target)
64
- tgt_lengths = torch.LongTensor(
65
- [s["target"].ne(pad_idx).long().sum() for s in samples]
66
- )
67
- ntokens = tgt_lengths.sum().item()
68
-
69
- if samples[0].get("prev_output_tokens", None) is not None:
70
- prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
71
- else:
72
- ntokens = src_lengths.sum().item()
73
-
74
- batch = {
75
- "id": id,
76
- "nsentences": len(samples),
77
- "ntokens": ntokens,
78
- "net_input": {
79
- "src_tokens": src_tokens,
80
- "src_lengths": src_lengths,
81
- "code_masks": code_masks,
82
- "prev_output_tokens": prev_output_tokens
83
- },
84
- "code_images": code_images,
85
- "target": target
86
- }
87
-
88
- return batch
89
-
90
-
91
- def preprocess_vqgan(x):
92
- x = 2. * x - 1.
93
- return x
94
-
95
-
96
- class ImageGenDataset(OFADataset):
97
- def __init__(
98
- self,
99
- split,
100
- dataset,
101
- bpe,
102
- src_dict,
103
- tgt_dict=None,
104
- max_src_length=128,
105
- code_dict_size=8192,
106
- code_image_size=256,
107
- num_bins=1000
108
- ):
109
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
110
- self.max_src_length = max_src_length
111
-
112
- self.code_dict_size = code_dict_size
113
- self.num_codes = (code_image_size // 8) ** 2
114
- self.num_bins = num_bins
115
-
116
- slice_id = self.dataset.slice_id
117
- empty_img = Image.new('RGB', (code_image_size, code_image_size))
118
- empty_img.save(f'temp_{slice_id}.png')
119
- img = Image.open(f'temp_{slice_id}.png')
120
- img_buffer = BytesIO()
121
- img.save(img_buffer, format=img.format)
122
- byte_data = img_buffer.getvalue()
123
- self.empty_image_base64 = base64.urlsafe_b64encode(byte_data)
124
-
125
- def __getitem__(self, index):
126
-
127
- data = self.dataset[index]
128
- if len(data) == 2:
129
- uniq_id, text = data
130
- image_code = [0] * 1024
131
- image = self.empty_image_base64
132
- elif len(data) == 3:
133
- uniq_id, text, image_code = data
134
- image_code = [int(num) for num in image_code.strip().split()]
135
- image = self.empty_image_base64
136
- elif len(data) == 4:
137
- uniq_id, image, text, image_code = data
138
- image_code = [int(num) for num in image_code.strip().split()]
139
- else:
140
- raise NotImplementedError
141
- code_mask = torch.tensor([True])
142
- image_code = torch.LongTensor(image_code)
143
- tgt_item = image_code + len(self.src_dict) - self.code_dict_size - self.num_bins
144
- target_item = torch.cat([tgt_item, self.eos_item])
145
- prev_output_item = torch.cat([self.bos_item, tgt_item])
146
-
147
- caption_token_list = text.strip().split()
148
- caption = ' '.join(caption_token_list[:self.max_src_length])
149
- src_item = self.encode_text(
150
- " what is the complete image? caption: {}".format(caption),
151
- append_bos=True,
152
- append_eos=True
153
- )
154
- example = {
155
- "id": uniq_id,
156
- "source": src_item,
157
- "code_mask": code_mask,
158
- "code_image": image,
159
- "target": target_item,
160
- "prev_output_tokens": prev_output_item
161
- }
162
- return example
163
-
164
- def collater(self, samples, pad_to_length=None):
165
- """Merge a list of samples to form a mini-batch.
166
- Args:
167
- samples (List[dict]): samples to collate
168
- Returns:
169
- dict: a mini-batch containing the data of the task
170
- """
171
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks/mm_tasks/__init__.py CHANGED
@@ -1,5 +1,4 @@
1
  from .caption import CaptionTask
2
- from .image_gen import ImageGenTask
3
  from .refcoco import RefcocoTask
4
  from .snli_ve import SnliVeTask
5
  from .vqa_gen import VqaGenTask
 
1
  from .caption import CaptionTask
 
2
  from .refcoco import RefcocoTask
3
  from .snli_ve import SnliVeTask
4
  from .vqa_gen import VqaGenTask
tasks/mm_tasks/image_gen.py DELETED
@@ -1,329 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from dataclasses import dataclass, field
7
- import json
8
- import logging
9
- import os
10
- import math
11
- import base64
12
- from typing import Optional
13
- from argparse import Namespace
14
- from omegaconf import DictConfig, OmegaConf
15
- from torchvision import transforms
16
- from PIL import Image
17
- from io import BytesIO
18
-
19
- import torch
20
- import numpy as np
21
- from fairseq import metrics
22
- from fairseq.tasks import register_task
23
- from fairseq.dataclass import ChoiceEnum
24
-
25
- from models import search, clip
26
- from models.taming.models.vqgan import GumbelVQ
27
- from data.mm_data.image_gen_dataset import ImageGenDataset
28
- from data.file_dataset import FileDataset
29
-
30
- from tasks.ofa_task import OFATask, OFAConfig
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- def custom_to_pil(x):
36
- x = x.detach().cpu()
37
- x = torch.clamp(x, -1., 1.)
38
- x = (x + 1.) / 2.
39
- x = x.permute(1, 2, 0).numpy()
40
- x = (255 * x).astype(np.uint8)
41
- x = Image.fromarray(x)
42
- if not x.mode == "RGB":
43
- x = x.convert("RGB")
44
- return x
45
-
46
-
47
- EVAL_CLIP_METHOD = ChoiceEnum(["ii_sim", "ti_sim"])
48
-
49
- @dataclass
50
- class ImageGenConfig(OFAConfig):
51
- sampling_times: int = field(
52
- default=1, metadata={"help": "sample times"}
53
- )
54
-
55
- code_image_size: int = field(
56
- default=256, metadata={"help": "code image size"}
57
- )
58
-
59
- # options for reporting CLIP score during validation
60
- eval_clip_method: EVAL_CLIP_METHOD = field(
61
- default='ti_sim',
62
- metadata={
63
- "help": "evaluation with CLIP scores. ii_sim means Similarity between generated Images and ref Images, ti_sim means Similarity between generated Images and input Text"}
64
- )
65
-
66
- eval_args: Optional[str] = field(
67
- default='{}',
68
- metadata={
69
- "help": 'generation args for clip scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
70
- },
71
- )
72
-
73
- scst: bool = field(
74
- default=False, metadata={"help": "Self-critical sequence training"}
75
- )
76
- scst_args: str = field(
77
- default='{}',
78
- metadata={
79
- "help": 'generation args for Self-critical sequence training, as JSON string'
80
- },
81
- )
82
-
83
- vqgan_model_path: Optional[str] = field(
84
- default=None,
85
- metadata={"help": "path of vqgan model"}
86
- )
87
- vqgan_config_path: Optional[str] = field(
88
- default=None,
89
- metadata={"help": "path of vqgan config"}
90
- )
91
- clip_model_path: Optional[str] = field(
92
- default=None,
93
- metadata={"help": "clip model path"}
94
- )
95
- gen_images_path: str = field(
96
- default='', metadata={"help": "where to store generated images during evalution. Don't dump images if None. "}
97
- )
98
-
99
-
100
- @register_task("image_gen", dataclass=ImageGenConfig)
101
- class ImageGenTask(OFATask):
102
- def __init__(self, cfg: ImageGenConfig, src_dict, tgt_dict):
103
- super().__init__(cfg, src_dict, tgt_dict)
104
-
105
- def load_dataset(self, split, epoch=1, combine=False, **kwargs):
106
- paths = self.cfg.data.split(',')
107
- assert len(paths) > 0
108
-
109
- if split == 'train':
110
- file_path = paths[(epoch - 1) % (len(paths) - 1)]
111
- else:
112
- file_path = paths[-1]
113
- dataset = FileDataset(file_path, self.cfg.selected_cols)
114
-
115
- self.datasets[split] = ImageGenDataset(
116
- split,
117
- dataset,
118
- self.bpe,
119
- self.src_dict,
120
- self.tgt_dict,
121
- max_src_length=self.cfg.max_src_length,
122
- code_dict_size=self.cfg.code_dict_size,
123
- code_image_size=self.cfg.code_image_size
124
- )
125
-
126
- def build_model(self, cfg):
127
- model = super().build_model(cfg)
128
-
129
- device = torch.cuda.current_device()
130
- clip_model, clip_preprocess = clip.load(self.cfg.clip_model_path, device=device)
131
- self.clip_model = clip_model
132
- self.clip_preprocess = clip_preprocess
133
- self.clip_model.to(device)
134
- self.clip_model.eval()
135
-
136
- vqgan_config = OmegaConf.load(self.cfg.vqgan_config_path)
137
- vqgan = GumbelVQ(**vqgan_config.model.params)
138
- sd = torch.load(self.cfg.vqgan_model_path, map_location="cpu")["state_dict"]
139
- missing, unexpected = vqgan.load_state_dict(sd, strict=False)
140
- for k, v in vqgan.named_parameters():
141
- v.requires_grad = False
142
- self.image_tokenizer = vqgan
143
- self.image_tokenizer.to(device)
144
- self.image_tokenizer.eval()
145
-
146
- gen_args = json.loads(self.cfg.eval_args)
147
- self.sequence_generator = self.build_generator(
148
- [model], Namespace(**gen_args)
149
- )
150
- if self.cfg.scst:
151
- scst_args = json.loads(self.cfg.scst_args)
152
- self.scst_generator = self.build_generator(
153
- [model], Namespace(**scst_args)
154
- )
155
-
156
- return model
157
-
158
- def build_generator(
159
- self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
160
- ):
161
- """
162
- Build a :class:`~fairseq.SequenceGenerator` instance for this
163
- task.
164
-
165
- Args:
166
- models (List[~fairseq.models.FairseqModel]): ensemble of models
167
- args (fairseq.dataclass.configs.GenerationConfig):
168
- configuration object (dataclass) for generation
169
- extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
170
- through to SequenceGenerator
171
- prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
172
- If provided, this function constrains the beam search to
173
- allowed tokens only at each step. The provided function
174
- should take 2 arguments: the batch ID (`batch_id: int`)
175
- and a unidimensional tensor of token ids (`inputs_ids:
176
- torch.Tensor`). It has to return a `List[int]` with the
177
- allowed tokens for the next generation step conditioned
178
- on the previously generated tokens (`inputs_ids`) and
179
- the batch ID (`batch_id`). This argument is useful for
180
- constrained generation conditioned on the prefix, as
181
- described in "Autoregressive Entity Retrieval"
182
- (https://arxiv.org/abs/2010.00904) and
183
- https://github.com/facebookresearch/GENRE.
184
- """
185
- from models.sequence_generator import SequenceGenerator
186
-
187
- # Choose search strategy. Defaults to Sampling.
188
- self.sampling_times = self.cfg.sampling_times
189
- sampling = True # we have to use sampling instead of beam search in image generation task
190
- sampling_topk = getattr(args, "sampling_topk", -1)
191
- sampling_topp = getattr(args, "sampling_topp", -1.0)
192
-
193
- assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
194
- assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
195
-
196
- search_strategy = search.Sampling(
197
- self.target_dictionary, sampling_topk, sampling_topp
198
- )
199
- extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
200
-
201
- return SequenceGenerator(
202
- models,
203
- self.target_dictionary,
204
- beam_size=getattr(args, "beam", 5),
205
- max_len_a=getattr(args, "max_len_a", 0),
206
- max_len_b=getattr(args, "max_len_b", 200),
207
- min_len=getattr(args, "min_len", 1),
208
- normalize_scores=(not getattr(args, "unnormalized", False)),
209
- len_penalty=getattr(args, "lenpen", 1),
210
- unk_penalty=getattr(args, "unkpen", 0),
211
- temperature=getattr(args, "temperature", 1.0),
212
- match_source_len=getattr(args, "match_source_len", False),
213
- no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
214
- search_strategy=search_strategy,
215
- constraint_range=self.cfg.constraint_range,
216
- gen_code=True,
217
- **extra_gen_cls_kwargs,
218
- )
219
-
220
- def compute_ref_image_similarity(self, hyps, ref, device):
221
- hyp_images = torch.stack(
222
- [self.clip_preprocess(hyp_image) for hyp_image in hyps], dim=0
223
- ).to(device)
224
-
225
- ref_images = self.clip_preprocess(ref).unsqueeze(0).to(device)
226
- with torch.no_grad():
227
- hyp_image_features = self.clip_model.encode_image(hyp_images)
228
- ref_image_features = self.clip_model.encode_image(ref_images)
229
- hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True)
230
- ref_image_features /= ref_image_features.norm(dim=-1, keepdim=True)
231
- similarity = hyp_image_features @ ref_image_features.T
232
- # scores.append(similarity.max().item())
233
- sorted_score, indices = torch.sort(similarity.view(-1), descending=True)
234
- return sorted_score, indices
235
-
236
- def compute_text_similarity(self, hyps, text, device):
237
- hyp_images = torch.stack(
238
- [self.clip_preprocess(hyp_image) for hyp_image in hyps], dim=0
239
- ).to(device)
240
-
241
- clip_input = clip.tokenize([text]).to(device)
242
- with torch.no_grad():
243
- hyp_image_features = self.clip_model.encode_image(hyp_images)
244
- hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True)
245
- text_features = self.clip_model.encode_text(clip_input)
246
- text_features /= text_features.norm(dim=-1, keepdim=True)
247
- ti_similarity = hyp_image_features @ text_features.T
248
- sorted_score, indices = torch.sort(ti_similarity.view(-1), descending=True)
249
- return sorted_score, indices
250
-
251
- def valid_step(self, sample, model, criterion):
252
- loss, sample_size, logging_output = criterion(model, sample)
253
-
254
- model.eval()
255
- device = sample['target'].device
256
-
257
- hyps, ref = self.inference_image(self.sequence_generator, sample, [model])
258
- scores = []
259
-
260
- tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
261
- caption = self.bpe.decode(self.tgt_dict.string([token for token in tokens if token >= 4]))[
262
- 38:].replace('/', '')
263
- if self.cfg.eval_clip_method == 'ii_sim':
264
- similarity_score, indices = self.compute_ref_image_similarity(hyps, ref, device)
265
- elif self.cfg.eval_clip_method == 'ti_sim':
266
- similarity_score, indices = self.compute_text_similarity(hyps, caption, device)
267
- else:
268
- raise ValueError("unsupported eval method.")
269
-
270
- scores.append(similarity_score.max().item())
271
- sorted_hyps = [hyps[indice] for indice in indices]
272
-
273
- if self.cfg.gen_images_path:
274
- caption_tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
275
- caption = self.bpe.decode(self.tgt_dict.string([token for token in caption_tokens if token >= 4]))[
276
- 38:].replace('/', '')
277
- self.dump_images(sorted_hyps, text=caption, path=os.path.join(self.cfg.gen_images_path, 'all_results'))
278
- self.dump_images(sorted_hyps, text=caption, path=os.path.join(self.cfg.gen_images_path, 'top1'), topk=1)
279
-
280
- logging_output["_score_sum"] = sum(scores)
281
- logging_output["_score_cnt"] = len(scores)
282
-
283
- return loss, sample_size, logging_output
284
-
285
- def reduce_metrics(self, logging_outputs, criterion):
286
- super().reduce_metrics(logging_outputs, criterion)
287
-
288
- def sum_logs(key):
289
- import torch
290
- result = sum(log.get(key, 0) for log in logging_outputs)
291
- if torch.is_tensor(result):
292
- result = result.cpu()
293
- return result
294
-
295
- def compute_score(meters):
296
- score = meters["_score_sum"].sum / meters["_score_cnt"].sum
297
- score = score if isinstance(score, float) else score.item()
298
- return round(score, 3)
299
-
300
- if sum_logs("_score_cnt") > 0:
301
- metrics.log_scalar("_score_sum", sum_logs("_score_sum"))
302
- metrics.log_scalar("_score_cnt", sum_logs("_score_cnt"))
303
- metrics.log_derived("score", compute_score)
304
-
305
- def inference_image(self, generator, sample, models):
306
- hyps, ref = [], None
307
- for j in range(self.sampling_times):
308
- gen_out = self.inference_step(generator, models, sample)
309
- for i in range(len(gen_out)):
310
- with torch.no_grad():
311
- tokens = torch.stack([item['tokens'][:-1] for item in gen_out[i]], dim=0)
312
- tokens += -len(self.src_dict) + self.cfg.code_dict_size + self.cfg.num_bins
313
- images = self.image_tokenizer.decode_code(
314
- tokens.view(-1, self.cfg.code_image_size // 8, self.cfg.code_image_size // 8)
315
- )
316
- images = [custom_to_pil(image) for image in images]
317
- hyps += images
318
- if 'code_images' in sample:
319
- ref = Image.open(BytesIO(base64.urlsafe_b64decode(sample['code_images'][0]))).convert('RGB')
320
-
321
- return hyps, ref
322
-
323
- def dump_images(self, images, text, path, topk=None):
324
- os.makedirs(path, exist_ok=True)
325
- if topk:
326
- images = images[:topk]
327
- for j, image in enumerate(images):
328
- save_path = os.path.join(path, f'{text}_{j}.png')
329
- image.save(save_path)