ryanramos commited on
Commit
be8362f
·
1 Parent(s): f7522c4

Delete finetune_vqa.py

Browse files
Files changed (1) hide show
  1. finetune_vqa.py +0 -204
finetune_vqa.py DELETED
@@ -1,204 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import argparse
8
- import datetime
9
- import os
10
- import random
11
- import time
12
-
13
- import ruamel.yaml as yaml
14
- import torch
15
- import torch.backends.cudnn as cudnn
16
- import torch.distributed as dist
17
- from data.vqa_datamodules import VQADataModule
18
- from model import albef_model_for_vqa
19
- from torch.optim import AdamW
20
- from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
21
-
22
- from utils import (
23
- add_weight_decay,
24
- get_rank,
25
- get_world_size,
26
- init_distributed_mode,
27
- is_dist_avail_and_initialized,
28
- is_main_process,
29
- save_result,
30
- )
31
-
32
-
33
- def train(model, datamodule, args, device):
34
- model_without_ddp = model.module if is_dist_avail_and_initialized() else model
35
- model.train()
36
-
37
- optimizer_params = add_weight_decay(model, args["weight_decay"])
38
- optimizer = AdamW(optimizer_params, lr=args["lr"])
39
- scheduler = CosineAnnealingWarmRestarts(
40
- optimizer, T_0=args["max_epochs"], eta_min=args["min_lr"]
41
- )
42
-
43
- step_size = args["step_size"]
44
- warmup_steps = args["warmup_steps"]
45
- warmup_iterations = warmup_steps * step_size
46
-
47
- data_loader = datamodule.train_dataloader(
48
- is_distributed=is_dist_avail_and_initialized(),
49
- num_tasks=get_world_size(),
50
- global_rank=get_rank(),
51
- )
52
-
53
- start_time = time.time()
54
-
55
- for epoch in range(args["max_epochs"]):
56
- if is_dist_avail_and_initialized():
57
- data_loader.sampler.set_epoch(epoch)
58
-
59
- if epoch > 0:
60
- scheduler.step(epoch + warmup_steps)
61
-
62
- for batch, (
63
- images,
64
- questions,
65
- questions_atts,
66
- answers,
67
- answers_atts,
68
- ans_weights,
69
- ans_lengths,
70
- ) in enumerate(data_loader):
71
- if epoch > 0:
72
- alpha = args["alpha"]
73
- else:
74
- alpha = args["alpha"] * min(1, batch / len(data_loader))
75
-
76
- images = images.to(device, non_blocking=True)
77
- questions = questions.to(device)
78
- questions_atts = questions_atts.to(device)
79
- answers = answers.to(device)
80
- answers_atts = answers_atts.to(device)
81
- ans_weights = ans_weights.to(device)
82
-
83
- loss = model(
84
- images,
85
- questions,
86
- questions_atts,
87
- answers,
88
- answers_atts,
89
- ans_weights=ans_weights,
90
- ans_lengths=ans_lengths,
91
- alpha=alpha,
92
- is_train=True,
93
- )
94
-
95
- optimizer.zero_grad()
96
- loss.backward()
97
- optimizer.step()
98
-
99
- if epoch == 0 and batch % step_size == 0 and batch <= warmup_iterations:
100
- scheduler.step(batch // step_size)
101
-
102
- if batch % args["log_every_n_steps"] == 0:
103
- total_time = time.time() - start_time
104
- time_str = "time {},".format(
105
- datetime.timedelta(seconds=int(total_time))
106
- )
107
- epoch_str = "epoch {}/{},".format(epoch, args["max_epochs"])
108
- batch_str = "batch {}/{},".format(batch, len(data_loader))
109
- loss_str = "loss {}".format(loss.item())
110
- print(time_str, epoch_str, batch_str, loss_str)
111
-
112
- if is_main_process():
113
- save_obj = {
114
- "model": model_without_ddp.state_dict(),
115
- "optimizer": optimizer.state_dict(),
116
- "scheduler": scheduler.state_dict(),
117
- "epoch": epoch,
118
- }
119
- torch.save(
120
- save_obj,
121
- os.path.join(args["checkpoint_root"], "vqa_checkpoint_%02d.pt" % epoch),
122
- )
123
-
124
- if is_dist_avail_and_initialized():
125
- dist.barrier()
126
-
127
-
128
- @torch.no_grad()
129
- def evaluation(model, datamodule, args, device):
130
- model.eval()
131
-
132
- result = []
133
-
134
- answer_list = datamodule.test_dataset.answer_list
135
- answer_input_ids = datamodule.test_dataset.answer_input_ids.to(device)
136
- answer_atts = datamodule.test_dataset.answer_attention_mask.to(device)
137
- data_loader = datamodule.test_dataloader(
138
- is_distributed=is_dist_avail_and_initialized(),
139
- num_tasks=get_world_size(),
140
- global_rank=get_rank(),
141
- )
142
-
143
- start_time = time.time()
144
-
145
- for batch, (img, ques, ques_atts, ques_ids) in enumerate(data_loader):
146
- img = img.to(device, non_blocking=True)
147
- ques = ques.to(device)
148
- ques_atts = ques_atts.to(device)
149
-
150
- topk_ids, topk_probs = model(
151
- img,
152
- ques,
153
- ques_atts,
154
- answer_input_ids,
155
- answer_atts,
156
- k=args["k_test"],
157
- is_train=False,
158
- )
159
-
160
- for ques_id, topk_id, topk_prob in zip(ques_ids, topk_ids, topk_probs):
161
- _, pred = topk_prob.max(dim=0)
162
- result.append(
163
- {"question_id": ques_id, "answer": answer_list[topk_id[pred]]}
164
- )
165
-
166
- if batch % args["log_every_n_steps"] == 0:
167
- total_time = time.time() - start_time
168
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
169
- print(
170
- "time {}, batch {}/{}".format(total_time_str, batch, len(data_loader))
171
- )
172
-
173
- return result
174
-
175
-
176
- def main():
177
- parser = argparse.ArgumentParser()
178
- parser.add_argument("--config", default="./examples/albef/configs/vqa.yaml")
179
- args = parser.parse_args()
180
- config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
181
-
182
- init_distributed_mode(config)
183
- device = torch.device(config["device"])
184
-
185
- seed = config["seed"] + get_rank()
186
- torch.manual_seed(seed)
187
- random.seed(seed)
188
- cudnn.benchmark = True
189
-
190
- datamodule = VQADataModule(**config["datamodule_args"])
191
- model = albef_model_for_vqa(config, pretrained=True)
192
- model = model.to(device)
193
- if is_dist_avail_and_initialized():
194
- model = torch.nn.parallel.DistributedDataParallel(
195
- model, device_ids=[config["gpu"]]
196
- )
197
-
198
- train(model, datamodule, config["training_args"], device)
199
- result = evaluation(model, datamodule, config["eval_args"], device)
200
- save_result(result, config["output_root"], "vqa_output")
201
-
202
-
203
- if __name__ == "__main__":
204
- main()