File size: 1,646 Bytes
1f4f3bd
 
 
 
 
 
 
 
 
 
 
 
 
fc547f0
 
 
 
1f4f3bd
 
 
 
 
 
 
 
 
 
 
fc547f0
 
 
 
 
 
 
 
 
1f4f3bd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import torch

# 加载数据
df = pd.read_csv("dataset/processed_new_data.csv")
# 准备数据集
def prepare_dataset(df, tokenizer, max_length=512):
    input_ids = []
    attention_masks = []
    labels = []
    for _, row in df.iterrows():
        # 检查标签是否有效(例如,不是NaN)
        if pd.isna(row['SAS_Class']) or pd.isna(row['SDS_Class']):
            continue  # 跳过这个样本

        encoded = tokenizer.encode_plus(
            row['Description'],
            add_special_tokens=True,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])
        # labels.append([row['SAS_Class'], row['SDS_Class']])
        # 将SAS_Class和SDS_Class转换为one-hot编码
        sas_label = [0] * 4  # 初始化4个元素为0的列表
        sds_label = [0] * 4  # 同上
        sas_label[int(row['SAS_Class'])] = 1  # 将对应的位置设为1
        sds_label[int(row['SDS_Class'])] = 1  # 同上
        combined_label = sas_label + sds_label  # 组合两个标签

        labels.append(combined_label)

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(labels, dtype=torch.float)
    return TensorDataset(input_ids, attention_masks, labels)