Spaces:
Runtime error
Runtime error
Delete finetune_vqa.py
Browse files- finetune_vqa.py +0 -204
finetune_vqa.py
DELETED
@@ -1,204 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the BSD-style license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import argparse
|
8 |
-
import datetime
|
9 |
-
import os
|
10 |
-
import random
|
11 |
-
import time
|
12 |
-
|
13 |
-
import ruamel.yaml as yaml
|
14 |
-
import torch
|
15 |
-
import torch.backends.cudnn as cudnn
|
16 |
-
import torch.distributed as dist
|
17 |
-
from data.vqa_datamodules import VQADataModule
|
18 |
-
from model import albef_model_for_vqa
|
19 |
-
from torch.optim import AdamW
|
20 |
-
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
21 |
-
|
22 |
-
from utils import (
|
23 |
-
add_weight_decay,
|
24 |
-
get_rank,
|
25 |
-
get_world_size,
|
26 |
-
init_distributed_mode,
|
27 |
-
is_dist_avail_and_initialized,
|
28 |
-
is_main_process,
|
29 |
-
save_result,
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
def train(model, datamodule, args, device):
|
34 |
-
model_without_ddp = model.module if is_dist_avail_and_initialized() else model
|
35 |
-
model.train()
|
36 |
-
|
37 |
-
optimizer_params = add_weight_decay(model, args["weight_decay"])
|
38 |
-
optimizer = AdamW(optimizer_params, lr=args["lr"])
|
39 |
-
scheduler = CosineAnnealingWarmRestarts(
|
40 |
-
optimizer, T_0=args["max_epochs"], eta_min=args["min_lr"]
|
41 |
-
)
|
42 |
-
|
43 |
-
step_size = args["step_size"]
|
44 |
-
warmup_steps = args["warmup_steps"]
|
45 |
-
warmup_iterations = warmup_steps * step_size
|
46 |
-
|
47 |
-
data_loader = datamodule.train_dataloader(
|
48 |
-
is_distributed=is_dist_avail_and_initialized(),
|
49 |
-
num_tasks=get_world_size(),
|
50 |
-
global_rank=get_rank(),
|
51 |
-
)
|
52 |
-
|
53 |
-
start_time = time.time()
|
54 |
-
|
55 |
-
for epoch in range(args["max_epochs"]):
|
56 |
-
if is_dist_avail_and_initialized():
|
57 |
-
data_loader.sampler.set_epoch(epoch)
|
58 |
-
|
59 |
-
if epoch > 0:
|
60 |
-
scheduler.step(epoch + warmup_steps)
|
61 |
-
|
62 |
-
for batch, (
|
63 |
-
images,
|
64 |
-
questions,
|
65 |
-
questions_atts,
|
66 |
-
answers,
|
67 |
-
answers_atts,
|
68 |
-
ans_weights,
|
69 |
-
ans_lengths,
|
70 |
-
) in enumerate(data_loader):
|
71 |
-
if epoch > 0:
|
72 |
-
alpha = args["alpha"]
|
73 |
-
else:
|
74 |
-
alpha = args["alpha"] * min(1, batch / len(data_loader))
|
75 |
-
|
76 |
-
images = images.to(device, non_blocking=True)
|
77 |
-
questions = questions.to(device)
|
78 |
-
questions_atts = questions_atts.to(device)
|
79 |
-
answers = answers.to(device)
|
80 |
-
answers_atts = answers_atts.to(device)
|
81 |
-
ans_weights = ans_weights.to(device)
|
82 |
-
|
83 |
-
loss = model(
|
84 |
-
images,
|
85 |
-
questions,
|
86 |
-
questions_atts,
|
87 |
-
answers,
|
88 |
-
answers_atts,
|
89 |
-
ans_weights=ans_weights,
|
90 |
-
ans_lengths=ans_lengths,
|
91 |
-
alpha=alpha,
|
92 |
-
is_train=True,
|
93 |
-
)
|
94 |
-
|
95 |
-
optimizer.zero_grad()
|
96 |
-
loss.backward()
|
97 |
-
optimizer.step()
|
98 |
-
|
99 |
-
if epoch == 0 and batch % step_size == 0 and batch <= warmup_iterations:
|
100 |
-
scheduler.step(batch // step_size)
|
101 |
-
|
102 |
-
if batch % args["log_every_n_steps"] == 0:
|
103 |
-
total_time = time.time() - start_time
|
104 |
-
time_str = "time {},".format(
|
105 |
-
datetime.timedelta(seconds=int(total_time))
|
106 |
-
)
|
107 |
-
epoch_str = "epoch {}/{},".format(epoch, args["max_epochs"])
|
108 |
-
batch_str = "batch {}/{},".format(batch, len(data_loader))
|
109 |
-
loss_str = "loss {}".format(loss.item())
|
110 |
-
print(time_str, epoch_str, batch_str, loss_str)
|
111 |
-
|
112 |
-
if is_main_process():
|
113 |
-
save_obj = {
|
114 |
-
"model": model_without_ddp.state_dict(),
|
115 |
-
"optimizer": optimizer.state_dict(),
|
116 |
-
"scheduler": scheduler.state_dict(),
|
117 |
-
"epoch": epoch,
|
118 |
-
}
|
119 |
-
torch.save(
|
120 |
-
save_obj,
|
121 |
-
os.path.join(args["checkpoint_root"], "vqa_checkpoint_%02d.pt" % epoch),
|
122 |
-
)
|
123 |
-
|
124 |
-
if is_dist_avail_and_initialized():
|
125 |
-
dist.barrier()
|
126 |
-
|
127 |
-
|
128 |
-
@torch.no_grad()
|
129 |
-
def evaluation(model, datamodule, args, device):
|
130 |
-
model.eval()
|
131 |
-
|
132 |
-
result = []
|
133 |
-
|
134 |
-
answer_list = datamodule.test_dataset.answer_list
|
135 |
-
answer_input_ids = datamodule.test_dataset.answer_input_ids.to(device)
|
136 |
-
answer_atts = datamodule.test_dataset.answer_attention_mask.to(device)
|
137 |
-
data_loader = datamodule.test_dataloader(
|
138 |
-
is_distributed=is_dist_avail_and_initialized(),
|
139 |
-
num_tasks=get_world_size(),
|
140 |
-
global_rank=get_rank(),
|
141 |
-
)
|
142 |
-
|
143 |
-
start_time = time.time()
|
144 |
-
|
145 |
-
for batch, (img, ques, ques_atts, ques_ids) in enumerate(data_loader):
|
146 |
-
img = img.to(device, non_blocking=True)
|
147 |
-
ques = ques.to(device)
|
148 |
-
ques_atts = ques_atts.to(device)
|
149 |
-
|
150 |
-
topk_ids, topk_probs = model(
|
151 |
-
img,
|
152 |
-
ques,
|
153 |
-
ques_atts,
|
154 |
-
answer_input_ids,
|
155 |
-
answer_atts,
|
156 |
-
k=args["k_test"],
|
157 |
-
is_train=False,
|
158 |
-
)
|
159 |
-
|
160 |
-
for ques_id, topk_id, topk_prob in zip(ques_ids, topk_ids, topk_probs):
|
161 |
-
_, pred = topk_prob.max(dim=0)
|
162 |
-
result.append(
|
163 |
-
{"question_id": ques_id, "answer": answer_list[topk_id[pred]]}
|
164 |
-
)
|
165 |
-
|
166 |
-
if batch % args["log_every_n_steps"] == 0:
|
167 |
-
total_time = time.time() - start_time
|
168 |
-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
169 |
-
print(
|
170 |
-
"time {}, batch {}/{}".format(total_time_str, batch, len(data_loader))
|
171 |
-
)
|
172 |
-
|
173 |
-
return result
|
174 |
-
|
175 |
-
|
176 |
-
def main():
|
177 |
-
parser = argparse.ArgumentParser()
|
178 |
-
parser.add_argument("--config", default="./examples/albef/configs/vqa.yaml")
|
179 |
-
args = parser.parse_args()
|
180 |
-
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
|
181 |
-
|
182 |
-
init_distributed_mode(config)
|
183 |
-
device = torch.device(config["device"])
|
184 |
-
|
185 |
-
seed = config["seed"] + get_rank()
|
186 |
-
torch.manual_seed(seed)
|
187 |
-
random.seed(seed)
|
188 |
-
cudnn.benchmark = True
|
189 |
-
|
190 |
-
datamodule = VQADataModule(**config["datamodule_args"])
|
191 |
-
model = albef_model_for_vqa(config, pretrained=True)
|
192 |
-
model = model.to(device)
|
193 |
-
if is_dist_avail_and_initialized():
|
194 |
-
model = torch.nn.parallel.DistributedDataParallel(
|
195 |
-
model, device_ids=[config["gpu"]]
|
196 |
-
)
|
197 |
-
|
198 |
-
train(model, datamodule, config["training_args"], device)
|
199 |
-
result = evaluation(model, datamodule, config["eval_args"], device)
|
200 |
-
save_result(result, config["output_root"], "vqa_output")
|
201 |
-
|
202 |
-
|
203 |
-
if __name__ == "__main__":
|
204 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|