from transformers import AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoModelForSequenceClassification from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.nn import CrossEntropyLoss import torch from sklearn.model_selection import train_test_split from dataset.load_dataset import df, prepare_dataset epochs = 10 tokenizer = AutoTokenizer.from_pretrained( "pretrained_models/Bio_ClinicalBERT-finetuned-medicalcondition") # 用于将文本转换为模型所需输入格式的tokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 创建一个device对象,如果有可用的GPU就使用它,否则使用CPU # 分割数据集 train_df, val_df = train_test_split(df, test_size=0.1) # 以90%训练,10%验证的比例分割数据集 # 准备训练和验证数据集 train_dataset = prepare_dataset(train_df, tokenizer) val_dataset = prepare_dataset(val_df, tokenizer) # 现在train_dataloader和validation_dataloader已准备好,可用于模型训练和验证 train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=64) validation_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=64) model = AutoModelForSequenceClassification.from_pretrained( "pretrained_models/Bio_ClinicalBERT-finetuned-medicalcondition").to(device) input = tokenizer("I love using transformers for natural language processing.", return_tensors="pt") # 使用模型进行预测 # with torch.no_grad(): # logits = model(**input).logits # 解析预测结果 # predicted_class_id = logits.argmax().item() # print(f"Predicted class id: {predicted_class_id}") # 准备优化器和学习率调度器 optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8) total_steps = len(train_dataloader) * epochs # epochs是您想要训练的轮数 scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps) # 微调模型 model.train() for epoch in range(epochs): # 迭代多个epoch for step, batch in enumerate(train_dataloader): # 将数据加载到GPU batch = tuple(t.to(device) for t in batch) b_input_ids, b_input_mask, b_labels = batch model.zero_grad() # 前向传播 outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) loss = outputs.loss logits = outputs.logits # 反向传播 loss.backward() optimizer.step() scheduler.step() # 评估阶段省略,但在实际应用中非常重要