aim98 commited on
Commit
9fd2d07
1 Parent(s): 54b6102

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +133 -2
README.md CHANGED
@@ -87,8 +87,139 @@ Use the code below to get started with the model.
87
 
88
  #### Preprocessing [optional]
89
 
90
- [More Information Needed]
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  #### Training Hyperparameters
94
 
 
87
 
88
  #### Preprocessing [optional]
89
 
90
+ import json
91
+ import torch
92
+
93
+ class ResponseHandler:
94
+ def __init__(self):
95
+ self.function_list = {
96
+ 'get_weather': self.get_weather
97
+ }
98
+ self.tool_list = {
99
+ 'get_weather': self.get_weather
100
+ }
101
+
102
+ def handle_function_call(self, response):
103
+ function_calls = self.extract_function_calls(response)
104
+ results = []
105
+ for func_call in function_calls:
106
+ function_name = func_call.get('name')
107
+ arguments = func_call.get('arguments')
108
+ result = self.execute_function(function_name, arguments)
109
+ print(f"Function call result: {result}")
110
+ results.append(result)
111
+ return results
112
+
113
+ def extract_function_calls(self, response):
114
+ return response.get('tool_calls', [])
115
+
116
+ def handle_content_filter_error(self, response):
117
+ error_message = "Content filtered due to policy restrictions. Please modify your request."
118
+ print(error_message)
119
+ return error_message
120
+
121
+ def handle_length_error(self, response):
122
+ truncated_conversation = self.truncate_conversation(response)
123
+ print("Truncated conversation to fit within the context window.")
124
+ return truncated_conversation
125
+
126
+ def truncate_conversation(self, response):
127
+ # Implement actual truncation logic here
128
+ return response
129
+
130
+ def execute_function(self, function_name, arguments):
131
+ func = self.function_list.get(function_name)
132
+ if func:
133
+ return func(**arguments)
134
+ else:
135
+ error_message = f"Function '{function_name}' not found."
136
+ print(error_message)
137
+ return error_message
138
+
139
+ def execute_tool(self, tool_name, arguments):
140
+ tool = self.tool_list.get(tool_name)
141
+ if tool:
142
+ return tool(**arguments)
143
+ else:
144
+ error_message = f"Tool '{tool_name}' not found."
145
+ print(error_message)
146
+ return error_message
147
+
148
+ def extract_tool_calls(self, response):
149
+ return response.get('tool_calls', [])
150
+
151
+ def handle_tool_call(self, response):
152
+ tool_calls = self.extract_tool_calls(response)
153
+ results = []
154
+ for tool_call in tool_calls:
155
+ tool_name = tool_call.get('name')
156
+ arguments = tool_call.get('arguments')
157
+ result = self.execute_tool(tool_name, arguments)
158
+ print(f"Tool call result: {result}")
159
+ results.append(result)
160
+ return results
161
+
162
+ def get_weather(self, location, unit):
163
+ # Dummy implementation for testing
164
+ return f"Weather in {location} is 75 degrees {unit}."
165
+
166
+ def convert_to_xlam_tool(tools):
167
+ if isinstance(tools, dict):
168
+ return {
169
+ "name": tools["name"],
170
+ "description": tools["description"],
171
+ "parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
172
+ }
173
+ elif isinstance(tools, list):
174
+ return [convert_to_xlam_tool(tool) for tool in tools]
175
+ else:
176
+ return tools
177
+
178
+ def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str):
179
+ prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
180
+ prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
181
+ prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
182
+ prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
183
+ return prompt
184
+
185
+ # Example usage
186
+ openai_format_tools = [
187
+ {
188
+ "name": "get_weather",
189
+ "description": "Get the current weather for a location.",
190
+ "parameters": {
191
+ "properties": {
192
+ "location": {"type": "string"},
193
+ "unit": {"type": "string"}
194
+ }
195
+ }
196
+ }
197
+ ]
198
+
199
+ task_instruction = "Provide the current weather."
200
+ format_instruction = "Return the weather information in a readable format."
201
+ query = "What is the weather in New York in fahrenheit?"
202
+
203
+ xlam_format_tools = convert_to_xlam_tool(openai_format_tools)
204
+ content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query)
205
+
206
+ messages = [
207
+ {'role': 'user', 'content': content}
208
+ ]
209
+
210
+ # Assuming a temporary model name and tokenizer
211
+ model_name = "temp_model"
212
+ tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', model_name)
213
+ model = torch.hub.load('huggingface/pytorch-transformers', 'modelWithLMHead', model_name)
214
+
215
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
216
+
217
+ # tokenizer.eos_token_id is the id of <|EOT|> token
218
+ outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
219
+ print(tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True))
220
+
221
+
222
+ ##end
223
 
224
  #### Training Hyperparameters
225