lewtun HF staff commited on
Commit
caf57a3
1 Parent(s): d0f9fb0

Use proper prompt for Joi

Browse files
Files changed (3) hide show
  1. app.ipynb +156 -36
  2. app.py +24 -26
  3. prompt_templates/openassistant_joi.json +1 -0
app.ipynb CHANGED
@@ -50,7 +50,7 @@
50
  },
51
  {
52
  "cell_type": "code",
53
- "execution_count": 4,
54
  "metadata": {},
55
  "outputs": [],
56
  "source": [
@@ -75,7 +75,7 @@
75
  " if max_new_tokens_supported is True:\n",
76
  " payload[\"parameters\"][\"max_new_tokens\"] = 100\n",
77
  " payload[\"parameters\"][\"repetition_penalty\"]: 1.03\n",
78
- " payload[\"parameters\"][\"stop\"] = [\"Human:\"]\n",
79
  " else:\n",
80
  " payload[\"parameters\"][\"max_length\"] = 512\n",
81
  "\n",
@@ -95,7 +95,7 @@
95
  {
96
  "data": {
97
  "text/plain": [
98
- "{'generated_text': '\\n\\nJoi: Black holes are one of the most fascinating topics in astronomy. They’re objects in space that contain massive amounts of matter, and have such powerful gravity that they warp spacetime. It is thought that black holes might be the most compact objects in the universe. It is thought that black holes are the most powerful sources of gravity in the universe and that they occur in various forms, from stellar-sized black holes to the supermassive black holes at the hearts of galaxies. Black'}"
99
  ]
100
  },
101
  "execution_count": 5,
@@ -112,7 +112,46 @@
112
  },
113
  {
114
  "cell_type": "code",
115
- "execution_count": 6,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  "metadata": {},
117
  "outputs": [],
118
  "source": [
@@ -125,20 +164,14 @@
125
  " history=[],\n",
126
  "):\n",
127
  " if \"joi\" in model_id:\n",
128
- " prompt_filename = \"langchain_default.json\"\n",
 
129
  " else:\n",
130
  " prompt_filename = \"anthropic_hhh_single.json\"\n",
131
- " print(prompt_filename)\n",
132
  " with open(f\"prompt_templates/{prompt_filename}\", \"r\") as f:\n",
133
  " prompt_template = json.load(f)\n",
134
  "\n",
135
- " history_input = \"\"\n",
136
- " for idx, text in enumerate(history):\n",
137
- " if idx % 2 == 0:\n",
138
- " history_input += f\"Human: {text}\\n\"\n",
139
- " else:\n",
140
- " history_input += f\"Assistant: {text}\\n\"\n",
141
- " history_input = history_input.rstrip(\"\\n\")\n",
142
  " inputs = prompt_template[\"prompt\"].format(human_input=text_input, history=history_input)\n",
143
  " history.append(text_input)\n",
144
  "\n",
@@ -146,9 +179,13 @@
146
  " print(f\"Inputs: {inputs}\")\n",
147
  "\n",
148
  " output = query_chat_api(model_id, inputs, temperature, top_p)\n",
 
149
  " if isinstance(output, list):\n",
150
  " output = output[0]\n",
151
- " output = output[\"generated_text\"].rstrip(\" Human:\")\n",
 
 
 
152
  " history.append(\" \" + output)\n",
153
  "\n",
154
  " chat = [\n",
@@ -179,7 +216,7 @@
179
  },
180
  {
181
  "cell_type": "code",
182
- "execution_count": 20,
183
  "metadata": {},
184
  "outputs": [
185
  {
@@ -217,6 +254,39 @@
217
  " json.dump({\"prompt\": template}, f)"
218
  ]
219
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  {
221
  "cell_type": "code",
222
  "execution_count": 28,
@@ -772,7 +842,7 @@
772
  },
773
  {
774
  "cell_type": "code",
775
- "execution_count": 7,
776
  "metadata": {},
777
  "outputs": [],
778
  "source": [
@@ -794,7 +864,7 @@
794
  },
795
  {
796
  "cell_type": "code",
797
- "execution_count": 32,
798
  "metadata": {},
799
  "outputs": [
800
  {
@@ -803,7 +873,7 @@
803
  "'So far, the following prompts are available:\\n\\n* `langchain_default`: The default prompt used in the [LangChain library](https://github.com/hwchase17/langchain/blob/bc53c928fc1b221d0038b839d111039d31729def/langchain/chains/conversation/prompt.py#L4). Around 67 tokens long.\\n* `openai_chatgpt`: The prompt used in the OpenAI ChatGPT model. Around 261 tokens long.\\n* `deepmind_Assistant`: The prompt used in the DeepMind Assistant model (Table 7 of [their paper](https://arxiv.org/abs/2209.14375)). Around 880 tokens long.\\n* `deepmind_gopher`: The prompt used in the DeepMind Assistant model (Table A30 of [their paper](https://arxiv.org/abs/2112.11446)). Around 791 tokens long.\\n* `anthropic_hhh`: The prompt used in the [Anthropic HHH models](https://gist.github.com/jareddk/2509330f8ef3d787fc5aaac67aab5f11#file-hhh_prompt-txt). A whopping 6,341 tokens long!\\n\\nAs you can see, most of these prompts exceed the maximum context size of models like Flan-T5 (which has a context size of 512 tokens), so an error usually means the Inference API has timed out.'"
804
  ]
805
  },
806
- "execution_count": 32,
807
  "metadata": {},
808
  "output_type": "execute_result"
809
  }
@@ -822,14 +892,14 @@
822
  },
823
  {
824
  "cell_type": "code",
825
- "execution_count": 8,
826
  "metadata": {},
827
  "outputs": [
828
  {
829
  "name": "stdout",
830
  "output_type": "stream",
831
  "text": [
832
- "Running on local URL: http://127.0.0.1:7860\n",
833
  "\n",
834
  "To create a public link, set `share=True` in `launch()`.\n"
835
  ]
@@ -837,7 +907,7 @@
837
  {
838
  "data": {
839
  "text/html": [
840
- "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
841
  ],
842
  "text/plain": [
843
  "<IPython.core.display.HTML object>"
@@ -850,9 +920,60 @@
850
  "data": {
851
  "text/plain": []
852
  },
853
- "execution_count": 8,
854
  "metadata": {},
855
  "output_type": "execute_result"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
856
  }
857
  ],
858
  "source": [
@@ -876,18 +997,6 @@
876
  " label=\"Model\",\n",
877
  " interactive=True,\n",
878
  " )\n",
879
- " # prompt_template = gr.Dropdown(\n",
880
- " # choices=[\n",
881
- " # \"langchain_default\",\n",
882
- " # \"openai_chatgpt\",\n",
883
- " # \"deepmind_sparrow\",\n",
884
- " # \"deepmind_gopher\",\n",
885
- " # \"anthropic_hhh\",\n",
886
- " # ],\n",
887
- " # value=\"langchain_default\",\n",
888
- " # label=\"Prompt Template\",\n",
889
- " # interactive=True,\n",
890
- " # )\n",
891
  " temperature = gr.Slider(\n",
892
  " minimum=0.0,\n",
893
  " maximum=2.0,\n",
@@ -971,9 +1080,20 @@
971
  },
972
  {
973
  "cell_type": "code",
974
- "execution_count": 9,
975
  "metadata": {},
976
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
977
  "source": [
978
  "from nbdev.export import nb_export\n",
979
  "nb_export('app.ipynb', lib_path='.', name='app')"
 
50
  },
51
  {
52
  "cell_type": "code",
53
+ "execution_count": 11,
54
  "metadata": {},
55
  "outputs": [],
56
  "source": [
 
75
  " if max_new_tokens_supported is True:\n",
76
  " payload[\"parameters\"][\"max_new_tokens\"] = 100\n",
77
  " payload[\"parameters\"][\"repetition_penalty\"]: 1.03\n",
78
+ " payload[\"parameters\"][\"stop\"] = [\"User:\"]\n",
79
  " else:\n",
80
  " payload[\"parameters\"][\"max_length\"] = 512\n",
81
  "\n",
 
95
  {
96
  "data": {
97
  "text/plain": [
98
+ "{'generated_text': '\\n\\nJoi: Black holes are regions of spacetime where gravity is so strong that nothing, not even light, can escape from inside them. They are the result of huge amounts of mass concentrated in a small space, which causes intense gravitational force. The more massive the mass, the stronger the gravity, and the faster the force of gravity increases with increased mass. Black holes have no size or shape, as they are just a point in spacetime, the event horizon, from which light can no longer'}"
99
  ]
100
  },
101
  "execution_count": 5,
 
112
  },
113
  {
114
  "cell_type": "code",
115
+ "execution_count": 37,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "#|export\n",
120
+ "def format_history(history, human=\"Human\", bot=\"Assistant\"):\n",
121
+ " history_input = \"\"\n",
122
+ " for idx, text in enumerate(history):\n",
123
+ " if idx % 2 == 0:\n",
124
+ " history_input += f\"{human}: {text}\\n\\n\"\n",
125
+ " else:\n",
126
+ " history_input += f\"{bot}: {text}\\n\\n\"\n",
127
+ " history_input = history_input.rstrip(\"\\n\")\n",
128
+ " return history_input"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 38,
134
+ "metadata": {},
135
+ "outputs": [
136
+ {
137
+ "name": "stdout",
138
+ "output_type": "stream",
139
+ "text": [
140
+ "Human: Hello\n",
141
+ "\n",
142
+ "Assistant: Hi\n",
143
+ "\n",
144
+ "Human: How are you?\n"
145
+ ]
146
+ }
147
+ ],
148
+ "source": [
149
+ "print(format_history([\"Hello\", \"Hi\", \"How are you?\"]))"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 44,
155
  "metadata": {},
156
  "outputs": [],
157
  "source": [
 
164
  " history=[],\n",
165
  "):\n",
166
  " if \"joi\" in model_id:\n",
167
+ " prompt_filename = \"openassistant_joi.json\"\n",
168
+ " history_input = format_history(history, human=\"User\", bot=\"Joi\")\n",
169
  " else:\n",
170
  " prompt_filename = \"anthropic_hhh_single.json\"\n",
171
+ " history_input = format_history(history, human=\"Human\", bot=\"Assistant\")\n",
172
  " with open(f\"prompt_templates/{prompt_filename}\", \"r\") as f:\n",
173
  " prompt_template = json.load(f)\n",
174
  "\n",
 
 
 
 
 
 
 
175
  " inputs = prompt_template[\"prompt\"].format(human_input=text_input, history=history_input)\n",
176
  " history.append(text_input)\n",
177
  "\n",
 
179
  " print(f\"Inputs: {inputs}\")\n",
180
  "\n",
181
  " output = query_chat_api(model_id, inputs, temperature, top_p)\n",
182
+ " print(output)\n",
183
  " if isinstance(output, list):\n",
184
  " output = output[0]\n",
185
+ " if \"joi\" in model_id:\n",
186
+ " output = output[\"generated_text\"].rstrip(\"\\n\\nUser:\")\n",
187
+ " else:\n",
188
+ " output = output[\"generated_text\"].rstrip(\" Human:\")\n",
189
  " history.append(\" \" + output)\n",
190
  "\n",
191
  " chat = [\n",
 
216
  },
217
  {
218
  "cell_type": "code",
219
+ "execution_count": 8,
220
  "metadata": {},
221
  "outputs": [
222
  {
 
254
  " json.dump({\"prompt\": template}, f)"
255
  ]
256
  },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 17,
260
+ "metadata": {},
261
+ "outputs": [
262
+ {
263
+ "name": "stdout",
264
+ "output_type": "stream",
265
+ "text": [
266
+ "17\n"
267
+ ]
268
+ }
269
+ ],
270
+ "source": [
271
+ "template = \"\"\"{history}\n",
272
+ "\n",
273
+ "User: {human_input}\n",
274
+ "\n",
275
+ "Joi:\"\"\"\n",
276
+ "\n",
277
+ "print(len(tokenizer(template)[\"input_ids\"]))"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": 18,
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "with open(\"prompt_templates/openassistant_joi.json\", \"w\") as f:\n",
287
+ " json.dump({\"prompt\": template}, f)"
288
+ ]
289
+ },
290
  {
291
  "cell_type": "code",
292
  "execution_count": 28,
 
842
  },
843
  {
844
  "cell_type": "code",
845
+ "execution_count": 45,
846
  "metadata": {},
847
  "outputs": [],
848
  "source": [
 
864
  },
865
  {
866
  "cell_type": "code",
867
+ "execution_count": 46,
868
  "metadata": {},
869
  "outputs": [
870
  {
 
873
  "'So far, the following prompts are available:\\n\\n* `langchain_default`: The default prompt used in the [LangChain library](https://github.com/hwchase17/langchain/blob/bc53c928fc1b221d0038b839d111039d31729def/langchain/chains/conversation/prompt.py#L4). Around 67 tokens long.\\n* `openai_chatgpt`: The prompt used in the OpenAI ChatGPT model. Around 261 tokens long.\\n* `deepmind_Assistant`: The prompt used in the DeepMind Assistant model (Table 7 of [their paper](https://arxiv.org/abs/2209.14375)). Around 880 tokens long.\\n* `deepmind_gopher`: The prompt used in the DeepMind Assistant model (Table A30 of [their paper](https://arxiv.org/abs/2112.11446)). Around 791 tokens long.\\n* `anthropic_hhh`: The prompt used in the [Anthropic HHH models](https://gist.github.com/jareddk/2509330f8ef3d787fc5aaac67aab5f11#file-hhh_prompt-txt). A whopping 6,341 tokens long!\\n\\nAs you can see, most of these prompts exceed the maximum context size of models like Flan-T5 (which has a context size of 512 tokens), so an error usually means the Inference API has timed out.'"
874
  ]
875
  },
876
+ "execution_count": 46,
877
  "metadata": {},
878
  "output_type": "execute_result"
879
  }
 
892
  },
893
  {
894
  "cell_type": "code",
895
+ "execution_count": 47,
896
  "metadata": {},
897
  "outputs": [
898
  {
899
  "name": "stdout",
900
  "output_type": "stream",
901
  "text": [
902
+ "Running on local URL: http://127.0.0.1:7866\n",
903
  "\n",
904
  "To create a public link, set `share=True` in `launch()`.\n"
905
  ]
 
907
  {
908
  "data": {
909
  "text/html": [
910
+ "<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
911
  ],
912
  "text/plain": [
913
  "<IPython.core.display.HTML object>"
 
920
  "data": {
921
  "text/plain": []
922
  },
923
+ "execution_count": 47,
924
  "metadata": {},
925
  "output_type": "execute_result"
926
+ },
927
+ {
928
+ "name": "stdout",
929
+ "output_type": "stream",
930
+ "text": [
931
+ "History: ['What is 2 times 3?']\n",
932
+ "Inputs: \n",
933
+ "\n",
934
+ "User: What is 2 times 3?\n",
935
+ "\n",
936
+ "Joi:\n",
937
+ "{'generated_text': ' 3*2=6\\n\\nUser:'}\n",
938
+ "History: ['What is 2 times 3?', ' 3*2=6', 'What about 4 times 3?']\n",
939
+ "Inputs: User: What is 2 times 3?\n",
940
+ "\n",
941
+ "Joi: 3*2=6\n",
942
+ "\n",
943
+ "User: What about 4 times 3?\n",
944
+ "\n",
945
+ "Joi:\n",
946
+ "{'generated_text': ' 3*4=12\\n\\nUser:'}\n",
947
+ "History: ['What is 2 times 3?', ' 3*2=6', 'What about 4 times 3?', ' 3*4=12', 'What about -1 times -3?']\n",
948
+ "Inputs: User: What is 2 times 3?\n",
949
+ "\n",
950
+ "Joi: 3*2=6\n",
951
+ "\n",
952
+ "User: What about 4 times 3?\n",
953
+ "\n",
954
+ "Joi: 3*4=12\n",
955
+ "\n",
956
+ "User: What about -1 times -3?\n",
957
+ "\n",
958
+ "Joi:\n",
959
+ "{'generated_text': ' -3*(-1)=3\\n\\nUser:'}\n",
960
+ "History: ['What can you tell me about llamas?']\n",
961
+ "Inputs: \n",
962
+ "\n",
963
+ "User: What can you tell me about llamas?\n",
964
+ "\n",
965
+ "Joi:\n",
966
+ "{'generated_text': ' Llamas are a large mammal native to South America. They are related to the camelids, which include the alpaca, vicuna, and guanaco. Llamas have a long, thick, curly coat of fur and long, sharp horns. They are very social and socialize with each other. They are also known for their amazing agility and speed. They are considered to be the fastest land animals in the world.'}\n",
967
+ "History: ['What can you tell me about llamas?', ' Llamas are a large mammal native to South America. They are related to the camelids, which include the alpaca, vicuna, and guanaco. Llamas have a long, thick, curly coat of fur and long, sharp horns. They are very social and socialize with each other. They are also known for their amazing agility and speed. They are considered to be the fastest land animals in the world.', 'Who would win in a battle between a llama and an alpaca?']\n",
968
+ "Inputs: User: What can you tell me about llamas?\n",
969
+ "\n",
970
+ "Joi: Llamas are a large mammal native to South America. They are related to the camelids, which include the alpaca, vicuna, and guanaco. Llamas have a long, thick, curly coat of fur and long, sharp horns. They are very social and socialize with each other. They are also known for their amazing agility and speed. They are considered to be the fastest land animals in the world.\n",
971
+ "\n",
972
+ "User: Who would win in a battle between a llama and an alpaca?\n",
973
+ "\n",
974
+ "Joi:\n",
975
+ "{'generated_text': \" That depends on the alpaca. If they are of the same gender, then it depends on the alpaca's age, size, and condition. Generally speaking, the alpaca would win.\"}\n"
976
+ ]
977
  }
978
  ],
979
  "source": [
 
997
  " label=\"Model\",\n",
998
  " interactive=True,\n",
999
  " )\n",
 
 
 
 
 
 
 
 
 
 
 
 
1000
  " temperature = gr.Slider(\n",
1001
  " minimum=0.0,\n",
1002
  " maximum=2.0,\n",
 
1080
  },
1081
  {
1082
  "cell_type": "code",
1083
+ "execution_count": 48,
1084
  "metadata": {},
1085
+ "outputs": [
1086
+ {
1087
+ "name": "stdout",
1088
+ "output_type": "stream",
1089
+ "text": [
1090
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1091
+ "To disable this warning, you can either:\n",
1092
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1093
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
1094
+ ]
1095
+ }
1096
+ ],
1097
  "source": [
1098
  "from nbdev.export import nb_export\n",
1099
  "nb_export('app.ipynb', lib_path='.', name='app')"
app.py CHANGED
@@ -1,7 +1,8 @@
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
2
 
3
  # %% auto 0
4
- __all__ = ['HF_TOKEN', 'ENDPOINT_URL', 'title', 'description', 'get_model_endpoint_params', 'query_chat_api', 'inference_chat']
 
5
 
6
  # %% app.ipynb 0
7
  import gradio as gr
@@ -54,7 +55,7 @@ def query_chat_api(
54
  if max_new_tokens_supported is True:
55
  payload["parameters"]["max_new_tokens"] = 100
56
  payload["parameters"]["repetition_penalty"]: 1.03
57
- payload["parameters"]["stop"] = ["Human:"]
58
  else:
59
  payload["parameters"]["max_length"] = 512
60
 
@@ -67,6 +68,17 @@ def query_chat_api(
67
 
68
 
69
  # %% app.ipynb 5
 
 
 
 
 
 
 
 
 
 
 
70
  def inference_chat(
71
  model_id,
72
  text_input,
@@ -75,20 +87,14 @@ def inference_chat(
75
  history=[],
76
  ):
77
  if "joi" in model_id:
78
- prompt_filename = "langchain_default.json"
 
79
  else:
80
  prompt_filename = "anthropic_hhh_single.json"
81
- print(prompt_filename)
82
  with open(f"prompt_templates/{prompt_filename}", "r") as f:
83
  prompt_template = json.load(f)
84
 
85
- history_input = ""
86
- for idx, text in enumerate(history):
87
- if idx % 2 == 0:
88
- history_input += f"Human: {text}\n"
89
- else:
90
- history_input += f"Assistant: {text}\n"
91
- history_input = history_input.rstrip("\n")
92
  inputs = prompt_template["prompt"].format(human_input=text_input, history=history_input)
93
  history.append(text_input)
94
 
@@ -96,9 +102,13 @@ def inference_chat(
96
  print(f"Inputs: {inputs}")
97
 
98
  output = query_chat_api(model_id, inputs, temperature, top_p)
 
99
  if isinstance(output, list):
100
  output = output[0]
101
- output = output["generated_text"].rstrip(" Human:")
 
 
 
102
  history.append(" " + output)
103
 
104
  chat = [
@@ -108,7 +118,7 @@ def inference_chat(
108
  return {chatbot: chat, state: history}
109
 
110
 
111
- # %% app.ipynb 21
112
  title = """<h1 align="center">Chatty Language Models</h1>"""
113
  description = """Pretrained language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
114
 
@@ -123,7 +133,7 @@ Assistant: <utterance>
123
  In this app, you can explore the outputs of several language models conditioned on different conversational prompts. The models are trained on different datasets and have different objectives, so they will have different personalities and strengths.
124
  """
125
 
126
- # %% app.ipynb 23
127
  with gr.Blocks(
128
  css="""
129
  .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
@@ -143,18 +153,6 @@ with gr.Blocks(
143
  label="Model",
144
  interactive=True,
145
  )
146
- # prompt_template = gr.Dropdown(
147
- # choices=[
148
- # "langchain_default",
149
- # "openai_chatgpt",
150
- # "deepmind_sparrow",
151
- # "deepmind_gopher",
152
- # "anthropic_hhh",
153
- # ],
154
- # value="langchain_default",
155
- # label="Prompt Template",
156
- # interactive=True,
157
- # )
158
  temperature = gr.Slider(
159
  minimum=0.0,
160
  maximum=2.0,
 
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
2
 
3
  # %% auto 0
4
+ __all__ = ['HF_TOKEN', 'ENDPOINT_URL', 'title', 'description', 'get_model_endpoint_params', 'query_chat_api', 'format_history',
5
+ 'inference_chat']
6
 
7
  # %% app.ipynb 0
8
  import gradio as gr
 
55
  if max_new_tokens_supported is True:
56
  payload["parameters"]["max_new_tokens"] = 100
57
  payload["parameters"]["repetition_penalty"]: 1.03
58
+ payload["parameters"]["stop"] = ["User:"]
59
  else:
60
  payload["parameters"]["max_length"] = 512
61
 
 
68
 
69
 
70
  # %% app.ipynb 5
71
+ def format_history(history, human="Human", bot="Assistant"):
72
+ history_input = ""
73
+ for idx, text in enumerate(history):
74
+ if idx % 2 == 0:
75
+ history_input += f"{human}: {text}\n\n"
76
+ else:
77
+ history_input += f"{bot}: {text}\n\n"
78
+ history_input = history_input.rstrip("\n")
79
+ return history_input
80
+
81
+ # %% app.ipynb 7
82
  def inference_chat(
83
  model_id,
84
  text_input,
 
87
  history=[],
88
  ):
89
  if "joi" in model_id:
90
+ prompt_filename = "openassistant_joi.json"
91
+ history_input = format_history(history, human="User", bot="Joi")
92
  else:
93
  prompt_filename = "anthropic_hhh_single.json"
94
+ history_input = format_history(history, human="Human", bot="Assistant")
95
  with open(f"prompt_templates/{prompt_filename}", "r") as f:
96
  prompt_template = json.load(f)
97
 
 
 
 
 
 
 
 
98
  inputs = prompt_template["prompt"].format(human_input=text_input, history=history_input)
99
  history.append(text_input)
100
 
 
102
  print(f"Inputs: {inputs}")
103
 
104
  output = query_chat_api(model_id, inputs, temperature, top_p)
105
+ print(output)
106
  if isinstance(output, list):
107
  output = output[0]
108
+ if "joi" in model_id:
109
+ output = output["generated_text"].rstrip("\n\nUser:")
110
+ else:
111
+ output = output["generated_text"].rstrip(" Human:")
112
  history.append(" " + output)
113
 
114
  chat = [
 
118
  return {chatbot: chat, state: history}
119
 
120
 
121
+ # %% app.ipynb 25
122
  title = """<h1 align="center">Chatty Language Models</h1>"""
123
  description = """Pretrained language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
124
 
 
133
  In this app, you can explore the outputs of several language models conditioned on different conversational prompts. The models are trained on different datasets and have different objectives, so they will have different personalities and strengths.
134
  """
135
 
136
+ # %% app.ipynb 27
137
  with gr.Blocks(
138
  css="""
139
  .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
 
153
  label="Model",
154
  interactive=True,
155
  )
 
 
 
 
 
 
 
 
 
 
 
 
156
  temperature = gr.Slider(
157
  minimum=0.0,
158
  maximum=2.0,
prompt_templates/openassistant_joi.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"prompt": "{history}\n\nUser: {human_input}\n\nJoi:"}