|
from transformers import AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoModelForSequenceClassification |
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler |
|
import torch |
|
from sklearn.model_selection import train_test_split |
|
from dataset.load_dataset import df, prepare_dataset |
|
from torch.nn import BCEWithLogitsLoss |
|
from transformers import BertForSequenceClassification, BertConfig |
|
from tqdm.auto import tqdm |
|
from torch.cuda.amp import GradScaler, autocast |
|
from torch.utils.tensorboard import SummaryWriter |
|
import datetime |
|
|
|
|
|
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') |
|
log_dir = f'runs/train_{current_time}' |
|
writer = SummaryWriter(log_dir) |
|
epochs = 10 |
|
lr = 1e-5 |
|
optimizer_name = 'AdamW' |
|
loss_fn_name = 'BCEWithLogitsLoss' |
|
batch_size = 16 |
|
|
|
|
|
model_save_name = f'model_{current_time}_lr{lr}_opt{optimizer_name}_loss{loss_fn_name}_batch{batch_size}_epoch{epochs}.pt' |
|
model_save_path = f'./saved_models/{model_save_name}' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"pretrained_models/Bio_ClinicalBERT-finetuned-medicalcondition") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
train_df, val_df = train_test_split(df, test_size=0.1) |
|
|
|
|
|
train_dataset = prepare_dataset(train_df, tokenizer) |
|
val_dataset = prepare_dataset(val_df, tokenizer) |
|
|
|
train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size) |
|
validation_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=batch_size) |
|
|
|
|
|
config = BertConfig.from_pretrained("pretrained_models/Bio_ClinicalBERT-finetuned-medicalcondition") |
|
config.num_labels = 8 |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
"pretrained_models/Bio_ClinicalBERT-finetuned-medicalcondition", config=config, ignore_mismatched_sizes=True).to( |
|
device) |
|
|
|
optimizer = AdamW(model.parameters(), lr=1e-5, eps=1e-8) |
|
total_steps = len(train_dataloader) * epochs |
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps) |
|
loss_fn = BCEWithLogitsLoss() |
|
|
|
scaler = GradScaler() |
|
|
|
for epoch in range(epochs): |
|
print(f"\nEpoch {epoch + 1}/{epochs}") |
|
print('-------------------------------') |
|
model.train() |
|
total_loss = 0 |
|
train_progress_bar = tqdm(train_dataloader, desc="Training", leave=False) |
|
for step, batch in enumerate(train_progress_bar): |
|
|
|
batch = tuple(t.to(device) for t in batch) |
|
b_input_ids, b_input_mask, b_labels = batch |
|
model.zero_grad() |
|
|
|
outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask) |
|
logits = outputs.logits |
|
|
|
loss = loss_fn(logits, b_labels) |
|
total_loss += loss.item() |
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
if torch.isnan(loss).any(): |
|
print(f"Loss is nan in epoch {epoch + 1}, step {step}.") |
|
|
|
|
|
continue |
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
train_progress_bar.set_postfix({'loss': f"{loss.item():.2f}"}) |
|
|
|
writer.add_scalar('Loss/train', loss.item(), epoch * len(train_dataloader) + step) |
|
|
|
|
|
avg_train_loss = total_loss / len(train_dataloader) |
|
print(f"Training loss: {avg_train_loss:.2f}") |
|
|
|
|
|
model.eval() |
|
total_eval_accuracy = 0 |
|
eval_progress_bar = tqdm(validation_dataloader, desc="Validation", leave=False) |
|
total_eval_loss = 0 |
|
|
|
for batch in eval_progress_bar: |
|
batch = tuple(t.to(device) for t in batch) |
|
b_input_ids, b_input_mask, b_labels = batch |
|
with torch.no_grad(): |
|
outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask) |
|
logits = outputs.logits |
|
|
|
loss = loss_fn(logits, b_labels) |
|
total_eval_loss += loss.item() |
|
|
|
probs = torch.sigmoid(logits) |
|
|
|
predictions = (probs > 0.5).int() |
|
|
|
correct_predictions = (predictions == b_labels.int()).float() |
|
|
|
accuracy_per_sample = correct_predictions.mean(dim=1) |
|
accuracy = accuracy_per_sample.mean().item() |
|
total_eval_accuracy += accuracy |
|
|
|
eval_progress_bar.set_postfix({'accuracy': f"{accuracy:.2f}"}) |
|
|
|
avg_val_loss = total_eval_loss / len(validation_dataloader) |
|
print(f"Validation Loss: {avg_val_loss:.2f}") |
|
avg_val_accuracy = total_eval_accuracy / len(validation_dataloader) |
|
writer.add_scalar('Loss/val', avg_val_loss, epoch) |
|
print(f"Validation Accuracy: {avg_val_accuracy:.2f}") |
|
|
|
writer.close() |
|
|
|
torch.save(model.state_dict(), model_save_path) |
|
print(f"traing end, save model to :{model_save_path}") |
|
|