winglian commited on
Commit
403f43d
·
1 Parent(s): 4a2d9ba

handle the correct prompting style

Browse files
Files changed (1) hide show
  1. app.py +33 -19
app.py CHANGED
@@ -21,12 +21,26 @@ dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
21
  # Get a reference to the table
22
  table = dynamodb.Table('oaaic_chatbot_arena')
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class Pipeline:
25
  prefer_async = True
26
 
27
- def __init__(self, endpoint_id, name):
28
  self.endpoint_id = endpoint_id
29
  self.name = name
 
30
  self.generation_config = {
31
  "max_tokens": 1024,
32
  "top_k": 40,
@@ -37,7 +51,7 @@ class Pipeline:
37
  "seed": -1,
38
  "batch_size": 8,
39
  "threads": -1,
40
- "stop": ["</s>", "USER:"],
41
  }
42
 
43
  def __call__(self, prompt):
@@ -79,13 +93,16 @@ class Pipeline:
79
  # Sleep for 3 seconds between each request
80
  sleep(3)
81
 
 
 
 
82
 
83
  AVAILABLE_MODELS = {
84
- "hermes-13b": "p0zqb2gkcwp0ww",
85
- "manticore-13b-chat": "u6tv84bpomhfei",
86
- "airoboros-13b": "rglzxnk80660ja",
87
- "supercot-13b": "0be7865dwxpwqk",
88
- "mpt-7b-instruct": "jpqbvnyluj18b0",
89
  }
90
 
91
  _memoized_models = defaultdict()
@@ -93,7 +110,7 @@ _memoized_models = defaultdict()
93
 
94
  def get_model_pipeline(model_name):
95
  if not _memoized_models.get(model_name):
96
- _memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name], model_name)
97
  return _memoized_models.get(model_name)
98
 
99
  start_message = """- The Assistant is helpful and transparent.
@@ -116,20 +133,17 @@ def chat(history1, history2, system_msg):
116
  history1 = history1 or []
117
  history2 = history2 or []
118
 
119
- messages1 = system_msg.strip() + "\n" + \
120
- "\n".join(["\n".join(["USER: "+item[0], "ASSISTANT: "+item[1]])
121
- for item in history1])
122
- messages2 = system_msg.strip() + "\n" + \
123
- "\n".join(["\n".join(["USER: "+item[0], "ASSISTANT: "+item[1]])
124
- for item in history2])
125
 
126
  # remove last space from assistant, some models output a ZWSP if you leave a space
127
  messages1 = messages1.rstrip()
128
  messages2 = messages2.rstrip()
129
 
130
- random_battle = random.sample(AVAILABLE_MODELS.keys(), 2)
131
- model1 = get_model_pipeline(random_battle[0])
132
- model2 = get_model_pipeline(random_battle[1])
133
 
134
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
135
  futures = []
@@ -212,14 +226,14 @@ with gr.Blocks() as arena:
212
  dismiss_reveal = gr.Button(value="Dismiss & Continue", variant="secondary", visible=False).style(full_width=True)
213
  with gr.Row():
214
  with gr.Column():
215
- rlhf_persona = gr.Textbox(
216
- "", label="Persona Tags", interactive=True, visible=True, placeholder="Tell us about how you are judging the quality. ex: #CoT #SFW #NSFW #helpful #ethical #creativity", lines=2)
217
  message = gr.Textbox(
218
  label="What do you want to ask?",
219
  placeholder="Ask me anything.",
220
  lines=3,
221
  )
222
  with gr.Column():
 
 
223
  system_msg = gr.Textbox(
224
  start_message, label="System Message", interactive=True, visible=True, placeholder="system prompt", lines=8)
225
 
 
21
  # Get a reference to the table
22
  table = dynamodb.Table('oaaic_chatbot_arena')
23
 
24
+
25
+ def prompt_instruct(system_msg, history):
26
+ return system_msg.strip() + "\n" + \
27
+ "\n".join(["\n".join(["### Instruction: "+item[0], "### Response: "+item[1]])
28
+ for item in history])
29
+
30
+
31
+ def prompt_chat(system_msg, history):
32
+ return system_msg.strip() + "\n" + \
33
+ "\n".join(["\n".join(["USER: "+item[0], "ASSISTANT: "+item[1]])
34
+ for item in history])
35
+
36
+
37
  class Pipeline:
38
  prefer_async = True
39
 
40
+ def __init__(self, endpoint_id, name, prompt_fn):
41
  self.endpoint_id = endpoint_id
42
  self.name = name
43
+ self.prompt_fn = prompt_fn
44
  self.generation_config = {
45
  "max_tokens": 1024,
46
  "top_k": 40,
 
51
  "seed": -1,
52
  "batch_size": 8,
53
  "threads": -1,
54
+ "stop": ["</s>", "USER:", "### Instruction:"],
55
  }
56
 
57
  def __call__(self, prompt):
 
93
  # Sleep for 3 seconds between each request
94
  sleep(3)
95
 
96
+ def transform_prompt(self, system_msg, history):
97
+ return self.prompt_fn(system_msg, history)
98
+
99
 
100
  AVAILABLE_MODELS = {
101
+ "hermes-13b": ("p0zqb2gkcwp0ww", prompt_instruct),
102
+ "manticore-13b-chat": ("u6tv84bpomhfei", prompt_chat),
103
+ "airoboros-13b": ("rglzxnk80660ja", prompt_chat),
104
+ "supercot-13b": ("0be7865dwxpwqk", prompt_instruct),
105
+ "mpt-7b-instruct": ("jpqbvnyluj18b0", prompt_instruct),
106
  }
107
 
108
  _memoized_models = defaultdict()
 
110
 
111
  def get_model_pipeline(model_name):
112
  if not _memoized_models.get(model_name):
113
+ _memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name][0], model_name, AVAILABLE_MODELS[model_name][1])
114
  return _memoized_models.get(model_name)
115
 
116
  start_message = """- The Assistant is helpful and transparent.
 
133
  history1 = history1 or []
134
  history2 = history2 or []
135
 
136
+ random_battle = random.sample(AVAILABLE_MODELS.keys(), 2)
137
+ model1 = get_model_pipeline(random_battle[0])
138
+ model2 = get_model_pipeline(random_battle[1])
139
+
140
+ messages1 = model1.transform_prompt(system_msg, history1)
141
+ messages2 = model2.transform_prompt(system_msg, history2)
142
 
143
  # remove last space from assistant, some models output a ZWSP if you leave a space
144
  messages1 = messages1.rstrip()
145
  messages2 = messages2.rstrip()
146
 
 
 
 
147
 
148
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
149
  futures = []
 
226
  dismiss_reveal = gr.Button(value="Dismiss & Continue", variant="secondary", visible=False).style(full_width=True)
227
  with gr.Row():
228
  with gr.Column():
 
 
229
  message = gr.Textbox(
230
  label="What do you want to ask?",
231
  placeholder="Ask me anything.",
232
  lines=3,
233
  )
234
  with gr.Column():
235
+ rlhf_persona = gr.Textbox(
236
+ "", label="Persona Tags", interactive=True, visible=True, placeholder="Tell us about how you are judging the quality. ex: #CoT #SFW #NSFW #helpful #ethical #creativity", lines=2)
237
  system_msg = gr.Textbox(
238
  start_message, label="System Message", interactive=True, visible=True, placeholder="system prompt", lines=8)
239