Spaces:
Runtime error
Runtime error
handle the correct prompting style
Browse files
app.py
CHANGED
@@ -21,12 +21,26 @@ dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
|
|
21 |
# Get a reference to the table
|
22 |
table = dynamodb.Table('oaaic_chatbot_arena')
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
class Pipeline:
|
25 |
prefer_async = True
|
26 |
|
27 |
-
def __init__(self, endpoint_id, name):
|
28 |
self.endpoint_id = endpoint_id
|
29 |
self.name = name
|
|
|
30 |
self.generation_config = {
|
31 |
"max_tokens": 1024,
|
32 |
"top_k": 40,
|
@@ -37,7 +51,7 @@ class Pipeline:
|
|
37 |
"seed": -1,
|
38 |
"batch_size": 8,
|
39 |
"threads": -1,
|
40 |
-
"stop": ["</s>", "USER:"],
|
41 |
}
|
42 |
|
43 |
def __call__(self, prompt):
|
@@ -79,13 +93,16 @@ class Pipeline:
|
|
79 |
# Sleep for 3 seconds between each request
|
80 |
sleep(3)
|
81 |
|
|
|
|
|
|
|
82 |
|
83 |
AVAILABLE_MODELS = {
|
84 |
-
"hermes-13b": "p0zqb2gkcwp0ww",
|
85 |
-
"manticore-13b-chat": "u6tv84bpomhfei",
|
86 |
-
"airoboros-13b": "rglzxnk80660ja",
|
87 |
-
"supercot-13b": "0be7865dwxpwqk",
|
88 |
-
"mpt-7b-instruct": "jpqbvnyluj18b0",
|
89 |
}
|
90 |
|
91 |
_memoized_models = defaultdict()
|
@@ -93,7 +110,7 @@ _memoized_models = defaultdict()
|
|
93 |
|
94 |
def get_model_pipeline(model_name):
|
95 |
if not _memoized_models.get(model_name):
|
96 |
-
_memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name], model_name)
|
97 |
return _memoized_models.get(model_name)
|
98 |
|
99 |
start_message = """- The Assistant is helpful and transparent.
|
@@ -116,20 +133,17 @@ def chat(history1, history2, system_msg):
|
|
116 |
history1 = history1 or []
|
117 |
history2 = history2 or []
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
# remove last space from assistant, some models output a ZWSP if you leave a space
|
127 |
messages1 = messages1.rstrip()
|
128 |
messages2 = messages2.rstrip()
|
129 |
|
130 |
-
random_battle = random.sample(AVAILABLE_MODELS.keys(), 2)
|
131 |
-
model1 = get_model_pipeline(random_battle[0])
|
132 |
-
model2 = get_model_pipeline(random_battle[1])
|
133 |
|
134 |
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
135 |
futures = []
|
@@ -212,14 +226,14 @@ with gr.Blocks() as arena:
|
|
212 |
dismiss_reveal = gr.Button(value="Dismiss & Continue", variant="secondary", visible=False).style(full_width=True)
|
213 |
with gr.Row():
|
214 |
with gr.Column():
|
215 |
-
rlhf_persona = gr.Textbox(
|
216 |
-
"", label="Persona Tags", interactive=True, visible=True, placeholder="Tell us about how you are judging the quality. ex: #CoT #SFW #NSFW #helpful #ethical #creativity", lines=2)
|
217 |
message = gr.Textbox(
|
218 |
label="What do you want to ask?",
|
219 |
placeholder="Ask me anything.",
|
220 |
lines=3,
|
221 |
)
|
222 |
with gr.Column():
|
|
|
|
|
223 |
system_msg = gr.Textbox(
|
224 |
start_message, label="System Message", interactive=True, visible=True, placeholder="system prompt", lines=8)
|
225 |
|
|
|
21 |
# Get a reference to the table
|
22 |
table = dynamodb.Table('oaaic_chatbot_arena')
|
23 |
|
24 |
+
|
25 |
+
def prompt_instruct(system_msg, history):
|
26 |
+
return system_msg.strip() + "\n" + \
|
27 |
+
"\n".join(["\n".join(["### Instruction: "+item[0], "### Response: "+item[1]])
|
28 |
+
for item in history])
|
29 |
+
|
30 |
+
|
31 |
+
def prompt_chat(system_msg, history):
|
32 |
+
return system_msg.strip() + "\n" + \
|
33 |
+
"\n".join(["\n".join(["USER: "+item[0], "ASSISTANT: "+item[1]])
|
34 |
+
for item in history])
|
35 |
+
|
36 |
+
|
37 |
class Pipeline:
|
38 |
prefer_async = True
|
39 |
|
40 |
+
def __init__(self, endpoint_id, name, prompt_fn):
|
41 |
self.endpoint_id = endpoint_id
|
42 |
self.name = name
|
43 |
+
self.prompt_fn = prompt_fn
|
44 |
self.generation_config = {
|
45 |
"max_tokens": 1024,
|
46 |
"top_k": 40,
|
|
|
51 |
"seed": -1,
|
52 |
"batch_size": 8,
|
53 |
"threads": -1,
|
54 |
+
"stop": ["</s>", "USER:", "### Instruction:"],
|
55 |
}
|
56 |
|
57 |
def __call__(self, prompt):
|
|
|
93 |
# Sleep for 3 seconds between each request
|
94 |
sleep(3)
|
95 |
|
96 |
+
def transform_prompt(self, system_msg, history):
|
97 |
+
return self.prompt_fn(system_msg, history)
|
98 |
+
|
99 |
|
100 |
AVAILABLE_MODELS = {
|
101 |
+
"hermes-13b": ("p0zqb2gkcwp0ww", prompt_instruct),
|
102 |
+
"manticore-13b-chat": ("u6tv84bpomhfei", prompt_chat),
|
103 |
+
"airoboros-13b": ("rglzxnk80660ja", prompt_chat),
|
104 |
+
"supercot-13b": ("0be7865dwxpwqk", prompt_instruct),
|
105 |
+
"mpt-7b-instruct": ("jpqbvnyluj18b0", prompt_instruct),
|
106 |
}
|
107 |
|
108 |
_memoized_models = defaultdict()
|
|
|
110 |
|
111 |
def get_model_pipeline(model_name):
|
112 |
if not _memoized_models.get(model_name):
|
113 |
+
_memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name][0], model_name, AVAILABLE_MODELS[model_name][1])
|
114 |
return _memoized_models.get(model_name)
|
115 |
|
116 |
start_message = """- The Assistant is helpful and transparent.
|
|
|
133 |
history1 = history1 or []
|
134 |
history2 = history2 or []
|
135 |
|
136 |
+
random_battle = random.sample(AVAILABLE_MODELS.keys(), 2)
|
137 |
+
model1 = get_model_pipeline(random_battle[0])
|
138 |
+
model2 = get_model_pipeline(random_battle[1])
|
139 |
+
|
140 |
+
messages1 = model1.transform_prompt(system_msg, history1)
|
141 |
+
messages2 = model2.transform_prompt(system_msg, history2)
|
142 |
|
143 |
# remove last space from assistant, some models output a ZWSP if you leave a space
|
144 |
messages1 = messages1.rstrip()
|
145 |
messages2 = messages2.rstrip()
|
146 |
|
|
|
|
|
|
|
147 |
|
148 |
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
149 |
futures = []
|
|
|
226 |
dismiss_reveal = gr.Button(value="Dismiss & Continue", variant="secondary", visible=False).style(full_width=True)
|
227 |
with gr.Row():
|
228 |
with gr.Column():
|
|
|
|
|
229 |
message = gr.Textbox(
|
230 |
label="What do you want to ask?",
|
231 |
placeholder="Ask me anything.",
|
232 |
lines=3,
|
233 |
)
|
234 |
with gr.Column():
|
235 |
+
rlhf_persona = gr.Textbox(
|
236 |
+
"", label="Persona Tags", interactive=True, visible=True, placeholder="Tell us about how you are judging the quality. ex: #CoT #SFW #NSFW #helpful #ethical #creativity", lines=2)
|
237 |
system_msg = gr.Textbox(
|
238 |
start_message, label="System Message", interactive=True, visible=True, placeholder="system prompt", lines=8)
|
239 |
|