summary / fengshen /examples /clue_sim /finetune_clue_sim.py
skf15963's picture
Duplicate from fclong/summary
fb238e8
# coding=utf-8
# Copyright 2021 The IDEA Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from sklearn import metrics
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import pytorch_lightning as pl
from collections import defaultdict
from transformers import AutoConfig, AutoModel, get_cosine_schedule_with_warmup
from loss import FocalLoss, LabelSmoothingCorrectionCrossEntropy
class CustomDataset(Dataset):
def __init__(self, file, tokenizer, max_len, mode='no_test'):
self.tokenizer = tokenizer
self.max_len = max_len
self.mode = mode
self.ex_list = []
with open('./dataset/' + file, "r", encoding='utf-8') as f:
for line in f:
sample = json.loads(line)
query = sample["query"]
title = sample["title"]
id = int(sample["id"])
if self.mode == 'no_test':
relevant = int(sample["label"])
self.ex_list.append((query, title, relevant, id))
else:
self.ex_list.append((query, title, id))
def __len__(self):
return len(self.ex_list)
def __getitem__(self, index):
if self.mode == 'no_test':
query, title, relevant, id = self.ex_list[index]
else:
query, title, id = self.ex_list[index]
inputs = self.tokenizer.encode_plus(
query, title,
truncation=True,
add_special_tokens=True,
max_length=self.max_len,
padding='max_length',
return_token_type_ids=True
)
ids = inputs['input_ids']
mask = inputs['attention_mask']
token_type_ids = inputs["token_type_ids"]
if self.mode == 'no_test':
return {
'ids': torch.tensor(ids, dtype=torch.long),
'mask': torch.tensor(mask, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'targets': torch.tensor(relevant, dtype=torch.float),
'id': torch.tensor(id, dtype=torch.long)
}
else:
return {
'ids': torch.tensor(ids, dtype=torch.long),
'mask': torch.tensor(mask, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'id': torch.tensor(id, dtype=torch.long)
}
class CustomDataModule(pl.LightningDataModule):
def __init__(self, args, tokenizer):
super().__init__()
self.args = args
self.tokenizer = tokenizer
self.max_len = self.args.max_seq_length
self.train_dataset = None
self.val_dataset = None
def setup(self, stage):
data_path = "./dataset"
assert os.path.exists(os.path.join(data_path, 'train.json'))
assert os.path.exists(os.path.join(data_path, 'dev.json'))
assert os.path.exists(os.path.join(data_path, 'test_public.json'))
if stage == 'fit':
self.train_dataset = CustomDataset('train.json', self.tokenizer, self.max_len)
self.val_dataset = CustomDataset('dev.json', self.tokenizer, self.max_len)
self.test_dataset = CustomDataset('test_public.json', self.tokenizer, self.max_len)
elif stage == 'test':
self.test_dataset = CustomDataset('test_public.json', self.tokenizer, self.max_len)
def train_dataloader(self):
full_dataset = ConcatDataset([self.train_dataset, self.val_dataset])
train_dataloader = DataLoader(
full_dataset,
batch_size=self.args.batch_size,
num_workers=4,
shuffle=True,
pin_memory=True,
drop_last=True)
return train_dataloader
def val_dataloader(self):
val_dataloader = DataLoader(
self.test_dataset,
batch_size=self.args.val_batch_size,
num_workers=4,
shuffle=False,
pin_memory=True,
drop_last=False)
return val_dataloader
def test_dataloader(self):
test_dataloader = DataLoader(
self.test_dataset,
batch_size=self.args.val_batch_size,
num_workers=4,
shuffle=False,
pin_memory=True,
drop_last=False)
return test_dataloader
class CustomModel(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
self.model = self.args.model_name
self.cache_dir = self.args.model_path
self.scheduler = self.args.scheduler
self.step_scheduler_after = "batch"
self.optimizer = self.args.optimizer
self.pooler = self.args.use_original_pooler
self.category = self.args.cate_performance
self.loss_func = self.args.loss_function
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-7
config = AutoConfig.from_pretrained(self.model, cache_dir=self.cache_dir)
config.update(
{
"output_hidden_states": False,
"hidden_dropout_prob": hidden_dropout_prob,
"layer_norm_eps": layer_norm_eps,
}
)
self.transformer = AutoModel.from_pretrained(self.model, config=config, cache_dir=self.cache_dir)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.linear = torch.nn.Linear(config.hidden_size, self.args.num_labels, bias=True) # 分三类
def configure_optimizers(self):
"""Prepare optimizer and schedule"""
model = self.transformer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.01,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer_index = ['Adam', 'AdamW'].index(self.optimizer)
optimizer = [
torch.optim.Adam(optimizer_grouped_parameters, lr=self.args.learning_rate),
torch.optim.AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)][optimizer_index]
scheduler_index = ['StepLR', 'CosineWarmup', 'CosineAnnealingLR'].index(self.scheduler)
scheduler = [
torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.warmup_step,
gamma=self.args.warmup_proportion),
get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=int(self.args.warmup_proportion * self.total_steps),
num_training_steps=self.total_steps,
),
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=2e-06)][scheduler_index]
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]
def setup(self, stage=None):
if stage != "fit":
return
# calculate total steps
train_dataloader = self.trainer.datamodule.train_dataloader()
gpus = 0 if self.trainer.gpus is None else self.trainer.gpus
tb_size = self.args.batch_size * max(1, gpus)
ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs)
self.total_steps = (len(train_dataloader.dataset) // tb_size) // ab_size
def loss(self, outputs, targets):
lossf_index = ['CE', 'Focal', 'LSCE_correction'].index(self.loss_func)
loss_fct = [nn.CrossEntropyLoss(), FocalLoss(), LabelSmoothingCorrectionCrossEntropy()][lossf_index]
loss = loss_fct(outputs, targets)
return loss
def category_performance_measure(self, labels_right, labels_pred, num_label=3):
text_labels = [i for i in range(num_label)]
TP = dict.fromkeys(text_labels, 0) # 预测正确的各个类的数目
TP_FP = dict.fromkeys(text_labels, 0) # 测试数据集中各个类的数目
TP_FN = dict.fromkeys(text_labels, 0) # 预测结果中各个类的数目
label_dict = defaultdict(list)
for num in range(num_label):
label_dict[num].append(str(num))
# 计算TP等数量
for i in range(0, len(labels_right)):
TP_FP[labels_right[i]] += 1
TP_FN[labels_pred[i]] += 1
if labels_right[i] == labels_pred[i]:
TP[labels_right[i]] += 1
# 计算准确率P,召回率R,F1值
results = []
for key in TP_FP:
P = float(TP[key]) / float(TP_FP[key] + 1e-9)
R = float(TP[key]) / float(TP_FN[key] + 1e-9)
F1 = P * R * 2 / (P + R) if (P + R) != 0 else 0
# results.append("%s:\t P:%f\t R:%f\t F1:%f" % (key, P, R, F1))
results.append(F1)
return results
def monitor_metrics(self, outputs, targets):
pred = torch.argmax(outputs, dim=1).cpu().numpy().tolist()
targets = targets.int().cpu().numpy().tolist()
if self.category:
category_results = self.category_performance_measure(
labels_right=targets,
labels_pred=pred,
num_label=self.args.num_labels
)
return {"f1": category_results}
else:
f1_score = metrics.f1_score(targets, pred, average="macro")
return {"f1": f1_score}
def forward(self, ids, mask, token_type_ids, labels):
transformer_out = self.transformer(input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids)
if self.pooler:
pooler_output = transformer_out.pooler_output
else:
sequence_output = transformer_out.last_hidden_state
pooler_output = torch.mean(sequence_output, dim=1)
logits = self.linear(self.dropout(pooler_output))
labels_hat = torch.argmax(logits, dim=1)
correct_count = torch.sum(labels == labels_hat)
return logits, correct_count
def predict(self, ids, mask, token_type_ids):
transformer_out = self.transformer(input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids)
pooler_output = transformer_out.pooler_output
logits = self.linear(self.dropout(pooler_output))
logits = torch.argmax(logits, dim=1)
return logits
def training_step(self, batch, batch_idx):
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets']
logits, correct_count = self.forward(ids, mask, token_type_ids, labels)
loss = self.loss(logits, labels.long())
f1 = self.monitor_metrics(logits, labels)["f1"]
self.log("train_loss", loss, logger=True, prog_bar=True)
self.log('train_acc', correct_count.float() / len(labels), logger=True, prog_bar=True)
if self.category:
self.log("train_f1_key0", f1[0], logger=True, prog_bar=True)
self.log("train_f1_key1", f1[1], logger=True, prog_bar=True)
self.log("train_f1_key2", f1[2], logger=True, prog_bar=True)
else:
self.log("train_f1", f1, logger=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets']
logits, correct_count = self.forward(ids, mask, token_type_ids, labels)
loss = self.loss(logits, labels.long())
f1 = self.monitor_metrics(logits, labels)["f1"]
self.log("val_loss", loss, logger=True, prog_bar=True)
self.log("val_acc", correct_count.float() / len(labels), logger=True, prog_bar=True)
if self.category:
self.log("val_f1_key0", f1[0], logger=True, prog_bar=True)
self.log("val_f1_key1", f1[1], logger=True, prog_bar=True)
self.log("val_f1_key2", f1[2], logger=True, prog_bar=True)
else:
self.log("val_f1", f1, logger=True, prog_bar=True)
def test_step(self, batch, batch_idx):
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets']
logits, correct_count = self.forward(ids, mask, token_type_ids, labels)
loss = self.loss(logits, labels.long())
f1 = self.monitor_metrics(logits, labels)["f1"]
self.log("test_loss", loss, logger=True, prog_bar=True)
self.log("test_acc", correct_count.float() / len(labels), logger=True, prog_bar=True)
if self.category:
self.log("test_f1_key0", f1[0], logger=True, prog_bar=True)
self.log("test_f1_key1", f1[1], logger=True, prog_bar=True)
self.log("test_f1_key2", f1[2], logger=True, prog_bar=True)
else:
self.log("test_f1", f1, logger=True, prog_bar=True)
return {"test_loss": loss, "logits": logits, "labels": labels}
def predict_step(self, batch, batch_idx, dataloader_idx):
ids, mask, token_type_ids, id = batch['ids'], batch['mask'], batch['token_type_ids'], batch['id']
logits = self.predict(ids, mask, token_type_ids)
return {'id': id.cpu().numpy().tolist(), 'logits': logits.cpu().numpy().tolist()}