babe24 / ram_train_eval.py
mart9992's picture
m
9231ab9
import os
import time
from datetime import timedelta
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.config import Config
from mmengine.utils import ProgressBar
from transformers import AutoConfig, AutoModel
class RamDataset(torch.utils.data.Dataset):
def __init__(self, data_path, is_train=True, num_relation_classes=56):
super().__init__()
self.num_relation_classes = num_relation_classes
data = np.load(data_path, allow_pickle=True)
self.samples = data["arr_0"]
sample_num = self.samples.size
self.sample_idx_list = []
for idx in range(sample_num):
if self.samples[idx]["is_train"] == is_train:
self.sample_idx_list.append(idx)
def __getitem__(self, idx):
sample = self.samples[self.sample_idx_list[idx]]
object_num = sample["feat"].shape[0]
embedding = torch.from_numpy(sample["feat"])
gt_rels = sample["relations"]
rel_target = self._get_target(object_num, gt_rels)
return embedding, rel_target, gt_rels
def __len__(self):
return len(self.sample_idx_list)
def _get_target(self, object_num, gt_rels):
rel_target = torch.zeros([self.num_relation_classes, object_num, object_num])
for ii, jj, cls_relationship in gt_rels:
rel_target[cls_relationship, ii, jj] = 1
return rel_target
class RamModel(nn.Module):
def __init__(
self,
pretrained_model_name_or_path,
load_pretrained_weights=True,
num_transformer_layer=2,
input_feature_size=256,
output_feature_size=768,
cls_feature_size=512,
num_relation_classes=56,
pred_type="attention",
loss_type="bce",
):
super().__init__()
# 0. config
self.cls_feature_size = cls_feature_size
self.num_relation_classes = num_relation_classes
self.pred_type = pred_type
self.loss_type = loss_type
# 1. fc input and output
self.fc_input = nn.Sequential(
nn.Linear(input_feature_size, output_feature_size),
nn.LayerNorm(output_feature_size),
)
self.fc_output = nn.Sequential(
nn.Linear(output_feature_size, output_feature_size),
nn.LayerNorm(output_feature_size),
)
# 2. transformer model
if load_pretrained_weights:
self.model = AutoModel.from_pretrained(pretrained_model_name_or_path)
else:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
self.model = AutoModel.from_config(config)
if num_transformer_layer != "all" and isinstance(num_transformer_layer, int):
self.model.encoder.layer = self.model.encoder.layer[:num_transformer_layer]
# 3. predict head
self.cls_sub = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes)
self.cls_obj = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes)
# 4. loss
if self.loss_type == "bce":
self.bce_loss = nn.BCEWithLogitsLoss()
elif self.loss_type == "multi_label_ce":
print("Use Multi Label Cross Entropy Loss.")
def forward(self, embeds, attention_mask=None):
"""
embeds: (batch_size, token_num, feature_size)
attention_mask: (batch_size, token_num)
"""
# 1. fc input
embeds = self.fc_input(embeds)
# 2. transformer model
position_ids = torch.ones([1, embeds.shape[1]]).to(embeds.device).to(torch.long)
outputs = self.model.forward(inputs_embeds=embeds, attention_mask=attention_mask, position_ids=position_ids)
embeds = outputs["last_hidden_state"]
# 3. fc output
embeds = self.fc_output(embeds)
# 4. predict head
batch_size, token_num, feature_size = embeds.shape
sub_embeds = self.cls_sub(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3])
obj_embeds = self.cls_obj(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3])
if self.pred_type == "attention":
cls_pred = sub_embeds @ torch.transpose(obj_embeds, 2, 3) / self.cls_feature_size**0.5 # noqa
elif self.pred_type == "einsum":
cls_pred = torch.einsum("nrsc,nroc->nrso", sub_embeds, obj_embeds)
return cls_pred
def loss(self, pred, target, attention_mask):
loss_dict = dict()
batch_size, relation_num, _, _ = pred.shape
mask = torch.zeros_like(pred).to(pred.device)
for idx in range(batch_size):
n = torch.sum(attention_mask[idx]).to(torch.int)
mask[idx, :, :n, :n] = 1
pred = pred * mask - 9999 * (1 - mask)
if self.loss_type == "bce":
loss = self.bce_loss(pred, target)
elif self.loss_type == "multi_label_ce":
input_tensor = torch.permute(pred, (1, 0, 2, 3))
target_tensor = torch.permute(target, (1, 0, 2, 3))
input_tensor = pred.reshape([relation_num, -1])
target_tensor = target.reshape([relation_num, -1])
loss = self.multilabel_categorical_crossentropy(target_tensor, input_tensor)
weight = loss / loss.max()
loss = loss * weight
loss = loss.mean()
loss_dict["loss"] = loss
# running metric
recall_20 = get_recall_N(pred, target, object_num=20)
loss_dict["recall@20"] = recall_20
return loss_dict
def multilabel_categorical_crossentropy(self, y_true, y_pred):
"""
https://kexue.fm/archives/7359
"""
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 9999
y_pred_pos = y_pred - (1 - y_true) * 9999
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return neg_loss + pos_loss
def get_recall_N(y_pred, y_true, object_num=20):
"""
y_pred: [batch_size, 56, object_num, object_num]
y_true: [batch_size, 56, object_num, object_num]
"""
device = y_pred.device
recall_list = []
for idx in range(len(y_true)):
sample_y_true = []
sample_y_pred = []
# find topk
_, topk_indices = torch.topk(
y_true[idx : idx + 1].reshape(
[
-1,
]
),
k=object_num,
)
for index in topk_indices:
pred_cls = index // (y_true.shape[2] ** 2)
index_subject_object = index % (y_true.shape[2] ** 2)
pred_subject = index_subject_object // y_true.shape[2]
pred_object = index_subject_object % y_true.shape[2]
if y_true[idx, pred_cls, pred_subject, pred_object] == 0:
continue
sample_y_true.append([pred_subject, pred_object, pred_cls])
# find topk
_, topk_indices = torch.topk(
y_pred[idx : idx + 1].reshape(
[
-1,
]
),
k=object_num,
)
for index in topk_indices:
pred_cls = index // (y_pred.shape[2] ** 2)
index_subject_object = index % (y_pred.shape[2] ** 2)
pred_subject = index_subject_object // y_pred.shape[2]
pred_object = index_subject_object % y_pred.shape[2]
sample_y_pred.append([pred_subject, pred_object, pred_cls])
recall = len([x for x in sample_y_pred if x in sample_y_true]) / (len(sample_y_true) + 1e-8)
recall_list.append(recall)
recall = torch.tensor(recall_list).to(device).mean() * 100
return recall
class RamTrainer(object):
def __init__(self, config):
self.config = config
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._build_dataset()
self._build_dataloader()
self._build_model()
self._build_optimizer()
self._build_lr_scheduler()
def _build_dataset(self):
self.dataset = RamDataset(**self.config.dataset)
def _build_dataloader(self):
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=self.config.dataloader.batch_size,
shuffle=True if self.config.dataset.is_train else False,
)
def _build_model(self):
self.model = RamModel(**self.config.model).to(self.device)
if self.config.load_from is not None:
self.model.load_state_dict(torch.load(self.config.load_from))
self.model.train()
def _build_optimizer(self):
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.optim.lr, weight_decay=self.config.optim.weight_decay, eps=self.config.optim.eps, betas=self.config.optim.betas)
def _build_lr_scheduler(self):
self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config.optim.lr_scheduler.step, gamma=self.config.optim.lr_scheduler.gamma)
def train(self):
t_start = time.time()
running_avg_loss = 0
for epoch_idx in range(self.config.num_epoch):
for batch_idx, batch_data in enumerate(self.dataloader):
batch_embeds = batch_data[0].to(torch.float32).to(self.device)
batch_target = batch_data[1].to(torch.float32).to(self.device)
attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1]))
batch_pred = self.model.forward(batch_embeds, attention_mask)
loss_dict = self.model.loss(batch_pred, batch_target, attention_mask)
loss = loss_dict["loss"]
recall_20 = loss_dict["recall@20"]
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.optim.max_norm, self.config.optim.norm_type)
self.optimizer.step()
running_avg_loss += loss.item()
if batch_idx % 100 == 0:
t_current = time.time()
num_finished_step = epoch_idx * self.config.num_epoch * len(self.dataloader) + batch_idx + 1
num_to_do_step = (self.config.num_epoch - epoch_idx - 1) * len(self.dataloader) + (len(self.dataloader) - batch_idx - 1)
avg_speed = num_finished_step / (t_current - t_start)
eta = num_to_do_step / avg_speed
print(
"ETA={:0>8}, Epoch={}, Batch={}/{}, LR={}, Loss={:.4f}, RunningAvgLoss={:.4f}, Recall@20={:.2f}%".format(
str(timedelta(seconds=int(eta))), epoch_idx + 1, batch_idx, len(self.dataloader), self.lr_scheduler.get_last_lr()[0], loss.item(), running_avg_loss / num_finished_step, recall_20.item()
)
)
self.lr_scheduler.step()
if not os.path.exists(self.config.output_dir):
os.makedirs(self.config.output_dir)
save_path = os.path.join(self.config.output_dir, "epoch_{}.pth".format(epoch_idx + 1))
print("Save epoch={} checkpoint to {}".format(epoch_idx + 1, save_path))
torch.save(self.model.state_dict(), save_path)
return save_path
class RamPredictor(object):
def __init__(self, config):
self.config = config
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._build_dataset()
self._build_dataloader()
self._build_model()
def _build_dataset(self):
self.dataset = RamDataset(**self.config.dataset)
def _build_dataloader(self):
self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.config.dataloader.batch_size, shuffle=False)
def _build_model(self):
self.model = RamModel(**self.config.model).to(self.device)
if self.config.load_from is not None:
self.model.load_state_dict(torch.load(self.config.load_from))
self.model.eval()
def predict(self, batch_embeds, pred_keep_num=100):
"""
Parameters
----------
batch_embeds: (batch_size=1, token_num, feature_size)
pred_keep_num: int
Returns
-------
batch_pred: (batch_size, relation_num, object_num, object_num)
pred_rels: [[sub_id, obj_id, rel_id], ...]
"""
if not isinstance(batch_embeds, torch.Tensor):
batch_embeds = torch.asarray(batch_embeds)
batch_embeds = batch_embeds.to(torch.float32).to(self.device)
attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1]))
batch_pred = self.model.forward(batch_embeds, attention_mask)
for idx_i in range(batch_pred.shape[2]):
batch_pred[:, :, idx_i, idx_i] = -9999
batch_pred = batch_pred.sigmoid()
pred_rels = []
_, topk_indices = torch.topk(
batch_pred.reshape(
[
-1,
]
),
k=pred_keep_num,
)
# subject, object, relation
for index in topk_indices:
pred_relation = index // (batch_pred.shape[2] ** 2)
index_subject_object = index % (batch_pred.shape[2] ** 2)
pred_subject = index_subject_object // batch_pred.shape[2]
pred_object = index_subject_object % batch_pred.shape[2]
pred = [pred_subject.item(), pred_object.item(), pred_relation.item()]
pred_rels.append(pred)
return batch_pred, pred_rels
def eval(self):
sum_recall_20 = 0.0
sum_recall_50 = 0.0
sum_recall_100 = 0.0
prog_bar = ProgressBar(len(self.dataloader))
for batch_idx, batch_data in enumerate(self.dataloader):
batch_embeds = batch_data[0]
batch_target = batch_data[1]
gt_rels = batch_data[2]
batch_pred, pred_rels = self.predict(batch_embeds)
this_recall_20 = get_recall_N(batch_pred, batch_target, object_num=20)
this_recall_50 = get_recall_N(batch_pred, batch_target, object_num=50)
this_recall_100 = get_recall_N(batch_pred, batch_target, object_num=100)
sum_recall_20 += this_recall_20.item()
sum_recall_50 += this_recall_50.item()
sum_recall_100 += this_recall_100.item()
prog_bar.update()
recall_20 = sum_recall_20 / len(self.dataloader)
recall_50 = sum_recall_50 / len(self.dataloader)
recall_100 = sum_recall_100 / len(self.dataloader)
metric = {
"recall_20": recall_20,
"recall_50": recall_50,
"recall_100": recall_100,
}
return metric
if __name__ == "__main__":
# Config
config = dict(
dataset=dict(
data_path="./data/feat_0420.npz",
is_train=True,
num_relation_classes=56,
),
dataloader=dict(
batch_size=4,
),
model=dict(
pretrained_model_name_or_path="bert-base-uncased",
load_pretrained_weights=True,
num_transformer_layer=2,
input_feature_size=256,
output_feature_size=768,
cls_feature_size=512,
num_relation_classes=56,
pred_type="attention",
loss_type="multi_label_ce",
),
optim=dict(
lr=1e-4,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
max_norm=0.01,
norm_type=2,
lr_scheduler=dict(
step=[6, 10],
gamma=0.1,
),
),
num_epoch=12,
output_dir="./work_dirs",
load_from=None,
)
# Train
config = Config(config)
trainer = RamTrainer(config)
last_model_path = trainer.train()
# Test/Eval
config.dataset.is_train = False
config.load_from = last_model_path
predictor = RamPredictor(config)
metric = predictor.eval()
print(metric)