Tokymin's picture
Epoch 10/10
fc547f0
raw
history blame
No virus
1.65 kB
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import torch
# 加载数据
df = pd.read_csv("dataset/processed_new_data.csv")
# 准备数据集
def prepare_dataset(df, tokenizer, max_length=512):
input_ids = []
attention_masks = []
labels = []
for _, row in df.iterrows():
# 检查标签是否有效(例如,不是NaN)
if pd.isna(row['SAS_Class']) or pd.isna(row['SDS_Class']):
continue # 跳过这个样本
encoded = tokenizer.encode_plus(
row['Description'],
add_special_tokens=True,
max_length=max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
input_ids.append(encoded['input_ids'])
attention_masks.append(encoded['attention_mask'])
# labels.append([row['SAS_Class'], row['SDS_Class']])
# 将SAS_Class和SDS_Class转换为one-hot编码
sas_label = [0] * 4 # 初始化4个元素为0的列表
sds_label = [0] * 4 # 同上
sas_label[int(row['SAS_Class'])] = 1 # 将对应的位置设为1
sds_label[int(row['SDS_Class'])] = 1 # 同上
combined_label = sas_label + sds_label # 组合两个标签
labels.append(combined_label)
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(labels, dtype=torch.float)
return TensorDataset(input_ids, attention_masks, labels)