Hansimov commited on
Commit
30421b7
1 Parent(s): 4ba2ca6

:gem: [Feature] MessageComposer: Support openchat-3.5, and generalize preprocessing stages

Browse files
Files changed (1) hide show
  1. messagers/message_composer.py +111 -36
messagers/message_composer.py CHANGED
@@ -3,12 +3,18 @@ from pprint import pprint
3
 
4
 
5
  class MessageComposer:
6
- """
7
- models:
8
- - mixtral-8x7b (mistralai/Mixtral-8x7B-Instruct-v0.1)
9
- """
 
 
10
 
11
  def __init__(self, model: str = None):
 
 
 
 
12
  self.inst_roles = ["user", "system", "inst"]
13
  self.answer_roles = ["assistant", "bot", "answer"]
14
 
@@ -40,37 +46,62 @@ class MessageComposer:
40
  return concat_messages
41
 
42
  def merge(self, messages) -> str:
43
- # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]
 
 
 
44
 
45
  self.messages = self.concat_messages_by_role(messages)
46
  self.merged_str = ""
47
- self.cached_str = ""
48
- for message in self.messages:
49
- role = message["role"]
50
- content = message["content"]
51
- if role in self.inst_roles:
52
- self.cached_str = f"[INST] {content} [/INST]"
53
- elif role in self.answer_roles:
54
- self.merged_str += f"<s> {self.cached_str} {content} </s>\n"
55
- self.cached_str = ""
56
- else:
57
- self.cached_str = f"[INST] {content} [/INST]"
58
- if self.cached_str:
59
- self.merged_str += f"{self.cached_str}"
60
 
61
- return self.merged_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- def split(self, merged_str) -> list:
64
- self.messages = []
65
- self.merged_str = merged_str
66
- pair_pattern = (
67
- r"<s>\s*\[INST\](?P<inst>[\s\S]*?)\[/INST\](?P<answer>[\s\S]*?)</s>"
68
- )
69
- pair_matches = re.finditer(pair_pattern, self.merged_str, re.MULTILINE)
70
- pair_matches_list = list(pair_matches)
71
 
 
 
72
  if len(pair_matches_list) <= 0:
73
- self.messages = [
74
  {
75
  "role": "user",
76
  "content": self.merged_str,
@@ -80,17 +111,15 @@ class MessageComposer:
80
  for match in pair_matches_list:
81
  inst = match.group("inst")
82
  answer = match.group("answer")
83
- self.messages.extend(
84
  [
85
  {"role": "user", "content": inst.strip()},
86
  {"role": "assistant", "content": answer.strip()},
87
  ]
88
  )
 
89
 
90
- inst_pattern = r"\[INST\](?P<inst>[\s\S]*?)\[/INST\]"
91
- inst_matches = re.finditer(inst_pattern, self.merged_str, re.MULTILINE)
92
- inst_matches_list = list(inst_matches)
93
-
94
  if len(inst_matches_list) > len(pair_matches_list):
95
  self.messages.extend(
96
  [
@@ -101,11 +130,56 @@ class MessageComposer:
101
  ]
102
  )
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  return self.messages
105
 
106
 
107
  if __name__ == "__main__":
108
- composer = MessageComposer()
109
  messages = [
110
  {
111
  "role": "system",
@@ -113,8 +187,8 @@ if __name__ == "__main__":
113
  },
114
  {"role": "user", "content": "Hello, who are you?"},
115
  {"role": "assistant", "content": "I am a bot."},
116
- # {"role": "user", "content": "What is your name?"},
117
- {"role": "assistant", "content": "My name is Bing."},
118
  # {"role": "user", "content": "Tell me a joke."},
119
  # {"role": "assistant", "content": "What is a robot's favorite type of music?"},
120
  # {
@@ -122,6 +196,7 @@ if __name__ == "__main__":
122
  # "content": "How many questions have I asked? Please list them.",
123
  # },
124
  ]
 
125
  merged_str = composer.merge(messages)
126
  print(merged_str)
127
  pprint(composer.split(merged_str))
 
3
 
4
 
5
  class MessageComposer:
6
+ # LINK - apis/chat_api.py#available-models
7
+ AVALAIBLE_MODELS = [
8
+ "mixtral-8x7b",
9
+ "mistral-7b",
10
+ "openchat-3.5",
11
+ ]
12
 
13
  def __init__(self, model: str = None):
14
+ if model in self.AVALAIBLE_MODELS:
15
+ self.model = model
16
+ else:
17
+ self.model = "mixtral-8x7b"
18
  self.inst_roles = ["user", "system", "inst"]
19
  self.answer_roles = ["assistant", "bot", "answer"]
20
 
 
46
  return concat_messages
47
 
48
  def merge(self, messages) -> str:
49
+ # Mistral and Mixtral:
50
+ # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]
51
+ # OpenChat:
52
+ # GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi<|end_of_turn|>GPT4 Correct User: How are you today?<|end_of_turn|>GPT4 Correct Assistant:
53
 
54
  self.messages = self.concat_messages_by_role(messages)
55
  self.merged_str = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ if self.model in ["mixtral-8x7b", "mistral-7b"]:
58
+ self.cached_str = ""
59
+ for message in self.messages:
60
+ role = message["role"]
61
+ content = message["content"]
62
+ if role in self.inst_roles:
63
+ self.cached_str = f"[INST] {content} [/INST]"
64
+ elif role in self.answer_roles:
65
+ self.merged_str += f"<s> {self.cached_str} {content} </s>\n"
66
+ self.cached_str = ""
67
+ else:
68
+ self.cached_str = f"[INST] {content} [/INST]"
69
+ if self.cached_str:
70
+ self.merged_str += f"{self.cached_str}"
71
+ elif self.model in ["openchat-3.5"]:
72
+ self.merged_str_list = []
73
+ self.end_of_turn = "<|end_of_turn|>"
74
+ for message in self.messages:
75
+ role = message["role"]
76
+ content = message["content"]
77
+ if role in self.inst_roles:
78
+ self.merged_str_list.append(
79
+ f"GPT4 Correct User:\n{content}{self.end_of_turn}"
80
+ )
81
+ elif role in self.answer_roles:
82
+ self.merged_str_list.append(
83
+ f"GPT4 Correct Assistant:\n{content}{self.end_of_turn}"
84
+ )
85
+ else:
86
+ self.merged_str_list.append(
87
+ f"GPT4 Correct User: {content}{self.end_of_turn}"
88
+ )
89
+ self.merged_str_list.append(f"GPT4 Correct Assistant:\n")
90
+ self.merged_str = "\n".join(self.merged_str_list)
91
+ else:
92
+ self.merged_str = "\n".join(
93
+ [
94
+ f'`{message["role"]}`:\n{message["content"]}\n'
95
+ for message in self.messages
96
+ ]
97
+ )
98
 
99
+ return self.merged_str
 
 
 
 
 
 
 
100
 
101
+ def convert_pair_matches_to_messages(self, pair_matches_list):
102
+ messages = []
103
  if len(pair_matches_list) <= 0:
104
+ messages = [
105
  {
106
  "role": "user",
107
  "content": self.merged_str,
 
111
  for match in pair_matches_list:
112
  inst = match.group("inst")
113
  answer = match.group("answer")
114
+ messages.extend(
115
  [
116
  {"role": "user", "content": inst.strip()},
117
  {"role": "assistant", "content": answer.strip()},
118
  ]
119
  )
120
+ return messages
121
 
122
+ def append_last_instruction_to_messages(self, inst_matches_list, pair_matches_list):
 
 
 
123
  if len(inst_matches_list) > len(pair_matches_list):
124
  self.messages.extend(
125
  [
 
130
  ]
131
  )
132
 
133
+ def split(self, merged_str) -> list:
134
+ self.merged_str = merged_str
135
+ self.messages = []
136
+
137
+ if self.model in ["mixtral-8x7b", "mistral-7b"]:
138
+ pair_pattern = (
139
+ r"<s>\s*\[INST\](?P<inst>[\s\S]*?)\[/INST\](?P<answer>[\s\S]*?)</s>"
140
+ )
141
+ pair_matches = re.finditer(pair_pattern, self.merged_str, re.MULTILINE)
142
+ pair_matches_list = list(pair_matches)
143
+
144
+ self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
145
+
146
+ inst_pattern = r"\[INST\](?P<inst>[\s\S]*?)\[/INST\]"
147
+ inst_matches = re.finditer(inst_pattern, self.merged_str, re.MULTILINE)
148
+ inst_matches_list = list(inst_matches)
149
+
150
+ self.append_last_instruction_to_messages(
151
+ inst_matches_list, pair_matches_list
152
+ )
153
+
154
+ elif self.model in ["openchat-3.5"]:
155
+ pair_pattern = r"GPT4 Correct User:(?P<inst>[\s\S]*?)<\|end_of_turn\|>\s*GPT4 Correct Assistant:(?P<answer>[\s\S]*?)<\|end_of_turn\|>"
156
+ # ignore case
157
+ pair_matches = re.finditer(
158
+ pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
159
+ )
160
+ pair_matches_list = list(pair_matches)
161
+ self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
162
+ inst_pattern = r"GPT4 Correct User:(?P<inst>[\s\S]*?)<\|end_of_turn\|>"
163
+ inst_matches = re.finditer(
164
+ inst_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
165
+ )
166
+ inst_matches_list = list(inst_matches)
167
+ self.append_last_instruction_to_messages(
168
+ inst_matches_list, pair_matches_list
169
+ )
170
+ else:
171
+ self.messages = [
172
+ {
173
+ "role": "user",
174
+ "content": self.merged_str,
175
+ }
176
+ ]
177
+
178
  return self.messages
179
 
180
 
181
  if __name__ == "__main__":
182
+ composer = MessageComposer(model="openchat-3.5")
183
  messages = [
184
  {
185
  "role": "system",
 
187
  },
188
  {"role": "user", "content": "Hello, who are you?"},
189
  {"role": "assistant", "content": "I am a bot."},
190
+ {"role": "user", "content": "What is your name?"},
191
+ # {"role": "assistant", "content": "My name is Bing."},
192
  # {"role": "user", "content": "Tell me a joke."},
193
  # {"role": "assistant", "content": "What is a robot's favorite type of music?"},
194
  # {
 
196
  # "content": "How many questions have I asked? Please list them.",
197
  # },
198
  ]
199
+ print("model:", composer.model)
200
  merged_str = composer.merge(messages)
201
  print(merged_str)
202
  pprint(composer.split(merged_str))