|
from tokenizers import BertWordPieceTokenizer |
|
from transformers import BertTokenizer |
|
from transformers import BertTokenizerFast |
|
import argparse |
|
import pandas as pd |
|
import pickle |
|
import jieba.analyse |
|
from tqdm import tqdm |
|
from transformers import GPT2TokenizerFast, GPT2LMHeadModel |
|
import logging |
|
import numpy as np |
|
|
|
|
|
def create_logger(log_path): |
|
""" |
|
将日志输出到日志文件和控制台 |
|
""" |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
formatter = logging.Formatter( |
|
'%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
file_handler = logging.FileHandler( |
|
filename=log_path) |
|
file_handler.setFormatter(formatter) |
|
file_handler.setLevel(logging.INFO) |
|
logger.addHandler(file_handler) |
|
|
|
|
|
console = logging.StreamHandler() |
|
console.setLevel(logging.DEBUG) |
|
console.setFormatter(formatter) |
|
logger.addHandler(console) |
|
|
|
return logger |
|
|
|
|
|
def preprocess(): |
|
""" |
|
对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]" |
|
""" |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, |
|
help='词表路径') |
|
parser.add_argument('--log_path', default='data/preprocess.log', type=str, required=False, help='训练日志存放位置') |
|
parser.add_argument('--train_path', default='data/train.txt', type=str, required=False, help='训练日志存放位置') |
|
parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='tokenize的训练数据集') |
|
args = parser.parse_args() |
|
|
|
|
|
logger = create_logger(args.log_path) |
|
|
|
|
|
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") |
|
sep_id = tokenizer.sep_token_id |
|
cls_id = tokenizer.cls_token_id |
|
logger.info("preprocessing data,data path:{}, save path:{}".format(args.train_path, args.save_path)) |
|
|
|
|
|
with open(args.train_path, 'rb') as f: |
|
data = f.read().decode("utf-8") |
|
|
|
|
|
if "\r\n" in data: |
|
train_data = data.split("\r\n\r\n") |
|
else: |
|
train_data = data.split("\n\n") |
|
logger.info("there are {} dialogue in dataset".format(len(train_data))) |
|
|
|
|
|
|
|
dialogue_len = [] |
|
dialogue_list = [] |
|
with open(args.save_path, "w", encoding="utf-8") as f: |
|
for index, dialogue in enumerate(tqdm(train_data)): |
|
if "\r\n" in data: |
|
utterances = dialogue.split("\r\n") |
|
else: |
|
utterances = dialogue.split("\n") |
|
|
|
input_ids = [cls_id] |
|
for utterance in utterances: |
|
input_ids += tokenizer.encode(utterance, add_special_tokens=False) |
|
input_ids.append(sep_id) |
|
dialogue_len.append(len(input_ids)) |
|
dialogue_list.append(input_ids) |
|
len_mean = np.mean(dialogue_len) |
|
len_median = np.median(dialogue_len) |
|
len_max = np.max(dialogue_len) |
|
with open(args.save_path, "wb") as f: |
|
pickle.dump(dialogue_list, f) |
|
logger.info("finish preprocessing data,the result is stored in {}".format(args.save_path)) |
|
logger.info("mean of dialogue len:{},median of dialogue len:{},max len:{}".format(len_mean, len_median, len_max)) |
|
|
|
|
|
if __name__ == '__main__': |
|
preprocess() |
|
|
|
|