Spaces:
Runtime error
Runtime error
Use proper prompt for Joi
Browse files- app.ipynb +156 -36
- app.py +24 -26
- prompt_templates/openassistant_joi.json +1 -0
app.ipynb
CHANGED
@@ -50,7 +50,7 @@
|
|
50 |
},
|
51 |
{
|
52 |
"cell_type": "code",
|
53 |
-
"execution_count":
|
54 |
"metadata": {},
|
55 |
"outputs": [],
|
56 |
"source": [
|
@@ -75,7 +75,7 @@
|
|
75 |
" if max_new_tokens_supported is True:\n",
|
76 |
" payload[\"parameters\"][\"max_new_tokens\"] = 100\n",
|
77 |
" payload[\"parameters\"][\"repetition_penalty\"]: 1.03\n",
|
78 |
-
" payload[\"parameters\"][\"stop\"] = [\"
|
79 |
" else:\n",
|
80 |
" payload[\"parameters\"][\"max_length\"] = 512\n",
|
81 |
"\n",
|
@@ -95,7 +95,7 @@
|
|
95 |
{
|
96 |
"data": {
|
97 |
"text/plain": [
|
98 |
-
"{'generated_text': '\\n\\nJoi: Black holes are
|
99 |
]
|
100 |
},
|
101 |
"execution_count": 5,
|
@@ -112,7 +112,46 @@
|
|
112 |
},
|
113 |
{
|
114 |
"cell_type": "code",
|
115 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
"metadata": {},
|
117 |
"outputs": [],
|
118 |
"source": [
|
@@ -125,20 +164,14 @@
|
|
125 |
" history=[],\n",
|
126 |
"):\n",
|
127 |
" if \"joi\" in model_id:\n",
|
128 |
-
" prompt_filename = \"
|
|
|
129 |
" else:\n",
|
130 |
" prompt_filename = \"anthropic_hhh_single.json\"\n",
|
131 |
-
"
|
132 |
" with open(f\"prompt_templates/{prompt_filename}\", \"r\") as f:\n",
|
133 |
" prompt_template = json.load(f)\n",
|
134 |
"\n",
|
135 |
-
" history_input = \"\"\n",
|
136 |
-
" for idx, text in enumerate(history):\n",
|
137 |
-
" if idx % 2 == 0:\n",
|
138 |
-
" history_input += f\"Human: {text}\\n\"\n",
|
139 |
-
" else:\n",
|
140 |
-
" history_input += f\"Assistant: {text}\\n\"\n",
|
141 |
-
" history_input = history_input.rstrip(\"\\n\")\n",
|
142 |
" inputs = prompt_template[\"prompt\"].format(human_input=text_input, history=history_input)\n",
|
143 |
" history.append(text_input)\n",
|
144 |
"\n",
|
@@ -146,9 +179,13 @@
|
|
146 |
" print(f\"Inputs: {inputs}\")\n",
|
147 |
"\n",
|
148 |
" output = query_chat_api(model_id, inputs, temperature, top_p)\n",
|
|
|
149 |
" if isinstance(output, list):\n",
|
150 |
" output = output[0]\n",
|
151 |
-
"
|
|
|
|
|
|
|
152 |
" history.append(\" \" + output)\n",
|
153 |
"\n",
|
154 |
" chat = [\n",
|
@@ -179,7 +216,7 @@
|
|
179 |
},
|
180 |
{
|
181 |
"cell_type": "code",
|
182 |
-
"execution_count":
|
183 |
"metadata": {},
|
184 |
"outputs": [
|
185 |
{
|
@@ -217,6 +254,39 @@
|
|
217 |
" json.dump({\"prompt\": template}, f)"
|
218 |
]
|
219 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
{
|
221 |
"cell_type": "code",
|
222 |
"execution_count": 28,
|
@@ -772,7 +842,7 @@
|
|
772 |
},
|
773 |
{
|
774 |
"cell_type": "code",
|
775 |
-
"execution_count":
|
776 |
"metadata": {},
|
777 |
"outputs": [],
|
778 |
"source": [
|
@@ -794,7 +864,7 @@
|
|
794 |
},
|
795 |
{
|
796 |
"cell_type": "code",
|
797 |
-
"execution_count":
|
798 |
"metadata": {},
|
799 |
"outputs": [
|
800 |
{
|
@@ -803,7 +873,7 @@
|
|
803 |
"'So far, the following prompts are available:\\n\\n* `langchain_default`: The default prompt used in the [LangChain library](https://github.com/hwchase17/langchain/blob/bc53c928fc1b221d0038b839d111039d31729def/langchain/chains/conversation/prompt.py#L4). Around 67 tokens long.\\n* `openai_chatgpt`: The prompt used in the OpenAI ChatGPT model. Around 261 tokens long.\\n* `deepmind_Assistant`: The prompt used in the DeepMind Assistant model (Table 7 of [their paper](https://arxiv.org/abs/2209.14375)). Around 880 tokens long.\\n* `deepmind_gopher`: The prompt used in the DeepMind Assistant model (Table A30 of [their paper](https://arxiv.org/abs/2112.11446)). Around 791 tokens long.\\n* `anthropic_hhh`: The prompt used in the [Anthropic HHH models](https://gist.github.com/jareddk/2509330f8ef3d787fc5aaac67aab5f11#file-hhh_prompt-txt). A whopping 6,341 tokens long!\\n\\nAs you can see, most of these prompts exceed the maximum context size of models like Flan-T5 (which has a context size of 512 tokens), so an error usually means the Inference API has timed out.'"
|
804 |
]
|
805 |
},
|
806 |
-
"execution_count":
|
807 |
"metadata": {},
|
808 |
"output_type": "execute_result"
|
809 |
}
|
@@ -822,14 +892,14 @@
|
|
822 |
},
|
823 |
{
|
824 |
"cell_type": "code",
|
825 |
-
"execution_count":
|
826 |
"metadata": {},
|
827 |
"outputs": [
|
828 |
{
|
829 |
"name": "stdout",
|
830 |
"output_type": "stream",
|
831 |
"text": [
|
832 |
-
"Running on local URL: http://127.0.0.1:
|
833 |
"\n",
|
834 |
"To create a public link, set `share=True` in `launch()`.\n"
|
835 |
]
|
@@ -837,7 +907,7 @@
|
|
837 |
{
|
838 |
"data": {
|
839 |
"text/html": [
|
840 |
-
"<div><iframe src=\"http://127.0.0.1:
|
841 |
],
|
842 |
"text/plain": [
|
843 |
"<IPython.core.display.HTML object>"
|
@@ -850,9 +920,60 @@
|
|
850 |
"data": {
|
851 |
"text/plain": []
|
852 |
},
|
853 |
-
"execution_count":
|
854 |
"metadata": {},
|
855 |
"output_type": "execute_result"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
}
|
857 |
],
|
858 |
"source": [
|
@@ -876,18 +997,6 @@
|
|
876 |
" label=\"Model\",\n",
|
877 |
" interactive=True,\n",
|
878 |
" )\n",
|
879 |
-
" # prompt_template = gr.Dropdown(\n",
|
880 |
-
" # choices=[\n",
|
881 |
-
" # \"langchain_default\",\n",
|
882 |
-
" # \"openai_chatgpt\",\n",
|
883 |
-
" # \"deepmind_sparrow\",\n",
|
884 |
-
" # \"deepmind_gopher\",\n",
|
885 |
-
" # \"anthropic_hhh\",\n",
|
886 |
-
" # ],\n",
|
887 |
-
" # value=\"langchain_default\",\n",
|
888 |
-
" # label=\"Prompt Template\",\n",
|
889 |
-
" # interactive=True,\n",
|
890 |
-
" # )\n",
|
891 |
" temperature = gr.Slider(\n",
|
892 |
" minimum=0.0,\n",
|
893 |
" maximum=2.0,\n",
|
@@ -971,9 +1080,20 @@
|
|
971 |
},
|
972 |
{
|
973 |
"cell_type": "code",
|
974 |
-
"execution_count":
|
975 |
"metadata": {},
|
976 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
977 |
"source": [
|
978 |
"from nbdev.export import nb_export\n",
|
979 |
"nb_export('app.ipynb', lib_path='.', name='app')"
|
|
|
50 |
},
|
51 |
{
|
52 |
"cell_type": "code",
|
53 |
+
"execution_count": 11,
|
54 |
"metadata": {},
|
55 |
"outputs": [],
|
56 |
"source": [
|
|
|
75 |
" if max_new_tokens_supported is True:\n",
|
76 |
" payload[\"parameters\"][\"max_new_tokens\"] = 100\n",
|
77 |
" payload[\"parameters\"][\"repetition_penalty\"]: 1.03\n",
|
78 |
+
" payload[\"parameters\"][\"stop\"] = [\"User:\"]\n",
|
79 |
" else:\n",
|
80 |
" payload[\"parameters\"][\"max_length\"] = 512\n",
|
81 |
"\n",
|
|
|
95 |
{
|
96 |
"data": {
|
97 |
"text/plain": [
|
98 |
+
"{'generated_text': '\\n\\nJoi: Black holes are regions of spacetime where gravity is so strong that nothing, not even light, can escape from inside them. They are the result of huge amounts of mass concentrated in a small space, which causes intense gravitational force. The more massive the mass, the stronger the gravity, and the faster the force of gravity increases with increased mass. Black holes have no size or shape, as they are just a point in spacetime, the event horizon, from which light can no longer'}"
|
99 |
]
|
100 |
},
|
101 |
"execution_count": 5,
|
|
|
112 |
},
|
113 |
{
|
114 |
"cell_type": "code",
|
115 |
+
"execution_count": 37,
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"#|export\n",
|
120 |
+
"def format_history(history, human=\"Human\", bot=\"Assistant\"):\n",
|
121 |
+
" history_input = \"\"\n",
|
122 |
+
" for idx, text in enumerate(history):\n",
|
123 |
+
" if idx % 2 == 0:\n",
|
124 |
+
" history_input += f\"{human}: {text}\\n\\n\"\n",
|
125 |
+
" else:\n",
|
126 |
+
" history_input += f\"{bot}: {text}\\n\\n\"\n",
|
127 |
+
" history_input = history_input.rstrip(\"\\n\")\n",
|
128 |
+
" return history_input"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 38,
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [
|
136 |
+
{
|
137 |
+
"name": "stdout",
|
138 |
+
"output_type": "stream",
|
139 |
+
"text": [
|
140 |
+
"Human: Hello\n",
|
141 |
+
"\n",
|
142 |
+
"Assistant: Hi\n",
|
143 |
+
"\n",
|
144 |
+
"Human: How are you?\n"
|
145 |
+
]
|
146 |
+
}
|
147 |
+
],
|
148 |
+
"source": [
|
149 |
+
"print(format_history([\"Hello\", \"Hi\", \"How are you?\"]))"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": 44,
|
155 |
"metadata": {},
|
156 |
"outputs": [],
|
157 |
"source": [
|
|
|
164 |
" history=[],\n",
|
165 |
"):\n",
|
166 |
" if \"joi\" in model_id:\n",
|
167 |
+
" prompt_filename = \"openassistant_joi.json\"\n",
|
168 |
+
" history_input = format_history(history, human=\"User\", bot=\"Joi\")\n",
|
169 |
" else:\n",
|
170 |
" prompt_filename = \"anthropic_hhh_single.json\"\n",
|
171 |
+
" history_input = format_history(history, human=\"Human\", bot=\"Assistant\")\n",
|
172 |
" with open(f\"prompt_templates/{prompt_filename}\", \"r\") as f:\n",
|
173 |
" prompt_template = json.load(f)\n",
|
174 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
" inputs = prompt_template[\"prompt\"].format(human_input=text_input, history=history_input)\n",
|
176 |
" history.append(text_input)\n",
|
177 |
"\n",
|
|
|
179 |
" print(f\"Inputs: {inputs}\")\n",
|
180 |
"\n",
|
181 |
" output = query_chat_api(model_id, inputs, temperature, top_p)\n",
|
182 |
+
" print(output)\n",
|
183 |
" if isinstance(output, list):\n",
|
184 |
" output = output[0]\n",
|
185 |
+
" if \"joi\" in model_id:\n",
|
186 |
+
" output = output[\"generated_text\"].rstrip(\"\\n\\nUser:\")\n",
|
187 |
+
" else:\n",
|
188 |
+
" output = output[\"generated_text\"].rstrip(\" Human:\")\n",
|
189 |
" history.append(\" \" + output)\n",
|
190 |
"\n",
|
191 |
" chat = [\n",
|
|
|
216 |
},
|
217 |
{
|
218 |
"cell_type": "code",
|
219 |
+
"execution_count": 8,
|
220 |
"metadata": {},
|
221 |
"outputs": [
|
222 |
{
|
|
|
254 |
" json.dump({\"prompt\": template}, f)"
|
255 |
]
|
256 |
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": 17,
|
260 |
+
"metadata": {},
|
261 |
+
"outputs": [
|
262 |
+
{
|
263 |
+
"name": "stdout",
|
264 |
+
"output_type": "stream",
|
265 |
+
"text": [
|
266 |
+
"17\n"
|
267 |
+
]
|
268 |
+
}
|
269 |
+
],
|
270 |
+
"source": [
|
271 |
+
"template = \"\"\"{history}\n",
|
272 |
+
"\n",
|
273 |
+
"User: {human_input}\n",
|
274 |
+
"\n",
|
275 |
+
"Joi:\"\"\"\n",
|
276 |
+
"\n",
|
277 |
+
"print(len(tokenizer(template)[\"input_ids\"]))"
|
278 |
+
]
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"cell_type": "code",
|
282 |
+
"execution_count": 18,
|
283 |
+
"metadata": {},
|
284 |
+
"outputs": [],
|
285 |
+
"source": [
|
286 |
+
"with open(\"prompt_templates/openassistant_joi.json\", \"w\") as f:\n",
|
287 |
+
" json.dump({\"prompt\": template}, f)"
|
288 |
+
]
|
289 |
+
},
|
290 |
{
|
291 |
"cell_type": "code",
|
292 |
"execution_count": 28,
|
|
|
842 |
},
|
843 |
{
|
844 |
"cell_type": "code",
|
845 |
+
"execution_count": 45,
|
846 |
"metadata": {},
|
847 |
"outputs": [],
|
848 |
"source": [
|
|
|
864 |
},
|
865 |
{
|
866 |
"cell_type": "code",
|
867 |
+
"execution_count": 46,
|
868 |
"metadata": {},
|
869 |
"outputs": [
|
870 |
{
|
|
|
873 |
"'So far, the following prompts are available:\\n\\n* `langchain_default`: The default prompt used in the [LangChain library](https://github.com/hwchase17/langchain/blob/bc53c928fc1b221d0038b839d111039d31729def/langchain/chains/conversation/prompt.py#L4). Around 67 tokens long.\\n* `openai_chatgpt`: The prompt used in the OpenAI ChatGPT model. Around 261 tokens long.\\n* `deepmind_Assistant`: The prompt used in the DeepMind Assistant model (Table 7 of [their paper](https://arxiv.org/abs/2209.14375)). Around 880 tokens long.\\n* `deepmind_gopher`: The prompt used in the DeepMind Assistant model (Table A30 of [their paper](https://arxiv.org/abs/2112.11446)). Around 791 tokens long.\\n* `anthropic_hhh`: The prompt used in the [Anthropic HHH models](https://gist.github.com/jareddk/2509330f8ef3d787fc5aaac67aab5f11#file-hhh_prompt-txt). A whopping 6,341 tokens long!\\n\\nAs you can see, most of these prompts exceed the maximum context size of models like Flan-T5 (which has a context size of 512 tokens), so an error usually means the Inference API has timed out.'"
|
874 |
]
|
875 |
},
|
876 |
+
"execution_count": 46,
|
877 |
"metadata": {},
|
878 |
"output_type": "execute_result"
|
879 |
}
|
|
|
892 |
},
|
893 |
{
|
894 |
"cell_type": "code",
|
895 |
+
"execution_count": 47,
|
896 |
"metadata": {},
|
897 |
"outputs": [
|
898 |
{
|
899 |
"name": "stdout",
|
900 |
"output_type": "stream",
|
901 |
"text": [
|
902 |
+
"Running on local URL: http://127.0.0.1:7866\n",
|
903 |
"\n",
|
904 |
"To create a public link, set `share=True` in `launch()`.\n"
|
905 |
]
|
|
|
907 |
{
|
908 |
"data": {
|
909 |
"text/html": [
|
910 |
+
"<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
911 |
],
|
912 |
"text/plain": [
|
913 |
"<IPython.core.display.HTML object>"
|
|
|
920 |
"data": {
|
921 |
"text/plain": []
|
922 |
},
|
923 |
+
"execution_count": 47,
|
924 |
"metadata": {},
|
925 |
"output_type": "execute_result"
|
926 |
+
},
|
927 |
+
{
|
928 |
+
"name": "stdout",
|
929 |
+
"output_type": "stream",
|
930 |
+
"text": [
|
931 |
+
"History: ['What is 2 times 3?']\n",
|
932 |
+
"Inputs: \n",
|
933 |
+
"\n",
|
934 |
+
"User: What is 2 times 3?\n",
|
935 |
+
"\n",
|
936 |
+
"Joi:\n",
|
937 |
+
"{'generated_text': ' 3*2=6\\n\\nUser:'}\n",
|
938 |
+
"History: ['What is 2 times 3?', ' 3*2=6', 'What about 4 times 3?']\n",
|
939 |
+
"Inputs: User: What is 2 times 3?\n",
|
940 |
+
"\n",
|
941 |
+
"Joi: 3*2=6\n",
|
942 |
+
"\n",
|
943 |
+
"User: What about 4 times 3?\n",
|
944 |
+
"\n",
|
945 |
+
"Joi:\n",
|
946 |
+
"{'generated_text': ' 3*4=12\\n\\nUser:'}\n",
|
947 |
+
"History: ['What is 2 times 3?', ' 3*2=6', 'What about 4 times 3?', ' 3*4=12', 'What about -1 times -3?']\n",
|
948 |
+
"Inputs: User: What is 2 times 3?\n",
|
949 |
+
"\n",
|
950 |
+
"Joi: 3*2=6\n",
|
951 |
+
"\n",
|
952 |
+
"User: What about 4 times 3?\n",
|
953 |
+
"\n",
|
954 |
+
"Joi: 3*4=12\n",
|
955 |
+
"\n",
|
956 |
+
"User: What about -1 times -3?\n",
|
957 |
+
"\n",
|
958 |
+
"Joi:\n",
|
959 |
+
"{'generated_text': ' -3*(-1)=3\\n\\nUser:'}\n",
|
960 |
+
"History: ['What can you tell me about llamas?']\n",
|
961 |
+
"Inputs: \n",
|
962 |
+
"\n",
|
963 |
+
"User: What can you tell me about llamas?\n",
|
964 |
+
"\n",
|
965 |
+
"Joi:\n",
|
966 |
+
"{'generated_text': ' Llamas are a large mammal native to South America. They are related to the camelids, which include the alpaca, vicuna, and guanaco. Llamas have a long, thick, curly coat of fur and long, sharp horns. They are very social and socialize with each other. They are also known for their amazing agility and speed. They are considered to be the fastest land animals in the world.'}\n",
|
967 |
+
"History: ['What can you tell me about llamas?', ' Llamas are a large mammal native to South America. They are related to the camelids, which include the alpaca, vicuna, and guanaco. Llamas have a long, thick, curly coat of fur and long, sharp horns. They are very social and socialize with each other. They are also known for their amazing agility and speed. They are considered to be the fastest land animals in the world.', 'Who would win in a battle between a llama and an alpaca?']\n",
|
968 |
+
"Inputs: User: What can you tell me about llamas?\n",
|
969 |
+
"\n",
|
970 |
+
"Joi: Llamas are a large mammal native to South America. They are related to the camelids, which include the alpaca, vicuna, and guanaco. Llamas have a long, thick, curly coat of fur and long, sharp horns. They are very social and socialize with each other. They are also known for their amazing agility and speed. They are considered to be the fastest land animals in the world.\n",
|
971 |
+
"\n",
|
972 |
+
"User: Who would win in a battle between a llama and an alpaca?\n",
|
973 |
+
"\n",
|
974 |
+
"Joi:\n",
|
975 |
+
"{'generated_text': \" That depends on the alpaca. If they are of the same gender, then it depends on the alpaca's age, size, and condition. Generally speaking, the alpaca would win.\"}\n"
|
976 |
+
]
|
977 |
}
|
978 |
],
|
979 |
"source": [
|
|
|
997 |
" label=\"Model\",\n",
|
998 |
" interactive=True,\n",
|
999 |
" )\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1000 |
" temperature = gr.Slider(\n",
|
1001 |
" minimum=0.0,\n",
|
1002 |
" maximum=2.0,\n",
|
|
|
1080 |
},
|
1081 |
{
|
1082 |
"cell_type": "code",
|
1083 |
+
"execution_count": 48,
|
1084 |
"metadata": {},
|
1085 |
+
"outputs": [
|
1086 |
+
{
|
1087 |
+
"name": "stdout",
|
1088 |
+
"output_type": "stream",
|
1089 |
+
"text": [
|
1090 |
+
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
1091 |
+
"To disable this warning, you can either:\n",
|
1092 |
+
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
1093 |
+
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
1094 |
+
]
|
1095 |
+
}
|
1096 |
+
],
|
1097 |
"source": [
|
1098 |
"from nbdev.export import nb_export\n",
|
1099 |
"nb_export('app.ipynb', lib_path='.', name='app')"
|
app.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
|
2 |
|
3 |
# %% auto 0
|
4 |
-
__all__ = ['HF_TOKEN', 'ENDPOINT_URL', 'title', 'description', 'get_model_endpoint_params', 'query_chat_api', '
|
|
|
5 |
|
6 |
# %% app.ipynb 0
|
7 |
import gradio as gr
|
@@ -54,7 +55,7 @@ def query_chat_api(
|
|
54 |
if max_new_tokens_supported is True:
|
55 |
payload["parameters"]["max_new_tokens"] = 100
|
56 |
payload["parameters"]["repetition_penalty"]: 1.03
|
57 |
-
payload["parameters"]["stop"] = ["
|
58 |
else:
|
59 |
payload["parameters"]["max_length"] = 512
|
60 |
|
@@ -67,6 +68,17 @@ def query_chat_api(
|
|
67 |
|
68 |
|
69 |
# %% app.ipynb 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def inference_chat(
|
71 |
model_id,
|
72 |
text_input,
|
@@ -75,20 +87,14 @@ def inference_chat(
|
|
75 |
history=[],
|
76 |
):
|
77 |
if "joi" in model_id:
|
78 |
-
prompt_filename = "
|
|
|
79 |
else:
|
80 |
prompt_filename = "anthropic_hhh_single.json"
|
81 |
-
|
82 |
with open(f"prompt_templates/{prompt_filename}", "r") as f:
|
83 |
prompt_template = json.load(f)
|
84 |
|
85 |
-
history_input = ""
|
86 |
-
for idx, text in enumerate(history):
|
87 |
-
if idx % 2 == 0:
|
88 |
-
history_input += f"Human: {text}\n"
|
89 |
-
else:
|
90 |
-
history_input += f"Assistant: {text}\n"
|
91 |
-
history_input = history_input.rstrip("\n")
|
92 |
inputs = prompt_template["prompt"].format(human_input=text_input, history=history_input)
|
93 |
history.append(text_input)
|
94 |
|
@@ -96,9 +102,13 @@ def inference_chat(
|
|
96 |
print(f"Inputs: {inputs}")
|
97 |
|
98 |
output = query_chat_api(model_id, inputs, temperature, top_p)
|
|
|
99 |
if isinstance(output, list):
|
100 |
output = output[0]
|
101 |
-
|
|
|
|
|
|
|
102 |
history.append(" " + output)
|
103 |
|
104 |
chat = [
|
@@ -108,7 +118,7 @@ def inference_chat(
|
|
108 |
return {chatbot: chat, state: history}
|
109 |
|
110 |
|
111 |
-
# %% app.ipynb
|
112 |
title = """<h1 align="center">Chatty Language Models</h1>"""
|
113 |
description = """Pretrained language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
|
114 |
|
@@ -123,7 +133,7 @@ Assistant: <utterance>
|
|
123 |
In this app, you can explore the outputs of several language models conditioned on different conversational prompts. The models are trained on different datasets and have different objectives, so they will have different personalities and strengths.
|
124 |
"""
|
125 |
|
126 |
-
# %% app.ipynb
|
127 |
with gr.Blocks(
|
128 |
css="""
|
129 |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
|
@@ -143,18 +153,6 @@ with gr.Blocks(
|
|
143 |
label="Model",
|
144 |
interactive=True,
|
145 |
)
|
146 |
-
# prompt_template = gr.Dropdown(
|
147 |
-
# choices=[
|
148 |
-
# "langchain_default",
|
149 |
-
# "openai_chatgpt",
|
150 |
-
# "deepmind_sparrow",
|
151 |
-
# "deepmind_gopher",
|
152 |
-
# "anthropic_hhh",
|
153 |
-
# ],
|
154 |
-
# value="langchain_default",
|
155 |
-
# label="Prompt Template",
|
156 |
-
# interactive=True,
|
157 |
-
# )
|
158 |
temperature = gr.Slider(
|
159 |
minimum=0.0,
|
160 |
maximum=2.0,
|
|
|
1 |
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
|
2 |
|
3 |
# %% auto 0
|
4 |
+
__all__ = ['HF_TOKEN', 'ENDPOINT_URL', 'title', 'description', 'get_model_endpoint_params', 'query_chat_api', 'format_history',
|
5 |
+
'inference_chat']
|
6 |
|
7 |
# %% app.ipynb 0
|
8 |
import gradio as gr
|
|
|
55 |
if max_new_tokens_supported is True:
|
56 |
payload["parameters"]["max_new_tokens"] = 100
|
57 |
payload["parameters"]["repetition_penalty"]: 1.03
|
58 |
+
payload["parameters"]["stop"] = ["User:"]
|
59 |
else:
|
60 |
payload["parameters"]["max_length"] = 512
|
61 |
|
|
|
68 |
|
69 |
|
70 |
# %% app.ipynb 5
|
71 |
+
def format_history(history, human="Human", bot="Assistant"):
|
72 |
+
history_input = ""
|
73 |
+
for idx, text in enumerate(history):
|
74 |
+
if idx % 2 == 0:
|
75 |
+
history_input += f"{human}: {text}\n\n"
|
76 |
+
else:
|
77 |
+
history_input += f"{bot}: {text}\n\n"
|
78 |
+
history_input = history_input.rstrip("\n")
|
79 |
+
return history_input
|
80 |
+
|
81 |
+
# %% app.ipynb 7
|
82 |
def inference_chat(
|
83 |
model_id,
|
84 |
text_input,
|
|
|
87 |
history=[],
|
88 |
):
|
89 |
if "joi" in model_id:
|
90 |
+
prompt_filename = "openassistant_joi.json"
|
91 |
+
history_input = format_history(history, human="User", bot="Joi")
|
92 |
else:
|
93 |
prompt_filename = "anthropic_hhh_single.json"
|
94 |
+
history_input = format_history(history, human="Human", bot="Assistant")
|
95 |
with open(f"prompt_templates/{prompt_filename}", "r") as f:
|
96 |
prompt_template = json.load(f)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
inputs = prompt_template["prompt"].format(human_input=text_input, history=history_input)
|
99 |
history.append(text_input)
|
100 |
|
|
|
102 |
print(f"Inputs: {inputs}")
|
103 |
|
104 |
output = query_chat_api(model_id, inputs, temperature, top_p)
|
105 |
+
print(output)
|
106 |
if isinstance(output, list):
|
107 |
output = output[0]
|
108 |
+
if "joi" in model_id:
|
109 |
+
output = output["generated_text"].rstrip("\n\nUser:")
|
110 |
+
else:
|
111 |
+
output = output["generated_text"].rstrip(" Human:")
|
112 |
history.append(" " + output)
|
113 |
|
114 |
chat = [
|
|
|
118 |
return {chatbot: chat, state: history}
|
119 |
|
120 |
|
121 |
+
# %% app.ipynb 25
|
122 |
title = """<h1 align="center">Chatty Language Models</h1>"""
|
123 |
description = """Pretrained language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
|
124 |
|
|
|
133 |
In this app, you can explore the outputs of several language models conditioned on different conversational prompts. The models are trained on different datasets and have different objectives, so they will have different personalities and strengths.
|
134 |
"""
|
135 |
|
136 |
+
# %% app.ipynb 27
|
137 |
with gr.Blocks(
|
138 |
css="""
|
139 |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
|
|
|
153 |
label="Model",
|
154 |
interactive=True,
|
155 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
temperature = gr.Slider(
|
157 |
minimum=0.0,
|
158 |
maximum=2.0,
|
prompt_templates/openassistant_joi.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"prompt": "{history}\n\nUser: {human_input}\n\nJoi:"}
|