zuxin-llm commited on
Commit
b41c8df
1 Parent(s): 1b6ae43

Upload multi_turn_xlam.ipynb

Browse files
Files changed (1) hide show
  1. example/multi_turn_xlam.ipynb +459 -0
example/multi_turn_xlam.ipynb ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "ce4a9ccf-4bd6-43fb-a24d-b6a7da401a96",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Load xLAM model"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "b1351d81-4502-4b65-b88a-464acd0e80f8",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import torch \n",
19
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
20
+ "torch.random.manual_seed(0) \n",
21
+ "\n",
22
+ "model_name = \"Salesforce/xLAM-7b-r\"\n",
23
+ "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", torch_dtype=\"auto\", trust_remote_code=True)\n",
24
+ "tokenizer = AutoTokenizer.from_pretrained(model_name) "
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "id": "2cdd5bae-da43-4713-9956-360f1f3a9721",
30
+ "metadata": {},
31
+ "source": [
32
+ "## Build the prompt"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 1,
38
+ "id": "e138e9f6-0543-427c-bce6-b4f14765a040",
39
+ "metadata": {
40
+ "tags": []
41
+ },
42
+ "outputs": [],
43
+ "source": [
44
+ "import json\n",
45
+ "\n",
46
+ "# Please use our provided instruction prompt for best performance\n",
47
+ "task_instruction = \"\"\"\n",
48
+ "Based on the previous context and API request history, generate an API request or a response as an AI assistant.\"\"\".strip()\n",
49
+ "\n",
50
+ "format_instruction = \"\"\"\n",
51
+ "The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make \n",
52
+ "tool_calls an empty list \"[]\".\n",
53
+ "```\n",
54
+ "{\"thought\": \"the thought process, or an empty string\", \"tool_calls\": [{\"name\": \"api_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}}]}\n",
55
+ "```\n",
56
+ "\"\"\".strip()\n",
57
+ "\n",
58
+ "get_weather_api = {\n",
59
+ " \"name\": \"get_weather\",\n",
60
+ " \"description\": \"Get the current weather for a location\",\n",
61
+ " \"parameters\": {\n",
62
+ " \"type\": \"object\",\n",
63
+ " \"properties\": {\n",
64
+ " \"location\": {\n",
65
+ " \"type\": \"string\",\n",
66
+ " \"description\": \"The city and state, e.g. San Francisco, New York\"\n",
67
+ " },\n",
68
+ " \"unit\": {\n",
69
+ " \"type\": \"string\",\n",
70
+ " \"enum\": [\"celsius\", \"fahrenheit\"],\n",
71
+ " \"description\": \"The unit of temperature to return\"\n",
72
+ " }\n",
73
+ " },\n",
74
+ " \"required\": [\"location\"]\n",
75
+ " }\n",
76
+ "}\n",
77
+ "\n",
78
+ "search_api = {\n",
79
+ " \"name\": \"search\",\n",
80
+ " \"description\": \"Search for information on the internet\",\n",
81
+ " \"parameters\": {\n",
82
+ " \"type\": \"object\",\n",
83
+ " \"properties\": {\n",
84
+ " \"query\": {\n",
85
+ " \"type\": \"string\",\n",
86
+ " \"description\": \"The search query, e.g. 'latest news on AI'\"\n",
87
+ " }\n",
88
+ " },\n",
89
+ " \"required\": [\"query\"]\n",
90
+ " }\n",
91
+ "}\n",
92
+ "\n",
93
+ "openai_format_tools = [get_weather_api, search_api]\n",
94
+ "\n",
95
+ "# Define the input query and available tools\n",
96
+ "query = \"What's the weather like in New York in fahrenheit?\"\n",
97
+ "\n",
98
+ "# Helper function to convert openai format tools to our more concise xLAM format\n",
99
+ "def convert_to_xlam_tool(tools):\n",
100
+ " ''''''\n",
101
+ " if isinstance(tools, dict):\n",
102
+ " return {\n",
103
+ " \"name\": tools[\"name\"],\n",
104
+ " \"description\": tools[\"description\"],\n",
105
+ " \"parameters\": {k: v for k, v in tools[\"parameters\"].get(\"properties\", {}).items()}\n",
106
+ " }\n",
107
+ " elif isinstance(tools, list):\n",
108
+ " return [convert_to_xlam_tool(tool) for tool in tools]\n",
109
+ " else:\n",
110
+ " return tools\n",
111
+ "\n",
112
+ "def build_conversation_history_prompt(conversation_history: str):\n",
113
+ " parsed_history = []\n",
114
+ " for step_data in conversation_history:\n",
115
+ " parsed_history.append({\n",
116
+ " \"step_id\": step_data[\"step_id\"],\n",
117
+ " \"thought\": step_data[\"thought\"],\n",
118
+ " \"tool_calls\": step_data[\"tool_calls\"],\n",
119
+ " \"next_observation\": step_data[\"next_observation\"],\n",
120
+ " \"user_input\": step_data['user_input']\n",
121
+ " })\n",
122
+ " \n",
123
+ " history_string = json.dumps(parsed_history)\n",
124
+ " return f\"\\n[BEGIN OF HISTORY STEPS]\\n{history_string}\\n[END OF HISTORY STEPS]\\n\"\n",
125
+ " \n",
126
+ " \n",
127
+ "# Helper function to build the input prompt for our model\n",
128
+ "def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str, conversation_history: list):\n",
129
+ " prompt = f\"[BEGIN OF TASK INSTRUCTION]\\n{task_instruction}\\n[END OF TASK INSTRUCTION]\\n\\n\"\n",
130
+ " prompt += f\"[BEGIN OF AVAILABLE TOOLS]\\n{json.dumps(xlam_format_tools)}\\n[END OF AVAILABLE TOOLS]\\n\\n\"\n",
131
+ " prompt += f\"[BEGIN OF FORMAT INSTRUCTION]\\n{format_instruction}\\n[END OF FORMAT INSTRUCTION]\\n\\n\"\n",
132
+ " prompt += f\"[BEGIN OF QUERY]\\n{query}\\n[END OF QUERY]\\n\\n\"\n",
133
+ " \n",
134
+ " if len(conversation_history) > 0: prompt += build_conversation_history_prompt(conversation_history)\n",
135
+ " return prompt\n",
136
+ "\n",
137
+ "\n",
138
+ " \n",
139
+ "# Build the input and start the inference\n",
140
+ "xlam_format_tools = convert_to_xlam_tool(openai_format_tools)\n",
141
+ "\n",
142
+ "conversation_history = []\n",
143
+ "content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history)\n",
144
+ "\n",
145
+ "messages=[\n",
146
+ " { 'role': 'user', 'content': content}\n",
147
+ "]\n"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 2,
153
+ "id": "ff7bccd5-fa04-4fbe-92b3-13f58914da4d",
154
+ "metadata": {
155
+ "tags": []
156
+ },
157
+ "outputs": [
158
+ {
159
+ "name": "stdout",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "[BEGIN OF TASK INSTRUCTION]\n",
163
+ "Based on the previous context and API request history, generate an API request or a response as an AI assistant.\n",
164
+ "[END OF TASK INSTRUCTION]\n",
165
+ "\n",
166
+ "[BEGIN OF AVAILABLE TOOLS]\n",
167
+ "[{\"name\": \"get_weather\", \"description\": \"Get the current weather for a location\", \"parameters\": {\"location\": {\"type\": \"string\", \"description\": \"The city and state, e.g. San Francisco, New York\"}, \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"], \"description\": \"The unit of temperature to return\"}}}, {\"name\": \"search\", \"description\": \"Search for information on the internet\", \"parameters\": {\"query\": {\"type\": \"string\", \"description\": \"The search query, e.g. 'latest news on AI'\"}}}]\n",
168
+ "[END OF AVAILABLE TOOLS]\n",
169
+ "\n",
170
+ "[BEGIN OF FORMAT INSTRUCTION]\n",
171
+ "The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make \n",
172
+ "tool_calls an empty list \"[]\".\n",
173
+ "```\n",
174
+ "{\"thought\": \"the thought process, or an empty string\", \"tool_calls\": [{\"name\": \"api_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}}]}\n",
175
+ "```\n",
176
+ "[END OF FORMAT INSTRUCTION]\n",
177
+ "\n",
178
+ "[BEGIN OF QUERY]\n",
179
+ "What's the weather like in New York in fahrenheit?\n",
180
+ "[END OF QUERY]\n",
181
+ "\n",
182
+ "\n"
183
+ ]
184
+ }
185
+ ],
186
+ "source": [
187
+ "print(content)"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "id": "a5fb0006-9f5d-4d79-a8cd-819bad627441",
193
+ "metadata": {},
194
+ "source": [
195
+ "## Get the model output (agent_action)"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "id": "cbe56588-c786-4913-9062-373a22a92e08",
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
206
+ "\n",
207
+ "# tokenizer.eos_token_id is the id of <|EOT|> token\n",
208
+ "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
209
+ "agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "markdown",
214
+ "id": "b20ed2ae-86f6-489b-ad54-fe7ea911667b",
215
+ "metadata": {},
216
+ "source": [
217
+ "For demo purpose, we use an example agent_action"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 3,
223
+ "id": "ab20c084-44fa-403d-92a5-1b8ced72e9be",
224
+ "metadata": {
225
+ "tags": []
226
+ },
227
+ "outputs": [],
228
+ "source": [
229
+ "agent_action = \"\"\"{\"thought\": \"\", \"tool_calls\": [{\"name\": \"get_weather\", \"arguments\": {\"location\": \"New York\"}}]}\n",
230
+ "\"\"\".strip()"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "id": "1cd4d8e4-ee6b-499e-b75f-a48df7848a60",
236
+ "metadata": {},
237
+ "source": [
238
+ "### Add follow-up question"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 4,
244
+ "id": "825649ba-2691-43a2-b3d8-7baf8b66d46e",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "def parse_agent_action(agent_action: str):\n",
249
+ " \"\"\"\n",
250
+ " Given an agent's action, parse it to add to conversation history\n",
251
+ " \"\"\"\n",
252
+ " try: parsed_agent_action_json = json.loads(agent_action)\n",
253
+ " except: return \"\", []\n",
254
+ " \n",
255
+ " if \"thought\" not in parsed_agent_action_json.keys(): thought = \"\"\n",
256
+ " else: thought = parsed_agent_action_json[\"thought\"]\n",
257
+ " \n",
258
+ " if \"tool_calls\" not in parsed_agent_action_json.keys(): tool_calls = []\n",
259
+ " else: tool_calls = parsed_agent_action_json[\"tool_calls\"]\n",
260
+ " \n",
261
+ " return thought, tool_calls\n",
262
+ "\n",
263
+ "def update_conversation_history(conversation_history: list, agent_action: str, environment_response: str, user_input: str):\n",
264
+ " \"\"\"\n",
265
+ " Update the conversation history list based on the new agent_action, environment_response, and/or user_input\n",
266
+ " \"\"\"\n",
267
+ " thought, tool_calls = parse_agent_action(agent_action)\n",
268
+ " new_step_data = {\n",
269
+ " \"step_id\": len(conversation_history) + 1,\n",
270
+ " \"thought\": thought,\n",
271
+ " \"tool_calls\": tool_calls,\n",
272
+ " \"next_observation\": environment_response,\n",
273
+ " \"user_input\": user_input,\n",
274
+ " }\n",
275
+ " \n",
276
+ " conversation_history.append(new_step_data)\n",
277
+ "\n",
278
+ "def get_environment_response(agent_action: str):\n",
279
+ " \"\"\"\n",
280
+ " Get the environment response for the agent_action\n",
281
+ " \"\"\"\n",
282
+ " # TODO: add custom implementation here\n",
283
+ " error_message, response_message = \"\", \"Sunny, 81 degrees\"\n",
284
+ " return {\"error\": error_message, \"response\": response_message}\n",
285
+ "\n"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "id": "051e6aff-c21b-4dcb-9eb8-c34154d90c39",
291
+ "metadata": {},
292
+ "source": [
293
+ "1. **Get the next state after agent's response:**\n",
294
+ " The next 2 lines are examples of getting environment response and user_input.\n",
295
+ " It is depended on particular usage, we can have either one or both of those."
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 5,
301
+ "id": "649a8e9d-9757-408c-9214-0590556c2db4",
302
+ "metadata": {
303
+ "tags": []
304
+ },
305
+ "outputs": [],
306
+ "source": [
307
+ "environment_response = get_environment_response(agent_action)\n",
308
+ "user_input = \"Now, search on the Internet for cute puppies\""
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "id": "9c9c9418-1c54-4381-81d1-7f3834037739",
314
+ "metadata": {},
315
+ "source": [
316
+ "2. After we got environment_response and (or) user_input, we want to add to our conversation history"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 6,
322
+ "id": "bcfe89f3-8237-41bf-b92c-7c7568366042",
323
+ "metadata": {
324
+ "tags": []
325
+ },
326
+ "outputs": [
327
+ {
328
+ "data": {
329
+ "text/plain": [
330
+ "[{'step_id': 1,\n",
331
+ " 'thought': '',\n",
332
+ " 'tool_calls': [{'name': 'get_weather',\n",
333
+ " 'arguments': {'location': 'New York'}}],\n",
334
+ " 'next_observation': {'error': '', 'response': 'Sunny, 81 degrees'},\n",
335
+ " 'user_input': 'Now, search on the Internet for cute puppies'}]"
336
+ ]
337
+ },
338
+ "execution_count": 6,
339
+ "metadata": {},
340
+ "output_type": "execute_result"
341
+ }
342
+ ],
343
+ "source": [
344
+ "update_conversation_history(conversation_history, agent_action, environment_response, user_input)\n",
345
+ "conversation_history"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "markdown",
350
+ "id": "23ba97c6-2356-49e8-a07b-0e664b7f505c",
351
+ "metadata": {},
352
+ "source": [
353
+ "3. We now can build the prompt with the updated history, and prepare the inputs for the LLM"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": 7,
359
+ "id": "ed204b3a-3be5-431b-b355-facaf31309d2",
360
+ "metadata": {
361
+ "tags": []
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history)\n",
366
+ "messages=[\n",
367
+ " { 'role': 'user', 'content': content}\n",
368
+ "]\n"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": 8,
374
+ "id": "8af843aa-6a47-4938-a455-567ea0cccce3",
375
+ "metadata": {
376
+ "tags": []
377
+ },
378
+ "outputs": [
379
+ {
380
+ "name": "stdout",
381
+ "output_type": "stream",
382
+ "text": [
383
+ "[BEGIN OF TASK INSTRUCTION]\n",
384
+ "Based on the previous context and API request history, generate an API request or a response as an AI assistant.\n",
385
+ "[END OF TASK INSTRUCTION]\n",
386
+ "\n",
387
+ "[BEGIN OF AVAILABLE TOOLS]\n",
388
+ "[{\"name\": \"get_weather\", \"description\": \"Get the current weather for a location\", \"parameters\": {\"location\": {\"type\": \"string\", \"description\": \"The city and state, e.g. San Francisco, New York\"}, \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"], \"description\": \"The unit of temperature to return\"}}}, {\"name\": \"search\", \"description\": \"Search for information on the internet\", \"parameters\": {\"query\": {\"type\": \"string\", \"description\": \"The search query, e.g. 'latest news on AI'\"}}}]\n",
389
+ "[END OF AVAILABLE TOOLS]\n",
390
+ "\n",
391
+ "[BEGIN OF FORMAT INSTRUCTION]\n",
392
+ "The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make \n",
393
+ "tool_calls an empty list \"[]\".\n",
394
+ "```\n",
395
+ "{\"thought\": \"the thought process, or an empty string\", \"tool_calls\": [{\"name\": \"api_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}}]}\n",
396
+ "```\n",
397
+ "[END OF FORMAT INSTRUCTION]\n",
398
+ "\n",
399
+ "[BEGIN OF QUERY]\n",
400
+ "What's the weather like in New York in fahrenheit?\n",
401
+ "[END OF QUERY]\n",
402
+ "\n",
403
+ "\n",
404
+ "[BEGIN OF HISTORY STEPS]\n",
405
+ "[{\"step_id\": 1, \"thought\": \"\", \"tool_calls\": [{\"name\": \"get_weather\", \"arguments\": {\"location\": \"New York\"}}], \"next_observation\": {\"error\": \"\", \"response\": \"Sunny, 81 degrees\"}, \"user_input\": \"Now, search on the Internet for cute puppies\"}]\n",
406
+ "[END OF HISTORY STEPS]\n",
407
+ "\n"
408
+ ]
409
+ }
410
+ ],
411
+ "source": [
412
+ "print(content)"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "markdown",
417
+ "id": "71f76a10-a152-49d7-aa6f-3060cc49b935",
418
+ "metadata": {},
419
+ "source": [
420
+ "## Get the model output for follow-up question"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "id": "30af06fd-4aa7-4550-af39-3a77b5951882",
427
+ "metadata": {},
428
+ "outputs": [],
429
+ "source": [
430
+ "inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
431
+ "# 5. Generate the outputs & decode\n",
432
+ "# tokenizer.eos_token_id is the id of <|EOT|> token\n",
433
+ "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
434
+ "agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n"
435
+ ]
436
+ }
437
+ ],
438
+ "metadata": {
439
+ "kernelspec": {
440
+ "display_name": "Python 3 (ipykernel) (Local)",
441
+ "language": "python",
442
+ "name": "python3"
443
+ },
444
+ "language_info": {
445
+ "codemirror_mode": {
446
+ "name": "ipython",
447
+ "version": 3
448
+ },
449
+ "file_extension": ".py",
450
+ "mimetype": "text/x-python",
451
+ "name": "python",
452
+ "nbconvert_exporter": "python",
453
+ "pygments_lexer": "ipython3",
454
+ "version": "3.10.13"
455
+ }
456
+ },
457
+ "nbformat": 4,
458
+ "nbformat_minor": 5
459
+ }