bingnoi commited on
Commit
e535922
1 Parent(s): aca2cb2

Upload 6 files

Browse files
Files changed (5) hide show
  1. __init__.py +6 -0
  2. get_dataset.py +68 -0
  3. logger.py +60 -0
  4. prompt_concat.py +170 -0
  5. utils.py +59 -0
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from .get_dataset import *
3
+ from .logger import *
4
+ from .prompt_concat import *
5
+ from .retrieve_dialog import *
6
+ from .utils import *
get_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import sys
3
+ sys.path.append("../")
4
+
5
+ from collections import defaultdict
6
+ from .utils import is_float, load_txt
7
+
8
+ import random
9
+
10
+ random.seed(1234)
11
+
12
+
13
+ class CreateDataset:
14
+ def __init__(self, max_input_len=1500):
15
+ self.prompt = load_txt("../prompt/dataset_character.txt")
16
+ self.max_input_len = max_input_len # 小于(seq-length)-(max-gen-length)
17
+ self.example_split_flag = f"\n{'-' * 20}\n"
18
+
19
+ self.dataset = defaultdict(list)
20
+ self.manual_dataset = []
21
+
22
+ @staticmethod
23
+ def choose_examples(similar_examples,
24
+ max_length,
25
+ train_flag=False,
26
+ dialog=None,
27
+ example_split_flag=f"\n{'-' * 20}\n"):
28
+ if isinstance(similar_examples, str):
29
+ new_similar_examples = [x.strip() for x in similar_examples.split(example_split_flag)]
30
+ else:
31
+ # 去重
32
+ new_similar_examples = []
33
+ for example in similar_examples:
34
+ if (isinstance(example, list) or isinstance(example, tuple)) and len(example) == 2 and is_float(
35
+ example[0]):
36
+ # 包含score
37
+ example = example[1]
38
+
39
+ try:
40
+ example = "\n".join(example).strip()
41
+ except TypeError:
42
+ raise TypeError(f"example: {example}")
43
+ if train_flag and dialog and (example in dialog or dialog in example):
44
+ continue
45
+
46
+ # example去重
47
+ if train_flag:
48
+ # 部分相似也去掉
49
+ flag = False
50
+ for n_example in new_similar_examples:
51
+ if example in n_example or n_example in example:
52
+ flag = True
53
+ break
54
+ if not flag:
55
+ new_similar_examples.append(example)
56
+ else:
57
+ if example not in new_similar_examples:
58
+ new_similar_examples.append(example)
59
+
60
+ results = []
61
+ total_length = 0
62
+ for example in new_similar_examples:
63
+ total_length += len(example) if not total_length else len(example_split_flag) + len(example)
64
+ if total_length > max_length:
65
+ break
66
+ results.append(example)
67
+ results = example_split_flag.join(results).strip()
68
+ return results
logger.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from logging.handlers import TimedRotatingFileHandler
3
+
4
+ import os
5
+ import sys
6
+ import logging
7
+
8
+
9
+ class LoggerFactory:
10
+
11
+ @staticmethod
12
+ def create_logger(name=None, level=logging.INFO):
13
+ """create a logger
14
+
15
+ Args:
16
+ name (str): name of the logger
17
+ level: level of logger
18
+
19
+ Raises:
20
+ ValueError is name is None
21
+ """
22
+
23
+ if name is None:
24
+ raise ValueError("name for logger cannot be None")
25
+
26
+ formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] "
27
+ "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
28
+
29
+ logger_ = logging.getLogger(name)
30
+ logger_.setLevel(level)
31
+ logger_.propagate = False
32
+ ch = logging.StreamHandler(stream=sys.stdout)
33
+ ch.setLevel(level)
34
+ ch.setFormatter(formatter)
35
+ logger_.addHandler(ch)
36
+ return logger_
37
+
38
+ @staticmethod
39
+ def create_logger_with_file(log_file_path: str = None, logger_level=logging.INFO):
40
+ logger_inner = logging.getLogger()
41
+ logger_inner.setLevel(logger_level)
42
+ logger_inner.propagate = True
43
+
44
+ formatter = logging.Formatter(fmt="[%(asctime)s] [%(filename)s:%(lineno)s - %(levelname)s] %(message)s",
45
+ datefmt="%Y-%m-%d %H:%M:%S")
46
+
47
+ # TimedRotatingFileHandler
48
+ if log_file_path:
49
+ basedir = os.path.dirname(log_file_path)
50
+ if not os.path.isdir(basedir):
51
+ os.makedirs(basedir, exist_ok=True)
52
+ handler_file = TimedRotatingFileHandler(log_file_path, when="d", interval=1, backupCount=30)
53
+ handler_file.setFormatter(formatter)
54
+ logger_inner.addHandler(handler_file)
55
+
56
+ # StreamHandler
57
+ handler_console = logging.StreamHandler()
58
+ handler_console.setFormatter(formatter)
59
+ logger_inner.addHandler(handler_console)
60
+ return logger_inner
prompt_concat.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from copy import deepcopy
3
+ from .get_dataset import CreateDataset
4
+ from .logger import LoggerFactory
5
+ from .retrieve_dialog import RetrieveDialog
6
+ from .utils import load_json, load_txt, save_to_json
7
+
8
+ import logging
9
+ import os
10
+
11
+ logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
12
+
13
+
14
+ class GetManualTestSamples:
15
+ def __init__(
16
+ self,
17
+ role_name,
18
+ role_data_path,
19
+ save_samples_dir,
20
+ save_samples_path=None,
21
+ prompt_path="dataset_character.txt",
22
+ max_seq_len=4000,
23
+ retrieve_num=20,
24
+ ):
25
+ self.role_name = role_name.strip()
26
+ self.role_data = load_json(role_data_path)
27
+ self.role_info = self.role_data[0]["role_info"].strip()
28
+
29
+ self.prompt = load_txt(prompt_path)
30
+ self.prompt = self.prompt.replace("${role_name}", self.role_name)
31
+ self.prompt = self.prompt.replace("${role_info}",
32
+ f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
33
+
34
+ self.retrieve_num = retrieve_num
35
+ self.retrieve = RetrieveDialog(role_name=self.role_name,
36
+ raw_dialog_list=[d["dialog"] for d in self.role_data],
37
+ retrieve_num=retrieve_num)
38
+
39
+ self.max_seq_len = max_seq_len
40
+ if not save_samples_path:
41
+ save_samples_path = f"{self.role_name}.json"
42
+ self.save_samples_path = os.path.join(save_samples_dir, save_samples_path)
43
+
44
+ def _add_simi_dialog(self, history: list, content_length):
45
+ retrieve_results = self.retrieve.get_retrieve_res(history, self.retrieve_num)
46
+ simi_dialogs = deepcopy(retrieve_results)
47
+
48
+ if simi_dialogs:
49
+ simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
50
+ max_length=self.max_seq_len - content_length,
51
+ train_flag=False)
52
+ logger.debug(f"retrieve_results: {retrieve_results}\nsimi_dialogs: {simi_dialogs}.")
53
+ return simi_dialogs, retrieve_results
54
+
55
+ def get_qa_samples_by_file(self,
56
+ questions_path,
57
+ user_name="user",
58
+ keep_retrieve_results_flag=False
59
+ ):
60
+ questions = load_txt(questions_path).splitlines()
61
+ samples = []
62
+ for question in questions:
63
+ question = question.replace('\\n', "\n")
64
+ query = f"{user_name}:{question}" if ":" not in question else question
65
+ content = self.prompt.replace("${dialog}", query)
66
+ content = content.replace("${user_name}", user_name).strip()
67
+
68
+ history = [query]
69
+ simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
70
+
71
+ sample = {
72
+ "role_name": self.role_name,
73
+ "role_info": self.role_info,
74
+ "user_name": user_name,
75
+ "dialog": history,
76
+ "simi_dialogs": simi_dialogs,
77
+ }
78
+ if keep_retrieve_results_flag and retrieve_results:
79
+ sample["retrieve_results"] = retrieve_results
80
+ samples.append(sample)
81
+ self._save_samples(samples)
82
+
83
+ def get_qa_samples_by_query(self,
84
+ questions_query,
85
+ user_name="user",
86
+ keep_retrieve_results_flag=False
87
+ ):
88
+ question = questions_query
89
+ samples = []
90
+ question = question.replace('\\n', "\n")
91
+ query = f"{user_name}: {question}" if ":" not in question else question
92
+ content = self.prompt.replace("${dialog}", query)
93
+ content = content.replace("${user_name}", user_name).strip()
94
+
95
+ history = [query]
96
+ simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
97
+
98
+ sample = {
99
+ "role_name": self.role_name,
100
+ "role_info": self.role_info,
101
+ "user_name": user_name,
102
+ "dialog": history,
103
+ "simi_dialogs": simi_dialogs,
104
+ }
105
+ if keep_retrieve_results_flag and retrieve_results:
106
+ sample["retrieve_results"] = retrieve_results
107
+ samples.append(sample)
108
+ self._save_samples(samples)
109
+
110
+ def _save_samples(self, samples):
111
+ data = samples
112
+ save_to_json(data, self.save_samples_path)
113
+
114
+
115
+ class CreateTestDataset:
116
+ def __init__(self,
117
+ role_name,
118
+ role_samples_path=None,
119
+ role_data_path=None,
120
+ prompt_path="dataset_character.txt",
121
+ max_seq_len=4000):
122
+ self.max_seq_len = max_seq_len
123
+ self.role_name = role_name
124
+
125
+ self.prompt = load_txt(prompt_path)
126
+ self.prompt = self.prompt.replace("${role_name}", role_name).strip()
127
+
128
+ if not role_data_path:
129
+ print("need role_data_path, check please!")
130
+ self.default_simi_dialogs = None
131
+ if os.path.exists(role_data_path):
132
+ data = load_json(role_data_path)
133
+ role_info = data[0]["role_info"]
134
+ else:
135
+ raise ValueError(f"{self.role_name} didn't find role_info.")
136
+ self.role_info = role_info
137
+ self.prompt = self.prompt.replace("${role_info}", f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
138
+
139
+ if role_samples_path:
140
+ self.role_samples_path = role_samples_path
141
+ else:
142
+ print("check role_samples_path please!")
143
+
144
+ def load_samples(self):
145
+ samples = load_json(self.role_samples_path)
146
+ results = []
147
+ for sample in samples:
148
+ input_text = self.prompt
149
+
150
+ simi_dialogs = sample.get("simi_dialogs", None)
151
+ if not simi_dialogs:
152
+ simi_dialogs = self.default_simi_dialogs
153
+ if not simi_dialogs:
154
+ raise ValueError(f"didn't find simi_dialogs.")
155
+ simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
156
+ max_length=self.max_seq_len - len(input_text),
157
+ train_flag=False)
158
+
159
+ input_text = input_text.replace("${simi_dialog}", simi_dialogs)
160
+ user_name = sample.get("user_name", "user")
161
+ input_text = input_text.replace("${user_name}", user_name)
162
+
163
+ dialog = "\n".join(sample["dialog"]) if isinstance(sample["dialog"], list) else sample["dialog"]
164
+ input_text = input_text.replace("${dialog}", dialog)
165
+
166
+ assert len(input_text) < self.max_seq_len
167
+ results.append({
168
+ "input_text": input_text,
169
+ })
170
+ return results
utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import csv
3
+ import json
4
+ import os
5
+
6
+
7
+ def read_csv_to_json(file_path, role_name, role_info):
8
+ json_list = []
9
+
10
+ with open(file_path, mode="r", newline="", encoding="utf-8") as csvfile:
11
+ csv_reader = csv.reader(csvfile)
12
+ _ = next(csv_reader)
13
+
14
+ for row in csv_reader:
15
+ json_object = {
16
+ "role_name": role_name,
17
+ "role_info": role_info,
18
+ "dialog": row[1].split("\n"),
19
+ }
20
+ json_list.append(json_object)
21
+
22
+ return json_list
23
+
24
+
25
+ def save_json(json_list, output_path):
26
+ with open(output_path, "w", encoding="utf-8") as jsonfile:
27
+ json.dump(json_list, jsonfile, ensure_ascii=False, indent=4)
28
+
29
+
30
+ def decode_csv_to_json(role_data_path, role_name, role_info, json_output_path):
31
+ json_data = read_csv_to_json(role_data_path, role_name, role_info)
32
+ save_json(json_data, json_output_path)
33
+
34
+
35
+ def load_txt(path):
36
+ with open(path, "r", encoding="utf-8", errors="ignore") as file:
37
+ text = file.read()
38
+ return text
39
+
40
+
41
+ def load_json(path):
42
+ with open(path, "r", encoding="utf-8") as f:
43
+ data = json.load(f)
44
+ return data
45
+
46
+
47
+ def save_to_json(data, filepath, flag="w"):
48
+ if not os.path.exists(os.path.dirname(filepath)):
49
+ os.makedirs(os.path.dirname(filepath))
50
+ with open(filepath, flag, encoding="utf-8") as f:
51
+ f.write(json.dumps(data, ensure_ascii=False, indent=3))
52
+
53
+
54
+ def is_float(my_str):
55
+ try:
56
+ num = float(my_str)
57
+ return True
58
+ except ValueError:
59
+ return False