Update README.md
Browse files
README.md
CHANGED
@@ -87,8 +87,139 @@ Use the code below to get started with the model.
|
|
87 |
|
88 |
#### Preprocessing [optional]
|
89 |
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
#### Training Hyperparameters
|
94 |
|
|
|
87 |
|
88 |
#### Preprocessing [optional]
|
89 |
|
90 |
+
import json
|
91 |
+
import torch
|
92 |
+
|
93 |
+
class ResponseHandler:
|
94 |
+
def __init__(self):
|
95 |
+
self.function_list = {
|
96 |
+
'get_weather': self.get_weather
|
97 |
+
}
|
98 |
+
self.tool_list = {
|
99 |
+
'get_weather': self.get_weather
|
100 |
+
}
|
101 |
+
|
102 |
+
def handle_function_call(self, response):
|
103 |
+
function_calls = self.extract_function_calls(response)
|
104 |
+
results = []
|
105 |
+
for func_call in function_calls:
|
106 |
+
function_name = func_call.get('name')
|
107 |
+
arguments = func_call.get('arguments')
|
108 |
+
result = self.execute_function(function_name, arguments)
|
109 |
+
print(f"Function call result: {result}")
|
110 |
+
results.append(result)
|
111 |
+
return results
|
112 |
+
|
113 |
+
def extract_function_calls(self, response):
|
114 |
+
return response.get('tool_calls', [])
|
115 |
+
|
116 |
+
def handle_content_filter_error(self, response):
|
117 |
+
error_message = "Content filtered due to policy restrictions. Please modify your request."
|
118 |
+
print(error_message)
|
119 |
+
return error_message
|
120 |
+
|
121 |
+
def handle_length_error(self, response):
|
122 |
+
truncated_conversation = self.truncate_conversation(response)
|
123 |
+
print("Truncated conversation to fit within the context window.")
|
124 |
+
return truncated_conversation
|
125 |
+
|
126 |
+
def truncate_conversation(self, response):
|
127 |
+
# Implement actual truncation logic here
|
128 |
+
return response
|
129 |
+
|
130 |
+
def execute_function(self, function_name, arguments):
|
131 |
+
func = self.function_list.get(function_name)
|
132 |
+
if func:
|
133 |
+
return func(**arguments)
|
134 |
+
else:
|
135 |
+
error_message = f"Function '{function_name}' not found."
|
136 |
+
print(error_message)
|
137 |
+
return error_message
|
138 |
+
|
139 |
+
def execute_tool(self, tool_name, arguments):
|
140 |
+
tool = self.tool_list.get(tool_name)
|
141 |
+
if tool:
|
142 |
+
return tool(**arguments)
|
143 |
+
else:
|
144 |
+
error_message = f"Tool '{tool_name}' not found."
|
145 |
+
print(error_message)
|
146 |
+
return error_message
|
147 |
+
|
148 |
+
def extract_tool_calls(self, response):
|
149 |
+
return response.get('tool_calls', [])
|
150 |
+
|
151 |
+
def handle_tool_call(self, response):
|
152 |
+
tool_calls = self.extract_tool_calls(response)
|
153 |
+
results = []
|
154 |
+
for tool_call in tool_calls:
|
155 |
+
tool_name = tool_call.get('name')
|
156 |
+
arguments = tool_call.get('arguments')
|
157 |
+
result = self.execute_tool(tool_name, arguments)
|
158 |
+
print(f"Tool call result: {result}")
|
159 |
+
results.append(result)
|
160 |
+
return results
|
161 |
+
|
162 |
+
def get_weather(self, location, unit):
|
163 |
+
# Dummy implementation for testing
|
164 |
+
return f"Weather in {location} is 75 degrees {unit}."
|
165 |
+
|
166 |
+
def convert_to_xlam_tool(tools):
|
167 |
+
if isinstance(tools, dict):
|
168 |
+
return {
|
169 |
+
"name": tools["name"],
|
170 |
+
"description": tools["description"],
|
171 |
+
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
|
172 |
+
}
|
173 |
+
elif isinstance(tools, list):
|
174 |
+
return [convert_to_xlam_tool(tool) for tool in tools]
|
175 |
+
else:
|
176 |
+
return tools
|
177 |
+
|
178 |
+
def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str):
|
179 |
+
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
|
180 |
+
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
|
181 |
+
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
|
182 |
+
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
|
183 |
+
return prompt
|
184 |
+
|
185 |
+
# Example usage
|
186 |
+
openai_format_tools = [
|
187 |
+
{
|
188 |
+
"name": "get_weather",
|
189 |
+
"description": "Get the current weather for a location.",
|
190 |
+
"parameters": {
|
191 |
+
"properties": {
|
192 |
+
"location": {"type": "string"},
|
193 |
+
"unit": {"type": "string"}
|
194 |
+
}
|
195 |
+
}
|
196 |
+
}
|
197 |
+
]
|
198 |
+
|
199 |
+
task_instruction = "Provide the current weather."
|
200 |
+
format_instruction = "Return the weather information in a readable format."
|
201 |
+
query = "What is the weather in New York in fahrenheit?"
|
202 |
+
|
203 |
+
xlam_format_tools = convert_to_xlam_tool(openai_format_tools)
|
204 |
+
content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query)
|
205 |
+
|
206 |
+
messages = [
|
207 |
+
{'role': 'user', 'content': content}
|
208 |
+
]
|
209 |
+
|
210 |
+
# Assuming a temporary model name and tokenizer
|
211 |
+
model_name = "temp_model"
|
212 |
+
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', model_name)
|
213 |
+
model = torch.hub.load('huggingface/pytorch-transformers', 'modelWithLMHead', model_name)
|
214 |
+
|
215 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
216 |
+
|
217 |
+
# tokenizer.eos_token_id is the id of <|EOT|> token
|
218 |
+
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
219 |
+
print(tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True))
|
220 |
+
|
221 |
+
|
222 |
+
##end
|
223 |
|
224 |
#### Training Hyperparameters
|
225 |
|